view toolboxes/FullBNT-1.0.7/netlabKPM/mlphdotv_weighted.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function hdv = mlphdotv_weighted(net, x, t, eso_w, v)
%MLPHDOTV Evaluate the product of the data Hessian with a vector. 
%
%	Description
%
%	HDV = MLPHDOTV(NET, X, T, V) takes an MLP network data structure NET,
%	together with the matrix X of input vectors, the matrix T of target
%	vectors and an arbitrary row vector V whose length equals the number
%	of parameters in the network, and returns the product of the data-
%	dependent contribution to the Hessian matrix with V. The
%	implementation is based on the R-propagation algorithm of
%	Pearlmutter.
%
%	See also
%	MLP, MLPHESS, HESSCHEK
%

%	Copyright (c) Ian T Nabney (1996-9)

% Check arguments for consistency
errstring = consist(net, 'mlp', x, t);
if ~isempty(errstring);
  error(errstring);
end

ndata = size(x, 1);

[y, z] = mlpfwd(net, x);		% Standard forward propagation.
zprime = (1 - z.*z);			% Hidden unit first derivatives.
zpprime = -2.0*z.*zprime;		% Hidden unit second derivatives.

vnet = mlpunpak(net, v);	% 		Unpack the v vector.

% Do the R-forward propagation.

ra1 = x*vnet.w1 + ones(ndata, 1)*vnet.b1;
rz = zprime.*ra1;
ra2 = rz*net.w2 + z*vnet.w2 + ones(ndata, 1)*vnet.b2;

switch net.actfn
  case 'softmax'     % Softmax outputs
  
    nout = size(t, 2);
    ry = y.*ra2 - y.*(sum(y.*ra2, 2)*ones(1, nout));

  otherwise
    error(['Unknown activation function ', net.actfn]);  
end

% Evaluate a weighted delta for the output units.
temp = y - t;
for m=1:ndata,
      delout(m,:)=eso_w(m,1)*temp(m,:);
end
clear temp;

% Do the standard backpropagation.

delhid = zprime.*(delout*net.w2');

% Now do the R-backpropagation.

rdelhid = zpprime.*ra1.*(delout*net.w2') + zprime.*(delout*vnet.w2') + ...
          zprime.*(ry*net.w2');

% Finally, evaluate the components of hdv and then merge into long vector.

hw1 = x'*rdelhid;
hb1 = sum(rdelhid, 1);
hw2 = z'*ry + rz'*delout;
hb2 = sum(ry, 1);

hdv = [hw1(:)', hb1, hw2(:)', hb2];