Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/KPMstats/logistK.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [beta,post,lli] = logistK(x,y,w,beta) | |
2 % [beta,post,lli] = logistK(x,y,beta,w) | |
3 % | |
4 % k-class logistic regression with optional sample weights | |
5 % | |
6 % k = number of classes | |
7 % n = number of samples | |
8 % d = dimensionality of samples | |
9 % | |
10 % INPUT | |
11 % x dxn matrix of n input column vectors | |
12 % y kxn vector of class assignments | |
13 % [w] 1xn vector of sample weights | |
14 % [beta] dxk matrix of model coefficients | |
15 % | |
16 % OUTPUT | |
17 % beta dxk matrix of fitted model coefficients | |
18 % (beta(:,k) are fixed at 0) | |
19 % post kxn matrix of fitted class posteriors | |
20 % lli log likelihood | |
21 % | |
22 % Let p(i,j) = exp(beta(:,j)'*x(:,i)), | |
23 % Class j posterior for observation i is: | |
24 % post(j,i) = p(i,j) / (p(i,1) + ... p(i,k)) | |
25 % | |
26 % See also logistK_eval. | |
27 % | |
28 % David Martin <dmartin@eecs.berkeley.edu> | |
29 % May 3, 2002 | |
30 | |
31 % Copyright (C) 2002 David R. Martin <dmartin@eecs.berkeley.edu> | |
32 % | |
33 % This program is free software; you can redistribute it and/or | |
34 % modify it under the terms of the GNU General Public License as | |
35 % published by the Free Software Foundation; either version 2 of the | |
36 % License, or (at your option) any later version. | |
37 % | |
38 % This program is distributed in the hope that it will be useful, but | |
39 % WITHOUT ANY WARRANTY; without even the implied warranty of | |
40 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
41 % General Public License for more details. | |
42 % | |
43 % You should have received a copy of the GNU General Public License | |
44 % along with this program; if not, write to the Free Software | |
45 % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA | |
46 % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. | |
47 | |
48 % TODO - this code would be faster if x were transposed | |
49 | |
50 error(nargchk(2,4,nargin)); | |
51 | |
52 debug = 0; | |
53 if debug>0, | |
54 h=figure(1); | |
55 set(h,'DoubleBuffer','on'); | |
56 end | |
57 | |
58 % get sizes | |
59 [d,nx] = size(x); | |
60 [k,ny] = size(y); | |
61 | |
62 % check sizes | |
63 if k < 2, | |
64 error('Input y must encode at least 2 classes.'); | |
65 end | |
66 if nx ~= ny, | |
67 error('Inputs x,y not the same length.'); | |
68 end | |
69 | |
70 n = nx; | |
71 | |
72 % make sure class assignments have unit L1-norm | |
73 sumy = sum(y,1); | |
74 if abs(1-sumy) > eps, | |
75 sumy = sum(y,1); | |
76 for i = 1:k, y(i,:) = y(i,:) ./ sumy; end | |
77 end | |
78 clear sumy; | |
79 | |
80 % if sample weights weren't specified, set them to 1 | |
81 if nargin < 3, | |
82 w = ones(1,n); | |
83 end | |
84 | |
85 % normalize sample weights so max is 1 | |
86 w = w / max(w); | |
87 | |
88 % if starting beta wasn't specified, initialize randomly | |
89 if nargin < 4, | |
90 beta = 1e-3*rand(d,k); | |
91 beta(:,k) = 0; % fix beta for class k at zero | |
92 else | |
93 if sum(beta(:,k)) ~= 0, | |
94 error('beta(:,k) ~= 0'); | |
95 end | |
96 end | |
97 | |
98 stepsize = 1; | |
99 minstepsize = 1e-2; | |
100 | |
101 post = computePost(beta,x); | |
102 lli = computeLogLik(post,y,w); | |
103 | |
104 for iter = 1:100, | |
105 %disp(sprintf(' logist iter=%d lli=%g',iter,lli)); | |
106 vis(x,y,beta,lli,d,k,iter,debug); | |
107 | |
108 % gradient and hessian | |
109 [g,h] = derivs(post,x,y,w); | |
110 | |
111 % make sure Hessian is well conditioned | |
112 if rcond(h) < eps, | |
113 % condition with Levenberg-Marquardt method | |
114 for i = -16:16, | |
115 h2 = h .* ((1 + 10^i)*eye(size(h)) + (1-eye(size(h)))); | |
116 if rcond(h2) > eps, break, end | |
117 end | |
118 if rcond(h2) < eps, | |
119 warning(['Stopped at iteration ' num2str(iter) ... | |
120 ' because Hessian can''t be conditioned']); | |
121 break | |
122 end | |
123 h = h2; | |
124 end | |
125 | |
126 % save lli before update | |
127 lli_prev = lli; | |
128 | |
129 % Newton-Raphson with step-size halving | |
130 while stepsize >= minstepsize, | |
131 % Newton-Raphson update step | |
132 step = stepsize * (h \ g); | |
133 beta2 = beta; | |
134 beta2(:,1:k-1) = beta2(:,1:k-1) - reshape(step,d,k-1); | |
135 | |
136 % get the new log likelihood | |
137 post2 = computePost(beta2,x); | |
138 lli2 = computeLogLik(post2,y,w); | |
139 | |
140 % if the log likelihood increased, then stop | |
141 if lli2 > lli, | |
142 post = post2; lli = lli2; beta = beta2; | |
143 break | |
144 end | |
145 | |
146 % otherwise, reduce step size by half | |
147 stepsize = 0.5 * stepsize; | |
148 end | |
149 | |
150 % stop if the average log likelihood has gotten small enough | |
151 if 1-exp(lli/n) < 1e-2, break, end | |
152 | |
153 % stop if the log likelihood changed by a small enough fraction | |
154 dlli = (lli_prev-lli) / lli; | |
155 if abs(dlli) < 1e-3, break, end | |
156 | |
157 % stop if the step size has gotten too small | |
158 if stepsize < minstepsize, brea, end | |
159 | |
160 % stop if the log likelihood has decreased; this shouldn't happen | |
161 if lli < lli_prev, | |
162 warning(['Stopped at iteration ' num2str(iter) ... | |
163 ' because the log likelihood decreased from ' ... | |
164 num2str(lli_prev) ' to ' num2str(lli) '.' ... | |
165 ' This may be a bug.']); | |
166 break | |
167 end | |
168 end | |
169 | |
170 if debug>0, | |
171 vis(x,y,beta,lli,d,k,iter,2); | |
172 end | |
173 | |
174 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
175 %% class posteriors | |
176 function post = computePost(beta,x) | |
177 [d,n] = size(x); | |
178 [d,k] = size(beta); | |
179 post = zeros(k,n); | |
180 bx = zeros(k,n); | |
181 for j = 1:k, | |
182 bx(j,:) = beta(:,j)'*x; | |
183 end | |
184 for j = 1:k, | |
185 post(j,:) = 1 ./ sum(exp(bx - repmat(bx(j,:),k,1)),1); | |
186 end | |
187 | |
188 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
189 %% log likelihood | |
190 function lli = computeLogLik(post,y,w) | |
191 [k,n] = size(post); | |
192 lli = 0; | |
193 for j = 1:k, | |
194 lli = lli + sum(w.*y(j,:).*log(post(j,:)+eps)); | |
195 end | |
196 if isnan(lli), | |
197 error('lli is nan'); | |
198 end | |
199 | |
200 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
201 %% gradient and hessian | |
202 %% These are computed in what seems a verbose manner, but it is | |
203 %% done this way to use minimal memory. x should be transposed | |
204 %% to make it faster. | |
205 function [g,h] = derivs(post,x,y,w) | |
206 | |
207 [k,n] = size(post); | |
208 [d,n] = size(x); | |
209 | |
210 % first derivative of likelihood w.r.t. beta | |
211 g = zeros(d,k-1); | |
212 for j = 1:k-1, | |
213 wyp = w .* (y(j,:) - post(j,:)); | |
214 for ii = 1:d, | |
215 g(ii,j) = x(ii,:) * wyp'; | |
216 end | |
217 end | |
218 g = reshape(g,d*(k-1),1); | |
219 | |
220 % hessian of likelihood w.r.t. beta | |
221 h = zeros(d*(k-1),d*(k-1)); | |
222 for i = 1:k-1, % diagonal | |
223 wt = w .* post(i,:) .* (1 - post(i,:)); | |
224 hii = zeros(d,d); | |
225 for a = 1:d, | |
226 wxa = wt .* x(a,:); | |
227 for b = a:d, | |
228 hii_ab = wxa * x(b,:)'; | |
229 hii(a,b) = hii_ab; | |
230 hii(b,a) = hii_ab; | |
231 end | |
232 end | |
233 h( (i-1)*d+1 : i*d , (i-1)*d+1 : i*d ) = -hii; | |
234 end | |
235 for i = 1:k-1, % off-diagonal | |
236 for j = i+1:k-1, | |
237 wt = w .* post(j,:) .* post(i,:); | |
238 hij = zeros(d,d); | |
239 for a = 1:d, | |
240 wxa = wt .* x(a,:); | |
241 for b = a:d, | |
242 hij_ab = wxa * x(b,:)'; | |
243 hij(a,b) = hij_ab; | |
244 hij(b,a) = hij_ab; | |
245 end | |
246 end | |
247 h( (i-1)*d+1 : i*d , (j-1)*d+1 : j*d ) = hij; | |
248 h( (j-1)*d+1 : j*d , (i-1)*d+1 : i*d ) = hij; | |
249 end | |
250 end | |
251 | |
252 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
253 %% debug/visualization | |
254 function vis (x,y,beta,lli,d,k,iter,debug) | |
255 | |
256 if debug<=0, return, end | |
257 | |
258 disp(['iter=' num2str(iter) ' lli=' num2str(lli)]); | |
259 if debug<=1, return, end | |
260 | |
261 if d~=3 | k>10, return, end | |
262 | |
263 figure(1); | |
264 res = 100; | |
265 r = abs(max(max(x))); | |
266 dom = linspace(-r,r,res); | |
267 [px,py] = meshgrid(dom,dom); | |
268 xx = px(:); yy = py(:); | |
269 points = [xx' ; yy' ; ones(1,res*res)]; | |
270 func = zeros(k,res*res); | |
271 for j = 1:k, | |
272 func(j,:) = exp(beta(:,j)'*points); | |
273 end | |
274 [mval,ind] = max(func,[],1); | |
275 hold off; | |
276 im = reshape(ind,res,res); | |
277 imagesc(xx,yy,im); | |
278 hold on; | |
279 syms = {'w.' 'wx' 'w+' 'wo' 'w*' 'ws' 'wd' 'wv' 'w^' 'w<'}; | |
280 for j = 1:k, | |
281 [mval,ind] = max(y,[],1); | |
282 ind = find(ind==j); | |
283 plot(x(1,ind),x(2,ind),syms{j}); | |
284 end | |
285 pause(0.1); | |
286 | |
287 % eof |