Maria@18: # -*- coding: utf-8 -*- Maria@18: """ Maria@18: Created on Fri Sep 1 19:11:52 2017 Maria@18: Maria@18: @author: mariapanteli Maria@18: """ Maria@18: Maria@18: import pytest Maria@18: Maria@18: import numpy as np Maria@18: Maria@18: import scripts.load_dataset as load_dataset Maria@18: Maria@18: Maria@18: def test_get_train_val_test_idx(): Maria@18: X = np.arange(10) Maria@18: Y = np.concatenate([np.ones(5), np.zeros(5)]) Maria@18: train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) Maria@18: assert len(train[0]) == 6 and len(val[0]) == 2 and len(test[0]) == 2 Maria@18: Maria@18: Maria@18: def test_get_train_val_test_idx_stratify(): Maria@18: X = np.arange(10) Maria@18: Y = np.concatenate([np.ones(5), np.zeros(5)]) Maria@18: train, val, test = load_dataset.get_train_val_test_idx(X, Y, seed=1) Maria@18: assert np.array_equal(np.unique(train[1]), np.unique(val[1])) Maria@18: Maria@18: Maria@18: def test_subset_labels(): Maria@18: Y = np.concatenate([np.ones(5), 2*np.ones(10), 3*np.ones(100)]) Maria@18: subset_idx = load_dataset.subset_labels(Y, seed=1) Maria@18: subset_idx = np.sort(subset_idx) Maria@18: subset_idx_true = np.arange(5, 115) Maria@18: assert np.array_equal(subset_idx, subset_idx_true) Maria@18: Maria@18: