Mercurial > hg > smacpy
changeset 15:a7b513f0057c
separate out a handy trainAndTest() method
author | Dan Stowell <danstowell@users.sourceforge.net> |
---|---|
date | Tue, 04 Dec 2012 22:25:00 +0000 |
parents | 3c6b9180b740 |
children | f872eceb2c9c |
files | smacpy.py |
diffstat | 1 files changed, 22 insertions(+), 18 deletions(-) [+] |
line wrap: on
line diff
--- a/smacpy.py Thu Nov 29 14:10:52 2012 +0000 +++ b/smacpy.py Tue Dec 04 22:25:00 2012 +0000 @@ -22,7 +22,6 @@ ####################################################################### # some settings - framelen = 1024 fs = 44100.0 verbose = True @@ -141,6 +140,21 @@ return ret ####################################################################### +def trainAndTest(trainpath, trainwavs, testpath, testwavs): + "Trains a model, tests it on wavs of known class. Returns (numcorrect, numtotal, numclasses)." + print("TRAINING") + model = Smacpy(trainpath, trainwavs) + print("TESTING") + ncorrect = 0 + for wavpath,label in testwavs.items(): + result = model.classify(os.path.join(testpath, wavpath)) + if verbose: + print(" inferred: %s" % result) + if result == label: + ncorrect += 1 + return (ncorrect, len(testwavs), len(model.gmms)) + +####################################################################### # If this file is invoked as a script, it carries out a simple runthrough # of training on some wavs, then testing, with classnames being the start of the filenames if __name__ == '__main__': @@ -148,7 +162,7 @@ # Handle the command-line arguments for where the train/test data comes from: parser = argparse.ArgumentParser() parser.add_argument('-t', '--trainpath', default='wavs', help="Path to the WAV files used for training") - parser.add_argument('-T', '--testpath', default='wavs', help="Path to the WAV files used for testing") + parser.add_argument('-T', '--testpath', help="Path to the WAV files used for testing") parser.add_argument('-q', dest='quiet', action='store_true', help="Be less verbose, don't output much text during processing") group = parser.add_mutually_exclusive_group() group.add_argument('-c', '--charsplit', default='_', help="Character used to split filenames: anything BEFORE this character is the class") @@ -156,6 +170,10 @@ args = vars(parser.parse_args()) verbose = not args['quiet'] + if args['testpath']==None: + print("TODO: loocv") # TODO!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + args['testpath'] = args['trainpath'] + # Build up lists of the training and testing WAV files: wavsfound = {'trainpath':{}, 'testpath':{}} for onepath in ['trainpath', 'testpath']: @@ -174,20 +192,6 @@ for wavpath,label in sorted(wavsfound[onepath].items()): print(" %s: \t %s" % (label, wavpath)) - print("##################################################") - print("TRAINING") - model = Smacpy(args['trainpath'], wavsfound['trainpath']) + ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) + print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) - print("##################################################") - print("TESTING") - if args['trainpath'] == args['testpath']: - print(" (nb testing on the same files as used for training - for true evaluation please train and test on independent data):") - ncorrect = 0 - for wavpath,label in wavsfound['testpath'].items(): - result = model.classify(os.path.join(args['testpath'], wavpath)) - if verbose: - print(" inferred: %s" % result) - if result == label: - ncorrect += 1 - print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, len(wavsfound['testpath']), len(model.gmms))) -