To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @softmax_CPD / softmax_CPD.m @ 8:b5b38998ef3b
History | View | Annotate | Download (7.69 KB)
| 1 |
function CPD = softmax_CPD(bnet, self, varargin) |
|---|---|
| 2 |
% SOFTMAX_CPD Make a softmax (multinomial logit) CPD |
| 3 |
% |
| 4 |
% To define this CPD precisely, let W be an (m x n) matrix with W(i,:) = {i-th row of B}
|
| 5 |
% => we can define the following vectorial function: |
| 6 |
% |
| 7 |
% softmax: R^n |--> R^m |
| 8 |
% softmax(z,i-th)=exp(W(i,:)*z)/sum_k(exp(W(k,:)*z)) |
| 9 |
% |
| 10 |
% (this constructor augments z with a one at the beginning to introduce an offset term (=bias, intercept)) |
| 11 |
% Now call the continuous (cts) and always observed (obs) parents X, |
| 12 |
% the discrete parents (if any) Q, and this node Y then we use the discrete parent(s) just to index |
| 13 |
% the parameter vectors (c.f., conditional Gaussian nodes); that is: |
| 14 |
% prob(Y=i | X=x, Q=j) = softmax(x,i-th|j) |
| 15 |
% where '|j' means that we are using the j-th (m x n) parameters matrix W(:,:,j). |
| 16 |
% If there are no discrete parents, this is a regular softmax node. |
| 17 |
% If Y is binary, this is a logistic (sigmoid) function. |
| 18 |
% |
| 19 |
% CPD = softmax_CPD(bnet, node_num, ...) will create a softmax CPD with random parameters, |
| 20 |
% where node is the number of a node in this equivalence class. |
| 21 |
% |
| 22 |
% The following optional arguments can be specified in the form of name/value pairs: |
| 23 |
% [default value in brackets] |
| 24 |
% (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y), Q1=ns(dps(1)), Q2=ns(dps(2)), ... |
| 25 |
% where dps are the discrete parents; if there are no discrete parents, we set Q1=1.) |
| 26 |
% |
| 27 |
% discrete - the discrete parents that we want to treat like the cts ones [ [] ]. |
| 28 |
% This can be used to define sigmoid belief network - see below the reference. |
| 29 |
% For example suppose that Y has one cts parents X and two discrete ones: Q, C1 where: |
| 30 |
% -> Q is binary (1/2) and used just to index the parameters of 'self' |
| 31 |
% -> C1 is ternary (1/2/3) and treated as a cts node <=> its values appear into the linear |
| 32 |
% part of the softmax function |
| 33 |
% then: |
| 34 |
% prob(Y|X=x, Q=q, C1=c1)= softmax(W(:,:,q)' * y) |
| 35 |
% where y = [1 | delta(C1,1) delta(C1,2) delta(C1,3) | x(:)']' and delta(Y,a)=indicator(Y=a). |
| 36 |
% weights - (w(:,j,a,b,...) - w(:,j',a,b,...)) is ppn to dec. boundary |
| 37 |
% between j,j' given Q1=a,Q2=b,... [ randn(X,Y,Q1,Q2,...) ] |
| 38 |
% offset - (b(j,a,b,...) - b(j',a,b,...)) is the offset to dec. boundary |
| 39 |
% between j,j' given Q1=a,Q2=b,... [ randn(Y,Q1,Q2,...) ] |
| 40 |
% |
| 41 |
% e.g., CPD = softmax_CPD(bnet, i, 'offset', zeros(ns(i),1)); |
| 42 |
% |
| 43 |
% The following fields control the behavior of the M step, which uses |
| 44 |
% a weighted version of the Iteratively Reweighted Least Squares (WIRLS) if dps_as_cps=[]; or |
| 45 |
% a weighted SCG otherwise, as implemented in Netlab, and modified by Pierpaolo Brutti. |
| 46 |
% |
| 47 |
% clamped - 'yes' means don't adjust params during learning ['no'] |
| 48 |
% max_iter - the maximum number of steps to take [10] |
| 49 |
% verbose - 'yes' means print the LL at each step of IRLS ['no'] |
| 50 |
% wthresh - convergence threshold for weights [1e-2] |
| 51 |
% llthresh - convergence threshold for log likelihood [1e-2] |
| 52 |
% approx_hess - 'yes' means approximate the Hessian for speed ['no'] |
| 53 |
% |
| 54 |
% For backwards compatibility with BNT2, you can also specify the parameters in the following order |
| 55 |
% softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess) |
| 56 |
% |
| 57 |
% REFERENCE |
| 58 |
% For details on the sigmoid belief nets, see: |
| 59 |
% - Neal (1992). Connectionist learning of belief networks, Artificial Intelligence, 56, 71-113. |
| 60 |
% - Saul, Jakkola, Jordan (1996). Mean field theory for sigmoid belief networks, Journal of Artificial Intelligence Reseach (4), pagg. 61-76. |
| 61 |
% |
| 62 |
% For details on the M step, see: |
| 63 |
% - K. Chen, L. Xu, H. Chi (1999). Improved learning algorithms for mixtures of experts in multiclass |
| 64 |
% classification. Neural Networks 12, pp. 1229-1252. |
| 65 |
% - M.I. Jordan, R.A. Jacobs (1994). Hierarchical Mixtures of Experts and the EM algorithm. |
| 66 |
% Neural Computation 6, pp. 181-214. |
| 67 |
% - S.R. Waterhouse, A.J. Robinson (1994). Classification Using Hierarchical Mixtures of Experts. In Proc. IEEE |
| 68 |
% Workshop on Neural Network for Signal Processing IV, pp. 177-186 |
| 69 |
|
| 70 |
if nargin==0 |
| 71 |
% This occurs if we are trying to load an object from a file. |
| 72 |
CPD = init_fields; |
| 73 |
CPD = class(CPD, 'softmax_CPD', discrete_CPD(0, [])); |
| 74 |
return; |
| 75 |
elseif isa(bnet, 'softmax_CPD') |
| 76 |
% This might occur if we are copying an object. |
| 77 |
CPD = bnet; |
| 78 |
return; |
| 79 |
end |
| 80 |
CPD = init_fields; |
| 81 |
|
| 82 |
assert(myismember(self, bnet.dnodes)); |
| 83 |
ns = bnet.node_sizes; |
| 84 |
ps = parents(bnet.dag, self); |
| 85 |
dps = myintersect(ps, bnet.dnodes); |
| 86 |
cps = myintersect(ps, bnet.cnodes); |
| 87 |
|
| 88 |
clamped = 0; |
| 89 |
CPD = class(CPD, 'softmax_CPD', discrete_CPD(clamped, ns([ps self]))); |
| 90 |
|
| 91 |
dps_as_cpssz = 0; |
| 92 |
dps_as_cps = []; |
| 93 |
% determine if any discrete parents are to be treated as cts |
| 94 |
if nargin >= 3 & isstr(varargin{1}) % might have passed in 'discrete'
|
| 95 |
for i=1:2:length(varargin) |
| 96 |
if strcmp(varargin{i}, 'discrete')
|
| 97 |
dps_as_cps = varargin{i+1};
|
| 98 |
assert(myismember(dps_as_cps, dps)); |
| 99 |
dps = mysetdiff(dps, dps_as_cps); % put out the dps treated as cts |
| 100 |
CPD.dps_as_cps.ndx = find_equiv_posns(dps_as_cps, ps); |
| 101 |
CPD.dps_as_cps.separator = [0 cumsum(ns(dps_as_cps(1:end-1)))]; % concatenated dps_as_cps dims separators |
| 102 |
dps_as_cpssz = sum(ns(dps_as_cps)); |
| 103 |
break; |
| 104 |
end |
| 105 |
end |
| 106 |
end |
| 107 |
assert(~isempty(union(cps, dps_as_cps))); % It have to be at least a cts or a dps_as_cps parents |
| 108 |
self_size = ns(self); |
| 109 |
cpsz = sum(ns(cps)); |
| 110 |
glimsz = prod(ns(dps)); |
| 111 |
CPD.dpndx = find_equiv_posns(dps, ps); % it contains only the indeces of the 'pure' dps |
| 112 |
CPD.cpndx = find_equiv_posns(cps, ps); |
| 113 |
|
| 114 |
CPD.self = self; |
| 115 |
CPD.solo = (length(ns)<=2); |
| 116 |
CPD.sizes = bnet.node_sizes([ps self]); |
| 117 |
|
| 118 |
% set default params |
| 119 |
CPD.max_iter = 10; |
| 120 |
CPD.verbose = 0; |
| 121 |
CPD.wthresh = 1e-2; |
| 122 |
CPD.llthresh = 1e-2; |
| 123 |
CPD.approx_hess = 0; |
| 124 |
CPD.glim = cell(1,glimsz); |
| 125 |
for i=1:glimsz |
| 126 |
CPD.glim{i} = glm(dps_as_cpssz + cpsz, self_size, 'softmax');
|
| 127 |
end |
| 128 |
|
| 129 |
if nargin >= 3 |
| 130 |
args = varargin; |
| 131 |
nargs = length(args); |
| 132 |
if ~isstr(args{1})
|
| 133 |
% softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess) |
| 134 |
if nargs >= 1 & ~isempty(args{1}), CPD = set_fields(CPD, 'weights', args{1}); end
|
| 135 |
if nargs >= 2 & ~isempty(args{2}), CPD = set_fields(CPD, 'offset', args{2}); end
|
| 136 |
if nargs >= 3 & ~isempty(args{3}), CPD = set_clamped(CPD, args{3}); end
|
| 137 |
if nargs >= 4 & ~isempty(args{4}), CPD.max_iter = args{4}; end
|
| 138 |
if nargs >= 5 & ~isempty(args{5}), CPD.verbose = args{5}; end
|
| 139 |
if nargs >= 6 & ~isempty(args{6}), CPD.wthresh = args{6}; end
|
| 140 |
if nargs >= 7 & ~isempty(args{7}), CPD.llthresh = args{7}; end
|
| 141 |
if nargs >= 8 & ~isempty(args{8}), CPD.approx_hess = args{8}; end
|
| 142 |
else |
| 143 |
CPD = set_fields(CPD, args{:});
|
| 144 |
end |
| 145 |
end |
| 146 |
|
| 147 |
% sufficient statistics |
| 148 |
% Since dsoftmax is not in the exponential family, we must store all the raw data. |
| 149 |
CPD.parent_vals = []; % X(l,:) = value of cts parents in l'th example |
| 150 |
CPD.self_vals = []; % Y(l,:) = value of self in l'th example |
| 151 |
|
| 152 |
CPD.eso_weights=[]; % weights used by the WIRLS algorithm |
| 153 |
|
| 154 |
% For BIC |
| 155 |
CPD.nsamples = 0; |
| 156 |
if ~adjustable_CPD(CPD), |
| 157 |
CPD.nparams=0; |
| 158 |
else |
| 159 |
[W, b] = extract_params(CPD); |
| 160 |
CPD.nparams= prod(size(W)) + prod(size(b)); |
| 161 |
end |
| 162 |
|
| 163 |
%%%%%%%%%%% |
| 164 |
|
| 165 |
function CPD = init_fields() |
| 166 |
% This ensures we define the fields in the same order |
| 167 |
% no matter whether we load an object from a file, |
| 168 |
% or create it from scratch. (Matlab requires this.) |
| 169 |
|
| 170 |
CPD.glim = {};
|
| 171 |
CPD.self = []; |
| 172 |
CPD.solo = []; |
| 173 |
CPD.max_iter = []; |
| 174 |
CPD.verbose = []; |
| 175 |
CPD.wthresh = []; |
| 176 |
CPD.llthresh = []; |
| 177 |
CPD.approx_hess = []; |
| 178 |
CPD.sizes = []; |
| 179 |
CPD.parent_vals = []; |
| 180 |
CPD.eso_weights=[]; |
| 181 |
CPD.self_vals = []; |
| 182 |
CPD.nsamples = []; |
| 183 |
CPD.nparams = []; |
| 184 |
CPD.dpndx = []; |
| 185 |
CPD.cpndx = []; |
| 186 |
CPD.dps_as_cps.ndx = []; |
| 187 |
CPD.dps_as_cps.separator = []; |