changeset 15:a7b513f0057c

separate out a handy trainAndTest() method
author Dan Stowell <danstowell@users.sourceforge.net>
date Tue, 04 Dec 2012 22:25:00 +0000
parents 3c6b9180b740
children f872eceb2c9c
files smacpy.py
diffstat 1 files changed, 22 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/smacpy.py	Thu Nov 29 14:10:52 2012 +0000
+++ b/smacpy.py	Tue Dec 04 22:25:00 2012 +0000
@@ -22,7 +22,6 @@
 
 #######################################################################
 # some settings
-
 framelen = 1024
 fs = 44100.0
 verbose = True
@@ -141,6 +140,21 @@
 		return ret
 
 #######################################################################
+def trainAndTest(trainpath, trainwavs, testpath, testwavs):
+	"Trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)."
+	print("TRAINING")
+	model = Smacpy(trainpath, trainwavs)
+	print("TESTING")
+	ncorrect = 0
+	for wavpath,label in testwavs.items():
+		result = model.classify(os.path.join(testpath, wavpath))
+		if verbose:
+			print(" inferred: %s" % result)
+		if result == label:
+			ncorrect += 1
+	return (ncorrect, len(testwavs), len(model.gmms))
+
+#######################################################################
 # If this file is invoked as a script, it carries out a simple runthrough
 # of training on some wavs, then testing, with classnames being the start of the filenames
 if __name__ == '__main__':
@@ -148,7 +162,7 @@
 	# Handle the command-line arguments for where the train/test data comes from:
 	parser = argparse.ArgumentParser()
 	parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training")
-	parser.add_argument('-T', '--testpath',  default='wavs', help="Path to the WAV files used for testing")
+	parser.add_argument('-T', '--testpath',                  help="Path to the WAV files used for testing")
 	parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing")
 	group = parser.add_mutually_exclusive_group()
 	group.add_argument('-c', '--charsplit',  default='_',    help="Character used to split filenames: anything BEFORE this character is the class")
@@ -156,6 +170,10 @@
 	args = vars(parser.parse_args())
 	verbose = not args['quiet']
 
+	if args['testpath']==None:
+		print("TODO: loocv") # TODO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+		args['testpath'] = args['trainpath']
+
 	# Build up lists of the training and testing WAV files:
 	wavsfound = {'trainpath':{}, 'testpath':{}}
 	for onepath in ['trainpath', 'testpath']:
@@ -174,20 +192,6 @@
 			for wavpath,label in sorted(wavsfound[onepath].items()):
 				print(" %s: \t %s" % (label, wavpath))
 
-	print("##################################################")
-	print("TRAINING")
-	model = Smacpy(args['trainpath'], wavsfound['trainpath'])
+	ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
+	print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
 
-	print("##################################################")
-	print("TESTING")
-	if args['trainpath'] == args['testpath']:
-		print(" (nb testing on the same files as used for training - for true evaluation please train and test on independent data):")
-	ncorrect = 0
-	for wavpath,label in wavsfound['testpath'].items():
-		result = model.classify(os.path.join(args['testpath'], wavpath))
-		if verbose:
-			print(" inferred: %s" % result)
-		if result == label:
-			ncorrect += 1
-	print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, len(wavsfound['testpath']), len(model.gmms)))
-