wolffd@0: function [W, Xi, Diagnostics] = rmlr_train(X, Y, Cslack, varargin) wolffd@0: %[W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, stochastic, lam, STRUCTREG) wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C,...) wolffd@0: % wolffd@0: % X = d*n data matrix wolffd@0: % Y = either n-by-1 label of vectors wolffd@0: % OR wolffd@0: % n-by-2 cell array where wolffd@0: % Y{q,1} contains relevant indices for q, and wolffd@0: % Y{q,2} contains irrelevant indices for q wolffd@0: % wolffd@0: % C >= 0 slack trade-off parameter (default=1) wolffd@0: % wolffd@0: % W = the learned metric wolffd@0: % Xi = slack value on the learned metric wolffd@0: % D = diagnostics wolffd@0: % wolffd@0: % Optional arguments: wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS) wolffd@0: % where LOSS is one of: wolffd@0: % 'AUC': Area under ROC curve (default) wolffd@0: % 'KNN': KNN accuracy wolffd@0: % 'Prec@k': Precision-at-k wolffd@0: % 'MAP': Mean Average Precision wolffd@0: % 'MRR': Mean Reciprocal Rank wolffd@0: % 'NDCG': Normalized Discounted Cumulative Gain wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k) wolffd@0: % where k is the number of neighbors for Prec@k or NDCG wolffd@0: % (default=3) wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG) wolffd@0: % where REG defines the regularization on W, and is one of: wolffd@0: % 0: no regularization wolffd@0: % 1: 1-norm: trace(W) (default) wolffd@0: % 2: 2-norm: trace(W' * W) wolffd@0: % 3: Kernel: trace(W * X), assumes X is square and positive-definite wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal) wolffd@0: % Not implemented as learning a diagonal W metric just reduces to MLR because ||W||_2,1 = trace(W) when W is diagonal. wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B) wolffd@0: % where B > 0 enables stochastic optimization with batch size B wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda) wolffd@0: % lambda is the desired value of the hyperparameter which is the coefficient of ||W||_2,1. Default is 1 if lambda is not set wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC) wolffd@0: % Set ConstraintClock to CC (default: 20, 100) wolffd@0: % wolffd@0: % [W, Xi, D] = rmlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, lambda, CC, E) wolffd@0: % Set ConstraintClock to E (default: 1e-3) wolffd@0: % wolffd@0: wolffd@0: wolffd@0: TIME_START = tic(); wolffd@0: wolffd@0: global C; wolffd@0: C = Cslack; wolffd@0: wolffd@0: [d,n,m] = size(X); wolffd@0: wolffd@0: if m > 1 wolffd@0: MKL = 1; wolffd@0: else wolffd@0: MKL = 0; wolffd@0: end wolffd@0: wolffd@0: if nargin < 3 wolffd@0: C = 1; wolffd@0: end wolffd@0: wolffd@0: %%% wolffd@0: % Default options: wolffd@0: wolffd@0: global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW THRESH INIT; wolffd@0: global RHO; wolffd@0: wolffd@0: wolffd@0: %%% wolffd@0: % Augmented lagrangian factor wolffd@0: RHO = 1; wolffd@0: wolffd@0: % wolffd@0: wolffd@0: global FEASIBLE_COUNT; wolffd@0: FEASIBLE_COUNT = 0; wolffd@0: wolffd@0: CP = @cuttingPlaneFull; wolffd@0: SO = @separationOracleAUC; wolffd@0: PSI = @metricPsiPO; wolffd@0: wolffd@0: if ~MKL wolffd@0: INIT = @initializeFull; wolffd@0: REG = @regularizeTraceFull; wolffd@0: STRUCTKERNEL= @structKernelLinear; wolffd@0: DUALW = @dualWLinear; wolffd@0: FEASIBLE = @feasibleFull; wolffd@0: THRESH = @threshFull_admmMixed; wolffd@0: CPGRADIENT = @cpGradientFull; wolffd@0: DISTANCE = @distanceFull; wolffd@0: SETDISTANCE = @setDistanceFull; wolffd@0: LOSS = @lossHinge; wolffd@0: Regularizer = 'Trace'; wolffd@0: else wolffd@0: INIT = @initializeFullMKL; wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: FEASIBLE = @feasibleFullMKL; wolffd@0: THRESH = @threshFull_admmMixed; wolffd@0: CPGRADIENT = @cpGradientFullMKL; wolffd@0: DISTANCE = @distanceFullMKL; wolffd@0: SETDISTANCE = @setDistanceFullMKL; wolffd@0: LOSS = @lossHingeFullMKL; wolffd@0: Regularizer = 'Trace'; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: Loss = 'AUC'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: wolffd@0: wolffd@0: %%% wolffd@0: % Default k for prec@k, ndcg wolffd@0: k = 3; wolffd@0: wolffd@0: %%% wolffd@0: % Stochastic violator selection? wolffd@0: STOCHASTIC = 0; wolffd@0: batchSize = n; wolffd@0: SAMPLES = 1:n; wolffd@0: wolffd@0: wolffd@0: if nargin > 3 wolffd@0: switch lower(varargin{1}) wolffd@0: case {'auc'} wolffd@0: SO = @separationOracleAUC; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'AUC'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'knn'} wolffd@0: SO = @separationOracleKNN; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'KNN'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'prec@k'} wolffd@0: SO = @separationOraclePrecAtK; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'Prec@k'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'map'} wolffd@0: SO = @separationOracleMAP; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'MAP'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'mrr'} wolffd@0: SO = @separationOracleMRR; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'MRR'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: case {'ndcg'} wolffd@0: SO = @separationOracleNDCG; wolffd@0: PSI = @metricPsiPO; wolffd@0: Loss = 'NDCG'; wolffd@0: Feature = 'metricPsiPO'; wolffd@0: otherwise wolffd@0: error('MLR:LOSS', ... wolffd@0: 'Unknown loss function: %s', varargin{1}); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if nargin > 4 wolffd@0: k = varargin{2}; wolffd@0: end wolffd@0: wolffd@0: Diagonal = 0; wolffd@0: %Diagonal case is not implemented. Use mlr_train for that. wolffd@0: wolffd@0: if nargin > 5 wolffd@0: switch(varargin{3}) wolffd@0: case {0} wolffd@0: REG = @regularizeNone; wolffd@0: Regularizer = 'None'; wolffd@0: THRESH = @threshFull_admmMixed; wolffd@0: case {1} wolffd@0: if MKL wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: else wolffd@0: REG = @regularizeTraceFull; wolffd@0: STRUCTKERNEL= @structKernelLinear; wolffd@0: DUALW = @dualWLinear; wolffd@0: end wolffd@0: Regularizer = 'Trace'; wolffd@0: wolffd@0: case {2} wolffd@0: REG = @regularizeTwoFull; wolffd@0: Regularizer = '2-norm'; wolffd@0: error('MLR:REGULARIZER', '2-norm regularization no longer supported'); wolffd@0: wolffd@0: wolffd@0: case {3} wolffd@0: if MKL wolffd@0: REG = @regularizeMKLFull; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: else wolffd@0: REG = @regularizeKernel; wolffd@0: STRUCTKERNEL= @structKernelMKL; wolffd@0: DUALW = @dualWMKL; wolffd@0: end wolffd@0: Regularizer = 'Kernel'; wolffd@0: wolffd@0: wolffd@0: otherwise wolffd@0: error('MLR:REGULARIZER', ... wolffd@0: 'Unknown regularization: %s', varargin{3}); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % Are we in stochastic optimization mode? wolffd@0: if nargin > 7 && varargin{5} > 0 wolffd@0: if varargin{5} < n wolffd@0: STOCHASTIC = 1; wolffd@0: CP = @cuttingPlaneRandom; wolffd@0: batchSize = varargin{5}; wolffd@0: end wolffd@0: end wolffd@0: % Algorithm wolffd@0: % wolffd@0: % Working <- [] wolffd@0: % wolffd@0: % repeat: wolffd@0: % (W, Xi) <- solver(X, Y, C, Working) wolffd@0: % wolffd@0: % for i = 1:|X| wolffd@0: % y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) ) wolffd@0: % wolffd@0: % Working <- Working + (y^_1,y^_2,...,y^_n) wolffd@0: % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) wolffd@0: % <= Xi + epsilon wolffd@0: wolffd@0: if nargin > 8 wolffd@0: lam = varargin{6}; wolffd@0: else wolffd@0: lam = 1; wolffd@0: end wolffd@0: wolffd@0: disp(['lam = ' num2str(lam)]) wolffd@0: wolffd@0: global DEBUG; wolffd@0: wolffd@0: if isempty(DEBUG) wolffd@0: DEBUG = 0; wolffd@0: end wolffd@0: wolffd@0: DEBUG = 1; wolffd@0: wolffd@0: %%% wolffd@0: % Max calls to seperation oracle wolffd@0: MAX_CALLS = 200; wolffd@0: MIN_CALLS = 10; wolffd@0: wolffd@0: %%% wolffd@0: % Timer to eliminate old constraints wolffd@0: ConstraintClock = 500; % standard: 500 wolffd@0: wolffd@0: if nargin > 9 && varargin{7} > 0 wolffd@0: ConstraintClock = varargin{7}; wolffd@0: end wolffd@0: wolffd@0: %%% wolffd@0: % Convergence criteria for worst-violated constraint wolffd@0: E = 1e-3; wolffd@0: if nargin > 10 && varargin{8} > 0 wolffd@0: E = varargin{8}; wolffd@0: end wolffd@0: wolffd@0: %XXX: 2012-01-31 21:29:50 by Brian McFee wolffd@0: % no longer belongs here wolffd@0: % Initialize wolffd@0: W = INIT(X); wolffd@0: wolffd@0: wolffd@0: global ADMM_Z ADMM_V ADMM_UW ADMM_UV; wolffd@0: ADMM_Z = W; wolffd@0: ADMM_V = W; wolffd@0: ADMM_UW = 0 * ADMM_Z; wolffd@0: ADMM_UV = 0 * ADMM_Z; wolffd@0: wolffd@0: ClassScores = []; wolffd@0: wolffd@0: if isa(Y, 'double') wolffd@0: Ypos = []; wolffd@0: Yneg = []; wolffd@0: ClassScores = synthesizeRelevance(Y); wolffd@0: wolffd@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2 wolffd@0: dbprint(2, 'Using supplied Ypos/Yneg'); wolffd@0: Ypos = Y(:,1); wolffd@0: Yneg = Y(:,2); wolffd@0: wolffd@0: % Compute the valid samples wolffd@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2)))); wolffd@0: elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1 wolffd@0: dbprint(2, 'Using supplied Ypos/synthesized Yneg'); wolffd@0: Ypos = Y(:,1); wolffd@0: Yneg = []; wolffd@0: SAMPLES = find( ~(cellfun(@isempty, Y(:,1)))); wolffd@0: else wolffd@0: error('MLR:LABELS', 'Incorrect format for Y.'); wolffd@0: end wolffd@0: %% wolffd@0: % If we don't have enough data to make the batch, cut the batch wolffd@0: batchSize = min([batchSize, length(SAMPLES)]); wolffd@0: wolffd@0: Diagnostics = struct( 'loss', Loss, ... % Which loss are we optimizing? wolffd@0: 'feature', Feature, ... % Which ranking feature is used? wolffd@0: 'k', k, ... % What is the ranking length? wolffd@0: 'regularizer', Regularizer, ... % What regularization is used? wolffd@0: 'diagonal', Diagonal, ... % 0 for full metric, 1 for diagonal wolffd@0: 'num_calls_SO', 0, ... % Calls to separation oracle wolffd@0: 'num_calls_solver', 0, ... % Calls to solver wolffd@0: 'time_SO', 0, ... % Time in separation oracle wolffd@0: 'time_solver', 0, ... % Time in solver wolffd@0: 'time_total', 0, ... % Total time wolffd@0: 'f', [], ... % Objective value wolffd@0: 'num_steps', [], ... % Number of steps for each solver run wolffd@0: 'num_constraints', [], ... % Number of constraints for each run wolffd@0: 'Xi', [], ... % Slack achieved for each run wolffd@0: 'Delta', [], ... % Mean loss for each SO call wolffd@0: 'gap', [], ... % Gap between loss and slack wolffd@0: 'C', C, ... % Slack trade-off wolffd@0: 'epsilon', E, ... % Convergence threshold wolffd@0: 'feasible_count', FEASIBLE_COUNT, ... % Counter for # svd's wolffd@0: 'constraint_timer', ConstraintClock); % Time before evicting old constraints wolffd@0: wolffd@0: wolffd@0: wolffd@0: global PsiR; wolffd@0: global PsiClock; wolffd@0: wolffd@0: PsiR = {}; wolffd@0: PsiClock = []; wolffd@0: wolffd@0: Xi = -Inf; wolffd@0: Margins = []; wolffd@0: H = []; wolffd@0: Q = []; wolffd@0: wolffd@0: if STOCHASTIC wolffd@0: dbprint(2, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n); wolffd@0: end wolffd@0: wolffd@0: dbprint(2,['Regularizer is "' Regularizer '"']); wolffd@0: while 1 wolffd@0: if Diagnostics.num_calls_solver > MAX_CALLS wolffd@0: dbprint(2,['Calls to SO >= ' num2str(MAX_CALLS)]); wolffd@0: break; wolffd@0: end wolffd@0: wolffd@0: dbprint(2, 'Round %03d', Diagnostics.num_calls_solver); wolffd@0: % Generate a constraint set wolffd@0: Termination = -Inf; wolffd@0: wolffd@0: wolffd@0: dbprint(2, 'Calling separation oracle...'); wolffd@0: [PsiNew, Mnew, SO_time] = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores); wolffd@0: wolffd@0: Termination = LOSS(W, PsiNew, Mnew, 0); wolffd@0: Diagnostics.num_calls_SO = Diagnostics.num_calls_SO + 1; wolffd@0: Diagnostics.time_SO = Diagnostics.time_SO + SO_time; wolffd@0: wolffd@0: Margins = cat(1, Margins, Mnew); wolffd@0: PsiR = cat(1, PsiR, PsiNew); wolffd@0: PsiClock = cat(1, PsiClock, 0); wolffd@0: H = expandKernel(H); wolffd@0: Q = expandRegularizer(Q, X, W); wolffd@0: wolffd@0: wolffd@0: dbprint(2, '\n\tActive constraints : %d', length(PsiClock)); wolffd@0: dbprint(2, '\t Mean loss : %0.4f', Mnew); wolffd@0: dbprint(2, '\t Current loss Xi : %0.4f', Xi); wolffd@0: dbprint(2, '\t Termination -Xi < E : %0.4f MIN_CALLS wolffd@0: %if Termination - Xi <= E wolffd@0: dbprint(1, 'Done.'); wolffd@0: break; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: dbprint(1, 'Calling solver...'); wolffd@0: PsiClock = PsiClock + 1; wolffd@0: Solver_time = tic(); wolffd@0: % disp('Robust MLR') wolffd@0: [W, Xi, Dsolver] = rmlr_admm(C, X, Margins, H, Q, lam); wolffd@0: wolffd@0: Diagnostics.time_solver = Diagnostics.time_solver + toc(Solver_time); wolffd@0: Diagnostics.num_calls_solver = Diagnostics.num_calls_solver + 1; wolffd@0: wolffd@0: Diagnostics.Xi = cat(1, Diagnostics.Xi, Xi); wolffd@0: Diagnostics.f = cat(1, Diagnostics.f, Dsolver.f); wolffd@0: Diagnostics.num_steps = cat(1, Diagnostics.num_steps, Dsolver.num_steps); wolffd@0: wolffd@0: %%% wolffd@0: % Cull the old constraints wolffd@0: GC = PsiClock < ConstraintClock; wolffd@0: Margins = Margins(GC); wolffd@0: PsiR = PsiR(GC); wolffd@0: PsiClock = PsiClock(GC); wolffd@0: H = H(GC, GC); wolffd@0: Q = Q(GC); wolffd@0: wolffd@0: Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR)); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % Finish diagnostics wolffd@0: wolffd@0: Diagnostics.time_total = toc(TIME_START); wolffd@0: Diagnostics.feasible_count = FEASIBLE_COUNT; wolffd@0: end wolffd@0: wolffd@0: function H = expandKernel(H) wolffd@0: wolffd@0: global STRUCTKERNEL; wolffd@0: global PsiR; wolffd@0: wolffd@0: m = length(H); wolffd@0: H = padarray(H, [1 1], 0, 'post'); wolffd@0: wolffd@0: wolffd@0: for i = 1:m+1 wolffd@0: H(i,m+1) = STRUCTKERNEL( PsiR{i}, PsiR{m+1} ); wolffd@0: H(m+1, i) = H(i, m+1); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: function Q = expandRegularizer(Q, K, W) wolffd@0: wolffd@0: % FIXME: 2012-01-31 21:34:15 by Brian McFee wolffd@0: % does not support unregularized learning wolffd@0: wolffd@0: global PsiR; wolffd@0: global STRUCTKERNEL REG; wolffd@0: wolffd@0: m = length(Q); wolffd@0: Q(m+1,1) = STRUCTKERNEL(REG(W,K,1), PsiR{m+1}); wolffd@0: wolffd@0: end wolffd@0: wolffd@0: function ClassScores = synthesizeRelevance(Y) wolffd@0: wolffd@0: classes = unique(Y); wolffd@0: nClasses = length(classes); wolffd@0: wolffd@0: ClassScores = struct( 'Y', Y, ... wolffd@0: 'classes', classes, ... wolffd@0: 'Ypos', [], ... wolffd@0: 'Yneg', []); wolffd@0: wolffd@0: Ypos = cell(nClasses, 1); wolffd@0: Yneg = cell(nClasses, 1); wolffd@0: for c = 1:nClasses wolffd@0: Ypos{c} = (Y == classes(c)); wolffd@0: Yneg{c} = ~Ypos{c}; wolffd@0: end wolffd@0: wolffd@0: ClassScores.Ypos = Ypos; wolffd@0: ClassScores.Yneg = Yneg; wolffd@0: wolffd@0: end