changeset 26:54612c177bcc

Change smacpy commandline to do input/output as specified for AASP challenge
author Dan Stowell <danstowell@users.sourceforge.net>
date Wed, 23 Jan 2013 10:48:12 +0000
parents 1a2cfd98e737
children 55b7b5a5cf43
files smacpy.py
diffstat 1 files changed, 36 insertions(+), 54 deletions(-) [+]
line wrap: on
line diff
--- a/smacpy.py	Thu Jan 10 14:31:40 2013 +0000
+++ b/smacpy.py	Wed Jan 23 10:48:12 2013 +0000
@@ -17,6 +17,7 @@
 from scikits.audiolab import Sndfile
 from scikits.audiolab import Format
 from sklearn.mixture import GMM
+import csv
 
 from MFCC import melScaling
 
@@ -147,67 +148,48 @@
 	return (ncorrect, len(testwavs), len(model.gmms))
 
 #######################################################################
-# If this file is invoked as a script, it carries out a simple runthrough
-# of training on some wavs, then testing, with classnames being the start of the filenames
+# This handles the invocation as set up for the AASP challenge
+# eg:  python smacpy.py --trainlist trainlist.example.txt --testlist testlist.example.txt --outlist output.txt
 if __name__ == '__main__':
-
 	# Handle the command-line arguments for where the train/test data comes from:
 	parser = argparse.ArgumentParser()
-	parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training")
-	parser.add_argument('-T', '--testpath',                  help="Path to the WAV files used for testing")
+	parser.add_argument('-t', '--trainlist', default='NOT_SPECIFIED', help="Path to the file listing WAV files used for training")
+	parser.add_argument('-T', '--testlist',  default='NOT_SPECIFIED', help="Path to the file listing WAV files used for testing")
+	parser.add_argument('-o', '--outlist',   default='output.txt',    help="Path to write results to")
 	parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing")
-	group = parser.add_mutually_exclusive_group()
-	group.add_argument('-c', '--charsplit',  default='_',    help="Character used to split filenames: anything BEFORE this character is the class")
-	group.add_argument('-n', '--numchars' ,  default=0  ,    help="Instead of splitting using 'charsplit', use this fixed number of characters from the start of the filename", type=int)
 	args = vars(parser.parse_args())
 	verbose = not args['quiet']
 
-	if args['testpath']==None:
-		args['testpath'] = args['trainpath']
+	# Load the training list as a dictionary of "filename->classification" data
+	trainlist = {}
+	for row in csv.reader(file(args['trainlist']), delimiter="\t"):
+		if len(row)==2:
+			trainlist[row[0]] = row[1]
+		elif len(row)>2:
+			raise ValueError("Row has more than one tab character in: ", row)
 
-	# Build up lists of the training and testing WAV files:
-	wavsfound = {'trainpath':{}, 'testpath':{}}
-	for onepath in ['trainpath', 'testpath']:
-		pattern = os.path.join(args[onepath], '*.wav')
-		for wavpath in glob(pattern):
-			if args['numchars'] != 0:
-				label = os.path.basename(wavpath)[:args['numchars']]
-			else:
-				label = os.path.basename(wavpath).split(args['charsplit'])[0]
-			shortwavpath = os.path.relpath(wavpath, args[onepath])
-			wavsfound[onepath][shortwavpath] = label
-		if len(wavsfound[onepath])==0:
-			raise RuntimeError("Found no files using this pattern: %s" % pattern)
-		if verbose:
-			print("Class-labels and filenames to be used from %s:" % onepath)
-			for wavpath,label in sorted(wavsfound[onepath].items()):
-				print(" %s: \t %s" % (label, wavpath))
+	print("Training files to be used:")
+	for wavpath,label in sorted(trainlist.items()):
+		print(" %s: \t %s" % (label, wavpath))
 
-	if args['testpath'] != args['trainpath']:
-		# Separate train-and-test collections
-		ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath'])
-		print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses))
-	else:
-		# This runs "stratified leave-one-out crossvalidation": test multiple times by leaving one-of-each-class out and training on the rest.
-		# First we need to build a list of files grouped by each classlabel
-		labelsinuse = sorted(list(set(wavsfound['trainpath'].values())))
-		grouped = {label:[] for label in labelsinuse}
-		for wavpath,label in wavsfound['trainpath'].items():
-			grouped[label].append(wavpath)
-		numfolds = min(len(collection) for collection in grouped.values())
-		# Each "fold" will be a collection of one item of each label
-		folds = [{wavpaths[index]:label for label,wavpaths in grouped.items()} for index in range(numfolds)]
-		totcorrect, tottotal = (0,0)
-		# Then we go through, each time training on all-but-one and testing on the one left out
-		for index in range(numfolds):
-			print("Fold %i of %i" % (index+1, numfolds))
-			chosenfold = folds[index]
-			alltherest = {}
-			for whichfold, otherfold in enumerate(folds):
-				if whichfold != index:
-					alltherest.update(otherfold)
-			ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenfold)
-			totcorrect += ncorrect
-			tottotal   += ntotal
-		print("Got %i correct out of %i (using stratified leave-one-out crossvalidation, %i folds)" % (totcorrect, tottotal, numfolds))
+	# Load the test list
+	testlist = []
+	for row in csv.reader(file(args['testlist']), delimiter="\t"):
+		if len(row)==1:
+			testlist.append(row[0])
+		elif len(row)>1:
+			raise ValueError("Row has a tab character in - SHOULD NOT happen for testlist: ", row)
+	print("Testing files to be used:")
+	for item in testlist: print(item)
 
+	# OK so let's go
+	outlist = file(args['outlist'], 'wb')
+	print("TRAINING")
+	model = Smacpy('/', trainlist)
+	print("TESTING")
+	for wavpath in testlist:
+		result = model.classify(wavpath)
+		outlist.write("%s\t%s\n" % (wavpath, result))
+	outlist.close()
+	print("Finished.")
+