annotate toolboxes/FullBNT-1.0.7/netlab3.3/mdngrad.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function g = mdngrad(net, x, t)
Daniel@0 2 %MDNGRAD Evaluate gradient of error function for Mixture Density Network.
Daniel@0 3 %
Daniel@0 4 % Description
Daniel@0 5 % G = MDNGRAD(NET, X, T) takes a mixture density network data
Daniel@0 6 % structure NET, a matrix X of input vectors and a matrix T of target
Daniel@0 7 % vectors, and evaluates the gradient G of the error function with
Daniel@0 8 % respect to the network weights. The error function is negative log
Daniel@0 9 % likelihood of the target data. Each row of X corresponds to one
Daniel@0 10 % input vector and each row of T corresponds to one target vector.
Daniel@0 11 %
Daniel@0 12 % See also
Daniel@0 13 % MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP
Daniel@0 14 %
Daniel@0 15
Daniel@0 16 % Copyright (c) Ian T Nabney (1996-2001)
Daniel@0 17 % David J Evans (1998)
Daniel@0 18
Daniel@0 19 % Check arguments for consistency
Daniel@0 20 errstring = consist(net, 'mdn', x, t);
Daniel@0 21 if ~isempty(errstring)
Daniel@0 22 error(errstring);
Daniel@0 23 end
Daniel@0 24
Daniel@0 25 [mixparams, y, z] = mdnfwd(net, x);
Daniel@0 26
Daniel@0 27 % Compute gradients at MLP outputs: put the answer in deltas
Daniel@0 28 ncentres = net.mdnmixes.ncentres;
Daniel@0 29 dim_target = net.mdnmixes.dim_target;
Daniel@0 30 nmixparams = net.mdnmixes.nparams;
Daniel@0 31 ntarget = size(t, 1);
Daniel@0 32 deltas = zeros(ntarget, net.mlp.nout);
Daniel@0 33 e = ones(ncentres, 1);
Daniel@0 34 f = ones(1, dim_target);
Daniel@0 35
Daniel@0 36 post = mdnpost(mixparams, t);
Daniel@0 37
Daniel@0 38 % Calculate prior derivatives
Daniel@0 39 deltas(:,1:ncentres) = mixparams.mixcoeffs - post;
Daniel@0 40
Daniel@0 41 % Calculate centre derivatives
Daniel@0 42 long_t = kron(ones(1, ncentres), t);
Daniel@0 43 centre_err = mixparams.centres - long_t;
Daniel@0 44
Daniel@0 45 % Get the post to match each u_jk:
Daniel@0 46 % this array will be (ntarget, (ncentres*dim_target))
Daniel@0 47 long_post = kron(ones(dim_target, 1), post);
Daniel@0 48 long_post = reshape(long_post, ntarget, (ncentres*dim_target));
Daniel@0 49
Daniel@0 50 % Get the variance to match each u_jk:
Daniel@0 51 var = mixparams.covars;
Daniel@0 52 var = kron(ones(dim_target, 1), var);
Daniel@0 53 var = reshape(var, ntarget, (ncentres*dim_target));
Daniel@0 54
Daniel@0 55 % Compute centre deltas
Daniel@0 56 deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = ...
Daniel@0 57 (centre_err.*long_post)./var;
Daniel@0 58
Daniel@0 59 % Compute variance deltas
Daniel@0 60 dist2 = mdndist2(mixparams, t);
Daniel@0 61 c = dim_target*ones(ntarget, ncentres);
Daniel@0 62 deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = ...
Daniel@0 63 post.*((dist2./mixparams.covars)-c)./(-2);
Daniel@0 64
Daniel@0 65 % Now back-propagate deltas through MLP
Daniel@0 66 g = mlpbkp(net.mlp, x, z, deltas);