Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/Kalman/learn_kalman.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [A, C, Q, R, initx, initV, LL] = ... | |
2 learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin) | |
3 % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM. | |
4 % | |
5 % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits | |
6 % the parameters which are defined as follows | |
7 % x(t+1) = A*x(t) + w(t), w ~ N(0, Q), x(0) ~ N(init_x, init_V) | |
8 % y(t) = C*x(t) + v(t), v ~ N(0, R) | |
9 % A0 is the initial value, A is the final value, etc. | |
10 % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of | |
11 % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix. | |
12 % LL is the "learning curve": a vector of the log lik. values at each iteration. | |
13 % LL might go positive, since prob. densities can exceed 1, although this probably | |
14 % indicates that something has gone wrong e.g., a variance has collapsed to 0. | |
15 % | |
16 % There are several optional arguments, that should be passed in the following order. | |
17 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode) | |
18 % MAX_ITER specifies the maximum number of EM iterations (default 10). | |
19 % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0). | |
20 % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0). | |
21 % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0). | |
22 % This problem has a global MLE. Hence the initial parameter values are not important. | |
23 % | |
24 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...) | |
25 % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be | |
26 % used to enforce any constraints on the params. | |
27 % | |
28 % For details, see | |
29 % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996 | |
30 % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM | |
31 % algorithm and its application to speech recognition", | |
32 % IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993. | |
33 | |
34 | |
35 % learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin) | |
36 if nargin < 8, max_iter = 10; end | |
37 if nargin < 9, diagQ = 0; end | |
38 if nargin < 10, diagR = 0; end | |
39 if nargin < 11, ARmode = 0; end | |
40 if nargin < 12, constr_fun = []; end | |
41 verbose = 1; | |
42 thresh = 1e-4; | |
43 | |
44 | |
45 if ~iscell(data) | |
46 N = size(data, 3); | |
47 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell | |
48 else | |
49 N = length(data); | |
50 end | |
51 | |
52 N = length(data); | |
53 ss = size(A, 1); | |
54 os = size(C,1); | |
55 | |
56 alpha = zeros(os, os); | |
57 Tsum = 0; | |
58 for ex = 1:N | |
59 %y = data(:,:,ex); | |
60 y = data{ex}; | |
61 T = length(y); | |
62 Tsum = Tsum + T; | |
63 alpha_temp = zeros(os, os); | |
64 for t=1:T | |
65 alpha_temp = alpha_temp + y(:,t)*y(:,t)'; | |
66 end | |
67 alpha = alpha + alpha_temp; | |
68 end | |
69 | |
70 previous_loglik = -inf; | |
71 loglik = 0; | |
72 converged = 0; | |
73 num_iter = 1; | |
74 LL = []; | |
75 | |
76 % Convert to inline function as needed. | |
77 if ~isempty(constr_fun) | |
78 constr_fun = fcnchk(constr_fun,length(varargin)); | |
79 end | |
80 | |
81 | |
82 while ~converged & (num_iter <= max_iter) | |
83 | |
84 %%% E step | |
85 | |
86 delta = zeros(os, ss); | |
87 gamma = zeros(ss, ss); | |
88 gamma1 = zeros(ss, ss); | |
89 gamma2 = zeros(ss, ss); | |
90 beta = zeros(ss, ss); | |
91 P1sum = zeros(ss, ss); | |
92 x1sum = zeros(ss, 1); | |
93 loglik = 0; | |
94 | |
95 for ex = 1:N | |
96 y = data{ex}; | |
97 T = length(y); | |
98 [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ... | |
99 Estep(y, A, C, Q, R, initx, initV, ARmode); | |
100 beta = beta + beta_t; | |
101 gamma = gamma + gamma_t; | |
102 delta = delta + delta_t; | |
103 gamma1 = gamma1 + gamma1_t; | |
104 gamma2 = gamma2 + gamma2_t; | |
105 P1sum = P1sum + V1 + x1*x1'; | |
106 x1sum = x1sum + x1; | |
107 %fprintf(1, 'example %d, ll/T %5.3f\n', ex, loglik_t/T); | |
108 loglik = loglik + loglik_t; | |
109 end | |
110 LL = [LL loglik]; | |
111 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end | |
112 %fprintf(1, 'iteration %d, loglik/NT = %f\n', num_iter, loglik/Tsum); | |
113 num_iter = num_iter + 1; | |
114 | |
115 %%% M step | |
116 | |
117 % Tsum = N*T | |
118 % Tsum1 = N*(T-1); | |
119 Tsum1 = Tsum - N; | |
120 A = beta * inv(gamma1); | |
121 %A = (gamma1' \ beta')'; | |
122 Q = (gamma2 - A*beta') / Tsum1; | |
123 if diagQ | |
124 Q = diag(diag(Q)); | |
125 end | |
126 if ~ARmode | |
127 C = delta * inv(gamma); | |
128 %C = (gamma' \ delta')'; | |
129 R = (alpha - C*delta') / Tsum; | |
130 if diagR | |
131 R = diag(diag(R)); | |
132 end | |
133 end | |
134 initx = x1sum / N; | |
135 initV = P1sum/N - initx*initx'; | |
136 | |
137 if ~isempty(constr_fun) | |
138 [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:}); | |
139 end | |
140 | |
141 converged = em_converged(loglik, previous_loglik, thresh); | |
142 previous_loglik = loglik; | |
143 end | |
144 | |
145 | |
146 | |
147 %%%%%%%%% | |
148 | |
149 function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ... | |
150 Estep(y, A, C, Q, R, initx, initV, ARmode) | |
151 % | |
152 % Compute the (expected) sufficient statistics for a single Kalman filter sequence. | |
153 % | |
154 | |
155 [os T] = size(y); | |
156 ss = length(A); | |
157 | |
158 if ARmode | |
159 xsmooth = y; | |
160 Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states | |
161 VVsmooth = zeros(ss, ss, T); | |
162 loglik = 0; | |
163 else | |
164 [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV); | |
165 end | |
166 | |
167 delta = zeros(os, ss); | |
168 gamma = zeros(ss, ss); | |
169 beta = zeros(ss, ss); | |
170 for t=1:T | |
171 delta = delta + y(:,t)*xsmooth(:,t)'; | |
172 gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t); | |
173 if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end | |
174 end | |
175 gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T); | |
176 gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1); | |
177 | |
178 x1 = xsmooth(:,1); | |
179 V1 = Vsmooth(:,:,1); | |
180 | |
181 | |
182 |