comparison toolboxes/distance_learning/mlr/rmlr_train.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [W, Xi, Diagnostics] = rmlr_train(X, Y, Cslack, varargin)
2 %[W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, stochastic, lam, STRUCTREG)
3 %
4 % [W, Xi, D] = rmlr_train(X, Y, C,...)
5 %
6 % X = d*n data matrix
7 % Y = either n-by-1 label of vectors
8 % OR
9 % n-by-2 cell array where
10 % Y{q,1} contains relevant indices for q, and
11 % Y{q,2} contains irrelevant indices for q
12 %
13 % C >= 0 slack trade-off parameter (default=1)
14 %
15 % W = the learned metric
16 % Xi = slack value on the learned metric
17 % D = diagnostics
18 %
19 % Optional arguments:
20 %
21 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS)
22 % where LOSS is one of:
23 % 'AUC': Area under ROC curve (default)
24 % 'KNN': KNN accuracy
25 % 'Prec@k': Precision-at-k
26 % 'MAP': Mean Average Precision
27 % 'MRR': Mean Reciprocal Rank
28 % 'NDCG': Normalized Discounted Cumulative Gain
29 %
30 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k)
31 % where k is the number of neighbors for Prec@k or NDCG
32 % (default=3)
33 %
34 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG)
35 % where REG defines the regularization on W, and is one of:
36 % 0: no regularization
37 % 1: 1-norm: trace(W) (default)
38 % 2: 2-norm: trace(W' * W)
39 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
40 %
41 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal)
42 % Not implemented as learning a diagonal W metric just reduces to MLR because ||W||_2,1 = trace(W) when W is diagonal.
43 %
44 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
45 % where B > 0 enables stochastic optimization with batch size B
46 %
47 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda)
48 % lambda is the desired value of the hyperparameter which is the coefficient of ||W||_2,1. Default is 1 if lambda is not set
49 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC)
50 % Set ConstraintClock to CC (default: 20, 100)
51 %
52 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC, E)
53 % Set ConstraintClock to E (default: 1e-3)
54 %
55
56
57 TIME_START = tic();
58
59 global C;
60 C = Cslack;
61
62 [d,n,m] = size(X);
63
64 if m > 1
65 MKL = 1;
66 else
67 MKL = 0;
68 end
69
70 if nargin < 3
71 C = 1;
72 end
73
74 %%%
75 % Default options:
76
77 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW THRESH INIT;
78 global RHO;
79
80
81 %%%
82 % Augmented lagrangian factor
83 RHO = 1;
84
85 % </modified>
86
87 global FEASIBLE_COUNT;
88 FEASIBLE_COUNT = 0;
89
90 CP = @cuttingPlaneFull;
91 SO = @separationOracleAUC;
92 PSI = @metricPsiPO;
93
94 if ~MKL
95 INIT = @initializeFull;
96 REG = @regularizeTraceFull;
97 STRUCTKERNEL= @structKernelLinear;
98 DUALW = @dualWLinear;
99 FEASIBLE = @feasibleFull;
100 THRESH = @threshFull_admmMixed;
101 CPGRADIENT = @cpGradientFull;
102 DISTANCE = @distanceFull;
103 SETDISTANCE = @setDistanceFull;
104 LOSS = @lossHinge;
105 Regularizer = 'Trace';
106 else
107 INIT = @initializeFullMKL;
108 REG = @regularizeMKLFull;
109 STRUCTKERNEL= @structKernelMKL;
110 DUALW = @dualWMKL;
111 FEASIBLE = @feasibleFullMKL;
112 THRESH = @threshFull_admmMixed;
113 CPGRADIENT = @cpGradientFullMKL;
114 DISTANCE = @distanceFullMKL;
115 SETDISTANCE = @setDistanceFullMKL;
116 LOSS = @lossHingeFullMKL;
117 Regularizer = 'Trace';
118 end
119
120
121 Loss = 'AUC';
122 Feature = 'metricPsiPO';
123
124
125 %%%
126 % Default k for prec@k, ndcg
127 k = 3;
128
129 %%%
130 % Stochastic violator selection?
131 STOCHASTIC = 0;
132 batchSize = n;
133 SAMPLES = 1:n;
134
135
136 if nargin > 3
137 switch lower(varargin{1})
138 case {'auc'}
139 SO = @separationOracleAUC;
140 PSI = @metricPsiPO;
141 Loss = 'AUC';
142 Feature = 'metricPsiPO';
143 case {'knn'}
144 SO = @separationOracleKNN;
145 PSI = @metricPsiPO;
146 Loss = 'KNN';
147 Feature = 'metricPsiPO';
148 case {'prec@k'}
149 SO = @separationOraclePrecAtK;
150 PSI = @metricPsiPO;
151 Loss = 'Prec@k';
152 Feature = 'metricPsiPO';
153 case {'map'}
154 SO = @separationOracleMAP;
155 PSI = @metricPsiPO;
156 Loss = 'MAP';
157 Feature = 'metricPsiPO';
158 case {'mrr'}
159 SO = @separationOracleMRR;
160 PSI = @metricPsiPO;
161 Loss = 'MRR';
162 Feature = 'metricPsiPO';
163 case {'ndcg'}
164 SO = @separationOracleNDCG;
165 PSI = @metricPsiPO;
166 Loss = 'NDCG';
167 Feature = 'metricPsiPO';
168 otherwise
169 error('MLR:LOSS', ...
170 'Unknown loss function: %s', varargin{1});
171 end
172 end
173
174 if nargin > 4
175 k = varargin{2};
176 end
177
178 Diagonal = 0;
179 %Diagonal case is not implemented. Use mlr_train for that.
180
181 if nargin > 5
182 switch(varargin{3})
183 case {0}
184 REG = @regularizeNone;
185 Regularizer = 'None';
186 THRESH = @threshFull_admmMixed;
187 case {1}
188 if MKL
189 REG = @regularizeMKLFull;
190 STRUCTKERNEL= @structKernelMKL;
191 DUALW = @dualWMKL;
192 else
193 REG = @regularizeTraceFull;
194 STRUCTKERNEL= @structKernelLinear;
195 DUALW = @dualWLinear;
196 end
197 Regularizer = 'Trace';
198
199 case {2}
200 REG = @regularizeTwoFull;
201 Regularizer = '2-norm';
202 error('MLR:REGULARIZER', '2-norm regularization no longer supported');
203
204
205 case {3}
206 if MKL
207 REG = @regularizeMKLFull;
208 STRUCTKERNEL= @structKernelMKL;
209 DUALW = @dualWMKL;
210 else
211 REG = @regularizeKernel;
212 STRUCTKERNEL= @structKernelMKL;
213 DUALW = @dualWMKL;
214 end
215 Regularizer = 'Kernel';
216
217
218 otherwise
219 error('MLR:REGULARIZER', ...
220 'Unknown regularization: %s', varargin{3});
221 end
222 end
223
224
225 % Are we in stochastic optimization mode?
226 if nargin > 7 && varargin{5} > 0
227 if varargin{5} < n
228 STOCHASTIC = 1;
229 CP = @cuttingPlaneRandom;
230 batchSize = varargin{5};
231 end
232 end
233 % Algorithm
234 %
235 % Working <- []
236 %
237 % repeat:
238 % (W, Xi) <- solver(X, Y, C, Working)
239 %
240 % for i = 1:|X|
241 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
242 %
243 % Working <- Working + (y^_1,y^_2,...,y^_n)
244 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
245 % <= Xi + epsilon
246
247 if nargin > 8
248 lam = varargin{6};
249 else
250 lam = 1;
251 end
252
253 disp(['lam = ' num2str(lam)])
254
255 global DEBUG;
256
257 if isempty(DEBUG)
258 DEBUG = 0;
259 end
260
261 DEBUG = 1;
262
263 %%%
264 % Max calls to seperation oracle
265 MAX_CALLS = 200;
266 MIN_CALLS = 10;
267
268 %%%
269 % Timer to eliminate old constraints
270 ConstraintClock = 500; % standard: 500
271
272 if nargin > 9 && varargin{7} > 0
273 ConstraintClock = varargin{7};
274 end
275
276 %%%
277 % Convergence criteria for worst-violated constraint
278 E = 1e-3;
279 if nargin > 10 && varargin{8} > 0
280 E = varargin{8};
281 end
282
283 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
284 % no longer belongs here
285 % Initialize
286 W = INIT(X);
287
288
289 global ADMM_Z ADMM_V ADMM_UW ADMM_UV;
290 ADMM_Z = W;
291 ADMM_V = W;
292 ADMM_UW = 0 * ADMM_Z;
293 ADMM_UV = 0 * ADMM_Z;
294
295 ClassScores = [];
296
297 if isa(Y, 'double')
298 Ypos = [];
299 Yneg = [];
300 ClassScores = synthesizeRelevance(Y);
301
302 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
303 dbprint(2, 'Using supplied Ypos/Yneg');
304 Ypos = Y(:,1);
305 Yneg = Y(:,2);
306
307 % Compute the valid samples
308 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
309 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
310 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
311 Ypos = Y(:,1);
312 Yneg = [];
313 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
314 else
315 error('MLR:LABELS', 'Incorrect format for Y.');
316 end
317 %%
318 % If we don't have enough data to make the batch, cut the batch
319 batchSize = min([batchSize, length(SAMPLES)]);
320
321 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
322 'feature', Feature, ... % Which ranking feature is used?
323 'k', k, ... % What is the ranking length?
324 'regularizer', Regularizer, ... % What regularization is used?
325 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
326 'num_calls_SO', 0, ... % Calls to separation oracle
327 'num_calls_solver', 0, ... % Calls to solver
328 'time_SO', 0, ... % Time in separation oracle
329 'time_solver', 0, ... % Time in solver
330 'time_total', 0, ... % Total time
331 'f', [], ... % Objective value
332 'num_steps', [], ... % Number of steps for each solver run
333 'num_constraints', [], ... % Number of constraints for each run
334 'Xi', [], ... % Slack achieved for each run
335 'Delta', [], ... % Mean loss for each SO call
336 'gap', [], ... % Gap between loss and slack
337 'C', C, ... % Slack trade-off
338 'epsilon', E, ... % Convergence threshold
339 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's
340 'constraint_timer', ConstraintClock); % Time before evicting old constraints
341
342
343
344 global PsiR;
345 global PsiClock;
346
347 PsiR = {};
348 PsiClock = [];
349
350 Xi = -Inf;
351 Margins = [];
352 H = [];
353 Q = [];
354
355 if STOCHASTIC
356 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
357 end
358
359 dbprint(2,['Regularizer is "' Regularizer '"']);
360 while 1
361 if Diagnostics.num_calls_solver > MAX_CALLS
362 dbprint(2,['Calls to SO >= ' num2str(MAX_CALLS)]);
363 break;
364 end
365
366 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
367 % Generate a constraint set
368 Termination = -Inf;
369
370
371 dbprint(2, 'Calling separation oracle...');
372 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
373
374 Termination = LOSS(W, PsiNew, Mnew, 0);
375 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
376 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
377
378 Margins = cat(1, Margins, Mnew);
379 PsiR = cat(1, PsiR, PsiNew);
380 PsiClock = cat(1, PsiClock, 0);
381 H = expandKernel(H);
382 Q = expandRegularizer(Q, X, W);
383
384
385 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
386 dbprint(2, '\t Mean loss : %0.4f', Mnew);
387 dbprint(2, '\t Current loss Xi : %0.4f', Xi);
388 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
389
390 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
391 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
392
393 %if Termination <= Xi + E
394 if Termination <= Xi + E && Diagnostics.num_calls_solver > MIN_CALLS
395 %if Termination - Xi <= E
396 dbprint(1, 'Done.');
397 break;
398 end
399
400
401
402 dbprint(1, 'Calling solver...');
403 PsiClock = PsiClock + 1;
404 Solver_time = tic();
405 % disp('Robust MLR')
406 [W, Xi, Dsolver] = rmlr_admm(C, X, Margins, H, Q, lam);
407
408 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
409 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
410
411 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
412 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
413 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
414
415 %%%
416 % Cull the old constraints
417 GC = PsiClock < ConstraintClock;
418 Margins = Margins(GC);
419 PsiR = PsiR(GC);
420 PsiClock = PsiClock(GC);
421 H = H(GC, GC);
422 Q = Q(GC);
423
424 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
425 end
426
427
428 % Finish diagnostics
429
430 Diagnostics.time_total = toc(TIME_START);
431 Diagnostics.feasible_count = FEASIBLE_COUNT;
432 end
433
434 function H = expandKernel(H)
435
436 global STRUCTKERNEL;
437 global PsiR;
438
439 m = length(H);
440 H = padarray(H, [1 1], 0, 'post');
441
442
443 for i = 1:m+1
444 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
445 H(m+1, i) = H(i, m+1);
446 end
447 end
448
449 function Q = expandRegularizer(Q, K, W)
450
451 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
452 % does not support unregularized learning
453
454 global PsiR;
455 global STRUCTKERNEL REG;
456
457 m = length(Q);
458 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});
459
460 end
461
462 function ClassScores = synthesizeRelevance(Y)
463
464 classes = unique(Y);
465 nClasses = length(classes);
466
467 ClassScores = struct( 'Y', Y, ...
468 'classes', classes, ...
469 'Ypos', [], ...
470 'Yneg', []);
471
472 Ypos = cell(nClasses, 1);
473 Yneg = cell(nClasses, 1);
474 for c = 1:nClasses
475 Ypos{c} = (Y == classes(c));
476 Yneg{c} = ~Ypos{c};
477 end
478
479 ClassScores.Ypos = Ypos;
480 ClassScores.Yneg = Yneg;
481
482 end