diff reeval/classification/perform_classification.m @ 2:b1cd83874633

Major structural revision. Modular organization of functionalities
author Francisco Rodriguez Algarra <f.rodriguezalgarra@qmul.ac.uk>
date Wed, 28 Oct 2015 16:15:47 +0000
parents
children a1f6a08f624c
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/reeval/classification/perform_classification.m	Wed Oct 28 16:15:47 2015 +0000
@@ -0,0 +1,65 @@
+function [results] = perform_classification(experiment, db, condition)
+
+  db = svm_calc_kernel(db,'gaussian','square',1:8:size(db.features,2));
+
+  optt.kernel_type = 'gaussian';
+  optt.C = 2.^[0:4:8];
+  optt.gamma = 2.^[-16:4:-8];
+  optt.search_depth = 3;
+  optt.full_test_kernel = 1;
+
+  if nargin < 3
+    condition = cellstr(['none '; 'fault']);
+  end
+
+  for ii=1:length(condition)
+
+    [train_set,test_set] = createFolds(condition{ii});
+    train_set = find(train_set)';
+    test_set = find(test_set)';
+
+    % Only consider excerpts that are in src struct
+
+    elems = length(db.src.objects);
+    ids = zeros(elems, 1);
+    for jj=1:elems
+      ids(jj) = db.src.objects(jj).ind;
+    end
+
+    train_set = intersect(train_set, ids);
+    test_set = intersect(test_set, ids);
+
+    %
+
+    [dev_err_grid,C_grid,gamma_grid] = ...
+      svm_adaptive_param_search(db,train_set,[],optt);
+
+    [dev_err,ind] = min(mean(dev_err_grid{end},2));
+    C = C_grid{end}(ind);
+    gamma = gamma_grid{end}(ind);
+
+    optt1 = optt;
+    optt1.C = C;
+    optt1.gamma = gamma;
+
+    model = svm_train(db,train_set,optt1);
+    labels = svm_test(db,model,test_set);
+    err = classif_err(labels,test_set,db.src);
+
+    % dummy renaming of variables
+    id = test_set;
+    true_label = kron((1:10), ones(1, 100));
+    true_label = true_label(test_set)';
+    pred_label = labels';
+
+    % saving results in struct
+    results.(condition{ii}).err = err;
+    results.(condition{ii}).labels = labels;
+    results.(condition{ii}).dev_err = dev_err;
+    results.(condition{ii}).C = C;
+    results.(condition{ii}).gamma = gamma;
+    results.(condition{ii}).tab = table(id, true_label, pred_label);
+
+  end
+
+end