wolffd@0
|
1 function hdv = mlphdotv(net, x, t, v)
|
wolffd@0
|
2 %MLPHDOTV Evaluate the product of the data Hessian with a vector.
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % Description
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % HDV = MLPHDOTV(NET, X, T, V) takes an MLP network data structure NET,
|
wolffd@0
|
7 % together with the matrix X of input vectors, the matrix T of target
|
wolffd@0
|
8 % vectors and an arbitrary row vector V whose length equals the number
|
wolffd@0
|
9 % of parameters in the network, and returns the product of the data-
|
wolffd@0
|
10 % dependent contribution to the Hessian matrix with V. The
|
wolffd@0
|
11 % implementation is based on the R-propagation algorithm of
|
wolffd@0
|
12 % Pearlmutter.
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % See also
|
wolffd@0
|
15 % MLP, MLPHESS, HESSCHEK
|
wolffd@0
|
16 %
|
wolffd@0
|
17
|
wolffd@0
|
18 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
19
|
wolffd@0
|
20 % Check arguments for consistency
|
wolffd@0
|
21 errstring = consist(net, 'mlp', x, t);
|
wolffd@0
|
22 if ~isempty(errstring);
|
wolffd@0
|
23 error(errstring);
|
wolffd@0
|
24 end
|
wolffd@0
|
25
|
wolffd@0
|
26 ndata = size(x, 1);
|
wolffd@0
|
27
|
wolffd@0
|
28 [y, z] = mlpfwd(net, x); % Standard forward propagation.
|
wolffd@0
|
29 zprime = (1 - z.*z); % Hidden unit first derivatives.
|
wolffd@0
|
30 zpprime = -2.0*z.*zprime; % Hidden unit second derivatives.
|
wolffd@0
|
31
|
wolffd@0
|
32 vnet = mlpunpak(net, v); % Unpack the v vector.
|
wolffd@0
|
33
|
wolffd@0
|
34 % Do the R-forward propagation.
|
wolffd@0
|
35
|
wolffd@0
|
36 ra1 = x*vnet.w1 + ones(ndata, 1)*vnet.b1;
|
wolffd@0
|
37 rz = zprime.*ra1;
|
wolffd@0
|
38 ra2 = rz*net.w2 + z*vnet.w2 + ones(ndata, 1)*vnet.b2;
|
wolffd@0
|
39
|
wolffd@0
|
40 switch net.outfn
|
wolffd@0
|
41
|
wolffd@0
|
42 case 'linear' % Linear outputs
|
wolffd@0
|
43
|
wolffd@0
|
44 ry = ra2;
|
wolffd@0
|
45
|
wolffd@0
|
46 case 'logistic' % Logistic outputs
|
wolffd@0
|
47
|
wolffd@0
|
48 ry = y.*(1 - y).*ra2;
|
wolffd@0
|
49
|
wolffd@0
|
50 case 'softmax' % Softmax outputs
|
wolffd@0
|
51
|
wolffd@0
|
52 nout = size(t, 2);
|
wolffd@0
|
53 ry = y.*ra2 - y.*(sum(y.*ra2, 2)*ones(1, nout));
|
wolffd@0
|
54
|
wolffd@0
|
55 otherwise
|
wolffd@0
|
56 error(['Unknown activation function ', net.outfn]);
|
wolffd@0
|
57 end
|
wolffd@0
|
58
|
wolffd@0
|
59 % Evaluate delta for the output units.
|
wolffd@0
|
60
|
wolffd@0
|
61 delout = y - t;
|
wolffd@0
|
62
|
wolffd@0
|
63 % Do the standard backpropagation.
|
wolffd@0
|
64
|
wolffd@0
|
65 delhid = zprime.*(delout*net.w2');
|
wolffd@0
|
66
|
wolffd@0
|
67 % Now do the R-backpropagation.
|
wolffd@0
|
68
|
wolffd@0
|
69 rdelhid = zpprime.*ra1.*(delout*net.w2') + zprime.*(delout*vnet.w2') + ...
|
wolffd@0
|
70 zprime.*(ry*net.w2');
|
wolffd@0
|
71
|
wolffd@0
|
72 % Finally, evaluate the components of hdv and then merge into long vector.
|
wolffd@0
|
73
|
wolffd@0
|
74 hw1 = x'*rdelhid;
|
wolffd@0
|
75 hb1 = sum(rdelhid, 1);
|
wolffd@0
|
76 hw2 = z'*ry + rz'*delout;
|
wolffd@0
|
77 hb2 = sum(ry, 1);
|
wolffd@0
|
78
|
wolffd@0
|
79 hdv = [hw1(:)', hb1, hw2(:)', hb2];
|