annotate _FullBNT/BNT/CPDs/@softmax_CPD/softmax_CPD.m @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function CPD = softmax_CPD(bnet, self, varargin)
matthiasm@8 2 % SOFTMAX_CPD Make a softmax (multinomial logit) CPD
matthiasm@8 3 %
matthiasm@8 4 % To define this CPD precisely, let W be an (m x n) matrix with W(i,:) = {i-th row of B}
matthiasm@8 5 % => we can define the following vectorial function:
matthiasm@8 6 %
matthiasm@8 7 % softmax: R^n |--> R^m
matthiasm@8 8 % softmax(z,i-th)=exp(W(i,:)*z)/sum_k(exp(W(k,:)*z))
matthiasm@8 9 %
matthiasm@8 10 % (this constructor augments z with a one at the beginning to introduce an offset term (=bias, intercept))
matthiasm@8 11 % Now call the continuous (cts) and always observed (obs) parents X,
matthiasm@8 12 % the discrete parents (if any) Q, and this node Y then we use the discrete parent(s) just to index
matthiasm@8 13 % the parameter vectors (c.f., conditional Gaussian nodes); that is:
matthiasm@8 14 % prob(Y=i | X=x, Q=j) = softmax(x,i-th|j)
matthiasm@8 15 % where '|j' means that we are using the j-th (m x n) parameters matrix W(:,:,j).
matthiasm@8 16 % If there are no discrete parents, this is a regular softmax node.
matthiasm@8 17 % If Y is binary, this is a logistic (sigmoid) function.
matthiasm@8 18 %
matthiasm@8 19 % CPD = softmax_CPD(bnet, node_num, ...) will create a softmax CPD with random parameters,
matthiasm@8 20 % where node is the number of a node in this equivalence class.
matthiasm@8 21 %
matthiasm@8 22 % The following optional arguments can be specified in the form of name/value pairs:
matthiasm@8 23 % [default value in brackets]
matthiasm@8 24 % (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y), Q1=ns(dps(1)), Q2=ns(dps(2)), ...
matthiasm@8 25 % where dps are the discrete parents; if there are no discrete parents, we set Q1=1.)
matthiasm@8 26 %
matthiasm@8 27 % discrete - the discrete parents that we want to treat like the cts ones [ [] ].
matthiasm@8 28 % This can be used to define sigmoid belief network - see below the reference.
matthiasm@8 29 % For example suppose that Y has one cts parents X and two discrete ones: Q, C1 where:
matthiasm@8 30 % -> Q is binary (1/2) and used just to index the parameters of 'self'
matthiasm@8 31 % -> C1 is ternary (1/2/3) and treated as a cts node <=> its values appear into the linear
matthiasm@8 32 % part of the softmax function
matthiasm@8 33 % then:
matthiasm@8 34 % prob(Y|X=x, Q=q, C1=c1)= softmax(W(:,:,q)' * y)
matthiasm@8 35 % where y = [1 | delta(C1,1) delta(C1,2) delta(C1,3) | x(:)']' and delta(Y,a)=indicator(Y=a).
matthiasm@8 36 % weights - (w(:,j,a,b,...) - w(:,j',a,b,...)) is ppn to dec. boundary
matthiasm@8 37 % between j,j' given Q1=a,Q2=b,... [ randn(X,Y,Q1,Q2,...) ]
matthiasm@8 38 % offset - (b(j,a,b,...) - b(j',a,b,...)) is the offset to dec. boundary
matthiasm@8 39 % between j,j' given Q1=a,Q2=b,... [ randn(Y,Q1,Q2,...) ]
matthiasm@8 40 %
matthiasm@8 41 % e.g., CPD = softmax_CPD(bnet, i, 'offset', zeros(ns(i),1));
matthiasm@8 42 %
matthiasm@8 43 % The following fields control the behavior of the M step, which uses
matthiasm@8 44 % a weighted version of the Iteratively Reweighted Least Squares (WIRLS) if dps_as_cps=[]; or
matthiasm@8 45 % a weighted SCG otherwise, as implemented in Netlab, and modified by Pierpaolo Brutti.
matthiasm@8 46 %
matthiasm@8 47 % clamped - 'yes' means don't adjust params during learning ['no']
matthiasm@8 48 % max_iter - the maximum number of steps to take [10]
matthiasm@8 49 % verbose - 'yes' means print the LL at each step of IRLS ['no']
matthiasm@8 50 % wthresh - convergence threshold for weights [1e-2]
matthiasm@8 51 % llthresh - convergence threshold for log likelihood [1e-2]
matthiasm@8 52 % approx_hess - 'yes' means approximate the Hessian for speed ['no']
matthiasm@8 53 %
matthiasm@8 54 % For backwards compatibility with BNT2, you can also specify the parameters in the following order
matthiasm@8 55 % softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess)
matthiasm@8 56 %
matthiasm@8 57 % REFERENCE
matthiasm@8 58 % For details on the sigmoid belief nets, see:
matthiasm@8 59 % - Neal (1992). Connectionist learning of belief networks, Artificial Intelligence, 56, 71-113.
matthiasm@8 60 % - Saul, Jakkola, Jordan (1996). Mean field theory for sigmoid belief networks, Journal of Artificial Intelligence Reseach (4), pagg. 61-76.
matthiasm@8 61 %
matthiasm@8 62 % For details on the M step, see:
matthiasm@8 63 % - K. Chen, L. Xu, H. Chi (1999). Improved learning algorithms for mixtures of experts in multiclass
matthiasm@8 64 % classification. Neural Networks 12, pp. 1229-1252.
matthiasm@8 65 % - M.I. Jordan, R.A. Jacobs (1994). Hierarchical Mixtures of Experts and the EM algorithm.
matthiasm@8 66 % Neural Computation 6, pp. 181-214.
matthiasm@8 67 % - S.R. Waterhouse, A.J. Robinson (1994). Classification Using Hierarchical Mixtures of Experts. In Proc. IEEE
matthiasm@8 68 % Workshop on Neural Network for Signal Processing IV, pp. 177-186
matthiasm@8 69
matthiasm@8 70 if nargin==0
matthiasm@8 71 % This occurs if we are trying to load an object from a file.
matthiasm@8 72 CPD = init_fields;
matthiasm@8 73 CPD = class(CPD, 'softmax_CPD', discrete_CPD(0, []));
matthiasm@8 74 return;
matthiasm@8 75 elseif isa(bnet, 'softmax_CPD')
matthiasm@8 76 % This might occur if we are copying an object.
matthiasm@8 77 CPD = bnet;
matthiasm@8 78 return;
matthiasm@8 79 end
matthiasm@8 80 CPD = init_fields;
matthiasm@8 81
matthiasm@8 82 assert(myismember(self, bnet.dnodes));
matthiasm@8 83 ns = bnet.node_sizes;
matthiasm@8 84 ps = parents(bnet.dag, self);
matthiasm@8 85 dps = myintersect(ps, bnet.dnodes);
matthiasm@8 86 cps = myintersect(ps, bnet.cnodes);
matthiasm@8 87
matthiasm@8 88 clamped = 0;
matthiasm@8 89 CPD = class(CPD, 'softmax_CPD', discrete_CPD(clamped, ns([ps self])));
matthiasm@8 90
matthiasm@8 91 dps_as_cpssz = 0;
matthiasm@8 92 dps_as_cps = [];
matthiasm@8 93 % determine if any discrete parents are to be treated as cts
matthiasm@8 94 if nargin >= 3 & isstr(varargin{1}) % might have passed in 'discrete'
matthiasm@8 95 for i=1:2:length(varargin)
matthiasm@8 96 if strcmp(varargin{i}, 'discrete')
matthiasm@8 97 dps_as_cps = varargin{i+1};
matthiasm@8 98 assert(myismember(dps_as_cps, dps));
matthiasm@8 99 dps = mysetdiff(dps, dps_as_cps); % put out the dps treated as cts
matthiasm@8 100 CPD.dps_as_cps.ndx = find_equiv_posns(dps_as_cps, ps);
matthiasm@8 101 CPD.dps_as_cps.separator = [0 cumsum(ns(dps_as_cps(1:end-1)))]; % concatenated dps_as_cps dims separators
matthiasm@8 102 dps_as_cpssz = sum(ns(dps_as_cps));
matthiasm@8 103 break;
matthiasm@8 104 end
matthiasm@8 105 end
matthiasm@8 106 end
matthiasm@8 107 assert(~isempty(union(cps, dps_as_cps))); % It have to be at least a cts or a dps_as_cps parents
matthiasm@8 108 self_size = ns(self);
matthiasm@8 109 cpsz = sum(ns(cps));
matthiasm@8 110 glimsz = prod(ns(dps));
matthiasm@8 111 CPD.dpndx = find_equiv_posns(dps, ps); % it contains only the indeces of the 'pure' dps
matthiasm@8 112 CPD.cpndx = find_equiv_posns(cps, ps);
matthiasm@8 113
matthiasm@8 114 CPD.self = self;
matthiasm@8 115 CPD.solo = (length(ns)<=2);
matthiasm@8 116 CPD.sizes = bnet.node_sizes([ps self]);
matthiasm@8 117
matthiasm@8 118 % set default params
matthiasm@8 119 CPD.max_iter = 10;
matthiasm@8 120 CPD.verbose = 0;
matthiasm@8 121 CPD.wthresh = 1e-2;
matthiasm@8 122 CPD.llthresh = 1e-2;
matthiasm@8 123 CPD.approx_hess = 0;
matthiasm@8 124 CPD.glim = cell(1,glimsz);
matthiasm@8 125 for i=1:glimsz
matthiasm@8 126 CPD.glim{i} = glm(dps_as_cpssz + cpsz, self_size, 'softmax');
matthiasm@8 127 end
matthiasm@8 128
matthiasm@8 129 if nargin >= 3
matthiasm@8 130 args = varargin;
matthiasm@8 131 nargs = length(args);
matthiasm@8 132 if ~isstr(args{1})
matthiasm@8 133 % softmax_CPD(bnet, self, w, b, clamped, max_iter, verbose, wthresh, llthresh, approx_hess)
matthiasm@8 134 if nargs >= 1 & ~isempty(args{1}), CPD = set_fields(CPD, 'weights', args{1}); end
matthiasm@8 135 if nargs >= 2 & ~isempty(args{2}), CPD = set_fields(CPD, 'offset', args{2}); end
matthiasm@8 136 if nargs >= 3 & ~isempty(args{3}), CPD = set_clamped(CPD, args{3}); end
matthiasm@8 137 if nargs >= 4 & ~isempty(args{4}), CPD.max_iter = args{4}; end
matthiasm@8 138 if nargs >= 5 & ~isempty(args{5}), CPD.verbose = args{5}; end
matthiasm@8 139 if nargs >= 6 & ~isempty(args{6}), CPD.wthresh = args{6}; end
matthiasm@8 140 if nargs >= 7 & ~isempty(args{7}), CPD.llthresh = args{7}; end
matthiasm@8 141 if nargs >= 8 & ~isempty(args{8}), CPD.approx_hess = args{8}; end
matthiasm@8 142 else
matthiasm@8 143 CPD = set_fields(CPD, args{:});
matthiasm@8 144 end
matthiasm@8 145 end
matthiasm@8 146
matthiasm@8 147 % sufficient statistics
matthiasm@8 148 % Since dsoftmax is not in the exponential family, we must store all the raw data.
matthiasm@8 149 CPD.parent_vals = []; % X(l,:) = value of cts parents in l'th example
matthiasm@8 150 CPD.self_vals = []; % Y(l,:) = value of self in l'th example
matthiasm@8 151
matthiasm@8 152 CPD.eso_weights=[]; % weights used by the WIRLS algorithm
matthiasm@8 153
matthiasm@8 154 % For BIC
matthiasm@8 155 CPD.nsamples = 0;
matthiasm@8 156 if ~adjustable_CPD(CPD),
matthiasm@8 157 CPD.nparams=0;
matthiasm@8 158 else
matthiasm@8 159 [W, b] = extract_params(CPD);
matthiasm@8 160 CPD.nparams= prod(size(W)) + prod(size(b));
matthiasm@8 161 end
matthiasm@8 162
matthiasm@8 163 %%%%%%%%%%%
matthiasm@8 164
matthiasm@8 165 function CPD = init_fields()
matthiasm@8 166 % This ensures we define the fields in the same order
matthiasm@8 167 % no matter whether we load an object from a file,
matthiasm@8 168 % or create it from scratch. (Matlab requires this.)
matthiasm@8 169
matthiasm@8 170 CPD.glim = {};
matthiasm@8 171 CPD.self = [];
matthiasm@8 172 CPD.solo = [];
matthiasm@8 173 CPD.max_iter = [];
matthiasm@8 174 CPD.verbose = [];
matthiasm@8 175 CPD.wthresh = [];
matthiasm@8 176 CPD.llthresh = [];
matthiasm@8 177 CPD.approx_hess = [];
matthiasm@8 178 CPD.sizes = [];
matthiasm@8 179 CPD.parent_vals = [];
matthiasm@8 180 CPD.eso_weights=[];
matthiasm@8 181 CPD.self_vals = [];
matthiasm@8 182 CPD.nsamples = [];
matthiasm@8 183 CPD.nparams = [];
matthiasm@8 184 CPD.dpndx = [];
matthiasm@8 185 CPD.cpndx = [];
matthiasm@8 186 CPD.dps_as_cps.ndx = [];
matthiasm@8 187 CPD.dps_as_cps.separator = [];