danstowell@0
|
1 #!/bin/env python
|
danstowell@0
|
2 #
|
danstowell@0
|
3 # smacpy - simple-minded audio classifier in python
|
danstowell@0
|
4 #
|
danstowell@0
|
5 # Copyright (c) 2012 Dan Stowell and Queen Mary University of London
|
danstowell@0
|
6 #
|
danstowell@0
|
7 # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
danstowell@0
|
8 #
|
danstowell@0
|
9 # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
danstowell@0
|
10 #
|
danstowell@0
|
11 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
danstowell@0
|
12
|
danstowell@0
|
13 import os.path
|
danstowell@0
|
14 import sys
|
danstowell@0
|
15 import numpy as np
|
danstowell@0
|
16 from glob import glob
|
danstowell@0
|
17 from scikits.audiolab import Sndfile
|
danstowell@0
|
18 from scikits.audiolab import Format
|
danstowell@0
|
19 from sklearn.mixture import GMM
|
danstowell@0
|
20
|
danstowell@0
|
21 from MFCC import melScaling
|
danstowell@0
|
22
|
danstowell@0
|
23 #######################################################################
|
danstowell@0
|
24 # some settings
|
danstowell@0
|
25
|
danstowell@0
|
26 framelen = 1024
|
danstowell@0
|
27 fs = 44100.0
|
danstowell@0
|
28 verbose = True
|
danstowell@0
|
29
|
danstowell@0
|
30 #######################################################################
|
danstowell@0
|
31 # main class
|
danstowell@0
|
32
|
danstowell@0
|
33 class Smacpy:
|
danstowell@0
|
34 """Smacpy - simple-minded audio classifier in python.
|
danstowell@0
|
35 This is a classifier that you can train on a set of labelled audio files, and then it predicts a label for further audio files.
|
danstowell@0
|
36 It is designed with two main aims:
|
danstowell@0
|
37 (1) to provide a baseline against which to test more advanced audio classifiers;
|
danstowell@0
|
38 (2) to provide a simple code example of a classifier which people are free to build on.
|
danstowell@0
|
39
|
danstowell@0
|
40 It uses the very common workflow of taking audio, converting to MFCCs, and modelling the MFCC "bag of frames" with a GMM.
|
danstowell@0
|
41
|
danstowell@0
|
42 USAGE EXAMPLE:
|
danstowell@0
|
43 In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert:
|
danstowell@0
|
44
|
danstowell@0
|
45 from smacpy import Smacpy
|
danstowell@0
|
46 model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'})
|
danstowell@0
|
47 model.classify('wavs/testing/hubert01.wav')
|
danstowell@0
|
48 """
|
danstowell@0
|
49
|
danstowell@0
|
50 def __init__(self, wavfolder, trainingdata):
|
danstowell@0
|
51 """Initialise the classifier and train it on some WAV files.
|
danstowell@0
|
52 'wavfolder' is the base folder, to be prepended to all WAV paths.
|
danstowell@0
|
53 'trainingdata' is a dictionary of wavpath:label pairs."""
|
danstowell@0
|
54
|
danstowell@4
|
55 self.mfccMaker = melScaling(int(fs), framelen/2, 40)
|
danstowell@4
|
56 self.mfccMaker.update()
|
danstowell@0
|
57
|
danstowell@4
|
58 allfeatures = {wavpath:self.file_to_features(os.path.join(wavfolder, wavpath)) for wavpath in trainingdata}
|
danstowell@4
|
59
|
danstowell@4
|
60 # Determine the normalisation stats, and remember them
|
danstowell@4
|
61 allconcat = np.vstack(allfeatures.values())
|
danstowell@4
|
62 self.means = np.mean(allconcat, 0)
|
danstowell@4
|
63 self.invstds = np.std(allconcat, 0)
|
danstowell@4
|
64 for i,val in enumerate(self.invstds):
|
danstowell@0
|
65 if val == 0.0:
|
danstowell@4
|
66 self.invstds[i] = 1.0
|
danstowell@0
|
67 else:
|
danstowell@4
|
68 self.invstds[i] = 1.0 / val
|
danstowell@0
|
69
|
danstowell@0
|
70 # For each label, compile a normalised concatenated list of features
|
danstowell@0
|
71 aggfeatures = {}
|
danstowell@0
|
72 for wavpath, features in allfeatures.iteritems():
|
danstowell@0
|
73 label = trainingdata[wavpath]
|
danstowell@4
|
74 normed = self.__normalise(features)
|
danstowell@0
|
75 if label not in aggfeatures:
|
danstowell@4
|
76 aggfeatures[label] = normed
|
danstowell@4
|
77 else:
|
danstowell@4
|
78 aggfeatures[label] = np.vstack((aggfeatures[label], normed))
|
danstowell@0
|
79
|
danstowell@4
|
80 # For each label's aggregated features, train a GMM and remember it
|
danstowell@0
|
81 self.gmms = {}
|
danstowell@0
|
82 for label, aggf in aggfeatures.iteritems():
|
danstowell@0
|
83 if verbose:
|
danstowell@0
|
84 print " Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf)))
|
danstowell@0
|
85 self.gmms[label] = GMM(n_components=10, cvtype='full')
|
danstowell@0
|
86 self.gmms[label].fit(aggf)
|
danstowell@0
|
87 if verbose:
|
danstowell@0
|
88 print " Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata))
|
danstowell@0
|
89
|
danstowell@0
|
90 def __normalise(self, data):
|
danstowell@0
|
91 "Normalises data using the mean and stdev of the training data - so that everything is on a common scale."
|
danstowell@0
|
92 return (data - self.means) * self.invstds
|
danstowell@0
|
93
|
danstowell@0
|
94 def classify(self, wavpath):
|
danstowell@0
|
95 "Specify the path to an audio file, and this returns the max-likelihood class, as a string label."
|
danstowell@4
|
96 features = self.__normalise(self.file_to_features(wavpath))
|
danstowell@0
|
97 # For each label GMM, find the overall log-likelihood and choose the strongest
|
danstowell@0
|
98 bestlabel = ''
|
danstowell@0
|
99 bestll = -9e99
|
danstowell@0
|
100 for label, gmm in self.gmms.iteritems():
|
danstowell@4
|
101 ll = gmm.eval(features)[0]
|
danstowell@4
|
102 ll = np.sum(ll)
|
danstowell@0
|
103 if ll > bestll:
|
danstowell@0
|
104 bestll = ll
|
danstowell@0
|
105 bestlabel = label
|
danstowell@0
|
106 return bestlabel
|
danstowell@0
|
107
|
danstowell@4
|
108 def file_to_features(self, wavpath):
|
danstowell@4
|
109 "Reads through a mono WAV file, converting each frame to the required features. Returns a 2D array."
|
danstowell@4
|
110 if verbose: print "Reading %s" % wavpath
|
danstowell@4
|
111 if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path)
|
danstowell@4
|
112 sf = Sndfile(wavpath, "r")
|
danstowell@4
|
113 if sf.channels != 1: raise ValueError("sound file has multiple channels (%i) - mono audio required." % sf.channels)
|
danstowell@4
|
114 if sf.samplerate != fs: raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate))
|
danstowell@4
|
115 window = np.hamming(framelen)
|
danstowell@4
|
116 features = []
|
danstowell@4
|
117 while(True):
|
danstowell@4
|
118 try:
|
danstowell@4
|
119 chunk = sf.read_frames(framelen, dtype=np.float32)
|
danstowell@4
|
120 if len(chunk) != framelen:
|
danstowell@4
|
121 print "Not read sufficient samples - returning"
|
danstowell@4
|
122 break
|
danstowell@4
|
123 framespectrum = np.fft.fft(window * chunk)
|
danstowell@4
|
124 magspec = abs(framespectrum[:framelen/2])
|
danstowell@0
|
125
|
danstowell@4
|
126 # do the frequency warping and MFCC computation
|
danstowell@4
|
127 melSpectrum = self.mfccMaker.warpSpectrum(magspec)
|
danstowell@4
|
128 melCepstrum = self.mfccMaker.getMFCCs(melSpectrum,cn=True)
|
danstowell@4
|
129 melCepstrum = melCepstrum[1:] # exclude zeroth coefficient
|
danstowell@4
|
130 melCepstrum = melCepstrum[:13] # limit to lower MFCCs
|
danstowell@4
|
131
|
danstowell@4
|
132 framefeatures = melCepstrum # todo: include deltas? that can be your homework.
|
danstowell@4
|
133
|
danstowell@4
|
134 features.append(framefeatures)
|
danstowell@4
|
135 except RuntimeError:
|
danstowell@0
|
136 break
|
danstowell@4
|
137 sf.close()
|
danstowell@4
|
138 ret = np.array(features)
|
danstowell@4
|
139 return ret
|
danstowell@0
|
140
|
danstowell@0
|
141 #######################################################################
|
danstowell@4
|
142 # If this file is invoked as a script, it carries out a simple runthrough
|
danstowell@4
|
143 # of training on some wavs, then testing (on the same ones, just for confirmation, not for eval)
|
danstowell@0
|
144 if __name__ == '__main__':
|
danstowell@0
|
145 foldername = 'wavs'
|
danstowell@0
|
146 if len(sys.argv) > 1:
|
danstowell@0
|
147 foldername = sys.argv[1]
|
danstowell@0
|
148
|
danstowell@0
|
149 trainingdata = {}
|
danstowell@0
|
150 pattern = os.path.join(foldername, '*.wav')
|
danstowell@0
|
151 for wavpath in glob(pattern):
|
danstowell@3
|
152 label = os.path.basename(wavpath).split('_')[0]
|
danstowell@0
|
153 shortwavpath = os.path.relpath(wavpath, foldername)
|
danstowell@3
|
154 trainingdata[shortwavpath] = label
|
danstowell@0
|
155 if len(trainingdata)==0:
|
danstowell@0
|
156 raise RuntimeError("Found no files using this pattern: %s" % pattern)
|
danstowell@0
|
157 if verbose:
|
danstowell@0
|
158 print "Class-labels and filenames to be used in training:"
|
danstowell@3
|
159 for wavpath,label in sorted(trainingdata.iteritems()):
|
danstowell@3
|
160 print " %s: \t %s" % (label, wavpath)
|
danstowell@0
|
161
|
danstowell@4
|
162 print "##################################################"
|
danstowell@4
|
163 print "TRAINING"
|
danstowell@0
|
164 model = Smacpy(foldername, trainingdata)
|
danstowell@0
|
165
|
danstowell@4
|
166 print "##################################################"
|
danstowell@4
|
167 print "TESTING (nb on the same files as used for training - for true evaluation please train and test on independent data):"
|
danstowell@4
|
168 ncorrect = 0
|
danstowell@0
|
169 for wavpath,label in trainingdata.iteritems():
|
danstowell@0
|
170 result = model.classify(os.path.join(foldername, wavpath))
|
danstowell@0
|
171 print " inferred: %s" % result
|
danstowell@4
|
172 if result == label:
|
danstowell@4
|
173 ncorrect += 1
|
danstowell@4
|
174 print "Got %i correct out of %i" % (ncorrect, len(trainingdata))
|
danstowell@0
|
175
|