annotate toolboxes/FullBNT-1.0.7/KPMstats/logistK.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [beta,post,lli] = logistK(x,y,w,beta)
wolffd@0 2 % [beta,post,lli] = logistK(x,y,beta,w)
wolffd@0 3 %
wolffd@0 4 % k-class logistic regression with optional sample weights
wolffd@0 5 %
wolffd@0 6 % k = number of classes
wolffd@0 7 % n = number of samples
wolffd@0 8 % d = dimensionality of samples
wolffd@0 9 %
wolffd@0 10 % INPUT
wolffd@0 11 % x dxn matrix of n input column vectors
wolffd@0 12 % y kxn vector of class assignments
wolffd@0 13 % [w] 1xn vector of sample weights
wolffd@0 14 % [beta] dxk matrix of model coefficients
wolffd@0 15 %
wolffd@0 16 % OUTPUT
wolffd@0 17 % beta dxk matrix of fitted model coefficients
wolffd@0 18 % (beta(:,k) are fixed at 0)
wolffd@0 19 % post kxn matrix of fitted class posteriors
wolffd@0 20 % lli log likelihood
wolffd@0 21 %
wolffd@0 22 % Let p(i,j) = exp(beta(:,j)'*x(:,i)),
wolffd@0 23 % Class j posterior for observation i is:
wolffd@0 24 % post(j,i) = p(i,j) / (p(i,1) + ... p(i,k))
wolffd@0 25 %
wolffd@0 26 % See also logistK_eval.
wolffd@0 27 %
wolffd@0 28 % David Martin <dmartin@eecs.berkeley.edu>
wolffd@0 29 % May 3, 2002
wolffd@0 30
wolffd@0 31 % Copyright (C) 2002 David R. Martin <dmartin@eecs.berkeley.edu>
wolffd@0 32 %
wolffd@0 33 % This program is free software; you can redistribute it and/or
wolffd@0 34 % modify it under the terms of the GNU General Public License as
wolffd@0 35 % published by the Free Software Foundation; either version 2 of the
wolffd@0 36 % License, or (at your option) any later version.
wolffd@0 37 %
wolffd@0 38 % This program is distributed in the hope that it will be useful, but
wolffd@0 39 % WITHOUT ANY WARRANTY; without even the implied warranty of
wolffd@0 40 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
wolffd@0 41 % General Public License for more details.
wolffd@0 42 %
wolffd@0 43 % You should have received a copy of the GNU General Public License
wolffd@0 44 % along with this program; if not, write to the Free Software
wolffd@0 45 % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
wolffd@0 46 % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
wolffd@0 47
wolffd@0 48 % TODO - this code would be faster if x were transposed
wolffd@0 49
wolffd@0 50 error(nargchk(2,4,nargin));
wolffd@0 51
wolffd@0 52 debug = 0;
wolffd@0 53 if debug>0,
wolffd@0 54 h=figure(1);
wolffd@0 55 set(h,'DoubleBuffer','on');
wolffd@0 56 end
wolffd@0 57
wolffd@0 58 % get sizes
wolffd@0 59 [d,nx] = size(x);
wolffd@0 60 [k,ny] = size(y);
wolffd@0 61
wolffd@0 62 % check sizes
wolffd@0 63 if k < 2,
wolffd@0 64 error('Input y must encode at least 2 classes.');
wolffd@0 65 end
wolffd@0 66 if nx ~= ny,
wolffd@0 67 error('Inputs x,y not the same length.');
wolffd@0 68 end
wolffd@0 69
wolffd@0 70 n = nx;
wolffd@0 71
wolffd@0 72 % make sure class assignments have unit L1-norm
wolffd@0 73 sumy = sum(y,1);
wolffd@0 74 if abs(1-sumy) > eps,
wolffd@0 75 sumy = sum(y,1);
wolffd@0 76 for i = 1:k, y(i,:) = y(i,:) ./ sumy; end
wolffd@0 77 end
wolffd@0 78 clear sumy;
wolffd@0 79
wolffd@0 80 % if sample weights weren't specified, set them to 1
wolffd@0 81 if nargin < 3,
wolffd@0 82 w = ones(1,n);
wolffd@0 83 end
wolffd@0 84
wolffd@0 85 % normalize sample weights so max is 1
wolffd@0 86 w = w / max(w);
wolffd@0 87
wolffd@0 88 % if starting beta wasn't specified, initialize randomly
wolffd@0 89 if nargin < 4,
wolffd@0 90 beta = 1e-3*rand(d,k);
wolffd@0 91 beta(:,k) = 0; % fix beta for class k at zero
wolffd@0 92 else
wolffd@0 93 if sum(beta(:,k)) ~= 0,
wolffd@0 94 error('beta(:,k) ~= 0');
wolffd@0 95 end
wolffd@0 96 end
wolffd@0 97
wolffd@0 98 stepsize = 1;
wolffd@0 99 minstepsize = 1e-2;
wolffd@0 100
wolffd@0 101 post = computePost(beta,x);
wolffd@0 102 lli = computeLogLik(post,y,w);
wolffd@0 103
wolffd@0 104 for iter = 1:100,
wolffd@0 105 %disp(sprintf(' logist iter=%d lli=%g',iter,lli));
wolffd@0 106 vis(x,y,beta,lli,d,k,iter,debug);
wolffd@0 107
wolffd@0 108 % gradient and hessian
wolffd@0 109 [g,h] = derivs(post,x,y,w);
wolffd@0 110
wolffd@0 111 % make sure Hessian is well conditioned
wolffd@0 112 if rcond(h) < eps,
wolffd@0 113 % condition with Levenberg-Marquardt method
wolffd@0 114 for i = -16:16,
wolffd@0 115 h2 = h .* ((1 + 10^i)*eye(size(h)) + (1-eye(size(h))));
wolffd@0 116 if rcond(h2) > eps, break, end
wolffd@0 117 end
wolffd@0 118 if rcond(h2) < eps,
wolffd@0 119 warning(['Stopped at iteration ' num2str(iter) ...
wolffd@0 120 ' because Hessian can''t be conditioned']);
wolffd@0 121 break
wolffd@0 122 end
wolffd@0 123 h = h2;
wolffd@0 124 end
wolffd@0 125
wolffd@0 126 % save lli before update
wolffd@0 127 lli_prev = lli;
wolffd@0 128
wolffd@0 129 % Newton-Raphson with step-size halving
wolffd@0 130 while stepsize >= minstepsize,
wolffd@0 131 % Newton-Raphson update step
wolffd@0 132 step = stepsize * (h \ g);
wolffd@0 133 beta2 = beta;
wolffd@0 134 beta2(:,1:k-1) = beta2(:,1:k-1) - reshape(step,d,k-1);
wolffd@0 135
wolffd@0 136 % get the new log likelihood
wolffd@0 137 post2 = computePost(beta2,x);
wolffd@0 138 lli2 = computeLogLik(post2,y,w);
wolffd@0 139
wolffd@0 140 % if the log likelihood increased, then stop
wolffd@0 141 if lli2 > lli,
wolffd@0 142 post = post2; lli = lli2; beta = beta2;
wolffd@0 143 break
wolffd@0 144 end
wolffd@0 145
wolffd@0 146 % otherwise, reduce step size by half
wolffd@0 147 stepsize = 0.5 * stepsize;
wolffd@0 148 end
wolffd@0 149
wolffd@0 150 % stop if the average log likelihood has gotten small enough
wolffd@0 151 if 1-exp(lli/n) < 1e-2, break, end
wolffd@0 152
wolffd@0 153 % stop if the log likelihood changed by a small enough fraction
wolffd@0 154 dlli = (lli_prev-lli) / lli;
wolffd@0 155 if abs(dlli) < 1e-3, break, end
wolffd@0 156
wolffd@0 157 % stop if the step size has gotten too small
wolffd@0 158 if stepsize < minstepsize, brea, end
wolffd@0 159
wolffd@0 160 % stop if the log likelihood has decreased; this shouldn't happen
wolffd@0 161 if lli < lli_prev,
wolffd@0 162 warning(['Stopped at iteration ' num2str(iter) ...
wolffd@0 163 ' because the log likelihood decreased from ' ...
wolffd@0 164 num2str(lli_prev) ' to ' num2str(lli) '.' ...
wolffd@0 165 ' This may be a bug.']);
wolffd@0 166 break
wolffd@0 167 end
wolffd@0 168 end
wolffd@0 169
wolffd@0 170 if debug>0,
wolffd@0 171 vis(x,y,beta,lli,d,k,iter,2);
wolffd@0 172 end
wolffd@0 173
wolffd@0 174 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 175 %% class posteriors
wolffd@0 176 function post = computePost(beta,x)
wolffd@0 177 [d,n] = size(x);
wolffd@0 178 [d,k] = size(beta);
wolffd@0 179 post = zeros(k,n);
wolffd@0 180 bx = zeros(k,n);
wolffd@0 181 for j = 1:k,
wolffd@0 182 bx(j,:) = beta(:,j)'*x;
wolffd@0 183 end
wolffd@0 184 for j = 1:k,
wolffd@0 185 post(j,:) = 1 ./ sum(exp(bx - repmat(bx(j,:),k,1)),1);
wolffd@0 186 end
wolffd@0 187
wolffd@0 188 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 189 %% log likelihood
wolffd@0 190 function lli = computeLogLik(post,y,w)
wolffd@0 191 [k,n] = size(post);
wolffd@0 192 lli = 0;
wolffd@0 193 for j = 1:k,
wolffd@0 194 lli = lli + sum(w.*y(j,:).*log(post(j,:)+eps));
wolffd@0 195 end
wolffd@0 196 if isnan(lli),
wolffd@0 197 error('lli is nan');
wolffd@0 198 end
wolffd@0 199
wolffd@0 200 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 201 %% gradient and hessian
wolffd@0 202 %% These are computed in what seems a verbose manner, but it is
wolffd@0 203 %% done this way to use minimal memory. x should be transposed
wolffd@0 204 %% to make it faster.
wolffd@0 205 function [g,h] = derivs(post,x,y,w)
wolffd@0 206
wolffd@0 207 [k,n] = size(post);
wolffd@0 208 [d,n] = size(x);
wolffd@0 209
wolffd@0 210 % first derivative of likelihood w.r.t. beta
wolffd@0 211 g = zeros(d,k-1);
wolffd@0 212 for j = 1:k-1,
wolffd@0 213 wyp = w .* (y(j,:) - post(j,:));
wolffd@0 214 for ii = 1:d,
wolffd@0 215 g(ii,j) = x(ii,:) * wyp';
wolffd@0 216 end
wolffd@0 217 end
wolffd@0 218 g = reshape(g,d*(k-1),1);
wolffd@0 219
wolffd@0 220 % hessian of likelihood w.r.t. beta
wolffd@0 221 h = zeros(d*(k-1),d*(k-1));
wolffd@0 222 for i = 1:k-1, % diagonal
wolffd@0 223 wt = w .* post(i,:) .* (1 - post(i,:));
wolffd@0 224 hii = zeros(d,d);
wolffd@0 225 for a = 1:d,
wolffd@0 226 wxa = wt .* x(a,:);
wolffd@0 227 for b = a:d,
wolffd@0 228 hii_ab = wxa * x(b,:)';
wolffd@0 229 hii(a,b) = hii_ab;
wolffd@0 230 hii(b,a) = hii_ab;
wolffd@0 231 end
wolffd@0 232 end
wolffd@0 233 h( (i-1)*d+1 : i*d , (i-1)*d+1 : i*d ) = -hii;
wolffd@0 234 end
wolffd@0 235 for i = 1:k-1, % off-diagonal
wolffd@0 236 for j = i+1:k-1,
wolffd@0 237 wt = w .* post(j,:) .* post(i,:);
wolffd@0 238 hij = zeros(d,d);
wolffd@0 239 for a = 1:d,
wolffd@0 240 wxa = wt .* x(a,:);
wolffd@0 241 for b = a:d,
wolffd@0 242 hij_ab = wxa * x(b,:)';
wolffd@0 243 hij(a,b) = hij_ab;
wolffd@0 244 hij(b,a) = hij_ab;
wolffd@0 245 end
wolffd@0 246 end
wolffd@0 247 h( (i-1)*d+1 : i*d , (j-1)*d+1 : j*d ) = hij;
wolffd@0 248 h( (j-1)*d+1 : j*d , (i-1)*d+1 : i*d ) = hij;
wolffd@0 249 end
wolffd@0 250 end
wolffd@0 251
wolffd@0 252 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 253 %% debug/visualization
wolffd@0 254 function vis (x,y,beta,lli,d,k,iter,debug)
wolffd@0 255
wolffd@0 256 if debug<=0, return, end
wolffd@0 257
wolffd@0 258 disp(['iter=' num2str(iter) ' lli=' num2str(lli)]);
wolffd@0 259 if debug<=1, return, end
wolffd@0 260
wolffd@0 261 if d~=3 | k>10, return, end
wolffd@0 262
wolffd@0 263 figure(1);
wolffd@0 264 res = 100;
wolffd@0 265 r = abs(max(max(x)));
wolffd@0 266 dom = linspace(-r,r,res);
wolffd@0 267 [px,py] = meshgrid(dom,dom);
wolffd@0 268 xx = px(:); yy = py(:);
wolffd@0 269 points = [xx' ; yy' ; ones(1,res*res)];
wolffd@0 270 func = zeros(k,res*res);
wolffd@0 271 for j = 1:k,
wolffd@0 272 func(j,:) = exp(beta(:,j)'*points);
wolffd@0 273 end
wolffd@0 274 [mval,ind] = max(func,[],1);
wolffd@0 275 hold off;
wolffd@0 276 im = reshape(ind,res,res);
wolffd@0 277 imagesc(xx,yy,im);
wolffd@0 278 hold on;
wolffd@0 279 syms = {'w.' 'wx' 'w+' 'wo' 'w*' 'ws' 'wd' 'wv' 'w^' 'w<'};
wolffd@0 280 for j = 1:k,
wolffd@0 281 [mval,ind] = max(y,[],1);
wolffd@0 282 ind = find(ind==j);
wolffd@0 283 plot(x(1,ind),x(2,ind),syms{j});
wolffd@0 284 end
wolffd@0 285 pause(0.1);
wolffd@0 286
wolffd@0 287 % eof