annotate toolboxes/FullBNT-1.0.7/Kalman/learn_kalman.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [A, C, Q, R, initx, initV, LL] = ...
wolffd@0 2 learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
wolffd@0 3 % LEARN_KALMAN Find the ML parameters of a stochastic Linear Dynamical System using EM.
wolffd@0 4 %
wolffd@0 5 % [A, C, Q, R, INITX, INITV, LL] = LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0) fits
wolffd@0 6 % the parameters which are defined as follows
wolffd@0 7 % x(t+1) = A*x(t) + w(t), w ~ N(0, Q), x(0) ~ N(init_x, init_V)
wolffd@0 8 % y(t) = C*x(t) + v(t), v ~ N(0, R)
wolffd@0 9 % A0 is the initial value, A is the final value, etc.
wolffd@0 10 % DATA(:,t,l) is the observation vector at time t for sequence l. If the sequences are of
wolffd@0 11 % different lengths, you can pass in a cell array, so DATA{l} is an O*T matrix.
wolffd@0 12 % LL is the "learning curve": a vector of the log lik. values at each iteration.
wolffd@0 13 % LL might go positive, since prob. densities can exceed 1, although this probably
wolffd@0 14 % indicates that something has gone wrong e.g., a variance has collapsed to 0.
wolffd@0 15 %
wolffd@0 16 % There are several optional arguments, that should be passed in the following order.
wolffd@0 17 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, ARmode)
wolffd@0 18 % MAX_ITER specifies the maximum number of EM iterations (default 10).
wolffd@0 19 % DIAGQ=1 specifies that the Q matrix should be diagonal. (Default 0).
wolffd@0 20 % DIAGR=1 specifies that the R matrix should also be diagonal. (Default 0).
wolffd@0 21 % ARMODE=1 specifies that C=I, R=0. i.e., a Gauss-Markov process. (Default 0).
wolffd@0 22 % This problem has a global MLE. Hence the initial parameter values are not important.
wolffd@0 23 %
wolffd@0 24 % LEARN_KALMAN(DATA, A0, C0, Q0, R0, INITX0, INITV0, MAX_ITER, DIAGQ, DIAGR, F, P1, P2, ...)
wolffd@0 25 % calls [A,C,Q,R,initx,initV] = f(A,C,Q,R,initx,initV,P1,P2,...) after every M step. f can be
wolffd@0 26 % used to enforce any constraints on the params.
wolffd@0 27 %
wolffd@0 28 % For details, see
wolffd@0 29 % - Ghahramani and Hinton, "Parameter Estimation for LDS", U. Toronto tech. report, 1996
wolffd@0 30 % - Digalakis, Rohlicek and Ostendorf, "ML Estimation of a stochastic linear system with the EM
wolffd@0 31 % algorithm and its application to speech recognition",
wolffd@0 32 % IEEE Trans. Speech and Audio Proc., 1(4):431--442, 1993.
wolffd@0 33
wolffd@0 34
wolffd@0 35 % learn_kalman(data, A, C, Q, R, initx, initV, max_iter, diagQ, diagR, ARmode, constr_fun, varargin)
wolffd@0 36 if nargin < 8, max_iter = 10; end
wolffd@0 37 if nargin < 9, diagQ = 0; end
wolffd@0 38 if nargin < 10, diagR = 0; end
wolffd@0 39 if nargin < 11, ARmode = 0; end
wolffd@0 40 if nargin < 12, constr_fun = []; end
wolffd@0 41 verbose = 1;
wolffd@0 42 thresh = 1e-4;
wolffd@0 43
wolffd@0 44
wolffd@0 45 if ~iscell(data)
wolffd@0 46 N = size(data, 3);
wolffd@0 47 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell
wolffd@0 48 else
wolffd@0 49 N = length(data);
wolffd@0 50 end
wolffd@0 51
wolffd@0 52 N = length(data);
wolffd@0 53 ss = size(A, 1);
wolffd@0 54 os = size(C,1);
wolffd@0 55
wolffd@0 56 alpha = zeros(os, os);
wolffd@0 57 Tsum = 0;
wolffd@0 58 for ex = 1:N
wolffd@0 59 %y = data(:,:,ex);
wolffd@0 60 y = data{ex};
wolffd@0 61 T = length(y);
wolffd@0 62 Tsum = Tsum + T;
wolffd@0 63 alpha_temp = zeros(os, os);
wolffd@0 64 for t=1:T
wolffd@0 65 alpha_temp = alpha_temp + y(:,t)*y(:,t)';
wolffd@0 66 end
wolffd@0 67 alpha = alpha + alpha_temp;
wolffd@0 68 end
wolffd@0 69
wolffd@0 70 previous_loglik = -inf;
wolffd@0 71 loglik = 0;
wolffd@0 72 converged = 0;
wolffd@0 73 num_iter = 1;
wolffd@0 74 LL = [];
wolffd@0 75
wolffd@0 76 % Convert to inline function as needed.
wolffd@0 77 if ~isempty(constr_fun)
wolffd@0 78 constr_fun = fcnchk(constr_fun,length(varargin));
wolffd@0 79 end
wolffd@0 80
wolffd@0 81
wolffd@0 82 while ~converged & (num_iter <= max_iter)
wolffd@0 83
wolffd@0 84 %%% E step
wolffd@0 85
wolffd@0 86 delta = zeros(os, ss);
wolffd@0 87 gamma = zeros(ss, ss);
wolffd@0 88 gamma1 = zeros(ss, ss);
wolffd@0 89 gamma2 = zeros(ss, ss);
wolffd@0 90 beta = zeros(ss, ss);
wolffd@0 91 P1sum = zeros(ss, ss);
wolffd@0 92 x1sum = zeros(ss, 1);
wolffd@0 93 loglik = 0;
wolffd@0 94
wolffd@0 95 for ex = 1:N
wolffd@0 96 y = data{ex};
wolffd@0 97 T = length(y);
wolffd@0 98 [beta_t, gamma_t, delta_t, gamma1_t, gamma2_t, x1, V1, loglik_t] = ...
wolffd@0 99 Estep(y, A, C, Q, R, initx, initV, ARmode);
wolffd@0 100 beta = beta + beta_t;
wolffd@0 101 gamma = gamma + gamma_t;
wolffd@0 102 delta = delta + delta_t;
wolffd@0 103 gamma1 = gamma1 + gamma1_t;
wolffd@0 104 gamma2 = gamma2 + gamma2_t;
wolffd@0 105 P1sum = P1sum + V1 + x1*x1';
wolffd@0 106 x1sum = x1sum + x1;
wolffd@0 107 %fprintf(1, 'example %d, ll/T %5.3f\n', ex, loglik_t/T);
wolffd@0 108 loglik = loglik + loglik_t;
wolffd@0 109 end
wolffd@0 110 LL = [LL loglik];
wolffd@0 111 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
wolffd@0 112 %fprintf(1, 'iteration %d, loglik/NT = %f\n', num_iter, loglik/Tsum);
wolffd@0 113 num_iter = num_iter + 1;
wolffd@0 114
wolffd@0 115 %%% M step
wolffd@0 116
wolffd@0 117 % Tsum = N*T
wolffd@0 118 % Tsum1 = N*(T-1);
wolffd@0 119 Tsum1 = Tsum - N;
wolffd@0 120 A = beta * inv(gamma1);
wolffd@0 121 %A = (gamma1' \ beta')';
wolffd@0 122 Q = (gamma2 - A*beta') / Tsum1;
wolffd@0 123 if diagQ
wolffd@0 124 Q = diag(diag(Q));
wolffd@0 125 end
wolffd@0 126 if ~ARmode
wolffd@0 127 C = delta * inv(gamma);
wolffd@0 128 %C = (gamma' \ delta')';
wolffd@0 129 R = (alpha - C*delta') / Tsum;
wolffd@0 130 if diagR
wolffd@0 131 R = diag(diag(R));
wolffd@0 132 end
wolffd@0 133 end
wolffd@0 134 initx = x1sum / N;
wolffd@0 135 initV = P1sum/N - initx*initx';
wolffd@0 136
wolffd@0 137 if ~isempty(constr_fun)
wolffd@0 138 [A,C,Q,R,initx,initV] = feval(constr_fun, A, C, Q, R, initx, initV, varargin{:});
wolffd@0 139 end
wolffd@0 140
wolffd@0 141 converged = em_converged(loglik, previous_loglik, thresh);
wolffd@0 142 previous_loglik = loglik;
wolffd@0 143 end
wolffd@0 144
wolffd@0 145
wolffd@0 146
wolffd@0 147 %%%%%%%%%
wolffd@0 148
wolffd@0 149 function [beta, gamma, delta, gamma1, gamma2, x1, V1, loglik] = ...
wolffd@0 150 Estep(y, A, C, Q, R, initx, initV, ARmode)
wolffd@0 151 %
wolffd@0 152 % Compute the (expected) sufficient statistics for a single Kalman filter sequence.
wolffd@0 153 %
wolffd@0 154
wolffd@0 155 [os T] = size(y);
wolffd@0 156 ss = length(A);
wolffd@0 157
wolffd@0 158 if ARmode
wolffd@0 159 xsmooth = y;
wolffd@0 160 Vsmooth = zeros(ss, ss, T); % no uncertainty about the hidden states
wolffd@0 161 VVsmooth = zeros(ss, ss, T);
wolffd@0 162 loglik = 0;
wolffd@0 163 else
wolffd@0 164 [xsmooth, Vsmooth, VVsmooth, loglik] = kalman_smoother(y, A, C, Q, R, initx, initV);
wolffd@0 165 end
wolffd@0 166
wolffd@0 167 delta = zeros(os, ss);
wolffd@0 168 gamma = zeros(ss, ss);
wolffd@0 169 beta = zeros(ss, ss);
wolffd@0 170 for t=1:T
wolffd@0 171 delta = delta + y(:,t)*xsmooth(:,t)';
wolffd@0 172 gamma = gamma + xsmooth(:,t)*xsmooth(:,t)' + Vsmooth(:,:,t);
wolffd@0 173 if t>1 beta = beta + xsmooth(:,t)*xsmooth(:,t-1)' + VVsmooth(:,:,t); end
wolffd@0 174 end
wolffd@0 175 gamma1 = gamma - xsmooth(:,T)*xsmooth(:,T)' - Vsmooth(:,:,T);
wolffd@0 176 gamma2 = gamma - xsmooth(:,1)*xsmooth(:,1)' - Vsmooth(:,:,1);
wolffd@0 177
wolffd@0 178 x1 = xsmooth(:,1);
wolffd@0 179 V1 = Vsmooth(:,:,1);
wolffd@0 180
wolffd@0 181
wolffd@0 182