Mercurial > hg > hybrid-music-recommender-using-content-based-and-social-information
comparison Code/genre_classification/learning/preprocess_spectrograms_gtzan.py @ 24:68a62ca32441
Organized python scripts
author | Paulo Chiliguano <p.e.chiilguano@se14.qmul.ac.uk> |
---|---|
date | Sat, 15 Aug 2015 19:16:17 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
23:45e6f85d0ba4 | 24:68a62ca32441 |
---|---|
1 # -*- coding: utf-8 -*- | |
2 """ | |
3 Created on Thu Jul 23 21:55:58 2015 | |
4 | |
5 @author: paulochiliguano | |
6 """ | |
7 | |
8 | |
9 import tables | |
10 import numpy as np | |
11 import cPickle | |
12 import sklearn.preprocessing as preprocessing | |
13 | |
14 #Read HDF5 file that contains log-mel spectrograms | |
15 filename = '/homes/pchilguano/msc_project/dataset/gtzan/features/\ | |
16 feats_3sec_9.h5' | |
17 with tables.openFile(filename, 'r') as f: | |
18 features = f.root.x.read() | |
19 #filenames = f.root.filenames.read() | |
20 | |
21 #Pre-processing of spectrograms mean=0 and std=1 | |
22 #initial_shape = features.shape[1:] | |
23 n_per_example = np.prod(features.shape[1:-1]) | |
24 number_of_features = features.shape[-1] | |
25 flat_data = features.view() | |
26 flat_data.shape = (-1, number_of_features) | |
27 scaler = preprocessing.StandardScaler().fit(flat_data) | |
28 flat_data = scaler.transform(flat_data) | |
29 flat_data.shape = (features.shape[0], -1) | |
30 #flat_targets = filenames.repeat(n_per_example) | |
31 #genre = np.asarray([line.strip().split('\t')[1] for line in open(filename,'r').readlines()]) | |
32 | |
33 #Read labels from ground truth | |
34 filename = '/homes/pchilguano/msc_project/dataset/gtzan/lists/ground_truth.txt' | |
35 with open(filename, 'r') as f: | |
36 tag_set = set() | |
37 for line in f: | |
38 tag = line.strip().split('\t')[1] | |
39 tag_set.add(tag) | |
40 | |
41 #Assign label to a discrete number | |
42 tag_dict = dict([(item, index) for index, item in enumerate(sorted(tag_set))]) | |
43 with open(filename, 'r') as f: | |
44 target = np.asarray([], dtype='int32') | |
45 mp3_dict = {} | |
46 for line in f: | |
47 tag = line.strip().split('\t')[1] | |
48 target = np.append(target, tag_dict[tag]) | |
49 | |
50 train_input, valid_input, test_input = np.array_split( | |
51 flat_data, | |
52 [flat_data.shape[0]*1/2, | |
53 flat_data.shape[0]*3/4] | |
54 ) | |
55 train_target, valid_target, test_target = np.array_split( | |
56 target, | |
57 [target.shape[0]*1/2, | |
58 target.shape[0]*3/4] | |
59 ) | |
60 | |
61 f = file('/homes/pchilguano/msc_project/dataset/gtzan/features/\ | |
62 gtzan_3sec_9.pkl', 'wb') | |
63 cPickle.dump( | |
64 ( | |
65 (train_input, train_target), | |
66 (valid_input, valid_target), | |
67 (test_input, test_target) | |
68 ), | |
69 f, | |
70 protocol=cPickle.HIGHEST_PROTOCOL | |
71 ) | |
72 f.close() | |
73 | |
74 ''' | |
75 flat_target = target.repeat(n_per_example) | |
76 | |
77 train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*4/5, flat_data.shape[0]*9/10]) | |
78 train_target, valid_target, test_target = np.array_split(flat_target, [flat_target.shape[0]*4/5, flat_target.shape[0]*9/10]) | |
79 | |
80 f = file('/homes/pchilguano/deep_learning/gtzan_logistic.pkl', 'wb') | |
81 cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL) | |
82 f.close() | |
83 ''' |