wolffd@0
|
1 function [beta,post,lli] = logistK(x,y,w,beta)
|
wolffd@0
|
2 % [beta,post,lli] = logistK(x,y,beta,w)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % k-class logistic regression with optional sample weights
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % k = number of classes
|
wolffd@0
|
7 % n = number of samples
|
wolffd@0
|
8 % d = dimensionality of samples
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % INPUT
|
wolffd@0
|
11 % x dxn matrix of n input column vectors
|
wolffd@0
|
12 % y kxn vector of class assignments
|
wolffd@0
|
13 % [w] 1xn vector of sample weights
|
wolffd@0
|
14 % [beta] dxk matrix of model coefficients
|
wolffd@0
|
15 %
|
wolffd@0
|
16 % OUTPUT
|
wolffd@0
|
17 % beta dxk matrix of fitted model coefficients
|
wolffd@0
|
18 % (beta(:,k) are fixed at 0)
|
wolffd@0
|
19 % post kxn matrix of fitted class posteriors
|
wolffd@0
|
20 % lli log likelihood
|
wolffd@0
|
21 %
|
wolffd@0
|
22 % Let p(i,j) = exp(beta(:,j)'*x(:,i)),
|
wolffd@0
|
23 % Class j posterior for observation i is:
|
wolffd@0
|
24 % post(j,i) = p(i,j) / (p(i,1) + ... p(i,k))
|
wolffd@0
|
25 %
|
wolffd@0
|
26 % See also logistK_eval.
|
wolffd@0
|
27 %
|
wolffd@0
|
28 % David Martin <dmartin@eecs.berkeley.edu>
|
wolffd@0
|
29 % May 3, 2002
|
wolffd@0
|
30
|
wolffd@0
|
31 % Copyright (C) 2002 David R. Martin <dmartin@eecs.berkeley.edu>
|
wolffd@0
|
32 %
|
wolffd@0
|
33 % This program is free software; you can redistribute it and/or
|
wolffd@0
|
34 % modify it under the terms of the GNU General Public License as
|
wolffd@0
|
35 % published by the Free Software Foundation; either version 2 of the
|
wolffd@0
|
36 % License, or (at your option) any later version.
|
wolffd@0
|
37 %
|
wolffd@0
|
38 % This program is distributed in the hope that it will be useful, but
|
wolffd@0
|
39 % WITHOUT ANY WARRANTY; without even the implied warranty of
|
wolffd@0
|
40 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
wolffd@0
|
41 % General Public License for more details.
|
wolffd@0
|
42 %
|
wolffd@0
|
43 % You should have received a copy of the GNU General Public License
|
wolffd@0
|
44 % along with this program; if not, write to the Free Software
|
wolffd@0
|
45 % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
|
wolffd@0
|
46 % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
|
wolffd@0
|
47
|
wolffd@0
|
48 % TODO - this code would be faster if x were transposed
|
wolffd@0
|
49
|
wolffd@0
|
50 error(nargchk(2,4,nargin));
|
wolffd@0
|
51
|
wolffd@0
|
52 debug = 0;
|
wolffd@0
|
53 if debug>0,
|
wolffd@0
|
54 h=figure(1);
|
wolffd@0
|
55 set(h,'DoubleBuffer','on');
|
wolffd@0
|
56 end
|
wolffd@0
|
57
|
wolffd@0
|
58 % get sizes
|
wolffd@0
|
59 [d,nx] = size(x);
|
wolffd@0
|
60 [k,ny] = size(y);
|
wolffd@0
|
61
|
wolffd@0
|
62 % check sizes
|
wolffd@0
|
63 if k < 2,
|
wolffd@0
|
64 error('Input y must encode at least 2 classes.');
|
wolffd@0
|
65 end
|
wolffd@0
|
66 if nx ~= ny,
|
wolffd@0
|
67 error('Inputs x,y not the same length.');
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 n = nx;
|
wolffd@0
|
71
|
wolffd@0
|
72 % make sure class assignments have unit L1-norm
|
wolffd@0
|
73 sumy = sum(y,1);
|
wolffd@0
|
74 if abs(1-sumy) > eps,
|
wolffd@0
|
75 sumy = sum(y,1);
|
wolffd@0
|
76 for i = 1:k, y(i,:) = y(i,:) ./ sumy; end
|
wolffd@0
|
77 end
|
wolffd@0
|
78 clear sumy;
|
wolffd@0
|
79
|
wolffd@0
|
80 % if sample weights weren't specified, set them to 1
|
wolffd@0
|
81 if nargin < 3,
|
wolffd@0
|
82 w = ones(1,n);
|
wolffd@0
|
83 end
|
wolffd@0
|
84
|
wolffd@0
|
85 % normalize sample weights so max is 1
|
wolffd@0
|
86 w = w / max(w);
|
wolffd@0
|
87
|
wolffd@0
|
88 % if starting beta wasn't specified, initialize randomly
|
wolffd@0
|
89 if nargin < 4,
|
wolffd@0
|
90 beta = 1e-3*rand(d,k);
|
wolffd@0
|
91 beta(:,k) = 0; % fix beta for class k at zero
|
wolffd@0
|
92 else
|
wolffd@0
|
93 if sum(beta(:,k)) ~= 0,
|
wolffd@0
|
94 error('beta(:,k) ~= 0');
|
wolffd@0
|
95 end
|
wolffd@0
|
96 end
|
wolffd@0
|
97
|
wolffd@0
|
98 stepsize = 1;
|
wolffd@0
|
99 minstepsize = 1e-2;
|
wolffd@0
|
100
|
wolffd@0
|
101 post = computePost(beta,x);
|
wolffd@0
|
102 lli = computeLogLik(post,y,w);
|
wolffd@0
|
103
|
wolffd@0
|
104 for iter = 1:100,
|
wolffd@0
|
105 %disp(sprintf(' logist iter=%d lli=%g',iter,lli));
|
wolffd@0
|
106 vis(x,y,beta,lli,d,k,iter,debug);
|
wolffd@0
|
107
|
wolffd@0
|
108 % gradient and hessian
|
wolffd@0
|
109 [g,h] = derivs(post,x,y,w);
|
wolffd@0
|
110
|
wolffd@0
|
111 % make sure Hessian is well conditioned
|
wolffd@0
|
112 if rcond(h) < eps,
|
wolffd@0
|
113 % condition with Levenberg-Marquardt method
|
wolffd@0
|
114 for i = -16:16,
|
wolffd@0
|
115 h2 = h .* ((1 + 10^i)*eye(size(h)) + (1-eye(size(h))));
|
wolffd@0
|
116 if rcond(h2) > eps, break, end
|
wolffd@0
|
117 end
|
wolffd@0
|
118 if rcond(h2) < eps,
|
wolffd@0
|
119 warning(['Stopped at iteration ' num2str(iter) ...
|
wolffd@0
|
120 ' because Hessian can''t be conditioned']);
|
wolffd@0
|
121 break
|
wolffd@0
|
122 end
|
wolffd@0
|
123 h = h2;
|
wolffd@0
|
124 end
|
wolffd@0
|
125
|
wolffd@0
|
126 % save lli before update
|
wolffd@0
|
127 lli_prev = lli;
|
wolffd@0
|
128
|
wolffd@0
|
129 % Newton-Raphson with step-size halving
|
wolffd@0
|
130 while stepsize >= minstepsize,
|
wolffd@0
|
131 % Newton-Raphson update step
|
wolffd@0
|
132 step = stepsize * (h \ g);
|
wolffd@0
|
133 beta2 = beta;
|
wolffd@0
|
134 beta2(:,1:k-1) = beta2(:,1:k-1) - reshape(step,d,k-1);
|
wolffd@0
|
135
|
wolffd@0
|
136 % get the new log likelihood
|
wolffd@0
|
137 post2 = computePost(beta2,x);
|
wolffd@0
|
138 lli2 = computeLogLik(post2,y,w);
|
wolffd@0
|
139
|
wolffd@0
|
140 % if the log likelihood increased, then stop
|
wolffd@0
|
141 if lli2 > lli,
|
wolffd@0
|
142 post = post2; lli = lli2; beta = beta2;
|
wolffd@0
|
143 break
|
wolffd@0
|
144 end
|
wolffd@0
|
145
|
wolffd@0
|
146 % otherwise, reduce step size by half
|
wolffd@0
|
147 stepsize = 0.5 * stepsize;
|
wolffd@0
|
148 end
|
wolffd@0
|
149
|
wolffd@0
|
150 % stop if the average log likelihood has gotten small enough
|
wolffd@0
|
151 if 1-exp(lli/n) < 1e-2, break, end
|
wolffd@0
|
152
|
wolffd@0
|
153 % stop if the log likelihood changed by a small enough fraction
|
wolffd@0
|
154 dlli = (lli_prev-lli) / lli;
|
wolffd@0
|
155 if abs(dlli) < 1e-3, break, end
|
wolffd@0
|
156
|
wolffd@0
|
157 % stop if the step size has gotten too small
|
wolffd@0
|
158 if stepsize < minstepsize, brea, end
|
wolffd@0
|
159
|
wolffd@0
|
160 % stop if the log likelihood has decreased; this shouldn't happen
|
wolffd@0
|
161 if lli < lli_prev,
|
wolffd@0
|
162 warning(['Stopped at iteration ' num2str(iter) ...
|
wolffd@0
|
163 ' because the log likelihood decreased from ' ...
|
wolffd@0
|
164 num2str(lli_prev) ' to ' num2str(lli) '.' ...
|
wolffd@0
|
165 ' This may be a bug.']);
|
wolffd@0
|
166 break
|
wolffd@0
|
167 end
|
wolffd@0
|
168 end
|
wolffd@0
|
169
|
wolffd@0
|
170 if debug>0,
|
wolffd@0
|
171 vis(x,y,beta,lli,d,k,iter,2);
|
wolffd@0
|
172 end
|
wolffd@0
|
173
|
wolffd@0
|
174 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
175 %% class posteriors
|
wolffd@0
|
176 function post = computePost(beta,x)
|
wolffd@0
|
177 [d,n] = size(x);
|
wolffd@0
|
178 [d,k] = size(beta);
|
wolffd@0
|
179 post = zeros(k,n);
|
wolffd@0
|
180 bx = zeros(k,n);
|
wolffd@0
|
181 for j = 1:k,
|
wolffd@0
|
182 bx(j,:) = beta(:,j)'*x;
|
wolffd@0
|
183 end
|
wolffd@0
|
184 for j = 1:k,
|
wolffd@0
|
185 post(j,:) = 1 ./ sum(exp(bx - repmat(bx(j,:),k,1)),1);
|
wolffd@0
|
186 end
|
wolffd@0
|
187
|
wolffd@0
|
188 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
189 %% log likelihood
|
wolffd@0
|
190 function lli = computeLogLik(post,y,w)
|
wolffd@0
|
191 [k,n] = size(post);
|
wolffd@0
|
192 lli = 0;
|
wolffd@0
|
193 for j = 1:k,
|
wolffd@0
|
194 lli = lli + sum(w.*y(j,:).*log(post(j,:)+eps));
|
wolffd@0
|
195 end
|
wolffd@0
|
196 if isnan(lli),
|
wolffd@0
|
197 error('lli is nan');
|
wolffd@0
|
198 end
|
wolffd@0
|
199
|
wolffd@0
|
200 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
201 %% gradient and hessian
|
wolffd@0
|
202 %% These are computed in what seems a verbose manner, but it is
|
wolffd@0
|
203 %% done this way to use minimal memory. x should be transposed
|
wolffd@0
|
204 %% to make it faster.
|
wolffd@0
|
205 function [g,h] = derivs(post,x,y,w)
|
wolffd@0
|
206
|
wolffd@0
|
207 [k,n] = size(post);
|
wolffd@0
|
208 [d,n] = size(x);
|
wolffd@0
|
209
|
wolffd@0
|
210 % first derivative of likelihood w.r.t. beta
|
wolffd@0
|
211 g = zeros(d,k-1);
|
wolffd@0
|
212 for j = 1:k-1,
|
wolffd@0
|
213 wyp = w .* (y(j,:) - post(j,:));
|
wolffd@0
|
214 for ii = 1:d,
|
wolffd@0
|
215 g(ii,j) = x(ii,:) * wyp';
|
wolffd@0
|
216 end
|
wolffd@0
|
217 end
|
wolffd@0
|
218 g = reshape(g,d*(k-1),1);
|
wolffd@0
|
219
|
wolffd@0
|
220 % hessian of likelihood w.r.t. beta
|
wolffd@0
|
221 h = zeros(d*(k-1),d*(k-1));
|
wolffd@0
|
222 for i = 1:k-1, % diagonal
|
wolffd@0
|
223 wt = w .* post(i,:) .* (1 - post(i,:));
|
wolffd@0
|
224 hii = zeros(d,d);
|
wolffd@0
|
225 for a = 1:d,
|
wolffd@0
|
226 wxa = wt .* x(a,:);
|
wolffd@0
|
227 for b = a:d,
|
wolffd@0
|
228 hii_ab = wxa * x(b,:)';
|
wolffd@0
|
229 hii(a,b) = hii_ab;
|
wolffd@0
|
230 hii(b,a) = hii_ab;
|
wolffd@0
|
231 end
|
wolffd@0
|
232 end
|
wolffd@0
|
233 h( (i-1)*d+1 : i*d , (i-1)*d+1 : i*d ) = -hii;
|
wolffd@0
|
234 end
|
wolffd@0
|
235 for i = 1:k-1, % off-diagonal
|
wolffd@0
|
236 for j = i+1:k-1,
|
wolffd@0
|
237 wt = w .* post(j,:) .* post(i,:);
|
wolffd@0
|
238 hij = zeros(d,d);
|
wolffd@0
|
239 for a = 1:d,
|
wolffd@0
|
240 wxa = wt .* x(a,:);
|
wolffd@0
|
241 for b = a:d,
|
wolffd@0
|
242 hij_ab = wxa * x(b,:)';
|
wolffd@0
|
243 hij(a,b) = hij_ab;
|
wolffd@0
|
244 hij(b,a) = hij_ab;
|
wolffd@0
|
245 end
|
wolffd@0
|
246 end
|
wolffd@0
|
247 h( (i-1)*d+1 : i*d , (j-1)*d+1 : j*d ) = hij;
|
wolffd@0
|
248 h( (j-1)*d+1 : j*d , (i-1)*d+1 : i*d ) = hij;
|
wolffd@0
|
249 end
|
wolffd@0
|
250 end
|
wolffd@0
|
251
|
wolffd@0
|
252 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
253 %% debug/visualization
|
wolffd@0
|
254 function vis (x,y,beta,lli,d,k,iter,debug)
|
wolffd@0
|
255
|
wolffd@0
|
256 if debug<=0, return, end
|
wolffd@0
|
257
|
wolffd@0
|
258 disp(['iter=' num2str(iter) ' lli=' num2str(lli)]);
|
wolffd@0
|
259 if debug<=1, return, end
|
wolffd@0
|
260
|
wolffd@0
|
261 if d~=3 | k>10, return, end
|
wolffd@0
|
262
|
wolffd@0
|
263 figure(1);
|
wolffd@0
|
264 res = 100;
|
wolffd@0
|
265 r = abs(max(max(x)));
|
wolffd@0
|
266 dom = linspace(-r,r,res);
|
wolffd@0
|
267 [px,py] = meshgrid(dom,dom);
|
wolffd@0
|
268 xx = px(:); yy = py(:);
|
wolffd@0
|
269 points = [xx' ; yy' ; ones(1,res*res)];
|
wolffd@0
|
270 func = zeros(k,res*res);
|
wolffd@0
|
271 for j = 1:k,
|
wolffd@0
|
272 func(j,:) = exp(beta(:,j)'*points);
|
wolffd@0
|
273 end
|
wolffd@0
|
274 [mval,ind] = max(func,[],1);
|
wolffd@0
|
275 hold off;
|
wolffd@0
|
276 im = reshape(ind,res,res);
|
wolffd@0
|
277 imagesc(xx,yy,im);
|
wolffd@0
|
278 hold on;
|
wolffd@0
|
279 syms = {'w.' 'wx' 'w+' 'wo' 'w*' 'ws' 'wd' 'wv' 'w^' 'w<'};
|
wolffd@0
|
280 for j = 1:k,
|
wolffd@0
|
281 [mval,ind] = max(y,[],1);
|
wolffd@0
|
282 ind = find(ind==j);
|
wolffd@0
|
283 plot(x(1,ind),x(2,ind),syms{j});
|
wolffd@0
|
284 end
|
wolffd@0
|
285 pause(0.1);
|
wolffd@0
|
286
|
wolffd@0
|
287 % eof
|