wolffd@0: function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin) wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C,...) wolffd@0: % wolffd@0: % X = d*n data matrix wolffd@0: % Y = either n-by-1 label of vectors wolffd@0: % OR wolffd@0: % n-by-2 cell array where wolffd@0: % Y{q,1} contains relevant indices for q, and wolffd@0: % Y{q,2} contains irrelevant indices for q wolffd@0: % wolffd@0: % C >= 0 slack trade-off parameter (default=1) wolffd@0: % wolffd@0: % W = the learned metric wolffd@0: % Xi = slack value on the learned metric wolffd@0: % D = diagnostics wolffd@0: % wolffd@0: % Optional arguments: wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS) wolffd@0: % where LOSS is one of: wolffd@0: % 'AUC': Area under ROC curve (default) wolffd@0: % 'KNN': KNN accuracy wolffd@0: % 'Prec@k': Precision-at-k wolffd@0: % 'MAP': Mean Average Precision wolffd@0: % 'MRR': Mean Reciprocal Rank wolffd@0: % 'NDCG': Normalized Discounted Cumulative Gain wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) wolffd@0: % where k is the number of neighbors for Prec@k or NDCG wolffd@0: % (default=3) wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) wolffd@0: % where REG defines the regularization on W, and is one of: wolffd@0: % 0: no regularization wolffd@0: % 1: 1-norm: trace(W) (default) wolffd@0: % 2: 2-norm: trace(W' * W) wolffd@0: % 3: Kernel: trace(W * X), assumes X is square and positive-definite wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) wolffd@0: % Diagonal = 0: learn a full d-by-d W (default) wolffd@0: % Diagonal = 1: learn diagonally-constrained W (d-by-1) wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) wolffd@0: % where B > 0 enables stochastic optimization with batch size B wolffd@0: % wolffd@0: % // added by Daniel Wolff wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC) wolffd@0: % Set ConstraintClock to CC (default: 20, 100) wolffd@0: % wolffd@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E) wolffd@0: % Set ConstraintClock to E (default: 1e-3) wolffd@0: % wolffd@0: wolffd@0: TIME_START = tic(); wolffd@0: wolffd@0: global C; wolffd@0: C = Cslack; wolffd@0: wolffd@0: [d,n,m] = size(X); wolffd@0: wolffd@0: if m > 1 wolffd@0: MKL = 1; wolffd@0: else wolffd@0: MKL = 0; wolffd@0: end wolffd@0: wolffd@0: if nargin < 3 wolffd@0: C = 1; wolffd@0: end wolffd@0: wolffd@0: %%% wolffd@0: % Default options: wolffd@0: wolffd@0: global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT; wolffd@0: wolffd@0: global FEASIBLE_COUNT; wolffd@0: FEASIBLE_COUNT = 0; wolffd@0: wolffd@0: CP = @cuttingPlaneFull; wolffd@0: SO = @separationOracleAUC; wolffd@0: PSI = @metricPsiPO; wolffd@0: wolffd@0: if ~MKL wolffd@0: INIT = @initializeFull; wolffd@0: REG = @regularizeTraceFull; wolffd@0: STRUCTKERNEL= @structKernelLinear; wolffd@0: DUALW = @dualWLinear; wolffd@0: FEASIBLE = @feasibleFull; wolffd@0: CPGRADIENT = @cpGradientFull; wolffd@0: DISTANCE = @distanceFull; wolffd@0: SETDISTANCE = @setDistanceFull; wolffd@0: LOSS = @lossHinge; wolffd@0: Regularizer = 'Trace'; wolffd@0: else wolffd@0: INIT = @initializeFullMKL; wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: FEASIBLE = @feasibleFullMKL; wolffd@0: CPGRADIENT = @cpGradientFullMKL; wolffd@0: DISTANCE = @distanceFullMKL; wolffd@0: SETDISTANCE = @setDistanceFullMKL; wolffd@0: LOSS = @lossHingeFullMKL; wolffd@0: Regularizer = 'Trace'; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: Loss = 'AUC'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: wolffd@0: wolffd@0: %%% wolffd@0: % Default k for prec@k, ndcg wolffd@0: k = 3; wolffd@0: wolffd@0: %%% wolffd@0: % Stochastic violator selection? wolffd@0: STOCHASTIC = 0; wolffd@0: batchSize = n; wolffd@0: SAMPLES = 1:n; wolffd@0: wolffd@0: wolffd@0: if nargin > 3 wolffd@0: switch lower(varargin{1}) wolffd@0: case {'auc'} wolffd@0: SO = @separationOracleAUC; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'AUC'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'knn'} wolffd@0: SO = @separationOracleKNN; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'KNN'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'prec@k'} wolffd@0: SO = @separationOraclePrecAtK; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'Prec@k'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'map'} wolffd@0: SO = @separationOracleMAP; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'MAP'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'mrr'} wolffd@0: SO = @separationOracleMRR; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'MRR'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'ndcg'} wolffd@0: SO = @separationOracleNDCG; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'NDCG'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: otherwise wolffd@0: error('MLR:LOSS', ... wolffd@0: 'Unknown loss function: %s', varargin{1}); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if nargin > 4 wolffd@0: k = varargin{2}; wolffd@0: end wolffd@0: wolffd@0: Diagonal = 0; wolffd@0: if nargin > 6 & varargin{4} > 0 wolffd@0: Diagonal = varargin{4}; wolffd@0: wolffd@0: if ~MKL wolffd@0: INIT = @initializeDiag; wolffd@0: REG = @regularizeTraceDiag; wolffd@0: STRUCTKERNEL= @structKernelDiag; wolffd@0: DUALW = @dualWDiag; wolffd@0: FEASIBLE = @feasibleDiag; wolffd@0: CPGRADIENT = @cpGradientDiag; wolffd@0: DISTANCE = @distanceDiag; wolffd@0: SETDISTANCE = @setDistanceDiag; wolffd@0: Regularizer = 'Trace'; wolffd@0: else wolffd@0: INIT = @initializeDiagMKL; wolffd@0: REG = @regularizeMKLDiag; wolffd@0: STRUCTKERNEL= @structKernelDiagMKL; wolffd@0: DUALW = @dualWDiagMKL; wolffd@0: FEASIBLE = @feasibleDiagMKL; wolffd@0: CPGRADIENT = @cpGradientDiagMKL; wolffd@0: DISTANCE = @distanceDiagMKL; wolffd@0: SETDISTANCE = @setDistanceDiagMKL; wolffd@0: LOSS = @lossHingeDiagMKL; wolffd@0: Regularizer = 'Trace'; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if nargin > 5 wolffd@0: switch(varargin{3}) wolffd@0: case {0} wolffd@0: REG = @regularizeNone; wolffd@0: Regularizer = 'None'; wolffd@0: wolffd@0: case {1} wolffd@0: if MKL wolffd@0: if Diagonal == 0 wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: elseif Diagonal == 1 wolffd@0: REG = @regularizeMKLDiag; wolffd@0: STRUCTKERNEL= @structKernelDiagMKL; wolffd@0: DUALW = @dualWDiagMKL; wolffd@0: end wolffd@0: else wolffd@0: if Diagonal wolffd@0: REG = @regularizeTraceDiag; wolffd@0: STRUCTKERNEL= @structKernelDiag; wolffd@0: DUALW = @dualWDiag; wolffd@0: else wolffd@0: REG = @regularizeTraceFull; wolffd@0: STRUCTKERNEL= @structKernelLinear; wolffd@0: DUALW = @dualWLinear; wolffd@0: end wolffd@0: end wolffd@0: Regularizer = 'Trace'; wolffd@0: wolffd@0: case {2} wolffd@0: if Diagonal wolffd@0: REG = @regularizeTwoDiag; wolffd@0: else wolffd@0: REG = @regularizeTwoFull; wolffd@0: end wolffd@0: Regularizer = '2-norm'; wolffd@0: error('MLR:REGULARIZER', '2-norm regularization no longer supported'); wolffd@0: wolffd@0: wolffd@0: case {3} wolffd@0: if MKL wolffd@0: if Diagonal == 0 wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: elseif Diagonal == 1 wolffd@0: REG = @regularizeMKLDiag; wolffd@0: STRUCTKERNEL= @structKernelDiagMKL; wolffd@0: DUALW = @dualWDiagMKL; wolffd@0: end wolffd@0: else wolffd@0: if Diagonal wolffd@0: REG = @regularizeMKLDiag; wolffd@0: STRUCTKERNEL= @structKernelDiagMKL; wolffd@0: DUALW = @dualWDiagMKL; wolffd@0: else wolffd@0: REG = @regularizeKernel; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: end wolffd@0: end wolffd@0: Regularizer = 'Kernel'; wolffd@0: wolffd@0: otherwise wolffd@0: error('MLR:REGULARIZER', ... wolffd@0: 'Unknown regularization: %s', varargin{3}); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % Are we in stochastic optimization mode? wolffd@0: if nargin > 7 && varargin{5} > 0 wolffd@0: if varargin{5} < n wolffd@0: STOCHASTIC = 1; wolffd@0: CP = @cuttingPlaneRandom; wolffd@0: batchSize = varargin{5}; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Algorithm wolffd@0: % wolffd@0: % Working <- [] wolffd@0: % wolffd@0: % repeat: wolffd@0: % (W, Xi) <- solver(X, Y, C, Working) wolffd@0: % wolffd@0: % for i = 1:|X| wolffd@0: % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) wolffd@0: % wolffd@0: % Working <- Working + (y^_1,y^_2,...,y^_n) wolffd@0: % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) wolffd@0: % <= Xi + epsilon wolffd@0: wolffd@0: global DEBUG; wolffd@0: wolffd@0: if isempty(DEBUG) wolffd@0: DEBUG = 0; wolffd@0: end wolffd@0: wolffd@0: %%% wolffd@0: % Timer to eliminate old constraints wolffd@0: wolffd@0: %%% wolffd@0: % Timer to eliminate old constraints wolffd@0: ConstraintClock = 100; % standard: 100 wolffd@0: wolffd@0: if nargin > 8 && varargin{6} > 0 wolffd@0: ConstraintClock = varargin{6}; wolffd@0: end wolffd@0: wolffd@0: %%% wolffd@0: % Convergence criteria for worst-violated constraint wolffd@0: E = 1e-3; wolffd@0: if nargin > 9 && varargin{7} > 0 wolffd@0: E = varargin{7}; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %XXX: 2012-01-31 21:29:50 by Brian McFee wolffd@0: % no longer belongs here wolffd@0: % Initialize wolffd@0: W = INIT(X); wolffd@0: wolffd@0: wolffd@0: global ADMM_Z ADMM_U RHO; wolffd@0: ADMM_Z = W; wolffd@0: ADMM_U = 0 * ADMM_Z; wolffd@0: wolffd@0: %%% wolffd@0: % Augmented lagrangian factor wolffd@0: RHO = 1; wolffd@0: wolffd@0: ClassScores = []; wolffd@0: wolffd@0: if isa(Y, 'double') wolffd@0: Ypos = []; wolffd@0: Yneg = []; wolffd@0: ClassScores = synthesizeRelevance(Y); wolffd@0: wolffd@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 wolffd@0: dbprint(2, 'Using supplied Ypos/synthesized Yneg'); wolffd@0: Ypos = Y(:,1); wolffd@0: Yneg = Y(:,2); wolffd@0: wolffd@0: % Compute the valid samples wolffd@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); wolffd@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 wolffd@0: dbprint(2, 'Using supplied Ypos/synthesized Yneg'); wolffd@0: Ypos = Y(:,1); wolffd@0: Yneg = []; wolffd@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); wolffd@0: else wolffd@0: error('MLR:LABELS', 'Incorrect format for Y.'); wolffd@0: end wolffd@0: wolffd@0: %% wolffd@0: % If we don't have enough data to make the batch, cut the batch wolffd@0: batchSize = min([batchSize, length(SAMPLES)]); wolffd@0: wolffd@0: wolffd@0: Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? wolffd@0: 'feature', Feature, ... % Which ranking feature is used? wolffd@0: 'k', k, ... % What is the ranking length? wolffd@0: 'regularizer', Regularizer, ... % What regularization is used? wolffd@0: 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal wolffd@0: 'num_calls_SO', 0, ... % Calls to separation oracle wolffd@0: 'num_calls_solver', 0, ... % Calls to solver wolffd@0: 'time_SO', 0, ... % Time in separation oracle wolffd@0: 'time_solver', 0, ... % Time in solver wolffd@0: 'time_total', 0, ... % Total time wolffd@0: 'f', [], ... % Objective value wolffd@0: 'num_steps', [], ... % Number of steps for each solver run wolffd@0: 'num_constraints', [], ... % Number of constraints for each run wolffd@0: 'Xi', [], ... % Slack achieved for each run wolffd@0: 'Delta', [], ... % Mean loss for each SO call wolffd@0: 'gap', [], ... % Gap between loss and slack wolffd@0: 'C', C, ... % Slack trade-off wolffd@0: 'epsilon', E, ... % Convergence threshold wolffd@0: 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's wolffd@0: 'constraint_timer', ConstraintClock); % Time before evicting old constraints wolffd@0: wolffd@0: wolffd@0: wolffd@0: global PsiR; wolffd@0: global PsiClock; wolffd@0: wolffd@0: PsiR = {}; wolffd@0: PsiClock = []; wolffd@0: wolffd@0: Xi = -Inf; wolffd@0: Margins = []; wolffd@0: H = []; wolffd@0: Q = []; wolffd@0: wolffd@0: if STOCHASTIC wolffd@0: dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); wolffd@0: end wolffd@0: wolffd@0: MAXITER = 200; wolffd@0: % while 1 wolffd@0: while Diagnostics.num_calls_solver < MAXITER wolffd@0: dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); wolffd@0: % Generate a constraint set wolffd@0: Termination = 0; wolffd@0: wolffd@0: wolffd@0: dbprint(2, 'Calling separation oracle...'); wolffd@0: wolffd@0: [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); wolffd@0: Termination = LOSS(W, PsiNew, Mnew, 0); wolffd@0: wolffd@0: Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; wolffd@0: Diagnostics.time_SO = Diagnostics.time_SO + SO_time; wolffd@0: wolffd@0: Margins = cat(1, Margins, Mnew); wolffd@0: PsiR = cat(1, PsiR, PsiNew); wolffd@0: PsiClock = cat(1, PsiClock, 0); wolffd@0: H = expandKernel(H); wolffd@0: Q = expandRegularizer(Q, X, W); wolffd@0: wolffd@0: wolffd@0: dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); wolffd@0: dbprint(2, '\t Mean loss : %0.4f', Mnew); wolffd@0: dbprint(2, '\t Current loss Xi : %0.4f', Xi); wolffd@0: dbprint(2, '\t Termination -Xi < E : %0.4f wolffd@0: % does not support unregularized learning wolffd@0: wolffd@0: global PsiR; wolffd@0: global STRUCTKERNEL REG; wolffd@0: wolffd@0: m = length(Q); wolffd@0: Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1}); wolffd@0: wolffd@0: end wolffd@0: wolffd@0: function ClassScores = synthesizeRelevance(Y) wolffd@0: wolffd@0: classes = unique(Y); wolffd@0: nClasses = length(classes); wolffd@0: wolffd@0: ClassScores = struct( 'Y', Y, ... wolffd@0: 'classes', classes, ... wolffd@0: 'Ypos', [], ... wolffd@0: 'Yneg', []); wolffd@0: wolffd@0: Ypos = cell(nClasses, 1); wolffd@0: Yneg = cell(nClasses, 1); wolffd@0: for c = 1:nClasses wolffd@0: Ypos{c} = (Y == classes(c)); wolffd@0: Yneg{c} = ~Ypos{c}; wolffd@0: end wolffd@0: wolffd@0: ClassScores.Ypos = Ypos; wolffd@0: ClassScores.Yneg = Yneg; wolffd@0: wolffd@0: end