Daniel@0: function g = mdngrad(net, x, t) Daniel@0: %MDNGRAD Evaluate gradient of error function for Mixture Density Network. Daniel@0: % Daniel@0: % Description Daniel@0: % G = MDNGRAD(NET, X, T) takes a mixture density network data Daniel@0: % structure NET, a matrix X of input vectors and a matrix T of target Daniel@0: % vectors, and evaluates the gradient G of the error function with Daniel@0: % respect to the network weights. The error function is negative log Daniel@0: % likelihood of the target data. Each row of X corresponds to one Daniel@0: % input vector and each row of T corresponds to one target vector. Daniel@0: % Daniel@0: % See also Daniel@0: % MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP Daniel@0: % Daniel@0: Daniel@0: % Copyright (c) Ian T Nabney (1996-2001) Daniel@0: % David J Evans (1998) Daniel@0: Daniel@0: % Check arguments for consistency Daniel@0: errstring = consist(net, 'mdn', x, t); Daniel@0: if ~isempty(errstring) Daniel@0: error(errstring); Daniel@0: end Daniel@0: Daniel@0: [mixparams, y, z] = mdnfwd(net, x); Daniel@0: Daniel@0: % Compute gradients at MLP outputs: put the answer in deltas Daniel@0: ncentres = net.mdnmixes.ncentres; Daniel@0: dim_target = net.mdnmixes.dim_target; Daniel@0: nmixparams = net.mdnmixes.nparams; Daniel@0: ntarget = size(t, 1); Daniel@0: deltas = zeros(ntarget, net.mlp.nout); Daniel@0: e = ones(ncentres, 1); Daniel@0: f = ones(1, dim_target); Daniel@0: Daniel@0: post = mdnpost(mixparams, t); Daniel@0: Daniel@0: % Calculate prior derivatives Daniel@0: deltas(:,1:ncentres) = mixparams.mixcoeffs - post; Daniel@0: Daniel@0: % Calculate centre derivatives Daniel@0: long_t = kron(ones(1, ncentres), t); Daniel@0: centre_err = mixparams.centres - long_t; Daniel@0: Daniel@0: % Get the post to match each u_jk: Daniel@0: % this array will be (ntarget, (ncentres*dim_target)) Daniel@0: long_post = kron(ones(dim_target, 1), post); Daniel@0: long_post = reshape(long_post, ntarget, (ncentres*dim_target)); Daniel@0: Daniel@0: % Get the variance to match each u_jk: Daniel@0: var = mixparams.covars; Daniel@0: var = kron(ones(dim_target, 1), var); Daniel@0: var = reshape(var, ntarget, (ncentres*dim_target)); Daniel@0: Daniel@0: % Compute centre deltas Daniel@0: deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = ... Daniel@0: (centre_err.*long_post)./var; Daniel@0: Daniel@0: % Compute variance deltas Daniel@0: dist2 = mdndist2(mixparams, t); Daniel@0: c = dim_target*ones(ntarget, ncentres); Daniel@0: deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = ... Daniel@0: post.*((dist2./mixparams.covars)-c)./(-2); Daniel@0: Daniel@0: % Now back-propagate deltas through MLP Daniel@0: g = mlpbkp(net.mlp, x, z, deltas);