Daniel@0: function [W, Xi, Diagnostics] = mlr_train(X, Y, C, varargin) Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C,...) Daniel@0: % Daniel@0: % X = d*n data matrix Daniel@0: % Y = either n-by-1 label of vectors Daniel@0: % OR Daniel@0: % n-by-2 cell array where Daniel@0: % Y{q,1} contains relevant indices for q, and Daniel@0: % Y{q,2} contains irrelevant indices for q Daniel@0: % Daniel@0: % C >= 0 slack trade-off parameter (default=1) Daniel@0: % Daniel@0: % W = the learned metric Daniel@0: % Xi = slack value on the learned metric Daniel@0: % D = diagnostics Daniel@0: % Daniel@0: % Optional arguments: Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS) Daniel@0: % where LOSS is one of: Daniel@0: % 'AUC': Area under ROC curve (default) Daniel@0: % 'KNN': KNN accuracy Daniel@0: % 'Prec@k': Precision-at-k Daniel@0: % 'MAP': Mean Average Precision Daniel@0: % 'MRR': Mean Reciprocal Rank Daniel@0: % 'NDCG': Normalized Discounted Cumulative Gain Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k) Daniel@0: % where k is the number of neighbors for Prec@k or NDCG Daniel@0: % (default=3) Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG) Daniel@0: % where REG defines the regularization on W, and is one of: Daniel@0: % 0: no regularization Daniel@0: % 1: 1-norm: trace(W) (default) Daniel@0: % 2: 2-norm: trace(W' * W) Daniel@0: % 3: Kernel: trace(W * X), assumes X is square and positive-definite Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal) Daniel@0: % Diagonal = 0: learn a full d-by-d W (default) Daniel@0: % Diagonal = 1: learn diagonally-constrained W (d-by-1) Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) Daniel@0: % where B > 0 enables stochastic optimization with batch size B Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC) Daniel@0: % Set ConstraintClock to CC (default: 20, 100) Daniel@0: % Daniel@0: % [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E) Daniel@0: % Set ConstraintClock to E (default: 1e-3) Daniel@0: % Daniel@0: Daniel@0: Daniel@0: global globalvars; Daniel@0: global DEBUG; Daniel@0: Daniel@0: if isfield(globalvars, 'debug') Daniel@0: Daniel@0: DEBUG = globalvars.debug; Daniel@0: else Daniel@0: Daniel@0: DEBUG = 0; Daniel@0: end Daniel@0: Daniel@0: Daniel@0: Daniel@0: TIME_START = tic(); Daniel@0: Daniel@0: % addpath('cuttingPlane', 'distance', 'feasible', 'initialize', 'loss', ... Daniel@0: % 'metricPsi', 'regularize', 'separationOracle', 'util'); Daniel@0: Daniel@0: [d,n,m] = size(X); Daniel@0: Daniel@0: if m > 1 Daniel@0: MKL = 1; Daniel@0: else Daniel@0: MKL = 0; Daniel@0: end Daniel@0: Daniel@0: if nargin < 3 Daniel@0: C = 1; Daniel@0: end Daniel@0: Daniel@0: %%% Daniel@0: % Default options: Daniel@0: Daniel@0: global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT METRICK; Daniel@0: Daniel@0: CP = @cuttingPlaneFull; Daniel@0: SO = @separationOracleAUC; Daniel@0: PSI = @metricPsiPO; Daniel@0: Daniel@0: if ~MKL Daniel@0: INIT = @initializeFull; Daniel@0: REG = @regularizeTraceFull; Daniel@0: FEASIBLE = @feasibleFull; Daniel@0: CPGRADIENT = @cpGradientFull; Daniel@0: DISTANCE = @distanceFull; Daniel@0: SETDISTANCE = @setDistanceFull; Daniel@0: LOSS = @lossHinge; Daniel@0: Regularizer = 'Trace'; Daniel@0: else Daniel@0: INIT = @initializeFullMKL; Daniel@0: REG = @regularizeMKLFull; Daniel@0: FEASIBLE = @feasibleFullMKL; Daniel@0: CPGRADIENT = @cpGradientFullMKL; Daniel@0: DISTANCE = @distanceFullMKL; Daniel@0: SETDISTANCE = @setDistanceFullMKL; Daniel@0: LOSS = @lossHingeFullMKL; Daniel@0: Regularizer = 'Trace'; Daniel@0: end Daniel@0: Daniel@0: Daniel@0: Loss = 'AUC'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: Daniel@0: Daniel@0: %%% Daniel@0: % Default k for prec@k, ndcg Daniel@0: k = 3; Daniel@0: Daniel@0: %%% Daniel@0: % Stochastic violator selection? Daniel@0: STOCHASTIC = 0; Daniel@0: batchSize = n; Daniel@0: SAMPLES = 1:n; Daniel@0: Daniel@0: Daniel@0: if nargin > 3 Daniel@0: switch lower(varargin{1}) Daniel@0: case {'auc'} Daniel@0: SO = @separationOracleAUC; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'AUC'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: case {'knn'} Daniel@0: SO = @separationOracleKNN; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'KNN'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: case {'prec@k'} Daniel@0: SO = @separationOraclePrecAtK; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'Prec@k'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: case {'map'} Daniel@0: SO = @separationOracleMAP; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'MAP'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: case {'mrr'} Daniel@0: SO = @separationOracleMRR; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'MRR'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: case {'ndcg'} Daniel@0: SO = @separationOracleNDCG; Daniel@0: PSI = @metricPsiPO; Daniel@0: Loss = 'NDCG'; Daniel@0: Feature = 'metricPsiPO'; Daniel@0: otherwise Daniel@0: error('MLR:LOSS', ... Daniel@0: 'Unknown loss function: %s', varargin{1}); Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: if nargin > 4 Daniel@0: k = varargin{2}; Daniel@0: end Daniel@0: Daniel@0: METRICK = k; Daniel@0: Daniel@0: Diagonal = 0; Daniel@0: if nargin > 6 & varargin{4} > 0 Daniel@0: Diagonal = varargin{4}; Daniel@0: Daniel@0: if ~MKL Daniel@0: INIT = @initializeDiag; Daniel@0: REG = @regularizeTraceDiag; Daniel@0: FEASIBLE = @feasibleDiag; Daniel@0: CPGRADIENT = @cpGradientDiag; Daniel@0: DISTANCE = @distanceDiag; Daniel@0: SETDISTANCE = @setDistanceDiag; Daniel@0: Regularizer = 'Trace'; Daniel@0: else Daniel@0: if Diagonal > 1 Daniel@0: INIT = @initializeDODMKL; Daniel@0: REG = @regularizeMKLDOD; Daniel@0: FEASIBLE = @feasibleDODMKL; Daniel@0: CPGRADIENT = @cpGradientDODMKL; Daniel@0: DISTANCE = @distanceDODMKL; Daniel@0: SETDISTANCE = @setDistanceDODMKL; Daniel@0: LOSS = @lossHingeDODMKL; Daniel@0: Regularizer = 'Trace'; Daniel@0: else Daniel@0: INIT = @initializeDiagMKL; Daniel@0: REG = @regularizeMKLDiag; Daniel@0: FEASIBLE = @feasibleDiagMKL; Daniel@0: CPGRADIENT = @cpGradientDiagMKL; Daniel@0: DISTANCE = @distanceDiagMKL; Daniel@0: SETDISTANCE = @setDistanceDiagMKL; Daniel@0: LOSS = @lossHingeDiagMKL; Daniel@0: Regularizer = 'Trace'; Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: if nargin > 5 Daniel@0: switch(varargin{3}) Daniel@0: case {0} Daniel@0: REG = @regularizeNone; Daniel@0: Regularizer = 'None'; Daniel@0: Daniel@0: case {1} Daniel@0: if MKL Daniel@0: if Diagonal == 0 Daniel@0: REG = @regularizeMKLFull; Daniel@0: elseif Diagonal == 1 Daniel@0: REG = @regularizeMKLDiag; Daniel@0: elseif Diagonal == 2 Daniel@0: REG = @regularizeMKLDOD; Daniel@0: end Daniel@0: else Daniel@0: if Diagonal Daniel@0: REG = @regularizeTraceDiag; Daniel@0: else Daniel@0: REG = @regularizeTraceFull; Daniel@0: end Daniel@0: end Daniel@0: Regularizer = 'Trace'; Daniel@0: Daniel@0: case {2} Daniel@0: if Diagonal Daniel@0: REG = @regularizeTwoDiag; Daniel@0: else Daniel@0: REG = @regularizeTwoFull; Daniel@0: end Daniel@0: Regularizer = '2-norm'; Daniel@0: Daniel@0: case {3} Daniel@0: if MKL Daniel@0: if Diagonal == 0 Daniel@0: REG = @regularizeMKLFull; Daniel@0: elseif Diagonal == 1 Daniel@0: REG = @regularizeMKLDiag; Daniel@0: elseif Diagonal == 2 Daniel@0: REG = @regularizeMKLDOD; Daniel@0: end Daniel@0: else Daniel@0: if Diagonal Daniel@0: REG = @regularizeMKLDiag; Daniel@0: else Daniel@0: REG = @regularizeKernel; Daniel@0: end Daniel@0: end Daniel@0: Regularizer = 'Kernel'; Daniel@0: Daniel@0: otherwise Daniel@0: error('MLR:REGULARIZER', ... Daniel@0: 'Unknown regularization: %s', varargin{3}); Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: Daniel@0: % Are we in stochastic optimization mode? Daniel@0: if nargin > 7 && varargin{5} > 0 Daniel@0: if varargin{5} < n Daniel@0: STOCHASTIC = 1; Daniel@0: CP = @cuttingPlaneRandom; Daniel@0: batchSize = varargin{5}; Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: %%% Daniel@0: % Timer to eliminate old constraints Daniel@0: ConstraintClock = 20; Daniel@0: Daniel@0: if nargin > 8 && varargin{6} > 0 Daniel@0: ConstraintClock = varargin{6}; Daniel@0: end Daniel@0: Daniel@0: %%% Daniel@0: % Convergence criteria for worst-violated constraint Daniel@0: E = 1e-3; Daniel@0: if nargin > 9 && varargin{7} > 0 Daniel@0: E = varargin{7}; Daniel@0: end Daniel@0: Daniel@0: % Algorithm Daniel@0: % Daniel@0: % Working <- [] Daniel@0: % Daniel@0: % repeat: Daniel@0: % (W, Xi) <- solver(X, Y, C, Working) Daniel@0: % Daniel@0: % for i = 1:|X| Daniel@0: % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) Daniel@0: % Daniel@0: % Working <- Working + (y^_1,y^_2,...,y^_n) Daniel@0: % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) Daniel@0: % <= Xi + epsilon Daniel@0: Daniel@0: Daniel@0: Daniel@0: Daniel@0: Daniel@0: % Initialize Daniel@0: W = INIT(X); Daniel@0: Daniel@0: ClassScores = []; Daniel@0: Daniel@0: if isa(Y, 'double') Daniel@0: Ypos = []; Daniel@0: Yneg = []; Daniel@0: ClassScores = synthesizeRelevance(Y); Daniel@0: Daniel@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 Daniel@0: dbprint(2, 'Using supplied Ypos/Yneg'); Daniel@0: Ypos = Y(:,1); Daniel@0: Yneg = Y(:,2); Daniel@0: Daniel@0: % Compute the valid samples Daniel@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); Daniel@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 Daniel@0: dbprint(2, 'Using supplied Ypos/synthesized Yneg'); Daniel@0: Ypos = Y(:,1); Daniel@0: Yneg = []; Daniel@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); Daniel@0: else Daniel@0: error('MLR:LABELS', 'Incorrect format for Y.'); Daniel@0: end Daniel@0: Daniel@0: Daniel@0: Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? Daniel@0: 'feature', Feature, ... % Which ranking feature is used? Daniel@0: 'k', k, ... % What is the ranking length? Daniel@0: 'regularizer', Regularizer, ... % What regularization is used? Daniel@0: 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal Daniel@0: 'num_calls_SO', 0, ... % Calls to separation oracle Daniel@0: 'num_calls_solver', 0, ... % Calls to solver Daniel@0: 'time_SO', 0, ... % Time in separation oracle Daniel@0: 'time_solver', 0, ... % Time in solver Daniel@0: 'time_total', 0, ... % Total time Daniel@0: 'f', [], ... % Objective value Daniel@0: 'num_steps', [], ... % Number of steps for each solver run Daniel@0: 'num_constraints', [], ... % Number of constraints for each run Daniel@0: 'Xi', [], ... % Slack achieved for each run Daniel@0: 'Delta', [], ... % Mean loss for each SO call Daniel@0: 'gap', [], ... % Gap between loss and slack Daniel@0: 'C', C, ... % Slack trade-off Daniel@0: 'epsilon', E, ... % Convergence threshold Daniel@0: 'constraint_timer', ConstraintClock); % Time before evicting old constraints Daniel@0: Daniel@0: Daniel@0: Daniel@0: global PsiR; Daniel@0: global PsiClock; Daniel@0: Daniel@0: PsiR = {}; Daniel@0: PsiClock = []; Daniel@0: Daniel@0: Xi = -Inf; Daniel@0: Margins = []; Daniel@0: Daniel@0: if STOCHASTIC Daniel@0: dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); Daniel@0: end Daniel@0: Daniel@0: while 1 Daniel@0: dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); Daniel@0: % Generate a constraint set Daniel@0: Termination = 0; Daniel@0: Daniel@0: Daniel@0: dbprint(2, 'Calling separation oracle...'); Daniel@0: Daniel@0: [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); Daniel@0: Termination = LOSS(W, PsiNew, Mnew, 0); Daniel@0: Daniel@0: Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; Daniel@0: Diagnostics.time_SO = Diagnostics.time_SO + SO_time; Daniel@0: Daniel@0: Margins = cat(1, Margins, Mnew); Daniel@0: PsiR = cat(1, PsiR, PsiNew); Daniel@0: PsiClock = cat(1, PsiClock, 0); Daniel@0: Daniel@0: dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); Daniel@0: dbprint(2, '\t Mean loss : %0.4f', Mnew); Daniel@0: dbprint(2, '\t Termination -Xi < E : %0.4f