Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/mlr_train.m @ 0:e9a9cd732c1e tip
first hg version after svn
| author | wolffd |
|---|---|
| date | Tue, 10 Feb 2015 15:05:51 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:e9a9cd732c1e |
|---|---|
| 1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin) | |
| 2 % | |
| 3 % [W, Xi, D] = mlr_train(X, Y, C,...) | |
| 4 % | |
| 5 % X = d*n data matrix | |
| 6 % Y = either n-by-1 label of vectors | |
| 7 % OR | |
| 8 % n-by-2 cell array where | |
| 9 % Y{q,1} contains relevant indices for q, and | |
| 10 % Y{q,2} contains irrelevant indices for q | |
| 11 % | |
| 12 % C >= 0 slack trade-off parameter (default=1) | |
| 13 % | |
| 14 % W = the learned metric | |
| 15 % Xi = slack value on the learned metric | |
| 16 % D = diagnostics | |
| 17 % | |
| 18 % Optional arguments: | |
| 19 % | |
| 20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS) | |
| 21 % where LOSS is one of: | |
| 22 % 'AUC': Area under ROC curve (default) | |
| 23 % 'KNN': KNN accuracy | |
| 24 % 'Prec@k': Precision-at-k | |
| 25 % 'MAP': Mean Average Precision | |
| 26 % 'MRR': Mean Reciprocal Rank | |
| 27 % 'NDCG': Normalized Discounted Cumulative Gain | |
| 28 % | |
| 29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) | |
| 30 % where k is the number of neighbors for Prec@k or NDCG | |
| 31 % (default=3) | |
| 32 % | |
| 33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) | |
| 34 % where REG defines the regularization on W, and is one of: | |
| 35 % 0: no regularization | |
| 36 % 1: 1-norm: trace(W) (default) | |
| 37 % 2: 2-norm: trace(W' * W) | |
| 38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite | |
| 39 % | |
| 40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) | |
| 41 % Diagonal = 0: learn a full d-by-d W (default) | |
| 42 % Diagonal = 1: learn diagonally-constrained W (d-by-1) | |
| 43 % | |
| 44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) | |
| 45 % where B > 0 enables stochastic optimization with batch size B | |
| 46 % | |
| 47 % // added by Daniel Wolff | |
| 48 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC) | |
| 49 % Set ConstraintClock to CC (default: 20, 100) | |
| 50 % | |
| 51 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E) | |
| 52 % Set ConstraintClock to E (default: 1e-3) | |
| 53 % | |
| 54 | |
| 55 TIME_START = tic(); | |
| 56 | |
| 57 global C; | |
| 58 C = Cslack; | |
| 59 | |
| 60 [d,n,m] = size(X); | |
| 61 | |
| 62 if m > 1 | |
| 63 MKL = 1; | |
| 64 else | |
| 65 MKL = 0; | |
| 66 end | |
| 67 | |
| 68 if nargin < 3 | |
| 69 C = 1; | |
| 70 end | |
| 71 | |
| 72 %%% | |
| 73 % Default options: | |
| 74 | |
| 75 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT; | |
| 76 | |
| 77 global FEASIBLE_COUNT; | |
| 78 FEASIBLE_COUNT = 0; | |
| 79 | |
| 80 CP = @cuttingPlaneFull; | |
| 81 SO = @separationOracleAUC; | |
| 82 PSI = @metricPsiPO; | |
| 83 | |
| 84 if ~MKL | |
| 85 INIT = @initializeFull; | |
| 86 REG = @regularizeTraceFull; | |
| 87 STRUCTKERNEL= @structKernelLinear; | |
| 88 DUALW = @dualWLinear; | |
| 89 FEASIBLE = @feasibleFull; | |
| 90 CPGRADIENT = @cpGradientFull; | |
| 91 DISTANCE = @distanceFull; | |
| 92 SETDISTANCE = @setDistanceFull; | |
| 93 LOSS = @lossHinge; | |
| 94 Regularizer = 'Trace'; | |
| 95 else | |
| 96 INIT = @initializeFullMKL; | |
| 97 REG = @regularizeMKLFull; | |
| 98 STRUCTKERNEL= @structKernelMKL; | |
| 99 DUALW = @dualWMKL; | |
| 100 FEASIBLE = @feasibleFullMKL; | |
| 101 CPGRADIENT = @cpGradientFullMKL; | |
| 102 DISTANCE = @distanceFullMKL; | |
| 103 SETDISTANCE = @setDistanceFullMKL; | |
| 104 LOSS = @lossHingeFullMKL; | |
| 105 Regularizer = 'Trace'; | |
| 106 end | |
| 107 | |
| 108 | |
| 109 Loss = 'AUC'; | |
| 110 Feature = 'metricPsiPO'; | |
| 111 | |
| 112 | |
| 113 %%% | |
| 114 % Default k for prec@k, ndcg | |
| 115 k = 3; | |
| 116 | |
| 117 %%% | |
| 118 % Stochastic violator selection? | |
| 119 STOCHASTIC = 0; | |
| 120 batchSize = n; | |
| 121 SAMPLES = 1:n; | |
| 122 | |
| 123 | |
| 124 if nargin > 3 | |
| 125 switch lower(varargin{1}) | |
| 126 case {'auc'} | |
| 127 SO = @separationOracleAUC; | |
| 128 PSI = @metricPsiPO; | |
| 129 Loss = 'AUC'; | |
| 130 Feature = 'metricPsiPO'; | |
| 131 case {'knn'} | |
| 132 SO = @separationOracleKNN; | |
| 133 PSI = @metricPsiPO; | |
| 134 Loss = 'KNN'; | |
| 135 Feature = 'metricPsiPO'; | |
| 136 case {'prec@k'} | |
| 137 SO = @separationOraclePrecAtK; | |
| 138 PSI = @metricPsiPO; | |
| 139 Loss = 'Prec@k'; | |
| 140 Feature = 'metricPsiPO'; | |
| 141 case {'map'} | |
| 142 SO = @separationOracleMAP; | |
| 143 PSI = @metricPsiPO; | |
| 144 Loss = 'MAP'; | |
| 145 Feature = 'metricPsiPO'; | |
| 146 case {'mrr'} | |
| 147 SO = @separationOracleMRR; | |
| 148 PSI = @metricPsiPO; | |
| 149 Loss = 'MRR'; | |
| 150 Feature = 'metricPsiPO'; | |
| 151 case {'ndcg'} | |
| 152 SO = @separationOracleNDCG; | |
| 153 PSI = @metricPsiPO; | |
| 154 Loss = 'NDCG'; | |
| 155 Feature = 'metricPsiPO'; | |
| 156 otherwise | |
| 157 error('MLR:LOSS', ... | |
| 158 'Unknown loss function: %s', varargin{1}); | |
| 159 end | |
| 160 end | |
| 161 | |
| 162 if nargin > 4 | |
| 163 k = varargin{2}; | |
| 164 end | |
| 165 | |
| 166 Diagonal = 0; | |
| 167 if nargin > 6 & varargin{4} > 0 | |
| 168 Diagonal = varargin{4}; | |
| 169 | |
| 170 if ~MKL | |
| 171 INIT = @initializeDiag; | |
| 172 REG = @regularizeTraceDiag; | |
| 173 STRUCTKERNEL= @structKernelDiag; | |
| 174 DUALW = @dualWDiag; | |
| 175 FEASIBLE = @feasibleDiag; | |
| 176 CPGRADIENT = @cpGradientDiag; | |
| 177 DISTANCE = @distanceDiag; | |
| 178 SETDISTANCE = @setDistanceDiag; | |
| 179 Regularizer = 'Trace'; | |
| 180 else | |
| 181 INIT = @initializeDiagMKL; | |
| 182 REG = @regularizeMKLDiag; | |
| 183 STRUCTKERNEL= @structKernelDiagMKL; | |
| 184 DUALW = @dualWDiagMKL; | |
| 185 FEASIBLE = @feasibleDiagMKL; | |
| 186 CPGRADIENT = @cpGradientDiagMKL; | |
| 187 DISTANCE = @distanceDiagMKL; | |
| 188 SETDISTANCE = @setDistanceDiagMKL; | |
| 189 LOSS = @lossHingeDiagMKL; | |
| 190 Regularizer = 'Trace'; | |
| 191 end | |
| 192 end | |
| 193 | |
| 194 if nargin > 5 | |
| 195 switch(varargin{3}) | |
| 196 case {0} | |
| 197 REG = @regularizeNone; | |
| 198 Regularizer = 'None'; | |
| 199 | |
| 200 case {1} | |
| 201 if MKL | |
| 202 if Diagonal == 0 | |
| 203 REG = @regularizeMKLFull; | |
| 204 STRUCTKERNEL= @structKernelMKL; | |
| 205 DUALW = @dualWMKL; | |
| 206 elseif Diagonal == 1 | |
| 207 REG = @regularizeMKLDiag; | |
| 208 STRUCTKERNEL= @structKernelDiagMKL; | |
| 209 DUALW = @dualWDiagMKL; | |
| 210 end | |
| 211 else | |
| 212 if Diagonal | |
| 213 REG = @regularizeTraceDiag; | |
| 214 STRUCTKERNEL= @structKernelDiag; | |
| 215 DUALW = @dualWDiag; | |
| 216 else | |
| 217 REG = @regularizeTraceFull; | |
| 218 STRUCTKERNEL= @structKernelLinear; | |
| 219 DUALW = @dualWLinear; | |
| 220 end | |
| 221 end | |
| 222 Regularizer = 'Trace'; | |
| 223 | |
| 224 case {2} | |
| 225 if Diagonal | |
| 226 REG = @regularizeTwoDiag; | |
| 227 else | |
| 228 REG = @regularizeTwoFull; | |
| 229 end | |
| 230 Regularizer = '2-norm'; | |
| 231 error('MLR:REGULARIZER', '2-norm regularization no longer supported'); | |
| 232 | |
| 233 | |
| 234 case {3} | |
| 235 if MKL | |
| 236 if Diagonal == 0 | |
| 237 REG = @regularizeMKLFull; | |
| 238 STRUCTKERNEL= @structKernelMKL; | |
| 239 DUALW = @dualWMKL; | |
| 240 elseif Diagonal == 1 | |
| 241 REG = @regularizeMKLDiag; | |
| 242 STRUCTKERNEL= @structKernelDiagMKL; | |
| 243 DUALW = @dualWDiagMKL; | |
| 244 end | |
| 245 else | |
| 246 if Diagonal | |
| 247 REG = @regularizeMKLDiag; | |
| 248 STRUCTKERNEL= @structKernelDiagMKL; | |
| 249 DUALW = @dualWDiagMKL; | |
| 250 else | |
| 251 REG = @regularizeKernel; | |
| 252 STRUCTKERNEL= @structKernelMKL; | |
| 253 DUALW = @dualWMKL; | |
| 254 end | |
| 255 end | |
| 256 Regularizer = 'Kernel'; | |
| 257 | |
| 258 otherwise | |
| 259 error('MLR:REGULARIZER', ... | |
| 260 'Unknown regularization: %s', varargin{3}); | |
| 261 end | |
| 262 end | |
| 263 | |
| 264 | |
| 265 % Are we in stochastic optimization mode? | |
| 266 if nargin > 7 && varargin{5} > 0 | |
| 267 if varargin{5} < n | |
| 268 STOCHASTIC = 1; | |
| 269 CP = @cuttingPlaneRandom; | |
| 270 batchSize = varargin{5}; | |
| 271 end | |
| 272 end | |
| 273 | |
| 274 % Algorithm | |
| 275 % | |
| 276 % Working <- [] | |
| 277 % | |
| 278 % repeat: | |
| 279 % (W, Xi) <- solver(X, Y, C, Working) | |
| 280 % | |
| 281 % for i = 1:|X| | |
| 282 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) | |
| 283 % | |
| 284 % Working <- Working + (y^_1,y^_2,...,y^_n) | |
| 285 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) | |
| 286 % <= Xi + epsilon | |
| 287 | |
| 288 global DEBUG; | |
| 289 | |
| 290 if isempty(DEBUG) | |
| 291 DEBUG = 0; | |
| 292 end | |
| 293 | |
| 294 %%% | |
| 295 % Timer to eliminate old constraints | |
| 296 | |
| 297 %%% | |
| 298 % Timer to eliminate old constraints | |
| 299 ConstraintClock = 100; % standard: 100 | |
| 300 | |
| 301 if nargin > 8 && varargin{6} > 0 | |
| 302 ConstraintClock = varargin{6}; | |
| 303 end | |
| 304 | |
| 305 %%% | |
| 306 % Convergence criteria for worst-violated constraint | |
| 307 E = 1e-3; | |
| 308 if nargin > 9 && varargin{7} > 0 | |
| 309 E = varargin{7}; | |
| 310 end | |
| 311 | |
| 312 | |
| 313 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu> | |
| 314 % no longer belongs here | |
| 315 % Initialize | |
| 316 W = INIT(X); | |
| 317 | |
| 318 | |
| 319 global ADMM_Z ADMM_U RHO; | |
| 320 ADMM_Z = W; | |
| 321 ADMM_U = 0 * ADMM_Z; | |
| 322 | |
| 323 %%% | |
| 324 % Augmented lagrangian factor | |
| 325 RHO = 1; | |
| 326 | |
| 327 ClassScores = []; | |
| 328 | |
| 329 if isa(Y, 'double') | |
| 330 Ypos = []; | |
| 331 Yneg = []; | |
| 332 ClassScores = synthesizeRelevance(Y); | |
| 333 | |
| 334 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 | |
| 335 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
| 336 Ypos = Y(:,1); | |
| 337 Yneg = Y(:,2); | |
| 338 | |
| 339 % Compute the valid samples | |
| 340 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); | |
| 341 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 | |
| 342 dbprint(2, 'Using supplied Ypos/synthesized Yneg'); | |
| 343 Ypos = Y(:,1); | |
| 344 Yneg = []; | |
| 345 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); | |
| 346 else | |
| 347 error('MLR:LABELS', 'Incorrect format for Y.'); | |
| 348 end | |
| 349 | |
| 350 %% | |
| 351 % If we don't have enough data to make the batch, cut the batch | |
| 352 batchSize = min([batchSize, length(SAMPLES)]); | |
| 353 | |
| 354 | |
| 355 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? | |
| 356 'feature', Feature, ... % Which ranking feature is used? | |
| 357 'k', k, ... % What is the ranking length? | |
| 358 'regularizer', Regularizer, ... % What regularization is used? | |
| 359 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal | |
| 360 'num_calls_SO', 0, ... % Calls to separation oracle | |
| 361 'num_calls_solver', 0, ... % Calls to solver | |
| 362 'time_SO', 0, ... % Time in separation oracle | |
| 363 'time_solver', 0, ... % Time in solver | |
| 364 'time_total', 0, ... % Total time | |
| 365 'f', [], ... % Objective value | |
| 366 'num_steps', [], ... % Number of steps for each solver run | |
| 367 'num_constraints', [], ... % Number of constraints for each run | |
| 368 'Xi', [], ... % Slack achieved for each run | |
| 369 'Delta', [], ... % Mean loss for each SO call | |
| 370 'gap', [], ... % Gap between loss and slack | |
| 371 'C', C, ... % Slack trade-off | |
| 372 'epsilon', E, ... % Convergence threshold | |
| 373 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's | |
| 374 'constraint_timer', ConstraintClock); % Time before evicting old constraints | |
| 375 | |
| 376 | |
| 377 | |
| 378 global PsiR; | |
| 379 global PsiClock; | |
| 380 | |
| 381 PsiR = {}; | |
| 382 PsiClock = []; | |
| 383 | |
| 384 Xi = -Inf; | |
| 385 Margins = []; | |
| 386 H = []; | |
| 387 Q = []; | |
| 388 | |
| 389 if STOCHASTIC | |
| 390 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); | |
| 391 end | |
| 392 | |
| 393 MAXITER = 200; | |
| 394 % while 1 | |
| 395 while Diagnostics.num_calls_solver < MAXITER | |
| 396 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); | |
| 397 % Generate a constraint set | |
| 398 Termination = 0; | |
| 399 | |
| 400 | |
| 401 dbprint(2, 'Calling separation oracle...'); | |
| 402 | |
| 403 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); | |
| 404 Termination = LOSS(W, PsiNew, Mnew, 0); | |
| 405 | |
| 406 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; | |
| 407 Diagnostics.time_SO = Diagnostics.time_SO + SO_time; | |
| 408 | |
| 409 Margins = cat(1, Margins, Mnew); | |
| 410 PsiR = cat(1, PsiR, PsiNew); | |
| 411 PsiClock = cat(1, PsiClock, 0); | |
| 412 H = expandKernel(H); | |
| 413 Q = expandRegularizer(Q, X, W); | |
| 414 | |
| 415 | |
| 416 dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); | |
| 417 dbprint(2, '\t Mean loss : %0.4f', Mnew); | |
| 418 dbprint(2, '\t Current loss Xi : %0.4f', Xi); | |
| 419 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E); | |
| 420 | |
| 421 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi); | |
| 422 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew); | |
| 423 | |
| 424 if Termination <= Xi + E | |
| 425 dbprint(2, 'Done.'); | |
| 426 break; | |
| 427 end | |
| 428 | |
| 429 dbprint(2, 'Calling solver...'); | |
| 430 PsiClock = PsiClock + 1; | |
| 431 Solver_time = tic(); | |
| 432 [W, Xi, Dsolver] = mlr_admm(C, X, Margins, H, Q); | |
| 433 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); | |
| 434 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; | |
| 435 | |
| 436 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); | |
| 437 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); | |
| 438 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); | |
| 439 | |
| 440 %%% | |
| 441 % Cull the old constraints | |
| 442 GC = PsiClock < ConstraintClock; | |
| 443 Margins = Margins(GC); | |
| 444 PsiR = PsiR(GC); | |
| 445 PsiClock = PsiClock(GC); | |
| 446 H = H(GC, GC); | |
| 447 Q = Q(GC); | |
| 448 | |
| 449 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); | |
| 450 end | |
| 451 | |
| 452 | |
| 453 % Finish diagnostics | |
| 454 | |
| 455 Diagnostics.time_total = toc(TIME_START); | |
| 456 Diagnostics.feasible_count = FEASIBLE_COUNT; | |
| 457 end | |
| 458 | |
| 459 function H = expandKernel(H) | |
| 460 | |
| 461 global STRUCTKERNEL; | |
| 462 global PsiR; | |
| 463 | |
| 464 m = length(H); | |
| 465 H = padarray(H, [1 1], 0, 'post'); | |
| 466 | |
| 467 | |
| 468 for i = 1:m+1 | |
| 469 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} ); | |
| 470 H(m+1, i) = H(i, m+1); | |
| 471 end | |
| 472 end | |
| 473 | |
| 474 function Q = expandRegularizer(Q, K, W) | |
| 475 | |
| 476 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu> | |
| 477 % does not support unregularized learning | |
| 478 | |
| 479 global PsiR; | |
| 480 global STRUCTKERNEL REG; | |
| 481 | |
| 482 m = length(Q); | |
| 483 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1}); | |
| 484 | |
| 485 end | |
| 486 | |
| 487 function ClassScores = synthesizeRelevance(Y) | |
| 488 | |
| 489 classes = unique(Y); | |
| 490 nClasses = length(classes); | |
| 491 | |
| 492 ClassScores = struct( 'Y', Y, ... | |
| 493 'classes', classes, ... | |
| 494 'Ypos', [], ... | |
| 495 'Yneg', []); | |
| 496 | |
| 497 Ypos = cell(nClasses, 1); | |
| 498 Yneg = cell(nClasses, 1); | |
| 499 for c = 1:nClasses | |
| 500 Ypos{c} = (Y == classes(c)); | |
| 501 Yneg{c} = ~Ypos{c}; | |
| 502 end | |
| 503 | |
| 504 ClassScores.Ypos = Ypos; | |
| 505 ClassScores.Yneg = Yneg; | |
| 506 | |
| 507 end |
