annotate scripts/classification.py @ 18:ed109218dd4b branch-tests

rename result scripts and more tests
author Maria Panteli
date Tue, 12 Sep 2017 23:18:19 +0100
parents
children e8084526f7e5
rev   line source
Maria@18 1 # -*- coding: utf-8 -*-
Maria@18 2 """
Maria@18 3 Created on Thu Nov 10 15:10:32 2016
Maria@18 4
Maria@18 5 @author: mariapanteli
Maria@18 6 """
Maria@18 7 import numpy as np
Maria@18 8 import pandas as pd
Maria@18 9 from sklearn import metrics
Maria@18 10
Maria@18 11 import map_and_average
Maria@18 12 import util_feature_learning
Maria@18 13
Maria@18 14
Maria@18 15 FILENAMES = map_and_average.OUTPUT_FILES
Maria@18 16
Maria@18 17
Maria@18 18 def load_data_from_pickle(filename):
Maria@18 19 X_list, Y, Yaudio = pickle.load(open(filename,'rb'))
Maria@18 20 X = np.concatenate(data_list, axis=1)
Maria@18 21 return X, Y, Yaudio
Maria@18 22
Maria@18 23
Maria@18 24 def get_train_test_indices():
Maria@18 25 trainset, valset, testset = map_and_average.load_train_val_test_sets()
Maria@18 26 trainaudiolabels, testaudiolabels = trainset[2], testset[2]
Maria@18 27 # train, test indices
Maria@18 28 aa_train = np.unique(trainaudiolabels)
Maria@18 29 aa_test = np.unique(testaudiolabels)
Maria@18 30 traininds = np.array([i for i, item in enumerate(audiolabs) if item in aa_train])
Maria@18 31 testinds = np.array([i for i, item in enumerate(audiolabs) if item in aa_test])
Maria@18 32 return traininds, testinds
Maria@18 33
Maria@18 34
Maria@18 35 def get_train_test_sets(X, Y, traininds, testinds):
Maria@18 36 X_train = X[traininds, :]
Maria@18 37 Y_train = Y[traininds]
Maria@18 38 X_test = X[testinds, :]
Maria@18 39 Y_test = Y[testinds]
Maria@18 40 return X_train, Y_train, X_test, Y_test
Maria@18 41
Maria@18 42
Maria@18 43 def classify_for_filenames(file_list=FILENAMES):
Maria@18 44 df_results = pd.DataFrame()
Maria@18 45 feat_learner = util_feature_learning.Transformer()
Maria@18 46 for filename in file_list:
Maria@18 47 X, Y, Yaudio = load_data_from_pickle(filename)
Maria@18 48 traininds, testinds = get_train_test_indices()
Maria@18 49 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
Maria@18 50 df_result = feat_learner.classify(X_train, Y_train, X_test, Y_test)
Maria@18 51 df_results = pd.concat([df_results, df_result], axis=0, ignore_index=True)
Maria@18 52 return df_results
Maria@18 53
Maria@18 54
Maria@18 55 def plot_CF(CF, labels=None, figurename=None):
Maria@18 56 labels[labels=='United States of America'] = 'United States Amer.'
Maria@18 57 plt.imshow(CF, cmap="Greys")
Maria@18 58 plt.xticks(range(len(labels)), labels, rotation='vertical', fontsize=4)
Maria@18 59 plt.yticks(range(len(labels)), labels, fontsize=4)
Maria@18 60 if figurename is not None:
Maria@18 61 plt.savefig(figurename, bbox_inches='tight')
Maria@18 62
Maria@18 63
Maria@18 64 def confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=False, plots=False):
Maria@18 65 feat_learner = util_feature_learning.Transformer()
Maria@18 66 accuracy, predictions = util_feature_learning.classification_accuracy(X_train, Y_train,
Maria@18 67 X_test, Y_test, model=util_feature_learning.modelLDA)
Maria@18 68 labels = np.unique(Y_test) # TODO: countries in geographical proximity
Maria@18 69 CF = metrics.confusion_matrix(Y_test, predictions, labels=labels)
Maria@18 70 if saveCF:
Maria@18 71 np.savetxt('data/CFlabels.csv', labels, fmt='%s')
Maria@18 72 np.savetxt('data/CF.csv', CF, fmt='%10.5f')
Maria@18 73 if plots:
Maria@18 74 plot_CF(CF, labels=labels, figurename='data/conf_matrix.pdf')
Maria@18 75 return accuracy, predictions
Maria@18 76
Maria@18 77
Maria@18 78 if __name__ == '__main__':
Maria@18 79 df_results = classify_for_filenames(file_list=FILENAMES)
Maria@18 80 max_i = np.argmax(df_results[:, 1])
Maria@18 81 feat_learning_i = max_i % 4 # 4 classifiers for each feature learning method
Maria@18 82 filename = FILENAMES[feat_learning_i]
Maria@18 83 X, Y, Yaudio = load_data_from_pickle(filename)
Maria@18 84 traininds, testinds = get_train_test_indices()
Maria@18 85 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
Maria@18 86 confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True)
Maria@18 87