annotate toolboxes/distance_learning/mlr/mlr_train_primal.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
wolffd@0 2 %
wolffd@0 3 % [W, Xi, D] = mlr_train(X, Y, C,...)
wolffd@0 4 %
wolffd@0 5 % X = d*n data matrix
wolffd@0 6 % Y = either n-by-1 label of vectors
wolffd@0 7 % OR
wolffd@0 8 % n-by-2 cell array where
wolffd@0 9 % Y{q,1} contains relevant indices for q, and
wolffd@0 10 % Y{q,2} contains irrelevant indices for q
wolffd@0 11 %
wolffd@0 12 % C >= 0 slack trade-off parameter (default=1)
wolffd@0 13 %
wolffd@0 14 % W = the learned metric
wolffd@0 15 % Xi = slack value on the learned metric
wolffd@0 16 % D = diagnostics
wolffd@0 17 %
wolffd@0 18 % Optional arguments:
wolffd@0 19 %
wolffd@0 20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
wolffd@0 21 % where LOSS is one of:
wolffd@0 22 % 'AUC': Area under ROC curve (default)
wolffd@0 23 % 'KNN': KNN accuracy
wolffd@0 24 % 'Prec@k': Precision-at-k
wolffd@0 25 % 'MAP': Mean Average Precision
wolffd@0 26 % 'MRR': Mean Reciprocal Rank
wolffd@0 27 % 'NDCG': Normalized Discounted Cumulative Gain
wolffd@0 28 %
wolffd@0 29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
wolffd@0 30 % where k is the number of neighbors for Prec@k or NDCG
wolffd@0 31 % (default=3)
wolffd@0 32 %
wolffd@0 33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
wolffd@0 34 % where REG defines the regularization on W, and is one of:
wolffd@0 35 % 0: no regularization
wolffd@0 36 % 1: 1-norm: trace(W) (default)
wolffd@0 37 % 2: 2-norm: trace(W' * W)
wolffd@0 38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
wolffd@0 39 %
wolffd@0 40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
wolffd@0 41 % Diagonal = 0: learn a full d-by-d W (default)
wolffd@0 42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
wolffd@0 43 %
wolffd@0 44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
wolffd@0 45 % where B > 0 enables stochastic optimization with batch size B
wolffd@0 46 %
wolffd@0 47
wolffd@0 48 TIME_START = tic();
wolffd@0 49
wolffd@0 50 global C;
wolffd@0 51 C = Cslack;
wolffd@0 52
wolffd@0 53 [d,n,m] = size(X);
wolffd@0 54
wolffd@0 55 if m > 1
wolffd@0 56 MKL = 1;
wolffd@0 57 else
wolffd@0 58 MKL = 0;
wolffd@0 59 end
wolffd@0 60
wolffd@0 61 if nargin < 3
wolffd@0 62 C = 1;
wolffd@0 63 end
wolffd@0 64
wolffd@0 65 %%%
wolffd@0 66 % Default options:
wolffd@0 67
wolffd@0 68 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT;
wolffd@0 69 global FEASIBLE_COUNT;
wolffd@0 70 FEASIBLE_COUNT = 0;
wolffd@0 71
wolffd@0 72 CP = @cuttingPlaneFull;
wolffd@0 73 SO = @separationOracleAUC;
wolffd@0 74 PSI = @metricPsiPO;
wolffd@0 75
wolffd@0 76 if ~MKL
wolffd@0 77 INIT = @initializeFull;
wolffd@0 78 REG = @regularizeTraceFull;
wolffd@0 79 FEASIBLE = @feasibleFull;
wolffd@0 80 CPGRADIENT = @cpGradientFull;
wolffd@0 81 DISTANCE = @distanceFull;
wolffd@0 82 SETDISTANCE = @setDistanceFull;
wolffd@0 83 LOSS = @lossHinge;
wolffd@0 84 Regularizer = 'Trace';
wolffd@0 85 else
wolffd@0 86 INIT = @initializeFullMKL;
wolffd@0 87 REG = @regularizeMKLFull;
wolffd@0 88 FEASIBLE = @feasibleFullMKL;
wolffd@0 89 CPGRADIENT = @cpGradientFullMKL;
wolffd@0 90 DISTANCE = @distanceFullMKL;
wolffd@0 91 SETDISTANCE = @setDistanceFullMKL;
wolffd@0 92 LOSS = @lossHingeFullMKL;
wolffd@0 93 Regularizer = 'Trace';
wolffd@0 94 end
wolffd@0 95
wolffd@0 96
wolffd@0 97 Loss = 'AUC';
wolffd@0 98 Feature = 'metricPsiPO';
wolffd@0 99
wolffd@0 100
wolffd@0 101 %%%
wolffd@0 102 % Default k for prec@k, ndcg
wolffd@0 103 k = 3;
wolffd@0 104
wolffd@0 105 %%%
wolffd@0 106 % Stochastic violator selection?
wolffd@0 107 STOCHASTIC = 0;
wolffd@0 108 batchSize = n;
wolffd@0 109 SAMPLES = 1:n;
wolffd@0 110
wolffd@0 111
wolffd@0 112 if nargin > 3
wolffd@0 113 switch lower(varargin{1})
wolffd@0 114 case {'auc'}
wolffd@0 115 SO = @separationOracleAUC;
wolffd@0 116 PSI = @metricPsiPO;
wolffd@0 117 Loss = 'AUC';
wolffd@0 118 Feature = 'metricPsiPO';
wolffd@0 119 case {'knn'}
wolffd@0 120 SO = @separationOracleKNN;
wolffd@0 121 PSI = @metricPsiPO;
wolffd@0 122 Loss = 'KNN';
wolffd@0 123 Feature = 'metricPsiPO';
wolffd@0 124 case {'prec@k'}
wolffd@0 125 SO = @separationOraclePrecAtK;
wolffd@0 126 PSI = @metricPsiPO;
wolffd@0 127 Loss = 'Prec@k';
wolffd@0 128 Feature = 'metricPsiPO';
wolffd@0 129 case {'map'}
wolffd@0 130 SO = @separationOracleMAP;
wolffd@0 131 PSI = @metricPsiPO;
wolffd@0 132 Loss = 'MAP';
wolffd@0 133 Feature = 'metricPsiPO';
wolffd@0 134 case {'mrr'}
wolffd@0 135 SO = @separationOracleMRR;
wolffd@0 136 PSI = @metricPsiPO;
wolffd@0 137 Loss = 'MRR';
wolffd@0 138 Feature = 'metricPsiPO';
wolffd@0 139 case {'ndcg'}
wolffd@0 140 SO = @separationOracleNDCG;
wolffd@0 141 PSI = @metricPsiPO;
wolffd@0 142 Loss = 'NDCG';
wolffd@0 143 Feature = 'metricPsiPO';
wolffd@0 144 otherwise
wolffd@0 145 error('MLR:LOSS', ...
wolffd@0 146 'Unknown loss function: %s', varargin{1});
wolffd@0 147 end
wolffd@0 148 end
wolffd@0 149
wolffd@0 150 if nargin > 4
wolffd@0 151 k = varargin{2};
wolffd@0 152 end
wolffd@0 153
wolffd@0 154 Diagonal = 0;
wolffd@0 155 if nargin > 6 & varargin{4} > 0
wolffd@0 156 Diagonal = varargin{4};
wolffd@0 157
wolffd@0 158 if ~MKL
wolffd@0 159 INIT = @initializeDiag;
wolffd@0 160 REG = @regularizeTraceDiag;
wolffd@0 161 FEASIBLE = @feasibleDiag;
wolffd@0 162 CPGRADIENT = @cpGradientDiag;
wolffd@0 163 DISTANCE = @distanceDiag;
wolffd@0 164 SETDISTANCE = @setDistanceDiag;
wolffd@0 165 Regularizer = 'Trace';
wolffd@0 166 else
wolffd@0 167 INIT = @initializeDiagMKL;
wolffd@0 168 REG = @regularizeMKLDiag;
wolffd@0 169 FEASIBLE = @feasibleDiagMKL;
wolffd@0 170 CPGRADIENT = @cpGradientDiagMKL;
wolffd@0 171 DISTANCE = @distanceDiagMKL;
wolffd@0 172 SETDISTANCE = @setDistanceDiagMKL;
wolffd@0 173 LOSS = @lossHingeDiagMKL;
wolffd@0 174 Regularizer = 'Trace';
wolffd@0 175 end
wolffd@0 176 end
wolffd@0 177
wolffd@0 178 if nargin > 5
wolffd@0 179 switch(varargin{3})
wolffd@0 180 case {0}
wolffd@0 181 REG = @regularizeNone;
wolffd@0 182 Regularizer = 'None';
wolffd@0 183
wolffd@0 184 case {1}
wolffd@0 185 if MKL
wolffd@0 186 if Diagonal == 0
wolffd@0 187 REG = @regularizeMKLFull;
wolffd@0 188 elseif Diagonal == 1
wolffd@0 189 REG = @regularizeMKLDiag;
wolffd@0 190 end
wolffd@0 191 else
wolffd@0 192 if Diagonal
wolffd@0 193 REG = @regularizeTraceDiag;
wolffd@0 194 else
wolffd@0 195 REG = @regularizeTraceFull;
wolffd@0 196 end
wolffd@0 197 end
wolffd@0 198 Regularizer = 'Trace';
wolffd@0 199
wolffd@0 200 case {2}
wolffd@0 201 if Diagonal
wolffd@0 202 REG = @regularizeTwoDiag;
wolffd@0 203 else
wolffd@0 204 REG = @regularizeTwoFull;
wolffd@0 205 end
wolffd@0 206 Regularizer = '2-norm';
wolffd@0 207
wolffd@0 208 case {3}
wolffd@0 209 if MKL
wolffd@0 210 if Diagonal == 0
wolffd@0 211 REG = @regularizeMKLFull;
wolffd@0 212 elseif Diagonal == 1
wolffd@0 213 REG = @regularizeMKLDiag;
wolffd@0 214 end
wolffd@0 215 else
wolffd@0 216 if Diagonal
wolffd@0 217 REG = @regularizeMKLDiag;
wolffd@0 218 else
wolffd@0 219 REG = @regularizeKernel;
wolffd@0 220 end
wolffd@0 221 end
wolffd@0 222 Regularizer = 'Kernel';
wolffd@0 223
wolffd@0 224 otherwise
wolffd@0 225 error('MLR:REGULARIZER', ...
wolffd@0 226 'Unknown regularization: %s', varargin{3});
wolffd@0 227 end
wolffd@0 228 end
wolffd@0 229
wolffd@0 230
wolffd@0 231 % Are we in stochastic optimization mode?
wolffd@0 232 if nargin > 7 && varargin{5} > 0
wolffd@0 233 if varargin{5} < n
wolffd@0 234 STOCHASTIC = 1;
wolffd@0 235 CP = @cuttingPlaneRandom;
wolffd@0 236 batchSize = varargin{5};
wolffd@0 237 end
wolffd@0 238 end
wolffd@0 239 % Algorithm
wolffd@0 240 %
wolffd@0 241 % Working <- []
wolffd@0 242 %
wolffd@0 243 % repeat:
wolffd@0 244 % (W, Xi) <- solver(X, Y, C, Working)
wolffd@0 245 %
wolffd@0 246 % for i = 1:|X|
wolffd@0 247 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
wolffd@0 248 %
wolffd@0 249 % Working <- Working + (y^_1,y^_2,...,y^_n)
wolffd@0 250 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
wolffd@0 251 % <= Xi + epsilon
wolffd@0 252
wolffd@0 253 global DEBUG;
wolffd@0 254
wolffd@0 255 if isempty(DEBUG)
wolffd@0 256 DEBUG = 0;
wolffd@0 257 end
wolffd@0 258
wolffd@0 259 %%%
wolffd@0 260 % Timer to eliminate old constraints
wolffd@0 261 ConstraintClock = 100;
wolffd@0 262
wolffd@0 263 %%%
wolffd@0 264 % Convergence criteria for worst-violated constraint
wolffd@0 265 E = 1e-3;
wolffd@0 266
wolffd@0 267 % Initialize
wolffd@0 268 W = INIT(X);
wolffd@0 269
wolffd@0 270 ClassScores = [];
wolffd@0 271
wolffd@0 272 if isa(Y, 'double')
wolffd@0 273 Ypos = [];
wolffd@0 274 Yneg = [];
wolffd@0 275 ClassScores = synthesizeRelevance(Y);
wolffd@0 276
wolffd@0 277 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
wolffd@0 278 dbprint(1, 'Using supplied Ypos/Yneg');
wolffd@0 279 Ypos = Y(:,1);
wolffd@0 280 Yneg = Y(:,2);
wolffd@0 281
wolffd@0 282 % Compute the valid samples
wolffd@0 283 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
wolffd@0 284 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
wolffd@0 285 dbprint(1, 'Using supplied Ypos/synthesized Yneg');
wolffd@0 286 Ypos = Y(:,1);
wolffd@0 287 Yneg = [];
wolffd@0 288 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
wolffd@0 289 else
wolffd@0 290 error('MLR:LABELS', 'Incorrect format for Y.');
wolffd@0 291 end
wolffd@0 292
wolffd@0 293 %%
wolffd@0 294 % If we don't have enough data to make the batch, cut the batch
wolffd@0 295 batchSize = min([batchSize, length(SAMPLES)]);
wolffd@0 296
wolffd@0 297
wolffd@0 298 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
wolffd@0 299 'feature', Feature, ... % Which ranking feature is used?
wolffd@0 300 'k', k, ... % What is the ranking length?
wolffd@0 301 'regularizer', Regularizer, ... % What regularization is used?
wolffd@0 302 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
wolffd@0 303 'num_calls_SO', 0, ... % Calls to separation oracle
wolffd@0 304 'num_calls_solver', 0, ... % Calls to solver
wolffd@0 305 'time_SO', 0, ... % Time in separation oracle
wolffd@0 306 'time_solver', 0, ... % Time in solver
wolffd@0 307 'time_total', 0, ... % Total time
wolffd@0 308 'f', [], ... % Objective value
wolffd@0 309 'num_steps', [], ... % Number of steps for each solver run
wolffd@0 310 'num_constraints', [], ... % Number of constraints for each run
wolffd@0 311 'Xi', [], ... % Slack achieved for each run
wolffd@0 312 'Delta', [], ... % Mean loss for each SO call
wolffd@0 313 'gap', [], ... % Gap between loss and slack
wolffd@0 314 'C', C, ... % Slack trade-off
wolffd@0 315 'epsilon', E, ... % Convergence threshold
wolffd@0 316 'feasible_count', 0, ... % Counter for projections
wolffd@0 317 'constraint_timer', ConstraintClock); % Time before evicting old constraints
wolffd@0 318
wolffd@0 319
wolffd@0 320
wolffd@0 321 global PsiR;
wolffd@0 322 global PsiClock;
wolffd@0 323
wolffd@0 324 PsiR = {};
wolffd@0 325 PsiClock = [];
wolffd@0 326
wolffd@0 327 Xi = -Inf;
wolffd@0 328 Margins = [];
wolffd@0 329
wolffd@0 330 if STOCHASTIC
wolffd@0 331 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
wolffd@0 332 end
wolffd@0 333
wolffd@0 334 while 1
wolffd@0 335 dbprint(1, 'Round %03d', Diagnostics.num_calls_solver);
wolffd@0 336 % Generate a constraint set
wolffd@0 337 Termination = 0;
wolffd@0 338
wolffd@0 339
wolffd@0 340 dbprint(2, 'Calling separation oracle...');
wolffd@0 341
wolffd@0 342 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
wolffd@0 343 Termination = LOSS(W, PsiNew, Mnew, 0);
wolffd@0 344
wolffd@0 345 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
wolffd@0 346 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
wolffd@0 347
wolffd@0 348 Margins = cat(1, Margins, Mnew);
wolffd@0 349 PsiR = cat(1, PsiR, PsiNew);
wolffd@0 350 PsiClock = cat(1, PsiClock, 0);
wolffd@0 351
wolffd@0 352 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
wolffd@0 353 dbprint(2, '\t Mean loss : %0.4f', Mnew);
wolffd@0 354 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
wolffd@0 355
wolffd@0 356 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
wolffd@0 357 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
wolffd@0 358
wolffd@0 359 if Termination <= Xi + E
wolffd@0 360 dbprint(1, 'Done.');
wolffd@0 361 break;
wolffd@0 362 end
wolffd@0 363
wolffd@0 364 dbprint(1, 'Calling solver...');
wolffd@0 365 PsiClock = PsiClock + 1;
wolffd@0 366 Solver_time = tic();
wolffd@0 367 [W, Xi, Dsolver] = mlr_solver(C, Margins, W, X);
wolffd@0 368 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
wolffd@0 369 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
wolffd@0 370
wolffd@0 371 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
wolffd@0 372 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
wolffd@0 373 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
wolffd@0 374
wolffd@0 375 %%%
wolffd@0 376 % Cull the old constraints
wolffd@0 377 GC = PsiClock < ConstraintClock;
wolffd@0 378 Margins = Margins(GC);
wolffd@0 379 PsiR = PsiR(GC);
wolffd@0 380 PsiClock = PsiClock(GC);
wolffd@0 381
wolffd@0 382 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
wolffd@0 383 end
wolffd@0 384
wolffd@0 385
wolffd@0 386 % Finish diagnostics
wolffd@0 387
wolffd@0 388 Diagnostics.time_total = toc(TIME_START);
wolffd@0 389 Diagnostics.feasible_count = FEASIBLE_COUNT;
wolffd@0 390 end
wolffd@0 391
wolffd@0 392
wolffd@0 393 function ClassScores = synthesizeRelevance(Y)
wolffd@0 394
wolffd@0 395 classes = unique(Y);
wolffd@0 396 nClasses = length(classes);
wolffd@0 397
wolffd@0 398 ClassScores = struct( 'Y', Y, ...
wolffd@0 399 'classes', classes, ...
wolffd@0 400 'Ypos', [], ...
wolffd@0 401 'Yneg', []);
wolffd@0 402
wolffd@0 403 Ypos = cell(nClasses, 1);
wolffd@0 404 Yneg = cell(nClasses, 1);
wolffd@0 405 for c = 1:nClasses
wolffd@0 406 Ypos{c} = (Y == classes(c));
wolffd@0 407 Yneg{c} = ~Ypos{c};
wolffd@0 408 end
wolffd@0 409
wolffd@0 410 ClassScores.Ypos = Ypos;
wolffd@0 411 ClassScores.Yneg = Yneg;
wolffd@0 412
wolffd@0 413 end