Mercurial > hg > camir-ismir2012
comparison toolboxes/distance_learning/mlr/mlr_train.m @ 0:cc4b1211e677 tip
initial commit to HG from
Changeset:
646 (e263d8a21543) added further path and more save "camirversion.m"
author | Daniel Wolff |
---|---|
date | Fri, 19 Aug 2016 13:07:06 +0200 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:cc4b1211e677 |
---|---|
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, C, varargin) | |
2 % | |
3 % [W, Xi, D] = mlr_train(X, Y, C,...) | |
4 % | |
5 % X = d*n data matrix | |
6 % Y = either n-by-1 label of vectors | |
7 % OR | |
8 % n-by-2 cell array where | |
9 % Y{q,1} contains relevant indices for q, and | |
10 % Y{q,2} contains irrelevant indices for q | |
11 % | |
12 % C >= 0 slack trade-off parameter (default=1) | |
13 % | |
14 % W = the learned metric | |
15 % Xi = slack value on the learned metric | |
16 % D = diagnostics | |
17 % | |
18 % Optional arguments: | |
19 % | |
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS) | |
21 % where LOSS is one of: | |
22 % 'AUC': Area under ROC curve (default) | |
23 % 'KNN': KNN accuracy | |
24 % 'Prec@k': Precision-at-k | |
25 % 'MAP': Mean Average Precision | |
26 % 'MRR': Mean Reciprocal Rank | |
27 % 'NDCG': Normalized Discounted Cumulative Gain | |
28 % | |
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) | |
30 % where k is the number of neighbors for Prec@k or NDCG | |
31 % (default=3) | |
32 % | |
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) | |
34 % where REG defines the regularization on W, and is one of: | |
35 % 0: no regularization | |
36 % 1: 1-norm: trace(W) (default) | |
37 % 2: 2-norm: trace(W' * W) | |
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite | |
39 % | |
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) | |
41 % Diagonal = 0: learn a full d-by-d W (default) | |
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1) | |
43 % | |
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) | |
45 % where B > 0 enables stochastic optimization with batch size B | |
46 % | |
47 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC) | |
48 % Set ConstraintClock to CC (default: 20, 100) | |
49 % | |
50 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E) | |
51 % Set ConstraintClock to E (default: 1e-3) | |
52 % | |
53 | |
54 | |
55 global globalvars; | |
56 global DEBUG; | |
57 | |
58 if isfield(globalvars, 'debug') | |
59 | |
60 DEBUG = globalvars.debug; | |
61 else | |
62 | |
63 DEBUG = 0; | |
64 end | |
65 | |
66 | |
67 | |
68 TIME_START = tic(); | |
69 | |
70 % addpath('cuttingPlane', 'distance', 'feasible', 'initialize', 'loss', ... | |
71 % 'metricPsi', 'regularize', 'separationOracle', 'util'); | |
72 | |
73 [d,n,m] = size(X); | |
74 | |
75 if m > 1 | |
76 MKL = 1; | |
77 else | |
78 MKL = 0; | |
79 end | |
80 | |
81 if nargin < 3 | |
82 C = 1; | |
83 end | |
84 | |
85 %%% | |
86 % Default options: | |
87 | |
88 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT METRICK; | |
89 | |
90 CP = @cuttingPlaneFull; | |
91 SO = @separationOracleAUC; | |
92 PSI = @metricPsiPO; | |
93 | |
94 if ~MKL | |
95 INIT = @initializeFull; | |
96 REG = @regularizeTraceFull; | |
97 FEASIBLE = @feasibleFull; | |
98 CPGRADIENT = @cpGradientFull; | |
99 DISTANCE = @distanceFull; | |
100 SETDISTANCE = @setDistanceFull; | |
101 LOSS = @lossHinge; | |
102 Regularizer = 'Trace'; | |
103 else | |
104 INIT = @initializeFullMKL; | |
105 REG = @regularizeMKLFull; | |
106 FEASIBLE = @feasibleFullMKL; | |
107 CPGRADIENT = @cpGradientFullMKL; | |
108 DISTANCE = @distanceFullMKL; | |
109 SETDISTANCE = @setDistanceFullMKL; | |
110 LOSS = @lossHingeFullMKL; | |
111 Regularizer = 'Trace'; | |
112 end | |
113 | |
114 | |
115 Loss = 'AUC'; | |
116 Feature = 'metricPsiPO'; | |
117 | |
118 | |
119 %%% | |
120 % Default k for prec@k, ndcg | |
121 k = 3; | |
122 | |
123 %%% | |
124 % Stochastic violator selection? | |
125 STOCHASTIC = 0; | |
126 batchSize = n; | |
127 SAMPLES = 1:n; | |
128 | |
129 | |
130 if nargin > 3 | |
131 switch lower(varargin{1}) | |
132 case {'auc'} | |
133 SO = @separationOracleAUC; | |
134 PSI = @metricPsiPO; | |
135 Loss = 'AUC'; | |
136 Feature = 'metricPsiPO'; | |
137 case {'knn'} | |
138 SO = @separationOracleKNN; | |
139 PSI = @metricPsiPO; | |
140 Loss = 'KNN'; | |
141 Feature = 'metricPsiPO'; | |
142 case {'prec@k'} | |
143 SO = @separationOraclePrecAtK; | |
144 PSI = @metricPsiPO; | |
145 Loss = 'Prec@k'; | |
146 Feature = 'metricPsiPO'; | |
147 case {'map'} | |
148 SO = @separationOracleMAP; | |
149 PSI = @metricPsiPO; | |
150 Loss = 'MAP'; | |
151 Feature = 'metricPsiPO'; | |
152 case {'mrr'} | |
153 SO = @separationOracleMRR; | |
154 PSI = @metricPsiPO; | |
155 Loss = 'MRR'; | |
156 Feature = 'metricPsiPO'; | |
157 case {'ndcg'} | |
158 SO = @separationOracleNDCG; | |
159 PSI = @metricPsiPO; | |
160 Loss = 'NDCG'; | |
161 Feature = 'metricPsiPO'; | |
162 otherwise | |
163 error('MLR:LOSS', ... | |
164 'Unknown loss function: %s', varargin{1}); | |
165 end | |
166 end | |
167 | |
168 if nargin > 4 | |
169 k = varargin{2}; | |
170 end | |
171 | |
172 METRICK = k; | |
173 | |
174 Diagonal = 0; | |
175 if nargin > 6 & varargin{4} > 0 | |
176 Diagonal = varargin{4}; | |
177 | |
178 if ~MKL | |
179 INIT = @initializeDiag; | |
180 REG = @regularizeTraceDiag; | |
181 FEASIBLE = @feasibleDiag; | |
182 CPGRADIENT = @cpGradientDiag; | |
183 DISTANCE = @distanceDiag; | |
184 SETDISTANCE = @setDistanceDiag; | |
185 Regularizer = 'Trace'; | |
186 else | |
187 if Diagonal > 1 | |
188 INIT = @initializeDODMKL; | |
189 REG = @regularizeMKLDOD; | |
190 FEASIBLE = @feasibleDODMKL; | |
191 CPGRADIENT = @cpGradientDODMKL; | |
192 DISTANCE = @distanceDODMKL; | |
193 SETDISTANCE = @setDistanceDODMKL; | |
194 LOSS = @lossHingeDODMKL; | |
195 Regularizer = 'Trace'; | |
196 else | |
197 INIT = @initializeDiagMKL; | |
198 REG = @regularizeMKLDiag; | |
199 FEASIBLE = @feasibleDiagMKL; | |
200 CPGRADIENT = @cpGradientDiagMKL; | |
201 DISTANCE = @distanceDiagMKL; | |
202 SETDISTANCE = @setDistanceDiagMKL; | |
203 LOSS = @lossHingeDiagMKL; | |
204 Regularizer = 'Trace'; | |
205 end | |
206 end | |
207 end | |
208 | |
209 if nargin > 5 | |
210 switch(varargin{3}) | |
211 case {0} | |
212 REG = @regularizeNone; | |
213 Regularizer = 'None'; | |
214 | |
215 case {1} | |
216 if MKL | |
217 if Diagonal == 0 | |
218 REG = @regularizeMKLFull; | |
219 elseif Diagonal == 1 | |
220 REG = @regularizeMKLDiag; | |
221 elseif Diagonal == 2 | |
222 REG = @regularizeMKLDOD; | |
223 end | |
224 else | |
225 if Diagonal | |
226 REG = @regularizeTraceDiag; | |
227 else | |
228 REG = @regularizeTraceFull; | |
229 end | |
230 end | |
231 Regularizer = 'Trace'; | |
232 | |
233 case {2} | |
234 if Diagonal | |
235 REG = @regularizeTwoDiag; | |
236 else | |
237 REG = @regularizeTwoFull; | |
238 end | |
239 Regularizer = '2-norm'; | |
240 | |
241 case {3} | |
242 if MKL | |
243 if Diagonal == 0 | |
244 REG = @regularizeMKLFull; | |
245 elseif Diagonal == 1 | |
246 REG = @regularizeMKLDiag; | |
247 elseif Diagonal == 2 | |
248 REG = @regularizeMKLDOD; | |
249 end | |
250 else | |
251 if Diagonal | |
252 REG = @regularizeMKLDiag; | |
253 else | |
254 REG = @regularizeKernel; | |
255 end | |
256 end | |
257 Regularizer = 'Kernel'; | |
258 | |
259 otherwise | |
260 error('MLR:REGULARIZER', ... | |
261 'Unknown regularization: %s', varargin{3}); | |
262 end | |
263 end | |
264 | |
265 | |
266 % Are we in stochastic optimization mode? | |
267 if nargin > 7 && varargin{5} > 0 | |
268 if varargin{5} < n | |
269 STOCHASTIC = 1; | |
270 CP = @cuttingPlaneRandom; | |
271 batchSize = varargin{5}; | |
272 end | |
273 end | |
274 | |
275 %%% | |
276 % Timer to eliminate old constraints | |
277 ConstraintClock = 20; | |
278 | |
279 if nargin > 8 && varargin{6} > 0 | |
280 ConstraintClock = varargin{6}; | |
281 end | |
282 | |
283 %%% | |
284 % Convergence criteria for worst-violated constraint | |
285 E = 1e-3; | |
286 if nargin > 9 && varargin{7} > 0 | |
287 E = varargin{7}; | |
288 end | |
289 | |
290 % Algorithm | |
291 % | |
292 % Working <- [] | |
293 % | |
294 % repeat: | |
295 % (W, Xi) <- solver(X, Y, C, Working) | |
296 % | |
297 % for i = 1:|X| | |
298 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) | |
299 % | |
300 % Working <- Working + (y^_1,y^_2,...,y^_n) | |
301 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) | |
302 % <= Xi + epsilon | |
303 | |
304 | |
305 | |
306 | |
307 | |
308 % Initialize | |
309 W = INIT(X); | |
310 | |
311 ClassScores = []; | |
312 | |
313 if isa(Y, 'double') | |
314 Ypos = []; | |
315 Yneg = []; | |
316 ClassScores = synthesizeRelevance(Y); | |
317 | |
318 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 | |
319 dbprint(2, 'Using supplied Ypos/Yneg'); | |
320 Ypos = Y(:,1); | |
321 Yneg = Y(:,2); | |
322 | |
323 % Compute the valid samples | |
324 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); | |
325 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 | |
326 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
327 Ypos = Y(:,1); | |
328 Yneg = []; | |
329 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); | |
330 else | |
331 error('MLR:LABELS', 'Incorrect format for Y.'); | |
332 end | |
333 | |
334 | |
335 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? | |
336 'feature', Feature, ... % Which ranking feature is used? | |
337 'k', k, ... % What is the ranking length? | |
338 'regularizer', Regularizer, ... % What regularization is used? | |
339 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal | |
340 'num_calls_SO', 0, ... % Calls to separation oracle | |
341 'num_calls_solver', 0, ... % Calls to solver | |
342 'time_SO', 0, ... % Time in separation oracle | |
343 'time_solver', 0, ... % Time in solver | |
344 'time_total', 0, ... % Total time | |
345 'f', [], ... % Objective value | |
346 'num_steps', [], ... % Number of steps for each solver run | |
347 'num_constraints', [], ... % Number of constraints for each run | |
348 'Xi', [], ... % Slack achieved for each run | |
349 'Delta', [], ... % Mean loss for each SO call | |
350 'gap', [], ... % Gap between loss and slack | |
351 'C', C, ... % Slack trade-off | |
352 'epsilon', E, ... % Convergence threshold | |
353 'constraint_timer', ConstraintClock); % Time before evicting old constraints | |
354 | |
355 | |
356 | |
357 global PsiR; | |
358 global PsiClock; | |
359 | |
360 PsiR = {}; | |
361 PsiClock = []; | |
362 | |
363 Xi = -Inf; | |
364 Margins = []; | |
365 | |
366 if STOCHASTIC | |
367 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); | |
368 end | |
369 | |
370 while 1 | |
371 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); | |
372 % Generate a constraint set | |
373 Termination = 0; | |
374 | |
375 | |
376 dbprint(2, 'Calling separation oracle...'); | |
377 | |
378 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); | |
379 Termination = LOSS(W, PsiNew, Mnew, 0); | |
380 | |
381 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; | |
382 Diagnostics.time_SO = Diagnostics.time_SO + SO_time; | |
383 | |
384 Margins = cat(1, Margins, Mnew); | |
385 PsiR = cat(1, PsiR, PsiNew); | |
386 PsiClock = cat(1, PsiClock, 0); | |
387 | |
388 dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); | |
389 dbprint(2, '\t Mean loss : %0.4f', Mnew); | |
390 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E); | |
391 | |
392 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi); | |
393 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew); | |
394 | |
395 if Termination <= Xi + E | |
396 dbprint(2, 'Done.'); | |
397 break; | |
398 end | |
399 | |
400 dbprint(2, 'Calling solver...'); | |
401 PsiClock = PsiClock + 1; | |
402 Solver_time = tic(); | |
403 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X); | |
404 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); | |
405 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; | |
406 | |
407 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); | |
408 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); | |
409 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); | |
410 | |
411 %%% | |
412 % Cull the old constraints | |
413 GC = PsiClock < ConstraintClock; | |
414 Margins = Margins(GC); | |
415 PsiR = PsiR(GC); | |
416 PsiClock = PsiClock(GC); | |
417 | |
418 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); | |
419 end | |
420 | |
421 | |
422 % Finish diagnostics | |
423 | |
424 Diagnostics.time_total = toc(TIME_START); | |
425 end | |
426 | |
427 | |
428 function ClassScores = synthesizeRelevance(Y) | |
429 | |
430 classes = unique(Y); | |
431 nClasses = length(classes); | |
432 | |
433 ClassScores = struct( 'Y', Y, ... | |
434 'classes', classes, ... | |
435 'Ypos', [], ... | |
436 'Yneg', []); | |
437 | |
438 Ypos = cell(nClasses, 1); | |
439 Yneg = cell(nClasses, 1); | |
440 for c = 1:nClasses | |
441 Ypos{c} = (Y == classes(c)); | |
442 Yneg{c} = ~Ypos{c}; | |
443 end | |
444 | |
445 ClassScores.Ypos = Ypos; | |
446 ClassScores.Yneg = Yneg; | |
447 | |
448 end |