changeset 16:f872eceb2c9c

implement leave-one-out crossvalidation
author Dan Stowell <danstowell@users.sourceforge.net>
date Wed, 05 Dec 2012 09:43:26 +0000
parents a7b513f0057c
children a7c4cc56ac6f
files smacpy.py
diffstat 1 files changed, 27 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/smacpy.py	Tue Dec 04 22:25:00 2012 +0000
+++ b/smacpy.py	Wed Dec 05 09:43:26 2012 +0000
@@ -30,20 +30,16 @@
 # main class
 
 class Smacpy:
-	"""Smacpy - simple-minded audio classifier in python.
-This is a classifier that you can train on a set of labelled audio files, and then it predicts a label for further audio files.
-It is designed with two main aims:
- (1) to provide a baseline against which to test more advanced audio classifiers;
- (2) to provide a simple code example of a classifier which people are free to build on.
+	"""Smacpy - simple-minded audio classifier in python. See the README file for more details.
 
-It uses the very common workflow of taking audio, converting to MFCCs, and modelling the MFCC "bag of frames" with a GMM.
+	USAGE EXAMPLE:
+	In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert:
 
-USAGE EXAMPLE:
-In this hypothetical example we train on four audio files, labelled as either 'usa' or 'uk', and then test on a separate audio file of someone called hubert:
+	from smacpy import Smacpy
+	model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'})
+	model.classify('wavs/testing/hubert01.wav')
 
-from smacpy import Smacpy
-model = Smacpy("wavs/training", {'karen01.wav':'usa', 'john01.wav':'uk', 'steve02.wav':'usa', 'joe03.wav':'uk'})
-model.classify('wavs/testing/hubert01.wav')
+	Note for developers: this code should aim to be understandable, and not too long. Don't add too much functionality, or efficiency ;)
 	"""
 
 	def __init__(self, wavfolder, trainingdata):
@@ -79,12 +75,10 @@
 		# For each label's aggregated features, train a GMM and remember it
 		self.gmms = {}
 		for label, aggf in aggfeatures.items():
-			if verbose:
-				print("    Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf))))
+			if verbose: print("    Training a GMM for label %s, using data of shape %s" % (label, str(np.shape(aggf))))
 			self.gmms[label] = GMM(n_components=10, cvtype='full')
 			self.gmms[label].fit(aggf)
-		if verbose:
-			print("  Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata)))
+		if verbose: print("  Trained %i classes from %i input files" % (len(self.gmms), len(trainingdata)))
 
 	def __normalise(self, data):
 		"Normalises data using the mean and stdev of the training data - so that everything is on a common scale."
@@ -109,7 +103,7 @@
 		if verbose: print("Reading %s" % wavpath)
 		if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path)
 		sf = Sndfile(wavpath, "r")
-		if (sf.channels != 1) and verbose:            print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels)
+		if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels)
 		if sf.samplerate != fs:         raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate))
 		window = np.hamming(framelen)
 		features = []
@@ -136,20 +130,18 @@
 			except RuntimeError:
 				break
 		sf.close()
-		ret = np.array(features)
-		return ret
+		return np.array(features)
 
 #######################################################################
 def trainAndTest(trainpath, trainwavs, testpath, testwavs):
-	"Trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)."
+	"Handy function for evaluating your code: trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)."
 	print("TRAINING")
 	model = Smacpy(trainpath, trainwavs)
 	print("TESTING")
 	ncorrect = 0
 	for wavpath,label in testwavs.items():
 		result = model.classify(os.path.join(testpath, wavpath))
-		if verbose:
-			print(" inferred: %s" % result)
+		if verbose: print(" inferred: %s" % result)
 		if result == label:
 			ncorrect += 1
 	return (ncorrect, len(testwavs), len(model.gmms))
@@ -171,7 +163,6 @@
 	verbose = not args['quiet']
 
 	if args['testpath']==None:
-		print("TODO: loocv") # TODO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 		args['testpath'] = args['trainpath']
 
 	# Build up lists of the training and testing WAV files:
@@ -192,6 +183,18 @@
 			for wavpath,label in sorted(wavsfound[onepath].items()):
 				print(" %s: \t %s" % (label, wavpath))
 
-	ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
-	print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
+	if args['testpath'] != args['trainpath']:
+		# Separate train-and-test collections
+		ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
+		print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
+	else:
+		# "Leave-one-out-crossvalidation": leave one at a time out of the training set, and see how it is classified
+		totcorrect, tottotal = (0,0)
+		for whichone in wavsfound['trainpath']:
+			chosenone  = {k:v for (k, v) in wavsfound['trainpath'].items() if k == whichone}
+			alltherest = {k:v for (k, v) in wavsfound['trainpath'].items() if k != whichone}
+			ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenone)
+			totcorrect += ncorrect
+			tottotal   += ntotal
+		print("Got %i correct out of %i (leave-one-out crossvalidation)" % (totcorrect, tottotal))