Mercurial > hg > plosone_underreview
comparison tests/test_load_dataset.py @ 18:ed109218dd4b branch-tests
rename result scripts and more tests
author | Maria Panteli |
---|---|
date | Tue, 12 Sep 2017 23:18:19 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
17:2e487b9c0a7b | 18:ed109218dd4b |
---|---|
1 # -*- coding: utf-8 -*- | |
2 """ | |
3 Created on Fri Sep 1 19:11:52 2017 | |
4 | |
5 @author: mariapanteli | |
6 """ | |
7 | |
8 import pytest | |
9 | |
10 import numpy as np | |
11 | |
12 import scripts.load_dataset as load_dataset | |
13 | |
14 | |
15 def test_get_train_val_test_idx(): | |
16 X = np.arange(10) | |
17 Y = np.concatenate([np.ones(5), np.zeros(5)]) | |
18 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) | |
19 assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2 | |
20 | |
21 | |
22 def test_get_train_val_test_idx_stratify(): | |
23 X = np.arange(10) | |
24 Y = np.concatenate([np.ones(5), np.zeros(5)]) | |
25 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) | |
26 assert np.array_equal(np.unique(train[1]), np.unique(val[1])) | |
27 | |
28 | |
29 def test_subset_labels(): | |
30 Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)]) | |
31 subset_idx = load_dataset.subset_labels(Y, seed=1) | |
32 subset_idx = np.sort(subset_idx) | |
33 subset_idx_true = np.arange(5, 115) | |
34 assert np.array_equal(subset_idx, subset_idx_true) | |
35 | |
36 |