Mercurial > hg > scatter_reeval
view reeval/classification/perform_classification.m @ 4:a1f6a08f624c tip
Completed version 0.0.2
author | Francisco Rodriguez Algarra <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Tue, 03 Nov 2015 21:24:41 +0000 |
parents | b1cd83874633 |
children |
line wrap: on
line source
function [results] = perform_classification(experiment, db, condition) db = svm_calc_kernel(db,'gaussian','square',1:8:size(db.features,2)); optt.kernel_type = 'gaussian'; optt.C = 2.^[0:4:8]; optt.gamma = 2.^[-16:4:-8]; optt.search_depth = 3; % This causes the accuracy to be lower than it could! switch(experiment) case('time_scat_l3') optt.full_test_kernel = 0; otherwise optt.full_test_kernel = 1; end % if nargin < 3 condition = cellstr(['none '; 'fault']); else condition = cellstr(condition); end for ii=1:length(condition) [train_set,test_set] = createFolds(condition{ii}); train_set = find(train_set)'; test_set = find(test_set)'; % Only consider excerpts that are in src struct elems = length(db.src.objects); ids = zeros(elems, 1); for jj=1:elems ids(jj) = db.src.objects(jj).ind; end train_set = intersect(train_set, ids); test_set = intersect(test_set, ids); % [dev_err_grid,C_grid,gamma_grid] = ... svm_adaptive_param_search(db,train_set,[],optt); [dev_err,ind] = min(mean(dev_err_grid{end},2)); C = C_grid{end}(ind); gamma = gamma_grid{end}(ind); optt1 = optt; optt1.C = C; optt1.gamma = gamma; model = svm_train(db,train_set,optt1); labels = svm_test(db,model,test_set); err = classif_err(labels,test_set,db.src); % dummy renaming of variables id = test_set; true_label = kron((1:10), ones(1, 100)); true_label = true_label(test_set)'; pred_label = labels'; % saving results in struct results.(condition{ii}).err = err; results.(condition{ii}).labels = labels; results.(condition{ii}).dev_err = dev_err; results.(condition{ii}).C = C; results.(condition{ii}).gamma = gamma; results.(condition{ii}).tab = table(id, true_label, pred_label); end end