changeset 6:54446ca7e6cb

onePass and cut methods both working for random forest feature selection
author DaveM
date Thu, 09 Feb 2017 21:43:20 +0000
parents 7848d183c7ab
children cf00dc8be4f7
files phase2/rfFeatureSelection.m
diffstat 1 files changed, 32 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/phase2/rfFeatureSelection.m	Thu Feb 09 18:14:44 2017 +0000
+++ b/phase2/rfFeatureSelection.m	Thu Feb 09 21:43:20 2017 +0000
@@ -1,4 +1,4 @@
-function features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees)
+function featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector)
 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees)
 %
 % using random forests to perform feature selection for a given data set
@@ -24,6 +24,7 @@
 %       measure is computed for every tree, then averaged over the entire
 %       ensemble and divided by the standard deviation over the entire
 %       ensemble.
+% featureVector is a list of the features to use, for recursive purposes.
 
 if(length(labels) ~= size(data,1))
     error('labels and data do not match up');
@@ -41,30 +42,38 @@
 if(nargin < 5)
     numTrees = 200;
 end
+if(nargin < 5)
+    featureVector = 1:size(data,2);
+end
 
 
-options = statset('UseParallel', true);
-b = TreeBagger(numTrees, data, labels,'OOBVarImp','On',...
-    'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options);
-[FI,I] = sort(b.OOBPermutedVarDeltaError,'descend'); 
-features = I;
+if(length(featureVector) > numFeatures)
+    options = statset('UseParallel', true);
+    b = TreeBagger(numTrees, data(:,featureVector), labels,'OOBVarImp','On',...
+        'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options);
+    [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend'); 
+    featureVector = featureVector(I);
 
-if(strcmp(iterMethod,'onePass'))
-    disp('onePass')
-    features = features(1:numFeatures);
-elseif(strcmp(iterMethod(1:3),'cut'))
-    disp(iterMethod)
-    cutPercentage = str2int(iterMethod(4:end));
-    cutSize = max(floor(length(features)*cutPercentage/100),1);
-    features = features(1:end-cutSize);
-    data = data(:,I);
-    features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees);
-elseif(strcmp(iterMethod,'oobErr'))
-    warning('This method has not been implemented yet, using onePass to return results')
-	features = features(1:numFeatures);
-elseif(strcmp(iterMethod,'featureDeltaErr'))
-    warning('This method has not been implemented yet, using onePass to return results')
-    % this will use variable FI
-	features = features(1:numFeatures);
+    if(strcmp(iterMethod,'onePass'))
+        disp('onePass')
+        featureVector = featureVector(1:numFeatures);
+    elseif(strcmp(iterMethod(1:3),'cut'))
+        disp(iterMethod)
+        cutPercentage = str2double(iterMethod(4:end));
+        cutSize = max(floor(length(featureVector)*cutPercentage/100),1);
+        if(length(featureVector) - cutSize < numFeatures)
+            cutSize = length(featureVector) - numFeatures;
+        end
+        featureVector = featureVector(1:end-cutSize);
+    %     data = data(:,sort(featureVector));
+        featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector);
+    elseif(strcmp(iterMethod,'oobErr'))
+        warning('This method has not been implemented yet, using onePass to return results')
+        featureVector = featureVector(1:numFeatures);
+    elseif(strcmp(iterMethod,'featureDeltaErr'))
+        warning('This method has not been implemented yet, using onePass to return results')
+        % this will use variable FI
+        featureVector = featureVector(1:numFeatures);
+    end
 end
 end
\ No newline at end of file