Mercurial > hg > scatter_reeval
diff reeval/classification/perform_classification.m @ 2:b1cd83874633
Major structural revision. Modular organization of functionalities
| author | Francisco Rodriguez Algarra <f.rodriguezalgarra@qmul.ac.uk> |
|---|---|
| date | Wed, 28 Oct 2015 16:15:47 +0000 |
| parents | |
| children | a1f6a08f624c |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/reeval/classification/perform_classification.m Wed Oct 28 16:15:47 2015 +0000 @@ -0,0 +1,65 @@ +function [results] = perform_classification(experiment, db, condition) + + db = svm_calc_kernel(db,'gaussian','square',1:8:size(db.features,2)); + + optt.kernel_type = 'gaussian'; + optt.C = 2.^[0:4:8]; + optt.gamma = 2.^[-16:4:-8]; + optt.search_depth = 3; + optt.full_test_kernel = 1; + + if nargin < 3 + condition = cellstr(['none '; 'fault']); + end + + for ii=1:length(condition) + + [train_set,test_set] = createFolds(condition{ii}); + train_set = find(train_set)'; + test_set = find(test_set)'; + + % Only consider excerpts that are in src struct + + elems = length(db.src.objects); + ids = zeros(elems, 1); + for jj=1:elems + ids(jj) = db.src.objects(jj).ind; + end + + train_set = intersect(train_set, ids); + test_set = intersect(test_set, ids); + + % + + [dev_err_grid,C_grid,gamma_grid] = ... + svm_adaptive_param_search(db,train_set,[],optt); + + [dev_err,ind] = min(mean(dev_err_grid{end},2)); + C = C_grid{end}(ind); + gamma = gamma_grid{end}(ind); + + optt1 = optt; + optt1.C = C; + optt1.gamma = gamma; + + model = svm_train(db,train_set,optt1); + labels = svm_test(db,model,test_set); + err = classif_err(labels,test_set,db.src); + + % dummy renaming of variables + id = test_set; + true_label = kron((1:10), ones(1, 100)); + true_label = true_label(test_set)'; + pred_label = labels'; + + % saving results in struct + results.(condition{ii}).err = err; + results.(condition{ii}).labels = labels; + results.(condition{ii}).dev_err = dev_err; + results.(condition{ii}).C = C; + results.(condition{ii}).gamma = gamma; + results.(condition{ii}).tab = table(id, true_label, pred_label); + + end + +end
