comparison toolboxes/distance_learning/mlr/mlr_train.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
2 %
3 % [W, Xi, D] = mlr_train(X, Y, C,...)
4 %
5 % X = d*n data matrix
6 % Y = either n-by-1 label of vectors
7 % OR
8 % n-by-2 cell array where
9 % Y{q,1} contains relevant indices for q, and
10 % Y{q,2} contains irrelevant indices for q
11 %
12 % C >= 0 slack trade-off parameter (default=1)
13 %
14 % W = the learned metric
15 % Xi = slack value on the learned metric
16 % D = diagnostics
17 %
18 % Optional arguments:
19 %
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
21 % where LOSS is one of:
22 % 'AUC': Area under ROC curve (default)
23 % 'KNN': KNN accuracy
24 % 'Prec@k': Precision-at-k
25 % 'MAP': Mean Average Precision
26 % 'MRR': Mean Reciprocal Rank
27 % 'NDCG': Normalized Discounted Cumulative Gain
28 %
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
30 % where k is the number of neighbors for Prec@k or NDCG
31 % (default=3)
32 %
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
34 % where REG defines the regularization on W, and is one of:
35 % 0: no regularization
36 % 1: 1-norm: trace(W) (default)
37 % 2: 2-norm: trace(W' * W)
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
39 %
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
41 % Diagonal = 0: learn a full d-by-d W (default)
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
43 %
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
45 % where B > 0 enables stochastic optimization with batch size B
46 %
47 % // added by Daniel Wolff
48 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
49 % Set ConstraintClock to CC (default: 20, 100)
50 %
51 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
52 % Set ConstraintClock to E (default: 1e-3)
53 %
54
55 TIME_START = tic();
56
57 global C;
58 C = Cslack;
59
60 [d,n,m] = size(X);
61
62 if m > 1
63 MKL = 1;
64 else
65 MKL = 0;
66 end
67
68 if nargin < 3
69 C = 1;
70 end
71
72 %%%
73 % Default options:
74
75 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT;
76
77 global FEASIBLE_COUNT;
78 FEASIBLE_COUNT = 0;
79
80 CP = @cuttingPlaneFull;
81 SO = @separationOracleAUC;
82 PSI = @metricPsiPO;
83
84 if ~MKL
85 INIT = @initializeFull;
86 REG = @regularizeTraceFull;
87 STRUCTKERNEL= @structKernelLinear;
88 DUALW = @dualWLinear;
89 FEASIBLE = @feasibleFull;
90 CPGRADIENT = @cpGradientFull;
91 DISTANCE = @distanceFull;
92 SETDISTANCE = @setDistanceFull;
93 LOSS = @lossHinge;
94 Regularizer = 'Trace';
95 else
96 INIT = @initializeFullMKL;
97 REG = @regularizeMKLFull;
98 STRUCTKERNEL= @structKernelMKL;
99 DUALW = @dualWMKL;
100 FEASIBLE = @feasibleFullMKL;
101 CPGRADIENT = @cpGradientFullMKL;
102 DISTANCE = @distanceFullMKL;
103 SETDISTANCE = @setDistanceFullMKL;
104 LOSS = @lossHingeFullMKL;
105 Regularizer = 'Trace';
106 end
107
108
109 Loss = 'AUC';
110 Feature = 'metricPsiPO';
111
112
113 %%%
114 % Default k for prec@k, ndcg
115 k = 3;
116
117 %%%
118 % Stochastic violator selection?
119 STOCHASTIC = 0;
120 batchSize = n;
121 SAMPLES = 1:n;
122
123
124 if nargin > 3
125 switch lower(varargin{1})
126 case {'auc'}
127 SO = @separationOracleAUC;
128 PSI = @metricPsiPO;
129 Loss = 'AUC';
130 Feature = 'metricPsiPO';
131 case {'knn'}
132 SO = @separationOracleKNN;
133 PSI = @metricPsiPO;
134 Loss = 'KNN';
135 Feature = 'metricPsiPO';
136 case {'prec@k'}
137 SO = @separationOraclePrecAtK;
138 PSI = @metricPsiPO;
139 Loss = 'Prec@k';
140 Feature = 'metricPsiPO';
141 case {'map'}
142 SO = @separationOracleMAP;
143 PSI = @metricPsiPO;
144 Loss = 'MAP';
145 Feature = 'metricPsiPO';
146 case {'mrr'}
147 SO = @separationOracleMRR;
148 PSI = @metricPsiPO;
149 Loss = 'MRR';
150 Feature = 'metricPsiPO';
151 case {'ndcg'}
152 SO = @separationOracleNDCG;
153 PSI = @metricPsiPO;
154 Loss = 'NDCG';
155 Feature = 'metricPsiPO';
156 otherwise
157 error('MLR:LOSS', ...
158 'Unknown loss function: %s', varargin{1});
159 end
160 end
161
162 if nargin > 4
163 k = varargin{2};
164 end
165
166 Diagonal = 0;
167 if nargin > 6 & varargin{4} > 0
168 Diagonal = varargin{4};
169
170 if ~MKL
171 INIT = @initializeDiag;
172 REG = @regularizeTraceDiag;
173 STRUCTKERNEL= @structKernelDiag;
174 DUALW = @dualWDiag;
175 FEASIBLE = @feasibleDiag;
176 CPGRADIENT = @cpGradientDiag;
177 DISTANCE = @distanceDiag;
178 SETDISTANCE = @setDistanceDiag;
179 Regularizer = 'Trace';
180 else
181 INIT = @initializeDiagMKL;
182 REG = @regularizeMKLDiag;
183 STRUCTKERNEL= @structKernelDiagMKL;
184 DUALW = @dualWDiagMKL;
185 FEASIBLE = @feasibleDiagMKL;
186 CPGRADIENT = @cpGradientDiagMKL;
187 DISTANCE = @distanceDiagMKL;
188 SETDISTANCE = @setDistanceDiagMKL;
189 LOSS = @lossHingeDiagMKL;
190 Regularizer = 'Trace';
191 end
192 end
193
194 if nargin > 5
195 switch(varargin{3})
196 case {0}
197 REG = @regularizeNone;
198 Regularizer = 'None';
199
200 case {1}
201 if MKL
202 if Diagonal == 0
203 REG = @regularizeMKLFull;
204 STRUCTKERNEL= @structKernelMKL;
205 DUALW = @dualWMKL;
206 elseif Diagonal == 1
207 REG = @regularizeMKLDiag;
208 STRUCTKERNEL= @structKernelDiagMKL;
209 DUALW = @dualWDiagMKL;
210 end
211 else
212 if Diagonal
213 REG = @regularizeTraceDiag;
214 STRUCTKERNEL= @structKernelDiag;
215 DUALW = @dualWDiag;
216 else
217 REG = @regularizeTraceFull;
218 STRUCTKERNEL= @structKernelLinear;
219 DUALW = @dualWLinear;
220 end
221 end
222 Regularizer = 'Trace';
223
224 case {2}
225 if Diagonal
226 REG = @regularizeTwoDiag;
227 else
228 REG = @regularizeTwoFull;
229 end
230 Regularizer = '2-norm';
231 error('MLR:REGULARIZER', '2-norm regularization no longer supported');
232
233
234 case {3}
235 if MKL
236 if Diagonal == 0
237 REG = @regularizeMKLFull;
238 STRUCTKERNEL= @structKernelMKL;
239 DUALW = @dualWMKL;
240 elseif Diagonal == 1
241 REG = @regularizeMKLDiag;
242 STRUCTKERNEL= @structKernelDiagMKL;
243 DUALW = @dualWDiagMKL;
244 end
245 else
246 if Diagonal
247 REG = @regularizeMKLDiag;
248 STRUCTKERNEL= @structKernelDiagMKL;
249 DUALW = @dualWDiagMKL;
250 else
251 REG = @regularizeKernel;
252 STRUCTKERNEL= @structKernelMKL;
253 DUALW = @dualWMKL;
254 end
255 end
256 Regularizer = 'Kernel';
257
258 otherwise
259 error('MLR:REGULARIZER', ...
260 'Unknown regularization: %s', varargin{3});
261 end
262 end
263
264
265 % Are we in stochastic optimization mode?
266 if nargin > 7 && varargin{5} > 0
267 if varargin{5} < n
268 STOCHASTIC = 1;
269 CP = @cuttingPlaneRandom;
270 batchSize = varargin{5};
271 end
272 end
273
274 % Algorithm
275 %
276 % Working <- []
277 %
278 % repeat:
279 % (W, Xi) <- solver(X, Y, C, Working)
280 %
281 % for i = 1:|X|
282 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
283 %
284 % Working <- Working + (y^_1,y^_2,...,y^_n)
285 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
286 % <= Xi + epsilon
287
288 global DEBUG;
289
290 if isempty(DEBUG)
291 DEBUG = 0;
292 end
293
294 %%%
295 % Timer to eliminate old constraints
296
297 %%%
298 % Timer to eliminate old constraints
299 ConstraintClock = 100; % standard: 100
300
301 if nargin > 8 && varargin{6} > 0
302 ConstraintClock = varargin{6};
303 end
304
305 %%%
306 % Convergence criteria for worst-violated constraint
307 E = 1e-3;
308 if nargin > 9 && varargin{7} > 0
309 E = varargin{7};
310 end
311
312
313 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
314 % no longer belongs here
315 % Initialize
316 W = INIT(X);
317
318
319 global ADMM_Z ADMM_U RHO;
320 ADMM_Z = W;
321 ADMM_U = 0 * ADMM_Z;
322
323 %%%
324 % Augmented lagrangian factor
325 RHO = 1;
326
327 ClassScores = [];
328
329 if isa(Y, 'double')
330 Ypos = [];
331 Yneg = [];
332 ClassScores = synthesizeRelevance(Y);
333
334 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
335 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
336 Ypos = Y(:,1);
337 Yneg = Y(:,2);
338
339 % Compute the valid samples
340 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
341 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
342 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
343 Ypos = Y(:,1);
344 Yneg = [];
345 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
346 else
347 error('MLR:LABELS', 'Incorrect format for Y.');
348 end
349
350 %%
351 % If we don't have enough data to make the batch, cut the batch
352 batchSize = min([batchSize, length(SAMPLES)]);
353
354
355 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
356 'feature', Feature, ... % Which ranking feature is used?
357 'k', k, ... % What is the ranking length?
358 'regularizer', Regularizer, ... % What regularization is used?
359 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
360 'num_calls_SO', 0, ... % Calls to separation oracle
361 'num_calls_solver', 0, ... % Calls to solver
362 'time_SO', 0, ... % Time in separation oracle
363 'time_solver', 0, ... % Time in solver
364 'time_total', 0, ... % Total time
365 'f', [], ... % Objective value
366 'num_steps', [], ... % Number of steps for each solver run
367 'num_constraints', [], ... % Number of constraints for each run
368 'Xi', [], ... % Slack achieved for each run
369 'Delta', [], ... % Mean loss for each SO call
370 'gap', [], ... % Gap between loss and slack
371 'C', C, ... % Slack trade-off
372 'epsilon', E, ... % Convergence threshold
373 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's
374 'constraint_timer', ConstraintClock); % Time before evicting old constraints
375
376
377
378 global PsiR;
379 global PsiClock;
380
381 PsiR = {};
382 PsiClock = [];
383
384 Xi = -Inf;
385 Margins = [];
386 H = [];
387 Q = [];
388
389 if STOCHASTIC
390 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
391 end
392
393 MAXITER = 200;
394 % while 1
395 while Diagnostics.num_calls_solver < MAXITER
396 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
397 % Generate a constraint set
398 Termination = 0;
399
400
401 dbprint(2, 'Calling separation oracle...');
402
403 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
404 Termination = LOSS(W, PsiNew, Mnew, 0);
405
406 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
407 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
408
409 Margins = cat(1, Margins, Mnew);
410 PsiR = cat(1, PsiR, PsiNew);
411 PsiClock = cat(1, PsiClock, 0);
412 H = expandKernel(H);
413 Q = expandRegularizer(Q, X, W);
414
415
416 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
417 dbprint(2, '\t Mean loss : %0.4f', Mnew);
418 dbprint(2, '\t Current loss Xi : %0.4f', Xi);
419 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
420
421 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
422 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
423
424 if Termination <= Xi + E
425 dbprint(2, 'Done.');
426 break;
427 end
428
429 dbprint(2, 'Calling solver...');
430 PsiClock = PsiClock + 1;
431 Solver_time = tic();
432 [W, Xi, Dsolver] = mlr_admm(C, X, Margins, H, Q);
433 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
434 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
435
436 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
437 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
438 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
439
440 %%%
441 % Cull the old constraints
442 GC = PsiClock < ConstraintClock;
443 Margins = Margins(GC);
444 PsiR = PsiR(GC);
445 PsiClock = PsiClock(GC);
446 H = H(GC, GC);
447 Q = Q(GC);
448
449 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
450 end
451
452
453 % Finish diagnostics
454
455 Diagnostics.time_total = toc(TIME_START);
456 Diagnostics.feasible_count = FEASIBLE_COUNT;
457 end
458
459 function H = expandKernel(H)
460
461 global STRUCTKERNEL;
462 global PsiR;
463
464 m = length(H);
465 H = padarray(H, [1 1], 0, 'post');
466
467
468 for i = 1:m+1
469 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
470 H(m+1, i) = H(i, m+1);
471 end
472 end
473
474 function Q = expandRegularizer(Q, K, W)
475
476 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
477 % does not support unregularized learning
478
479 global PsiR;
480 global STRUCTKERNEL REG;
481
482 m = length(Q);
483 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});
484
485 end
486
487 function ClassScores = synthesizeRelevance(Y)
488
489 classes = unique(Y);
490 nClasses = length(classes);
491
492 ClassScores = struct( 'Y', Y, ...
493 'classes', classes, ...
494 'Ypos', [], ...
495 'Yneg', []);
496
497 Ypos = cell(nClasses, 1);
498 Yneg = cell(nClasses, 1);
499 for c = 1:nClasses
500 Ypos{c} = (Y == classes(c));
501 Yneg{c} = ~Ypos{c};
502 end
503
504 ClassScores.Ypos = Ypos;
505 ClassScores.Yneg = Yneg;
506
507 end