Mercurial > hg > plosone_underreview
view tests/test_load_dataset.py @ 105:edd82eb89b4b branch-tests tip
Merge
author | Maria Panteli |
---|---|
date | Sun, 15 Oct 2017 13:36:59 +0100 |
parents | ed109218dd4b |
children |
line wrap: on
line source
# -*- coding: utf-8 -*- """ Created on Fri Sep 1 19:11:52 2017 @author: mariapanteli """ import pytest import numpy as np import scripts.load_dataset as load_dataset def test_get_train_val_test_idx(): X = np.arange(10) Y = np.concatenate([np.ones(5), np.zeros(5)]) train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2 def test_get_train_val_test_idx_stratify(): X = np.arange(10) Y = np.concatenate([np.ones(5), np.zeros(5)]) train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) assert np.array_equal(np.unique(train[1]), np.unique(val[1])) def test_subset_labels(): Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)]) subset_idx = load_dataset.subset_labels(Y, seed=1) subset_idx = np.sort(subset_idx) subset_idx_true = np.arange(5, 115) assert np.array_equal(subset_idx, subset_idx_true)