Mercurial > hg > smacpy
changeset 26:54612c177bcc
Change smacpy commandline to do input/output as specified for AASP challenge
author | Dan Stowell <danstowell@users.sourceforge.net> |
---|---|
date | Wed, 23 Jan 2013 10:48:12 +0000 |
parents | 1a2cfd98e737 |
children | 55b7b5a5cf43 |
files | smacpy.py |
diffstat | 1 files changed, 36 insertions(+), 54 deletions(-) [+] |
line wrap: on
line diff
--- a/smacpy.py Thu Jan 10 14:31:40 2013 +0000 +++ b/smacpy.py Wed Jan 23 10:48:12 2013 +0000 @@ -17,6 +17,7 @@ from scikits.audiolab import Sndfile from scikits.audiolab import Format from sklearn.mixture import GMM +import csv from MFCC import melScaling @@ -147,67 +148,48 @@ return (ncorrect, len(testwavs), len(model.gmms)) ####################################################################### -# If this file is invoked as a script, it carries out a simple runthrough -# of training on some wavs, then testing, with classnames being the start of the filenames +# This handles the invocation as set up for the AASP challenge +# eg: python smacpy.py --trainlist trainlist.example.txt --testlist testlist.example.txt --outlist output.txt if __name__ == '__main__': - # Handle the command-line arguments for where the train/test data comes from: parser = argparse.ArgumentParser() - parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training") - parser.add_argument('-T', '--testpath', help="Path to the WAV files used for testing") + parser.add_argument('-t', '--trainlist', default='NOT_SPECIFIED', help="Path to the file listing WAV files used for training") + parser.add_argument('-T', '--testlist', default='NOT_SPECIFIED', help="Path to the file listing WAV files used for testing") + parser.add_argument('-o', '--outlist', default='output.txt', help="Path to write results to") parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing") - group = parser.add_mutually_exclusive_group() - group.add_argument('-c', '--charsplit', default='_', help="Character used to split filenames: anything BEFORE this character is the class") - group.add_argument('-n', '--numchars' , default=0 , help="Instead of splitting using 'charsplit', use this fixed number of characters from the start of the filename", type=int) args = vars(parser.parse_args()) verbose = not args['quiet'] - if args['testpath']==None: - args['testpath'] = args['trainpath'] + # Load the training list as a dictionary of "filename->classification" data + trainlist = {} + for row in csv.reader(file(args['trainlist']), delimiter="\t"): + if len(row)==2: + trainlist[row[0]] = row[1] + elif len(row)>2: + raise ValueError("Row has more than one tab character in: ", row) - # Build up lists of the training and testing WAV files: - wavsfound = {'trainpath':{}, 'testpath':{}} - for onepath in ['trainpath', 'testpath']: - pattern = os.path.join(args[onepath], '*.wav') - for wavpath in glob(pattern): - if args['numchars'] != 0: - label = os.path.basename(wavpath)[:args['numchars']] - else: - label = os.path.basename(wavpath).split(args['charsplit'])[0] - shortwavpath = os.path.relpath(wavpath, args[onepath]) - wavsfound[onepath][shortwavpath] = label - if len(wavsfound[onepath])==0: - raise RuntimeError("Found no files using this pattern: %s" % pattern) - if verbose: - print("Class-labels and filenames to be used from %s:" % onepath) - for wavpath,label in sorted(wavsfound[onepath].items()): - print(" %s: \t %s" % (label, wavpath)) + print("Training files to be used:") + for wavpath,label in sorted(trainlist.items()): + print(" %s: \t %s" % (label, wavpath)) - if args['testpath'] != args['trainpath']: - # Separate train-and-test collections - ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) - print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) - else: - # This runs "stratified leave-one-out crossvalidation": test multiple times by leaving one-of-each-class out and training on the rest. - # First we need to build a list of files grouped by each classlabel - labelsinuse = sorted(list(set(wavsfound['trainpath'].values()))) - grouped = {label:[] for label in labelsinuse} - for wavpath,label in wavsfound['trainpath'].items(): - grouped[label].append(wavpath) - numfolds = min(len(collection) for collection in grouped.values()) - # Each "fold" will be a collection of one item of each label - folds = [{wavpaths[index]:label for label,wavpaths in grouped.items()} for index in range(numfolds)] - totcorrect, tottotal = (0,0) - # Then we go through, each time training on all-but-one and testing on the one left out - for index in range(numfolds): - print("Fold %i of %i" % (index+1, numfolds)) - chosenfold = folds[index] - alltherest = {} - for whichfold, otherfold in enumerate(folds): - if whichfold != index: - alltherest.update(otherfold) - ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenfold) - totcorrect += ncorrect - tottotal += ntotal - print("Got %i correct out of %i (using stratified leave-one-out crossvalidation, %i folds)" % (totcorrect, tottotal, numfolds)) + # Load the test list + testlist = [] + for row in csv.reader(file(args['testlist']), delimiter="\t"): + if len(row)==1: + testlist.append(row[0]) + elif len(row)>1: + raise ValueError("Row has a tab character in - SHOULD NOT happen for testlist: ", row) + print("Testing files to be used:") + for item in testlist: print(item) + # OK so let's go + outlist = file(args['outlist'], 'wb') + print("TRAINING") + model = Smacpy('/', trainlist) + print("TESTING") + for wavpath in testlist: + result = model.classify(wavpath) + outlist.write("%s\t%s\n" % (wavpath, result)) + outlist.close() + print("Finished.") +