view core/magnatagatune/tests_evals/svm_light/svmlight_wrapper.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [A, diag] = svmlight_wrapper(X, Ytrain, trainparams)
% wrapper to make ITML accessible via a general interface

% define model file
modelfile = 'tmp_model.dat';
% define constraints file
trainfile = 'tmp_train.dat';

% ---
% include weighting from Y
% ---
if isfield(trainparams,'weighted') && trainparams.weighted
    if trainparams.weighted > 1
        
        % ---
        % NOTE: this is to maintain a tradeoff between weighting 
        % and space usage :
        % 
        % scale the rating if the "weighted" parameter 
        % gives a maximum weight
        % 
        % TODO: try logarithmic weight scaling
        % ---
        Ytrain = scale_ratings(Ytrain, trainparams.weighted);
    end
    
    % ---
    % get squared pointwise distance for labeled features
    % ---
    if isfield(trainparams,'deltafun')
        [lhs, rhs, c] = get_svmlight_inequalities_from_ranking(Ytrain, X, str2func(trainparams.deltafun), trainparams.deltafun_params);
    else
        [lhs, rhs, c] = get_svmlight_inequalities_from_ranking(Ytrain, X);
    end
    
    % save to data file to disk;
    success = save_svmlight_inequalities(lhs, rhs, c, trainfile);
else
    % ---
    % get squared pointwise distance for labeled features
    % ---
    if isfield(trainparams,'deltafun')
        [lhs, rhs, dim] = get_svmlight_inequalities_from_ranking(Ytrain, X, str2func(trainparams.deltafun), trainparams.deltafun_params);
    else
        [lhs, rhs, dim] = get_svmlight_inequalities_from_ranking(Ytrain, X);
    end
    success = save_svmlight_inequalities(lhs, rhs, trainfile);

end

if ~success
    error 'cannot write svm training file'
end

% call svmlight solver
% D = dos(sprintf('svm_learn -z o -c %d %s %s', C, trainfile, modelfile));
[diag, ~] = evalc('dos(sprintf(''svm_learn -z o -c %d %s %s'', trainparams.C, trainfile, modelfile));');

% Strip some dots from the display
diag = strrep(diag,'.......','');
cprint(2, diag)

% ---
% get dual weight vector
% TODO: check whether this the actual w in Schultz2003
% ---
w = svmlight2weight(modelfile);

% prolong w to feature dimension
w = [w; zeros(dim - numel(w),1)];

% ---
% matrix from weights
% ---
A = spdiags(w, 0, numel(w), numel(w));
end