comparison phase2/rfFeatureSelection.m @ 6:54446ca7e6cb

onePass and cut methods both working for random forest feature selection
author DaveM
date Thu, 09 Feb 2017 21:43:20 +0000
parents 7848d183c7ab
children cf00dc8be4f7
comparison
equal deleted inserted replaced
5:7848d183c7ab 6:54446ca7e6cb
1 function features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees) 1 function featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector)
2 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees) 2 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees)
3 % 3 %
4 % using random forests to perform feature selection for a given data set 4 % using random forests to perform feature selection for a given data set
5 % data has size (x,y), where x is the number of labels and y, the number of 5 % data has size (x,y), where x is the number of labels and y, the number of
6 % features. 6 % features.
22 % measure is the increase in prediction error if the values of that 22 % measure is the increase in prediction error if the values of that
23 % variable are permuted across the out-of-bag observations. This 23 % variable are permuted across the out-of-bag observations. This
24 % measure is computed for every tree, then averaged over the entire 24 % measure is computed for every tree, then averaged over the entire
25 % ensemble and divided by the standard deviation over the entire 25 % ensemble and divided by the standard deviation over the entire
26 % ensemble. 26 % ensemble.
27 % featureVector is a list of the features to use, for recursive purposes.
27 28
28 if(length(labels) ~= size(data,1)) 29 if(length(labels) ~= size(data,1))
29 error('labels and data do not match up'); 30 error('labels and data do not match up');
30 end 31 end
31 32
39 iterMethod = 'onePass'; 40 iterMethod = 'onePass';
40 end 41 end
41 if(nargin < 5) 42 if(nargin < 5)
42 numTrees = 200; 43 numTrees = 200;
43 end 44 end
45 if(nargin < 5)
46 featureVector = 1:size(data,2);
47 end
44 48
45 49
46 options = statset('UseParallel', true); 50 if(length(featureVector) > numFeatures)
47 b = TreeBagger(numTrees, data, labels,'OOBVarImp','On',... 51 options = statset('UseParallel', true);
48 'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options); 52 b = TreeBagger(numTrees, data(:,featureVector), labels,'OOBVarImp','On',...
49 [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend'); 53 'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options);
50 features = I; 54 [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend');
55 featureVector = featureVector(I);
51 56
52 if(strcmp(iterMethod,'onePass')) 57 if(strcmp(iterMethod,'onePass'))
53 disp('onePass') 58 disp('onePass')
54 features = features(1:numFeatures); 59 featureVector = featureVector(1:numFeatures);
55 elseif(strcmp(iterMethod(1:3),'cut')) 60 elseif(strcmp(iterMethod(1:3),'cut'))
56 disp(iterMethod) 61 disp(iterMethod)
57 cutPercentage = str2int(iterMethod(4:end)); 62 cutPercentage = str2double(iterMethod(4:end));
58 cutSize = max(floor(length(features)*cutPercentage/100),1); 63 cutSize = max(floor(length(featureVector)*cutPercentage/100),1);
59 features = features(1:end-cutSize); 64 if(length(featureVector) - cutSize < numFeatures)
60 data = data(:,I); 65 cutSize = length(featureVector) - numFeatures;
61 features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees); 66 end
62 elseif(strcmp(iterMethod,'oobErr')) 67 featureVector = featureVector(1:end-cutSize);
63 warning('This method has not been implemented yet, using onePass to return results') 68 % data = data(:,sort(featureVector));
64 features = features(1:numFeatures); 69 featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector);
65 elseif(strcmp(iterMethod,'featureDeltaErr')) 70 elseif(strcmp(iterMethod,'oobErr'))
66 warning('This method has not been implemented yet, using onePass to return results') 71 warning('This method has not been implemented yet, using onePass to return results')
67 % this will use variable FI 72 featureVector = featureVector(1:numFeatures);
68 features = features(1:numFeatures); 73 elseif(strcmp(iterMethod,'featureDeltaErr'))
74 warning('This method has not been implemented yet, using onePass to return results')
75 % this will use variable FI
76 featureVector = featureVector(1:numFeatures);
77 end
69 end 78 end
70 end 79 end