changeset 45:ef829b187308 branch-tests

fixed get train indices
author Maria Panteli <m.x.panteli@gmail.com>
date Fri, 15 Sep 2017 16:23:23 +0100
parents 06e5711f9f62
children 3ed4c6af5a93
files scripts/classification.py
diffstat 1 files changed, 2 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/classification.py	Fri Sep 15 16:18:16 2017 +0100
+++ b/scripts/classification.py	Fri Sep 15 16:23:23 2017 +0100
@@ -21,7 +21,7 @@
     return X, Y, Yaudio
 
 
-def get_train_test_indices():
+def get_train_test_indices(audiolabs):
     trainset, valset, testset = map_and_average.load_train_val_test_sets()
     trainaudiolabels, testaudiolabels = trainset[2], testset[2]
     # train, test indices
@@ -81,7 +81,7 @@
     feat_learning_i = max_i % 4  # 4 classifiers for each feature learning method
     filename = FILENAMES[feat_learning_i]
     X, Y, Yaudio = load_data_from_pickle(filename)
-    traininds, testinds = get_train_test_indices()
+    traininds, testinds = get_train_test_indices(Yaudio)
     X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
     confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True)