annotate core/tools/machine_learning/cvpartition_trunctrain_incsubsets.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 % ---
wolffd@0 2 % class cvpartition_trunctrain
wolffd@0 3 % NOTE: this is a fake cvpartition double for
wolffd@0 4 % using cvpartitions in truncated-training size experiments
wolffd@0 5 %
wolffd@0 6 % differently from cvpartition_trunctrain, we take all the training sizes
wolffd@0 7 % at once and generate training partitions where the smaller ones are subsets
wolffd@0 8 % of the bigger ones
wolffd@0 9 % ---
wolffd@0 10 classdef cvpartition_trunctrain_incsubsets
wolffd@0 11
wolffd@0 12 properties (Hidden)
wolffd@0 13
wolffd@0 14 mtest;
wolffd@0 15 mtraining;
wolffd@0 16 end
wolffd@0 17 properties
wolffd@0 18 N;
wolffd@0 19 NumTestSets;
wolffd@0 20 TrainSize;
wolffd@0 21 TestSize;
wolffd@0 22 end
wolffd@0 23
wolffd@0 24
wolffd@0 25 methods
wolffd@0 26
wolffd@0 27 % ---
wolffd@0 28 % constuctor: directly calculates the truncated testset
wolffd@0 29 % ---
wolffd@0 30 function P = cvpartition_trunctrain_incsubsets(Pin, perctrain)
wolffd@0 31
wolffd@0 32 % ---
wolffd@0 33 % NOTE: we use a different permutation for each cv-Buun (testset),
wolffd@0 34 % as otherwise the very small training sets will have about the same
wolffd@0 35 % data
wolffd@0 36 % ---
wolffd@0 37 if ~cvpartition_trunctrain_incsubsets.exists_permutation(Pin)
wolffd@0 38 cvpartition_trunctrain_incsubsets.renew_permutation(Pin);
wolffd@0 39 end
wolffd@0 40
wolffd@0 41 P.N = Pin.N;
wolffd@0 42 P.NumTestSets = Pin.NumTestSets;
wolffd@0 43
wolffd@0 44 for i = 1:Pin.NumTestSets
wolffd@0 45
wolffd@0 46 % copy testing data
wolffd@0 47 P.TestSize(i) = Pin.TestSize(i);
wolffd@0 48 P.mtest{i} = Pin.test(i);
wolffd@0 49
wolffd@0 50 % calculate new training size
wolffd@0 51 P.TrainSize(i) = ceil(perctrain * Pin.TrainSize(i));
wolffd@0 52
wolffd@0 53 % get actual training indices
wolffd@0 54 idx = find(Pin.training(i));
wolffd@0 55
wolffd@0 56 % ---
wolffd@0 57 % NOTE: the Test-Set-Specific permutation is applied
wolffd@0 58 % we only extract as many indices as fit in Pin
wolffd@0 59 % ---
wolffd@0 60 permu = cvpartition_trunctrain_incsubsets.get_permutation(i,Pin.TrainSize(i));
wolffd@0 61
wolffd@0 62 % truncate the indices
wolffd@0 63 idx = idx(permu(1:P.TrainSize(i)));
wolffd@0 64
wolffd@0 65 % build truncated training set
wolffd@0 66 P.mtraining{i} = false(P.N, 1);
wolffd@0 67 P.mtraining{i}(idx) = true;
wolffd@0 68 end
wolffd@0 69 end
wolffd@0 70 function out = test(P, i)
wolffd@0 71
wolffd@0 72 out = P.mtest{i};
wolffd@0 73 end
wolffd@0 74
wolffd@0 75 function out = training(P, i)
wolffd@0 76
wolffd@0 77 out = P.mtraining{i};
wolffd@0 78 end
wolffd@0 79 end
wolffd@0 80
wolffd@0 81 methods (Static)
wolffd@0 82
wolffd@0 83 % ---
wolffd@0 84 % TODO: save the permutation in a global variable,
wolffd@0 85 % tomake the same smaller set available
wolffd@0 86 % for all further experiments.
wolffd@0 87 % moreover, it would be great if the smaller training sets
wolffd@0 88 % are subsets of the bigger ones
wolffd@0 89 % ---
wolffd@0 90 function renew_permutation(P)
wolffd@0 91 global globalvars;
wolffd@0 92
wolffd@0 93 if isfield(globalvars.camir, ...
wolffd@0 94 'cvpartition_trunctrain_incsubsets');
wolffd@0 95 warning 'renwewing permutations for train sets';
wolffd@0 96 end
wolffd@0 97
wolffd@0 98 for i = 1:P.NumTestSets
wolffd@0 99 globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(i).data = ...
wolffd@0 100 randperm(P.N);
wolffd@0 101 end
wolffd@0 102 end
wolffd@0 103
wolffd@0 104 function idx = get_permutation(testId, trainSize)
wolffd@0 105 % returns the permutation for specific test set
wolffd@0 106 global globalvars;
wolffd@0 107
wolffd@0 108 idx = globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(testId).data;
wolffd@0 109
wolffd@0 110 % cut the permutation to contain no exxcess numbers
wolffd@0 111 idx = idx(idx <= trainSize);
wolffd@0 112 end
wolffd@0 113
wolffd@0 114 function out = exists_permutation(P)
wolffd@0 115 global globalvars;
wolffd@0 116 if isfield(globalvars.camir, ...
wolffd@0 117 'cvpartition_trunctrain_incsubsets');
wolffd@0 118
wolffd@0 119 out = (numel(globalvars.camir.cvpartition_trunctrain_incsubsets.permutation) == P.NumTestSets) ...
wolffd@0 120 && (numel(globalvars.camir.cvpartition_trunctrain_incsubsets.permutation(1).data) == P.N);
wolffd@0 121
wolffd@0 122
wolffd@0 123 else out = false;
wolffd@0 124 end
wolffd@0 125 end
wolffd@0 126
wolffd@0 127 end
wolffd@0 128 end