annotate toolboxes/FullBNT-1.0.7/netlab3.3/mdninit.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function net = mdninit(net, prior, t, options)
wolffd@0 2 %MDNINIT Initialise the weights in a Mixture Density Network.
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 %
wolffd@0 6 % NET = MDNINIT(NET, PRIOR) takes a Mixture Density Network NET and
wolffd@0 7 % sets the weights and biases by sampling from a Gaussian distribution.
wolffd@0 8 % It calls MLPINIT for the MLP component of NET.
wolffd@0 9 %
wolffd@0 10 % NET = MDNINIT(NET, PRIOR, T, OPTIONS) uses the target data T to
wolffd@0 11 % initialise the biases for the output units after initialising the
wolffd@0 12 % other weights as above. It calls GMMINIT, with T and OPTIONS as
wolffd@0 13 % arguments, to obtain a model of the unconditional density of T. The
wolffd@0 14 % biases are then set so that NET will output the values in the
wolffd@0 15 % Gaussian mixture model.
wolffd@0 16 %
wolffd@0 17 % See also
wolffd@0 18 % MDN, MLP, MLPINIT, GMMINIT
wolffd@0 19 %
wolffd@0 20
wolffd@0 21 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 22 % David J Evans (1998)
wolffd@0 23
wolffd@0 24 % Initialise network weights from prior: this gives noise around values
wolffd@0 25 % determined later
wolffd@0 26 net.mlp = mlpinit(net.mlp, prior);
wolffd@0 27
wolffd@0 28 if nargin > 2
wolffd@0 29 % Initialise priors, centres and variances from target data
wolffd@0 30 temp_mix = gmm(net.mdnmixes.dim_target, net.mdnmixes.ncentres, 'spherical');
wolffd@0 31 temp_mix = gmminit(temp_mix, t, options);
wolffd@0 32
wolffd@0 33 ncentres = net.mdnmixes.ncentres;
wolffd@0 34 dim_target = net.mdnmixes.dim_target;
wolffd@0 35
wolffd@0 36 % Now set parameters in MLP to yield the right values.
wolffd@0 37 % This involves setting the biases correctly.
wolffd@0 38
wolffd@0 39 % Priors
wolffd@0 40 net.mlp.b2(1:ncentres) = temp_mix.priors;
wolffd@0 41
wolffd@0 42 % Centres are arranged in mlp such that we have
wolffd@0 43 % u11, u12, u13, ..., u1c, ... , uj1, uj2, uj3, ..., ujc, ..., um1, uM2,
wolffd@0 44 % ..., uMc
wolffd@0 45 % This is achieved by transposing temp_mix.centres before reshaping
wolffd@0 46 end_centres = ncentres*(dim_target+1);
wolffd@0 47 net.mlp.b2(ncentres+1:end_centres) = ...
wolffd@0 48 reshape(temp_mix.centres', 1, ncentres*dim_target);
wolffd@0 49
wolffd@0 50 % Variances
wolffd@0 51 net.mlp.b2((end_centres+1):net.mlp.nout) = ...
wolffd@0 52 log(temp_mix.covars);
wolffd@0 53 end