comparison tests/test_load_dataset.py @ 18:ed109218dd4b branch-tests

rename result scripts and more tests
author Maria Panteli
date Tue, 12 Sep 2017 23:18:19 +0100
parents
children
comparison
equal deleted inserted replaced
17:2e487b9c0a7b 18:ed109218dd4b
1 # -*- coding: utf-8 -*-
2 """
3 Created on Fri Sep 1 19:11:52 2017
4
5 @author: mariapanteli
6 """
7
8 import pytest
9
10 import numpy as np
11
12 import scripts.load_dataset as load_dataset
13
14
15 def test_get_train_val_test_idx():
16 X = np.arange(10)
17 Y = np.concatenate([np.ones(5), np.zeros(5)])
18 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
19 assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2
20
21
22 def test_get_train_val_test_idx_stratify():
23 X = np.arange(10)
24 Y = np.concatenate([np.ones(5), np.zeros(5)])
25 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
26 assert np.array_equal(np.unique(train[1]), np.unique(val[1]))
27
28
29 def test_subset_labels():
30 Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)])
31 subset_idx = load_dataset.subset_labels(Y, seed=1)
32 subset_idx = np.sort(subset_idx)
33 subset_idx_true = np.arange(5, 115)
34 assert np.array_equal(subset_idx, subset_idx_true)
35
36