comparison toolboxes/distance_learning/mlr/mlr_train.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cc4b1211e677
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, C, varargin)
2 %
3 % [W, Xi, D] = mlr_train(X, Y, C,...)
4 %
5 % X = d*n data matrix
6 % Y = either n-by-1 label of vectors
7 % OR
8 % n-by-2 cell array where
9 % Y{q,1} contains relevant indices for q, and
10 % Y{q,2} contains irrelevant indices for q
11 %
12 % C >= 0 slack trade-off parameter (default=1)
13 %
14 % W = the learned metric
15 % Xi = slack value on the learned metric
16 % D = diagnostics
17 %
18 % Optional arguments:
19 %
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
21 % where LOSS is one of:
22 % 'AUC': Area under ROC curve (default)
23 % 'KNN': KNN accuracy
24 % 'Prec@k': Precision-at-k
25 % 'MAP': Mean Average Precision
26 % 'MRR': Mean Reciprocal Rank
27 % 'NDCG': Normalized Discounted Cumulative Gain
28 %
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
30 % where k is the number of neighbors for Prec@k or NDCG
31 % (default=3)
32 %
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
34 % where REG defines the regularization on W, and is one of:
35 % 0: no regularization
36 % 1: 1-norm: trace(W) (default)
37 % 2: 2-norm: trace(W' * W)
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
39 %
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
41 % Diagonal = 0: learn a full d-by-d W (default)
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
43 %
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
45 % where B > 0 enables stochastic optimization with batch size B
46 %
47 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
48 % Set ConstraintClock to CC (default: 20, 100)
49 %
50 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
51 % Set ConstraintClock to E (default: 1e-3)
52 %
53
54
55 global globalvars;
56 global DEBUG;
57
58 if isfield(globalvars, 'debug')
59
60 DEBUG = globalvars.debug;
61 else
62
63 DEBUG = 0;
64 end
65
66
67
68 TIME_START = tic();
69
70 % addpath('cuttingPlane', 'distance', 'feasible', 'initialize', 'loss', ...
71 % 'metricPsi', 'regularize', 'separationOracle', 'util');
72
73 [d,n,m] = size(X);
74
75 if m > 1
76 MKL = 1;
77 else
78 MKL = 0;
79 end
80
81 if nargin < 3
82 C = 1;
83 end
84
85 %%%
86 % Default options:
87
88 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT METRICK;
89
90 CP = @cuttingPlaneFull;
91 SO = @separationOracleAUC;
92 PSI = @metricPsiPO;
93
94 if ~MKL
95 INIT = @initializeFull;
96 REG = @regularizeTraceFull;
97 FEASIBLE = @feasibleFull;
98 CPGRADIENT = @cpGradientFull;
99 DISTANCE = @distanceFull;
100 SETDISTANCE = @setDistanceFull;
101 LOSS = @lossHinge;
102 Regularizer = 'Trace';
103 else
104 INIT = @initializeFullMKL;
105 REG = @regularizeMKLFull;
106 FEASIBLE = @feasibleFullMKL;
107 CPGRADIENT = @cpGradientFullMKL;
108 DISTANCE = @distanceFullMKL;
109 SETDISTANCE = @setDistanceFullMKL;
110 LOSS = @lossHingeFullMKL;
111 Regularizer = 'Trace';
112 end
113
114
115 Loss = 'AUC';
116 Feature = 'metricPsiPO';
117
118
119 %%%
120 % Default k for prec@k, ndcg
121 k = 3;
122
123 %%%
124 % Stochastic violator selection?
125 STOCHASTIC = 0;
126 batchSize = n;
127 SAMPLES = 1:n;
128
129
130 if nargin > 3
131 switch lower(varargin{1})
132 case {'auc'}
133 SO = @separationOracleAUC;
134 PSI = @metricPsiPO;
135 Loss = 'AUC';
136 Feature = 'metricPsiPO';
137 case {'knn'}
138 SO = @separationOracleKNN;
139 PSI = @metricPsiPO;
140 Loss = 'KNN';
141 Feature = 'metricPsiPO';
142 case {'prec@k'}
143 SO = @separationOraclePrecAtK;
144 PSI = @metricPsiPO;
145 Loss = 'Prec@k';
146 Feature = 'metricPsiPO';
147 case {'map'}
148 SO = @separationOracleMAP;
149 PSI = @metricPsiPO;
150 Loss = 'MAP';
151 Feature = 'metricPsiPO';
152 case {'mrr'}
153 SO = @separationOracleMRR;
154 PSI = @metricPsiPO;
155 Loss = 'MRR';
156 Feature = 'metricPsiPO';
157 case {'ndcg'}
158 SO = @separationOracleNDCG;
159 PSI = @metricPsiPO;
160 Loss = 'NDCG';
161 Feature = 'metricPsiPO';
162 otherwise
163 error('MLR:LOSS', ...
164 'Unknown loss function: %s', varargin{1});
165 end
166 end
167
168 if nargin > 4
169 k = varargin{2};
170 end
171
172 METRICK = k;
173
174 Diagonal = 0;
175 if nargin > 6 & varargin{4} > 0
176 Diagonal = varargin{4};
177
178 if ~MKL
179 INIT = @initializeDiag;
180 REG = @regularizeTraceDiag;
181 FEASIBLE = @feasibleDiag;
182 CPGRADIENT = @cpGradientDiag;
183 DISTANCE = @distanceDiag;
184 SETDISTANCE = @setDistanceDiag;
185 Regularizer = 'Trace';
186 else
187 if Diagonal > 1
188 INIT = @initializeDODMKL;
189 REG = @regularizeMKLDOD;
190 FEASIBLE = @feasibleDODMKL;
191 CPGRADIENT = @cpGradientDODMKL;
192 DISTANCE = @distanceDODMKL;
193 SETDISTANCE = @setDistanceDODMKL;
194 LOSS = @lossHingeDODMKL;
195 Regularizer = 'Trace';
196 else
197 INIT = @initializeDiagMKL;
198 REG = @regularizeMKLDiag;
199 FEASIBLE = @feasibleDiagMKL;
200 CPGRADIENT = @cpGradientDiagMKL;
201 DISTANCE = @distanceDiagMKL;
202 SETDISTANCE = @setDistanceDiagMKL;
203 LOSS = @lossHingeDiagMKL;
204 Regularizer = 'Trace';
205 end
206 end
207 end
208
209 if nargin > 5
210 switch(varargin{3})
211 case {0}
212 REG = @regularizeNone;
213 Regularizer = 'None';
214
215 case {1}
216 if MKL
217 if Diagonal == 0
218 REG = @regularizeMKLFull;
219 elseif Diagonal == 1
220 REG = @regularizeMKLDiag;
221 elseif Diagonal == 2
222 REG = @regularizeMKLDOD;
223 end
224 else
225 if Diagonal
226 REG = @regularizeTraceDiag;
227 else
228 REG = @regularizeTraceFull;
229 end
230 end
231 Regularizer = 'Trace';
232
233 case {2}
234 if Diagonal
235 REG = @regularizeTwoDiag;
236 else
237 REG = @regularizeTwoFull;
238 end
239 Regularizer = '2-norm';
240
241 case {3}
242 if MKL
243 if Diagonal == 0
244 REG = @regularizeMKLFull;
245 elseif Diagonal == 1
246 REG = @regularizeMKLDiag;
247 elseif Diagonal == 2
248 REG = @regularizeMKLDOD;
249 end
250 else
251 if Diagonal
252 REG = @regularizeMKLDiag;
253 else
254 REG = @regularizeKernel;
255 end
256 end
257 Regularizer = 'Kernel';
258
259 otherwise
260 error('MLR:REGULARIZER', ...
261 'Unknown regularization: %s', varargin{3});
262 end
263 end
264
265
266 % Are we in stochastic optimization mode?
267 if nargin > 7 && varargin{5} > 0
268 if varargin{5} < n
269 STOCHASTIC = 1;
270 CP = @cuttingPlaneRandom;
271 batchSize = varargin{5};
272 end
273 end
274
275 %%%
276 % Timer to eliminate old constraints
277 ConstraintClock = 20;
278
279 if nargin > 8 && varargin{6} > 0
280 ConstraintClock = varargin{6};
281 end
282
283 %%%
284 % Convergence criteria for worst-violated constraint
285 E = 1e-3;
286 if nargin > 9 && varargin{7} > 0
287 E = varargin{7};
288 end
289
290 % Algorithm
291 %
292 % Working <- []
293 %
294 % repeat:
295 % (W, Xi) <- solver(X, Y, C, Working)
296 %
297 % for i = 1:|X|
298 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
299 %
300 % Working <- Working + (y^_1,y^_2,...,y^_n)
301 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
302 % <= Xi + epsilon
303
304
305
306
307
308 % Initialize
309 W = INIT(X);
310
311 ClassScores = [];
312
313 if isa(Y, 'double')
314 Ypos = [];
315 Yneg = [];
316 ClassScores = synthesizeRelevance(Y);
317
318 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
319 dbprint(2, 'Using supplied Ypos/Yneg');
320 Ypos = Y(:,1);
321 Yneg = Y(:,2);
322
323 % Compute the valid samples
324 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
325 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
326 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
327 Ypos = Y(:,1);
328 Yneg = [];
329 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
330 else
331 error('MLR:LABELS', 'Incorrect format for Y.');
332 end
333
334
335 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
336 'feature', Feature, ... % Which ranking feature is used?
337 'k', k, ... % What is the ranking length?
338 'regularizer', Regularizer, ... % What regularization is used?
339 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
340 'num_calls_SO', 0, ... % Calls to separation oracle
341 'num_calls_solver', 0, ... % Calls to solver
342 'time_SO', 0, ... % Time in separation oracle
343 'time_solver', 0, ... % Time in solver
344 'time_total', 0, ... % Total time
345 'f', [], ... % Objective value
346 'num_steps', [], ... % Number of steps for each solver run
347 'num_constraints', [], ... % Number of constraints for each run
348 'Xi', [], ... % Slack achieved for each run
349 'Delta', [], ... % Mean loss for each SO call
350 'gap', [], ... % Gap between loss and slack
351 'C', C, ... % Slack trade-off
352 'epsilon', E, ... % Convergence threshold
353 'constraint_timer', ConstraintClock); % Time before evicting old constraints
354
355
356
357 global PsiR;
358 global PsiClock;
359
360 PsiR = {};
361 PsiClock = [];
362
363 Xi = -Inf;
364 Margins = [];
365
366 if STOCHASTIC
367 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
368 end
369
370 while 1
371 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
372 % Generate a constraint set
373 Termination = 0;
374
375
376 dbprint(2, 'Calling separation oracle...');
377
378 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
379 Termination = LOSS(W, PsiNew, Mnew, 0);
380
381 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
382 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
383
384 Margins = cat(1, Margins, Mnew);
385 PsiR = cat(1, PsiR, PsiNew);
386 PsiClock = cat(1, PsiClock, 0);
387
388 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
389 dbprint(2, '\t Mean loss : %0.4f', Mnew);
390 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
391
392 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
393 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
394
395 if Termination <= Xi + E
396 dbprint(2, 'Done.');
397 break;
398 end
399
400 dbprint(2, 'Calling solver...');
401 PsiClock = PsiClock + 1;
402 Solver_time = tic();
403 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X);
404 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
405 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
406
407 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
408 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
409 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
410
411 %%%
412 % Cull the old constraints
413 GC = PsiClock < ConstraintClock;
414 Margins = Margins(GC);
415 PsiR = PsiR(GC);
416 PsiClock = PsiClock(GC);
417
418 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
419 end
420
421
422 % Finish diagnostics
423
424 Diagnostics.time_total = toc(TIME_START);
425 end
426
427
428 function ClassScores = synthesizeRelevance(Y)
429
430 classes = unique(Y);
431 nClasses = length(classes);
432
433 ClassScores = struct( 'Y', Y, ...
434 'classes', classes, ...
435 'Ypos', [], ...
436 'Yneg', []);
437
438 Ypos = cell(nClasses, 1);
439 Yneg = cell(nClasses, 1);
440 for c = 1:nClasses
441 Ypos{c} = (Y == classes(c));
442 Yneg{c} = ~Ypos{c};
443 end
444
445 ClassScores.Ypos = Ypos;
446 ClassScores.Yneg = Yneg;
447
448 end