Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/mdngrad.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function g = mdngrad(net, x, t) | |
2 %MDNGRAD Evaluate gradient of error function for Mixture Density Network. | |
3 % | |
4 % Description | |
5 % G = MDNGRAD(NET, X, T) takes a mixture density network data | |
6 % structure NET, a matrix X of input vectors and a matrix T of target | |
7 % vectors, and evaluates the gradient G of the error function with | |
8 % respect to the network weights. The error function is negative log | |
9 % likelihood of the target data. Each row of X corresponds to one | |
10 % input vector and each row of T corresponds to one target vector. | |
11 % | |
12 % See also | |
13 % MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP | |
14 % | |
15 | |
16 % Copyright (c) Ian T Nabney (1996-2001) | |
17 % David J Evans (1998) | |
18 | |
19 % Check arguments for consistency | |
20 errstring = consist(net, 'mdn', x, t); | |
21 if ~isempty(errstring) | |
22 error(errstring); | |
23 end | |
24 | |
25 [mixparams, y, z] = mdnfwd(net, x); | |
26 | |
27 % Compute gradients at MLP outputs: put the answer in deltas | |
28 ncentres = net.mdnmixes.ncentres; | |
29 dim_target = net.mdnmixes.dim_target; | |
30 nmixparams = net.mdnmixes.nparams; | |
31 ntarget = size(t, 1); | |
32 deltas = zeros(ntarget, net.mlp.nout); | |
33 e = ones(ncentres, 1); | |
34 f = ones(1, dim_target); | |
35 | |
36 post = mdnpost(mixparams, t); | |
37 | |
38 % Calculate prior derivatives | |
39 deltas(:,1:ncentres) = mixparams.mixcoeffs - post; | |
40 | |
41 % Calculate centre derivatives | |
42 long_t = kron(ones(1, ncentres), t); | |
43 centre_err = mixparams.centres - long_t; | |
44 | |
45 % Get the post to match each u_jk: | |
46 % this array will be (ntarget, (ncentres*dim_target)) | |
47 long_post = kron(ones(dim_target, 1), post); | |
48 long_post = reshape(long_post, ntarget, (ncentres*dim_target)); | |
49 | |
50 % Get the variance to match each u_jk: | |
51 var = mixparams.covars; | |
52 var = kron(ones(dim_target, 1), var); | |
53 var = reshape(var, ntarget, (ncentres*dim_target)); | |
54 | |
55 % Compute centre deltas | |
56 deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = ... | |
57 (centre_err.*long_post)./var; | |
58 | |
59 % Compute variance deltas | |
60 dist2 = mdndist2(mixparams, t); | |
61 c = dim_target*ones(ntarget, ncentres); | |
62 deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = ... | |
63 post.*((dist2./mixparams.covars)-c)./(-2); | |
64 | |
65 % Now back-propagate deltas through MLP | |
66 g = mlpbkp(net.mlp, x, z, deltas); |