annotate toolboxes/FullBNT-1.0.7/netlab3.3/mdngrad.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function g = mdngrad(net, x, t)
wolffd@0 2 %MDNGRAD Evaluate gradient of error function for Mixture Density Network.
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % G = MDNGRAD(NET, X, T) takes a mixture density network data
wolffd@0 6 % structure NET, a matrix X of input vectors and a matrix T of target
wolffd@0 7 % vectors, and evaluates the gradient G of the error function with
wolffd@0 8 % respect to the network weights. The error function is negative log
wolffd@0 9 % likelihood of the target data. Each row of X corresponds to one
wolffd@0 10 % input vector and each row of T corresponds to one target vector.
wolffd@0 11 %
wolffd@0 12 % See also
wolffd@0 13 % MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP
wolffd@0 14 %
wolffd@0 15
wolffd@0 16 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 17 % David J Evans (1998)
wolffd@0 18
wolffd@0 19 % Check arguments for consistency
wolffd@0 20 errstring = consist(net, 'mdn', x, t);
wolffd@0 21 if ~isempty(errstring)
wolffd@0 22 error(errstring);
wolffd@0 23 end
wolffd@0 24
wolffd@0 25 [mixparams, y, z] = mdnfwd(net, x);
wolffd@0 26
wolffd@0 27 % Compute gradients at MLP outputs: put the answer in deltas
wolffd@0 28 ncentres = net.mdnmixes.ncentres;
wolffd@0 29 dim_target = net.mdnmixes.dim_target;
wolffd@0 30 nmixparams = net.mdnmixes.nparams;
wolffd@0 31 ntarget = size(t, 1);
wolffd@0 32 deltas = zeros(ntarget, net.mlp.nout);
wolffd@0 33 e = ones(ncentres, 1);
wolffd@0 34 f = ones(1, dim_target);
wolffd@0 35
wolffd@0 36 post = mdnpost(mixparams, t);
wolffd@0 37
wolffd@0 38 % Calculate prior derivatives
wolffd@0 39 deltas(:,1:ncentres) = mixparams.mixcoeffs - post;
wolffd@0 40
wolffd@0 41 % Calculate centre derivatives
wolffd@0 42 long_t = kron(ones(1, ncentres), t);
wolffd@0 43 centre_err = mixparams.centres - long_t;
wolffd@0 44
wolffd@0 45 % Get the post to match each u_jk:
wolffd@0 46 % this array will be (ntarget, (ncentres*dim_target))
wolffd@0 47 long_post = kron(ones(dim_target, 1), post);
wolffd@0 48 long_post = reshape(long_post, ntarget, (ncentres*dim_target));
wolffd@0 49
wolffd@0 50 % Get the variance to match each u_jk:
wolffd@0 51 var = mixparams.covars;
wolffd@0 52 var = kron(ones(dim_target, 1), var);
wolffd@0 53 var = reshape(var, ntarget, (ncentres*dim_target));
wolffd@0 54
wolffd@0 55 % Compute centre deltas
wolffd@0 56 deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = ...
wolffd@0 57 (centre_err.*long_post)./var;
wolffd@0 58
wolffd@0 59 % Compute variance deltas
wolffd@0 60 dist2 = mdndist2(mixparams, t);
wolffd@0 61 c = dim_target*ones(ntarget, ncentres);
wolffd@0 62 deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = ...
wolffd@0 63 post.*((dist2./mixparams.covars)-c)./(-2);
wolffd@0 64
wolffd@0 65 % Now back-propagate deltas through MLP
wolffd@0 66 g = mlpbkp(net.mlp, x, z, deltas);