danstowell@0
|
1 #!/bin/env python
|
danstowell@0
|
2 #
|
danstowell@0
|
3 # smacpy - simple-minded audio classifier in python
|
danstowell@0
|
4 #
|
danstowell@0
|
5 # Copyright (c) 2012 Dan Stowell and Queen Mary University of London
|
danstowell@0
|
6 #
|
danstowell@0
|
7 # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
danstowell@0
|
8 #
|
danstowell@0
|
9 # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
danstowell@0
|
10 #
|
danstowell@0
|
11 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
danstowell@0
|
12
|
danstowell@0
|
13 import os.path
|
danstowell@0
|
14 import numpy as np
|
danstowell@7
|
15 import argparse
|
danstowell@0
|
16 from glob import glob
|
danstowell@33
|
17 import librosa
|
danstowell@32
|
18 from sklearn.mixture import GaussianMixture as GMM
|
danstowell@0
|
19
|
danstowell@0
|
20 from MFCC import melScaling
|
danstowell@0
|
21
|
danstowell@0
|
22 #######################################################################
|
danstowell@0
|
23 # some settings
|
danstowell@0
|
24 framelen = 1024
|
danstowell@0
|
25 fs = 44100.0
|
danstowell@0
|
26 verbose = True
|
danstowell@0
|
27
|
danstowell@0
|
28 #######################################################################
|
danstowell@0
|
29 # main class
|
danstowell@0
|
30
|
danstowell@0
|
31 class Smacpy:
|
danstowell@16
|
32 """Smacpy - simple-minded audio classifier in python. See the README file for more details.
|
danstowell@0
|
33
|
danstowell@16
|
34 USAGE EXAMPLE:
|
danstowell@16
|
35 In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert:
|
danstowell@0
|
36
|
danstowell@16
|
37 from smacpy import Smacpy
|
danstowell@16
|
38 model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'})
|
danstowell@16
|
39 model.classify('wavs/testing/hubert01.wav')
|
danstowell@0
|
40
|
danstowell@16
|
41 Note for developers: this code should aim to be understandable, and not too long. Don't add too much functionality, or efficiency ;)
|
danstowell@0
|
42 """
|
danstowell@0
|
43
|
danstowell@0
|
44 def __init__(self, wavfolder, trainingdata):
|
danstowell@0
|
45 """Initialise the classifier and train it on some WAV files.
|
danstowell@0
|
46 'wavfolder' is the base folder, to be prepended to all WAV paths.
|
danstowell@0
|
47 'trainingdata' is a dictionary of wavpath:label pairs."""
|
danstowell@0
|
48
|
danstowell@33
|
49 self.mfccMaker = melScaling(int(fs), int(framelen/2), 40)
|
danstowell@4
|
50 self.mfccMaker.update()
|
danstowell@0
|
51
|
danstowell@4
|
52 allfeatures = {wavpath:self.file_to_features(os.path.join(wavfolder, wavpath)) for wavpath in trainingdata}
|
danstowell@4
|
53
|
danstowell@4
|
54 # Determine the normalisation stats, and remember them
|
danstowell@8
|
55 allconcat = np.vstack(list(allfeatures.values()))
|
danstowell@4
|
56 self.means = np.mean(allconcat, 0)
|
danstowell@4
|
57 self.invstds = np.std(allconcat, 0)
|
danstowell@4
|
58 for i,val in enumerate(self.invstds):
|
danstowell@0
|
59 if val == 0.0:
|
danstowell@4
|
60 self.invstds[i] = 1.0
|
danstowell@0
|
61 else:
|
danstowell@4
|
62 self.invstds[i] = 1.0 / val
|
danstowell@0
|
63
|
danstowell@0
|
64 # For each label, compile a normalised concatenated list of features
|
danstowell@0
|
65 aggfeatures = {}
|
danstowell@8
|
66 for wavpath, features in allfeatures.items():
|
danstowell@0
|
67 label = trainingdata[wavpath]
|
danstowell@4
|
68 normed = self.__normalise(features)
|
danstowell@0
|
69 if label not in aggfeatures:
|
danstowell@4
|
70 aggfeatures[label] = normed
|
danstowell@4
|
71 else:
|
danstowell@4
|
72 aggfeatures[label] = np.vstack((aggfeatures[label], normed))
|
danstowell@0
|
73
|
danstowell@4
|
74 # For each label's aggregated features, train a GMM and remember it
|
danstowell@0
|
75 self.gmms = {}
|
danstowell@8
|
76 for label, aggf in aggfeatures.items():
|
danstowell@16
|
77 if verbose: print(" Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf))))
|
danstowell@21
|
78 self.gmms[label] = GMM(n_components=10) # , cvtype='full')
|
danstowell@0
|
79 self.gmms[label].fit(aggf)
|
danstowell@16
|
80 if verbose: print(" Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata)))
|
danstowell@0
|
81
|
danstowell@0
|
82 def __normalise(self, data):
|
danstowell@0
|
83 "Normalises data using the mean and stdev of the training data - so that everything is on a common scale."
|
danstowell@0
|
84 return (data - self.means) * self.invstds
|
danstowell@0
|
85
|
danstowell@0
|
86 def classify(self, wavpath):
|
danstowell@0
|
87 "Specify the path to an audio file, and this returns the max-likelihood class, as a string label."
|
danstowell@4
|
88 features = self.__normalise(self.file_to_features(wavpath))
|
danstowell@0
|
89 # For each label GMM, find the overall log-likelihood and choose the strongest
|
danstowell@0
|
90 bestlabel = ''
|
danstowell@0
|
91 bestll = -9e99
|
danstowell@8
|
92 for label, gmm in self.gmms.items():
|
danstowell@25
|
93 ll = gmm.score_samples(features)[0]
|
danstowell@4
|
94 ll = np.sum(ll)
|
danstowell@0
|
95 if ll > bestll:
|
danstowell@0
|
96 bestll = ll
|
danstowell@0
|
97 bestlabel = label
|
danstowell@0
|
98 return bestlabel
|
danstowell@0
|
99
|
danstowell@4
|
100 def file_to_features(self, wavpath):
|
danstowell@4
|
101 "Reads through a mono WAV file, converting each frame to the required features. Returns a 2D array."
|
danstowell@8
|
102 if verbose: print("Reading %s" % wavpath)
|
danstowell@23
|
103 if not os.path.isfile(wavpath): raise ValueError("path %s not found" % wavpath)
|
danstowell@33
|
104
|
danstowell@33
|
105 audiodata, _ = librosa.load(wavpath, sr=fs, mono=True)
|
danstowell@4
|
106 window = np.hamming(framelen)
|
danstowell@4
|
107 features = []
|
danstowell@33
|
108 chunkpos = 0
|
danstowell@4
|
109 while(True):
|
danstowell@4
|
110 try:
|
danstowell@33
|
111 chunk = audiodata[chunkpos:chunkpos+framelen]
|
danstowell@4
|
112 if len(chunk) != framelen:
|
danstowell@33
|
113 #print("Not read sufficient samples - assuming end of file")
|
danstowell@4
|
114 break
|
danstowell@4
|
115 framespectrum = np.fft.fft(window * chunk)
|
danstowell@33
|
116 magspec = abs(framespectrum[:int(framelen/2)])
|
danstowell@0
|
117
|
danstowell@4
|
118 # do the frequency warping and MFCC computation
|
danstowell@4
|
119 melSpectrum = self.mfccMaker.warpSpectrum(magspec)
|
danstowell@4
|
120 melCepstrum = self.mfccMaker.getMFCCs(melSpectrum,cn=True)
|
danstowell@4
|
121 melCepstrum = melCepstrum[1:] # exclude zeroth coefficient
|
danstowell@4
|
122 melCepstrum = melCepstrum[:13] # limit to lower MFCCs
|
danstowell@4
|
123
|
danstowell@4
|
124 framefeatures = melCepstrum # todo: include deltas? that can be your homework.
|
danstowell@4
|
125
|
danstowell@4
|
126 features.append(framefeatures)
|
danstowell@33
|
127
|
danstowell@33
|
128 chunkpos += framelen
|
danstowell@4
|
129 except RuntimeError:
|
danstowell@0
|
130 break
|
danstowell@33
|
131 if verbose: print(" Data shape: %s" % str(np.array(features).shape))
|
danstowell@16
|
132 return np.array(features)
|
danstowell@0
|
133
|
danstowell@0
|
134 #######################################################################
|
danstowell@15
|
135 def trainAndTest(trainpath, trainwavs, testpath, testwavs):
|
danstowell@16
|
136 "Handy function for evaluating your code: trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)."
|
danstowell@15
|
137 print("TRAINING")
|
danstowell@15
|
138 model = Smacpy(trainpath, trainwavs)
|
danstowell@15
|
139 print("TESTING")
|
danstowell@15
|
140 ncorrect = 0
|
danstowell@15
|
141 for wavpath,label in testwavs.items():
|
danstowell@15
|
142 result = model.classify(os.path.join(testpath, wavpath))
|
danstowell@16
|
143 if verbose: print(" inferred: %s" % result)
|
danstowell@15
|
144 if result == label:
|
danstowell@15
|
145 ncorrect += 1
|
danstowell@15
|
146 return (ncorrect, len(testwavs), len(model.gmms))
|
danstowell@15
|
147
|
danstowell@15
|
148 #######################################################################
|
danstowell@4
|
149 # If this file is invoked as a script, it carries out a simple runthrough
|
danstowell@14
|
150 # of training on some wavs, then testing, with classnames being the start of the filenames
|
danstowell@0
|
151 if __name__ == '__main__':
|
danstowell@0
|
152
|
danstowell@7
|
153 # Handle the command-line arguments for where the train/test data comes from:
|
danstowell@7
|
154 parser = argparse.ArgumentParser()
|
danstowell@13
|
155 parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training")
|
danstowell@15
|
156 parser.add_argument('-T', '--testpath', help="Path to the WAV files used for testing")
|
danstowell@10
|
157 parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing")
|
danstowell@13
|
158 group = parser.add_mutually_exclusive_group()
|
danstowell@13
|
159 group.add_argument('-c', '--charsplit', default='_', help="Character used to split filenames: anything BEFORE this character is the class")
|
danstowell@13
|
160 group.add_argument('-n', '--numchars' , default=0 , help="Instead of splitting using 'charsplit', use this fixed number of characters from the start of the filename", type=int)
|
danstowell@7
|
161 args = vars(parser.parse_args())
|
danstowell@10
|
162 verbose = not args['quiet']
|
danstowell@7
|
163
|
danstowell@15
|
164 if args['testpath']==None:
|
danstowell@15
|
165 args['testpath'] = args['trainpath']
|
danstowell@15
|
166
|
danstowell@7
|
167 # Build up lists of the training and testing WAV files:
|
danstowell@7
|
168 wavsfound = {'trainpath':{}, 'testpath':{}}
|
danstowell@7
|
169 for onepath in ['trainpath', 'testpath']:
|
danstowell@7
|
170 pattern = os.path.join(args[onepath], '*.wav')
|
danstowell@7
|
171 for wavpath in glob(pattern):
|
danstowell@17
|
172 if args['numchars'] != 0:
|
danstowell@13
|
173 label = os.path.basename(wavpath)[:args['numchars']]
|
danstowell@13
|
174 else:
|
danstowell@13
|
175 label = os.path.basename(wavpath).split(args['charsplit'])[0]
|
danstowell@7
|
176 shortwavpath = os.path.relpath(wavpath, args[onepath])
|
danstowell@7
|
177 wavsfound[onepath][shortwavpath] = label
|
danstowell@7
|
178 if len(wavsfound[onepath])==0:
|
danstowell@7
|
179 raise RuntimeError("Found no files using this pattern: %s" % pattern)
|
danstowell@7
|
180 if verbose:
|
danstowell@8
|
181 print("Class-labels and filenames to be used from %s:" % onepath)
|
danstowell@8
|
182 for wavpath,label in sorted(wavsfound[onepath].items()):
|
danstowell@8
|
183 print(" %s: \t %s" % (label, wavpath))
|
danstowell@0
|
184
|
danstowell@16
|
185 if args['testpath'] != args['trainpath']:
|
danstowell@16
|
186 # Separate train-and-test collections
|
danstowell@16
|
187 ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
|
danstowell@16
|
188 print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
|
danstowell@16
|
189 else:
|
danstowell@17
|
190 # This runs "stratified leave-one-out crossvalidation": test multiple times by leaving one-of-each-class out and training on the rest.
|
danstowell@17
|
191 # First we need to build a list of files grouped by each classlabel
|
danstowell@18
|
192 labelsinuse = sorted(list(set(wavsfound['trainpath'].values())))
|
danstowell@17
|
193 grouped = {label:[] for label in labelsinuse}
|
danstowell@17
|
194 for wavpath,label in wavsfound['trainpath'].items():
|
danstowell@17
|
195 grouped[label].append(wavpath)
|
danstowell@17
|
196 numfolds = min(len(collection) for collection in grouped.values())
|
danstowell@17
|
197 # Each "fold" will be a collection of one item of each label
|
danstowell@17
|
198 folds = [{wavpaths[index]:label for label,wavpaths in grouped.items()} for index in range(numfolds)]
|
danstowell@16
|
199 totcorrect, tottotal = (0,0)
|
danstowell@17
|
200 # Then we go through, each time training on all-but-one and testing on the one left out
|
danstowell@17
|
201 for index in range(numfolds):
|
danstowell@19
|
202 print("Fold %i of %i" % (index+1, numfolds))
|
danstowell@17
|
203 chosenfold = folds[index]
|
danstowell@17
|
204 alltherest = {}
|
danstowell@17
|
205 for whichfold, otherfold in enumerate(folds):
|
danstowell@17
|
206 if whichfold != index:
|
danstowell@17
|
207 alltherest.update(otherfold)
|
danstowell@17
|
208 ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenfold)
|
danstowell@16
|
209 totcorrect += ncorrect
|
danstowell@16
|
210 tottotal += ntotal
|
danstowell@17
|
211 print("Got %i correct out of %i (using stratified leave-one-out crossvalidation, %i folds)" % (totcorrect, tottotal, numfolds))
|
danstowell@0
|
212
|