changeset 20:65b9330afdd8 branch-tests

return train/test sets in load_dataset
author Maria Panteli <m.x.panteli@gmail.com>
date Wed, 13 Sep 2017 12:53:57 +0100
parents 0bba6f63f4fd
children 4aa0ce25fabd 852d4377f6ca
files scripts/load_dataset.py scripts/outliers.py tests/test_outliers.py
diffstat 3 files changed, 21 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/load_dataset.py	Wed Sep 13 12:09:55 2017 +0100
+++ b/scripts/load_dataset.py	Wed Sep 13 12:53:57 2017 +0100
@@ -152,17 +152,21 @@
     X_val, Y_val, Y_audio_val = extract_features(df.iloc[val_set[0], :], win2sec=WIN_SIZE)   
     X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE)
    
+    train = [X_train, Y_train, Y_audio_train]
+    val = [X_val, Y_val, Y_audio_val]
+    test = [X_test, Y_test, Y_audio_test]
     if write_output:
         with open(OUTPUT_FILES[0], 'wb') as f:
-            pickle.dump([X_train, Y_train, Y_audio_train], f)            
+            pickle.dump(train, f)            
         with open(OUTPUT_FILES[1], 'wb') as f:
-            pickle.dump([X_val, Y_val, Y_audio_val], f)
+            pickle.dump(val, f)
         with open(OUTPUT_FILES[2], 'wb') as f:
-            pickle.dump([X_test, Y_test, Y_audio_test], f)
+            pickle.dump(test, f)
+    return train, val, test
 
 
 if __name__ == '__main__':
     # load dataset
     df = sample_dataset(csv_file=METADATA_FILE)
-    features_for_train_test_sets(df, write_output=True)
+    train, val, test = features_for_train_test_sets(df, write_output=True)
 
--- a/scripts/outliers.py	Wed Sep 13 12:09:55 2017 +0100
+++ b/scripts/outliers.py	Wed Sep 13 12:53:57 2017 +0100
@@ -65,7 +65,7 @@
     return ddf
 
 
-def clusters_metadata(df, cl_pred, out_file=None):
+def print_clusters_metadata(df, cl_pred, out_file=None):
     def get_top_N_counts(labels, N=3):
         ulab, ucount = np.unique(labels, return_counts=True)
         inds = np.argsort(ucount)
@@ -116,7 +116,7 @@
     centroids = cluster_model.cluster_centers_
     cl_pred = cluster_model.predict(X)
     ddf['Clusters'] = cl_pred
-    clusters_metadata(ddf, cl_pred)
+    print_clusters_metadata(ddf, cl_pred)
 
     # how similar are the cultures and which ones seem to be global outliers
     cluster_freq = utils.get_cluster_freq_linear(X, Y, centroids)
--- a/tests/test_outliers.py	Wed Sep 13 12:09:55 2017 +0100
+++ b/tests/test_outliers.py	Wed Sep 13 12:53:57 2017 +0100
@@ -9,6 +9,8 @@
 
 import numpy as np
 import pandas as pd
+import pickle
+import os
 
 import scripts.outliers as outliers
 
@@ -29,4 +31,12 @@
 
 
 def test_get_outliers_df():
-    assert True
\ No newline at end of file
+    np.random.seed(1)
+    X = np.random.randn(100, 3)
+    # create outliers by shifting the entries of the last 5 samples
+    X[-5:, :] = X[-5:, :] + 10
+    Y = np.concatenate([np.repeat('a', 95), np.repeat('b', 5)])
+    df, threshold, MD = outliers.get_outliers_df(X, Y)
+    # expect that items from country 'b' are detected as outliers
+    assert np.array_equal(df['Outliers'].get_values(), np.array([0., 1.0]))
+