annotate tests/test_load_dataset.py @ 105:edd82eb89b4b branch-tests tip

Merge
author Maria Panteli
date Sun, 15 Oct 2017 13:36:59 +0100
parents ed109218dd4b
children
rev   line source
Maria@18 1 # -*- coding: utf-8 -*-
Maria@18 2 """
Maria@18 3 Created on Fri Sep 1 19:11:52 2017
Maria@18 4
Maria@18 5 @author: mariapanteli
Maria@18 6 """
Maria@18 7
Maria@18 8 import pytest
Maria@18 9
Maria@18 10 import numpy as np
Maria@18 11
Maria@18 12 import scripts.load_dataset as load_dataset
Maria@18 13
Maria@18 14
Maria@18 15 def test_get_train_val_test_idx():
Maria@18 16 X = np.arange(10)
Maria@18 17 Y = np.concatenate([np.ones(5), np.zeros(5)])
Maria@18 18 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
Maria@18 19 assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2
Maria@18 20
Maria@18 21
Maria@18 22 def test_get_train_val_test_idx_stratify():
Maria@18 23 X = np.arange(10)
Maria@18 24 Y = np.concatenate([np.ones(5), np.zeros(5)])
Maria@18 25 train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1)
Maria@18 26 assert np.array_equal(np.unique(train[1]), np.unique(val[1]))
Maria@18 27
Maria@18 28
Maria@18 29 def test_subset_labels():
Maria@18 30 Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)])
Maria@18 31 subset_idx = load_dataset.subset_labels(Y, seed=1)
Maria@18 32 subset_idx = np.sort(subset_idx)
Maria@18 33 subset_idx_true = np.arange(5, 115)
Maria@18 34 assert np.array_equal(subset_idx, subset_idx_true)
Maria@18 35
Maria@18 36