Mercurial > hg > hybrid-music-recommender-using-content-based-and-social-information
view Code/prepare_dataset.py @ 21:e68dbee1f6db
Modified code
New datasets
Updated report
author | Paulo Chiliguano <p.e.chiilguano@se14.qmul.ac.uk> |
---|---|
date | Tue, 11 Aug 2015 10:50:36 +0100 |
parents | 1dbd24575d44 |
children |
line wrap: on
line source
# -*- coding: utf-8 -*- """ Created on Thu Jul 23 21:55:58 2015 @author: paulochiliguano """ import tables import numpy as np import cPickle import sklearn.preprocessing as preprocessing ''' Read HDF5 file that contains log-mel spectrograms ''' filename = '/homes/pchilguano/deep_learning/features/feats.h5' with tables.openFile(filename, 'r') as f: features = f.root.x.read() #filenames = f.root.filenames.read() ''' Pre-processing of spectrograms mean=0 and std=1 ''' #initial_shape = features.shape[1:] n_per_example = np.prod(features.shape[1:-1]) number_of_features = features.shape[-1] flat_data = features.view() flat_data.shape = (-1, number_of_features) scaler = preprocessing.StandardScaler().fit(flat_data) flat_data = scaler.transform(flat_data) flat_data.shape = (features.shape[0], -1) #flat_targets = filenames.repeat(n_per_example) #genre = np.asarray([line.strip().split('\t')[1] for line in open(filename,'r').readlines()]) ''' Read labels from ground truth ''' filename = '/homes/pchilguano/deep_learning/lists/ground_truth.txt' with open(filename, 'r') as f: tag_set = set() for line in f: tag = line.strip().split('\t')[1] tag_set.add(tag) ''' Assign label to a discrete number ''' tag_dict = dict([(item, index) for index, item in enumerate(sorted(tag_set))]) with open(filename, 'r') as f: target = np.asarray([], dtype='int32') mp3_dict = {} for line in f: tag = line.strip().split('\t')[1] target = np.append(target, tag_dict[tag]) train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*1/2, flat_data.shape[0]*3/4]) train_target, valid_target, test_target = np.array_split(target, [target.shape[0]*1/2, target.shape[0]*3/4]) f = file('/homes/pchilguano/deep_learning/gtzan.pkl', 'wb') cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL) f.close() ''' flat_target = target.repeat(n_per_example) train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*4/5, flat_data.shape[0]*9/10]) train_target, valid_target, test_target = np.array_split(flat_target, [flat_target.shape[0]*4/5, flat_target.shape[0]*9/10]) f = file('/homes/pchilguano/deep_learning/gtzan_logistic.pkl', 'wb') cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL) f.close() '''