Mercurial > hg > plosone_underreview
comparison scripts/load_dataset.py @ 20:65b9330afdd8 branch-tests
return train/test sets in load_dataset
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Wed, 13 Sep 2017 12:53:57 +0100 |
parents | 9847b954c217 |
children |
comparison
equal
deleted
inserted
replaced
19:0bba6f63f4fd | 20:65b9330afdd8 |
---|---|
150 train_set, val_set, test_set = get_train_val_test_idx(X_idx, Y) | 150 train_set, val_set, test_set = get_train_val_test_idx(X_idx, Y) |
151 X_train, Y_train, Y_audio_train = extract_features(df.iloc[train_set[0], :], win2sec=WIN_SIZE) | 151 X_train, Y_train, Y_audio_train = extract_features(df.iloc[train_set[0], :], win2sec=WIN_SIZE) |
152 X_val, Y_val, Y_audio_val = extract_features(df.iloc[val_set[0], :], win2sec=WIN_SIZE) | 152 X_val, Y_val, Y_audio_val = extract_features(df.iloc[val_set[0], :], win2sec=WIN_SIZE) |
153 X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE) | 153 X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE) |
154 | 154 |
155 train = [X_train, Y_train, Y_audio_train] | |
156 val = [X_val, Y_val, Y_audio_val] | |
157 test = [X_test, Y_test, Y_audio_test] | |
155 if write_output: | 158 if write_output: |
156 with open(OUTPUT_FILES[0], 'wb') as f: | 159 with open(OUTPUT_FILES[0], 'wb') as f: |
157 pickle.dump([X_train, Y_train, Y_audio_train], f) | 160 pickle.dump(train, f) |
158 with open(OUTPUT_FILES[1], 'wb') as f: | 161 with open(OUTPUT_FILES[1], 'wb') as f: |
159 pickle.dump([X_val, Y_val, Y_audio_val], f) | 162 pickle.dump(val, f) |
160 with open(OUTPUT_FILES[2], 'wb') as f: | 163 with open(OUTPUT_FILES[2], 'wb') as f: |
161 pickle.dump([X_test, Y_test, Y_audio_test], f) | 164 pickle.dump(test, f) |
165 return train, val, test | |
162 | 166 |
163 | 167 |
164 if __name__ == '__main__': | 168 if __name__ == '__main__': |
165 # load dataset | 169 # load dataset |
166 df = sample_dataset(csv_file=METADATA_FILE) | 170 df = sample_dataset(csv_file=METADATA_FILE) |
167 features_for_train_test_sets(df, write_output=True) | 171 train, val, test = features_for_train_test_sets(df, write_output=True) |
168 | 172 |