annotate toolboxes/distance_learning/mlr/mlr_train.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function [W, Xi, Diagnostics] = mlr_train(X, Y, C, varargin)
Daniel@0 2 %
Daniel@0 3 % [W, Xi, D] = mlr_train(X, Y, C,...)
Daniel@0 4 %
Daniel@0 5 % X = d*n data matrix
Daniel@0 6 % Y = either n-by-1 label of vectors
Daniel@0 7 % OR
Daniel@0 8 % n-by-2 cell array where
Daniel@0 9 % Y{q,1} contains relevant indices for q, and
Daniel@0 10 % Y{q,2} contains irrelevant indices for q
Daniel@0 11 %
Daniel@0 12 % C >= 0 slack trade-off parameter (default=1)
Daniel@0 13 %
Daniel@0 14 % W = the learned metric
Daniel@0 15 % Xi = slack value on the learned metric
Daniel@0 16 % D = diagnostics
Daniel@0 17 %
Daniel@0 18 % Optional arguments:
Daniel@0 19 %
Daniel@0 20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
Daniel@0 21 % where LOSS is one of:
Daniel@0 22 % 'AUC': Area under ROC curve (default)
Daniel@0 23 % 'KNN': KNN accuracy
Daniel@0 24 % 'Prec@k': Precision-at-k
Daniel@0 25 % 'MAP': Mean Average Precision
Daniel@0 26 % 'MRR': Mean Reciprocal Rank
Daniel@0 27 % 'NDCG': Normalized Discounted Cumulative Gain
Daniel@0 28 %
Daniel@0 29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
Daniel@0 30 % where k is the number of neighbors for Prec@k or NDCG
Daniel@0 31 % (default=3)
Daniel@0 32 %
Daniel@0 33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
Daniel@0 34 % where REG defines the regularization on W, and is one of:
Daniel@0 35 % 0: no regularization
Daniel@0 36 % 1: 1-norm: trace(W) (default)
Daniel@0 37 % 2: 2-norm: trace(W' * W)
Daniel@0 38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
Daniel@0 39 %
Daniel@0 40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
Daniel@0 41 % Diagonal = 0: learn a full d-by-d W (default)
Daniel@0 42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
Daniel@0 43 %
Daniel@0 44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
Daniel@0 45 % where B > 0 enables stochastic optimization with batch size B
Daniel@0 46 %
Daniel@0 47 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
Daniel@0 48 % Set ConstraintClock to CC (default: 20, 100)
Daniel@0 49 %
Daniel@0 50 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
Daniel@0 51 % Set ConstraintClock to E (default: 1e-3)
Daniel@0 52 %
Daniel@0 53
Daniel@0 54
Daniel@0 55 global globalvars;
Daniel@0 56 global DEBUG;
Daniel@0 57
Daniel@0 58 if isfield(globalvars, 'debug')
Daniel@0 59
Daniel@0 60 DEBUG = globalvars.debug;
Daniel@0 61 else
Daniel@0 62
Daniel@0 63 DEBUG = 0;
Daniel@0 64 end
Daniel@0 65
Daniel@0 66
Daniel@0 67
Daniel@0 68 TIME_START = tic();
Daniel@0 69
Daniel@0 70 % addpath('cuttingPlane', 'distance', 'feasible', 'initialize', 'loss', ...
Daniel@0 71 % 'metricPsi', 'regularize', 'separationOracle', 'util');
Daniel@0 72
Daniel@0 73 [d,n,m] = size(X);
Daniel@0 74
Daniel@0 75 if m > 1
Daniel@0 76 MKL = 1;
Daniel@0 77 else
Daniel@0 78 MKL = 0;
Daniel@0 79 end
Daniel@0 80
Daniel@0 81 if nargin < 3
Daniel@0 82 C = 1;
Daniel@0 83 end
Daniel@0 84
Daniel@0 85 %%%
Daniel@0 86 % Default options:
Daniel@0 87
Daniel@0 88 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT METRICK;
Daniel@0 89
Daniel@0 90 CP = @cuttingPlaneFull;
Daniel@0 91 SO = @separationOracleAUC;
Daniel@0 92 PSI = @metricPsiPO;
Daniel@0 93
Daniel@0 94 if ~MKL
Daniel@0 95 INIT = @initializeFull;
Daniel@0 96 REG = @regularizeTraceFull;
Daniel@0 97 FEASIBLE = @feasibleFull;
Daniel@0 98 CPGRADIENT = @cpGradientFull;
Daniel@0 99 DISTANCE = @distanceFull;
Daniel@0 100 SETDISTANCE = @setDistanceFull;
Daniel@0 101 LOSS = @lossHinge;
Daniel@0 102 Regularizer = 'Trace';
Daniel@0 103 else
Daniel@0 104 INIT = @initializeFullMKL;
Daniel@0 105 REG = @regularizeMKLFull;
Daniel@0 106 FEASIBLE = @feasibleFullMKL;
Daniel@0 107 CPGRADIENT = @cpGradientFullMKL;
Daniel@0 108 DISTANCE = @distanceFullMKL;
Daniel@0 109 SETDISTANCE = @setDistanceFullMKL;
Daniel@0 110 LOSS = @lossHingeFullMKL;
Daniel@0 111 Regularizer = 'Trace';
Daniel@0 112 end
Daniel@0 113
Daniel@0 114
Daniel@0 115 Loss = 'AUC';
Daniel@0 116 Feature = 'metricPsiPO';
Daniel@0 117
Daniel@0 118
Daniel@0 119 %%%
Daniel@0 120 % Default k for prec@k, ndcg
Daniel@0 121 k = 3;
Daniel@0 122
Daniel@0 123 %%%
Daniel@0 124 % Stochastic violator selection?
Daniel@0 125 STOCHASTIC = 0;
Daniel@0 126 batchSize = n;
Daniel@0 127 SAMPLES = 1:n;
Daniel@0 128
Daniel@0 129
Daniel@0 130 if nargin > 3
Daniel@0 131 switch lower(varargin{1})
Daniel@0 132 case {'auc'}
Daniel@0 133 SO = @separationOracleAUC;
Daniel@0 134 PSI = @metricPsiPO;
Daniel@0 135 Loss = 'AUC';
Daniel@0 136 Feature = 'metricPsiPO';
Daniel@0 137 case {'knn'}
Daniel@0 138 SO = @separationOracleKNN;
Daniel@0 139 PSI = @metricPsiPO;
Daniel@0 140 Loss = 'KNN';
Daniel@0 141 Feature = 'metricPsiPO';
Daniel@0 142 case {'prec@k'}
Daniel@0 143 SO = @separationOraclePrecAtK;
Daniel@0 144 PSI = @metricPsiPO;
Daniel@0 145 Loss = 'Prec@k';
Daniel@0 146 Feature = 'metricPsiPO';
Daniel@0 147 case {'map'}
Daniel@0 148 SO = @separationOracleMAP;
Daniel@0 149 PSI = @metricPsiPO;
Daniel@0 150 Loss = 'MAP';
Daniel@0 151 Feature = 'metricPsiPO';
Daniel@0 152 case {'mrr'}
Daniel@0 153 SO = @separationOracleMRR;
Daniel@0 154 PSI = @metricPsiPO;
Daniel@0 155 Loss = 'MRR';
Daniel@0 156 Feature = 'metricPsiPO';
Daniel@0 157 case {'ndcg'}
Daniel@0 158 SO = @separationOracleNDCG;
Daniel@0 159 PSI = @metricPsiPO;
Daniel@0 160 Loss = 'NDCG';
Daniel@0 161 Feature = 'metricPsiPO';
Daniel@0 162 otherwise
Daniel@0 163 error('MLR:LOSS', ...
Daniel@0 164 'Unknown loss function: %s', varargin{1});
Daniel@0 165 end
Daniel@0 166 end
Daniel@0 167
Daniel@0 168 if nargin > 4
Daniel@0 169 k = varargin{2};
Daniel@0 170 end
Daniel@0 171
Daniel@0 172 METRICK = k;
Daniel@0 173
Daniel@0 174 Diagonal = 0;
Daniel@0 175 if nargin > 6 & varargin{4} > 0
Daniel@0 176 Diagonal = varargin{4};
Daniel@0 177
Daniel@0 178 if ~MKL
Daniel@0 179 INIT = @initializeDiag;
Daniel@0 180 REG = @regularizeTraceDiag;
Daniel@0 181 FEASIBLE = @feasibleDiag;
Daniel@0 182 CPGRADIENT = @cpGradientDiag;
Daniel@0 183 DISTANCE = @distanceDiag;
Daniel@0 184 SETDISTANCE = @setDistanceDiag;
Daniel@0 185 Regularizer = 'Trace';
Daniel@0 186 else
Daniel@0 187 if Diagonal > 1
Daniel@0 188 INIT = @initializeDODMKL;
Daniel@0 189 REG = @regularizeMKLDOD;
Daniel@0 190 FEASIBLE = @feasibleDODMKL;
Daniel@0 191 CPGRADIENT = @cpGradientDODMKL;
Daniel@0 192 DISTANCE = @distanceDODMKL;
Daniel@0 193 SETDISTANCE = @setDistanceDODMKL;
Daniel@0 194 LOSS = @lossHingeDODMKL;
Daniel@0 195 Regularizer = 'Trace';
Daniel@0 196 else
Daniel@0 197 INIT = @initializeDiagMKL;
Daniel@0 198 REG = @regularizeMKLDiag;
Daniel@0 199 FEASIBLE = @feasibleDiagMKL;
Daniel@0 200 CPGRADIENT = @cpGradientDiagMKL;
Daniel@0 201 DISTANCE = @distanceDiagMKL;
Daniel@0 202 SETDISTANCE = @setDistanceDiagMKL;
Daniel@0 203 LOSS = @lossHingeDiagMKL;
Daniel@0 204 Regularizer = 'Trace';
Daniel@0 205 end
Daniel@0 206 end
Daniel@0 207 end
Daniel@0 208
Daniel@0 209 if nargin > 5
Daniel@0 210 switch(varargin{3})
Daniel@0 211 case {0}
Daniel@0 212 REG = @regularizeNone;
Daniel@0 213 Regularizer = 'None';
Daniel@0 214
Daniel@0 215 case {1}
Daniel@0 216 if MKL
Daniel@0 217 if Diagonal == 0
Daniel@0 218 REG = @regularizeMKLFull;
Daniel@0 219 elseif Diagonal == 1
Daniel@0 220 REG = @regularizeMKLDiag;
Daniel@0 221 elseif Diagonal == 2
Daniel@0 222 REG = @regularizeMKLDOD;
Daniel@0 223 end
Daniel@0 224 else
Daniel@0 225 if Diagonal
Daniel@0 226 REG = @regularizeTraceDiag;
Daniel@0 227 else
Daniel@0 228 REG = @regularizeTraceFull;
Daniel@0 229 end
Daniel@0 230 end
Daniel@0 231 Regularizer = 'Trace';
Daniel@0 232
Daniel@0 233 case {2}
Daniel@0 234 if Diagonal
Daniel@0 235 REG = @regularizeTwoDiag;
Daniel@0 236 else
Daniel@0 237 REG = @regularizeTwoFull;
Daniel@0 238 end
Daniel@0 239 Regularizer = '2-norm';
Daniel@0 240
Daniel@0 241 case {3}
Daniel@0 242 if MKL
Daniel@0 243 if Diagonal == 0
Daniel@0 244 REG = @regularizeMKLFull;
Daniel@0 245 elseif Diagonal == 1
Daniel@0 246 REG = @regularizeMKLDiag;
Daniel@0 247 elseif Diagonal == 2
Daniel@0 248 REG = @regularizeMKLDOD;
Daniel@0 249 end
Daniel@0 250 else
Daniel@0 251 if Diagonal
Daniel@0 252 REG = @regularizeMKLDiag;
Daniel@0 253 else
Daniel@0 254 REG = @regularizeKernel;
Daniel@0 255 end
Daniel@0 256 end
Daniel@0 257 Regularizer = 'Kernel';
Daniel@0 258
Daniel@0 259 otherwise
Daniel@0 260 error('MLR:REGULARIZER', ...
Daniel@0 261 'Unknown regularization: %s', varargin{3});
Daniel@0 262 end
Daniel@0 263 end
Daniel@0 264
Daniel@0 265
Daniel@0 266 % Are we in stochastic optimization mode?
Daniel@0 267 if nargin > 7 && varargin{5} > 0
Daniel@0 268 if varargin{5} < n
Daniel@0 269 STOCHASTIC = 1;
Daniel@0 270 CP = @cuttingPlaneRandom;
Daniel@0 271 batchSize = varargin{5};
Daniel@0 272 end
Daniel@0 273 end
Daniel@0 274
Daniel@0 275 %%%
Daniel@0 276 % Timer to eliminate old constraints
Daniel@0 277 ConstraintClock = 20;
Daniel@0 278
Daniel@0 279 if nargin > 8 && varargin{6} > 0
Daniel@0 280 ConstraintClock = varargin{6};
Daniel@0 281 end
Daniel@0 282
Daniel@0 283 %%%
Daniel@0 284 % Convergence criteria for worst-violated constraint
Daniel@0 285 E = 1e-3;
Daniel@0 286 if nargin > 9 && varargin{7} > 0
Daniel@0 287 E = varargin{7};
Daniel@0 288 end
Daniel@0 289
Daniel@0 290 % Algorithm
Daniel@0 291 %
Daniel@0 292 % Working <- []
Daniel@0 293 %
Daniel@0 294 % repeat:
Daniel@0 295 % (W, Xi) <- solver(X, Y, C, Working)
Daniel@0 296 %
Daniel@0 297 % for i = 1:|X|
Daniel@0 298 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
Daniel@0 299 %
Daniel@0 300 % Working <- Working + (y^_1,y^_2,...,y^_n)
Daniel@0 301 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
Daniel@0 302 % <= Xi + epsilon
Daniel@0 303
Daniel@0 304
Daniel@0 305
Daniel@0 306
Daniel@0 307
Daniel@0 308 % Initialize
Daniel@0 309 W = INIT(X);
Daniel@0 310
Daniel@0 311 ClassScores = [];
Daniel@0 312
Daniel@0 313 if isa(Y, 'double')
Daniel@0 314 Ypos = [];
Daniel@0 315 Yneg = [];
Daniel@0 316 ClassScores = synthesizeRelevance(Y);
Daniel@0 317
Daniel@0 318 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
Daniel@0 319 dbprint(2, 'Using supplied Ypos/Yneg');
Daniel@0 320 Ypos = Y(:,1);
Daniel@0 321 Yneg = Y(:,2);
Daniel@0 322
Daniel@0 323 % Compute the valid samples
Daniel@0 324 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
Daniel@0 325 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
Daniel@0 326 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
Daniel@0 327 Ypos = Y(:,1);
Daniel@0 328 Yneg = [];
Daniel@0 329 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
Daniel@0 330 else
Daniel@0 331 error('MLR:LABELS', 'Incorrect format for Y.');
Daniel@0 332 end
Daniel@0 333
Daniel@0 334
Daniel@0 335 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
Daniel@0 336 'feature', Feature, ... % Which ranking feature is used?
Daniel@0 337 'k', k, ... % What is the ranking length?
Daniel@0 338 'regularizer', Regularizer, ... % What regularization is used?
Daniel@0 339 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
Daniel@0 340 'num_calls_SO', 0, ... % Calls to separation oracle
Daniel@0 341 'num_calls_solver', 0, ... % Calls to solver
Daniel@0 342 'time_SO', 0, ... % Time in separation oracle
Daniel@0 343 'time_solver', 0, ... % Time in solver
Daniel@0 344 'time_total', 0, ... % Total time
Daniel@0 345 'f', [], ... % Objective value
Daniel@0 346 'num_steps', [], ... % Number of steps for each solver run
Daniel@0 347 'num_constraints', [], ... % Number of constraints for each run
Daniel@0 348 'Xi', [], ... % Slack achieved for each run
Daniel@0 349 'Delta', [], ... % Mean loss for each SO call
Daniel@0 350 'gap', [], ... % Gap between loss and slack
Daniel@0 351 'C', C, ... % Slack trade-off
Daniel@0 352 'epsilon', E, ... % Convergence threshold
Daniel@0 353 'constraint_timer', ConstraintClock); % Time before evicting old constraints
Daniel@0 354
Daniel@0 355
Daniel@0 356
Daniel@0 357 global PsiR;
Daniel@0 358 global PsiClock;
Daniel@0 359
Daniel@0 360 PsiR = {};
Daniel@0 361 PsiClock = [];
Daniel@0 362
Daniel@0 363 Xi = -Inf;
Daniel@0 364 Margins = [];
Daniel@0 365
Daniel@0 366 if STOCHASTIC
Daniel@0 367 dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
Daniel@0 368 end
Daniel@0 369
Daniel@0 370 while 1
Daniel@0 371 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
Daniel@0 372 % Generate a constraint set
Daniel@0 373 Termination = 0;
Daniel@0 374
Daniel@0 375
Daniel@0 376 dbprint(2, 'Calling separation oracle...');
Daniel@0 377
Daniel@0 378 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
Daniel@0 379 Termination = LOSS(W, PsiNew, Mnew, 0);
Daniel@0 380
Daniel@0 381 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
Daniel@0 382 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
Daniel@0 383
Daniel@0 384 Margins = cat(1, Margins, Mnew);
Daniel@0 385 PsiR = cat(1, PsiR, PsiNew);
Daniel@0 386 PsiClock = cat(1, PsiClock, 0);
Daniel@0 387
Daniel@0 388 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
Daniel@0 389 dbprint(2, '\t Mean loss : %0.4f', Mnew);
Daniel@0 390 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
Daniel@0 391
Daniel@0 392 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
Daniel@0 393 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
Daniel@0 394
Daniel@0 395 if Termination <= Xi + E
Daniel@0 396 dbprint(2, 'Done.');
Daniel@0 397 break;
Daniel@0 398 end
Daniel@0 399
Daniel@0 400 dbprint(2, 'Calling solver...');
Daniel@0 401 PsiClock = PsiClock + 1;
Daniel@0 402 Solver_time = tic();
Daniel@0 403 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X);
Daniel@0 404 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
Daniel@0 405 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
Daniel@0 406
Daniel@0 407 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
Daniel@0 408 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
Daniel@0 409 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
Daniel@0 410
Daniel@0 411 %%%
Daniel@0 412 % Cull the old constraints
Daniel@0 413 GC = PsiClock < ConstraintClock;
Daniel@0 414 Margins = Margins(GC);
Daniel@0 415 PsiR = PsiR(GC);
Daniel@0 416 PsiClock = PsiClock(GC);
Daniel@0 417
Daniel@0 418 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
Daniel@0 419 end
Daniel@0 420
Daniel@0 421
Daniel@0 422 % Finish diagnostics
Daniel@0 423
Daniel@0 424 Diagnostics.time_total = toc(TIME_START);
Daniel@0 425 end
Daniel@0 426
Daniel@0 427
Daniel@0 428 function ClassScores = synthesizeRelevance(Y)
Daniel@0 429
Daniel@0 430 classes = unique(Y);
Daniel@0 431 nClasses = length(classes);
Daniel@0 432
Daniel@0 433 ClassScores = struct( 'Y', Y, ...
Daniel@0 434 'classes', classes, ...
Daniel@0 435 'Ypos', [], ...
Daniel@0 436 'Yneg', []);
Daniel@0 437
Daniel@0 438 Ypos = cell(nClasses, 1);
Daniel@0 439 Yneg = cell(nClasses, 1);
Daniel@0 440 for c = 1:nClasses
Daniel@0 441 Ypos{c} = (Y == classes(c));
Daniel@0 442 Yneg{c} = ~Ypos{c};
Daniel@0 443 end
Daniel@0 444
Daniel@0 445 ClassScores.Ypos = Ypos;
Daniel@0 446 ClassScores.Yneg = Yneg;
Daniel@0 447
Daniel@0 448 end