view toolboxes/distance_learning/mlr/mlr_train.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [W, Xi, Diagnostics] = mlr_train(X, Y, Cslack, varargin)
%
%   [W, Xi, D] = mlr_train(X, Y, C,...)
%
%       X = d*n data matrix
%       Y = either n-by-1 label of vectors
%           OR
%           n-by-2 cell array where 
%               Y{q,1} contains relevant indices for q, and
%               Y{q,2} contains irrelevant indices for q
%       
%       C >= 0 slack trade-off parameter (default=1)
%
%       W   = the learned metric
%       Xi  = slack value on the learned metric
%       D   = diagnostics
%
%   Optional arguments:
%
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS)
%       where LOSS is one of:
%           'AUC':      Area under ROC curve (default)
%           'KNN':      KNN accuracy
%           'Prec@k':   Precision-at-k
%           'MAP':      Mean Average Precision
%           'MRR':      Mean Reciprocal Rank
%           'NDCG':     Normalized Discounted Cumulative Gain
%
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k)
%       where k is the number of neighbors for Prec@k or NDCG
%       (default=3)
%
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG)
%       where REG defines the regularization on W, and is one of:
%           0:          no regularization
%           1:          1-norm: trace(W)            (default)
%           2:          2-norm: trace(W' * W)
%           3:          Kernel: trace(W * X), assumes X is square and positive-definite
%
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal)
%       Diagonal = 0:   learn a full d-by-d W       (default)
%       Diagonal = 1:   learn diagonally-constrained W (d-by-1)
%       
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B)
%       where B > 0 enables stochastic optimization with batch size B
%
% // added by Daniel Wolff
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC)
%       Set ConstraintClock to CC (default: 20, 100)
% 
%   [W, Xi, D] = mlr_train(X, Y, C, LOSS, k, REG, Diagonal, B, CC, E)
%       Set ConstraintClock to E (default: 1e-3)
%

    TIME_START = tic();

    global C;
    C = Cslack;

    [d,n,m] = size(X);

    if m > 1
        MKL = 1;
    else
        MKL = 0;
    end

    if nargin < 3
        C = 1;
    end

    %%%
    % Default options:

    global CP SO PSI REG FEASIBLE LOSS DISTANCE SETDISTANCE CPGRADIENT STRUCTKERNEL DUALW INIT;

    global FEASIBLE_COUNT;
    FEASIBLE_COUNT = 0;

    CP          = @cuttingPlaneFull;
    SO          = @separationOracleAUC;
    PSI         = @metricPsiPO;

    if ~MKL
        INIT        = @initializeFull;
        REG         = @regularizeTraceFull;
        STRUCTKERNEL= @structKernelLinear;
        DUALW       = @dualWLinear;
        FEASIBLE    = @feasibleFull;
        CPGRADIENT  = @cpGradientFull;
        DISTANCE    = @distanceFull;
        SETDISTANCE = @setDistanceFull;
        LOSS        = @lossHinge;
        Regularizer = 'Trace';
    else
        INIT        = @initializeFullMKL;
        REG         = @regularizeMKLFull;
        STRUCTKERNEL= @structKernelMKL;
        DUALW       = @dualWMKL;
        FEASIBLE    = @feasibleFullMKL;
        CPGRADIENT  = @cpGradientFullMKL;
        DISTANCE    = @distanceFullMKL;
        SETDISTANCE = @setDistanceFullMKL;
        LOSS        = @lossHingeFullMKL;
        Regularizer = 'Trace';
    end

    
    Loss        = 'AUC';
    Feature     = 'metricPsiPO';


    %%%
    % Default k for prec@k, ndcg
    k           = 3;

    %%%
    % Stochastic violator selection?
    STOCHASTIC  = 0;
    batchSize   = n;
    SAMPLES     = 1:n;


    if nargin > 3
        switch lower(varargin{1})
            case {'auc'}
                SO          = @separationOracleAUC;
                PSI         = @metricPsiPO;
                Loss        = 'AUC';
                Feature     = 'metricPsiPO';
            case {'knn'}
                SO          = @separationOracleKNN;
                PSI         = @metricPsiPO;
                Loss        = 'KNN';
                Feature     = 'metricPsiPO';
            case {'prec@k'}
                SO          = @separationOraclePrecAtK;
                PSI         = @metricPsiPO;
                Loss        = 'Prec@k';
                Feature     = 'metricPsiPO';
            case {'map'}
                SO  = @separationOracleMAP;
                PSI = @metricPsiPO;
                Loss        = 'MAP';
                Feature     = 'metricPsiPO';
            case {'mrr'}
                SO          = @separationOracleMRR;
                PSI         = @metricPsiPO;
                Loss        = 'MRR';
                Feature     = 'metricPsiPO';
            case {'ndcg'}
                SO          = @separationOracleNDCG;
                PSI         = @metricPsiPO;
                Loss        = 'NDCG';
                Feature     = 'metricPsiPO';
            otherwise
                error('MLR:LOSS', ...
                    'Unknown loss function: %s', varargin{1});
        end
    end

    if nargin > 4
        k = varargin{2};
    end

    Diagonal = 0;
    if nargin > 6 & varargin{4} > 0
        Diagonal = varargin{4};

        if ~MKL
            INIT        = @initializeDiag;
            REG         = @regularizeTraceDiag;
            STRUCTKERNEL= @structKernelDiag;
            DUALW       = @dualWDiag;
            FEASIBLE    = @feasibleDiag;
            CPGRADIENT  = @cpGradientDiag;
            DISTANCE    = @distanceDiag;
            SETDISTANCE = @setDistanceDiag;
            Regularizer = 'Trace';
        else
            INIT        = @initializeDiagMKL;
            REG         = @regularizeMKLDiag;
            STRUCTKERNEL= @structKernelDiagMKL;
            DUALW       = @dualWDiagMKL;
            FEASIBLE    = @feasibleDiagMKL;
            CPGRADIENT  = @cpGradientDiagMKL;
            DISTANCE    = @distanceDiagMKL;
            SETDISTANCE = @setDistanceDiagMKL;
            LOSS        = @lossHingeDiagMKL;
            Regularizer = 'Trace';
        end
    end

    if nargin > 5
        switch(varargin{3})
            case {0}
                REG         = @regularizeNone;
                Regularizer = 'None';

            case {1}
                if MKL
                    if Diagonal == 0
                        REG         = @regularizeMKLFull;
                        STRUCTKERNEL= @structKernelMKL;
                        DUALW       = @dualWMKL;
                    elseif Diagonal == 1
                        REG         = @regularizeMKLDiag;
                        STRUCTKERNEL= @structKernelDiagMKL;
                        DUALW       = @dualWDiagMKL;
                    end
                else
                    if Diagonal 
                        REG         = @regularizeTraceDiag;
                        STRUCTKERNEL= @structKernelDiag;
                        DUALW       = @dualWDiag;
                    else
                        REG         = @regularizeTraceFull;
                        STRUCTKERNEL= @structKernelLinear;
                        DUALW       = @dualWLinear;
                    end
                end
                Regularizer = 'Trace';

            case {2}
                if Diagonal
                    REG         = @regularizeTwoDiag;
                else
                    REG         = @regularizeTwoFull;
                end
                Regularizer = '2-norm';
                error('MLR:REGULARIZER', '2-norm regularization no longer supported');
                

            case {3}
                if MKL
                    if Diagonal == 0
                        REG         = @regularizeMKLFull;
                        STRUCTKERNEL= @structKernelMKL;
                        DUALW       = @dualWMKL;
                    elseif Diagonal == 1
                        REG         = @regularizeMKLDiag;
                        STRUCTKERNEL= @structKernelDiagMKL;
                        DUALW       = @dualWDiagMKL;
                    end
                else
                    if Diagonal
                        REG         = @regularizeMKLDiag;
                        STRUCTKERNEL= @structKernelDiagMKL;
                        DUALW       = @dualWDiagMKL;
                    else
                        REG         = @regularizeKernel;
                        STRUCTKERNEL= @structKernelMKL;
                        DUALW       = @dualWMKL;
                    end
                end
                Regularizer = 'Kernel';

            otherwise
                error('MLR:REGULARIZER', ... 
                    'Unknown regularization: %s', varargin{3});
        end
    end


    % Are we in stochastic optimization mode?
    if nargin > 7 && varargin{5} > 0
        if varargin{5} < n
            STOCHASTIC  = 1;
            CP          = @cuttingPlaneRandom;
            batchSize   = varargin{5};
        end
    end

    % Algorithm
    %
    % Working <- []
    %
    % repeat:
    %   (W, Xi) <- solver(X, Y, C, Working)
    %
    %   for i = 1:|X|
    %       y^_i <- argmax_y^ ( Delta(y*_i, y^) + w' Psi(x_i, y^) )
    %
    %   Working <- Working + (y^_1,y^_2,...,y^_n)
    % until mean(Delta(y*_i, y_i)) - mean(w' (Psi(x_i,y_i) - Psi(x_i,y^_i))) 
    %           <= Xi + epsilon

    global DEBUG;
    
    if isempty(DEBUG)
        DEBUG = 0;
    end

    %%%
    % Timer to eliminate old constraints
        
    %%%
    % Timer to eliminate old constraints
    ConstraintClock = 100; % standard: 100
    
    if nargin > 8 && varargin{6} > 0
        ConstraintClock = varargin{6};
    end
    
    %%%
    % Convergence criteria for worst-violated constraint
    E = 1e-3;
    if nargin > 9 && varargin{7} > 0
        E = varargin{7};
    end
    
    
    %XXX:    2012-01-31 21:29:50 by Brian McFee <bmcfee@cs.ucsd.edu>
    % no longer belongs here
    % Initialize
    W           = INIT(X);


    global ADMM_Z ADMM_U RHO;
    ADMM_Z      = W;
    ADMM_U      = 0 * ADMM_Z;

    %%%
    % Augmented lagrangian factor
    RHO = 1;

    ClassScores = [];

    if isa(Y, 'double')
        Ypos        = [];
        Yneg        = [];
        ClassScores = synthesizeRelevance(Y);

    elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 2
         dbprint(2, 'Using supplied Ypos/synthesized Yneg');
        Ypos        = Y(:,1);
        Yneg        = Y(:,2);

        % Compute the valid samples
        SAMPLES     = find( ~(cellfun(@isempty, Y(:,1)) | cellfun(@isempty, Y(:,2))));
    elseif isa(Y, 'cell') && size(Y,1) == n && size(Y,2) == 1
        dbprint(2, 'Using supplied Ypos/synthesized Yneg');
        Ypos        = Y(:,1);
        Yneg        = [];
        SAMPLES     = find( ~(cellfun(@isempty, Y(:,1))));
    else
        error('MLR:LABELS', 'Incorrect format for Y.');
    end

    %%
    % If we don't have enough data to make the batch, cut the batch
    batchSize = min([batchSize, length(SAMPLES)]);


    Diagnostics = struct(   'loss',                 Loss, ...           % Which loss are we optimizing?
                            'feature',              Feature, ...        % Which ranking feature is used?
                            'k',                    k, ...              % What is the ranking length?
                            'regularizer',          Regularizer, ...    % What regularization is used?
                            'diagonal',             Diagonal, ...       % 0 for full metric, 1 for diagonal
                            'num_calls_SO',         0, ...              % Calls to separation oracle
                            'num_calls_solver',     0, ...              % Calls to solver
                            'time_SO',              0, ...              % Time in separation oracle
                            'time_solver',          0, ...              % Time in solver
                            'time_total',           0, ...              % Total time
                            'f',                    [], ...             % Objective value
                            'num_steps',            [], ...             % Number of steps for each solver run
                            'num_constraints',      [], ...             % Number of constraints for each run
                            'Xi',                   [], ...             % Slack achieved for each run
                            'Delta',                [], ...             % Mean loss for each SO call
                            'gap',                  [], ...             % Gap between loss and slack
                            'C',                    C, ...              % Slack trade-off
                            'epsilon',              E, ...              % Convergence threshold
                            'feasible_count',       FEASIBLE_COUNT, ... % Counter for # svd's
                            'constraint_timer',     ConstraintClock);   % Time before evicting old constraints



    global PsiR;
    global PsiClock;

    PsiR        = {};
    PsiClock    = [];

    Xi          = -Inf;
    Margins     = [];
    H           = [];
    Q           = [];

    if STOCHASTIC
        dbprint(1, 'STOCHASTIC OPTIMIZATION: Batch size is %d/%d', batchSize, n);
    end

    MAXITER = 200;
%     while 1
    while Diagnostics.num_calls_solver < MAXITER
        dbprint(2, 'Round %03d', Diagnostics.num_calls_solver);
        % Generate a constraint set
        Termination = 0;


        dbprint(2, 'Calling separation oracle...');

        [PsiNew, Mnew, SO_time]     = CP(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores);
        Termination                 = LOSS(W, PsiNew, Mnew, 0);

        Diagnostics.num_calls_SO    = Diagnostics.num_calls_SO + 1;
        Diagnostics.time_SO         = Diagnostics.time_SO + SO_time;

        Margins     = cat(1, Margins,   Mnew);
        PsiR        = cat(1, PsiR,      PsiNew);
        PsiClock    = cat(1, PsiClock,  0);
        H           = expandKernel(H);
        Q           = expandRegularizer(Q, X, W);


        dbprint(2, '\n\tActive constraints    : %d',            length(PsiClock));
        dbprint(2, '\t           Mean loss  : %0.4f',           Mnew);
        dbprint(2, '\t  Current loss Xi     : %0.4f',           Xi);
        dbprint(2, '\t  Termination -Xi < E : %0.4f <? %.04f\n', Termination - Xi, E);
        
        Diagnostics.gap     = cat(1, Diagnostics.gap,   Termination - Xi);
        Diagnostics.Delta   = cat(1, Diagnostics.Delta, Mnew);

        if Termination <= Xi + E
            dbprint(2, 'Done.');
            break;
        end
        
        dbprint(2, 'Calling solver...');
        PsiClock                        = PsiClock + 1;
        Solver_time                     = tic();
            [W, Xi, Dsolver]            = mlr_admm(C, X, Margins, H, Q);
        Diagnostics.time_solver         = Diagnostics.time_solver + toc(Solver_time);
        Diagnostics.num_calls_solver    = Diagnostics.num_calls_solver + 1;

        Diagnostics.Xi                  = cat(1, Diagnostics.Xi, Xi);
        Diagnostics.f                   = cat(1, Diagnostics.f, Dsolver.f);
        Diagnostics.num_steps           = cat(1, Diagnostics.num_steps, Dsolver.num_steps);

        %%%
        % Cull the old constraints
        GC          = PsiClock < ConstraintClock;
        Margins     = Margins(GC);
        PsiR        = PsiR(GC);
        PsiClock    = PsiClock(GC);
        H           = H(GC, GC);
        Q           = Q(GC);

        Diagnostics.num_constraints = cat(1, Diagnostics.num_constraints, length(PsiR));
    end


    % Finish diagnostics

    Diagnostics.time_total = toc(TIME_START);
    Diagnostics.feasible_count = FEASIBLE_COUNT;
end

function H = expandKernel(H)

    global STRUCTKERNEL;
    global PsiR;

    m = length(H);
    H = padarray(H, [1 1], 0, 'post');


    for i = 1:m+1
        H(i,m+1)    = STRUCTKERNEL( PsiR{i}, PsiR{m+1} );
        H(m+1, i)   = H(i, m+1);
    end
end

function Q = expandRegularizer(Q, K, W)

    % FIXME:  2012-01-31 21:34:15 by Brian McFee <bmcfee@cs.ucsd.edu>
    %  does not support unregularized learning

    global PsiR;
    global STRUCTKERNEL REG;

    m           = length(Q);
    Q(m+1,1)    = STRUCTKERNEL(REG(W,K,1), PsiR{m+1});

end

function ClassScores = synthesizeRelevance(Y)

    classes     = unique(Y);
    nClasses    = length(classes);

    ClassScores = struct(   'Y',    Y, ...
                            'classes', classes, ...
                            'Ypos', [], ...
                            'Yneg', []);

    Ypos = cell(nClasses, 1);
    Yneg = cell(nClasses, 1);
    for c = 1:nClasses
        Ypos{c} = (Y == classes(c));
        Yneg{c} = ~Ypos{c};
    end

    ClassScores.Ypos = Ypos;
    ClassScores.Yneg = Yneg;

end