wolffd@0: function [A, C, Q, R, initx, initV, LL] = ... wolffd@0: learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin) wolffd@0: % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM. wolffd@0: % wolffd@0: % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits wolffd@0: % the parameters which are defined as follows wolffd@0: % x(t+1) = A*x(t) + w(t), w ~ N(0, Q), x(0) ~ N(init_x, init_V) wolffd@0: % y(t) = C*x(t) + v(t), v ~ N(0, R) wolffd@0: % A0 is the initial value, A is the final value, etc. wolffd@0: % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of wolffd@0: % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix. wolffd@0: % LL is the "learning curve": a vector of the log lik. values at each iteration. wolffd@0: % LL might go positive, since prob. densities can exceed 1, although this probably wolffd@0: % indicates that something has gone wrong e.g., a variance has collapsed to 0. wolffd@0: % wolffd@0: % There are several optional arguments, that should be passed in the following order. wolffd@0: % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode) wolffd@0: % MAX_ITER specifies the maximum number of EM iterations (default 10). wolffd@0: % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0). wolffd@0: % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0). wolffd@0: % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0). wolffd@0: % This problem has a global MLE. Hence the initial parameter values are not important. wolffd@0: % wolffd@0: % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...) wolffd@0: % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be wolffd@0: % used to enforce any constraints on the params. wolffd@0: % wolffd@0: % For details, see wolffd@0: % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996 wolffd@0: % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM wolffd@0: % algorithm and its application to speech recognition", wolffd@0: % IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993. wolffd@0: wolffd@0: wolffd@0: % learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin) wolffd@0: if nargin < 8, max_iter = 10; end wolffd@0: if nargin < 9, diagQ = 0; end wolffd@0: if nargin < 10, diagR = 0; end wolffd@0: if nargin < 11, ARmode = 0; end wolffd@0: if nargin < 12, constr_fun = []; end wolffd@0: verbose = 1; wolffd@0: thresh = 1e-4; wolffd@0: wolffd@0: wolffd@0: if ~iscell(data) wolffd@0: N = size(data, 3); wolffd@0: data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell wolffd@0: else wolffd@0: N = length(data); wolffd@0: end wolffd@0: wolffd@0: N = length(data); wolffd@0: ss = size(A, 1); wolffd@0: os = size(C,1); wolffd@0: wolffd@0: alpha = zeros(os, os); wolffd@0: Tsum = 0; wolffd@0: for ex = 1:N wolffd@0: %y = data(:,:,ex); wolffd@0: y = data{ex}; wolffd@0: T = length(y); wolffd@0: Tsum = Tsum + T; wolffd@0: alpha_temp = zeros(os, os); wolffd@0: for t=1:T wolffd@0: alpha_temp = alpha_temp + y(:,t)*y(:,t)'; wolffd@0: end wolffd@0: alpha = alpha + alpha_temp; wolffd@0: end wolffd@0: wolffd@0: previous_loglik = -inf; wolffd@0: loglik = 0; wolffd@0: converged = 0; wolffd@0: num_iter = 1; wolffd@0: LL = []; wolffd@0: wolffd@0: % Convert to inline function as needed. wolffd@0: if ~isempty(constr_fun) wolffd@0: constr_fun = fcnchk(constr_fun,length(varargin)); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: while ~converged & (num_iter <= max_iter) wolffd@0: wolffd@0: %%% E step wolffd@0: wolffd@0: delta = zeros(os, ss); wolffd@0: gamma = zeros(ss, ss); wolffd@0: gamma1 = zeros(ss, ss); wolffd@0: gamma2 = zeros(ss, ss); wolffd@0: beta = zeros(ss, ss); wolffd@0: P1sum = zeros(ss, ss); wolffd@0: x1sum = zeros(ss, 1); wolffd@0: loglik = 0; wolffd@0: wolffd@0: for ex = 1:N wolffd@0: y = data{ex}; wolffd@0: T = length(y); wolffd@0: [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ... wolffd@0: Estep(y, A, C, Q, R, initx, initV, ARmode); wolffd@0: beta = beta + beta_t; wolffd@0: gamma = gamma + gamma_t; wolffd@0: delta = delta + delta_t; wolffd@0: gamma1 = gamma1 + gamma1_t; wolffd@0: gamma2 = gamma2 + gamma2_t; wolffd@0: P1sum = P1sum + V1 + x1*x1'; wolffd@0: x1sum = x1sum + x1; wolffd@0: %fprintf(1, 'example %d, ll/T %5.3f\n', ex, loglik_t/T); wolffd@0: loglik = loglik + loglik_t; wolffd@0: end wolffd@0: LL = [LL loglik]; wolffd@0: if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end wolffd@0: %fprintf(1, 'iteration %d, loglik/NT = %f\n', num_iter, loglik/Tsum); wolffd@0: num_iter = num_iter + 1; wolffd@0: wolffd@0: %%% M step wolffd@0: wolffd@0: % Tsum = N*T wolffd@0: % Tsum1 = N*(T-1); wolffd@0: Tsum1 = Tsum - N; wolffd@0: A = beta * inv(gamma1); wolffd@0: %A = (gamma1' \ beta')'; wolffd@0: Q = (gamma2 - A*beta') / Tsum1; wolffd@0: if diagQ wolffd@0: Q = diag(diag(Q)); wolffd@0: end wolffd@0: if ~ARmode wolffd@0: C = delta * inv(gamma); wolffd@0: %C = (gamma' \ delta')'; wolffd@0: R = (alpha - C*delta') / Tsum; wolffd@0: if diagR wolffd@0: R = diag(diag(R)); wolffd@0: end wolffd@0: end wolffd@0: initx = x1sum / N; wolffd@0: initV = P1sum/N - initx*initx'; wolffd@0: wolffd@0: if ~isempty(constr_fun) wolffd@0: [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:}); wolffd@0: end wolffd@0: wolffd@0: converged = em_converged(loglik, previous_loglik, thresh); wolffd@0: previous_loglik = loglik; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ... wolffd@0: Estep(y, A, C, Q, R, initx, initV, ARmode) wolffd@0: % wolffd@0: % Compute the (expected) sufficient statistics for a single Kalman filter sequence. wolffd@0: % wolffd@0: wolffd@0: [os T] = size(y); wolffd@0: ss = length(A); wolffd@0: wolffd@0: if ARmode wolffd@0: xsmooth = y; wolffd@0: Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states wolffd@0: VVsmooth = zeros(ss, ss, T); wolffd@0: loglik = 0; wolffd@0: else wolffd@0: [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV); wolffd@0: end wolffd@0: wolffd@0: delta = zeros(os, ss); wolffd@0: gamma = zeros(ss, ss); wolffd@0: beta = zeros(ss, ss); wolffd@0: for t=1:T wolffd@0: delta = delta + y(:,t)*xsmooth(:,t)'; wolffd@0: gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t); wolffd@0: if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end wolffd@0: end wolffd@0: gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T); wolffd@0: gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1); wolffd@0: wolffd@0: x1 = xsmooth(:,1); wolffd@0: V1 = Vsmooth(:,:,1); wolffd@0: wolffd@0: wolffd@0: