changeset 17:a7c4cc56ac6f

implement stratified LOOCV, better and faster than LOOCV
author Dan Stowell <danstowell@users.sourceforge.net>
date Fri, 07 Dec 2012 11:39:13 +0000
parents f872eceb2c9c
children 6a34565c8a74
files smacpy.py
diffstat 1 files changed, 20 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/smacpy.py	Wed Dec 05 09:43:26 2012 +0000
+++ b/smacpy.py	Fri Dec 07 11:39:13 2012 +0000
@@ -103,7 +103,7 @@
 		if verbose: print("Reading %s" % wavpath)
 		if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path)
 		sf = Sndfile(wavpath, "r")
-		if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels)
+		#if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels)
 		if sf.samplerate != fs:         raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate))
 		window = np.hamming(framelen)
 		features = []
@@ -170,7 +170,7 @@
 	for onepath in ['trainpath', 'testpath']:
 		pattern = os.path.join(args[onepath], '*.wav')
 		for wavpath in glob(pattern):
-			if args['numchars'] > 0:
+			if args['numchars'] != 0:
 				label = os.path.basename(wavpath)[:args['numchars']]
 			else:
 				label = os.path.basename(wavpath).split(args['charsplit'])[0]
@@ -188,13 +188,25 @@
 		ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
 		print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
 	else:
-		# "Leave-one-out-crossvalidation": leave one at a time out of the training set, and see how it is classified
+		# This runs "stratified leave-one-out crossvalidation": test multiple times by leaving one-of-each-class out and training on the rest.
+		# First we need to build a list of files grouped by each classlabel
+		labelsinuse = list(set(wavsfound['trainpath'].values()))
+		grouped = {label:[] for label in labelsinuse}
+		for wavpath,label in wavsfound['trainpath'].items():
+			grouped[label].append(wavpath)
+		numfolds = min(len(collection) for collection in grouped.values())
+		# Each "fold" will be a collection of one item of each label
+		folds = [{wavpaths[index]:label for label,wavpaths in grouped.items()} for index in range(numfolds)]
 		totcorrect, tottotal = (0,0)
-		for whichone in wavsfound['trainpath']:
-			chosenone  = {k:v for (k, v) in wavsfound['trainpath'].items() if k == whichone}
-			alltherest = {k:v for (k, v) in wavsfound['trainpath'].items() if k != whichone}
-			ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenone)
+		# Then we go through, each time training on all-but-one and testing on the one left out
+		for index in range(numfolds):
+			chosenfold = folds[index]
+			alltherest = {}
+			for whichfold, otherfold in enumerate(folds):
+				if whichfold != index:
+					alltherest.update(otherfold)
+			ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenfold)
 			totcorrect += ncorrect
 			tottotal   += ntotal
-		print("Got %i correct out of %i (leave-one-out crossvalidation)" % (totcorrect, tottotal))
+		print("Got %i correct out of %i (using stratified leave-one-out crossvalidation, %i folds)" % (totcorrect, tottotal, numfolds))