annotate smacpy.py @ 3:7a20cff05bd6

couple of bugfixes, almost runs
author Dan Stowell <danstowell@users.sourceforge.net>
date Wed, 14 Nov 2012 13:23:02 +0000
parents 33a9f41169fc
children b1b9676f8791
rev   line source
danstowell@0 1 #!/bin/env python
danstowell@0 2 #
danstowell@0 3 # smacpy - simple-minded audio classifier in python
danstowell@0 4 #
danstowell@0 5 # Copyright (c) 2012 Dan Stowell and Queen Mary University of London
danstowell@0 6 #
danstowell@0 7 # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
danstowell@0 8 #
danstowell@0 9 # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
danstowell@0 10 #
danstowell@0 11 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
danstowell@0 12
danstowell@0 13 import os.path
danstowell@0 14 import sys
danstowell@0 15 import numpy as np
danstowell@0 16 from glob import glob
danstowell@0 17 from scikits.audiolab import Sndfile
danstowell@0 18 from scikits.audiolab import Format
danstowell@0 19 from sklearn.mixture import GMM
danstowell@0 20
danstowell@0 21 from MFCC import melScaling
danstowell@0 22
danstowell@0 23 #######################################################################
danstowell@0 24 # some settings
danstowell@0 25
danstowell@0 26 framelen = 1024
danstowell@0 27 fs = 44100.0
danstowell@0 28 verbose = True
danstowell@0 29
danstowell@0 30 #######################################################################
danstowell@0 31 # main class
danstowell@0 32
danstowell@0 33 class Smacpy:
danstowell@0 34 """Smacpy - simple-minded audio classifier in python.
danstowell@0 35 This is a classifier that you can train on a set of labelled audio files, and then it predicts a label for further audio files.
danstowell@0 36 It is designed with two main aims:
danstowell@0 37 (1) to provide a baseline against which to test more advanced audio classifiers;
danstowell@0 38 (2) to provide a simple code example of a classifier which people are free to build on.
danstowell@0 39
danstowell@0 40 It uses the very common workflow of taking audio, converting to MFCCs, and modelling the MFCC "bag of frames" with a GMM.
danstowell@0 41
danstowell@0 42 USAGE EXAMPLE:
danstowell@0 43 In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert:
danstowell@0 44
danstowell@0 45 from smacpy import Smacpy
danstowell@0 46 model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'})
danstowell@0 47 model.classify('wavs/testing/hubert01.wav')
danstowell@0 48 """
danstowell@0 49
danstowell@0 50 def __init__(self, wavfolder, trainingdata):
danstowell@0 51 """Initialise the classifier and train it on some WAV files.
danstowell@0 52 'wavfolder' is the base folder, to be prepended to all WAV paths.
danstowell@0 53 'trainingdata' is a dictionary of wavpath:label pairs."""
danstowell@0 54
danstowell@0 55 allfeatures = {wavpath:file_to_features(os.path.join(wavfolder, wavpath)) for wavpath in trainingdata}
danstowell@0 56
danstowell@0 57 # Now determine the normalisation stats, remember them
danstowell@0 58 self.means = np.mean(anarray, 0)
danstowell@0 59 self.theinvstds = np.std(anarray, 0)
danstowell@0 60 for i,val in enumerate(self.theinvstds):
danstowell@0 61 if val == 0.0:
danstowell@0 62 self.theinvstds[i] = 1.0
danstowell@0 63 else:
danstowell@0 64 self.theinvstds[i] = 1.0 / val
danstowell@0 65
danstowell@0 66 # For each label, compile a normalised concatenated list of features
danstowell@0 67 aggfeatures = {}
danstowell@0 68 for wavpath, features in allfeatures.iteritems():
danstowell@0 69 label = trainingdata[wavpath]
danstowell@0 70 if label not in aggfeatures:
danstowell@0 71 aggfeatures[label] = np.array([])
danstowell@0 72 aggfeatures[label] = np.hstack((aggfeatures[label], self.__normalise(features)))
danstowell@0 73
danstowell@0 74 # For each label, train a GMM and remember it
danstowell@0 75 self.gmms = {}
danstowell@0 76 for label, aggf in aggfeatures.iteritems():
danstowell@0 77 if verbose:
danstowell@0 78 print " Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf)))
danstowell@0 79 self.gmms[label] = GMM(n_components=10, cvtype='full')
danstowell@0 80 self.gmms[label].fit(aggf)
danstowell@0 81 if verbose:
danstowell@0 82 print " Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata))
danstowell@0 83
danstowell@0 84 def __normalise(self, data):
danstowell@0 85 "Normalises data using the mean and stdev of the training data - so that everything is on a common scale."
danstowell@0 86 return (data - self.means) * self.invstds
danstowell@0 87
danstowell@0 88 def classify(self, wavpath):
danstowell@0 89 "Specify the path to an audio file, and this returns the max-likelihood class, as a string label."
danstowell@0 90 features = self.__normalise(file_to_features(wavpath))
danstowell@0 91 # For each label GMM, find the overall log-likelihood and choose the strongest
danstowell@0 92 bestlabel = ''
danstowell@0 93 bestll = -9e99
danstowell@0 94 # Choose the biggest
danstowell@0 95 for label, gmm in self.gmms.iteritems():
danstowell@0 96 ll = np.sum(gmm.eval(features))
danstowell@0 97 if ll > bestll:
danstowell@0 98 bestll = ll
danstowell@0 99 bestlabel = label
danstowell@0 100 return bestlabel
danstowell@0 101
danstowell@0 102 #######################################################################
danstowell@0 103 # auxiliary functions
danstowell@0 104
danstowell@0 105 def file_to_features(wavpath):
danstowell@0 106 "Reads through a mono WAV file, converting each frame to the required features. Returns a 2D array."
danstowell@0 107 if verbose: print "Reading %s" % wavpath
danstowell@0 108 if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path)
danstowell@0 109 sf = Sndfile(wavpath, "r")
danstowell@0 110 if sf.channels != 1: raise ValueError("sound file has multiple channels (%i) - mono audio required." % sf.channels)
danstowell@0 111 if sf.samplerate != fs: raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate))
danstowell@0 112 window = np.hamming(framelen)
danstowell@0 113 features = []
danstowell@3 114 mfccMaker = melScaling(int(fs), framelen/2, 40)
danstowell@3 115 mfccMaker.update()
danstowell@0 116 while(True):
danstowell@0 117 try:
danstowell@0 118 chunk = sf.read_frames(framelen, dtype=np.float32)
danstowell@0 119 if len(chunk) != framelen:
danstowell@0 120 print "Not read sufficient samples - returning"
danstowell@0 121 break
danstowell@0 122 framespectrum = np.fft.fft(window * chunk)
danstowell@0 123 magspec = abs(framespectrum[:framelen/2])
danstowell@0 124
danstowell@0 125 # do the frequency warping and MFCC computation
danstowell@0 126 melSpectrum = mfccMaker.warpSpectrum(magspec)
danstowell@0 127 melCepstrum = mfccMaker.getMFCCs(melSpectrum,cn=True)
danstowell@0 128 melCepstrum = melCepstrum[1:] # exclude zeroth coefficient
danstowell@0 129 melCepstrum = melCepstrum[:13] # limit to lower MFCCs
danstowell@0 130
danstowell@0 131 framefeatures = melCepstrum # todo: include deltas? that can be your homework.
danstowell@0 132
danstowell@0 133 features.append(framefeatures)
danstowell@0 134 except RuntimeError:
danstowell@0 135 break
danstowell@0 136 sf.close()
danstowell@0 137 ret = np.array(features)
danstowell@0 138 if verbose:
danstowell@0 139 print "file_to_features() produced array shape " + str(np.shape(ret))
danstowell@0 140 return ret
danstowell@0 141
danstowell@0 142 #######################################################################
danstowell@0 143 if __name__ == '__main__':
danstowell@0 144 foldername = 'wavs'
danstowell@0 145 if len(sys.argv) > 1:
danstowell@0 146 foldername = sys.argv[1]
danstowell@0 147
danstowell@0 148 trainingdata = {}
danstowell@0 149 pattern = os.path.join(foldername, '*.wav')
danstowell@0 150 for wavpath in glob(pattern):
danstowell@3 151 label = os.path.basename(wavpath).split('_')[0]
danstowell@0 152 shortwavpath = os.path.relpath(wavpath, foldername)
danstowell@3 153 trainingdata[shortwavpath] = label
danstowell@0 154 if len(trainingdata)==0:
danstowell@0 155 raise RuntimeError("Found no files using this pattern: %s" % pattern)
danstowell@0 156 if verbose:
danstowell@0 157 print "Class-labels and filenames to be used in training:"
danstowell@3 158 for wavpath,label in sorted(trainingdata.iteritems()):
danstowell@3 159 print " %s: \t %s" % (label, wavpath)
danstowell@0 160
danstowell@0 161 model = Smacpy(foldername, trainingdata)
danstowell@0 162
danstowell@0 163 #################################
danstowell@0 164 print "Inferred classifications:"
danstowell@0 165 for wavpath,label in trainingdata.iteritems():
danstowell@0 166 print " %s" % wavpath
danstowell@0 167 print " true: %s" % label
danstowell@0 168 result = model.classify(os.path.join(foldername, wavpath))
danstowell@0 169 print " inferred: %s" % result
danstowell@0 170