comparison scripts/classification.py @ 47:081ff4ea7da7 branch-tests

sensitivity experiment split
author Maria Panteli <m.x.panteli@gmail.com>
date Fri, 15 Sep 2017 17:33:14 +0100
parents ef829b187308
children 08b9327f1935
comparison
equal deleted inserted replaced
46:3ed4c6af5a93 47:081ff4ea7da7
43 def classify_for_filenames(file_list=FILENAMES): 43 def classify_for_filenames(file_list=FILENAMES):
44 df_results = pd.DataFrame() 44 df_results = pd.DataFrame()
45 feat_learner = util_feature_learning.Transformer() 45 feat_learner = util_feature_learning.Transformer()
46 for filename in file_list: 46 for filename in file_list:
47 X, Y, Yaudio = load_data_from_pickle(filename) 47 X, Y, Yaudio = load_data_from_pickle(filename)
48 traininds, testinds = get_train_test_indices() 48 traininds, testinds = get_train_test_indices(Yaudio)
49 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds) 49 X_train, Y_train, X_test, Y_test = get_train_test_sets(X, Y, traininds, testinds)
50 df_result = feat_learner.classify(X_train, Y_train, X_test, Y_test) 50 df_result = feat_learner.classify(X_train, Y_train, X_test, Y_test)
51 df_results = pd.concat([df_results, df_result], axis=0, ignore_index=True) 51 df_results = pd.concat([df_results, df_result], axis=0, ignore_index=True)
52 return df_results 52 return df_results
53
54
55 def classify_each_feature(X_train, Y_train, X_test, Y_test):
56 n_dim = X_train.shape[1]
57 feat_labels, feat_inds = map_and_average.get_feat_inds(n_dim=n_dim)
58 #df_results = pd.DataFrame()
59 # first the classification with all features together
60 df_results = feat_learner.classify(X_train, Y_train, X_test, Y_test)
61 # then append for each feature separately
62 for i in range(len(feat_inds)):
63 df_result = feat_learner.classify(X_train[:, feat_inds[i]], Y_train,
64 X_test[:, feat_inds[i]], Y_test)
65 df_results = pd.concat([df_results, df_result], axis=1, ignore_index=True)
66 return df_results
53 67
54 68
55 def plot_CF(CF, labels=None, figurename=None): 69 def plot_CF(CF, labels=None, figurename=None):
56 labels[labels=='United States of America'] = 'United States Amer.' 70 labels[labels=='United States of America'] = 'United States Amer.'
57 plt.imshow(CF, cmap="Greys") 71 plt.imshow(CF, cmap="Greys")