Mercurial > hg > smacpy
changeset 17:a7c4cc56ac6f
implement stratified LOOCV, better and faster than LOOCV
author | Dan Stowell <danstowell@users.sourceforge.net> |
---|---|
date | Fri, 07 Dec 2012 11:39:13 +0000 |
parents | f872eceb2c9c |
children | 6a34565c8a74 |
files | smacpy.py |
diffstat | 1 files changed, 20 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/smacpy.py Wed Dec 05 09:43:26 2012 +0000 +++ b/smacpy.py Fri Dec 07 11:39:13 2012 +0000 @@ -103,7 +103,7 @@ if verbose: print("Reading %s" % wavpath) if not os.path.isfile(wavpath): raise ValueError("path %s not found" % path) sf = Sndfile(wavpath, "r") - if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels) + #if (sf.channels != 1) and verbose: print(" Sound file has multiple channels (%i) - channels will be mixed to mono." % sf.channels) if sf.samplerate != fs: raise ValueError("wanted sample rate %g - got %g." % (fs, sf.samplerate)) window = np.hamming(framelen) features = [] @@ -170,7 +170,7 @@ for onepath in ['trainpath', 'testpath']: pattern = os.path.join(args[onepath], '*.wav') for wavpath in glob(pattern): - if args['numchars'] > 0: + if args['numchars'] != 0: label = os.path.basename(wavpath)[:args['numchars']] else: label = os.path.basename(wavpath).split(args['charsplit'])[0] @@ -188,13 +188,25 @@ ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], wavsfound['trainpath'], args['testpath'], wavsfound['testpath']) print("Got %i correct out of %i (trained on %i classes)" % (ncorrect, ntotal, nclasses)) else: - # "Leave-one-out-crossvalidation": leave one at a time out of the training set, and see how it is classified + # This runs "stratified leave-one-out crossvalidation": test multiple times by leaving one-of-each-class out and training on the rest. + # First we need to build a list of files grouped by each classlabel + labelsinuse = list(set(wavsfound['trainpath'].values())) + grouped = {label:[] for label in labelsinuse} + for wavpath,label in wavsfound['trainpath'].items(): + grouped[label].append(wavpath) + numfolds = min(len(collection) for collection in grouped.values()) + # Each "fold" will be a collection of one item of each label + folds = [{wavpaths[index]:label for label,wavpaths in grouped.items()} for index in range(numfolds)] totcorrect, tottotal = (0,0) - for whichone in wavsfound['trainpath']: - chosenone = {k:v for (k, v) in wavsfound['trainpath'].items() if k == whichone} - alltherest = {k:v for (k, v) in wavsfound['trainpath'].items() if k != whichone} - ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenone) + # Then we go through, each time training on all-but-one and testing on the one left out + for index in range(numfolds): + chosenfold = folds[index] + alltherest = {} + for whichfold, otherfold in enumerate(folds): + if whichfold != index: + alltherest.update(otherfold) + ncorrect, ntotal, nclasses = trainAndTest(args['trainpath'], alltherest, args['trainpath'], chosenfold) totcorrect += ncorrect tottotal += ntotal - print("Got %i correct out of %i (leave-one-out crossvalidation)" % (totcorrect, tottotal)) + print("Got %i correct out of %i (using stratified leave-one-out crossvalidation, %i folds)" % (totcorrect, tottotal, numfolds))