comparison toolboxes/distance_learning/mlr/mlr_train_primal.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
2 %
3 % [W, Xi, D] = mlr_train(X, Y, C,...)
4 %
5 % X = d*n data matrix
6 % Y = either n-by-1 label of vectors
7 % OR
8 % n-by-2 cell array where
9 % Y{q,1} contains relevant indices for q, and
10 % Y{q,2} contains irrelevant indices for q
11 %
12 % C >= 0 slack trade-off parameter (default=1)
13 %
14 % W = the learned metric
15 % Xi = slack value on the learned metric
16 % D = diagnostics
17 %
18 % Optional arguments:
19 %
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
21 % where LOSS is one of:
22 % 'AUC': Area under ROC curve (default)
23 % 'KNN': KNN accuracy
24 % 'Prec@k': Precision-at-k
25 % 'MAP': Mean Average Precision
26 % 'MRR': Mean Reciprocal Rank
27 % 'NDCG': Normalized Discounted Cumulative Gain
28 %
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
30 % where k is the number of neighbors for Prec@k or NDCG
31 % (default=3)
32 %
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
34 % where REG defines the regularization on W, and is one of:
35 % 0: no regularization
36 % 1: 1-norm: trace(W) (default)
37 % 2: 2-norm: trace(W' * W)
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
39 %
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
41 % Diagonal = 0: learn a full d-by-d W (default)
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
43 %
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
45 % where B > 0 enables stochastic optimization with batch size B
46 %
47
48 TIME_START = tic();
49
50 global C;
51 C = Cslack;
52
53 [d,n,m] = size(X);
54
55 if m > 1
56 MKL = 1;
57 else
58 MKL = 0;
59 end
60
61 if nargin < 3
62 C = 1;
63 end
64
65 %%%
66 % Default options:
67
68 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT;
69 global FEASIBLE_COUNT;
70 FEASIBLE_COUNT = 0;
71
72 CP = @cuttingPlaneFull;
73 SO = @separationOracleAUC;
74 PSI = @metricPsiPO;
75
76 if ~MKL
77 INIT = @initializeFull;
78 REG = @regularizeTraceFull;
79 FEASIBLE = @feasibleFull;
80 CPGRADIENT = @cpGradientFull;
81 DISTANCE = @distanceFull;
82 SETDISTANCE = @setDistanceFull;
83 LOSS = @lossHinge;
84 Regularizer = 'Trace';
85 else
86 INIT = @initializeFullMKL;
87 REG = @regularizeMKLFull;
88 FEASIBLE = @feasibleFullMKL;
89 CPGRADIENT = @cpGradientFullMKL;
90 DISTANCE = @distanceFullMKL;
91 SETDISTANCE = @setDistanceFullMKL;
92 LOSS = @lossHingeFullMKL;
93 Regularizer = 'Trace';
94 end
95
96
97 Loss = 'AUC';
98 Feature = 'metricPsiPO';
99
100
101 %%%
102 % Default k for prec@k, ndcg
103 k = 3;
104
105 %%%
106 % Stochastic violator selection?
107 STOCHASTIC = 0;
108 batchSize = n;
109 SAMPLES = 1:n;
110
111
112 if nargin > 3
113 switch lower(varargin{1})
114 case {'auc'}
115 SO = @separationOracleAUC;
116 PSI = @metricPsiPO;
117 Loss = 'AUC';
118 Feature = 'metricPsiPO';
119 case {'knn'}
120 SO = @separationOracleKNN;
121 PSI = @metricPsiPO;
122 Loss = 'KNN';
123 Feature = 'metricPsiPO';
124 case {'prec@k'}
125 SO = @separationOraclePrecAtK;
126 PSI = @metricPsiPO;
127 Loss = 'Prec@k';
128 Feature = 'metricPsiPO';
129 case {'map'}
130 SO = @separationOracleMAP;
131 PSI = @metricPsiPO;
132 Loss = 'MAP';
133 Feature = 'metricPsiPO';
134 case {'mrr'}
135 SO = @separationOracleMRR;
136 PSI = @metricPsiPO;
137 Loss = 'MRR';
138 Feature = 'metricPsiPO';
139 case {'ndcg'}
140 SO = @separationOracleNDCG;
141 PSI = @metricPsiPO;
142 Loss = 'NDCG';
143 Feature = 'metricPsiPO';
144 otherwise
145 error('MLR:LOSS', ...
146 'Unknown loss function: %s', varargin{1});
147 end
148 end
149
150 if nargin > 4
151 k = varargin{2};
152 end
153
154 Diagonal = 0;
155 if nargin > 6 & varargin{4} > 0
156 Diagonal = varargin{4};
157
158 if ~MKL
159 INIT = @initializeDiag;
160 REG = @regularizeTraceDiag;
161 FEASIBLE = @feasibleDiag;
162 CPGRADIENT = @cpGradientDiag;
163 DISTANCE = @distanceDiag;
164 SETDISTANCE = @setDistanceDiag;
165 Regularizer = 'Trace';
166 else
167 INIT = @initializeDiagMKL;
168 REG = @regularizeMKLDiag;
169 FEASIBLE = @feasibleDiagMKL;
170 CPGRADIENT = @cpGradientDiagMKL;
171 DISTANCE = @distanceDiagMKL;
172 SETDISTANCE = @setDistanceDiagMKL;
173 LOSS = @lossHingeDiagMKL;
174 Regularizer = 'Trace';
175 end
176 end
177
178 if nargin > 5
179 switch(varargin{3})
180 case {0}
181 REG = @regularizeNone;
182 Regularizer = 'None';
183
184 case {1}
185 if MKL
186 if Diagonal == 0
187 REG = @regularizeMKLFull;
188 elseif Diagonal == 1
189 REG = @regularizeMKLDiag;
190 end
191 else
192 if Diagonal
193 REG = @regularizeTraceDiag;
194 else
195 REG = @regularizeTraceFull;
196 end
197 end
198 Regularizer = 'Trace';
199
200 case {2}
201 if Diagonal
202 REG = @regularizeTwoDiag;
203 else
204 REG = @regularizeTwoFull;
205 end
206 Regularizer = '2-norm';
207
208 case {3}
209 if MKL
210 if Diagonal == 0
211 REG = @regularizeMKLFull;
212 elseif Diagonal == 1
213 REG = @regularizeMKLDiag;
214 end
215 else
216 if Diagonal
217 REG = @regularizeMKLDiag;
218 else
219 REG = @regularizeKernel;
220 end
221 end
222 Regularizer = 'Kernel';
223
224 otherwise
225 error('MLR:REGULARIZER', ...
226 'Unknown regularization: %s', varargin{3});
227 end
228 end
229
230
231 % Are we in stochastic optimization mode?
232 if nargin > 7 && varargin{5} > 0
233 if varargin{5} < n
234 STOCHASTIC = 1;
235 CP = @cuttingPlaneRandom;
236 batchSize = varargin{5};
237 end
238 end
239 % Algorithm
240 %
241 % Working <- []
242 %
243 % repeat:
244 % (W, Xi) <- solver(X, Y, C, Working)
245 %
246 % for i = 1:|X|
247 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
248 %
249 % Working <- Working + (y^_1,y^_2,...,y^_n)
250 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
251 % <= Xi + epsilon
252
253 global DEBUG;
254
255 if isempty(DEBUG)
256 DEBUG = 0;
257 end
258
259 %%%
260 % Timer to eliminate old constraints
261 ConstraintClock = 100;
262
263 %%%
264 % Convergence criteria for worst-violated constraint
265 E = 1e-3;
266
267 % Initialize
268 W = INIT(X);
269
270 ClassScores = [];
271
272 if isa(Y, 'double')
273 Ypos = [];
274 Yneg = [];
275 ClassScores = synthesizeRelevance(Y);
276
277 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
278 dbprint(1, 'Using supplied Ypos/Yneg');
279 Ypos = Y(:,1);
280 Yneg = Y(:,2);
281
282 % Compute the valid samples
283 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
284 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
285 dbprint(1, 'Using supplied Ypos/synthesized Yneg');
286 Ypos = Y(:,1);
287 Yneg = [];
288 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
289 else
290 error('MLR:LABELS', 'Incorrect format for Y.');
291 end
292
293 %%
294 % If we don't have enough data to make the batch, cut the batch
295 batchSize = min([batchSize, length(SAMPLES)]);
296
297
298 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
299 'feature', Feature, ... % Which ranking feature is used?
300 'k', k, ... % What is the ranking length?
301 'regularizer', Regularizer, ... % What regularization is used?
302 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
303 'num_calls_SO', 0, ... % Calls to separation oracle
304 'num_calls_solver', 0, ... % Calls to solver
305 'time_SO', 0, ... % Time in separation oracle
306 'time_solver', 0, ... % Time in solver
307 'time_total', 0, ... % Total time
308 'f', [], ... % Objective value
309 'num_steps', [], ... % Number of steps for each solver run
310 'num_constraints', [], ... % Number of constraints for each run
311 'Xi', [], ... % Slack achieved for each run
312 'Delta', [], ... % Mean loss for each SO call
313 'gap', [], ... % Gap between loss and slack
314 'C', C, ... % Slack trade-off
315 'epsilon', E, ... % Convergence threshold
316 'feasible_count', 0, ... % Counter for projections
317 'constraint_timer', ConstraintClock); % Time before evicting old constraints
318
319
320
321 global PsiR;
322 global PsiClock;
323
324 PsiR = {};
325 PsiClock = [];
326
327 Xi = -Inf;
328 Margins = [];
329
330 if STOCHASTIC
331 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
332 end
333
334 while 1
335 dbprint(1, 'Round %03d', Diagnostics.num_calls_solver);
336 % Generate a constraint set
337 Termination = 0;
338
339
340 dbprint(2, 'Calling separation oracle...');
341
342 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
343 Termination = LOSS(W, PsiNew, Mnew, 0);
344
345 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
346 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
347
348 Margins = cat(1, Margins, Mnew);
349 PsiR = cat(1, PsiR, PsiNew);
350 PsiClock = cat(1, PsiClock, 0);
351
352 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
353 dbprint(2, '\t Mean loss : %0.4f', Mnew);
354 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
355
356 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
357 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
358
359 if Termination <= Xi + E
360 dbprint(1, 'Done.');
361 break;
362 end
363
364 dbprint(1, 'Calling solver...');
365 PsiClock = PsiClock + 1;
366 Solver_time = tic();
367 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X);
368 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
369 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
370
371 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
372 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
373 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
374
375 %%%
376 % Cull the old constraints
377 GC = PsiClock < ConstraintClock;
378 Margins = Margins(GC);
379 PsiR = PsiR(GC);
380 PsiClock = PsiClock(GC);
381
382 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
383 end
384
385
386 % Finish diagnostics
387
388 Diagnostics.time_total = toc(TIME_START);
389 Diagnostics.feasible_count = FEASIBLE_COUNT;
390 end
391
392
393 function ClassScores = synthesizeRelevance(Y)
394
395 classes = unique(Y);
396 nClasses = length(classes);
397
398 ClassScores = struct( 'Y', Y, ...
399 'classes', classes, ...
400 'Ypos', [], ...
401 'Yneg', []);
402
403 Ypos = cell(nClasses, 1);
404 Yneg = cell(nClasses, 1);
405 for c = 1:nClasses
406 Ypos{c} = (Y == classes(c));
407 Yneg{c} = ~Ypos{c};
408 end
409
410 ClassScores.Ypos = Ypos;
411 ClassScores.Yneg = Yneg;
412
413 end