wolffd@0
|
1 function [W, Xi, Diagnostics] = rmlr_train(X, Y, Cslack, varargin)
|
wolffd@0
|
2 %[W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, stochastic, lam, STRUCTREG)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % [W, Xi, D] = rmlr_train(X, Y, C,...)
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % X = d*n data matrix
|
wolffd@0
|
7 % Y = either n-by-1 label of vectors
|
wolffd@0
|
8 % OR
|
wolffd@0
|
9 % n-by-2 cell array where
|
wolffd@0
|
10 % Y{q,1} contains relevant indices for q, and
|
wolffd@0
|
11 % Y{q,2} contains irrelevant indices for q
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % C >= 0 slack trade-off parameter (default=1)
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % W = the learned metric
|
wolffd@0
|
16 % Xi = slack value on the learned metric
|
wolffd@0
|
17 % D = diagnostics
|
wolffd@0
|
18 %
|
wolffd@0
|
19 % Optional arguments:
|
wolffd@0
|
20 %
|
wolffd@0
|
21 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS)
|
wolffd@0
|
22 % where LOSS is one of:
|
wolffd@0
|
23 % 'AUC': Area under ROC curve (default)
|
wolffd@0
|
24 % 'KNN': KNN accuracy
|
wolffd@0
|
25 % 'Prec@k': Precision-at-k
|
wolffd@0
|
26 % 'MAP': Mean Average Precision
|
wolffd@0
|
27 % 'MRR': Mean Reciprocal Rank
|
wolffd@0
|
28 % 'NDCG': Normalized Discounted Cumulative Gain
|
wolffd@0
|
29 %
|
wolffd@0
|
30 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k)
|
wolffd@0
|
31 % where k is the number of neighbors for Prec@k or NDCG
|
wolffd@0
|
32 % (default=3)
|
wolffd@0
|
33 %
|
wolffd@0
|
34 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG)
|
wolffd@0
|
35 % where REG defines the regularization on W, and is one of:
|
wolffd@0
|
36 % 0: no regularization
|
wolffd@0
|
37 % 1: 1-norm: trace(W) (default)
|
wolffd@0
|
38 % 2: 2-norm: trace(W' * W)
|
wolffd@0
|
39 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
|
wolffd@0
|
40 %
|
wolffd@0
|
41 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal)
|
wolffd@0
|
42 % Not implemented as learning a diagonal W metric just reduces to MLR because ||W||_2,1 = trace(W) when W is diagonal.
|
wolffd@0
|
43 %
|
wolffd@0
|
44 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
|
wolffd@0
|
45 % where B > 0 enables stochastic optimization with batch size B
|
wolffd@0
|
46 %
|
wolffd@0
|
47 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda)
|
wolffd@0
|
48 % lambda is the desired value of the hyperparameter which is the coefficient of ||W||_2,1. Default is 1 if lambda is not set
|
wolffd@0
|
49 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC)
|
wolffd@0
|
50 % Set ConstraintClock to CC (default: 20, 100)
|
wolffd@0
|
51 %
|
wolffd@0
|
52 % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC, E)
|
wolffd@0
|
53 % Set ConstraintClock to E (default: 1e-3)
|
wolffd@0
|
54 %
|
wolffd@0
|
55
|
wolffd@0
|
56
|
wolffd@0
|
57 TIME_START = tic();
|
wolffd@0
|
58
|
wolffd@0
|
59 global C;
|
wolffd@0
|
60 C = Cslack;
|
wolffd@0
|
61
|
wolffd@0
|
62 [d,n,m] = size(X);
|
wolffd@0
|
63
|
wolffd@0
|
64 if m > 1
|
wolffd@0
|
65 MKL = 1;
|
wolffd@0
|
66 else
|
wolffd@0
|
67 MKL = 0;
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 if nargin < 3
|
wolffd@0
|
71 C = 1;
|
wolffd@0
|
72 end
|
wolffd@0
|
73
|
wolffd@0
|
74 %%%
|
wolffd@0
|
75 % Default options:
|
wolffd@0
|
76
|
wolffd@0
|
77 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW THRESH INIT;
|
wolffd@0
|
78 global RHO;
|
wolffd@0
|
79
|
wolffd@0
|
80
|
wolffd@0
|
81 %%%
|
wolffd@0
|
82 % Augmented lagrangian factor
|
wolffd@0
|
83 RHO = 1;
|
wolffd@0
|
84
|
wolffd@0
|
85 % </modified>
|
wolffd@0
|
86
|
wolffd@0
|
87 global FEASIBLE_COUNT;
|
wolffd@0
|
88 FEASIBLE_COUNT = 0;
|
wolffd@0
|
89
|
wolffd@0
|
90 CP = @cuttingPlaneFull;
|
wolffd@0
|
91 SO = @separationOracleAUC;
|
wolffd@0
|
92 PSI = @metricPsiPO;
|
wolffd@0
|
93
|
wolffd@0
|
94 if ~MKL
|
wolffd@0
|
95 INIT = @initializeFull;
|
wolffd@0
|
96 REG = @regularizeTraceFull;
|
wolffd@0
|
97 STRUCTKERNEL= @structKernelLinear;
|
wolffd@0
|
98 DUALW = @dualWLinear;
|
wolffd@0
|
99 FEASIBLE = @feasibleFull;
|
wolffd@0
|
100 THRESH = @threshFull_admmMixed;
|
wolffd@0
|
101 CPGRADIENT = @cpGradientFull;
|
wolffd@0
|
102 DISTANCE = @distanceFull;
|
wolffd@0
|
103 SETDISTANCE = @setDistanceFull;
|
wolffd@0
|
104 LOSS = @lossHinge;
|
wolffd@0
|
105 Regularizer = 'Trace';
|
wolffd@0
|
106 else
|
wolffd@0
|
107 INIT = @initializeFullMKL;
|
wolffd@0
|
108 REG = @regularizeMKLFull;
|
wolffd@0
|
109 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
110 DUALW = @dualWMKL;
|
wolffd@0
|
111 FEASIBLE = @feasibleFullMKL;
|
wolffd@0
|
112 THRESH = @threshFull_admmMixed;
|
wolffd@0
|
113 CPGRADIENT = @cpGradientFullMKL;
|
wolffd@0
|
114 DISTANCE = @distanceFullMKL;
|
wolffd@0
|
115 SETDISTANCE = @setDistanceFullMKL;
|
wolffd@0
|
116 LOSS = @lossHingeFullMKL;
|
wolffd@0
|
117 Regularizer = 'Trace';
|
wolffd@0
|
118 end
|
wolffd@0
|
119
|
wolffd@0
|
120
|
wolffd@0
|
121 Loss = 'AUC';
|
wolffd@0
|
122 Feature = 'metricPsiPO';
|
wolffd@0
|
123
|
wolffd@0
|
124
|
wolffd@0
|
125 %%%
|
wolffd@0
|
126 % Default k for prec@k, ndcg
|
wolffd@0
|
127 k = 3;
|
wolffd@0
|
128
|
wolffd@0
|
129 %%%
|
wolffd@0
|
130 % Stochastic violator selection?
|
wolffd@0
|
131 STOCHASTIC = 0;
|
wolffd@0
|
132 batchSize = n;
|
wolffd@0
|
133 SAMPLES = 1:n;
|
wolffd@0
|
134
|
wolffd@0
|
135
|
wolffd@0
|
136 if nargin > 3
|
wolffd@0
|
137 switch lower(varargin{1})
|
wolffd@0
|
138 case {'auc'}
|
wolffd@0
|
139 SO = @separationOracleAUC;
|
wolffd@0
|
140 PSI = @metricPsiPO;
|
wolffd@0
|
141 Loss = 'AUC';
|
wolffd@0
|
142 Feature = 'metricPsiPO';
|
wolffd@0
|
143 case {'knn'}
|
wolffd@0
|
144 SO = @separationOracleKNN;
|
wolffd@0
|
145 PSI = @metricPsiPO;
|
wolffd@0
|
146 Loss = 'KNN';
|
wolffd@0
|
147 Feature = 'metricPsiPO';
|
wolffd@0
|
148 case {'prec@k'}
|
wolffd@0
|
149 SO = @separationOraclePrecAtK;
|
wolffd@0
|
150 PSI = @metricPsiPO;
|
wolffd@0
|
151 Loss = 'Prec@k';
|
wolffd@0
|
152 Feature = 'metricPsiPO';
|
wolffd@0
|
153 case {'map'}
|
wolffd@0
|
154 SO = @separationOracleMAP;
|
wolffd@0
|
155 PSI = @metricPsiPO;
|
wolffd@0
|
156 Loss = 'MAP';
|
wolffd@0
|
157 Feature = 'metricPsiPO';
|
wolffd@0
|
158 case {'mrr'}
|
wolffd@0
|
159 SO = @separationOracleMRR;
|
wolffd@0
|
160 PSI = @metricPsiPO;
|
wolffd@0
|
161 Loss = 'MRR';
|
wolffd@0
|
162 Feature = 'metricPsiPO';
|
wolffd@0
|
163 case {'ndcg'}
|
wolffd@0
|
164 SO = @separationOracleNDCG;
|
wolffd@0
|
165 PSI = @metricPsiPO;
|
wolffd@0
|
166 Loss = 'NDCG';
|
wolffd@0
|
167 Feature = 'metricPsiPO';
|
wolffd@0
|
168 otherwise
|
wolffd@0
|
169 error('MLR:LOSS', ...
|
wolffd@0
|
170 'Unknown loss function: %s', varargin{1});
|
wolffd@0
|
171 end
|
wolffd@0
|
172 end
|
wolffd@0
|
173
|
wolffd@0
|
174 if nargin > 4
|
wolffd@0
|
175 k = varargin{2};
|
wolffd@0
|
176 end
|
wolffd@0
|
177
|
wolffd@0
|
178 Diagonal = 0;
|
wolffd@0
|
179 %Diagonal case is not implemented. Use mlr_train for that.
|
wolffd@0
|
180
|
wolffd@0
|
181 if nargin > 5
|
wolffd@0
|
182 switch(varargin{3})
|
wolffd@0
|
183 case {0}
|
wolffd@0
|
184 REG = @regularizeNone;
|
wolffd@0
|
185 Regularizer = 'None';
|
wolffd@0
|
186 THRESH = @threshFull_admmMixed;
|
wolffd@0
|
187 case {1}
|
wolffd@0
|
188 if MKL
|
wolffd@0
|
189 REG = @regularizeMKLFull;
|
wolffd@0
|
190 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
191 DUALW = @dualWMKL;
|
wolffd@0
|
192 else
|
wolffd@0
|
193 REG = @regularizeTraceFull;
|
wolffd@0
|
194 STRUCTKERNEL= @structKernelLinear;
|
wolffd@0
|
195 DUALW = @dualWLinear;
|
wolffd@0
|
196 end
|
wolffd@0
|
197 Regularizer = 'Trace';
|
wolffd@0
|
198
|
wolffd@0
|
199 case {2}
|
wolffd@0
|
200 REG = @regularizeTwoFull;
|
wolffd@0
|
201 Regularizer = '2-norm';
|
wolffd@0
|
202 error('MLR:REGULARIZER', '2-norm regularization no longer supported');
|
wolffd@0
|
203
|
wolffd@0
|
204
|
wolffd@0
|
205 case {3}
|
wolffd@0
|
206 if MKL
|
wolffd@0
|
207 REG = @regularizeMKLFull;
|
wolffd@0
|
208 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
209 DUALW = @dualWMKL;
|
wolffd@0
|
210 else
|
wolffd@0
|
211 REG = @regularizeKernel;
|
wolffd@0
|
212 STRUCTKERNEL= @structKernelMKL;
|
wolffd@0
|
213 DUALW = @dualWMKL;
|
wolffd@0
|
214 end
|
wolffd@0
|
215 Regularizer = 'Kernel';
|
wolffd@0
|
216
|
wolffd@0
|
217
|
wolffd@0
|
218 otherwise
|
wolffd@0
|
219 error('MLR:REGULARIZER', ...
|
wolffd@0
|
220 'Unknown regularization: %s', varargin{3});
|
wolffd@0
|
221 end
|
wolffd@0
|
222 end
|
wolffd@0
|
223
|
wolffd@0
|
224
|
wolffd@0
|
225 % Are we in stochastic optimization mode?
|
wolffd@0
|
226 if nargin > 7 && varargin{5} > 0
|
wolffd@0
|
227 if varargin{5} < n
|
wolffd@0
|
228 STOCHASTIC = 1;
|
wolffd@0
|
229 CP = @cuttingPlaneRandom;
|
wolffd@0
|
230 batchSize = varargin{5};
|
wolffd@0
|
231 end
|
wolffd@0
|
232 end
|
wolffd@0
|
233 % Algorithm
|
wolffd@0
|
234 %
|
wolffd@0
|
235 % Working <- []
|
wolffd@0
|
236 %
|
wolffd@0
|
237 % repeat:
|
wolffd@0
|
238 % (W, Xi) <- solver(X, Y, C, Working)
|
wolffd@0
|
239 %
|
wolffd@0
|
240 % for i = 1:|X|
|
wolffd@0
|
241 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
|
wolffd@0
|
242 %
|
wolffd@0
|
243 % Working <- Working + (y^_1,y^_2,...,y^_n)
|
wolffd@0
|
244 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
|
wolffd@0
|
245 % <= Xi + epsilon
|
wolffd@0
|
246
|
wolffd@0
|
247 if nargin > 8
|
wolffd@0
|
248 lam = varargin{6};
|
wolffd@0
|
249 else
|
wolffd@0
|
250 lam = 1;
|
wolffd@0
|
251 end
|
wolffd@0
|
252
|
wolffd@0
|
253 disp(['lam = ' num2str(lam)])
|
wolffd@0
|
254
|
wolffd@0
|
255 global DEBUG;
|
wolffd@0
|
256
|
wolffd@0
|
257 if isempty(DEBUG)
|
wolffd@0
|
258 DEBUG = 0;
|
wolffd@0
|
259 end
|
wolffd@0
|
260
|
wolffd@0
|
261 DEBUG = 1;
|
wolffd@0
|
262
|
wolffd@0
|
263 %%%
|
wolffd@0
|
264 % Max calls to seperation oracle
|
wolffd@0
|
265 MAX_CALLS = 200;
|
wolffd@0
|
266 MIN_CALLS = 10;
|
wolffd@0
|
267
|
wolffd@0
|
268 %%%
|
wolffd@0
|
269 % Timer to eliminate old constraints
|
wolffd@0
|
270 ConstraintClock = 500; % standard: 500
|
wolffd@0
|
271
|
wolffd@0
|
272 if nargin > 9 && varargin{7} > 0
|
wolffd@0
|
273 ConstraintClock = varargin{7};
|
wolffd@0
|
274 end
|
wolffd@0
|
275
|
wolffd@0
|
276 %%%
|
wolffd@0
|
277 % Convergence criteria for worst-violated constraint
|
wolffd@0
|
278 E = 1e-3;
|
wolffd@0
|
279 if nargin > 10 && varargin{8} > 0
|
wolffd@0
|
280 E = varargin{8};
|
wolffd@0
|
281 end
|
wolffd@0
|
282
|
wolffd@0
|
283 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
284 % no longer belongs here
|
wolffd@0
|
285 % Initialize
|
wolffd@0
|
286 W = INIT(X);
|
wolffd@0
|
287
|
wolffd@0
|
288
|
wolffd@0
|
289 global ADMM_Z ADMM_V ADMM_UW ADMM_UV;
|
wolffd@0
|
290 ADMM_Z = W;
|
wolffd@0
|
291 ADMM_V = W;
|
wolffd@0
|
292 ADMM_UW = 0 * ADMM_Z;
|
wolffd@0
|
293 ADMM_UV = 0 * ADMM_Z;
|
wolffd@0
|
294
|
wolffd@0
|
295 ClassScores = [];
|
wolffd@0
|
296
|
wolffd@0
|
297 if isa(Y, 'double')
|
wolffd@0
|
298 Ypos = [];
|
wolffd@0
|
299 Yneg = [];
|
wolffd@0
|
300 ClassScores = synthesizeRelevance(Y);
|
wolffd@0
|
301
|
wolffd@0
|
302 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
|
wolffd@0
|
303 dbprint(2, 'Using supplied Ypos/Yneg');
|
wolffd@0
|
304 Ypos = Y(:,1);
|
wolffd@0
|
305 Yneg = Y(:,2);
|
wolffd@0
|
306
|
wolffd@0
|
307 % Compute the valid samples
|
wolffd@0
|
308 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
|
wolffd@0
|
309 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
|
wolffd@0
|
310 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
|
wolffd@0
|
311 Ypos = Y(:,1);
|
wolffd@0
|
312 Yneg = [];
|
wolffd@0
|
313 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
|
wolffd@0
|
314 else
|
wolffd@0
|
315 error('MLR:LABELS', 'Incorrect format for Y.');
|
wolffd@0
|
316 end
|
wolffd@0
|
317 %%
|
wolffd@0
|
318 % If we don't have enough data to make the batch, cut the batch
|
wolffd@0
|
319 batchSize = min([batchSize, length(SAMPLES)]);
|
wolffd@0
|
320
|
wolffd@0
|
321 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
|
wolffd@0
|
322 'feature', Feature, ... % Which ranking feature is used?
|
wolffd@0
|
323 'k', k, ... % What is the ranking length?
|
wolffd@0
|
324 'regularizer', Regularizer, ... % What regularization is used?
|
wolffd@0
|
325 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
|
wolffd@0
|
326 'num_calls_SO', 0, ... % Calls to separation oracle
|
wolffd@0
|
327 'num_calls_solver', 0, ... % Calls to solver
|
wolffd@0
|
328 'time_SO', 0, ... % Time in separation oracle
|
wolffd@0
|
329 'time_solver', 0, ... % Time in solver
|
wolffd@0
|
330 'time_total', 0, ... % Total time
|
wolffd@0
|
331 'f', [], ... % Objective value
|
wolffd@0
|
332 'num_steps', [], ... % Number of steps for each solver run
|
wolffd@0
|
333 'num_constraints', [], ... % Number of constraints for each run
|
wolffd@0
|
334 'Xi', [], ... % Slack achieved for each run
|
wolffd@0
|
335 'Delta', [], ... % Mean loss for each SO call
|
wolffd@0
|
336 'gap', [], ... % Gap between loss and slack
|
wolffd@0
|
337 'C', C, ... % Slack trade-off
|
wolffd@0
|
338 'epsilon', E, ... % Convergence threshold
|
wolffd@0
|
339 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's
|
wolffd@0
|
340 'constraint_timer', ConstraintClock); % Time before evicting old constraints
|
wolffd@0
|
341
|
wolffd@0
|
342
|
wolffd@0
|
343
|
wolffd@0
|
344 global PsiR;
|
wolffd@0
|
345 global PsiClock;
|
wolffd@0
|
346
|
wolffd@0
|
347 PsiR = {};
|
wolffd@0
|
348 PsiClock = [];
|
wolffd@0
|
349
|
wolffd@0
|
350 Xi = -Inf;
|
wolffd@0
|
351 Margins = [];
|
wolffd@0
|
352 H = [];
|
wolffd@0
|
353 Q = [];
|
wolffd@0
|
354
|
wolffd@0
|
355 if STOCHASTIC
|
wolffd@0
|
356 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
|
wolffd@0
|
357 end
|
wolffd@0
|
358
|
wolffd@0
|
359 dbprint(2,['Regularizer is "' Regularizer '"']);
|
wolffd@0
|
360 while 1
|
wolffd@0
|
361 if Diagnostics.num_calls_solver > MAX_CALLS
|
wolffd@0
|
362 dbprint(2,['Calls to SO >= ' num2str(MAX_CALLS)]);
|
wolffd@0
|
363 break;
|
wolffd@0
|
364 end
|
wolffd@0
|
365
|
wolffd@0
|
366 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
|
wolffd@0
|
367 % Generate a constraint set
|
wolffd@0
|
368 Termination = -Inf;
|
wolffd@0
|
369
|
wolffd@0
|
370
|
wolffd@0
|
371 dbprint(2, 'Calling separation oracle...');
|
wolffd@0
|
372 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
|
wolffd@0
|
373
|
wolffd@0
|
374 Termination = LOSS(W, PsiNew, Mnew, 0);
|
wolffd@0
|
375 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
|
wolffd@0
|
376 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
|
wolffd@0
|
377
|
wolffd@0
|
378 Margins = cat(1, Margins, Mnew);
|
wolffd@0
|
379 PsiR = cat(1, PsiR, PsiNew);
|
wolffd@0
|
380 PsiClock = cat(1, PsiClock, 0);
|
wolffd@0
|
381 H = expandKernel(H);
|
wolffd@0
|
382 Q = expandRegularizer(Q, X, W);
|
wolffd@0
|
383
|
wolffd@0
|
384
|
wolffd@0
|
385 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
|
wolffd@0
|
386 dbprint(2, '\t Mean loss : %0.4f', Mnew);
|
wolffd@0
|
387 dbprint(2, '\t Current loss Xi : %0.4f', Xi);
|
wolffd@0
|
388 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
|
wolffd@0
|
389
|
wolffd@0
|
390 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
|
wolffd@0
|
391 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
|
wolffd@0
|
392
|
wolffd@0
|
393 %if Termination <= Xi + E
|
wolffd@0
|
394 if Termination <= Xi + E && Diagnostics.num_calls_solver > MIN_CALLS
|
wolffd@0
|
395 %if Termination - Xi <= E
|
wolffd@0
|
396 dbprint(1, 'Done.');
|
wolffd@0
|
397 break;
|
wolffd@0
|
398 end
|
wolffd@0
|
399
|
wolffd@0
|
400
|
wolffd@0
|
401
|
wolffd@0
|
402 dbprint(1, 'Calling solver...');
|
wolffd@0
|
403 PsiClock = PsiClock + 1;
|
wolffd@0
|
404 Solver_time = tic();
|
wolffd@0
|
405 % disp('Robust MLR')
|
wolffd@0
|
406 [W, Xi, Dsolver] = rmlr_admm(C, X, Margins, H, Q, lam);
|
wolffd@0
|
407
|
wolffd@0
|
408 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
|
wolffd@0
|
409 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
|
wolffd@0
|
410
|
wolffd@0
|
411 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
|
wolffd@0
|
412 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
|
wolffd@0
|
413 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
|
wolffd@0
|
414
|
wolffd@0
|
415 %%%
|
wolffd@0
|
416 % Cull the old constraints
|
wolffd@0
|
417 GC = PsiClock < ConstraintClock;
|
wolffd@0
|
418 Margins = Margins(GC);
|
wolffd@0
|
419 PsiR = PsiR(GC);
|
wolffd@0
|
420 PsiClock = PsiClock(GC);
|
wolffd@0
|
421 H = H(GC, GC);
|
wolffd@0
|
422 Q = Q(GC);
|
wolffd@0
|
423
|
wolffd@0
|
424 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
|
wolffd@0
|
425 end
|
wolffd@0
|
426
|
wolffd@0
|
427
|
wolffd@0
|
428 % Finish diagnostics
|
wolffd@0
|
429
|
wolffd@0
|
430 Diagnostics.time_total = toc(TIME_START);
|
wolffd@0
|
431 Diagnostics.feasible_count = FEASIBLE_COUNT;
|
wolffd@0
|
432 end
|
wolffd@0
|
433
|
wolffd@0
|
434 function H = expandKernel(H)
|
wolffd@0
|
435
|
wolffd@0
|
436 global STRUCTKERNEL;
|
wolffd@0
|
437 global PsiR;
|
wolffd@0
|
438
|
wolffd@0
|
439 m = length(H);
|
wolffd@0
|
440 H = padarray(H, [1 1], 0, 'post');
|
wolffd@0
|
441
|
wolffd@0
|
442
|
wolffd@0
|
443 for i = 1:m+1
|
wolffd@0
|
444 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
|
wolffd@0
|
445 H(m+1, i) = H(i, m+1);
|
wolffd@0
|
446 end
|
wolffd@0
|
447 end
|
wolffd@0
|
448
|
wolffd@0
|
449 function Q = expandRegularizer(Q, K, W)
|
wolffd@0
|
450
|
wolffd@0
|
451 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
452 % does not support unregularized learning
|
wolffd@0
|
453
|
wolffd@0
|
454 global PsiR;
|
wolffd@0
|
455 global STRUCTKERNEL REG;
|
wolffd@0
|
456
|
wolffd@0
|
457 m = length(Q);
|
wolffd@0
|
458 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});
|
wolffd@0
|
459
|
wolffd@0
|
460 end
|
wolffd@0
|
461
|
wolffd@0
|
462 function ClassScores = synthesizeRelevance(Y)
|
wolffd@0
|
463
|
wolffd@0
|
464 classes = unique(Y);
|
wolffd@0
|
465 nClasses = length(classes);
|
wolffd@0
|
466
|
wolffd@0
|
467 ClassScores = struct( 'Y', Y, ...
|
wolffd@0
|
468 'classes', classes, ...
|
wolffd@0
|
469 'Ypos', [], ...
|
wolffd@0
|
470 'Yneg', []);
|
wolffd@0
|
471
|
wolffd@0
|
472 Ypos = cell(nClasses, 1);
|
wolffd@0
|
473 Yneg = cell(nClasses, 1);
|
wolffd@0
|
474 for c = 1:nClasses
|
wolffd@0
|
475 Ypos{c} = (Y == classes(c));
|
wolffd@0
|
476 Yneg{c} = ~Ypos{c};
|
wolffd@0
|
477 end
|
wolffd@0
|
478
|
wolffd@0
|
479 ClassScores.Ypos = Ypos;
|
wolffd@0
|
480 ClassScores.Yneg = Yneg;
|
wolffd@0
|
481
|
wolffd@0
|
482 end
|