Mercurial > hg > chime-home-dataset-annotation-and-baseline-evaluation-code
diff gmm_baseline_experiments/run_experiments.py @ 2:cb535b80218a
Remaining scripts and brief documentation
author | peterf |
---|---|
date | Fri, 10 Jul 2015 23:24:23 +0100 |
parents | |
children | b523456082ca |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/gmm_baseline_experiments/run_experiments.py Fri Jul 10 23:24:23 2015 +0100 @@ -0,0 +1,135 @@ +#!/usr/bin/python + +# +# run_experiments.py: +# Main script for CHiME-Home dataset baseline GMM evaluation +# +# Author: Peter Foster +# (c) 2015 Peter Foster +# + +from pylab import * +from sklearn import cross_validation +import os +from pandas import Series, DataFrame +from collections import defaultdict +from extract_features import FeatureExtractor +import exper002 +import custompickler +from compute_performance_statistics import compute_performance_statistics +import pdb + +Settings = {'paths':{}, 'algorithms':{}} +Settings['paths'] = {'chime_home': {}, 'resultsdir':'/import/c4dm-scratch/peterf/audex/results/', 'featuresdir':'/import/c4dm-scratch/peterf/audex/features/'} +Settings['paths']['chime_home'] = {'basepath':'/import/c4dm-02/people/peterf/audex/datasets/chime_home/'} + +#Read data sets and class assignments +Datasets = {'chime_home':{}} + +#Read in annotations +Chunks = list(Series.from_csv(Settings['paths']['chime_home']['basepath'] + 'release_chunks_refined.csv',header=None)) +Annotations = [] +for chunk in Chunks: + Annotations.append(Series.from_csv(Settings['paths']['chime_home']['basepath'] + 'chunks/' + chunk + '.csv')) +Datasets['chime_home']['dataset'] = DataFrame(Annotations) + +#Compute label statistics +Datasets['chime_home']['labelstats'] = defaultdict(lambda: 0) +for item in Datasets['chime_home']['dataset']['majorityvote']: + for label in item: + Datasets['chime_home']['labelstats'][label] += 1 +#Labels to consider for multilabel classification -- based on label set used in Stowell and Plumbley (2013) +Datasets['chime_home']['consideredlabels'] = ['c', 'b', 'f', 'm', 'o', 'p', 'v'] +#Populate binary label assignments +for label in Datasets['chime_home']['consideredlabels']: + Datasets['chime_home']['dataset'][label] = [label in item for item in Datasets['chime_home']['dataset']['majorityvote']] +#Obtain statistics for considered labels +sum(Datasets['chime_home']['dataset'][Datasets['chime_home']['consideredlabels']]) / len(Datasets['chime_home']['dataset']) +#Create partition for 10-fold cross-validation. Shuffling ensures each fold has approximately equal proportion of label ocurrences +np.random.seed(475686) +Datasets['chime_home']['crossval_10fold'] = cross_validation.KFold(len(Datasets['chime_home']['dataset']), 10, shuffle=True) + +Datasets['chime_home']['dataset']['wavfile'] = Datasets['chime_home']['dataset']['chunkname'].apply(lambda s: Settings['paths']['chime_home']['basepath'] + 'chunks/' + s + '.wav') + +#Extract features and assign them to Datasets structure +for dataset in Datasets.keys(): + picklepath = os.path.join(Settings['paths']['featuresdir'],'features_' + dataset) + if not(os.path.isfile(picklepath)): + if dataset == 'chime_home': + featureExtractor = FeatureExtractor(samplingRate=48000, frameLength=1024, hopLength=512) + else: + raise NotImplementedError() + FeatureList = featureExtractor.files_to_features(Datasets[dataset]['dataset']['wavfile']) + custompickler.pickle_save(FeatureList,picklepath) + else: + FeatureList = custompickler.pickle_load(picklepath) + #Integrity check + for features in FeatureList: + for feature in features.values(): + assert(all(isfinite(feature.ravel()))) + Datasets[dataset]['dataset']['features'] = FeatureList + +#GMM experiments using CHiME home dataset +EXPER005 = {} +EXPER005['name'] = 'GMM_Baseline_EXPER005' +EXPER005['path'] = os.path.join(Settings['paths']['resultsdir'],'exploratory','saved_objects','EXPER005') +EXPER005['settings'] = {'numcomponents': (1,2,4,8), 'features': ('librosa_mfccs',)} +EXPER005['datasets'] = {} +EXPER005['datasets']['chime_home'] = exper002.exper002_multilabelclassification(Datasets['chime_home']['dataset'], Datasets['chime_home']['consideredlabels'], Datasets['chime_home']['crossval_10fold'], Settings, numComponentValues=EXPER005['settings']['numcomponents'], featureTypeValues=EXPER005['settings']['features']) +EXPER005 = compute_performance_statistics(EXPER005, Datasets, Settings, iterableParameters=['numcomponents', 'features']) +custompickler.pickle_save(EXPER005, EXPER005['path']) + +#Collate results +def accumulate_results(EXPER): + EXPER['summaryresults'] = {} + ds = EXPER['datasets'].keys()[0] + for numComponents in EXPER['settings']['numcomponents']: + EXPER['summaryresults'][numComponents] = {} + for label in Datasets[ds]['consideredlabels']: + EXPER['summaryresults'][numComponents][label] = EXPER['datasets'][ds][(numComponents, 'librosa_mfccs')]['performance']['classwise'][label]['auc_precisionrecall'] + EXPER['summaryresults'] = DataFrame(EXPER['summaryresults']) +accumulate_results(EXPER005) + +#Generate plot +def plot_performance(EXPER): + fig_width_pt = 246.0 # Get this from LaTeX using \showthe\columnwidth + inches_per_pt = 1.0/72.27 # Convert pt to inch + golden_mean = (sqrt(5)-1.0)/2.0 # Aesthetic ratio + fig_width = fig_width_pt*inches_per_pt # width in inches + fig_height = fig_width*golden_mean # height in inches + fig_size = [fig_width,fig_height] + params = {'backend': 'ps', + 'axes.labelsize': 8, + 'text.fontsize': 8, + 'legend.fontsize': 7.0, + 'xtick.labelsize': 8, + 'ytick.labelsize': 8, + 'text.usetex': False, + 'figure.figsize': fig_size} + rcParams.update(params) + ind = np.arange(len(EXPER['summaryresults'][1])) # the x locations for the groups + width = 0.22 # the width of the bars + fig, ax = plt.subplots() + rects = [] + colours = ('r', 'y', 'g', 'b', 'c') + for numComponents, i in zip(EXPER['summaryresults'],range(len(EXPER['summaryresults']))): + rects.append(ax.bar(ind+width*i, EXPER['summaryresults'][numComponents][['c','m','f','v','p','b','o']], width, color=colours[i], align='center')) + # add text for labels, title and axes ticks + ax.set_ylabel('AUC') + ax.set_xlabel('Label') + ax.set_xticks(ind+width) + ax.set_xticklabels(('c','m','f','v','p','b','o')) + ax.legend( (rect[0] for rect in rects), ('k=1', 'k=2', 'k=4','k=8') ,loc='lower right') + #Tweak x-axis limit + ax.set_xlim(left=-0.5) + ax.set_ylim(top=1.19) + plt.gcf().subplots_adjust(left=0.15) #Prevent y-axis label from being chopped off + def autolabel(r): + for rects in r: + for rect in rects: + height = rect.get_height() + ax.text(rect.get_x()+0.14,0.04+height,'%1.2f'%float(height),ha='center',va='bottom',rotation='vertical',size=6.0) + autolabel(rects) + plt.draw() + plt.savefig('figures/predictionperformance' + EXPER['name'] +'.pdf') +plot_performance(EXPER005) \ No newline at end of file