wolffd@0
|
1 function [A, C, Q, R, initx, initV, LL] = ...
|
wolffd@0
|
2 learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
|
wolffd@0
|
3 % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM.
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits
|
wolffd@0
|
6 % the parameters which are defined as follows
|
wolffd@0
|
7 % x(t+1) = A*x(t) + w(t), w ~ N(0, Q), x(0) ~ N(init_x, init_V)
|
wolffd@0
|
8 % y(t) = C*x(t) + v(t), v ~ N(0, R)
|
wolffd@0
|
9 % A0 is the initial value, A is the final value, etc.
|
wolffd@0
|
10 % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of
|
wolffd@0
|
11 % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix.
|
wolffd@0
|
12 % LL is the "learning curve": a vector of the log lik. values at each iteration.
|
wolffd@0
|
13 % LL might go positive, since prob. densities can exceed 1, although this probably
|
wolffd@0
|
14 % indicates that something has gone wrong e.g., a variance has collapsed to 0.
|
wolffd@0
|
15 %
|
wolffd@0
|
16 % There are several optional arguments, that should be passed in the following order.
|
wolffd@0
|
17 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode)
|
wolffd@0
|
18 % MAX_ITER specifies the maximum number of EM iterations (default 10).
|
wolffd@0
|
19 % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0).
|
wolffd@0
|
20 % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0).
|
wolffd@0
|
21 % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0).
|
wolffd@0
|
22 % This problem has a global MLE. Hence the initial parameter values are not important.
|
wolffd@0
|
23 %
|
wolffd@0
|
24 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...)
|
wolffd@0
|
25 % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be
|
wolffd@0
|
26 % used to enforce any constraints on the params.
|
wolffd@0
|
27 %
|
wolffd@0
|
28 % For details, see
|
wolffd@0
|
29 % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996
|
wolffd@0
|
30 % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM
|
wolffd@0
|
31 % algorithm and its application to speech recognition",
|
wolffd@0
|
32 % IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993.
|
wolffd@0
|
33
|
wolffd@0
|
34
|
wolffd@0
|
35 % learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
|
wolffd@0
|
36 if nargin < 8, max_iter = 10; end
|
wolffd@0
|
37 if nargin < 9, diagQ = 0; end
|
wolffd@0
|
38 if nargin < 10, diagR = 0; end
|
wolffd@0
|
39 if nargin < 11, ARmode = 0; end
|
wolffd@0
|
40 if nargin < 12, constr_fun = []; end
|
wolffd@0
|
41 verbose = 1;
|
wolffd@0
|
42 thresh = 1e-4;
|
wolffd@0
|
43
|
wolffd@0
|
44
|
wolffd@0
|
45 if ~iscell(data)
|
wolffd@0
|
46 N = size(data, 3);
|
wolffd@0
|
47 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell
|
wolffd@0
|
48 else
|
wolffd@0
|
49 N = length(data);
|
wolffd@0
|
50 end
|
wolffd@0
|
51
|
wolffd@0
|
52 N = length(data);
|
wolffd@0
|
53 ss = size(A, 1);
|
wolffd@0
|
54 os = size(C,1);
|
wolffd@0
|
55
|
wolffd@0
|
56 alpha = zeros(os, os);
|
wolffd@0
|
57 Tsum = 0;
|
wolffd@0
|
58 for ex = 1:N
|
wolffd@0
|
59 %y = data(:,:,ex);
|
wolffd@0
|
60 y = data{ex};
|
wolffd@0
|
61 T = length(y);
|
wolffd@0
|
62 Tsum = Tsum + T;
|
wolffd@0
|
63 alpha_temp = zeros(os, os);
|
wolffd@0
|
64 for t=1:T
|
wolffd@0
|
65 alpha_temp = alpha_temp + y(:,t)*y(:,t)';
|
wolffd@0
|
66 end
|
wolffd@0
|
67 alpha = alpha + alpha_temp;
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 previous_loglik = -inf;
|
wolffd@0
|
71 loglik = 0;
|
wolffd@0
|
72 converged = 0;
|
wolffd@0
|
73 num_iter = 1;
|
wolffd@0
|
74 LL = [];
|
wolffd@0
|
75
|
wolffd@0
|
76 % Convert to inline function as needed.
|
wolffd@0
|
77 if ~isempty(constr_fun)
|
wolffd@0
|
78 constr_fun = fcnchk(constr_fun,length(varargin));
|
wolffd@0
|
79 end
|
wolffd@0
|
80
|
wolffd@0
|
81
|
wolffd@0
|
82 while ~converged & (num_iter <= max_iter)
|
wolffd@0
|
83
|
wolffd@0
|
84 %%% E step
|
wolffd@0
|
85
|
wolffd@0
|
86 delta = zeros(os, ss);
|
wolffd@0
|
87 gamma = zeros(ss, ss);
|
wolffd@0
|
88 gamma1 = zeros(ss, ss);
|
wolffd@0
|
89 gamma2 = zeros(ss, ss);
|
wolffd@0
|
90 beta = zeros(ss, ss);
|
wolffd@0
|
91 P1sum = zeros(ss, ss);
|
wolffd@0
|
92 x1sum = zeros(ss, 1);
|
wolffd@0
|
93 loglik = 0;
|
wolffd@0
|
94
|
wolffd@0
|
95 for ex = 1:N
|
wolffd@0
|
96 y = data{ex};
|
wolffd@0
|
97 T = length(y);
|
wolffd@0
|
98 [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ...
|
wolffd@0
|
99 Estep(y, A, C, Q, R, initx, initV, ARmode);
|
wolffd@0
|
100 beta = beta + beta_t;
|
wolffd@0
|
101 gamma = gamma + gamma_t;
|
wolffd@0
|
102 delta = delta + delta_t;
|
wolffd@0
|
103 gamma1 = gamma1 + gamma1_t;
|
wolffd@0
|
104 gamma2 = gamma2 + gamma2_t;
|
wolffd@0
|
105 P1sum = P1sum + V1 + x1*x1';
|
wolffd@0
|
106 x1sum = x1sum + x1;
|
wolffd@0
|
107 %fprintf(1, 'example %d, ll/T %5.3f\n', ex, loglik_t/T);
|
wolffd@0
|
108 loglik = loglik + loglik_t;
|
wolffd@0
|
109 end
|
wolffd@0
|
110 LL = [LL loglik];
|
wolffd@0
|
111 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
|
wolffd@0
|
112 %fprintf(1, 'iteration %d, loglik/NT = %f\n', num_iter, loglik/Tsum);
|
wolffd@0
|
113 num_iter = num_iter + 1;
|
wolffd@0
|
114
|
wolffd@0
|
115 %%% M step
|
wolffd@0
|
116
|
wolffd@0
|
117 % Tsum = N*T
|
wolffd@0
|
118 % Tsum1 = N*(T-1);
|
wolffd@0
|
119 Tsum1 = Tsum - N;
|
wolffd@0
|
120 A = beta * inv(gamma1);
|
wolffd@0
|
121 %A = (gamma1' \ beta')';
|
wolffd@0
|
122 Q = (gamma2 - A*beta') / Tsum1;
|
wolffd@0
|
123 if diagQ
|
wolffd@0
|
124 Q = diag(diag(Q));
|
wolffd@0
|
125 end
|
wolffd@0
|
126 if ~ARmode
|
wolffd@0
|
127 C = delta * inv(gamma);
|
wolffd@0
|
128 %C = (gamma' \ delta')';
|
wolffd@0
|
129 R = (alpha - C*delta') / Tsum;
|
wolffd@0
|
130 if diagR
|
wolffd@0
|
131 R = diag(diag(R));
|
wolffd@0
|
132 end
|
wolffd@0
|
133 end
|
wolffd@0
|
134 initx = x1sum / N;
|
wolffd@0
|
135 initV = P1sum/N - initx*initx';
|
wolffd@0
|
136
|
wolffd@0
|
137 if ~isempty(constr_fun)
|
wolffd@0
|
138 [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:});
|
wolffd@0
|
139 end
|
wolffd@0
|
140
|
wolffd@0
|
141 converged = em_converged(loglik, previous_loglik, thresh);
|
wolffd@0
|
142 previous_loglik = loglik;
|
wolffd@0
|
143 end
|
wolffd@0
|
144
|
wolffd@0
|
145
|
wolffd@0
|
146
|
wolffd@0
|
147 %%%%%%%%%
|
wolffd@0
|
148
|
wolffd@0
|
149 function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ...
|
wolffd@0
|
150 Estep(y, A, C, Q, R, initx, initV, ARmode)
|
wolffd@0
|
151 %
|
wolffd@0
|
152 % Compute the (expected) sufficient statistics for a single Kalman filter sequence.
|
wolffd@0
|
153 %
|
wolffd@0
|
154
|
wolffd@0
|
155 [os T] = size(y);
|
wolffd@0
|
156 ss = length(A);
|
wolffd@0
|
157
|
wolffd@0
|
158 if ARmode
|
wolffd@0
|
159 xsmooth = y;
|
wolffd@0
|
160 Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states
|
wolffd@0
|
161 VVsmooth = zeros(ss, ss, T);
|
wolffd@0
|
162 loglik = 0;
|
wolffd@0
|
163 else
|
wolffd@0
|
164 [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV);
|
wolffd@0
|
165 end
|
wolffd@0
|
166
|
wolffd@0
|
167 delta = zeros(os, ss);
|
wolffd@0
|
168 gamma = zeros(ss, ss);
|
wolffd@0
|
169 beta = zeros(ss, ss);
|
wolffd@0
|
170 for t=1:T
|
wolffd@0
|
171 delta = delta + y(:,t)*xsmooth(:,t)';
|
wolffd@0
|
172 gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t);
|
wolffd@0
|
173 if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end
|
wolffd@0
|
174 end
|
wolffd@0
|
175 gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T);
|
wolffd@0
|
176 gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1);
|
wolffd@0
|
177
|
wolffd@0
|
178 x1 = xsmooth(:,1);
|
wolffd@0
|
179 V1 = Vsmooth(:,:,1);
|
wolffd@0
|
180
|
wolffd@0
|
181
|
wolffd@0
|
182
|