changeset 13:4a29489a577b

allow num-chars option for class label
author Dan Stowell <danstowell@users.sourceforge.net>
date Thu, 29 Nov 2012 13:51:36 +0000
parents 383cc866a221
children 3c6b9180b740
files smacpy.py
diffstat 1 files changed, 9 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/smacpy.py	Thu Nov 29 13:17:51 2012 +0000
+++ b/smacpy.py	Thu Nov 29 13:51:36 2012 +0000
@@ -147,10 +147,12 @@
 
 	# Handle the command-line arguments for where the train/test data comes from:
 	parser = argparse.ArgumentParser()
-	parser.add_argument('-t', '--trainpath',  default='wavs', help="Path to the WAV files used for training")
-	parser.add_argument('-T', '--testpath',   default='wavs', help="Path to the WAV files used for testing")
-	parser.add_argument('-c', '--charsplit',  default='_',   help="Character used to split filenames: anything BEFORE this character is the class")
+	parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training")
+	parser.add_argument('-T', '--testpath',  default='wavs', help="Path to the WAV files used for testing")
 	parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing")
+	group = parser.add_mutually_exclusive_group()
+	group.add_argument('-c', '--charsplit',  default='_',    help="Character used to split filenames: anything BEFORE this character is the class")
+	group.add_argument('-n', '--numchars' ,  default=0  ,    help="Instead of splitting using 'charsplit', use this fixed number of characters from the start of the filename", type=int)
 	args = vars(parser.parse_args())
 	verbose = not args['quiet']
 
@@ -159,9 +161,10 @@
 	for onepath in ['trainpath', 'testpath']:
 		pattern = os.path.join(args[onepath], '*.wav')
 		for wavpath in glob(pattern):
-			label = os.path.basename(wavpath).split(args['charsplit'])[0]
-			# a little hack to use first 4 chars as label:
-			# label = os.path.basename(wavpath)[:4]
+			if args['numchars'] > 0:
+				label = os.path.basename(wavpath)[:args['numchars']]
+			else:
+				label = os.path.basename(wavpath).split(args['charsplit'])[0]
 			shortwavpath = os.path.relpath(wavpath, args[onepath])
 			wavsfound[onepath][shortwavpath] = label
 		if len(wavsfound[onepath])==0: