To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @mlp_CPD / mlp_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (4.69 KB)

1
function CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped, max_iter, verbose, wthresh,  llthresh)
2
% MLP_CPD Make a CPD from a Multi Layer Perceptron (i.e., feedforward neural network)
3
%
4
% We use a different MLP for each discrete parent combination (if there are any discrete parents).
5
% We currently assume this node (the child) is discrete.
6
%
7
% CPD = mlp_CPD(bnet, self, nhidden)
8
% will create a CPD with random parameters, where self is the number of this node and nhidden the number of the hidden nodes.
9
% The params are drawn from N(0, s*I), where s = 1/sqrt(n+1), n = length(X).
10
%
11
% CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2) allows you to specify the params, where
12
%  w1 = first-layer weight matrix
13
%  b1 = first-layer bias vector
14
%  w2 = second-layer weight matrix
15
%  b2 = second-layer bias vector
16
% These are assumed to be the same for each discrete parent combination.
17
% If any of these are [], random values will be created.
18
%
19
% CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped) allows you to prevent the params from being
20
% updated during learning (if clamped = 1). Default: clamped = 0.
21
%
22
% CPD = mlp_CPD(bnet, self, nhidden, w1, b1, w2, b2, clamped, max_iter, verbose, wthresh,  llthresh)
23
% alllows you to specify params that control the M step:
24
%  max_iter - the maximum number of steps to take (default: 10)
25
%  verbose - controls whether to print (default: 0 means silent).
26
%  wthresh - a measure of the precision required for the value of
27
%     the weights W at the solution. Default: 1e-2.
28
%  llthresh - a measure of the precision required of the objective
29
%     function (log-likelihood) at the solution.  Both this and the previous condition must
30
%     be satisfied for termination. Default: 1e-2.
31
%
32
% For learning, we use a weighted version of scaled conjugated gradient in the M step.
33

    
34
if nargin==0
35
  % This occurs if we are trying to load an object from a file.
36
  CPD = init_fields;
37
  CPD = class(CPD, 'mlp_CPD', discrete_CPD(0,[]));
38
  return;
39
elseif isa(bnet, 'mlp_CPD')
40
  % This might occur if we are copying an object.
41
  CPD = bnet;
42
  return;
43
end
44
CPD = init_fields;
45

    
46
assert(myismember(self, bnet.dnodes));
47
ns = bnet.node_sizes;
48

    
49
ps = parents(bnet.dag, self);
50
dnodes = mysetdiff(1:length(bnet.dag), bnet.cnodes);
51
dps = myintersect(ps, dnodes);
52
cps = myintersect(ps, bnet.cnodes);
53
dpsz = prod(ns(dps));
54
cpsz = sum(ns(cps));
55
self_size = ns(self);
56

    
57
% discrete/cts parent index - which ones of my parents are discrete/cts?
58
CPD.dpndx = find_equiv_posns(dps, ps); 
59
CPD.cpndx = find_equiv_posns(cps, ps);
60

    
61
CPD.mlp = cell(1,dpsz);
62
for i=1:dpsz
63
    CPD.mlp{i} = mlp(cpsz, nhidden, self_size, 'softmax');
64
    if nargin >=4 & ~isempty(w1)
65
        CPD.mlp{i}.w1 = w1;
66
    end
67
    if nargin >=5 & ~isempty(b1)
68
        CPD.mlp{i}.b1 = b1; 
69
    end
70
    if nargin >=6 & ~isempty(w2)
71
        CPD.mlp{i}.w2 = w2; 
72
    end
73
    if nargin >=7 & ~isempty(b2)
74
        CPD.mlp{i}.b2 = b2; 
75
    end
76
    W1app(:,:,i)=CPD.mlp{i}.w1;
77
    W2app(:,:,i)=CPD.mlp{i}.w2;
78
    b1app(i,:)=CPD.mlp{i}.b1;
79
    b2app(i,:)=CPD.mlp{i}.b2;
80
end
81
if nargin < 8, clamped = 0; end
82
if nargin < 9, max_iter = 10; end
83
if nargin < 10, verbose = 0; end
84
if nargin < 11, wthresh = 1e-2; end
85
if nargin < 12, llthresh = 1e-2; end
86

    
87
CPD.self = self;
88
CPD.max_iter = max_iter;
89
CPD.verbose = verbose;
90
CPD.wthresh = wthresh;
91
CPD.llthresh = llthresh;
92

    
93
% sufficient statistics 
94
% Since MLP is not in the exponential family, we must store all the raw data.
95
%
96
CPD.W1=W1app;                     % Extract all the parameters of the node for handling discrete obs parents
97
CPD.W2=W2app;                     %
98
nparaW=[size(W1app) size(W2app)]; %
99
CPD.b1=b1app;                     %
100
CPD.b2=b2app;                     %
101
nparab=[size(b1app) size(b2app)]; %
102

    
103
CPD.sizes=bnet.node_sizes(:);   % used in CPD_to_table to pump up the node sizes
104

    
105
CPD.parent_vals = [];        % X(l,:) = value of cts parents in l'th example
106

    
107
CPD.eso_weights=[];          % weights used by the SCG algorithm 
108

    
109
CPD.self_vals = [];          % Y(l,:) = value of self in l'th example
110

    
111
% For BIC
112
CPD.nsamples = 0;   
113
CPD.nparams=prod(nparaW)+prod(nparab);
114
CPD = class(CPD, 'mlp_CPD', discrete_CPD(clamped, ns([ps self])));
115

    
116
%%%%%%%%%%%
117

    
118
function CPD = init_fields()
119
% This ensures we define the fields in the same order 
120
% no matter whether we load an object from a file,
121
% or create it from scratch. (Matlab requires this.)
122

    
123
CPD.mlp = {};
124
CPD.self = [];
125
CPD.max_iter = [];
126
CPD.verbose = [];
127
CPD.wthresh = [];
128
CPD.llthresh = [];
129
CPD.approx_hess = [];
130
CPD.W1 = [];
131
CPD.W2 = [];
132
CPD.b1 = [];
133
CPD.b2 = [];
134
CPD.sizes = [];
135
CPD.parent_vals = [];
136
CPD.eso_weights=[];
137
CPD.self_vals = [];
138
CPD.nsamples = [];
139
CPD.nparams = [];