comparison toolboxes/FullBNT-1.0.7/Kalman/learn_kalman.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [A, C, Q, R, initx, initV, LL] = ...
2 learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
3 % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM.
4 %
5 % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits
6 % the parameters which are defined as follows
7 % x(t+1) = A*x(t) + w(t), w ~ N(0, Q), x(0) ~ N(init_x, init_V)
8 % y(t) = C*x(t) + v(t), v ~ N(0, R)
9 % A0 is the initial value, A is the final value, etc.
10 % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of
11 % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix.
12 % LL is the "learning curve": a vector of the log lik. values at each iteration.
13 % LL might go positive, since prob. densities can exceed 1, although this probably
14 % indicates that something has gone wrong e.g., a variance has collapsed to 0.
15 %
16 % There are several optional arguments, that should be passed in the following order.
17 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode)
18 % MAX_ITER specifies the maximum number of EM iterations (default 10).
19 % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0).
20 % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0).
21 % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0).
22 % This problem has a global MLE. Hence the initial parameter values are not important.
23 %
24 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...)
25 % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be
26 % used to enforce any constraints on the params.
27 %
28 % For details, see
29 % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996
30 % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM
31 % algorithm and its application to speech recognition",
32 % IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993.
33
34
35 % learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
36 if nargin < 8, max_iter = 10; end
37 if nargin < 9, diagQ = 0; end
38 if nargin < 10, diagR = 0; end
39 if nargin < 11, ARmode = 0; end
40 if nargin < 12, constr_fun = []; end
41 verbose = 1;
42 thresh = 1e-4;
43
44
45 if ~iscell(data)
46 N = size(data, 3);
47 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell
48 else
49 N = length(data);
50 end
51
52 N = length(data);
53 ss = size(A, 1);
54 os = size(C,1);
55
56 alpha = zeros(os, os);
57 Tsum = 0;
58 for ex = 1:N
59 %y = data(:,:,ex);
60 y = data{ex};
61 T = length(y);
62 Tsum = Tsum + T;
63 alpha_temp = zeros(os, os);
64 for t=1:T
65 alpha_temp = alpha_temp + y(:,t)*y(:,t)';
66 end
67 alpha = alpha + alpha_temp;
68 end
69
70 previous_loglik = -inf;
71 loglik = 0;
72 converged = 0;
73 num_iter = 1;
74 LL = [];
75
76 % Convert to inline function as needed.
77 if ~isempty(constr_fun)
78 constr_fun = fcnchk(constr_fun,length(varargin));
79 end
80
81
82 while ~converged & (num_iter <= max_iter)
83
84 %%% E step
85
86 delta = zeros(os, ss);
87 gamma = zeros(ss, ss);
88 gamma1 = zeros(ss, ss);
89 gamma2 = zeros(ss, ss);
90 beta = zeros(ss, ss);
91 P1sum = zeros(ss, ss);
92 x1sum = zeros(ss, 1);
93 loglik = 0;
94
95 for ex = 1:N
96 y = data{ex};
97 T = length(y);
98 [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ...
99 Estep(y, A, C, Q, R, initx, initV, ARmode);
100 beta = beta + beta_t;
101 gamma = gamma + gamma_t;
102 delta = delta + delta_t;
103 gamma1 = gamma1 + gamma1_t;
104 gamma2 = gamma2 + gamma2_t;
105 P1sum = P1sum + V1 + x1*x1';
106 x1sum = x1sum + x1;
107 %fprintf(1, 'example %d, ll/T %5.3f\n', ex, loglik_t/T);
108 loglik = loglik + loglik_t;
109 end
110 LL = [LL loglik];
111 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
112 %fprintf(1, 'iteration %d, loglik/NT = %f\n', num_iter, loglik/Tsum);
113 num_iter = num_iter + 1;
114
115 %%% M step
116
117 % Tsum = N*T
118 % Tsum1 = N*(T-1);
119 Tsum1 = Tsum - N;
120 A = beta * inv(gamma1);
121 %A = (gamma1' \ beta')';
122 Q = (gamma2 - A*beta') / Tsum1;
123 if diagQ
124 Q = diag(diag(Q));
125 end
126 if ~ARmode
127 C = delta * inv(gamma);
128 %C = (gamma' \ delta')';
129 R = (alpha - C*delta') / Tsum;
130 if diagR
131 R = diag(diag(R));
132 end
133 end
134 initx = x1sum / N;
135 initV = P1sum/N - initx*initx';
136
137 if ~isempty(constr_fun)
138 [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:});
139 end
140
141 converged = em_converged(loglik, previous_loglik, thresh);
142 previous_loglik = loglik;
143 end
144
145
146
147 %%%%%%%%%
148
149 function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ...
150 Estep(y, A, C, Q, R, initx, initV, ARmode)
151 %
152 % Compute the (expected) sufficient statistics for a single Kalman filter sequence.
153 %
154
155 [os T] = size(y);
156 ss = length(A);
157
158 if ARmode
159 xsmooth = y;
160 Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states
161 VVsmooth = zeros(ss, ss, T);
162 loglik = 0;
163 else
164 [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV);
165 end
166
167 delta = zeros(os, ss);
168 gamma = zeros(ss, ss);
169 beta = zeros(ss, ss);
170 for t=1:T
171 delta = delta + y(:,t)*xsmooth(:,t)';
172 gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t);
173 if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end
174 end
175 gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T);
176 gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1);
177
178 x1 = xsmooth(:,1);
179 V1 = Vsmooth(:,:,1);
180
181
182