Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/mlr_train_primal.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin) | |
2 % | |
3 % [W, Xi, D] = mlr_train(X, Y, C,...) | |
4 % | |
5 % X = d*n data matrix | |
6 % Y = either n-by-1 label of vectors | |
7 % OR | |
8 % n-by-2 cell array where | |
9 % Y{q,1} contains relevant indices for q, and | |
10 % Y{q,2} contains irrelevant indices for q | |
11 % | |
12 % C >= 0 slack trade-off parameter (default=1) | |
13 % | |
14 % W = the learned metric | |
15 % Xi = slack value on the learned metric | |
16 % D = diagnostics | |
17 % | |
18 % Optional arguments: | |
19 % | |
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS) | |
21 % where LOSS is one of: | |
22 % 'AUC': Area under ROC curve (default) | |
23 % 'KNN': KNN accuracy | |
24 % 'Prec@k': Precision-at-k | |
25 % 'MAP': Mean Average Precision | |
26 % 'MRR': Mean Reciprocal Rank | |
27 % 'NDCG': Normalized Discounted Cumulative Gain | |
28 % | |
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) | |
30 % where k is the number of neighbors for Prec@k or NDCG | |
31 % (default=3) | |
32 % | |
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) | |
34 % where REG defines the regularization on W, and is one of: | |
35 % 0: no regularization | |
36 % 1: 1-norm: trace(W) (default) | |
37 % 2: 2-norm: trace(W' * W) | |
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite | |
39 % | |
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) | |
41 % Diagonal = 0: learn a full d-by-d W (default) | |
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1) | |
43 % | |
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) | |
45 % where B > 0 enables stochastic optimization with batch size B | |
46 % | |
47 | |
48 TIME_START = tic(); | |
49 | |
50 global C; | |
51 C = Cslack; | |
52 | |
53 [d,n,m] = size(X); | |
54 | |
55 if m > 1 | |
56 MKL = 1; | |
57 else | |
58 MKL = 0; | |
59 end | |
60 | |
61 if nargin < 3 | |
62 C = 1; | |
63 end | |
64 | |
65 %%% | |
66 % Default options: | |
67 | |
68 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT; | |
69 global FEASIBLE_COUNT; | |
70 FEASIBLE_COUNT = 0; | |
71 | |
72 CP = @cuttingPlaneFull; | |
73 SO = @separationOracleAUC; | |
74 PSI = @metricPsiPO; | |
75 | |
76 if ~MKL | |
77 INIT = @initializeFull; | |
78 REG = @regularizeTraceFull; | |
79 FEASIBLE = @feasibleFull; | |
80 CPGRADIENT = @cpGradientFull; | |
81 DISTANCE = @distanceFull; | |
82 SETDISTANCE = @setDistanceFull; | |
83 LOSS = @lossHinge; | |
84 Regularizer = 'Trace'; | |
85 else | |
86 INIT = @initializeFullMKL; | |
87 REG = @regularizeMKLFull; | |
88 FEASIBLE = @feasibleFullMKL; | |
89 CPGRADIENT = @cpGradientFullMKL; | |
90 DISTANCE = @distanceFullMKL; | |
91 SETDISTANCE = @setDistanceFullMKL; | |
92 LOSS = @lossHingeFullMKL; | |
93 Regularizer = 'Trace'; | |
94 end | |
95 | |
96 | |
97 Loss = 'AUC'; | |
98 Feature = 'metricPsiPO'; | |
99 | |
100 | |
101 %%% | |
102 % Default k for prec@k, ndcg | |
103 k = 3; | |
104 | |
105 %%% | |
106 % Stochastic violator selection? | |
107 STOCHASTIC = 0; | |
108 batchSize = n; | |
109 SAMPLES = 1:n; | |
110 | |
111 | |
112 if nargin > 3 | |
113 switch lower(varargin{1}) | |
114 case {'auc'} | |
115 SO = @separationOracleAUC; | |
116 PSI = @metricPsiPO; | |
117 Loss = 'AUC'; | |
118 Feature = 'metricPsiPO'; | |
119 case {'knn'} | |
120 SO = @separationOracleKNN; | |
121 PSI = @metricPsiPO; | |
122 Loss = 'KNN'; | |
123 Feature = 'metricPsiPO'; | |
124 case {'prec@k'} | |
125 SO = @separationOraclePrecAtK; | |
126 PSI = @metricPsiPO; | |
127 Loss = 'Prec@k'; | |
128 Feature = 'metricPsiPO'; | |
129 case {'map'} | |
130 SO = @separationOracleMAP; | |
131 PSI = @metricPsiPO; | |
132 Loss = 'MAP'; | |
133 Feature = 'metricPsiPO'; | |
134 case {'mrr'} | |
135 SO = @separationOracleMRR; | |
136 PSI = @metricPsiPO; | |
137 Loss = 'MRR'; | |
138 Feature = 'metricPsiPO'; | |
139 case {'ndcg'} | |
140 SO = @separationOracleNDCG; | |
141 PSI = @metricPsiPO; | |
142 Loss = 'NDCG'; | |
143 Feature = 'metricPsiPO'; | |
144 otherwise | |
145 error('MLR:LOSS', ... | |
146 'Unknown loss function: %s', varargin{1}); | |
147 end | |
148 end | |
149 | |
150 if nargin > 4 | |
151 k = varargin{2}; | |
152 end | |
153 | |
154 Diagonal = 0; | |
155 if nargin > 6 & varargin{4} > 0 | |
156 Diagonal = varargin{4}; | |
157 | |
158 if ~MKL | |
159 INIT = @initializeDiag; | |
160 REG = @regularizeTraceDiag; | |
161 FEASIBLE = @feasibleDiag; | |
162 CPGRADIENT = @cpGradientDiag; | |
163 DISTANCE = @distanceDiag; | |
164 SETDISTANCE = @setDistanceDiag; | |
165 Regularizer = 'Trace'; | |
166 else | |
167 INIT = @initializeDiagMKL; | |
168 REG = @regularizeMKLDiag; | |
169 FEASIBLE = @feasibleDiagMKL; | |
170 CPGRADIENT = @cpGradientDiagMKL; | |
171 DISTANCE = @distanceDiagMKL; | |
172 SETDISTANCE = @setDistanceDiagMKL; | |
173 LOSS = @lossHingeDiagMKL; | |
174 Regularizer = 'Trace'; | |
175 end | |
176 end | |
177 | |
178 if nargin > 5 | |
179 switch(varargin{3}) | |
180 case {0} | |
181 REG = @regularizeNone; | |
182 Regularizer = 'None'; | |
183 | |
184 case {1} | |
185 if MKL | |
186 if Diagonal == 0 | |
187 REG = @regularizeMKLFull; | |
188 elseif Diagonal == 1 | |
189 REG = @regularizeMKLDiag; | |
190 end | |
191 else | |
192 if Diagonal | |
193 REG = @regularizeTraceDiag; | |
194 else | |
195 REG = @regularizeTraceFull; | |
196 end | |
197 end | |
198 Regularizer = 'Trace'; | |
199 | |
200 case {2} | |
201 if Diagonal | |
202 REG = @regularizeTwoDiag; | |
203 else | |
204 REG = @regularizeTwoFull; | |
205 end | |
206 Regularizer = '2-norm'; | |
207 | |
208 case {3} | |
209 if MKL | |
210 if Diagonal == 0 | |
211 REG = @regularizeMKLFull; | |
212 elseif Diagonal == 1 | |
213 REG = @regularizeMKLDiag; | |
214 end | |
215 else | |
216 if Diagonal | |
217 REG = @regularizeMKLDiag; | |
218 else | |
219 REG = @regularizeKernel; | |
220 end | |
221 end | |
222 Regularizer = 'Kernel'; | |
223 | |
224 otherwise | |
225 error('MLR:REGULARIZER', ... | |
226 'Unknown regularization: %s', varargin{3}); | |
227 end | |
228 end | |
229 | |
230 | |
231 % Are we in stochastic optimization mode? | |
232 if nargin > 7 && varargin{5} > 0 | |
233 if varargin{5} < n | |
234 STOCHASTIC = 1; | |
235 CP = @cuttingPlaneRandom; | |
236 batchSize = varargin{5}; | |
237 end | |
238 end | |
239 % Algorithm | |
240 % | |
241 % Working <- [] | |
242 % | |
243 % repeat: | |
244 % (W, Xi) <- solver(X, Y, C, Working) | |
245 % | |
246 % for i = 1:|X| | |
247 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) | |
248 % | |
249 % Working <- Working + (y^_1,y^_2,...,y^_n) | |
250 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) | |
251 % <= Xi + epsilon | |
252 | |
253 global DEBUG; | |
254 | |
255 if isempty(DEBUG) | |
256 DEBUG = 0; | |
257 end | |
258 | |
259 %%% | |
260 % Timer to eliminate old constraints | |
261 ConstraintClock = 100; | |
262 | |
263 %%% | |
264 % Convergence criteria for worst-violated constraint | |
265 E = 1e-3; | |
266 | |
267 % Initialize | |
268 W = INIT(X); | |
269 | |
270 ClassScores = []; | |
271 | |
272 if isa(Y, 'double') | |
273 Ypos = []; | |
274 Yneg = []; | |
275 ClassScores = synthesizeRelevance(Y); | |
276 | |
277 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 | |
278 dbprint(1, 'Using supplied Ypos/Yneg'); | |
279 Ypos = Y(:,1); | |
280 Yneg = Y(:,2); | |
281 | |
282 % Compute the valid samples | |
283 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); | |
284 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 | |
285 dbprint(1, 'Using supplied Ypos/synthesized Yneg'); | |
286 Ypos = Y(:,1); | |
287 Yneg = []; | |
288 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); | |
289 else | |
290 error('MLR:LABELS', 'Incorrect format for Y.'); | |
291 end | |
292 | |
293 %% | |
294 % If we don't have enough data to make the batch, cut the batch | |
295 batchSize = min([batchSize, length(SAMPLES)]); | |
296 | |
297 | |
298 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? | |
299 'feature', Feature, ... % Which ranking feature is used? | |
300 'k', k, ... % What is the ranking length? | |
301 'regularizer', Regularizer, ... % What regularization is used? | |
302 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal | |
303 'num_calls_SO', 0, ... % Calls to separation oracle | |
304 'num_calls_solver', 0, ... % Calls to solver | |
305 'time_SO', 0, ... % Time in separation oracle | |
306 'time_solver', 0, ... % Time in solver | |
307 'time_total', 0, ... % Total time | |
308 'f', [], ... % Objective value | |
309 'num_steps', [], ... % Number of steps for each solver run | |
310 'num_constraints', [], ... % Number of constraints for each run | |
311 'Xi', [], ... % Slack achieved for each run | |
312 'Delta', [], ... % Mean loss for each SO call | |
313 'gap', [], ... % Gap between loss and slack | |
314 'C', C, ... % Slack trade-off | |
315 'epsilon', E, ... % Convergence threshold | |
316 'feasible_count', 0, ... % Counter for projections | |
317 'constraint_timer', ConstraintClock); % Time before evicting old constraints | |
318 | |
319 | |
320 | |
321 global PsiR; | |
322 global PsiClock; | |
323 | |
324 PsiR = {}; | |
325 PsiClock = []; | |
326 | |
327 Xi = -Inf; | |
328 Margins = []; | |
329 | |
330 if STOCHASTIC | |
331 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); | |
332 end | |
333 | |
334 while 1 | |
335 dbprint(1, 'Round %03d', Diagnostics.num_calls_solver); | |
336 % Generate a constraint set | |
337 Termination = 0; | |
338 | |
339 | |
340 dbprint(2, 'Calling separation oracle...'); | |
341 | |
342 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); | |
343 Termination = LOSS(W, PsiNew, Mnew, 0); | |
344 | |
345 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; | |
346 Diagnostics.time_SO = Diagnostics.time_SO + SO_time; | |
347 | |
348 Margins = cat(1, Margins, Mnew); | |
349 PsiR = cat(1, PsiR, PsiNew); | |
350 PsiClock = cat(1, PsiClock, 0); | |
351 | |
352 dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); | |
353 dbprint(2, '\t Mean loss : %0.4f', Mnew); | |
354 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E); | |
355 | |
356 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi); | |
357 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew); | |
358 | |
359 if Termination <= Xi + E | |
360 dbprint(1, 'Done.'); | |
361 break; | |
362 end | |
363 | |
364 dbprint(1, 'Calling solver...'); | |
365 PsiClock = PsiClock + 1; | |
366 Solver_time = tic(); | |
367 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X); | |
368 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); | |
369 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; | |
370 | |
371 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); | |
372 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); | |
373 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); | |
374 | |
375 %%% | |
376 % Cull the old constraints | |
377 GC = PsiClock < ConstraintClock; | |
378 Margins = Margins(GC); | |
379 PsiR = PsiR(GC); | |
380 PsiClock = PsiClock(GC); | |
381 | |
382 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); | |
383 end | |
384 | |
385 | |
386 % Finish diagnostics | |
387 | |
388 Diagnostics.time_total = toc(TIME_START); | |
389 Diagnostics.feasible_count = FEASIBLE_COUNT; | |
390 end | |
391 | |
392 | |
393 function ClassScores = synthesizeRelevance(Y) | |
394 | |
395 classes = unique(Y); | |
396 nClasses = length(classes); | |
397 | |
398 ClassScores = struct( 'Y', Y, ... | |
399 'classes', classes, ... | |
400 'Ypos', [], ... | |
401 'Yneg', []); | |
402 | |
403 Ypos = cell(nClasses, 1); | |
404 Yneg = cell(nClasses, 1); | |
405 for c = 1:nClasses | |
406 Ypos{c} = (Y == classes(c)); | |
407 Yneg{c} = ~Ypos{c}; | |
408 end | |
409 | |
410 ClassScores.Ypos = Ypos; | |
411 ClassScores.Yneg = Yneg; | |
412 | |
413 end |