view toolboxes/FullBNT-1.0.7/netlab3.3/mdnfwd.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [mixparams, y, z, a] = mdnfwd(net, x)
%MDNFWD	Forward propagation through Mixture Density Network.
%
%	Description
%	 MIXPARAMS = MDNFWD(NET, X) takes a mixture density network data
%	structure NET and a matrix X of input vectors, and forward propagates
%	the inputs through the network to generate a structure MIXPARAMS
%	which contains the parameters of several mixture models.   Each row
%	of X represents one input vector and the corresponding row of the
%	matrices in MIXPARAMS  represents the parameters of a mixture model
%	for the conditional probability of target vectors given the input
%	vector.  This is not represented as an array of GMM structures to
%	improve the efficiency of MDN training.
%
%	The fields in MIXPARAMS are
%	  type = 'mdnmixes'
%	  ncentres = number of mixture components
%	  dimtarget = dimension of target space
%	  mixcoeffs = mixing coefficients
%	  centres = means of Gaussians: stored as one row per pattern
%	  covars = covariances of Gaussians
%	  nparams = number of parameters
%
%	[MIXPARAMS, Y, Z] = MDNFWD(NET, X) also generates a matrix Y of the
%	outputs of the MLP and a matrix Z of the hidden unit activations
%	where each row corresponds to one pattern.
%
%	[MIXPARAMS, Y, Z, A] = MLPFWD(NET, X) also returns a matrix A  giving
%	the summed inputs to each output unit, where each row  corresponds to
%	one pattern.
%
%	See also
%	MDN, MDN2GMM, MDNERR, MDNGRAD, MLPFWD
%

%	Copyright (c) Ian T Nabney (1996-2001)
%	David J Evans (1998)

% Check arguments for consistency
errstring = consist(net, 'mdn', x);
if ~isempty(errstring)
  error(errstring);
end

% Extract mlp and mixture model descriptors
mlpnet = net.mlp;
mixes = net.mdnmixes;

ncentres = mixes.ncentres;	% Number of components in mixture model
dim_target = mixes.dim_target;	% Dimension of targets
nparams = mixes.nparams;	% Number of parameters in mixture model

% Propagate forwards through MLP
[y, z, a] = mlpfwd(mlpnet, x);

% Compute the postion for each parameter in the whole
% matrix.  Used to define the mixparams structure
mixcoeff  = [1:1:ncentres];
centres   = [ncentres+1:1:(ncentres*(1+dim_target))];
variances = [(ncentres*(1+dim_target)+1):1:nparams];

% Convert output values into mixture model parameters

% Use softmax to calculate priors
% Prevent overflow and underflow: use same bounds as glmfwd
% Ensure that sum(exp(y), 2) does not overflow
maxcut = log(realmax) - log(ncentres);
% Ensure that exp(y) > 0
mincut = log(realmin);
temp = min(y(:,1:ncentres), maxcut);
temp = max(temp, mincut);
temp = exp(temp);
mixpriors = temp./(sum(temp, 2)*ones(1,ncentres));

% Centres are just copies of network outputs
mixcentres =  y(:,(ncentres+1):ncentres*(1+dim_target));

% Variances are exp of network outputs
mixwidths = exp(y(:,(ncentres*(1+dim_target)+1):nparams));

% Now build up all the mixture model weight vectors
ndata = size(x, 1);

% Return parameters
mixparams.type = mixes.type;
mixparams.ncentres = mixes.ncentres;
mixparams.dim_target = mixes.dim_target;
mixparams.nparams = mixes.nparams;

mixparams.mixcoeffs = mixpriors;
mixparams.centres   = mixcentres;
mixparams.covars      = mixwidths;