Mercurial > hg > hybrid-music-recommender-using-content-based-and-social-information
diff Code/genre_classification/learning/preprocess_spectrograms_gtzan.py @ 24:68a62ca32441
Organized python scripts
author | Paulo Chiliguano <p.e.chiilguano@se14.qmul.ac.uk> |
---|---|
date | Sat, 15 Aug 2015 19:16:17 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Code/genre_classification/learning/preprocess_spectrograms_gtzan.py Sat Aug 15 19:16:17 2015 +0100 @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 23 21:55:58 2015 + +@author: paulochiliguano +""" + + +import tables +import numpy as np +import cPickle +import sklearn.preprocessing as preprocessing + +#Read HDF5 file that contains log-mel spectrograms +filename = '/homes/pchilguano/msc_project/dataset/gtzan/features/\ +feats_3sec_9.h5' +with tables.openFile(filename, 'r') as f: + features = f.root.x.read() + #filenames = f.root.filenames.read() + +#Pre-processing of spectrograms mean=0 and std=1 +#initial_shape = features.shape[1:] +n_per_example = np.prod(features.shape[1:-1]) +number_of_features = features.shape[-1] +flat_data = features.view() +flat_data.shape = (-1, number_of_features) +scaler = preprocessing.StandardScaler().fit(flat_data) +flat_data = scaler.transform(flat_data) +flat_data.shape = (features.shape[0], -1) +#flat_targets = filenames.repeat(n_per_example) +#genre = np.asarray([line.strip().split('\t')[1] for line in open(filename,'r').readlines()]) + +#Read labels from ground truth +filename = '/homes/pchilguano/msc_project/dataset/gtzan/lists/ground_truth.txt' +with open(filename, 'r') as f: + tag_set = set() + for line in f: + tag = line.strip().split('\t')[1] + tag_set.add(tag) + +#Assign label to a discrete number +tag_dict = dict([(item, index) for index, item in enumerate(sorted(tag_set))]) +with open(filename, 'r') as f: + target = np.asarray([], dtype='int32') + mp3_dict = {} + for line in f: + tag = line.strip().split('\t')[1] + target = np.append(target, tag_dict[tag]) + +train_input, valid_input, test_input = np.array_split( + flat_data, + [flat_data.shape[0]*1/2, + flat_data.shape[0]*3/4] +) +train_target, valid_target, test_target = np.array_split( + target, + [target.shape[0]*1/2, + target.shape[0]*3/4] +) + +f = file('/homes/pchilguano/msc_project/dataset/gtzan/features/\ +gtzan_3sec_9.pkl', 'wb') +cPickle.dump( + ( + (train_input, train_target), + (valid_input, valid_target), + (test_input, test_target) + ), + f, + protocol=cPickle.HIGHEST_PROTOCOL +) +f.close() + +''' +flat_target = target.repeat(n_per_example) + +train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*4/5, flat_data.shape[0]*9/10]) +train_target, valid_target, test_target = np.array_split(flat_target, [flat_target.shape[0]*4/5, flat_target.shape[0]*9/10]) + +f = file('/homes/pchilguano/deep_learning/gtzan_logistic.pkl', 'wb') +cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL) +f.close() +'''