diff tests/test_load_dataset.py @ 18:ed109218dd4b branch-tests

rename result scripts and more tests
author Maria Panteli
date Tue, 12 Sep 2017 23:18:19 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test_load_dataset.py	Tue Sep 12 23:18:19 2017 +0100
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Sep  1 19:11:52 2017
+
+@author: mariapanteli
+"""
+
+import pytest
+
+import numpy as np
+
+import scripts.load_dataset as load_dataset
+
+
+def test_get_train_val_test_idx():
+    X = np.arange(10)
+    Y = np.concatenate([np.ones(5), np.zeros(5)])
+    train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
+    assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2
+
+
+def test_get_train_val_test_idx_stratify():
+    X = np.arange(10)
+    Y = np.concatenate([np.ones(5), np.zeros(5)])
+    train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
+    assert np.array_equal(np.unique(train[1]), np.unique(val[1]))
+
+
+def test_subset_labels():
+    Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)])
+    subset_idx = load_dataset.subset_labels(Y, seed=1)
+    subset_idx = np.sort(subset_idx)
+    subset_idx_true = np.arange(5, 115)
+    assert np.array_equal(subset_idx, subset_idx_true)
+
+    
\ No newline at end of file