Mercurial > hg > plosone_underreview
comparison scripts/load_dataset.py @ 13:98718fdd8326 branch-tests
edits in the core functions
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Tue, 12 Sep 2017 18:03:47 +0100 |
parents | e50c63cf96be |
children | 9847b954c217 |
comparison
equal
deleted
inserted
replaced
10:8e897e82af51 | 13:98718fdd8326 |
---|---|
6 """ | 6 """ |
7 | 7 |
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 import pickle | 10 import pickle |
11 from sklearn.model_selection import train_test_split | |
11 | 12 |
12 import load_features | 13 import load_features |
13 import util_dataset | |
14 import util_filter_dataset | 14 import util_filter_dataset |
15 | 15 |
16 | 16 |
17 #METADATA_FILE = 'sample_dataset/metadata.csv' | 17 #METADATA_FILE = 'sample_dataset/metadata.csv' |
18 #OUTPUT_FILES = ['sample_dataset/train_data.pickle', 'sample_dataset/val_data.pickle', 'sample_dataset/test_data.pickle'] | 18 #OUTPUT_FILES = ['sample_dataset/train_data.pickle', 'sample_dataset/val_data.pickle', 'sample_dataset/test_data.pickle'] |
19 WIN_SIZE = 2 | 19 WIN_SIZE = 8 |
20 METADATA_FILE = 'data/metadata_BLSM_language_all.csv' | 20 METADATA_FILE = 'data/metadata_BLSM_language_all.csv' |
21 #OUTPUT_FILES = ['/import/c4dm-04/mariap/train_data_cf.pickle', '/import/c4dm-04/mariap/val_data_cf.pickle', '/import/c4dm-04/mariap/test_data_cf.pickle'] | |
22 #OUTPUT_FILES = ['/import/c4dm-04/mariap/train_data_cf_4.pickle', '/import/c4dm-04/mariap/val_data_cf_4.pickle', '/import/c4dm-04/mariap/test_data_cf_4.pickle'] | |
23 OUTPUT_FILES = ['/import/c4dm-04/mariap/train_data_melodia_'+str(WIN_SIZE)+'.pickle', | 21 OUTPUT_FILES = ['/import/c4dm-04/mariap/train_data_melodia_'+str(WIN_SIZE)+'.pickle', |
24 '/import/c4dm-04/mariap/val_data_melodia_'+str(WIN_SIZE)+'.pickle', | 22 '/import/c4dm-04/mariap/val_data_melodia_'+str(WIN_SIZE)+'.pickle', |
25 '/import/c4dm-04/mariap/test_data_melodia_'+str(WIN_SIZE)+'.pickle'] | 23 '/import/c4dm-04/mariap/test_data_melodia_'+str(WIN_SIZE)+'.pickle'] |
24 | |
25 | |
26 def get_train_val_test_idx(X, Y, seed=None): | |
27 """ Split in train, validation, test sets. | |
28 | |
29 Parameters | |
30 ---------- | |
31 X : np.array | |
32 Data or indices. | |
33 Y : np.array | |
34 Class labels for data in X. | |
35 seed: int | |
36 Random seed. | |
37 Returns | |
38 ------- | |
39 (X_train, Y_train) : tuple | |
40 Data X and labels y for the train set | |
41 (X_val, Y_val) : tuple | |
42 Data X and labels y for the validation set | |
43 (X_test, Y_test) : tuple | |
44 Data X and labels y for the test set | |
45 | |
46 """ | |
47 X_train, X_val_test, Y_train, Y_val_test = train_test_split(X, Y, train_size=0.6, random_state=seed, stratify=Y) | |
48 X_val, X_test, Y_val, Y_test = train_test_split(X_val_test, Y_val_test, train_size=0.5, random_state=seed, stratify=Y_val_test) | |
49 return (X_train, Y_train), (X_val, Y_val), (X_test, Y_test) | |
50 | |
51 | |
52 def subset_labels(Y, N_min=10, N_max=100, seed=None): | |
53 """ Subset dataset to contain minimum N_min and maximum N_max instances | |
54 per class. Return indices for this subset. | |
55 | |
56 Parameters | |
57 ---------- | |
58 Y : np.array | |
59 Class labels | |
60 N_min : int | |
61 Minimum instances per class | |
62 N_max : int | |
63 Maximum instances per class | |
64 seed: int | |
65 Random seed. | |
66 | |
67 Returns | |
68 ------- | |
69 subset_idx : np.array | |
70 Indices for a subset with classes of size bounded by N_min, N_max | |
71 | |
72 """ | |
73 np.random.seed(seed=seed) | |
74 subset_idx = [] | |
75 labels = np.unique(Y) | |
76 for label in labels: | |
77 label_idx = np.where(Y==label)[0] | |
78 counts = len(label_idx) | |
79 if counts>=N_max: | |
80 subset_idx.append(np.random.choice(label_idx, N_max, replace=False)) | |
81 elif counts>=N_min and counts<N_max: | |
82 subset_idx.append(label_idx) | |
83 else: | |
84 # not enough samples for this class, skip | |
85 continue | |
86 if len(subset_idx)>0: | |
87 subset_idx = np.concatenate(subset_idx, axis=0) | |
88 return subset_idx | |
89 | |
26 | 90 |
27 def extract_features(df, win2sec=8.0): | 91 def extract_features(df, win2sec=8.0): |
28 """Extract features from melspec and chroma. | 92 """Extract features from melspec and chroma. |
29 | 93 |
30 Parameters | 94 Parameters |
54 | 118 |
55 if __name__ == '__main__': | 119 if __name__ == '__main__': |
56 # load dataset | 120 # load dataset |
57 df = pd.read_csv(METADATA_FILE) | 121 df = pd.read_csv(METADATA_FILE) |
58 df = util_filter_dataset.remove_missing_data(df) | 122 df = util_filter_dataset.remove_missing_data(df) |
59 subset_idx = util_dataset.subset_labels(df['Country'].get_values()) | 123 subset_idx = subset_labels(df['Country'].get_values()) |
60 df = df.iloc[subset_idx, :] | 124 df = df.iloc[subset_idx, :] |
61 X, Y = np.arange(len(df)), df['Country'].get_values() | 125 X, Y = np.arange(len(df)), df['Country'].get_values() |
62 | 126 |
63 # split in train, val, test set | 127 # split in train, val, test set |
64 train_set, val_set, test_set = util_dataset.get_train_val_test_idx(X, Y) | 128 train_set, val_set, test_set = get_train_val_test_idx(X, Y) |
65 | 129 |
66 # extract features and write output | 130 # extract features and write output |
67 X_train, Y_train, Y_audio_train = extract_features(df.iloc[train_set[0], :], win2sec=WIN_SIZE) | 131 X_train, Y_train, Y_audio_train = extract_features(df.iloc[train_set[0], :], win2sec=WIN_SIZE) |
68 with open(OUTPUT_FILES[0], 'wb') as f: | 132 with open(OUTPUT_FILES[0], 'wb') as f: |
69 pickle.dump([X_train, Y_train, Y_audio_train], f) | 133 pickle.dump([X_train, Y_train, Y_audio_train], f) |
74 | 138 |
75 X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE) | 139 X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE) |
76 with open(OUTPUT_FILES[2], 'wb') as f: | 140 with open(OUTPUT_FILES[2], 'wb') as f: |
77 pickle.dump([X_test, Y_test, Y_audio_test], f) | 141 pickle.dump([X_test, Y_test, Y_audio_test], f) |
78 | 142 |
79 #out_file = '/import/c4dm-04/mariap/test_data_melodia_1_test.pickle' | |
80 # pickle.dump([X_test, Y_test, Y_audio_test], f) | |
81 #with open(out_file, 'wb') as f: |