annotate misc/RandomForestEMFeatureSelection.m @ 37:d9a9a6b93026 tip

Add README
author DaveM
date Sat, 01 Apr 2017 17:03:14 +0100
parents 985cd163ba54
children
rev   line source
DaveM@2 1 clearvars;
DaveM@2 2 load('AdobeStratified.mat');
DaveM@2 3 morefeatures = true;
DaveM@2 4 idxvar = (1:1450);
DaveM@2 5 count = 1;
DaveM@2 6 featuredata = struct('IdxVar', [], 'FeatureNamesRanked', {}, 'FeatureImportance', [], 'OOBError', [], 'LastOOBError', [], 'EMClusters', [], 'AIC', [], 'PreviousAIC', []);
DaveM@2 7
DaveM@2 8 while(morefeatures)
DaveM@2 9 DataTrain = DataTrain(:, idxvar);
DaveM@2 10 FeatureNames = FeatureNames(idxvar);
DaveM@2 11 idxvar = (1:length(FeatureNames));
DaveM@2 12 fprintf('\n Growing a Random Forest of 200 trees using %i features\n',length(idxvar))
DaveM@2 13
DaveM@2 14 rng(1945,'twister')
DaveM@2 15 tic
DaveM@2 16 options = statset('UseParallel', true);
DaveM@2 17 b = TreeBagger(200, DataTrain, LabelsTrain,'OOBVarImp','On', 'SampleWithReplacement', 'Off', 'FBoot', 0.632, 'Options', options);
DaveM@2 18 toc
DaveM@2 19
DaveM@2 20 oobErr = oobError(b);
DaveM@2 21 LastoobErr = oobErr(end);
DaveM@2 22
DaveM@2 23 fprintf('\n The cumulative OOB Error at 200 trees is %f\n', LastoobErr);
DaveM@2 24
DaveM@2 25 Indices = crossvalind('Kfold', size(DataTrain, 1), 10);
DaveM@2 26
DaveM@2 27 AICInitial = 1e16;
DaveM@2 28 AICNext = -1e16;
DaveM@2 29 AICAvg = zeros(10, 1);
DaveM@2 30 NumClusters = 1;
DaveM@2 31
DaveM@2 32 while(AICNext <= AICInitial)
DaveM@2 33
DaveM@2 34 if(NumClusters ~= 1)
DaveM@2 35 AICInitial = AICNext;
DaveM@2 36 end
DaveM@2 37 NumClusters = NumClusters + 1;
DaveM@2 38
DaveM@2 39 fprintf('\n Performing EM using 10 fold CV and %i clusters and %i features\n', NumClusters, length(idxvar))
DaveM@2 40
DaveM@2 41 for i = 1:10
DaveM@2 42
DaveM@2 43 emidx = (Indices == i); emidx = ~emidx;
DaveM@2 44
DaveM@2 45 EMDataTrain = DataTrain(emidx, :);
DaveM@2 46 GMModelCV = fitgmdist(EMDataTrain, NumClusters, 'RegularizationValue', 1e-5);
DaveM@2 47 AICAvg(i) = GMModelCV.AIC;
DaveM@2 48 end
DaveM@2 49
DaveM@2 50 AICNext = mean(AICAvg);
DaveM@2 51 fprintf('The average AIC was %f\n', AICNext);
DaveM@2 52 end
DaveM@2 53
DaveM@2 54 FI = b.OOBPermutedVarDeltaError;
DaveM@2 55
DaveM@2 56 [FI,I]=sort(FI, 'descend');
DaveM@2 57 idxvar = idxvar(I);
DaveM@2 58 FeatureNamesRanked = FeatureNames(I);
DaveM@2 59
DaveM@2 60 featuredata(count).IdxVar = idxvar;
DaveM@2 61 featuredata(count).FeatureNamesRanked = FeatureNamesRanked;
DaveM@2 62 featuredata(count).FeatureImportance = FI;
DaveM@2 63 featuredata(count).OOBError = oobErr;
DaveM@2 64 featuredata(count).LastOOBError = LastoobErr;
DaveM@2 65 featuredata(count).EMClusters = NumClusters;
DaveM@2 66 featuredata(count).AIC = AICNext;
DaveM@2 67 featuredata(count).PreviousAIC = AICInitial;
DaveM@2 68
DaveM@2 69 idxRemove = round((length(idxvar) / 100)* 1);
DaveM@2 70 fprintf('\n %i features will be removed.\n', idxRemove)
DaveM@2 71 idxRemove = (length(idxvar) - idxRemove);
DaveM@2 72 idxvar = idxvar(1:idxRemove);
DaveM@2 73 count = count + 1;
DaveM@2 74
DaveM@2 75 save('Results1Percent.mat', 'featuredata');
DaveM@2 76
DaveM@2 77 if(length(idxvar) == 2)
DaveM@2 78 morefeatures = false;
DaveM@2 79 end
DaveM@2 80 end