annotate phase2/rfFeatureSelection.m @ 6:54446ca7e6cb

onePass and cut methods both working for random forest feature selection
author DaveM
date Thu, 09 Feb 2017 21:43:20 +0000
parents 7848d183c7ab
children cf00dc8be4f7
rev   line source
DaveM@6 1 function featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector)
DaveM@4 2 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees)
DaveM@3 3 %
DaveM@4 4 % using random forests to perform feature selection for a given data set
DaveM@4 5 % data has size (x,y), where x is the number of labels and y, the number of
DaveM@4 6 % features.
DaveM@4 7 % labels is the set of labels for the data
DaveM@4 8 % numFeatures is the dimension of the output vector (default 5)
DaveM@4 9 % iterMethod is the method for which the features are cut down
DaveM@5 10 % * 'onePass' will simply select the top (numFeatures) features and
DaveM@5 11 % report them
DaveM@5 12 % * 'cutX' will iteratively cut the bottom X percent of
DaveM@5 13 % features out, and perform random forest feature selection on the
DaveM@5 14 % new set, until the desired number of features has been returned
DaveM@5 15 % * 'oobErr' will do something with the out-of-bag error, and return
DaveM@4 16 % that in some way, but this has not been implemented yet.
DaveM@5 17 % * 'featureDeltaErr' will do something with the feature importance
DaveM@5 18 % prediction error, and return that in some way, but this has not
DaveM@5 19 % been implemented yet. The OOBPermutedVarDeltaError property is a
DaveM@5 20 % numeric array of size 1-by-Nvars containing a measure of importance
DaveM@5 21 % for each predictor variable (feature). For any variable, the
DaveM@5 22 % measure is the increase in prediction error if the values of that
DaveM@5 23 % variable are permuted across the out-of-bag observations. This
DaveM@5 24 % measure is computed for every tree, then averaged over the entire
DaveM@5 25 % ensemble and divided by the standard deviation over the entire
DaveM@5 26 % ensemble.
DaveM@6 27 % featureVector is a list of the features to use, for recursive purposes.
DaveM@3 28
DaveM@3 29 if(length(labels) ~= size(data,1))
DaveM@3 30 error('labels and data do not match up');
DaveM@3 31 end
DaveM@3 32
DaveM@3 33 if(nargin < 2)
DaveM@3 34 error('must pass data and labels into function')
DaveM@3 35 end
DaveM@3 36 if(nargin < 3)
DaveM@3 37 numFeatures = 5;
DaveM@3 38 end
DaveM@3 39 if(nargin < 4)
DaveM@3 40 iterMethod = 'onePass';
DaveM@3 41 end
DaveM@3 42 if(nargin < 5)
DaveM@3 43 numTrees = 200;
DaveM@3 44 end
DaveM@6 45 if(nargin < 5)
DaveM@6 46 featureVector = 1:size(data,2);
DaveM@6 47 end
DaveM@3 48
DaveM@3 49
DaveM@6 50 if(length(featureVector) > numFeatures)
DaveM@6 51 options = statset('UseParallel', true);
DaveM@6 52 b = TreeBagger(numTrees, data(:,featureVector), labels,'OOBVarImp','On',...
DaveM@6 53 'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options);
DaveM@6 54 [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend');
DaveM@6 55 featureVector = featureVector(I);
DaveM@3 56
DaveM@6 57 if(strcmp(iterMethod,'onePass'))
DaveM@6 58 disp('onePass')
DaveM@6 59 featureVector = featureVector(1:numFeatures);
DaveM@6 60 elseif(strcmp(iterMethod(1:3),'cut'))
DaveM@6 61 disp(iterMethod)
DaveM@6 62 cutPercentage = str2double(iterMethod(4:end));
DaveM@6 63 cutSize = max(floor(length(featureVector)*cutPercentage/100),1);
DaveM@6 64 if(length(featureVector) - cutSize < numFeatures)
DaveM@6 65 cutSize = length(featureVector) - numFeatures;
DaveM@6 66 end
DaveM@6 67 featureVector = featureVector(1:end-cutSize);
DaveM@6 68 % data = data(:,sort(featureVector));
DaveM@6 69 featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector);
DaveM@6 70 elseif(strcmp(iterMethod,'oobErr'))
DaveM@6 71 warning('This method has not been implemented yet, using onePass to return results')
DaveM@6 72 featureVector = featureVector(1:numFeatures);
DaveM@6 73 elseif(strcmp(iterMethod,'featureDeltaErr'))
DaveM@6 74 warning('This method has not been implemented yet, using onePass to return results')
DaveM@6 75 % this will use variable FI
DaveM@6 76 featureVector = featureVector(1:numFeatures);
DaveM@6 77 end
DaveM@3 78 end
DaveM@3 79 end