comparison tests/test_classification.py @ 93:f9513664fe42 branch-tests

notebooks update
author mpanteli <m.x.panteli@gmail.com>
date Mon, 02 Oct 2017 18:58:39 +0100
parents d118b6ca8370
children
comparison
equal deleted inserted replaced
92:ce525367960e 93:f9513664fe42
17 X = np.random.randn(100, 3) 17 X = np.random.randn(100, 3)
18 # create 2 classes by shifting the entries of half the samples 18 # create 2 classes by shifting the entries of half the samples
19 X[-50:, :] = X[-50:, :] + 10 19 X[-50:, :] = X[-50:, :] + 10
20 Y = np.concatenate([np.repeat('a', 50), np.repeat('b', 50)]) 20 Y = np.concatenate([np.repeat('a', 50), np.repeat('b', 50)])
21 X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.6, random_state=1, stratify=Y) 21 X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.6, random_state=1, stratify=Y)
22 accuracy, _ = classification.confusion_matrix(X_train, Y_train, X_test, Y_test) 22 accuracy, _, _ = classification.confusion_matrix(X_train, Y_train, X_test, Y_test)
23 # expect perfect accuracy for this 'easy' dataset 23 # expect perfect accuracy for this 'easy' dataset
24 assert accuracy == 1.0 24 assert accuracy == 1.0
25 25