annotate toolboxes/FullBNT-1.0.7/bnt/CPDs/@mlp_CPD/mlp_CPD.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped, max_iter, verbose, wthresh, llthresh)
wolffd@0 2 % MLP_CPD Make a CPD from a Multi Layer Perceptron (i.e., feedforward neural network)
wolffd@0 3 %
wolffd@0 4 % We use a different MLP for each discrete parent combination (if there are any discrete parents).
wolffd@0 5 % We currently assume this node (the child) is discrete.
wolffd@0 6 %
wolffd@0 7 % CPD = mlp_CPD(bnet, self, nhidden)
wolffd@0 8 % will create a CPD with random parameters, where self is the number of this node and nhidden the number of the hidden nodes.
wolffd@0 9 % The params are drawn from N(0, s*I), where s = 1/sqrt(n+1), n = length(X).
wolffd@0 10 %
wolffd@0 11 % CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2) allows you to specify the params, where
wolffd@0 12 % w1 = first-layer weight matrix
wolffd@0 13 % b1 = first-layer bias vector
wolffd@0 14 % w2 = second-layer weight matrix
wolffd@0 15 % b2 = second-layer bias vector
wolffd@0 16 % These are assumed to be the same for each discrete parent combination.
wolffd@0 17 % If any of these are [], random values will be created.
wolffd@0 18 %
wolffd@0 19 % CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped) allows you to prevent the params from being
wolffd@0 20 % updated during learning (if clamped = 1). Default: clamped = 0.
wolffd@0 21 %
wolffd@0 22 % CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped, max_iter, verbose, wthresh, llthresh)
wolffd@0 23 % alllows you to specify params that control the M step:
wolffd@0 24 % max_iter - the maximum number of steps to take (default: 10)
wolffd@0 25 % verbose - controls whether to print (default: 0 means silent).
wolffd@0 26 % wthresh - a measure of the precision required for the value of
wolffd@0 27 % the weights W at the solution. Default: 1e-2.
wolffd@0 28 % llthresh - a measure of the precision required of the objective
wolffd@0 29 % function (log-likelihood) at the solution. Both this and the previous condition must
wolffd@0 30 % be satisfied for termination. Default: 1e-2.
wolffd@0 31 %
wolffd@0 32 % For learning, we use a weighted version of scaled conjugated gradient in the M step.
wolffd@0 33
wolffd@0 34 if nargin==0
wolffd@0 35 % This occurs if we are trying to load an object from a file.
wolffd@0 36 CPD = init_fields;
wolffd@0 37 CPD = class(CPD, 'mlp_CPD', discrete_CPD(0,[]));
wolffd@0 38 return;
wolffd@0 39 elseif isa(bnet, 'mlp_CPD')
wolffd@0 40 % This might occur if we are copying an object.
wolffd@0 41 CPD = bnet;
wolffd@0 42 return;
wolffd@0 43 end
wolffd@0 44 CPD = init_fields;
wolffd@0 45
wolffd@0 46 assert(myismember(self, bnet.dnodes));
wolffd@0 47 ns = bnet.node_sizes;
wolffd@0 48
wolffd@0 49 ps = parents(bnet.dag, self);
wolffd@0 50 dnodes = mysetdiff(1:length(bnet.dag), bnet.cnodes);
wolffd@0 51 dps = myintersect(ps, dnodes);
wolffd@0 52 cps = myintersect(ps, bnet.cnodes);
wolffd@0 53 dpsz = prod(ns(dps));
wolffd@0 54 cpsz = sum(ns(cps));
wolffd@0 55 self_size = ns(self);
wolffd@0 56
wolffd@0 57 % discrete/cts parent index - which ones of my parents are discrete/cts?
wolffd@0 58 CPD.dpndx = find_equiv_posns(dps, ps);
wolffd@0 59 CPD.cpndx = find_equiv_posns(cps, ps);
wolffd@0 60
wolffd@0 61 CPD.mlp = cell(1,dpsz);
wolffd@0 62 for i=1:dpsz
wolffd@0 63 CPD.mlp{i} = mlp(cpsz, nhidden, self_size, 'softmax');
wolffd@0 64 if nargin >=4 & ~isempty(w1)
wolffd@0 65 CPD.mlp{i}.w1 = w1;
wolffd@0 66 end
wolffd@0 67 if nargin >=5 & ~isempty(b1)
wolffd@0 68 CPD.mlp{i}.b1 = b1;
wolffd@0 69 end
wolffd@0 70 if nargin >=6 & ~isempty(w2)
wolffd@0 71 CPD.mlp{i}.w2 = w2;
wolffd@0 72 end
wolffd@0 73 if nargin >=7 & ~isempty(b2)
wolffd@0 74 CPD.mlp{i}.b2 = b2;
wolffd@0 75 end
wolffd@0 76 W1app(:,:,i)=CPD.mlp{i}.w1;
wolffd@0 77 W2app(:,:,i)=CPD.mlp{i}.w2;
wolffd@0 78 b1app(i,:)=CPD.mlp{i}.b1;
wolffd@0 79 b2app(i,:)=CPD.mlp{i}.b2;
wolffd@0 80 end
wolffd@0 81 if nargin < 8, clamped = 0; end
wolffd@0 82 if nargin < 9, max_iter = 10; end
wolffd@0 83 if nargin < 10, verbose = 0; end
wolffd@0 84 if nargin < 11, wthresh = 1e-2; end
wolffd@0 85 if nargin < 12, llthresh = 1e-2; end
wolffd@0 86
wolffd@0 87 CPD.self = self;
wolffd@0 88 CPD.max_iter = max_iter;
wolffd@0 89 CPD.verbose = verbose;
wolffd@0 90 CPD.wthresh = wthresh;
wolffd@0 91 CPD.llthresh = llthresh;
wolffd@0 92
wolffd@0 93 % sufficient statistics
wolffd@0 94 % Since MLP is not in the exponential family, we must store all the raw data.
wolffd@0 95 %
wolffd@0 96 CPD.W1=W1app; % Extract all the parameters of the node for handling discrete obs parents
wolffd@0 97 CPD.W2=W2app; %
wolffd@0 98 nparaW=[size(W1app) size(W2app)]; %
wolffd@0 99 CPD.b1=b1app; %
wolffd@0 100 CPD.b2=b2app; %
wolffd@0 101 nparab=[size(b1app) size(b2app)]; %
wolffd@0 102
wolffd@0 103 CPD.sizes=bnet.node_sizes(:); % used in CPD_to_table to pump up the node sizes
wolffd@0 104
wolffd@0 105 CPD.parent_vals = []; % X(l,:) = value of cts parents in l'th example
wolffd@0 106
wolffd@0 107 CPD.eso_weights=[]; % weights used by the SCG algorithm
wolffd@0 108
wolffd@0 109 CPD.self_vals = []; % Y(l,:) = value of self in l'th example
wolffd@0 110
wolffd@0 111 % For BIC
wolffd@0 112 CPD.nsamples = 0;
wolffd@0 113 CPD.nparams=prod(nparaW)+prod(nparab);
wolffd@0 114 CPD = class(CPD, 'mlp_CPD', discrete_CPD(clamped, ns([ps self])));
wolffd@0 115
wolffd@0 116 %%%%%%%%%%%
wolffd@0 117
wolffd@0 118 function CPD = init_fields()
wolffd@0 119 % This ensures we define the fields in the same order
wolffd@0 120 % no matter whether we load an object from a file,
wolffd@0 121 % or create it from scratch. (Matlab requires this.)
wolffd@0 122
wolffd@0 123 CPD.mlp = {};
wolffd@0 124 CPD.self = [];
wolffd@0 125 CPD.max_iter = [];
wolffd@0 126 CPD.verbose = [];
wolffd@0 127 CPD.wthresh = [];
wolffd@0 128 CPD.llthresh = [];
wolffd@0 129 CPD.approx_hess = [];
wolffd@0 130 CPD.W1 = [];
wolffd@0 131 CPD.W2 = [];
wolffd@0 132 CPD.b1 = [];
wolffd@0 133 CPD.b2 = [];
wolffd@0 134 CPD.sizes = [];
wolffd@0 135 CPD.parent_vals = [];
wolffd@0 136 CPD.eso_weights=[];
wolffd@0 137 CPD.self_vals = [];
wolffd@0 138 CPD.nsamples = [];
wolffd@0 139 CPD.nparams = [];