Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/rmlr_train.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W, Xi, Diagnostics] = rmlr_train(X, Y, Cslack, varargin) | |
2 %[W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, stochastic, lam, STRUCTREG) | |
3 % | |
4 % [W, Xi, D] = rmlr_train(X, Y, C,...) | |
5 % | |
6 % X = d*n data matrix | |
7 % Y = either n-by-1 label of vectors | |
8 % OR | |
9 % n-by-2 cell array where | |
10 % Y{q,1} contains relevant indices for q, and | |
11 % Y{q,2} contains irrelevant indices for q | |
12 % | |
13 % C >= 0 slack trade-off parameter (default=1) | |
14 % | |
15 % W = the learned metric | |
16 % Xi = slack value on the learned metric | |
17 % D = diagnostics | |
18 % | |
19 % Optional arguments: | |
20 % | |
21 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS) | |
22 % where LOSS is one of: | |
23 % 'AUC': Area under ROC curve (default) | |
24 % 'KNN': KNN accuracy | |
25 % 'Prec@k': Precision-at-k | |
26 % 'MAP': Mean Average Precision | |
27 % 'MRR': Mean Reciprocal Rank | |
28 % 'NDCG': Normalized Discounted Cumulative Gain | |
29 % | |
30 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k) | |
31 % where k is the number of neighbors for Prec@k or NDCG | |
32 % (default=3) | |
33 % | |
34 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG) | |
35 % where REG defines the regularization on W, and is one of: | |
36 % 0: no regularization | |
37 % 1: 1-norm: trace(W) (default) | |
38 % 2: 2-norm: trace(W' * W) | |
39 % 3: Kernel: trace(W * X), assumes X is square and positive-definite | |
40 % | |
41 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal) | |
42 % Not implemented as learning a diagonal W metric just reduces to MLR because ||W||_2,1 = trace(W) when W is diagonal. | |
43 % | |
44 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) | |
45 % where B > 0 enables stochastic optimization with batch size B | |
46 % | |
47 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda) | |
48 % lambda is the desired value of the hyperparameter which is the coefficient of ||W||_2,1. Default is 1 if lambda is not set | |
49 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC) | |
50 % Set ConstraintClock to CC (default: 20, 100) | |
51 % | |
52 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC, E) | |
53 % Set ConstraintClock to E (default: 1e-3) | |
54 % | |
55 | |
56 | |
57 TIME_START = tic(); | |
58 | |
59 global C; | |
60 C = Cslack; | |
61 | |
62 [d,n,m] = size(X); | |
63 | |
64 if m > 1 | |
65 MKL = 1; | |
66 else | |
67 MKL = 0; | |
68 end | |
69 | |
70 if nargin < 3 | |
71 C = 1; | |
72 end | |
73 | |
74 %%% | |
75 % Default options: | |
76 | |
77 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW THRESH INIT; | |
78 global RHO; | |
79 | |
80 | |
81 %%% | |
82 % Augmented lagrangian factor | |
83 RHO = 1; | |
84 | |
85 % </modified> | |
86 | |
87 global FEASIBLE_COUNT; | |
88 FEASIBLE_COUNT = 0; | |
89 | |
90 CP = @cuttingPlaneFull; | |
91 SO = @separationOracleAUC; | |
92 PSI = @metricPsiPO; | |
93 | |
94 if ~MKL | |
95 INIT = @initializeFull; | |
96 REG = @regularizeTraceFull; | |
97 STRUCTKERNEL= @structKernelLinear; | |
98 DUALW = @dualWLinear; | |
99 FEASIBLE = @feasibleFull; | |
100 THRESH = @threshFull_admmMixed; | |
101 CPGRADIENT = @cpGradientFull; | |
102 DISTANCE = @distanceFull; | |
103 SETDISTANCE = @setDistanceFull; | |
104 LOSS = @lossHinge; | |
105 Regularizer = 'Trace'; | |
106 else | |
107 INIT = @initializeFullMKL; | |
108 REG = @regularizeMKLFull; | |
109 STRUCTKERNEL= @structKernelMKL; | |
110 DUALW = @dualWMKL; | |
111 FEASIBLE = @feasibleFullMKL; | |
112 THRESH = @threshFull_admmMixed; | |
113 CPGRADIENT = @cpGradientFullMKL; | |
114 DISTANCE = @distanceFullMKL; | |
115 SETDISTANCE = @setDistanceFullMKL; | |
116 LOSS = @lossHingeFullMKL; | |
117 Regularizer = 'Trace'; | |
118 end | |
119 | |
120 | |
121 Loss = 'AUC'; | |
122 Feature = 'metricPsiPO'; | |
123 | |
124 | |
125 %%% | |
126 % Default k for prec@k, ndcg | |
127 k = 3; | |
128 | |
129 %%% | |
130 % Stochastic violator selection? | |
131 STOCHASTIC = 0; | |
132 batchSize = n; | |
133 SAMPLES = 1:n; | |
134 | |
135 | |
136 if nargin > 3 | |
137 switch lower(varargin{1}) | |
138 case {'auc'} | |
139 SO = @separationOracleAUC; | |
140 PSI = @metricPsiPO; | |
141 Loss = 'AUC'; | |
142 Feature = 'metricPsiPO'; | |
143 case {'knn'} | |
144 SO = @separationOracleKNN; | |
145 PSI = @metricPsiPO; | |
146 Loss = 'KNN'; | |
147 Feature = 'metricPsiPO'; | |
148 case {'prec@k'} | |
149 SO = @separationOraclePrecAtK; | |
150 PSI = @metricPsiPO; | |
151 Loss = 'Prec@k'; | |
152 Feature = 'metricPsiPO'; | |
153 case {'map'} | |
154 SO = @separationOracleMAP; | |
155 PSI = @metricPsiPO; | |
156 Loss = 'MAP'; | |
157 Feature = 'metricPsiPO'; | |
158 case {'mrr'} | |
159 SO = @separationOracleMRR; | |
160 PSI = @metricPsiPO; | |
161 Loss = 'MRR'; | |
162 Feature = 'metricPsiPO'; | |
163 case {'ndcg'} | |
164 SO = @separationOracleNDCG; | |
165 PSI = @metricPsiPO; | |
166 Loss = 'NDCG'; | |
167 Feature = 'metricPsiPO'; | |
168 otherwise | |
169 error('MLR:LOSS', ... | |
170 'Unknown loss function: %s', varargin{1}); | |
171 end | |
172 end | |
173 | |
174 if nargin > 4 | |
175 k = varargin{2}; | |
176 end | |
177 | |
178 Diagonal = 0; | |
179 %Diagonal case is not implemented. Use mlr_train for that. | |
180 | |
181 if nargin > 5 | |
182 switch(varargin{3}) | |
183 case {0} | |
184 REG = @regularizeNone; | |
185 Regularizer = 'None'; | |
186 THRESH = @threshFull_admmMixed; | |
187 case {1} | |
188 if MKL | |
189 REG = @regularizeMKLFull; | |
190 STRUCTKERNEL= @structKernelMKL; | |
191 DUALW = @dualWMKL; | |
192 else | |
193 REG = @regularizeTraceFull; | |
194 STRUCTKERNEL= @structKernelLinear; | |
195 DUALW = @dualWLinear; | |
196 end | |
197 Regularizer = 'Trace'; | |
198 | |
199 case {2} | |
200 REG = @regularizeTwoFull; | |
201 Regularizer = '2-norm'; | |
202 error('MLR:REGULARIZER', '2-norm regularization no longer supported'); | |
203 | |
204 | |
205 case {3} | |
206 if MKL | |
207 REG = @regularizeMKLFull; | |
208 STRUCTKERNEL= @structKernelMKL; | |
209 DUALW = @dualWMKL; | |
210 else | |
211 REG = @regularizeKernel; | |
212 STRUCTKERNEL= @structKernelMKL; | |
213 DUALW = @dualWMKL; | |
214 end | |
215 Regularizer = 'Kernel'; | |
216 | |
217 | |
218 otherwise | |
219 error('MLR:REGULARIZER', ... | |
220 'Unknown regularization: %s', varargin{3}); | |
221 end | |
222 end | |
223 | |
224 | |
225 % Are we in stochastic optimization mode? | |
226 if nargin > 7 && varargin{5} > 0 | |
227 if varargin{5} < n | |
228 STOCHASTIC = 1; | |
229 CP = @cuttingPlaneRandom; | |
230 batchSize = varargin{5}; | |
231 end | |
232 end | |
233 % Algorithm | |
234 % | |
235 % Working <- [] | |
236 % | |
237 % repeat: | |
238 % (W, Xi) <- solver(X, Y, C, Working) | |
239 % | |
240 % for i = 1:|X| | |
241 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) | |
242 % | |
243 % Working <- Working + (y^_1,y^_2,...,y^_n) | |
244 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) | |
245 % <= Xi + epsilon | |
246 | |
247 if nargin > 8 | |
248 lam = varargin{6}; | |
249 else | |
250 lam = 1; | |
251 end | |
252 | |
253 disp(['lam = ' num2str(lam)]) | |
254 | |
255 global DEBUG; | |
256 | |
257 if isempty(DEBUG) | |
258 DEBUG = 0; | |
259 end | |
260 | |
261 DEBUG = 1; | |
262 | |
263 %%% | |
264 % Max calls to seperation oracle | |
265 MAX_CALLS = 200; | |
266 MIN_CALLS = 10; | |
267 | |
268 %%% | |
269 % Timer to eliminate old constraints | |
270 ConstraintClock = 500; % standard: 500 | |
271 | |
272 if nargin > 9 && varargin{7} > 0 | |
273 ConstraintClock = varargin{7}; | |
274 end | |
275 | |
276 %%% | |
277 % Convergence criteria for worst-violated constraint | |
278 E = 1e-3; | |
279 if nargin > 10 && varargin{8} > 0 | |
280 E = varargin{8}; | |
281 end | |
282 | |
283 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu> | |
284 % no longer belongs here | |
285 % Initialize | |
286 W = INIT(X); | |
287 | |
288 | |
289 global ADMM_Z ADMM_V ADMM_UW ADMM_UV; | |
290 ADMM_Z = W; | |
291 ADMM_V = W; | |
292 ADMM_UW = 0 * ADMM_Z; | |
293 ADMM_UV = 0 * ADMM_Z; | |
294 | |
295 ClassScores = []; | |
296 | |
297 if isa(Y, 'double') | |
298 Ypos = []; | |
299 Yneg = []; | |
300 ClassScores = synthesizeRelevance(Y); | |
301 | |
302 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 | |
303 dbprint(2, 'Using supplied Ypos/Yneg'); | |
304 Ypos = Y(:,1); | |
305 Yneg = Y(:,2); | |
306 | |
307 % Compute the valid samples | |
308 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); | |
309 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 | |
310 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
311 Ypos = Y(:,1); | |
312 Yneg = []; | |
313 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); | |
314 else | |
315 error('MLR:LABELS', 'Incorrect format for Y.'); | |
316 end | |
317 %% | |
318 % If we don't have enough data to make the batch, cut the batch | |
319 batchSize = min([batchSize, length(SAMPLES)]); | |
320 | |
321 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? | |
322 'feature', Feature, ... % Which ranking feature is used? | |
323 'k', k, ... % What is the ranking length? | |
324 'regularizer', Regularizer, ... % What regularization is used? | |
325 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal | |
326 'num_calls_SO', 0, ... % Calls to separation oracle | |
327 'num_calls_solver', 0, ... % Calls to solver | |
328 'time_SO', 0, ... % Time in separation oracle | |
329 'time_solver', 0, ... % Time in solver | |
330 'time_total', 0, ... % Total time | |
331 'f', [], ... % Objective value | |
332 'num_steps', [], ... % Number of steps for each solver run | |
333 'num_constraints', [], ... % Number of constraints for each run | |
334 'Xi', [], ... % Slack achieved for each run | |
335 'Delta', [], ... % Mean loss for each SO call | |
336 'gap', [], ... % Gap between loss and slack | |
337 'C', C, ... % Slack trade-off | |
338 'epsilon', E, ... % Convergence threshold | |
339 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's | |
340 'constraint_timer', ConstraintClock); % Time before evicting old constraints | |
341 | |
342 | |
343 | |
344 global PsiR; | |
345 global PsiClock; | |
346 | |
347 PsiR = {}; | |
348 PsiClock = []; | |
349 | |
350 Xi = -Inf; | |
351 Margins = []; | |
352 H = []; | |
353 Q = []; | |
354 | |
355 if STOCHASTIC | |
356 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); | |
357 end | |
358 | |
359 dbprint(2,['Regularizer is "' Regularizer '"']); | |
360 while 1 | |
361 if Diagnostics.num_calls_solver > MAX_CALLS | |
362 dbprint(2,['Calls to SO >= ' num2str(MAX_CALLS)]); | |
363 break; | |
364 end | |
365 | |
366 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); | |
367 % Generate a constraint set | |
368 Termination = -Inf; | |
369 | |
370 | |
371 dbprint(2, 'Calling separation oracle...'); | |
372 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); | |
373 | |
374 Termination = LOSS(W, PsiNew, Mnew, 0); | |
375 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; | |
376 Diagnostics.time_SO = Diagnostics.time_SO + SO_time; | |
377 | |
378 Margins = cat(1, Margins, Mnew); | |
379 PsiR = cat(1, PsiR, PsiNew); | |
380 PsiClock = cat(1, PsiClock, 0); | |
381 H = expandKernel(H); | |
382 Q = expandRegularizer(Q, X, W); | |
383 | |
384 | |
385 dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); | |
386 dbprint(2, '\t Mean loss : %0.4f', Mnew); | |
387 dbprint(2, '\t Current loss Xi : %0.4f', Xi); | |
388 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E); | |
389 | |
390 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi); | |
391 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew); | |
392 | |
393 %if Termination <= Xi + E | |
394 if Termination <= Xi + E && Diagnostics.num_calls_solver > MIN_CALLS | |
395 %if Termination - Xi <= E | |
396 dbprint(1, 'Done.'); | |
397 break; | |
398 end | |
399 | |
400 | |
401 | |
402 dbprint(1, 'Calling solver...'); | |
403 PsiClock = PsiClock + 1; | |
404 Solver_time = tic(); | |
405 % disp('Robust MLR') | |
406 [W, Xi, Dsolver] = rmlr_admm(C, X, Margins, H, Q, lam); | |
407 | |
408 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); | |
409 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; | |
410 | |
411 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); | |
412 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); | |
413 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); | |
414 | |
415 %%% | |
416 % Cull the old constraints | |
417 GC = PsiClock < ConstraintClock; | |
418 Margins = Margins(GC); | |
419 PsiR = PsiR(GC); | |
420 PsiClock = PsiClock(GC); | |
421 H = H(GC, GC); | |
422 Q = Q(GC); | |
423 | |
424 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); | |
425 end | |
426 | |
427 | |
428 % Finish diagnostics | |
429 | |
430 Diagnostics.time_total = toc(TIME_START); | |
431 Diagnostics.feasible_count = FEASIBLE_COUNT; | |
432 end | |
433 | |
434 function H = expandKernel(H) | |
435 | |
436 global STRUCTKERNEL; | |
437 global PsiR; | |
438 | |
439 m = length(H); | |
440 H = padarray(H, [1 1], 0, 'post'); | |
441 | |
442 | |
443 for i = 1:m+1 | |
444 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} ); | |
445 H(m+1, i) = H(i, m+1); | |
446 end | |
447 end | |
448 | |
449 function Q = expandRegularizer(Q, K, W) | |
450 | |
451 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu> | |
452 % does not support unregularized learning | |
453 | |
454 global PsiR; | |
455 global STRUCTKERNEL REG; | |
456 | |
457 m = length(Q); | |
458 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1}); | |
459 | |
460 end | |
461 | |
462 function ClassScores = synthesizeRelevance(Y) | |
463 | |
464 classes = unique(Y); | |
465 nClasses = length(classes); | |
466 | |
467 ClassScores = struct( 'Y', Y, ... | |
468 'classes', classes, ... | |
469 'Ypos', [], ... | |
470 'Yneg', []); | |
471 | |
472 Ypos = cell(nClasses, 1); | |
473 Yneg = cell(nClasses, 1); | |
474 for c = 1:nClasses | |
475 Ypos{c} = (Y == classes(c)); | |
476 Yneg{c} = ~Ypos{c}; | |
477 end | |
478 | |
479 ClassScores.Ypos = Ypos; | |
480 ClassScores.Yneg = Yneg; | |
481 | |
482 end |