view core/tools/machine_learning/cvpartition_trunctrain_incsubsets.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
% ---
% class cvpartition_trunctrain
% NOTE: this is a fake cvpartition double for 
% using cvpartitions in truncated-training size experiments
%
% differently from cvpartition_trunctrain, we take all the training sizes 
% at once and generate training partitions where the smaller ones are subsets 
% of the bigger ones
% ---
classdef cvpartition_trunctrain_incsubsets
  
properties (Hidden)

    mtest;
    mtraining;
end
properties
    N;
    NumTestSets;
    TrainSize;
    TestSize;
end
    
        
methods
    
% ---
% constuctor: directly calculates the truncated testset
% ---
function P = cvpartition_trunctrain_incsubsets(Pin, perctrain)
  
    % ---
    % NOTE: we use a different permutation for each cv-Buun (testset), 
    % as otherwise the very small training sets will have about the same 
    % data
    % ---
    if ~cvpartition_trunctrain_incsubsets.exists_permutation(Pin)
        cvpartition_trunctrain_incsubsets.renew_permutation(Pin);
    end

    P.N = Pin.N;
    P.NumTestSets = Pin.NumTestSets;
    
    for i = 1:Pin.NumTestSets

        % copy testing data
        P.TestSize(i) = Pin.TestSize(i);
        P.mtest{i} = Pin.test(i);

        % calculate new training size
        P.TrainSize(i) = ceil(perctrain * Pin.TrainSize(i));
        
        % get actual training indices
        idx = find(Pin.training(i));
        
        % ---
        % NOTE: the Test-Set-Specific permutation is applied
        % we only extract as many indices as fit in Pin
        % ---
        permu = cvpartition_trunctrain_incsubsets.get_permutation(i,Pin.TrainSize(i));
        
        % truncate the indices
        idx = idx(permu(1:P.TrainSize(i)));
        
        % build truncated training set
        P.mtraining{i} = false(P.N, 1);
        P.mtraining{i}(idx) = true;
    end
end   
function out = test(P, i)

    out = P.mtest{i};
end

function out = training(P, i)

    out = P.mtraining{i};
end
end

methods (Static)
    
    % ---
    % TODO: save the permutation in a global variable,
    % tomake the same smaller set available
    % for all further experiments.
    % moreover, it would be great if the smaller training sets
    % are subsets of the bigger ones
    % ---
    function renew_permutation(P)
        global globalvars;
        
        if isfield(globalvars.camir, ...
                'cvpartition_trunctrain_incsubsets');
            warning 'renwewing permutations for train sets';
        end
        
        for i = 1:P.NumTestSets
                globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(i).data = ...
                    randperm(P.N);
        end
    end
    
    function idx = get_permutation(testId, trainSize)
        % returns the permutation for specific test set
        global globalvars;

        idx = globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(testId).data;
        
        % cut the permutation to contain no exxcess numbers
        idx = idx(idx <= trainSize);
    end
    
    function out = exists_permutation(P)
        global globalvars;
        if isfield(globalvars.camir, ...
                'cvpartition_trunctrain_incsubsets');
            
            out = (numel(globalvars.camir.cvpartition_trunctrain_incsubsets.permutation) == P.NumTestSets)  ...
                    && (numel(globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(1).data) == P.N);

  
        else out = false;
        end
    end
    
end
end