Mercurial > hg > camir-aes2014
view toolboxes/FullBNT-1.0.7/KPMstats/logistK.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line source
function [beta,post,lli] = logistK(x,y,w,beta) % [beta,post,lli] = logistK(x,y,beta,w) % % k-class logistic regression with optional sample weights % % k = number of classes % n = number of samples % d = dimensionality of samples % % INPUT % x dxn matrix of n input column vectors % y kxn vector of class assignments % [w] 1xn vector of sample weights % [beta] dxk matrix of model coefficients % % OUTPUT % beta dxk matrix of fitted model coefficients % (beta(:,k) are fixed at 0) % post kxn matrix of fitted class posteriors % lli log likelihood % % Let p(i,j) = exp(beta(:,j)'*x(:,i)), % Class j posterior for observation i is: % post(j,i) = p(i,j) / (p(i,1) + ... p(i,k)) % % See also logistK_eval. % % David Martin <dmartin@eecs.berkeley.edu> % May 3, 2002 % Copyright (C) 2002 David R. Martin <dmartin@eecs.berkeley.edu> % % This program is free software; you can redistribute it and/or % modify it under the terms of the GNU General Public License as % published by the Free Software Foundation; either version 2 of the % License, or (at your option) any later version. % % This program is distributed in the hope that it will be useful, but % WITHOUT ANY WARRANTY; without even the implied warranty of % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU % General Public License for more details. % % You should have received a copy of the GNU General Public License % along with this program; if not, write to the Free Software % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. % TODO - this code would be faster if x were transposed error(nargchk(2,4,nargin)); debug = 0; if debug>0, h=figure(1); set(h,'DoubleBuffer','on'); end % get sizes [d,nx] = size(x); [k,ny] = size(y); % check sizes if k < 2, error('Input y must encode at least 2 classes.'); end if nx ~= ny, error('Inputs x,y not the same length.'); end n = nx; % make sure class assignments have unit L1-norm sumy = sum(y,1); if abs(1-sumy) > eps, sumy = sum(y,1); for i = 1:k, y(i,:) = y(i,:) ./ sumy; end end clear sumy; % if sample weights weren't specified, set them to 1 if nargin < 3, w = ones(1,n); end % normalize sample weights so max is 1 w = w / max(w); % if starting beta wasn't specified, initialize randomly if nargin < 4, beta = 1e-3*rand(d,k); beta(:,k) = 0; % fix beta for class k at zero else if sum(beta(:,k)) ~= 0, error('beta(:,k) ~= 0'); end end stepsize = 1; minstepsize = 1e-2; post = computePost(beta,x); lli = computeLogLik(post,y,w); for iter = 1:100, %disp(sprintf(' logist iter=%d lli=%g',iter,lli)); vis(x,y,beta,lli,d,k,iter,debug); % gradient and hessian [g,h] = derivs(post,x,y,w); % make sure Hessian is well conditioned if rcond(h) < eps, % condition with Levenberg-Marquardt method for i = -16:16, h2 = h .* ((1 + 10^i)*eye(size(h)) + (1-eye(size(h)))); if rcond(h2) > eps, break, end end if rcond(h2) < eps, warning(['Stopped at iteration ' num2str(iter) ... ' because Hessian can''t be conditioned']); break end h = h2; end % save lli before update lli_prev = lli; % Newton-Raphson with step-size halving while stepsize >= minstepsize, % Newton-Raphson update step step = stepsize * (h \ g); beta2 = beta; beta2(:,1:k-1) = beta2(:,1:k-1) - reshape(step,d,k-1); % get the new log likelihood post2 = computePost(beta2,x); lli2 = computeLogLik(post2,y,w); % if the log likelihood increased, then stop if lli2 > lli, post = post2; lli = lli2; beta = beta2; break end % otherwise, reduce step size by half stepsize = 0.5 * stepsize; end % stop if the average log likelihood has gotten small enough if 1-exp(lli/n) < 1e-2, break, end % stop if the log likelihood changed by a small enough fraction dlli = (lli_prev-lli) / lli; if abs(dlli) < 1e-3, break, end % stop if the step size has gotten too small if stepsize < minstepsize, brea, end % stop if the log likelihood has decreased; this shouldn't happen if lli < lli_prev, warning(['Stopped at iteration ' num2str(iter) ... ' because the log likelihood decreased from ' ... num2str(lli_prev) ' to ' num2str(lli) '.' ... ' This may be a bug.']); break end end if debug>0, vis(x,y,beta,lli,d,k,iter,2); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% class posteriors function post = computePost(beta,x) [d,n] = size(x); [d,k] = size(beta); post = zeros(k,n); bx = zeros(k,n); for j = 1:k, bx(j,:) = beta(:,j)'*x; end for j = 1:k, post(j,:) = 1 ./ sum(exp(bx - repmat(bx(j,:),k,1)),1); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% log likelihood function lli = computeLogLik(post,y,w) [k,n] = size(post); lli = 0; for j = 1:k, lli = lli + sum(w.*y(j,:).*log(post(j,:)+eps)); end if isnan(lli), error('lli is nan'); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% gradient and hessian %% These are computed in what seems a verbose manner, but it is %% done this way to use minimal memory. x should be transposed %% to make it faster. function [g,h] = derivs(post,x,y,w) [k,n] = size(post); [d,n] = size(x); % first derivative of likelihood w.r.t. beta g = zeros(d,k-1); for j = 1:k-1, wyp = w .* (y(j,:) - post(j,:)); for ii = 1:d, g(ii,j) = x(ii,:) * wyp'; end end g = reshape(g,d*(k-1),1); % hessian of likelihood w.r.t. beta h = zeros(d*(k-1),d*(k-1)); for i = 1:k-1, % diagonal wt = w .* post(i,:) .* (1 - post(i,:)); hii = zeros(d,d); for a = 1:d, wxa = wt .* x(a,:); for b = a:d, hii_ab = wxa * x(b,:)'; hii(a,b) = hii_ab; hii(b,a) = hii_ab; end end h( (i-1)*d+1 : i*d , (i-1)*d+1 : i*d ) = -hii; end for i = 1:k-1, % off-diagonal for j = i+1:k-1, wt = w .* post(j,:) .* post(i,:); hij = zeros(d,d); for a = 1:d, wxa = wt .* x(a,:); for b = a:d, hij_ab = wxa * x(b,:)'; hij(a,b) = hij_ab; hij(b,a) = hij_ab; end end h( (i-1)*d+1 : i*d , (j-1)*d+1 : j*d ) = hij; h( (j-1)*d+1 : j*d , (i-1)*d+1 : i*d ) = hij; end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% debug/visualization function vis (x,y,beta,lli,d,k,iter,debug) if debug<=0, return, end disp(['iter=' num2str(iter) ' lli=' num2str(lli)]); if debug<=1, return, end if d~=3 | k>10, return, end figure(1); res = 100; r = abs(max(max(x))); dom = linspace(-r,r,res); [px,py] = meshgrid(dom,dom); xx = px(:); yy = py(:); points = [xx' ; yy' ; ones(1,res*res)]; func = zeros(k,res*res); for j = 1:k, func(j,:) = exp(beta(:,j)'*points); end [mval,ind] = max(func,[],1); hold off; im = reshape(ind,res,res); imagesc(xx,yy,im); hold on; syms = {'w.' 'wx' 'w+' 'wo' 'w*' 'ws' 'wd' 'wv' 'w^' 'w<'}; for j = 1:k, [mval,ind] = max(y,[],1); ind = find(ind==j); plot(x(1,ind),x(2,ind),syms{j}); end pause(0.1); % eof