annotate reeval/classification/perform_classification.m @ 4:a1f6a08f624c tip

Completed version 0.0.2
author Francisco Rodriguez Algarra <f.rodriguezalgarra@qmul.ac.uk>
date Tue, 03 Nov 2015 21:24:41 +0000
parents b1cd83874633
children
rev   line source
f@2 1 function [results] = perform_classification(experiment, db, condition)
f@2 2
f@2 3 db = svm_calc_kernel(db,'gaussian','square',1:8:size(db.features,2));
f@4 4
f@2 5
f@2 6 optt.kernel_type = 'gaussian';
f@2 7 optt.C = 2.^[0:4:8];
f@2 8 optt.gamma = 2.^[-16:4:-8];
f@2 9 optt.search_depth = 3;
f@4 10
f@4 11 % This causes the accuracy to be lower than it could!
f@4 12 switch(experiment)
f@4 13 case('time_scat_l3')
f@4 14 optt.full_test_kernel = 0;
f@4 15 otherwise
f@4 16 optt.full_test_kernel = 1;
f@4 17 end
f@4 18
f@4 19 %
f@2 20
f@2 21 if nargin < 3
f@2 22 condition = cellstr(['none '; 'fault']);
f@4 23 else
f@4 24 condition = cellstr(condition);
f@2 25 end
f@2 26
f@2 27 for ii=1:length(condition)
f@2 28
f@2 29 [train_set,test_set] = createFolds(condition{ii});
f@2 30 train_set = find(train_set)';
f@2 31 test_set = find(test_set)';
f@2 32
f@2 33 % Only consider excerpts that are in src struct
f@2 34
f@2 35 elems = length(db.src.objects);
f@2 36 ids = zeros(elems, 1);
f@2 37 for jj=1:elems
f@2 38 ids(jj) = db.src.objects(jj).ind;
f@2 39 end
f@2 40
f@2 41 train_set = intersect(train_set, ids);
f@2 42 test_set = intersect(test_set, ids);
f@2 43
f@2 44 %
f@2 45
f@2 46 [dev_err_grid,C_grid,gamma_grid] = ...
f@2 47 svm_adaptive_param_search(db,train_set,[],optt);
f@2 48
f@2 49 [dev_err,ind] = min(mean(dev_err_grid{end},2));
f@2 50 C = C_grid{end}(ind);
f@2 51 gamma = gamma_grid{end}(ind);
f@2 52
f@2 53 optt1 = optt;
f@2 54 optt1.C = C;
f@2 55 optt1.gamma = gamma;
f@2 56
f@2 57 model = svm_train(db,train_set,optt1);
f@2 58 labels = svm_test(db,model,test_set);
f@2 59 err = classif_err(labels,test_set,db.src);
f@2 60
f@2 61 % dummy renaming of variables
f@2 62 id = test_set;
f@2 63 true_label = kron((1:10), ones(1, 100));
f@2 64 true_label = true_label(test_set)';
f@2 65 pred_label = labels';
f@2 66
f@2 67 % saving results in struct
f@2 68 results.(condition{ii}).err = err;
f@2 69 results.(condition{ii}).labels = labels;
f@2 70 results.(condition{ii}).dev_err = dev_err;
f@2 71 results.(condition{ii}).C = C;
f@2 72 results.(condition{ii}).gamma = gamma;
f@2 73 results.(condition{ii}).tab = table(id, true_label, pred_label);
f@2 74
f@2 75 end
f@2 76
f@2 77 end