comparison Code/genre_classification/learning/preprocess_spectrograms_gtzan.py @ 24:68a62ca32441

Organized python scripts
author Paulo Chiliguano <p.e.chiilguano@se14.qmul.ac.uk>
date Sat, 15 Aug 2015 19:16:17 +0100
parents
children
comparison
equal deleted inserted replaced
23:45e6f85d0ba4 24:68a62ca32441
1 # -*- coding: utf-8 -*-
2 """
3 Created on Thu Jul 23 21:55:58 2015
4
5 @author: paulochiliguano
6 """
7
8
9 import tables
10 import numpy as np
11 import cPickle
12 import sklearn.preprocessing as preprocessing
13
14 #Read HDF5 file that contains log-mel spectrograms
15 filename = '/homes/pchilguano/msc_project/dataset/gtzan/features/\
16 feats_3sec_9.h5'
17 with tables.openFile(filename, 'r') as f:
18 features = f.root.x.read()
19 #filenames = f.root.filenames.read()
20
21 #Pre-processing of spectrograms mean=0 and std=1
22 #initial_shape = features.shape[1:]
23 n_per_example = np.prod(features.shape[1:-1])
24 number_of_features = features.shape[-1]
25 flat_data = features.view()
26 flat_data.shape = (-1, number_of_features)
27 scaler = preprocessing.StandardScaler().fit(flat_data)
28 flat_data = scaler.transform(flat_data)
29 flat_data.shape = (features.shape[0], -1)
30 #flat_targets = filenames.repeat(n_per_example)
31 #genre = np.asarray([line.strip().split('\t')[1] for line in open(filename,'r').readlines()])
32
33 #Read labels from ground truth
34 filename = '/homes/pchilguano/msc_project/dataset/gtzan/lists/ground_truth.txt'
35 with open(filename, 'r') as f:
36 tag_set = set()
37 for line in f:
38 tag = line.strip().split('\t')[1]
39 tag_set.add(tag)
40
41 #Assign label to a discrete number
42 tag_dict = dict([(item, index) for index, item in enumerate(sorted(tag_set))])
43 with open(filename, 'r') as f:
44 target = np.asarray([], dtype='int32')
45 mp3_dict = {}
46 for line in f:
47 tag = line.strip().split('\t')[1]
48 target = np.append(target, tag_dict[tag])
49
50 train_input, valid_input, test_input = np.array_split(
51 flat_data,
52 [flat_data.shape[0]*1/2,
53 flat_data.shape[0]*3/4]
54 )
55 train_target, valid_target, test_target = np.array_split(
56 target,
57 [target.shape[0]*1/2,
58 target.shape[0]*3/4]
59 )
60
61 f = file('/homes/pchilguano/msc_project/dataset/gtzan/features/\
62 gtzan_3sec_9.pkl', 'wb')
63 cPickle.dump(
64 (
65 (train_input, train_target),
66 (valid_input, valid_target),
67 (test_input, test_target)
68 ),
69 f,
70 protocol=cPickle.HIGHEST_PROTOCOL
71 )
72 f.close()
73
74 '''
75 flat_target = target.repeat(n_per_example)
76
77 train_input, valid_input, test_input = np.array_split(flat_data, [flat_data.shape[0]*4/5, flat_data.shape[0]*9/10])
78 train_target, valid_target, test_target = np.array_split(flat_target, [flat_target.shape[0]*4/5, flat_target.shape[0]*9/10])
79
80 f = file('/homes/pchilguano/deep_learning/gtzan_logistic.pkl', 'wb')
81 cPickle.dump(((train_input, train_target), (valid_input, valid_target), (test_input, test_target)), f, protocol=cPickle.HIGHEST_PROTOCOL)
82 f.close()
83 '''