Mercurial > hg > plosone_underreview
comparison scripts/classification.py @ 45:ef829b187308 branch-tests
fixed get train indices
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Fri, 15 Sep 2017 16:23:23 +0100 |
parents | e8084526f7e5 |
children | 081ff4ea7da7 |
comparison
equal
deleted
inserted
replaced
44:06e5711f9f62 | 45:ef829b187308 |
---|---|
19 X_list, Y, Yaudio = pickle.load(open(filename,'rb')) | 19 X_list, Y, Yaudio = pickle.load(open(filename,'rb')) |
20 X = np.concatenate(data_list, axis=1) | 20 X = np.concatenate(data_list, axis=1) |
21 return X, Y, Yaudio | 21 return X, Y, Yaudio |
22 | 22 |
23 | 23 |
24 def get_train_test_indices(): | 24 def get_train_test_indices(audiolabs): |
25 trainset, valset, testset = map_and_average.load_train_val_test_sets() | 25 trainset, valset, testset = map_and_average.load_train_val_test_sets() |
26 trainaudiolabels, testaudiolabels = trainset[2], testset[2] | 26 trainaudiolabels, testaudiolabels = trainset[2], testset[2] |
27 # train, test indices | 27 # train, test indices |
28 aa_train = np.unique(trainaudiolabels) | 28 aa_train = np.unique(trainaudiolabels) |
29 aa_test = np.unique(testaudiolabels) | 29 aa_test = np.unique(testaudiolabels) |
79 df_results = classify_for_filenames(file_list=FILENAMES) | 79 df_results = classify_for_filenames(file_list=FILENAMES) |
80 max_i = np.argmax(df_results[:, 1]) | 80 max_i = np.argmax(df_results[:, 1]) |
81 feat_learning_i = max_i % 4 # 4 classifiers for each feature learning method | 81 feat_learning_i = max_i % 4 # 4 classifiers for each feature learning method |
82 filename = FILENAMES[feat_learning_i] | 82 filename = FILENAMES[feat_learning_i] |
83 X, Y, Yaudio = load_data_from_pickle(filename) | 83 X, Y, Yaudio = load_data_from_pickle(filename) |
84 traininds, testinds = get_train_test_indices() | 84 traininds, testinds = get_train_test_indices(Yaudio) |
85 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds) | 85 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds) |
86 confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True) | 86 confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True) |
87 | 87 |