annotate toolboxes/FullBNT-1.0.7/netlab3.3/mdnfwd.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mixparams, y, z, a] = mdnfwd(net, x)
wolffd@0 2 %MDNFWD Forward propagation through Mixture Density Network.
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % MIXPARAMS = MDNFWD(NET, X) takes a mixture density network data
wolffd@0 6 % structure NET and a matrix X of input vectors, and forward propagates
wolffd@0 7 % the inputs through the network to generate a structure MIXPARAMS
wolffd@0 8 % which contains the parameters of several mixture models. Each row
wolffd@0 9 % of X represents one input vector and the corresponding row of the
wolffd@0 10 % matrices in MIXPARAMS represents the parameters of a mixture model
wolffd@0 11 % for the conditional probability of target vectors given the input
wolffd@0 12 % vector. This is not represented as an array of GMM structures to
wolffd@0 13 % improve the efficiency of MDN training.
wolffd@0 14 %
wolffd@0 15 % The fields in MIXPARAMS are
wolffd@0 16 % type = 'mdnmixes'
wolffd@0 17 % ncentres = number of mixture components
wolffd@0 18 % dimtarget = dimension of target space
wolffd@0 19 % mixcoeffs = mixing coefficients
wolffd@0 20 % centres = means of Gaussians: stored as one row per pattern
wolffd@0 21 % covars = covariances of Gaussians
wolffd@0 22 % nparams = number of parameters
wolffd@0 23 %
wolffd@0 24 % [MIXPARAMS, Y, Z] = MDNFWD(NET, X) also generates a matrix Y of the
wolffd@0 25 % outputs of the MLP and a matrix Z of the hidden unit activations
wolffd@0 26 % where each row corresponds to one pattern.
wolffd@0 27 %
wolffd@0 28 % [MIXPARAMS, Y, Z, A] = MLPFWD(NET, X) also returns a matrix A giving
wolffd@0 29 % the summed inputs to each output unit, where each row corresponds to
wolffd@0 30 % one pattern.
wolffd@0 31 %
wolffd@0 32 % See also
wolffd@0 33 % MDN, MDN2GMM, MDNERR, MDNGRAD, MLPFWD
wolffd@0 34 %
wolffd@0 35
wolffd@0 36 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 37 % David J Evans (1998)
wolffd@0 38
wolffd@0 39 % Check arguments for consistency
wolffd@0 40 errstring = consist(net, 'mdn', x);
wolffd@0 41 if ~isempty(errstring)
wolffd@0 42 error(errstring);
wolffd@0 43 end
wolffd@0 44
wolffd@0 45 % Extract mlp and mixture model descriptors
wolffd@0 46 mlpnet = net.mlp;
wolffd@0 47 mixes = net.mdnmixes;
wolffd@0 48
wolffd@0 49 ncentres = mixes.ncentres; % Number of components in mixture model
wolffd@0 50 dim_target = mixes.dim_target; % Dimension of targets
wolffd@0 51 nparams = mixes.nparams; % Number of parameters in mixture model
wolffd@0 52
wolffd@0 53 % Propagate forwards through MLP
wolffd@0 54 [y, z, a] = mlpfwd(mlpnet, x);
wolffd@0 55
wolffd@0 56 % Compute the postion for each parameter in the whole
wolffd@0 57 % matrix. Used to define the mixparams structure
wolffd@0 58 mixcoeff = [1:1:ncentres];
wolffd@0 59 centres = [ncentres+1:1:(ncentres*(1+dim_target))];
wolffd@0 60 variances = [(ncentres*(1+dim_target)+1):1:nparams];
wolffd@0 61
wolffd@0 62 % Convert output values into mixture model parameters
wolffd@0 63
wolffd@0 64 % Use softmax to calculate priors
wolffd@0 65 % Prevent overflow and underflow: use same bounds as glmfwd
wolffd@0 66 % Ensure that sum(exp(y), 2) does not overflow
wolffd@0 67 maxcut = log(realmax) - log(ncentres);
wolffd@0 68 % Ensure that exp(y) > 0
wolffd@0 69 mincut = log(realmin);
wolffd@0 70 temp = min(y(:,1:ncentres), maxcut);
wolffd@0 71 temp = max(temp, mincut);
wolffd@0 72 temp = exp(temp);
wolffd@0 73 mixpriors = temp./(sum(temp, 2)*ones(1,ncentres));
wolffd@0 74
wolffd@0 75 % Centres are just copies of network outputs
wolffd@0 76 mixcentres = y(:,(ncentres+1):ncentres*(1+dim_target));
wolffd@0 77
wolffd@0 78 % Variances are exp of network outputs
wolffd@0 79 mixwidths = exp(y(:,(ncentres*(1+dim_target)+1):nparams));
wolffd@0 80
wolffd@0 81 % Now build up all the mixture model weight vectors
wolffd@0 82 ndata = size(x, 1);
wolffd@0 83
wolffd@0 84 % Return parameters
wolffd@0 85 mixparams.type = mixes.type;
wolffd@0 86 mixparams.ncentres = mixes.ncentres;
wolffd@0 87 mixparams.dim_target = mixes.dim_target;
wolffd@0 88 mixparams.nparams = mixes.nparams;
wolffd@0 89
wolffd@0 90 mixparams.mixcoeffs = mixpriors;
wolffd@0 91 mixparams.centres = mixcentres;
wolffd@0 92 mixparams.covars = mixwidths;
wolffd@0 93