comparison scripts/classification.py @ 45:ef829b187308 branch-tests

fixed get train indices
author Maria Panteli <m.x.panteli@gmail.com>
date Fri, 15 Sep 2017 16:23:23 +0100
parents e8084526f7e5
children 081ff4ea7da7
comparison
equal deleted inserted replaced
44:06e5711f9f62 45:ef829b187308
19 X_list, Y, Yaudio = pickle.load(open(filename,'rb')) 19 X_list, Y, Yaudio = pickle.load(open(filename,'rb'))
20 X = np.concatenate(data_list, axis=1) 20 X = np.concatenate(data_list, axis=1)
21 return X, Y, Yaudio 21 return X, Y, Yaudio
22 22
23 23
24 def get_train_test_indices(): 24 def get_train_test_indices(audiolabs):
25 trainset, valset, testset = map_and_average.load_train_val_test_sets() 25 trainset, valset, testset = map_and_average.load_train_val_test_sets()
26 trainaudiolabels, testaudiolabels = trainset[2], testset[2] 26 trainaudiolabels, testaudiolabels = trainset[2], testset[2]
27 # train, test indices 27 # train, test indices
28 aa_train = np.unique(trainaudiolabels) 28 aa_train = np.unique(trainaudiolabels)
29 aa_test = np.unique(testaudiolabels) 29 aa_test = np.unique(testaudiolabels)
79 df_results = classify_for_filenames(file_list=FILENAMES) 79 df_results = classify_for_filenames(file_list=FILENAMES)
80 max_i = np.argmax(df_results[:, 1]) 80 max_i = np.argmax(df_results[:, 1])
81 feat_learning_i = max_i % 4 # 4 classifiers for each feature learning method 81 feat_learning_i = max_i % 4 # 4 classifiers for each feature learning method
82 filename = FILENAMES[feat_learning_i] 82 filename = FILENAMES[feat_learning_i]
83 X, Y, Yaudio = load_data_from_pickle(filename) 83 X, Y, Yaudio = load_data_from_pickle(filename)
84 traininds, testinds = get_train_test_indices() 84 traininds, testinds = get_train_test_indices(Yaudio)
85 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds) 85 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
86 confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True) 86 confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True)
87 87