annotate toolboxes/distance_learning/mlr/mlr_train.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
wolffd@0 2 %
wolffd@0 3 % [W, Xi, D] = mlr_train(X, Y, C,...)
wolffd@0 4 %
wolffd@0 5 % X = d*n data matrix
wolffd@0 6 % Y = either n-by-1 label of vectors
wolffd@0 7 % OR
wolffd@0 8 % n-by-2 cell array where
wolffd@0 9 % Y{q,1} contains relevant indices for q, and
wolffd@0 10 % Y{q,2} contains irrelevant indices for q
wolffd@0 11 %
wolffd@0 12 % C >= 0 slack trade-off parameter (default=1)
wolffd@0 13 %
wolffd@0 14 % W = the learned metric
wolffd@0 15 % Xi = slack value on the learned metric
wolffd@0 16 % D = diagnostics
wolffd@0 17 %
wolffd@0 18 % Optional arguments:
wolffd@0 19 %
wolffd@0 20 % [W, Xi, D] = mlr_train(X, Y, C, LOSS)
wolffd@0 21 % where LOSS is one of:
wolffd@0 22 % 'AUC': Area under ROC curve (default)
wolffd@0 23 % 'KNN': KNN accuracy
wolffd@0 24 % 'Prec@k': Precision-at-k
wolffd@0 25 % 'MAP': Mean Average Precision
wolffd@0 26 % 'MRR': Mean Reciprocal Rank
wolffd@0 27 % 'NDCG': Normalized Discounted Cumulative Gain
wolffd@0 28 %
wolffd@0 29 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
wolffd@0 30 % where k is the number of neighbors for Prec@k or NDCG
wolffd@0 31 % (default=3)
wolffd@0 32 %
wolffd@0 33 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
wolffd@0 34 % where REG defines the regularization on W, and is one of:
wolffd@0 35 % 0: no regularization
wolffd@0 36 % 1: 1-norm: trace(W) (default)
wolffd@0 37 % 2: 2-norm: trace(W' * W)
wolffd@0 38 % 3: Kernel: trace(W * X), assumes X is square and positive-definite
wolffd@0 39 %
wolffd@0 40 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
wolffd@0 41 % Diagonal = 0: learn a full d-by-d W (default)
wolffd@0 42 % Diagonal = 1: learn diagonally-constrained W (d-by-1)
wolffd@0 43 %
wolffd@0 44 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
wolffd@0 45 % where B > 0 enables stochastic optimization with batch size B
wolffd@0 46 %
wolffd@0 47 % // added by Daniel Wolff
wolffd@0 48 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
wolffd@0 49 % Set ConstraintClock to CC (default: 20, 100)
wolffd@0 50 %
wolffd@0 51 % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
wolffd@0 52 % Set ConstraintClock to E (default: 1e-3)
wolffd@0 53 %
wolffd@0 54
wolffd@0 55 TIME_START = tic();
wolffd@0 56
wolffd@0 57 global C;
wolffd@0 58 C = Cslack;
wolffd@0 59
wolffd@0 60 [d,n,m] = size(X);
wolffd@0 61
wolffd@0 62 if m > 1
wolffd@0 63 MKL = 1;
wolffd@0 64 else
wolffd@0 65 MKL = 0;
wolffd@0 66 end
wolffd@0 67
wolffd@0 68 if nargin < 3
wolffd@0 69 C = 1;
wolffd@0 70 end
wolffd@0 71
wolffd@0 72 %%%
wolffd@0 73 % Default options:
wolffd@0 74
wolffd@0 75 global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT;
wolffd@0 76
wolffd@0 77 global FEASIBLE_COUNT;
wolffd@0 78 FEASIBLE_COUNT = 0;
wolffd@0 79
wolffd@0 80 CP = @cuttingPlaneFull;
wolffd@0 81 SO = @separationOracleAUC;
wolffd@0 82 PSI = @metricPsiPO;
wolffd@0 83
wolffd@0 84 if ~MKL
wolffd@0 85 INIT = @initializeFull;
wolffd@0 86 REG = @regularizeTraceFull;
wolffd@0 87 STRUCTKERNEL= @structKernelLinear;
wolffd@0 88 DUALW = @dualWLinear;
wolffd@0 89 FEASIBLE = @feasibleFull;
wolffd@0 90 CPGRADIENT = @cpGradientFull;
wolffd@0 91 DISTANCE = @distanceFull;
wolffd@0 92 SETDISTANCE = @setDistanceFull;
wolffd@0 93 LOSS = @lossHinge;
wolffd@0 94 Regularizer = 'Trace';
wolffd@0 95 else
wolffd@0 96 INIT = @initializeFullMKL;
wolffd@0 97 REG = @regularizeMKLFull;
wolffd@0 98 STRUCTKERNEL= @structKernelMKL;
wolffd@0 99 DUALW = @dualWMKL;
wolffd@0 100 FEASIBLE = @feasibleFullMKL;
wolffd@0 101 CPGRADIENT = @cpGradientFullMKL;
wolffd@0 102 DISTANCE = @distanceFullMKL;
wolffd@0 103 SETDISTANCE = @setDistanceFullMKL;
wolffd@0 104 LOSS = @lossHingeFullMKL;
wolffd@0 105 Regularizer = 'Trace';
wolffd@0 106 end
wolffd@0 107
wolffd@0 108
wolffd@0 109 Loss = 'AUC';
wolffd@0 110 Feature = 'metricPsiPO';
wolffd@0 111
wolffd@0 112
wolffd@0 113 %%%
wolffd@0 114 % Default k for prec@k, ndcg
wolffd@0 115 k = 3;
wolffd@0 116
wolffd@0 117 %%%
wolffd@0 118 % Stochastic violator selection?
wolffd@0 119 STOCHASTIC = 0;
wolffd@0 120 batchSize = n;
wolffd@0 121 SAMPLES = 1:n;
wolffd@0 122
wolffd@0 123
wolffd@0 124 if nargin > 3
wolffd@0 125 switch lower(varargin{1})
wolffd@0 126 case {'auc'}
wolffd@0 127 SO = @separationOracleAUC;
wolffd@0 128 PSI = @metricPsiPO;
wolffd@0 129 Loss = 'AUC';
wolffd@0 130 Feature = 'metricPsiPO';
wolffd@0 131 case {'knn'}
wolffd@0 132 SO = @separationOracleKNN;
wolffd@0 133 PSI = @metricPsiPO;
wolffd@0 134 Loss = 'KNN';
wolffd@0 135 Feature = 'metricPsiPO';
wolffd@0 136 case {'prec@k'}
wolffd@0 137 SO = @separationOraclePrecAtK;
wolffd@0 138 PSI = @metricPsiPO;
wolffd@0 139 Loss = 'Prec@k';
wolffd@0 140 Feature = 'metricPsiPO';
wolffd@0 141 case {'map'}
wolffd@0 142 SO = @separationOracleMAP;
wolffd@0 143 PSI = @metricPsiPO;
wolffd@0 144 Loss = 'MAP';
wolffd@0 145 Feature = 'metricPsiPO';
wolffd@0 146 case {'mrr'}
wolffd@0 147 SO = @separationOracleMRR;
wolffd@0 148 PSI = @metricPsiPO;
wolffd@0 149 Loss = 'MRR';
wolffd@0 150 Feature = 'metricPsiPO';
wolffd@0 151 case {'ndcg'}
wolffd@0 152 SO = @separationOracleNDCG;
wolffd@0 153 PSI = @metricPsiPO;
wolffd@0 154 Loss = 'NDCG';
wolffd@0 155 Feature = 'metricPsiPO';
wolffd@0 156 otherwise
wolffd@0 157 error('MLR:LOSS', ...
wolffd@0 158 'Unknown loss function: %s', varargin{1});
wolffd@0 159 end
wolffd@0 160 end
wolffd@0 161
wolffd@0 162 if nargin > 4
wolffd@0 163 k = varargin{2};
wolffd@0 164 end
wolffd@0 165
wolffd@0 166 Diagonal = 0;
wolffd@0 167 if nargin > 6 & varargin{4} > 0
wolffd@0 168 Diagonal = varargin{4};
wolffd@0 169
wolffd@0 170 if ~MKL
wolffd@0 171 INIT = @initializeDiag;
wolffd@0 172 REG = @regularizeTraceDiag;
wolffd@0 173 STRUCTKERNEL= @structKernelDiag;
wolffd@0 174 DUALW = @dualWDiag;
wolffd@0 175 FEASIBLE = @feasibleDiag;
wolffd@0 176 CPGRADIENT = @cpGradientDiag;
wolffd@0 177 DISTANCE = @distanceDiag;
wolffd@0 178 SETDISTANCE = @setDistanceDiag;
wolffd@0 179 Regularizer = 'Trace';
wolffd@0 180 else
wolffd@0 181 INIT = @initializeDiagMKL;
wolffd@0 182 REG = @regularizeMKLDiag;
wolffd@0 183 STRUCTKERNEL= @structKernelDiagMKL;
wolffd@0 184 DUALW = @dualWDiagMKL;
wolffd@0 185 FEASIBLE = @feasibleDiagMKL;
wolffd@0 186 CPGRADIENT = @cpGradientDiagMKL;
wolffd@0 187 DISTANCE = @distanceDiagMKL;
wolffd@0 188 SETDISTANCE = @setDistanceDiagMKL;
wolffd@0 189 LOSS = @lossHingeDiagMKL;
wolffd@0 190 Regularizer = 'Trace';
wolffd@0 191 end
wolffd@0 192 end
wolffd@0 193
wolffd@0 194 if nargin > 5
wolffd@0 195 switch(varargin{3})
wolffd@0 196 case {0}
wolffd@0 197 REG = @regularizeNone;
wolffd@0 198 Regularizer = 'None';
wolffd@0 199
wolffd@0 200 case {1}
wolffd@0 201 if MKL
wolffd@0 202 if Diagonal == 0
wolffd@0 203 REG = @regularizeMKLFull;
wolffd@0 204 STRUCTKERNEL= @structKernelMKL;
wolffd@0 205 DUALW = @dualWMKL;
wolffd@0 206 elseif Diagonal == 1
wolffd@0 207 REG = @regularizeMKLDiag;
wolffd@0 208 STRUCTKERNEL= @structKernelDiagMKL;
wolffd@0 209 DUALW = @dualWDiagMKL;
wolffd@0 210 end
wolffd@0 211 else
wolffd@0 212 if Diagonal
wolffd@0 213 REG = @regularizeTraceDiag;
wolffd@0 214 STRUCTKERNEL= @structKernelDiag;
wolffd@0 215 DUALW = @dualWDiag;
wolffd@0 216 else
wolffd@0 217 REG = @regularizeTraceFull;
wolffd@0 218 STRUCTKERNEL= @structKernelLinear;
wolffd@0 219 DUALW = @dualWLinear;
wolffd@0 220 end
wolffd@0 221 end
wolffd@0 222 Regularizer = 'Trace';
wolffd@0 223
wolffd@0 224 case {2}
wolffd@0 225 if Diagonal
wolffd@0 226 REG = @regularizeTwoDiag;
wolffd@0 227 else
wolffd@0 228 REG = @regularizeTwoFull;
wolffd@0 229 end
wolffd@0 230 Regularizer = '2-norm';
wolffd@0 231 error('MLR:REGULARIZER', '2-norm regularization no longer supported');
wolffd@0 232
wolffd@0 233
wolffd@0 234 case {3}
wolffd@0 235 if MKL
wolffd@0 236 if Diagonal == 0
wolffd@0 237 REG = @regularizeMKLFull;
wolffd@0 238 STRUCTKERNEL= @structKernelMKL;
wolffd@0 239 DUALW = @dualWMKL;
wolffd@0 240 elseif Diagonal == 1
wolffd@0 241 REG = @regularizeMKLDiag;
wolffd@0 242 STRUCTKERNEL= @structKernelDiagMKL;
wolffd@0 243 DUALW = @dualWDiagMKL;
wolffd@0 244 end
wolffd@0 245 else
wolffd@0 246 if Diagonal
wolffd@0 247 REG = @regularizeMKLDiag;
wolffd@0 248 STRUCTKERNEL= @structKernelDiagMKL;
wolffd@0 249 DUALW = @dualWDiagMKL;
wolffd@0 250 else
wolffd@0 251 REG = @regularizeKernel;
wolffd@0 252 STRUCTKERNEL= @structKernelMKL;
wolffd@0 253 DUALW = @dualWMKL;
wolffd@0 254 end
wolffd@0 255 end
wolffd@0 256 Regularizer = 'Kernel';
wolffd@0 257
wolffd@0 258 otherwise
wolffd@0 259 error('MLR:REGULARIZER', ...
wolffd@0 260 'Unknown regularization: %s', varargin{3});
wolffd@0 261 end
wolffd@0 262 end
wolffd@0 263
wolffd@0 264
wolffd@0 265 % Are we in stochastic optimization mode?
wolffd@0 266 if nargin > 7 && varargin{5} > 0
wolffd@0 267 if varargin{5} < n
wolffd@0 268 STOCHASTIC = 1;
wolffd@0 269 CP = @cuttingPlaneRandom;
wolffd@0 270 batchSize = varargin{5};
wolffd@0 271 end
wolffd@0 272 end
wolffd@0 273
wolffd@0 274 % Algorithm
wolffd@0 275 %
wolffd@0 276 % Working <- []
wolffd@0 277 %
wolffd@0 278 % repeat:
wolffd@0 279 % (W, Xi) <- solver(X, Y, C, Working)
wolffd@0 280 %
wolffd@0 281 % for i = 1:|X|
wolffd@0 282 % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
wolffd@0 283 %
wolffd@0 284 % Working <- Working + (y^_1,y^_2,...,y^_n)
wolffd@0 285 % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i)))
wolffd@0 286 % <= Xi + epsilon
wolffd@0 287
wolffd@0 288 global DEBUG;
wolffd@0 289
wolffd@0 290 if isempty(DEBUG)
wolffd@0 291 DEBUG = 0;
wolffd@0 292 end
wolffd@0 293
wolffd@0 294 %%%
wolffd@0 295 % Timer to eliminate old constraints
wolffd@0 296
wolffd@0 297 %%%
wolffd@0 298 % Timer to eliminate old constraints
wolffd@0 299 ConstraintClock = 100; % standard: 100
wolffd@0 300
wolffd@0 301 if nargin > 8 && varargin{6} > 0
wolffd@0 302 ConstraintClock = varargin{6};
wolffd@0 303 end
wolffd@0 304
wolffd@0 305 %%%
wolffd@0 306 % Convergence criteria for worst-violated constraint
wolffd@0 307 E = 1e-3;
wolffd@0 308 if nargin > 9 && varargin{7} > 0
wolffd@0 309 E = varargin{7};
wolffd@0 310 end
wolffd@0 311
wolffd@0 312
wolffd@0 313 %XXX: 2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
wolffd@0 314 % no longer belongs here
wolffd@0 315 % Initialize
wolffd@0 316 W = INIT(X);
wolffd@0 317
wolffd@0 318
wolffd@0 319 global ADMM_Z ADMM_U RHO;
wolffd@0 320 ADMM_Z = W;
wolffd@0 321 ADMM_U = 0 * ADMM_Z;
wolffd@0 322
wolffd@0 323 %%%
wolffd@0 324 % Augmented lagrangian factor
wolffd@0 325 RHO = 1;
wolffd@0 326
wolffd@0 327 ClassScores = [];
wolffd@0 328
wolffd@0 329 if isa(Y, 'double')
wolffd@0 330 Ypos = [];
wolffd@0 331 Yneg = [];
wolffd@0 332 ClassScores = synthesizeRelevance(Y);
wolffd@0 333
wolffd@0 334 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
wolffd@0 335 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
wolffd@0 336 Ypos = Y(:,1);
wolffd@0 337 Yneg = Y(:,2);
wolffd@0 338
wolffd@0 339 % Compute the valid samples
wolffd@0 340 SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
wolffd@0 341 elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
wolffd@0 342 dbprint(2, 'Using supplied Ypos/synthesized Yneg');
wolffd@0 343 Ypos = Y(:,1);
wolffd@0 344 Yneg = [];
wolffd@0 345 SAMPLES = find( ~(cellfun(@isempty, Y(:,1))));
wolffd@0 346 else
wolffd@0 347 error('MLR:LABELS', 'Incorrect format for Y.');
wolffd@0 348 end
wolffd@0 349
wolffd@0 350 %%
wolffd@0 351 % If we don't have enough data to make the batch, cut the batch
wolffd@0 352 batchSize = min([batchSize, length(SAMPLES)]);
wolffd@0 353
wolffd@0 354
wolffd@0 355 Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing?
wolffd@0 356 'feature', Feature, ... % Which ranking feature is used?
wolffd@0 357 'k', k, ... % What is the ranking length?
wolffd@0 358 'regularizer', Regularizer, ... % What regularization is used?
wolffd@0 359 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal
wolffd@0 360 'num_calls_SO', 0, ... % Calls to separation oracle
wolffd@0 361 'num_calls_solver', 0, ... % Calls to solver
wolffd@0 362 'time_SO', 0, ... % Time in separation oracle
wolffd@0 363 'time_solver', 0, ... % Time in solver
wolffd@0 364 'time_total', 0, ... % Total time
wolffd@0 365 'f', [], ... % Objective value
wolffd@0 366 'num_steps', [], ... % Number of steps for each solver run
wolffd@0 367 'num_constraints', [], ... % Number of constraints for each run
wolffd@0 368 'Xi', [], ... % Slack achieved for each run
wolffd@0 369 'Delta', [], ... % Mean loss for each SO call
wolffd@0 370 'gap', [], ... % Gap between loss and slack
wolffd@0 371 'C', C, ... % Slack trade-off
wolffd@0 372 'epsilon', E, ... % Convergence threshold
wolffd@0 373 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's
wolffd@0 374 'constraint_timer', ConstraintClock); % Time before evicting old constraints
wolffd@0 375
wolffd@0 376
wolffd@0 377
wolffd@0 378 global PsiR;
wolffd@0 379 global PsiClock;
wolffd@0 380
wolffd@0 381 PsiR = {};
wolffd@0 382 PsiClock = [];
wolffd@0 383
wolffd@0 384 Xi = -Inf;
wolffd@0 385 Margins = [];
wolffd@0 386 H = [];
wolffd@0 387 Q = [];
wolffd@0 388
wolffd@0 389 if STOCHASTIC
wolffd@0 390 dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
wolffd@0 391 end
wolffd@0 392
wolffd@0 393 MAXITER = 200;
wolffd@0 394 % while 1
wolffd@0 395 while Diagnostics.num_calls_solver < MAXITER
wolffd@0 396 dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
wolffd@0 397 % Generate a constraint set
wolffd@0 398 Termination = 0;
wolffd@0 399
wolffd@0 400
wolffd@0 401 dbprint(2, 'Calling separation oracle...');
wolffd@0 402
wolffd@0 403 [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
wolffd@0 404 Termination = LOSS(W, PsiNew, Mnew, 0);
wolffd@0 405
wolffd@0 406 Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1;
wolffd@0 407 Diagnostics.time_SO = Diagnostics.time_SO + SO_time;
wolffd@0 408
wolffd@0 409 Margins = cat(1, Margins, Mnew);
wolffd@0 410 PsiR = cat(1, PsiR, PsiNew);
wolffd@0 411 PsiClock = cat(1, PsiClock, 0);
wolffd@0 412 H = expandKernel(H);
wolffd@0 413 Q = expandRegularizer(Q, X, W);
wolffd@0 414
wolffd@0 415
wolffd@0 416 dbprint(2, '\n\tActive constraints : %d', length(PsiClock));
wolffd@0 417 dbprint(2, '\t Mean loss : %0.4f', Mnew);
wolffd@0 418 dbprint(2, '\t Current loss Xi : %0.4f', Xi);
wolffd@0 419 dbprint(2, '\t Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
wolffd@0 420
wolffd@0 421 Diagnostics.gap = cat(1, Diagnostics.gap, Termination - Xi);
wolffd@0 422 Diagnostics.Delta = cat(1, Diagnostics.Delta, Mnew);
wolffd@0 423
wolffd@0 424 if Termination <= Xi + E
wolffd@0 425 dbprint(2, 'Done.');
wolffd@0 426 break;
wolffd@0 427 end
wolffd@0 428
wolffd@0 429 dbprint(2, 'Calling solver...');
wolffd@0 430 PsiClock = PsiClock + 1;
wolffd@0 431 Solver_time = tic();
wolffd@0 432 [W, Xi, Dsolver] = mlr_admm(C, X, Margins, H, Q);
wolffd@0 433 Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time);
wolffd@0 434 Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1;
wolffd@0 435
wolffd@0 436 Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi);
wolffd@0 437 Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f);
wolffd@0 438 Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps);
wolffd@0 439
wolffd@0 440 %%%
wolffd@0 441 % Cull the old constraints
wolffd@0 442 GC = PsiClock < ConstraintClock;
wolffd@0 443 Margins = Margins(GC);
wolffd@0 444 PsiR = PsiR(GC);
wolffd@0 445 PsiClock = PsiClock(GC);
wolffd@0 446 H = H(GC, GC);
wolffd@0 447 Q = Q(GC);
wolffd@0 448
wolffd@0 449 Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
wolffd@0 450 end
wolffd@0 451
wolffd@0 452
wolffd@0 453 % Finish diagnostics
wolffd@0 454
wolffd@0 455 Diagnostics.time_total = toc(TIME_START);
wolffd@0 456 Diagnostics.feasible_count = FEASIBLE_COUNT;
wolffd@0 457 end
wolffd@0 458
wolffd@0 459 function H = expandKernel(H)
wolffd@0 460
wolffd@0 461 global STRUCTKERNEL;
wolffd@0 462 global PsiR;
wolffd@0 463
wolffd@0 464 m = length(H);
wolffd@0 465 H = padarray(H, [1 1], 0, 'post');
wolffd@0 466
wolffd@0 467
wolffd@0 468 for i = 1:m+1
wolffd@0 469 H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
wolffd@0 470 H(m+1, i) = H(i, m+1);
wolffd@0 471 end
wolffd@0 472 end
wolffd@0 473
wolffd@0 474 function Q = expandRegularizer(Q, K, W)
wolffd@0 475
wolffd@0 476 % FIXME: 2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
wolffd@0 477 % does not support unregularized learning
wolffd@0 478
wolffd@0 479 global PsiR;
wolffd@0 480 global STRUCTKERNEL REG;
wolffd@0 481
wolffd@0 482 m = length(Q);
wolffd@0 483 Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});
wolffd@0 484
wolffd@0 485 end
wolffd@0 486
wolffd@0 487 function ClassScores = synthesizeRelevance(Y)
wolffd@0 488
wolffd@0 489 classes = unique(Y);
wolffd@0 490 nClasses = length(classes);
wolffd@0 491
wolffd@0 492 ClassScores = struct( 'Y', Y, ...
wolffd@0 493 'classes', classes, ...
wolffd@0 494 'Ypos', [], ...
wolffd@0 495 'Yneg', []);
wolffd@0 496
wolffd@0 497 Ypos = cell(nClasses, 1);
wolffd@0 498 Yneg = cell(nClasses, 1);
wolffd@0 499 for c = 1:nClasses
wolffd@0 500 Ypos{c} = (Y == classes(c));
wolffd@0 501 Yneg{c} = ~Ypos{c};
wolffd@0 502 end
wolffd@0 503
wolffd@0 504 ClassScores.Ypos = Ypos;
wolffd@0 505 ClassScores.Yneg = Yneg;
wolffd@0 506
wolffd@0 507 end