Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@softmax_CPD/softmax_CPD.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = softmax_CPD(bnet, self, varargin) | |
2 % SOFTMAX_CPD Make a softmax (multinomial logit) CPD | |
3 % | |
4 % To define this CPD precisely, let W be an (m x n) matrix with W(i,:) = {i-th row of B} | |
5 % => we can define the following vectorial function: | |
6 % | |
7 % softmax: R^n |--> R^m | |
8 % softmax(z,i-th)=exp(W(i,:)*z)/sum_k(exp(W(k,:)*z)) | |
9 % | |
10 % (this constructor augments z with a one at the beginning to introduce an offset term (=bias, intercept)) | |
11 % Now call the continuous (cts) and always observed (obs) parents X, | |
12 % the discrete parents (if any) Q, and this node Y then we use the discrete parent(s) just to index | |
13 % the parameter vectors (c.f., conditional Gaussian nodes); that is: | |
14 % prob(Y=i | X=x, Q=j) = softmax(x,i-th|j) | |
15 % where '|j' means that we are using the j-th (m x n) parameters matrix W(:,:,j). | |
16 % If there are no discrete parents, this is a regular softmax node. | |
17 % If Y is binary, this is a logistic (sigmoid) function. | |
18 % | |
19 % CPD = softmax_CPD(bnet, node_num, ...) will create a softmax CPD with random parameters, | |
20 % where node is the number of a node in this equivalence class. | |
21 % | |
22 % The following optional arguments can be specified in the form of name/value pairs: | |
23 % [default value in brackets] | |
24 % (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y), Q1=ns(dps(1)), Q2=ns(dps(2)), ... | |
25 % where dps are the discrete parents; if there are no discrete parents, we set Q1=1.) | |
26 % | |
27 % discrete - the discrete parents that we want to treat like the cts ones [ [] ]. | |
28 % This can be used to define sigmoid belief network - see below the reference. | |
29 % For example suppose that Y has one cts parents X and two discrete ones: Q, C1 where: | |
30 % -> Q is binary (1/2) and used just to index the parameters of 'self' | |
31 % -> C1 is ternary (1/2/3) and treated as a cts node <=> its values appear into the linear | |
32 % part of the softmax function | |
33 % then: | |
34 % prob(Y|X=x, Q=q, C1=c1)= softmax(W(:,:,q)' * y) | |
35 % where y = [1 | delta(C1,1) delta(C1,2) delta(C1,3) | x(:)']' and delta(Y,a)=indicator(Y=a). | |
36 % weights - (w(:,j,a,b,...) - w(:,j',a,b,...)) is ppn to dec. boundary | |
37 % between j,j' given Q1=a,Q2=b,... [ randn(X,Y,Q1,Q2,...) ] | |
38 % offset - (b(j,a,b,...) - b(j',a,b,...)) is the offset to dec. boundary | |
39 % between j,j' given Q1=a,Q2=b,... [ randn(Y,Q1,Q2,...) ] | |
40 % | |
41 % e.g., CPD = softmax_CPD(bnet, i, 'offset', zeros(ns(i),1)); | |
42 % | |
43 % The following fields control the behavior of the M step, which uses | |
44 % a weighted version of the Iteratively Reweighted Least Squares (WIRLS) if dps_as_cps=[]; or | |
45 % a weighted SCG otherwise, as implemented in Netlab, and modified by Pierpaolo Brutti. | |
46 % | |
47 % clamped - 'yes' means don't adjust params during learning ['no'] | |
48 % max_iter - the maximum number of steps to take [10] | |
49 % verbose - 'yes' means print the LL at each step of IRLS ['no'] | |
50 % wthresh - convergence threshold for weights [1e-2] | |
51 % llthresh - convergence threshold for log likelihood [1e-2] | |
52 % approx_hess - 'yes' means approximate the Hessian for speed ['no'] | |
53 % | |
54 % For backwards compatibility with BNT2, you can also specify the parameters in the following order | |
55 % softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess) | |
56 % | |
57 % REFERENCE | |
58 % For details on the sigmoid belief nets, see: | |
59 % - Neal (1992). Connectionist learning of belief networks, Artificial Intelligence, 56, 71-113. | |
60 % - Saul, Jakkola, Jordan (1996). Mean field theory for sigmoid belief networks, Journal of Artificial Intelligence Reseach (4), pagg. 61-76. | |
61 % | |
62 % For details on the M step, see: | |
63 % - K. Chen, L. Xu, H. Chi (1999). Improved learning algorithms for mixtures of experts in multiclass | |
64 % classification. Neural Networks 12, pp. 1229-1252. | |
65 % - M.I. Jordan, R.A. Jacobs (1994). Hierarchical Mixtures of Experts and the EM algorithm. | |
66 % Neural Computation 6, pp. 181-214. | |
67 % - S.R. Waterhouse, A.J. Robinson (1994). Classification Using Hierarchical Mixtures of Experts. In Proc. IEEE | |
68 % Workshop on Neural Network for Signal Processing IV, pp. 177-186 | |
69 | |
70 if nargin==0 | |
71 % This occurs if we are trying to load an object from a file. | |
72 CPD = init_fields; | |
73 CPD = class(CPD, 'softmax_CPD', discrete_CPD(0, [])); | |
74 return; | |
75 elseif isa(bnet, 'softmax_CPD') | |
76 % This might occur if we are copying an object. | |
77 CPD = bnet; | |
78 return; | |
79 end | |
80 CPD = init_fields; | |
81 | |
82 assert(myismember(self, bnet.dnodes)); | |
83 ns = bnet.node_sizes; | |
84 ps = parents(bnet.dag, self); | |
85 dps = myintersect(ps, bnet.dnodes); | |
86 cps = myintersect(ps, bnet.cnodes); | |
87 | |
88 clamped = 0; | |
89 CPD = class(CPD, 'softmax_CPD', discrete_CPD(clamped, ns([ps self]))); | |
90 | |
91 dps_as_cpssz = 0; | |
92 dps_as_cps = []; | |
93 % determine if any discrete parents are to be treated as cts | |
94 if nargin >= 3 & isstr(varargin{1}) % might have passed in 'discrete' | |
95 for i=1:2:length(varargin) | |
96 if strcmp(varargin{i}, 'discrete') | |
97 dps_as_cps = varargin{i+1}; | |
98 assert(myismember(dps_as_cps, dps)); | |
99 dps = mysetdiff(dps, dps_as_cps); % put out the dps treated as cts | |
100 CPD.dps_as_cps.ndx = find_equiv_posns(dps_as_cps, ps); | |
101 CPD.dps_as_cps.separator = [0 cumsum(ns(dps_as_cps(1:end-1)))]; % concatenated dps_as_cps dims separators | |
102 dps_as_cpssz = sum(ns(dps_as_cps)); | |
103 break; | |
104 end | |
105 end | |
106 end | |
107 assert(~isempty(union(cps, dps_as_cps))); % It have to be at least a cts or a dps_as_cps parents | |
108 self_size = ns(self); | |
109 cpsz = sum(ns(cps)); | |
110 glimsz = prod(ns(dps)); | |
111 CPD.dpndx = find_equiv_posns(dps, ps); % it contains only the indeces of the 'pure' dps | |
112 CPD.cpndx = find_equiv_posns(cps, ps); | |
113 | |
114 CPD.self = self; | |
115 CPD.solo = (length(ns)<=2); | |
116 CPD.sizes = bnet.node_sizes([ps self]); | |
117 | |
118 % set default params | |
119 CPD.max_iter = 10; | |
120 CPD.verbose = 0; | |
121 CPD.wthresh = 1e-2; | |
122 CPD.llthresh = 1e-2; | |
123 CPD.approx_hess = 0; | |
124 CPD.glim = cell(1,glimsz); | |
125 for i=1:glimsz | |
126 CPD.glim{i} = glm(dps_as_cpssz + cpsz, self_size, 'softmax'); | |
127 end | |
128 | |
129 if nargin >= 3 | |
130 args = varargin; | |
131 nargs = length(args); | |
132 if ~isstr(args{1}) | |
133 % softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess) | |
134 if nargs >= 1 & ~isempty(args{1}), CPD = set_fields(CPD, 'weights', args{1}); end | |
135 if nargs >= 2 & ~isempty(args{2}), CPD = set_fields(CPD, 'offset', args{2}); end | |
136 if nargs >= 3 & ~isempty(args{3}), CPD = set_clamped(CPD, args{3}); end | |
137 if nargs >= 4 & ~isempty(args{4}), CPD.max_iter = args{4}; end | |
138 if nargs >= 5 & ~isempty(args{5}), CPD.verbose = args{5}; end | |
139 if nargs >= 6 & ~isempty(args{6}), CPD.wthresh = args{6}; end | |
140 if nargs >= 7 & ~isempty(args{7}), CPD.llthresh = args{7}; end | |
141 if nargs >= 8 & ~isempty(args{8}), CPD.approx_hess = args{8}; end | |
142 else | |
143 CPD = set_fields(CPD, args{:}); | |
144 end | |
145 end | |
146 | |
147 % sufficient statistics | |
148 % Since dsoftmax is not in the exponential family, we must store all the raw data. | |
149 CPD.parent_vals = []; % X(l,:) = value of cts parents in l'th example | |
150 CPD.self_vals = []; % Y(l,:) = value of self in l'th example | |
151 | |
152 CPD.eso_weights=[]; % weights used by the WIRLS algorithm | |
153 | |
154 % For BIC | |
155 CPD.nsamples = 0; | |
156 if ~adjustable_CPD(CPD), | |
157 CPD.nparams=0; | |
158 else | |
159 [W, b] = extract_params(CPD); | |
160 CPD.nparams= prod(size(W)) + prod(size(b)); | |
161 end | |
162 | |
163 %%%%%%%%%%% | |
164 | |
165 function CPD = init_fields() | |
166 % This ensures we define the fields in the same order | |
167 % no matter whether we load an object from a file, | |
168 % or create it from scratch. (Matlab requires this.) | |
169 | |
170 CPD.glim = {}; | |
171 CPD.self = []; | |
172 CPD.solo = []; | |
173 CPD.max_iter = []; | |
174 CPD.verbose = []; | |
175 CPD.wthresh = []; | |
176 CPD.llthresh = []; | |
177 CPD.approx_hess = []; | |
178 CPD.sizes = []; | |
179 CPD.parent_vals = []; | |
180 CPD.eso_weights=[]; | |
181 CPD.self_vals = []; | |
182 CPD.nsamples = []; | |
183 CPD.nparams = []; | |
184 CPD.dpndx = []; | |
185 CPD.cpndx = []; | |
186 CPD.dps_as_cps.ndx = []; | |
187 CPD.dps_as_cps.separator = []; |