wolffd@0: function g = mdngrad(net, x, t) wolffd@0: %MDNGRAD Evaluate gradient of error function for Mixture Density Network. wolffd@0: % wolffd@0: % Description wolffd@0: % G = MDNGRAD(NET, X, T) takes a mixture density network data wolffd@0: % structure NET, a matrix X of input vectors and a matrix T of target wolffd@0: % vectors, and evaluates the gradient G of the error function with wolffd@0: % respect to the network weights. The error function is negative log wolffd@0: % likelihood of the target data. Each row of X corresponds to one wolffd@0: % input vector and each row of T corresponds to one target vector. wolffd@0: % wolffd@0: % See also wolffd@0: % MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: % David J Evans (1998) wolffd@0: wolffd@0: % Check arguments for consistency wolffd@0: errstring = consist(net, 'mdn', x, t); wolffd@0: if ~isempty(errstring) wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: [mixparams, y, z] = mdnfwd(net, x); wolffd@0: wolffd@0: % Compute gradients at MLP outputs: put the answer in deltas wolffd@0: ncentres = net.mdnmixes.ncentres; wolffd@0: dim_target = net.mdnmixes.dim_target; wolffd@0: nmixparams = net.mdnmixes.nparams; wolffd@0: ntarget = size(t, 1); wolffd@0: deltas = zeros(ntarget, net.mlp.nout); wolffd@0: e = ones(ncentres, 1); wolffd@0: f = ones(1, dim_target); wolffd@0: wolffd@0: post = mdnpost(mixparams, t); wolffd@0: wolffd@0: % Calculate prior derivatives wolffd@0: deltas(:,1:ncentres) = mixparams.mixcoeffs - post; wolffd@0: wolffd@0: % Calculate centre derivatives wolffd@0: long_t = kron(ones(1, ncentres), t); wolffd@0: centre_err = mixparams.centres - long_t; wolffd@0: wolffd@0: % Get the post to match each u_jk: wolffd@0: % this array will be (ntarget, (ncentres*dim_target)) wolffd@0: long_post = kron(ones(dim_target, 1), post); wolffd@0: long_post = reshape(long_post, ntarget, (ncentres*dim_target)); wolffd@0: wolffd@0: % Get the variance to match each u_jk: wolffd@0: var = mixparams.covars; wolffd@0: var = kron(ones(dim_target, 1), var); wolffd@0: var = reshape(var, ntarget, (ncentres*dim_target)); wolffd@0: wolffd@0: % Compute centre deltas wolffd@0: deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = ... wolffd@0: (centre_err.*long_post)./var; wolffd@0: wolffd@0: % Compute variance deltas wolffd@0: dist2 = mdndist2(mixparams, t); wolffd@0: c = dim_target*ones(ntarget, ncentres); wolffd@0: deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = ... wolffd@0: post.*((dist2./mixparams.covars)-c)./(-2); wolffd@0: wolffd@0: % Now back-propagate deltas through MLP wolffd@0: g = mlpbkp(net.mlp, x, z, deltas);