wolffd@0
|
1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % [W, Xi, D] = mlr_train(X, Y, C,...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % X = d*n data matrix
|
wolffd@0
|
6 % Y = either n-by-1 label of vectors
|
wolffd@0
|
7 % OR
|
wolffd@0
|
8 % n-by-2 cell array where
|
wolffd@0
|
9 % Y{q,1} contains relevant indices for q, and
|
wolffd@0
|
10 % Y{q,2} contains irrelevant indices for q
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % C >= 0 slack trade-off parameter (default=1)
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % W = the learned metric
|
wolffd@0
|
15 % Xi = slack value on the learned metric
|
wolffd@0
|
16 % D = diagnostics
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % Optional arguments:
|
wolffd@0
|
19 %
|
wolffd@0
|
20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
|
wolffd@0
|
21 % where LOSS is one of:
|
wolffd@0
|
22 % 'AUC': Area under ROC curve (default)
|
wolffd@0
|
23 % 'KNN': KNN accuracy
|
wolffd@0
|
24 % 'Prec@k': Precision-at-k
|
wolffd@0
|
25 % 'MAP': Mean Average Precision
|
wolffd@0
|
26 % 'MRR': Mean Reciprocal Rank
|
wolffd@0
|
27 % 'NDCG': Normalized Discounted Cumulative Gain
|
wolffd@0
|
28 %
|
wolffd@0
|
29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
|
wolffd@0
|
30 % where k is the number of neighbors for Prec@k or NDCG
|
wolffd@0
|
31 % (default=3)
|
wolffd@0
|
32 %
|
wolffd@0
|
33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
|
wolffd@0
|
34 % where REG defines the regularization on W, and is one of:
|
wolffd@0
|
35 % 0: no regularization
|
wolffd@0
|
36 % 1: 1-norm: trace(W) (default)
|
wolffd@0
|
37 % 2: 2-norm: trace(W' * W)
|
wolffd@0
|
38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
|
wolffd@0
|
39 %
|
wolffd@0
|
40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
|
wolffd@0
|
41 % Diagonal = 0: learn a full d-by-d W (default)
|
wolffd@0
|
42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
|
wolffd@0
|
43 %
|
wolffd@0
|
44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
|
wolffd@0
|
45 % where B > 0 enables stochastic optimization with batch size B
|
wolffd@0
|
46 %
|
wolffd@0
|
47 % // added by Daniel Wolff
|
wolffd@0
|
48 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
|
wolffd@0
|
49 % Set ConstraintClock to CC (default: 20, 100)
|
wolffd@0
|
50 %
|
wolffd@0
|
51 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
|
wolffd@0
|
52 % Set ConstraintClock to E (default: 1e-3)
|
wolffd@0
|
53 %
|
wolffd@0
|
54
|
wolffd@0
|
55 TIME_START = tic();
|
wolffd@0
|
56
|
wolffd@0
|
57 global C;
|
wolffd@0
|
58 C = Cslack;
|
wolffd@0
|
59
|
wolffd@0
|
60 [d,n,m] = size(X);
|
wolffd@0
|
61
|
wolffd@0
|
62 if m > 1
|
wolffd@0
|
63 MKL = 1;
|
wolffd@0
|
64 else
|
wolffd@0
|
65 MKL = 0;
|
wolffd@0
|
66 end
|
wolffd@0
|
67
|
wolffd@0
|
68 if nargin < 3
|
wolffd@0
|
69 C = 1;
|
wolffd@0
|
70 end
|
wolffd@0
|
71
|
wolffd@0
|
72 %%%
|
wolffd@0
|
73 % Default options:
|
wolffd@0
|
74
|
wolffd@0
|
75 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT;
|
wolffd@0
|
76
|
wolffd@0
|
77 global FEASIBLE_COUNT;
|
wolffd@0
|
78 FEASIBLE_COUNT = 0;
|
wolffd@0
|
79
|
wolffd@0
|
80 CP = @cuttingPlaneFull;
|
wolffd@0
|
81 SO = @separationOracleAUC;
|
wolffd@0
|
82 PSI = @metricPsiPO;
|
wolffd@0
|
83
|
wolffd@0
|
84 if ~MKL
|
wolffd@0
|
85 INIT = @initializeFull;
|
wolffd@0
|
86 REG = @regularizeTraceFull;
|
wolffd@0
|
87 STRUCTKERNEL= @structKernelLinear;
|
wolffd@0
|
88 DUALW = @dualWLinear;
|
wolffd@0
|
89 FEASIBLE = @feasibleFull;
|
wolffd@0
|
90 CPGRADIENT = @cpGradientFull;
|
wolffd@0
|
91 DISTANCE = @distanceFull;
|
wolffd@0
|
92 SETDISTANCE = @setDistanceFull;
|
wolffd@0
|
93 LOSS = @lossHinge;
|
wolffd@0
|
94 Regularizer = 'Trace';
|
wolffd@0
|
95 else
|
wolffd@0
|
96 INIT = @initializeFullMKL;
|
wolffd@0
|
97 REG = @regularizeMKLFull;
|
wolffd@0
|
98 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
99 DUALW = @dualWMKL;
|
wolffd@0
|
100 FEASIBLE = @feasibleFullMKL;
|
wolffd@0
|
101 CPGRADIENT = @cpGradientFullMKL;
|
wolffd@0
|
102 DISTANCE = @distanceFullMKL;
|
wolffd@0
|
103 SETDISTANCE = @setDistanceFullMKL;
|
wolffd@0
|
104 LOSS = @lossHingeFullMKL;
|
wolffd@0
|
105 Regularizer = 'Trace';
|
wolffd@0
|
106 end
|
wolffd@0
|
107
|
wolffd@0
|
108
|
wolffd@0
|
109 Loss = 'AUC';
|
wolffd@0
|
110 Feature = 'metricPsiPO';
|
wolffd@0
|
111
|
wolffd@0
|
112
|
wolffd@0
|
113 %%%
|
wolffd@0
|
114 % Default k for prec@k, ndcg
|
wolffd@0
|
115 k = 3;
|
wolffd@0
|
116
|
wolffd@0
|
117 %%%
|
wolffd@0
|
118 % Stochastic violator selection?
|
wolffd@0
|
119 STOCHASTIC = 0;
|
wolffd@0
|
120 batchSize = n;
|
wolffd@0
|
121 SAMPLES = 1:n;
|
wolffd@0
|
122
|
wolffd@0
|
123
|
wolffd@0
|
124 if nargin > 3
|
wolffd@0
|
125 switch lower(varargin{1})
|
wolffd@0
|
126 case {'auc'}
|
wolffd@0
|
127 SO = @separationOracleAUC;
|
wolffd@0
|
128 PSI = @metricPsiPO;
|
wolffd@0
|
129 Loss = 'AUC';
|
wolffd@0
|
130 Feature = 'metricPsiPO';
|
wolffd@0
|
131 case {'knn'}
|
wolffd@0
|
132 SO = @separationOracleKNN;
|
wolffd@0
|
133 PSI = @metricPsiPO;
|
wolffd@0
|
134 Loss = 'KNN';
|
wolffd@0
|
135 Feature = 'metricPsiPO';
|
wolffd@0
|
136 case {'prec@k'}
|
wolffd@0
|
137 SO = @separationOraclePrecAtK;
|
wolffd@0
|
138 PSI = @metricPsiPO;
|
wolffd@0
|
139 Loss = 'Prec@k';
|
wolffd@0
|
140 Feature = 'metricPsiPO';
|
wolffd@0
|
141 case {'map'}
|
wolffd@0
|
142 SO = @separationOracleMAP;
|
wolffd@0
|
143 PSI = @metricPsiPO;
|
wolffd@0
|
144 Loss = 'MAP';
|
wolffd@0
|
145 Feature = 'metricPsiPO';
|
wolffd@0
|
146 case {'mrr'}
|
wolffd@0
|
147 SO = @separationOracleMRR;
|
wolffd@0
|
148 PSI = @metricPsiPO;
|
wolffd@0
|
149 Loss = 'MRR';
|
wolffd@0
|
150 Feature = 'metricPsiPO';
|
wolffd@0
|
151 case {'ndcg'}
|
wolffd@0
|
152 SO = @separationOracleNDCG;
|
wolffd@0
|
153 PSI = @metricPsiPO;
|
wolffd@0
|
154 Loss = 'NDCG';
|
wolffd@0
|
155 Feature = 'metricPsiPO';
|
wolffd@0
|
156 otherwise
|
wolffd@0
|
157 error('MLR:LOSS', ...
|
wolffd@0
|
158 'Unknown loss function: %s', varargin{1});
|
wolffd@0
|
159 end
|
wolffd@0
|
160 end
|
wolffd@0
|
161
|
wolffd@0
|
162 if nargin > 4
|
wolffd@0
|
163 k = varargin{2};
|
wolffd@0
|
164 end
|
wolffd@0
|
165
|
wolffd@0
|
166 Diagonal = 0;
|
wolffd@0
|
167 if nargin > 6 & varargin{4} > 0
|
wolffd@0
|
168 Diagonal = varargin{4};
|
wolffd@0
|
169
|
wolffd@0
|
170 if ~MKL
|
wolffd@0
|
171 INIT = @initializeDiag;
|
wolffd@0
|
172 REG = @regularizeTraceDiag;
|
wolffd@0
|
173 STRUCTKERNEL= @structKernelDiag;
|
wolffd@0
|
174 DUALW = @dualWDiag;
|
wolffd@0
|
175 FEASIBLE = @feasibleDiag;
|
wolffd@0
|
176 CPGRADIENT = @cpGradientDiag;
|
wolffd@0
|
177 DISTANCE = @distanceDiag;
|
wolffd@0
|
178 SETDISTANCE = @setDistanceDiag;
|
wolffd@0
|
179 Regularizer = 'Trace';
|
wolffd@0
|
180 else
|
wolffd@0
|
181 INIT = @initializeDiagMKL;
|
wolffd@0
|
182 REG = @regularizeMKLDiag;
|
wolffd@0
|
183 STRUCTKERNEL= @structKernelDiagMKL;
|
wolffd@0
|
184 DUALW = @dualWDiagMKL;
|
wolffd@0
|
185 FEASIBLE = @feasibleDiagMKL;
|
wolffd@0
|
186 CPGRADIENT = @cpGradientDiagMKL;
|
wolffd@0
|
187 DISTANCE = @distanceDiagMKL;
|
wolffd@0
|
188 SETDISTANCE = @setDistanceDiagMKL;
|
wolffd@0
|
189 LOSS = @lossHingeDiagMKL;
|
wolffd@0
|
190 Regularizer = 'Trace';
|
wolffd@0
|
191 end
|
wolffd@0
|
192 end
|
wolffd@0
|
193
|
wolffd@0
|
194 if nargin > 5
|
wolffd@0
|
195 switch(varargin{3})
|
wolffd@0
|
196 case {0}
|
wolffd@0
|
197 REG = @regularizeNone;
|
wolffd@0
|
198 Regularizer = 'None';
|
wolffd@0
|
199
|
wolffd@0
|
200 case {1}
|
wolffd@0
|
201 if MKL
|
wolffd@0
|
202 if Diagonal == 0
|
wolffd@0
|
203 REG = @regularizeMKLFull;
|
wolffd@0
|
204 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
205 DUALW = @dualWMKL;
|
wolffd@0
|
206 elseif Diagonal == 1
|
wolffd@0
|
207 REG = @regularizeMKLDiag;
|
wolffd@0
|
208 STRUCTKERNEL= @structKernelDiagMKL;
|
wolffd@0
|
209 DUALW = @dualWDiagMKL;
|
wolffd@0
|
210 end
|
wolffd@0
|
211 else
|
wolffd@0
|
212 if Diagonal
|
wolffd@0
|
213 REG = @regularizeTraceDiag;
|
wolffd@0
|
214 STRUCTKERNEL= @structKernelDiag;
|
wolffd@0
|
215 DUALW = @dualWDiag;
|
wolffd@0
|
216 else
|
wolffd@0
|
217 REG = @regularizeTraceFull;
|
wolffd@0
|
218 STRUCTKERNEL= @structKernelLinear;
|
wolffd@0
|
219 DUALW = @dualWLinear;
|
wolffd@0
|
220 end
|
wolffd@0
|
221 end
|
wolffd@0
|
222 Regularizer = 'Trace';
|
wolffd@0
|
223
|
wolffd@0
|
224 case {2}
|
wolffd@0
|
225 if Diagonal
|
wolffd@0
|
226 REG = @regularizeTwoDiag;
|
wolffd@0
|
227 else
|
wolffd@0
|
228 REG = @regularizeTwoFull;
|
wolffd@0
|
229 end
|
wolffd@0
|
230 Regularizer = '2-norm';
|
wolffd@0
|
231 error('MLR:REGULARIZER', '2-norm regularization no longer supported');
|
wolffd@0
|
232
|
wolffd@0
|
233
|
wolffd@0
|
234 case {3}
|
wolffd@0
|
235 if MKL
|
wolffd@0
|
236 if Diagonal == 0
|
wolffd@0
|
237 REG = @regularizeMKLFull;
|
wolffd@0
|
238 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
239 DUALW = @dualWMKL;
|
wolffd@0
|
240 elseif Diagonal == 1
|
wolffd@0
|
241 REG = @regularizeMKLDiag;
|
wolffd@0
|
242 STRUCTKERNEL= @structKernelDiagMKL;
|
wolffd@0
|
243 DUALW = @dualWDiagMKL;
|
wolffd@0
|
244 end
|
wolffd@0
|
245 else
|
wolffd@0
|
246 if Diagonal
|
wolffd@0
|
247 REG = @regularizeMKLDiag;
|
wolffd@0
|
248 STRUCTKERNEL= @structKernelDiagMKL;
|
wolffd@0
|
249 DUALW = @dualWDiagMKL;
|
wolffd@0
|
250 else
|
wolffd@0
|
251 REG = @regularizeKernel;
|
wolffd@0
|
252 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
253 DUALW = @dualWMKL;
|
wolffd@0
|
254 end
|
wolffd@0
|
255 end
|
wolffd@0
|
256 Regularizer = 'Kernel';
|
wolffd@0
|
257
|
wolffd@0
|
258 otherwise
|
wolffd@0
|
259 error('MLR:REGULARIZER', ...
|
wolffd@0
|
260 'Unknown regularization: %s', varargin{3});
|
wolffd@0
|
261 end
|
wolffd@0
|
262 end
|
wolffd@0
|
263
|
wolffd@0
|
264
|
wolffd@0
|
265 % Are we in stochastic optimization mode?
|
wolffd@0
|
266 if nargin > 7 && varargin{5} > 0
|
wolffd@0
|
267 if varargin{5} < n
|
wolffd@0
|
268 STOCHASTIC = 1;
|
wolffd@0
|
269 CP = @cuttingPlaneRandom;
|
wolffd@0
|
270 batchSize = varargin{5};
|
wolffd@0
|
271 end
|
wolffd@0
|
272 end
|
wolffd@0
|
273
|
wolffd@0
|
274 % Algorithm
|
wolffd@0
|
275 %
|
wolffd@0
|
276 % Working <- []
|
wolffd@0
|
277 %
|
wolffd@0
|
278 % repeat:
|
wolffd@0
|
279 % (W, Xi) <- solver(X, Y, C, Working)
|
wolffd@0
|
280 %
|
wolffd@0
|
281 % for i = 1:|X|
|
wolffd@0
|
282 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
|
wolffd@0
|
283 %
|
wolffd@0
|
284 % Working <- Working + (y^_1,y^_2,...,y^_n)
|
wolffd@0
|
285 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
|
wolffd@0
|
286 % <= Xi + epsilon
|
wolffd@0
|
287
|
wolffd@0
|
288 global DEBUG;
|
wolffd@0
|
289
|
wolffd@0
|
290 if isempty(DEBUG)
|
wolffd@0
|
291 DEBUG = 0;
|
wolffd@0
|
292 end
|
wolffd@0
|
293
|
wolffd@0
|
294 %%%
|
wolffd@0
|
295 % Timer to eliminate old constraints
|
wolffd@0
|
296
|
wolffd@0
|
297 %%%
|
wolffd@0
|
298 % Timer to eliminate old constraints
|
wolffd@0
|
299 ConstraintClock = 100; % standard: 100
|
wolffd@0
|
300
|
wolffd@0
|
301 if nargin > 8 && varargin{6} > 0
|
wolffd@0
|
302 ConstraintClock = varargin{6};
|
wolffd@0
|
303 end
|
wolffd@0
|
304
|
wolffd@0
|
305 %%%
|
wolffd@0
|
306 % Convergence criteria for worst-violated constraint
|
wolffd@0
|
307 E = 1e-3;
|
wolffd@0
|
308 if nargin > 9 && varargin{7} > 0
|
wolffd@0
|
309 E = varargin{7};
|
wolffd@0
|
310 end
|
wolffd@0
|
311
|
wolffd@0
|
312
|
wolffd@0
|
313 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
314 % no longer belongs here
|
wolffd@0
|
315 % Initialize
|
wolffd@0
|
316 W = INIT(X);
|
wolffd@0
|
317
|
wolffd@0
|
318
|
wolffd@0
|
319 global ADMM_Z ADMM_U RHO;
|
wolffd@0
|
320 ADMM_Z = W;
|
wolffd@0
|
321 ADMM_U = 0 * ADMM_Z;
|
wolffd@0
|
322
|
wolffd@0
|
323 %%%
|
wolffd@0
|
324 % Augmented lagrangian factor
|
wolffd@0
|
325 RHO = 1;
|
wolffd@0
|
326
|
wolffd@0
|
327 ClassScores = [];
|
wolffd@0
|
328
|
wolffd@0
|
329 if isa(Y, 'double')
|
wolffd@0
|
330 Ypos = [];
|
wolffd@0
|
331 Yneg = [];
|
wolffd@0
|
332 ClassScores = synthesizeRelevance(Y);
|
wolffd@0
|
333
|
wolffd@0
|
334 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
|
wolffd@0
|
335 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
|
wolffd@0
|
336 Ypos = Y(:,1);
|
wolffd@0
|
337 Yneg = Y(:,2);
|
wolffd@0
|
338
|
wolffd@0
|
339 % Compute the valid samples
|
wolffd@0
|
340 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
|
wolffd@0
|
341 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
|
wolffd@0
|
342 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
|
wolffd@0
|
343 Ypos = Y(:,1);
|
wolffd@0
|
344 Yneg = [];
|
wolffd@0
|
345 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
|
wolffd@0
|
346 else
|
wolffd@0
|
347 error('MLR:LABELS', 'Incorrect format for Y.');
|
wolffd@0
|
348 end
|
wolffd@0
|
349
|
wolffd@0
|
350 %%
|
wolffd@0
|
351 % If we don't have enough data to make the batch, cut the batch
|
wolffd@0
|
352 batchSize = min([batchSize, length(SAMPLES)]);
|
wolffd@0
|
353
|
wolffd@0
|
354
|
wolffd@0
|
355 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
|
wolffd@0
|
356 'feature', Feature, ... % Which ranking feature is used?
|
wolffd@0
|
357 'k', k, ... % What is the ranking length?
|
wolffd@0
|
358 'regularizer', Regularizer, ... % What regularization is used?
|
wolffd@0
|
359 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
|
wolffd@0
|
360 'num_calls_SO', 0, ... % Calls to separation oracle
|
wolffd@0
|
361 'num_calls_solver', 0, ... % Calls to solver
|
wolffd@0
|
362 'time_SO', 0, ... % Time in separation oracle
|
wolffd@0
|
363 'time_solver', 0, ... % Time in solver
|
wolffd@0
|
364 'time_total', 0, ... % Total time
|
wolffd@0
|
365 'f', [], ... % Objective value
|
wolffd@0
|
366 'num_steps', [], ... % Number of steps for each solver run
|
wolffd@0
|
367 'num_constraints', [], ... % Number of constraints for each run
|
wolffd@0
|
368 'Xi', [], ... % Slack achieved for each run
|
wolffd@0
|
369 'Delta', [], ... % Mean loss for each SO call
|
wolffd@0
|
370 'gap', [], ... % Gap between loss and slack
|
wolffd@0
|
371 'C', C, ... % Slack trade-off
|
wolffd@0
|
372 'epsilon', E, ... % Convergence threshold
|
wolffd@0
|
373 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's
|
wolffd@0
|
374 'constraint_timer', ConstraintClock); % Time before evicting old constraints
|
wolffd@0
|
375
|
wolffd@0
|
376
|
wolffd@0
|
377
|
wolffd@0
|
378 global PsiR;
|
wolffd@0
|
379 global PsiClock;
|
wolffd@0
|
380
|
wolffd@0
|
381 PsiR = {};
|
wolffd@0
|
382 PsiClock = [];
|
wolffd@0
|
383
|
wolffd@0
|
384 Xi = -Inf;
|
wolffd@0
|
385 Margins = [];
|
wolffd@0
|
386 H = [];
|
wolffd@0
|
387 Q = [];
|
wolffd@0
|
388
|
wolffd@0
|
389 if STOCHASTIC
|
wolffd@0
|
390 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
|
wolffd@0
|
391 end
|
wolffd@0
|
392
|
wolffd@0
|
393 MAXITER = 200;
|
wolffd@0
|
394 % while 1
|
wolffd@0
|
395 while Diagnostics.num_calls_solver < MAXITER
|
wolffd@0
|
396 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
|
wolffd@0
|
397 % Generate a constraint set
|
wolffd@0
|
398 Termination = 0;
|
wolffd@0
|
399
|
wolffd@0
|
400
|
wolffd@0
|
401 dbprint(2, 'Calling separation oracle...');
|
wolffd@0
|
402
|
wolffd@0
|
403 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
|
wolffd@0
|
404 Termination = LOSS(W, PsiNew, Mnew, 0);
|
wolffd@0
|
405
|
wolffd@0
|
406 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
|
wolffd@0
|
407 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
|
wolffd@0
|
408
|
wolffd@0
|
409 Margins = cat(1, Margins, Mnew);
|
wolffd@0
|
410 PsiR = cat(1, PsiR, PsiNew);
|
wolffd@0
|
411 PsiClock = cat(1, PsiClock, 0);
|
wolffd@0
|
412 H = expandKernel(H);
|
wolffd@0
|
413 Q = expandRegularizer(Q, X, W);
|
wolffd@0
|
414
|
wolffd@0
|
415
|
wolffd@0
|
416 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
|
wolffd@0
|
417 dbprint(2, '\t Mean loss : %0.4f', Mnew);
|
wolffd@0
|
418 dbprint(2, '\t Current loss Xi : %0.4f', Xi);
|
wolffd@0
|
419 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
|
wolffd@0
|
420
|
wolffd@0
|
421 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
|
wolffd@0
|
422 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
|
wolffd@0
|
423
|
wolffd@0
|
424 if Termination <= Xi + E
|
wolffd@0
|
425 dbprint(2, 'Done.');
|
wolffd@0
|
426 break;
|
wolffd@0
|
427 end
|
wolffd@0
|
428
|
wolffd@0
|
429 dbprint(2, 'Calling solver...');
|
wolffd@0
|
430 PsiClock = PsiClock + 1;
|
wolffd@0
|
431 Solver_time = tic();
|
wolffd@0
|
432 [W, Xi, Dsolver] = mlr_admm(C, X, Margins, H, Q);
|
wolffd@0
|
433 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
|
wolffd@0
|
434 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
|
wolffd@0
|
435
|
wolffd@0
|
436 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
|
wolffd@0
|
437 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
|
wolffd@0
|
438 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
|
wolffd@0
|
439
|
wolffd@0
|
440 %%%
|
wolffd@0
|
441 % Cull the old constraints
|
wolffd@0
|
442 GC = PsiClock < ConstraintClock;
|
wolffd@0
|
443 Margins = Margins(GC);
|
wolffd@0
|
444 PsiR = PsiR(GC);
|
wolffd@0
|
445 PsiClock = PsiClock(GC);
|
wolffd@0
|
446 H = H(GC, GC);
|
wolffd@0
|
447 Q = Q(GC);
|
wolffd@0
|
448
|
wolffd@0
|
449 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
|
wolffd@0
|
450 end
|
wolffd@0
|
451
|
wolffd@0
|
452
|
wolffd@0
|
453 % Finish diagnostics
|
wolffd@0
|
454
|
wolffd@0
|
455 Diagnostics.time_total = toc(TIME_START);
|
wolffd@0
|
456 Diagnostics.feasible_count = FEASIBLE_COUNT;
|
wolffd@0
|
457 end
|
wolffd@0
|
458
|
wolffd@0
|
459 function H = expandKernel(H)
|
wolffd@0
|
460
|
wolffd@0
|
461 global STRUCTKERNEL;
|
wolffd@0
|
462 global PsiR;
|
wolffd@0
|
463
|
wolffd@0
|
464 m = length(H);
|
wolffd@0
|
465 H = padarray(H, [1 1], 0, 'post');
|
wolffd@0
|
466
|
wolffd@0
|
467
|
wolffd@0
|
468 for i = 1:m+1
|
wolffd@0
|
469 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
|
wolffd@0
|
470 H(m+1, i) = H(i, m+1);
|
wolffd@0
|
471 end
|
wolffd@0
|
472 end
|
wolffd@0
|
473
|
wolffd@0
|
474 function Q = expandRegularizer(Q, K, W)
|
wolffd@0
|
475
|
wolffd@0
|
476 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
477 % does not support unregularized learning
|
wolffd@0
|
478
|
wolffd@0
|
479 global PsiR;
|
wolffd@0
|
480 global STRUCTKERNEL REG;
|
wolffd@0
|
481
|
wolffd@0
|
482 m = length(Q);
|
wolffd@0
|
483 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});
|
wolffd@0
|
484
|
wolffd@0
|
485 end
|
wolffd@0
|
486
|
wolffd@0
|
487 function ClassScores = synthesizeRelevance(Y)
|
wolffd@0
|
488
|
wolffd@0
|
489 classes = unique(Y);
|
wolffd@0
|
490 nClasses = length(classes);
|
wolffd@0
|
491
|
wolffd@0
|
492 ClassScores = struct( 'Y', Y, ...
|
wolffd@0
|
493 'classes', classes, ...
|
wolffd@0
|
494 'Ypos', [], ...
|
wolffd@0
|
495 'Yneg', []);
|
wolffd@0
|
496
|
wolffd@0
|
497 Ypos = cell(nClasses, 1);
|
wolffd@0
|
498 Yneg = cell(nClasses, 1);
|
wolffd@0
|
499 for c = 1:nClasses
|
wolffd@0
|
500 Ypos{c} = (Y == classes(c));
|
wolffd@0
|
501 Yneg{c} = ~Ypos{c};
|
wolffd@0
|
502 end
|
wolffd@0
|
503
|
wolffd@0
|
504 ClassScores.Ypos = Ypos;
|
wolffd@0
|
505 ClassScores.Yneg = Yneg;
|
wolffd@0
|
506
|
wolffd@0
|
507 end
|