Mercurial > hg > plosone_underreview
annotate tests/test_load_dataset.py @ 105:edd82eb89b4b branch-tests tip
Merge
author | Maria Panteli |
---|---|
date | Sun, 15 Oct 2017 13:36:59 +0100 |
parents | ed109218dd4b |
children |
rev | line source |
---|---|
Maria@18 | 1 # -*- coding: utf-8 -*- |
Maria@18 | 2 """ |
Maria@18 | 3 Created on Fri Sep 1 19:11:52 2017 |
Maria@18 | 4 |
Maria@18 | 5 @author: mariapanteli |
Maria@18 | 6 """ |
Maria@18 | 7 |
Maria@18 | 8 import pytest |
Maria@18 | 9 |
Maria@18 | 10 import numpy as np |
Maria@18 | 11 |
Maria@18 | 12 import scripts.load_dataset as load_dataset |
Maria@18 | 13 |
Maria@18 | 14 |
Maria@18 | 15 def test_get_train_val_test_idx(): |
Maria@18 | 16 X = np.arange(10) |
Maria@18 | 17 Y = np.concatenate([np.ones(5), np.zeros(5)]) |
Maria@18 | 18 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) |
Maria@18 | 19 assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2 |
Maria@18 | 20 |
Maria@18 | 21 |
Maria@18 | 22 def test_get_train_val_test_idx_stratify(): |
Maria@18 | 23 X = np.arange(10) |
Maria@18 | 24 Y = np.concatenate([np.ones(5), np.zeros(5)]) |
Maria@18 | 25 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) |
Maria@18 | 26 assert np.array_equal(np.unique(train[1]), np.unique(val[1])) |
Maria@18 | 27 |
Maria@18 | 28 |
Maria@18 | 29 def test_subset_labels(): |
Maria@18 | 30 Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)]) |
Maria@18 | 31 subset_idx = load_dataset.subset_labels(Y, seed=1) |
Maria@18 | 32 subset_idx = np.sort(subset_idx) |
Maria@18 | 33 subset_idx_true = np.arange(5, 115) |
Maria@18 | 34 assert np.array_equal(subset_idx, subset_idx_true) |
Maria@18 | 35 |
Maria@18 | 36 |