Mercurial > hg > sfx-subgrouping
comparison phase2/rfFeatureSelection.m @ 6:54446ca7e6cb
onePass and cut methods both working for random forest feature selection
author | DaveM |
---|---|
date | Thu, 09 Feb 2017 21:43:20 +0000 |
parents | 7848d183c7ab |
children | cf00dc8be4f7 |
comparison
equal
deleted
inserted
replaced
5:7848d183c7ab | 6:54446ca7e6cb |
---|---|
1 function features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees) | 1 function featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector) |
2 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees) | 2 % rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees) |
3 % | 3 % |
4 % using random forests to perform feature selection for a given data set | 4 % using random forests to perform feature selection for a given data set |
5 % data has size (x,y), where x is the number of labels and y, the number of | 5 % data has size (x,y), where x is the number of labels and y, the number of |
6 % features. | 6 % features. |
22 % measure is the increase in prediction error if the values of that | 22 % measure is the increase in prediction error if the values of that |
23 % variable are permuted across the out-of-bag observations. This | 23 % variable are permuted across the out-of-bag observations. This |
24 % measure is computed for every tree, then averaged over the entire | 24 % measure is computed for every tree, then averaged over the entire |
25 % ensemble and divided by the standard deviation over the entire | 25 % ensemble and divided by the standard deviation over the entire |
26 % ensemble. | 26 % ensemble. |
27 % featureVector is a list of the features to use, for recursive purposes. | |
27 | 28 |
28 if(length(labels) ~= size(data,1)) | 29 if(length(labels) ~= size(data,1)) |
29 error('labels and data do not match up'); | 30 error('labels and data do not match up'); |
30 end | 31 end |
31 | 32 |
39 iterMethod = 'onePass'; | 40 iterMethod = 'onePass'; |
40 end | 41 end |
41 if(nargin < 5) | 42 if(nargin < 5) |
42 numTrees = 200; | 43 numTrees = 200; |
43 end | 44 end |
45 if(nargin < 5) | |
46 featureVector = 1:size(data,2); | |
47 end | |
44 | 48 |
45 | 49 |
46 options = statset('UseParallel', true); | 50 if(length(featureVector) > numFeatures) |
47 b = TreeBagger(numTrees, data, labels,'OOBVarImp','On',... | 51 options = statset('UseParallel', true); |
48 'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options); | 52 b = TreeBagger(numTrees, data(:,featureVector), labels,'OOBVarImp','On',... |
49 [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend'); | 53 'SampleWithReplacement', 'Off','FBoot', 0.632,'Options', options); |
50 features = I; | 54 [FI,I] = sort(b.OOBPermutedVarDeltaError,'descend'); |
55 featureVector = featureVector(I); | |
51 | 56 |
52 if(strcmp(iterMethod,'onePass')) | 57 if(strcmp(iterMethod,'onePass')) |
53 disp('onePass') | 58 disp('onePass') |
54 features = features(1:numFeatures); | 59 featureVector = featureVector(1:numFeatures); |
55 elseif(strcmp(iterMethod(1:3),'cut')) | 60 elseif(strcmp(iterMethod(1:3),'cut')) |
56 disp(iterMethod) | 61 disp(iterMethod) |
57 cutPercentage = str2int(iterMethod(4:end)); | 62 cutPercentage = str2double(iterMethod(4:end)); |
58 cutSize = max(floor(length(features)*cutPercentage/100),1); | 63 cutSize = max(floor(length(featureVector)*cutPercentage/100),1); |
59 features = features(1:end-cutSize); | 64 if(length(featureVector) - cutSize < numFeatures) |
60 data = data(:,I); | 65 cutSize = length(featureVector) - numFeatures; |
61 features = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees); | 66 end |
62 elseif(strcmp(iterMethod,'oobErr')) | 67 featureVector = featureVector(1:end-cutSize); |
63 warning('This method has not been implemented yet, using onePass to return results') | 68 % data = data(:,sort(featureVector)); |
64 features = features(1:numFeatures); | 69 featureVector = rfFeatureSelection(data, labels, numFeatures, iterMethod, numTrees, featureVector); |
65 elseif(strcmp(iterMethod,'featureDeltaErr')) | 70 elseif(strcmp(iterMethod,'oobErr')) |
66 warning('This method has not been implemented yet, using onePass to return results') | 71 warning('This method has not been implemented yet, using onePass to return results') |
67 % this will use variable FI | 72 featureVector = featureVector(1:numFeatures); |
68 features = features(1:numFeatures); | 73 elseif(strcmp(iterMethod,'featureDeltaErr')) |
74 warning('This method has not been implemented yet, using onePass to return results') | |
75 % this will use variable FI | |
76 featureVector = featureVector(1:numFeatures); | |
77 end | |
69 end | 78 end |
70 end | 79 end |