view core/tools/machine_learning/cvpartition_trunctrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
% ---
% class cvpartition_trunctrain
% NOTE: this is a fake cvpartition double for 
% using cvpartitions in truncated-training size experiments
% ---
classdef cvpartition_trunctrain
  
properties (Hidden)

    mtest;
    mtraining;
end
properties
    N;
    NumTestSets;
    TrainSize;
    TestSize;
end
    
        
methods
    
% ---
% constuctor: directly calculates the truncated testset
% ---
function P = cvpartition_trunctrain(Pin, perctrain)
  
    P.N = Pin.N;
    P.NumTestSets = Pin.NumTestSets;
    
    for i = 1:Pin.NumTestSets

        % copy testing data
        P.TestSize(i) = Pin.TestSize(i);
        P.mtest{i} = Pin.test(i);

        % calculate new training size
        P.TrainSize(i) = ceil(perctrain * Pin.TrainSize(i));

        % get actual training indices
        idx = find(Pin.training(i));
    
        % ---
        % TODO: save the permutation in a global variable,
        % tomake the same smaller set available
        % for all further experiments.
        % moreover, it would be great if the smaller training sets
        % are subsets of the bigger ones
        % ---
        tokeep = randperm(numel(idx));
        tokeep = tokeep(1:P.TrainSize(i));
        
        % get indices to keep
        idx = idx(tokeep);

        % build truncated training set
        P.mtraining{i} = false(P.N, 1);
        P.mtraining{i}(idx) = true;
    end
end   

function out = test(P, i)

    out = P.mtest{i};
end

function out = training(P, i)

    out = P.mtraining{i};
end

end  
end