Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/gtmem.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [net, options, errlog] = gtmem(net, t, options) | |
2 %GTMEM EM algorithm for Generative Topographic Mapping. | |
3 % | |
4 % Description | |
5 % [NET, OPTIONS, ERRLOG] = GTMEM(NET, T, OPTIONS) uses the Expectation | |
6 % Maximization algorithm to estimate the parameters of a GTM defined by | |
7 % a data structure NET. The matrix T represents the data whose | |
8 % expectation is maximized, with each row corresponding to a vector. | |
9 % It is assumed that the latent data NET.X has been set following a | |
10 % call to GTMINIT, for example. The optional parameters have the | |
11 % following interpretations. | |
12 % | |
13 % OPTIONS(1) is set to 1 to display error values; also logs error | |
14 % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then | |
15 % only warning messages are displayed. If OPTIONS(1) is -1, then | |
16 % nothing is displayed. | |
17 % | |
18 % OPTIONS(3) is a measure of the absolute precision required of the | |
19 % error function at the solution. If the change in log likelihood | |
20 % between two steps of the EM algorithm is less than this value, then | |
21 % the function terminates. | |
22 % | |
23 % OPTIONS(14) is the maximum number of iterations; default 100. | |
24 % | |
25 % The optional return value OPTIONS contains the final error value | |
26 % (i.e. data log likelihood) in OPTIONS(8). | |
27 % | |
28 % See also | |
29 % GTM, GTMINIT | |
30 % | |
31 | |
32 % Copyright (c) Ian T Nabney (1996-2001) | |
33 | |
34 % Check that inputs are consistent | |
35 errstring = consist(net, 'gtm', t); | |
36 if ~isempty(errstring) | |
37 error(errstring); | |
38 end | |
39 | |
40 % Sort out the options | |
41 if (options(14)) | |
42 niters = options(14); | |
43 else | |
44 niters = 100; | |
45 end | |
46 | |
47 display = options(1); | |
48 store = 0; | |
49 if (nargout > 2) | |
50 store = 1; % Store the error values to return them | |
51 errlog = zeros(1, niters); | |
52 end | |
53 test = 0; | |
54 if options(3) > 0.0 | |
55 test = 1; % Test log likelihood for termination | |
56 end | |
57 | |
58 % Calculate various quantities that remain constant during training | |
59 [ndata, tdim] = size(t); | |
60 ND = ndata*tdim; | |
61 [net.gmmnet.centres, Phi] = rbffwd(net.rbfnet, net.X); | |
62 Phi = [Phi ones(size(net.X, 1), 1)]; | |
63 PhiT = Phi'; | |
64 [K, Mplus1] = size(Phi); | |
65 | |
66 A = zeros(Mplus1, Mplus1); | |
67 cholDcmp = zeros(Mplus1, Mplus1); | |
68 % Use a sparse representation for the weight regularizing matrix. | |
69 if (net.rbfnet.alpha > 0) | |
70 Alpha = net.rbfnet.alpha*speye(Mplus1); | |
71 Alpha(Mplus1, Mplus1) = 0; | |
72 end | |
73 | |
74 for n = 1:niters | |
75 % Calculate responsibilities | |
76 [R, act] = gtmpost(net, t); | |
77 % Calculate error value if needed | |
78 if (display | store | test) | |
79 prob = act*(net.gmmnet.priors)'; | |
80 % Error value is negative log likelihood of data | |
81 e = - sum(log(max(prob,eps))); | |
82 if store | |
83 errlog(n) = e; | |
84 end | |
85 if display > 0 | |
86 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); | |
87 end | |
88 if test | |
89 if (n > 1 & abs(e - eold) < options(3)) | |
90 options(8) = e; | |
91 return; | |
92 else | |
93 eold = e; | |
94 end | |
95 end | |
96 end | |
97 | |
98 % Calculate matrix be inverted (Phi'*G*Phi + alpha*I in the papers). | |
99 % Sparse representation of G normally executes faster and saves | |
100 % memory | |
101 if (net.rbfnet.alpha > 0) | |
102 A = full(PhiT*spdiags(sum(R)', 0, K, K)*Phi + ... | |
103 (Alpha.*net.gmmnet.covars(1))); | |
104 else | |
105 A = full(PhiT*spdiags(sum(R)', 0, K, K)*Phi); | |
106 end | |
107 % A is a symmetric matrix likely to be positive definite, so try | |
108 % fast Cholesky decomposition to calculate W, otherwise use SVD. | |
109 % (PhiT*(R*t)) is computed right-to-left, as R | |
110 % and t are normally (much) larger than PhiT. | |
111 [cholDcmp singular] = chol(A); | |
112 if (singular) | |
113 if (display) | |
114 fprintf(1, ... | |
115 'gtmem: Warning -- M-Step matrix singular, using pinv.\n'); | |
116 end | |
117 W = pinv(A)*(PhiT*(R'*t)); | |
118 else | |
119 W = cholDcmp \ (cholDcmp' \ (PhiT*(R'*t))); | |
120 end | |
121 % Put new weights into network to calculate responsibilities | |
122 % net.rbfnet = netunpak(net.rbfnet, W); | |
123 net.rbfnet.w2 = W(1:net.rbfnet.nhidden, :); | |
124 net.rbfnet.b2 = W(net.rbfnet.nhidden+1, :); | |
125 % Calculate new distances | |
126 d = dist2(t, Phi*W); | |
127 | |
128 % Calculate new value for beta | |
129 net.gmmnet.covars = ones(1, net.gmmnet.ncentres)*(sum(sum(d.*R))/ND); | |
130 end | |
131 | |
132 options(8) = -sum(log(gtmprob(net, t))); | |
133 if (display >= 0) | |
134 disp(maxitmess); | |
135 end |