Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/netlab3.3/mlphdotv.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/netlab3.3/mlphdotv.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,79 @@ +function hdv = mlphdotv(net, x, t, v) +%MLPHDOTV Evaluate the product of the data Hessian with a vector. +% +% Description +% +% HDV = MLPHDOTV(NET, X, T, V) takes an MLP network data structure NET, +% together with the matrix X of input vectors, the matrix T of target +% vectors and an arbitrary row vector V whose length equals the number +% of parameters in the network, and returns the product of the data- +% dependent contribution to the Hessian matrix with V. The +% implementation is based on the R-propagation algorithm of +% Pearlmutter. +% +% See also +% MLP, MLPHESS, HESSCHEK +% + +% Copyright (c) Ian T Nabney (1996-2001) + +% Check arguments for consistency +errstring = consist(net, 'mlp', x, t); +if ~isempty(errstring); + error(errstring); +end + +ndata = size(x, 1); + +[y, z] = mlpfwd(net, x); % Standard forward propagation. +zprime = (1 - z.*z); % Hidden unit first derivatives. +zpprime = -2.0*z.*zprime; % Hidden unit second derivatives. + +vnet = mlpunpak(net, v); % Unpack the v vector. + +% Do the R-forward propagation. + +ra1 = x*vnet.w1 + ones(ndata, 1)*vnet.b1; +rz = zprime.*ra1; +ra2 = rz*net.w2 + z*vnet.w2 + ones(ndata, 1)*vnet.b2; + +switch net.outfn + + case 'linear' % Linear outputs + + ry = ra2; + + case 'logistic' % Logistic outputs + + ry = y.*(1 - y).*ra2; + + case 'softmax' % Softmax outputs + + nout = size(t, 2); + ry = y.*ra2 - y.*(sum(y.*ra2, 2)*ones(1, nout)); + + otherwise + error(['Unknown activation function ', net.outfn]); +end + +% Evaluate delta for the output units. + +delout = y - t; + +% Do the standard backpropagation. + +delhid = zprime.*(delout*net.w2'); + +% Now do the R-backpropagation. + +rdelhid = zpprime.*ra1.*(delout*net.w2') + zprime.*(delout*vnet.w2') + ... + zprime.*(ry*net.w2'); + +% Finally, evaluate the components of hdv and then merge into long vector. + +hw1 = x'*rdelhid; +hb1 = sum(rdelhid, 1); +hw2 = z'*ry + rz'*delout; +hb2 = sum(ry, 1); + +hdv = [hw1(:)', hb1, hw2(:)', hb2];