Mercurial > hg > plosone_underreview
diff tests/test_load_dataset.py @ 18:ed109218dd4b branch-tests
rename result scripts and more tests
author | Maria Panteli |
---|---|
date | Tue, 12 Sep 2017 23:18:19 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/test_load_dataset.py Tue Sep 12 23:18:19 2017 +0100 @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Sep 1 19:11:52 2017 + +@author: mariapanteli +""" + +import pytest + +import numpy as np + +import scripts.load_dataset as load_dataset + + +def test_get_train_val_test_idx(): + X = np.arange(10) + Y = np.concatenate([np.ones(5), np.zeros(5)]) + train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) + assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2 + + +def test_get_train_val_test_idx_stratify(): + X = np.arange(10) + Y = np.concatenate([np.ones(5), np.zeros(5)]) + train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) + assert np.array_equal(np.unique(train[1]), np.unique(val[1])) + + +def test_subset_labels(): + Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)]) + subset_idx = load_dataset.subset_labels(Y, seed=1) + subset_idx = np.sort(subset_idx) + subset_idx_true = np.arange(5, 115) + assert np.array_equal(subset_idx, subset_idx_true) + + \ No newline at end of file