Mercurial > hg > plosone_underreview
diff scripts/load_dataset.py @ 20:65b9330afdd8 branch-tests
return train/test sets in load_dataset
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Wed, 13 Sep 2017 12:53:57 +0100 |
parents | 9847b954c217 |
children |
line wrap: on
line diff
--- a/scripts/load_dataset.py Wed Sep 13 12:09:55 2017 +0100 +++ b/scripts/load_dataset.py Wed Sep 13 12:53:57 2017 +0100 @@ -152,17 +152,21 @@ X_val, Y_val, Y_audio_val = extract_features(df.iloc[val_set[0], :], win2sec=WIN_SIZE) X_test, Y_test, Y_audio_test = extract_features(df.iloc[test_set[0], :], win2sec=WIN_SIZE) + train = [X_train, Y_train, Y_audio_train] + val = [X_val, Y_val, Y_audio_val] + test = [X_test, Y_test, Y_audio_test] if write_output: with open(OUTPUT_FILES[0], 'wb') as f: - pickle.dump([X_train, Y_train, Y_audio_train], f) + pickle.dump(train, f) with open(OUTPUT_FILES[1], 'wb') as f: - pickle.dump([X_val, Y_val, Y_audio_val], f) + pickle.dump(val, f) with open(OUTPUT_FILES[2], 'wb') as f: - pickle.dump([X_test, Y_test, Y_audio_test], f) + pickle.dump(test, f) + return train, val, test if __name__ == '__main__': # load dataset df = sample_dataset(csv_file=METADATA_FILE) - features_for_train_test_sets(df, write_output=True) + train, val, test = features_for_train_test_sets(df, write_output=True)