Mercurial > hg > camir-aes2014
view toolboxes/FullBNT-1.0.7/netlabKPM/mlphdotv_weighted.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line source
function hdv = mlphdotv_weighted(net, x, t, eso_w, v) %MLPHDOTV Evaluate the product of the data Hessian with a vector. % % Description % % HDV = MLPHDOTV(NET, X, T, V) takes an MLP network data structure NET, % together with the matrix X of input vectors, the matrix T of target % vectors and an arbitrary row vector V whose length equals the number % of parameters in the network, and returns the product of the data- % dependent contribution to the Hessian matrix with V. The % implementation is based on the R-propagation algorithm of % Pearlmutter. % % See also % MLP, MLPHESS, HESSCHEK % % Copyright (c) Ian T Nabney (1996-9) % Check arguments for consistency errstring = consist(net, 'mlp', x, t); if ~isempty(errstring); error(errstring); end ndata = size(x, 1); [y, z] = mlpfwd(net, x); % Standard forward propagation. zprime = (1 - z.*z); % Hidden unit first derivatives. zpprime = -2.0*z.*zprime; % Hidden unit second derivatives. vnet = mlpunpak(net, v); % Unpack the v vector. % Do the R-forward propagation. ra1 = x*vnet.w1 + ones(ndata, 1)*vnet.b1; rz = zprime.*ra1; ra2 = rz*net.w2 + z*vnet.w2 + ones(ndata, 1)*vnet.b2; switch net.actfn case 'softmax' % Softmax outputs nout = size(t, 2); ry = y.*ra2 - y.*(sum(y.*ra2, 2)*ones(1, nout)); otherwise error(['Unknown activation function ', net.actfn]); end % Evaluate a weighted delta for the output units. temp = y - t; for m=1:ndata, delout(m,:)=eso_w(m,1)*temp(m,:); end clear temp; % Do the standard backpropagation. delhid = zprime.*(delout*net.w2'); % Now do the R-backpropagation. rdelhid = zpprime.*ra1.*(delout*net.w2') + zprime.*(delout*vnet.w2') + ... zprime.*(ry*net.w2'); % Finally, evaluate the components of hdv and then merge into long vector. hw1 = x'*rdelhid; hb1 = sum(rdelhid, 1); hw2 = z'*ry + rz'*delout; hb2 = sum(ry, 1); hdv = [hw1(:)', hb1, hw2(:)', hb2];