# HG changeset patch # User Dan Stowell # Date 1354700606 0 # Node ID f872eceb2c9c986b9f44d9c885dd4ebe38257bdb # Parent a7b513f0057c91aba497d73d59c704d5a9158371 implement leave-one-out crossvalidation diff -r a7b513f0057c -r f872eceb2c9c smacpy.py --- a/smacpy.py Tue Dec 04 22:25:00 2012 +0000 +++ b/smacpy.py Wed Dec 05 09:43:26 2012 +0000 @@ -30,20 +30,16 @@ # main class class Smacpy: - """Smacpy - simple-minded audio classifier in python. -This is a classifier that you can train on a set of labelled audio files, and then it predicts a label for further audio files. -It is designed with two main aims: - (1) to provide a baseline against which to test more advanced audio classifiers; - (2) to provide a simple code example of a classifier which people are free to build on. + """Smacpy - simple-minded audio classifier in python. See the README file for more details. -It uses the very common workflow of taking audio, converting to MFCCs, and modelling the MFCC "bag of frames" with a GMM. + USAGE EXAMPLE: + In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert: -USAGE EXAMPLE: -In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert: + from smacpy import Smacpy + model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'}) + model.classify('wavs/testing/hubert01.wav') -from smacpy import Smacpy -model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'}) -model.classify('wavs/testing/hubert01.wav') + Note for developers: this code should aim to be understandable, and not too long. Don't add too much functionality, or efficiency ;) """ def __init__(self, wavfolder, trainingdata): @@ -79,12 +75,10 @@ # For each label's aggregated features, train a GMM and remember it self.gmms = {} for label, aggf in aggfeatures.items(): - if verbose: - print(" Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf)))) + if verbose: print(" Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf)))) self.gmms[label] = GMM(n_components=10, cvtype='full') self.gmms[label].fit(aggf) - if verbose: - print(" Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata))) + if verbose: print(" Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata))) def __normalise(self, data): "Normalises data using the mean and stdev of the training data - so that everything is on a common scale." @@ -109,7 +103,7 @@ if verbose: print("Reading %s" % wavpath) if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path) sf = Sndfile(wavpath, "r") - if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels) + if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels) if sf.samplerate != fs: raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate)) window = np.hamming(framelen) features = [] @@ -136,20 +130,18 @@ except RuntimeError: break sf.close() - ret = np.array(features) - return ret + return np.array(features) ####################################################################### def trainAndTest(trainpath, trainwavs, testpath, testwavs): - "Trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)." + "Handy function for evaluating your code: trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)." print("TRAINING") model = Smacpy(trainpath, trainwavs) print("TESTING") ncorrect = 0 for wavpath,label in testwavs.items(): result = model.classify(os.path.join(testpath, wavpath)) - if verbose: - print(" inferred: %s" % result) + if verbose: print(" inferred: %s" % result) if result == label: ncorrect += 1 return (ncorrect, len(testwavs), len(model.gmms)) @@ -171,7 +163,6 @@ verbose = not args['quiet'] if args['testpath']==None: - print("TODO: loocv") # TODO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! args['testpath'] = args['trainpath'] # Build up lists of the training and testing WAV files: @@ -192,6 +183,18 @@ for wavpath,label in sorted(wavsfound[onepath].items()): print(" %s: \t %s" % (label, wavpath)) - ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) - print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) + if args['testpath'] != args['trainpath']: + # Separate train-and-test collections + ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) + print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) + else: + # "Leave-one-out-crossvalidation": leave one at a time out of the training set, and see how it is classified + totcorrect, tottotal = (0,0) + for whichone in wavsfound['trainpath']: + chosenone = {k:v for (k, v) in wavsfound['trainpath'].items() if k == whichone} + alltherest = {k:v for (k, v) in wavsfound['trainpath'].items() if k != whichone} + ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenone) + totcorrect += ncorrect + tottotal += ntotal + print("Got %i correct out of %i (leave-one-out crossvalidation)" % (totcorrect, tottotal))