p@24: # -*- coding: utf-8 -*- p@24: """ p@24: Created on Thu Jul 23 21:55:58 2015 p@24: p@24: @author: paulochiliguano p@24: """ p@24: p@24: p@24: import tables p@24: import numpy as np p@24: import cPickle p@24: import sklearn.preprocessing as preprocessing p@24: p@24: #Read HDF5 file that contains log-mel spectrograms p@24: filename = '/homes/pchilguano/msc_project/dataset/gtzan/features/\ p@24: feats_3sec_9.h5' p@24: with tables.openFile(filename, 'r') as f: p@24: features = f.root.x.read() p@24: #filenames = f.root.filenames.read() p@24: p@24: #Pre-processing of spectrograms mean=0 and std=1 p@24: #initial_shape = features.shape[1:] p@24: n_per_example = np.prod(features.shape[1:-1]) p@24: number_of_features = features.shape[-1] p@24: flat_data = features.view() p@24: flat_data.shape = (-1, number_of_features) p@24: scaler = preprocessing.StandardScaler().fit(flat_data) p@24: flat_data = scaler.transform(flat_data) p@24: flat_data.shape = (features.shape[0], -1) p@24: #flat_targets = filenames.repeat(n_per_example) p@24: #genre = np.asarray([line.strip().split('\t')[1] for line in open(filename,'r').readlines()]) p@24: p@24: #Read labels from ground truth p@24: filename = '/homes/pchilguano/msc_project/dataset/gtzan/lists/ground_truth.txt' p@24: with open(filename, 'r') as f: p@24: tag_set = set() p@24: for line in f: p@24: tag = line.strip().split('\t')[1] p@24: tag_set.add(tag) p@24: p@24: #Assign label to a discrete number p@24: tag_dict = dict([(item, index) for index, item in enumerate(sorted(tag_set))]) p@24: with open(filename, 'r') as f: p@24: target = np.asarray([], dtype='int32') p@24: mp3_dict = {} p@24: for line in f: p@24: tag = line.strip().split('\t')[1] p@24: target = np.append(target, tag_dict[tag]) p@24: p@24: train_input, valid_input, test_input = np.array_split( p@24: flat_data, p@24: [flat_data.shape[0]*1/2, p@24: flat_data.shape[0]*3/4] p@24: ) p@24: train_target, valid_target, test_target = np.array_split( p@24: target, p@24: [target.shape[0]*1/2, p@24: target.shape[0]*3/4] p@24: ) p@24: p@24: f = file('/homes/pchilguano/msc_project/dataset/gtzan/features/\ p@24: gtzan_3sec_9.pkl', 'wb') p@24: cPickle.dump( p@24: ( p@24: (train_input, train_target), p@24: (valid_input, valid_target), p@24: (test_input, test_target) p@24: ), p@24: f, p@24: protocol=cPickle.HIGHEST_PROTOCOL p@24: ) p@24: f.close() p@24: p@24: ''' p@24: flat_target = target.repeat(n_per_example) p@24: p@24: train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*4/5, flat_data.shape[0]*9/10]) p@24: train_target, valid_target, test_target = np.array_split(flat_target, [flat_target.shape[0]*4/5, flat_target.shape[0]*9/10]) p@24: p@24: f = file('/homes/pchilguano/deep_learning/gtzan_logistic.pkl', 'wb') p@24: cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL) p@24: f.close() p@24: '''