view scripts/classification.py @ 30:e8084526f7e5 branch-tests

additional test functions
author Maria Panteli <m.x.panteli@gmail.com>
date Wed, 13 Sep 2017 19:57:49 +0100
parents ed109218dd4b
children ef829b187308
line wrap: on
line source
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 10 15:10:32 2016

@author: mariapanteli
"""
import numpy as np
import pandas as pd
from sklearn import metrics

import map_and_average
import util_feature_learning


FILENAMES = map_and_average.OUTPUT_FILES


def load_data_from_pickle(filename):
    X_list, Y, Yaudio = pickle.load(open(filename,'rb'))
    X = np.concatenate(data_list, axis=1)
    return X, Y, Yaudio


def get_train_test_indices():
    trainset, valset, testset = map_and_average.load_train_val_test_sets()
    trainaudiolabels, testaudiolabels = trainset[2], testset[2]
    # train, test indices
    aa_train = np.unique(trainaudiolabels)
    aa_test = np.unique(testaudiolabels)
    traininds = np.array([i for i, item in enumerate(audiolabs) if item in aa_train])
    testinds = np.array([i for i, item in enumerate(audiolabs) if item in aa_test])
    return traininds, testinds


def get_train_test_sets(X, Y, traininds, testinds):
    X_train = X[traininds, :]
    Y_train = Y[traininds]
    X_test = X[testinds, :]
    Y_test = Y[testinds]
    return X_train, Y_train, X_test, Y_test


def classify_for_filenames(file_list=FILENAMES):
    df_results = pd.DataFrame()
    feat_learner = util_feature_learning.Transformer()
    for filename in file_list:
        X, Y, Yaudio = load_data_from_pickle(filename)
        traininds, testinds = get_train_test_indices()
        X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
        df_result = feat_learner.classify(X_train, Y_train, X_test, Y_test)
        df_results = pd.concat([df_results, df_result], axis=0, ignore_index=True)
    return df_results    


def plot_CF(CF, labels=None, figurename=None):
    labels[labels=='United States of America'] = 'United States Amer.'
    plt.imshow(CF, cmap="Greys")
    plt.xticks(range(len(labels)), labels, rotation='vertical', fontsize=4)
    plt.yticks(range(len(labels)), labels, fontsize=4)
    if figurename is not None:
        plt.savefig(figurename, bbox_inches='tight')


def confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=False, plots=False):
    feat_learner = util_feature_learning.Transformer()
    accuracy, predictions = feat_learner.classification_accuracy(X_train, Y_train, 
                        X_test, Y_test, model=feat_learner.modelLDA)
    labels = np.unique(Y_test)  # TODO: countries in geographical proximity
    CF = metrics.confusion_matrix(Y_test, predictions, labels=labels)
    if saveCF:
        np.savetxt('data/CFlabels.csv', labels, fmt='%s')
        np.savetxt('data/CF.csv', CF, fmt='%10.5f')
    if plots:
        plot_CF(CF, labels=labels, figurename='data/conf_matrix.pdf')
    return accuracy, predictions


if __name__ == '__main__':
    df_results = classify_for_filenames(file_list=FILENAMES)
    max_i = np.argmax(df_results[:, 1])
    feat_learning_i = max_i % 4  # 4 classifiers for each feature learning method
    filename = FILENAMES[feat_learning_i]
    X, Y, Yaudio = load_data_from_pickle(filename)
    traininds, testinds = get_train_test_indices()
    X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
    confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=True, plots=True)