# HG changeset patch # User Dan Stowell # Date 1354659900 0 # Node ID a7b513f0057c91aba497d73d59c704d5a9158371 # Parent 3c6b9180b740db2f65925d05ec63119589e8107d separate out a handy trainAndTest() method diff -r 3c6b9180b740 -r a7b513f0057c smacpy.py --- a/smacpy.py Thu Nov 29 14:10:52 2012 +0000 +++ b/smacpy.py Tue Dec 04 22:25:00 2012 +0000 @@ -22,7 +22,6 @@ ####################################################################### # some settings - framelen = 1024 fs = 44100.0 verbose = True @@ -141,6 +140,21 @@ return ret ####################################################################### +def trainAndTest(trainpath, trainwavs, testpath, testwavs): + "Trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)." + print("TRAINING") + model = Smacpy(trainpath, trainwavs) + print("TESTING") + ncorrect = 0 + for wavpath,label in testwavs.items(): + result = model.classify(os.path.join(testpath, wavpath)) + if verbose: + print(" inferred: %s" % result) + if result == label: + ncorrect += 1 + return (ncorrect, len(testwavs), len(model.gmms)) + +####################################################################### # If this file is invoked as a script, it carries out a simple runthrough # of training on some wavs, then testing, with classnames being the start of the filenames if __name__ == '__main__': @@ -148,7 +162,7 @@ # Handle the command-line arguments for where the train/test data comes from: parser = argparse.ArgumentParser() parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training") - parser.add_argument('-T', '--testpath', default='wavs', help="Path to the WAV files used for testing") + parser.add_argument('-T', '--testpath', help="Path to the WAV files used for testing") parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing") group = parser.add_mutually_exclusive_group() group.add_argument('-c', '--charsplit', default='_', help="Character used to split filenames: anything BEFORE this character is the class") @@ -156,6 +170,10 @@ args = vars(parser.parse_args()) verbose = not args['quiet'] + if args['testpath']==None: + print("TODO: loocv") # TODO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + args['testpath'] = args['trainpath'] + # Build up lists of the training and testing WAV files: wavsfound = {'trainpath':{}, 'testpath':{}} for onepath in ['trainpath', 'testpath']: @@ -174,20 +192,6 @@ for wavpath,label in sorted(wavsfound[onepath].items()): print(" %s: \t %s" % (label, wavpath)) - print("##################################################") - print("TRAINING") - model = Smacpy(args['trainpath'], wavsfound['trainpath']) + ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) + print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) - print("##################################################") - print("TESTING") - if args['trainpath'] == args['testpath']: - print(" (nb testing on the same files as used for training - for true evaluation please train and test on independent data):") - ncorrect = 0 - for wavpath,label in wavsfound['testpath'].items(): - result = model.classify(os.path.join(args['testpath'], wavpath)) - if verbose: - print(" inferred: %s" % result) - if result == label: - ncorrect += 1 - print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, len(wavsfound['testpath']), len(model.gmms))) -