Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/mlr_train.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin) | |
2 % | |
3 % [W, Xi, D] = mlr_train(X, Y, C,...) | |
4 % | |
5 % X = d*n data matrix | |
6 % Y = either n-by-1 label of vectors | |
7 % OR | |
8 % n-by-2 cell array where | |
9 % Y{q,1} contains relevant indices for q, and | |
10 % Y{q,2} contains irrelevant indices for q | |
11 % | |
12 % C >= 0 slack trade-off parameter (default=1) | |
13 % | |
14 % W = the learned metric | |
15 % Xi = slack value on the learned metric | |
16 % D = diagnostics | |
17 % | |
18 % Optional arguments: | |
19 % | |
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS) | |
21 % where LOSS is one of: | |
22 % 'AUC': Area under ROC curve (default) | |
23 % 'KNN': KNN accuracy | |
24 % 'Prec@k': Precision-at-k | |
25 % 'MAP': Mean Average Precision | |
26 % 'MRR': Mean Reciprocal Rank | |
27 % 'NDCG': Normalized Discounted Cumulative Gain | |
28 % | |
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) | |
30 % where k is the number of neighbors for Prec@k or NDCG | |
31 % (default=3) | |
32 % | |
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) | |
34 % where REG defines the regularization on W, and is one of: | |
35 % 0: no regularization | |
36 % 1: 1-norm: trace(W) (default) | |
37 % 2: 2-norm: trace(W' * W) | |
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite | |
39 % | |
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) | |
41 % Diagonal = 0: learn a full d-by-d W (default) | |
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1) | |
43 % | |
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) | |
45 % where B > 0 enables stochastic optimization with batch size B | |
46 % | |
47 % // added by Daniel Wolff | |
48 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC) | |
49 % Set ConstraintClock to CC (default: 20, 100) | |
50 % | |
51 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E) | |
52 % Set ConstraintClock to E (default: 1e-3) | |
53 % | |
54 | |
55 TIME_START = tic(); | |
56 | |
57 global C; | |
58 C = Cslack; | |
59 | |
60 [d,n,m] = size(X); | |
61 | |
62 if m > 1 | |
63 MKL = 1; | |
64 else | |
65 MKL = 0; | |
66 end | |
67 | |
68 if nargin < 3 | |
69 C = 1; | |
70 end | |
71 | |
72 %%% | |
73 % Default options: | |
74 | |
75 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT; | |
76 | |
77 global FEASIBLE_COUNT; | |
78 FEASIBLE_COUNT = 0; | |
79 | |
80 CP = @cuttingPlaneFull; | |
81 SO = @separationOracleAUC; | |
82 PSI = @metricPsiPO; | |
83 | |
84 if ~MKL | |
85 INIT = @initializeFull; | |
86 REG = @regularizeTraceFull; | |
87 STRUCTKERNEL= @structKernelLinear; | |
88 DUALW = @dualWLinear; | |
89 FEASIBLE = @feasibleFull; | |
90 CPGRADIENT = @cpGradientFull; | |
91 DISTANCE = @distanceFull; | |
92 SETDISTANCE = @setDistanceFull; | |
93 LOSS = @lossHinge; | |
94 Regularizer = 'Trace'; | |
95 else | |
96 INIT = @initializeFullMKL; | |
97 REG = @regularizeMKLFull; | |
98 STRUCTKERNEL= @structKernelMKL; | |
99 DUALW = @dualWMKL; | |
100 FEASIBLE = @feasibleFullMKL; | |
101 CPGRADIENT = @cpGradientFullMKL; | |
102 DISTANCE = @distanceFullMKL; | |
103 SETDISTANCE = @setDistanceFullMKL; | |
104 LOSS = @lossHingeFullMKL; | |
105 Regularizer = 'Trace'; | |
106 end | |
107 | |
108 | |
109 Loss = 'AUC'; | |
110 Feature = 'metricPsiPO'; | |
111 | |
112 | |
113 %%% | |
114 % Default k for prec@k, ndcg | |
115 k = 3; | |
116 | |
117 %%% | |
118 % Stochastic violator selection? | |
119 STOCHASTIC = 0; | |
120 batchSize = n; | |
121 SAMPLES = 1:n; | |
122 | |
123 | |
124 if nargin > 3 | |
125 switch lower(varargin{1}) | |
126 case {'auc'} | |
127 SO = @separationOracleAUC; | |
128 PSI = @metricPsiPO; | |
129 Loss = 'AUC'; | |
130 Feature = 'metricPsiPO'; | |
131 case {'knn'} | |
132 SO = @separationOracleKNN; | |
133 PSI = @metricPsiPO; | |
134 Loss = 'KNN'; | |
135 Feature = 'metricPsiPO'; | |
136 case {'prec@k'} | |
137 SO = @separationOraclePrecAtK; | |
138 PSI = @metricPsiPO; | |
139 Loss = 'Prec@k'; | |
140 Feature = 'metricPsiPO'; | |
141 case {'map'} | |
142 SO = @separationOracleMAP; | |
143 PSI = @metricPsiPO; | |
144 Loss = 'MAP'; | |
145 Feature = 'metricPsiPO'; | |
146 case {'mrr'} | |
147 SO = @separationOracleMRR; | |
148 PSI = @metricPsiPO; | |
149 Loss = 'MRR'; | |
150 Feature = 'metricPsiPO'; | |
151 case {'ndcg'} | |
152 SO = @separationOracleNDCG; | |
153 PSI = @metricPsiPO; | |
154 Loss = 'NDCG'; | |
155 Feature = 'metricPsiPO'; | |
156 otherwise | |
157 error('MLR:LOSS', ... | |
158 'Unknown loss function: %s', varargin{1}); | |
159 end | |
160 end | |
161 | |
162 if nargin > 4 | |
163 k = varargin{2}; | |
164 end | |
165 | |
166 Diagonal = 0; | |
167 if nargin > 6 & varargin{4} > 0 | |
168 Diagonal = varargin{4}; | |
169 | |
170 if ~MKL | |
171 INIT = @initializeDiag; | |
172 REG = @regularizeTraceDiag; | |
173 STRUCTKERNEL= @structKernelDiag; | |
174 DUALW = @dualWDiag; | |
175 FEASIBLE = @feasibleDiag; | |
176 CPGRADIENT = @cpGradientDiag; | |
177 DISTANCE = @distanceDiag; | |
178 SETDISTANCE = @setDistanceDiag; | |
179 Regularizer = 'Trace'; | |
180 else | |
181 INIT = @initializeDiagMKL; | |
182 REG = @regularizeMKLDiag; | |
183 STRUCTKERNEL= @structKernelDiagMKL; | |
184 DUALW = @dualWDiagMKL; | |
185 FEASIBLE = @feasibleDiagMKL; | |
186 CPGRADIENT = @cpGradientDiagMKL; | |
187 DISTANCE = @distanceDiagMKL; | |
188 SETDISTANCE = @setDistanceDiagMKL; | |
189 LOSS = @lossHingeDiagMKL; | |
190 Regularizer = 'Trace'; | |
191 end | |
192 end | |
193 | |
194 if nargin > 5 | |
195 switch(varargin{3}) | |
196 case {0} | |
197 REG = @regularizeNone; | |
198 Regularizer = 'None'; | |
199 | |
200 case {1} | |
201 if MKL | |
202 if Diagonal == 0 | |
203 REG = @regularizeMKLFull; | |
204 STRUCTKERNEL= @structKernelMKL; | |
205 DUALW = @dualWMKL; | |
206 elseif Diagonal == 1 | |
207 REG = @regularizeMKLDiag; | |
208 STRUCTKERNEL= @structKernelDiagMKL; | |
209 DUALW = @dualWDiagMKL; | |
210 end | |
211 else | |
212 if Diagonal | |
213 REG = @regularizeTraceDiag; | |
214 STRUCTKERNEL= @structKernelDiag; | |
215 DUALW = @dualWDiag; | |
216 else | |
217 REG = @regularizeTraceFull; | |
218 STRUCTKERNEL= @structKernelLinear; | |
219 DUALW = @dualWLinear; | |
220 end | |
221 end | |
222 Regularizer = 'Trace'; | |
223 | |
224 case {2} | |
225 if Diagonal | |
226 REG = @regularizeTwoDiag; | |
227 else | |
228 REG = @regularizeTwoFull; | |
229 end | |
230 Regularizer = '2-norm'; | |
231 error('MLR:REGULARIZER', '2-norm regularization no longer supported'); | |
232 | |
233 | |
234 case {3} | |
235 if MKL | |
236 if Diagonal == 0 | |
237 REG = @regularizeMKLFull; | |
238 STRUCTKERNEL= @structKernelMKL; | |
239 DUALW = @dualWMKL; | |
240 elseif Diagonal == 1 | |
241 REG = @regularizeMKLDiag; | |
242 STRUCTKERNEL= @structKernelDiagMKL; | |
243 DUALW = @dualWDiagMKL; | |
244 end | |
245 else | |
246 if Diagonal | |
247 REG = @regularizeMKLDiag; | |
248 STRUCTKERNEL= @structKernelDiagMKL; | |
249 DUALW = @dualWDiagMKL; | |
250 else | |
251 REG = @regularizeKernel; | |
252 STRUCTKERNEL= @structKernelMKL; | |
253 DUALW = @dualWMKL; | |
254 end | |
255 end | |
256 Regularizer = 'Kernel'; | |
257 | |
258 otherwise | |
259 error('MLR:REGULARIZER', ... | |
260 'Unknown regularization: %s', varargin{3}); | |
261 end | |
262 end | |
263 | |
264 | |
265 % Are we in stochastic optimization mode? | |
266 if nargin > 7 && varargin{5} > 0 | |
267 if varargin{5} < n | |
268 STOCHASTIC = 1; | |
269 CP = @cuttingPlaneRandom; | |
270 batchSize = varargin{5}; | |
271 end | |
272 end | |
273 | |
274 % Algorithm | |
275 % | |
276 % Working <- [] | |
277 % | |
278 % repeat: | |
279 % (W, Xi) <- solver(X, Y, C, Working) | |
280 % | |
281 % for i = 1:|X| | |
282 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) | |
283 % | |
284 % Working <- Working + (y^_1,y^_2,...,y^_n) | |
285 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) | |
286 % <= Xi + epsilon | |
287 | |
288 global DEBUG; | |
289 | |
290 if isempty(DEBUG) | |
291 DEBUG = 0; | |
292 end | |
293 | |
294 %%% | |
295 % Timer to eliminate old constraints | |
296 | |
297 %%% | |
298 % Timer to eliminate old constraints | |
299 ConstraintClock = 100; % standard: 100 | |
300 | |
301 if nargin > 8 && varargin{6} > 0 | |
302 ConstraintClock = varargin{6}; | |
303 end | |
304 | |
305 %%% | |
306 % Convergence criteria for worst-violated constraint | |
307 E = 1e-3; | |
308 if nargin > 9 && varargin{7} > 0 | |
309 E = varargin{7}; | |
310 end | |
311 | |
312 | |
313 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu> | |
314 % no longer belongs here | |
315 % Initialize | |
316 W = INIT(X); | |
317 | |
318 | |
319 global ADMM_Z ADMM_U RHO; | |
320 ADMM_Z = W; | |
321 ADMM_U = 0 * ADMM_Z; | |
322 | |
323 %%% | |
324 % Augmented lagrangian factor | |
325 RHO = 1; | |
326 | |
327 ClassScores = []; | |
328 | |
329 if isa(Y, 'double') | |
330 Ypos = []; | |
331 Yneg = []; | |
332 ClassScores = synthesizeRelevance(Y); | |
333 | |
334 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 | |
335 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
336 Ypos = Y(:,1); | |
337 Yneg = Y(:,2); | |
338 | |
339 % Compute the valid samples | |
340 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); | |
341 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 | |
342 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
343 Ypos = Y(:,1); | |
344 Yneg = []; | |
345 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); | |
346 else | |
347 error('MLR:LABELS', 'Incorrect format for Y.'); | |
348 end | |
349 | |
350 %% | |
351 % If we don't have enough data to make the batch, cut the batch | |
352 batchSize = min([batchSize, length(SAMPLES)]); | |
353 | |
354 | |
355 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? | |
356 'feature', Feature, ... % Which ranking feature is used? | |
357 'k', k, ... % What is the ranking length? | |
358 'regularizer', Regularizer, ... % What regularization is used? | |
359 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal | |
360 'num_calls_SO', 0, ... % Calls to separation oracle | |
361 'num_calls_solver', 0, ... % Calls to solver | |
362 'time_SO', 0, ... % Time in separation oracle | |
363 'time_solver', 0, ... % Time in solver | |
364 'time_total', 0, ... % Total time | |
365 'f', [], ... % Objective value | |
366 'num_steps', [], ... % Number of steps for each solver run | |
367 'num_constraints', [], ... % Number of constraints for each run | |
368 'Xi', [], ... % Slack achieved for each run | |
369 'Delta', [], ... % Mean loss for each SO call | |
370 'gap', [], ... % Gap between loss and slack | |
371 'C', C, ... % Slack trade-off | |
372 'epsilon', E, ... % Convergence threshold | |
373 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's | |
374 'constraint_timer', ConstraintClock); % Time before evicting old constraints | |
375 | |
376 | |
377 | |
378 global PsiR; | |
379 global PsiClock; | |
380 | |
381 PsiR = {}; | |
382 PsiClock = []; | |
383 | |
384 Xi = -Inf; | |
385 Margins = []; | |
386 H = []; | |
387 Q = []; | |
388 | |
389 if STOCHASTIC | |
390 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); | |
391 end | |
392 | |
393 MAXITER = 200; | |
394 % while 1 | |
395 while Diagnostics.num_calls_solver < MAXITER | |
396 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); | |
397 % Generate a constraint set | |
398 Termination = 0; | |
399 | |
400 | |
401 dbprint(2, 'Calling separation oracle...'); | |
402 | |
403 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); | |
404 Termination = LOSS(W, PsiNew, Mnew, 0); | |
405 | |
406 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; | |
407 Diagnostics.time_SO = Diagnostics.time_SO + SO_time; | |
408 | |
409 Margins = cat(1, Margins, Mnew); | |
410 PsiR = cat(1, PsiR, PsiNew); | |
411 PsiClock = cat(1, PsiClock, 0); | |
412 H = expandKernel(H); | |
413 Q = expandRegularizer(Q, X, W); | |
414 | |
415 | |
416 dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); | |
417 dbprint(2, '\t Mean loss : %0.4f', Mnew); | |
418 dbprint(2, '\t Current loss Xi : %0.4f', Xi); | |
419 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E); | |
420 | |
421 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi); | |
422 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew); | |
423 | |
424 if Termination <= Xi + E | |
425 dbprint(2, 'Done.'); | |
426 break; | |
427 end | |
428 | |
429 dbprint(2, 'Calling solver...'); | |
430 PsiClock = PsiClock + 1; | |
431 Solver_time = tic(); | |
432 [W, Xi, Dsolver] = mlr_admm(C, X, Margins, H, Q); | |
433 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); | |
434 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; | |
435 | |
436 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); | |
437 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); | |
438 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); | |
439 | |
440 %%% | |
441 % Cull the old constraints | |
442 GC = PsiClock < ConstraintClock; | |
443 Margins = Margins(GC); | |
444 PsiR = PsiR(GC); | |
445 PsiClock = PsiClock(GC); | |
446 H = H(GC, GC); | |
447 Q = Q(GC); | |
448 | |
449 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); | |
450 end | |
451 | |
452 | |
453 % Finish diagnostics | |
454 | |
455 Diagnostics.time_total = toc(TIME_START); | |
456 Diagnostics.feasible_count = FEASIBLE_COUNT; | |
457 end | |
458 | |
459 function H = expandKernel(H) | |
460 | |
461 global STRUCTKERNEL; | |
462 global PsiR; | |
463 | |
464 m = length(H); | |
465 H = padarray(H, [1 1], 0, 'post'); | |
466 | |
467 | |
468 for i = 1:m+1 | |
469 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} ); | |
470 H(m+1, i) = H(i, m+1); | |
471 end | |
472 end | |
473 | |
474 function Q = expandRegularizer(Q, K, W) | |
475 | |
476 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu> | |
477 % does not support unregularized learning | |
478 | |
479 global PsiR; | |
480 global STRUCTKERNEL REG; | |
481 | |
482 m = length(Q); | |
483 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1}); | |
484 | |
485 end | |
486 | |
487 function ClassScores = synthesizeRelevance(Y) | |
488 | |
489 classes = unique(Y); | |
490 nClasses = length(classes); | |
491 | |
492 ClassScores = struct( 'Y', Y, ... | |
493 'classes', classes, ... | |
494 'Ypos', [], ... | |
495 'Yneg', []); | |
496 | |
497 Ypos = cell(nClasses, 1); | |
498 Yneg = cell(nClasses, 1); | |
499 for c = 1:nClasses | |
500 Ypos{c} = (Y == classes(c)); | |
501 Yneg{c} = ~Ypos{c}; | |
502 end | |
503 | |
504 ClassScores.Ypos = Ypos; | |
505 ClassScores.Yneg = Yneg; | |
506 | |
507 end |