Mercurial > hg > chime-home-dataset-annotation-and-baseline-evaluation-code
comparison gmm_baseline_experiments/run_experiments.py @ 2:cb535b80218a
Remaining scripts and brief documentation
author | peterf |
---|---|
date | Fri, 10 Jul 2015 23:24:23 +0100 |
parents | |
children | b523456082ca |
comparison
equal
deleted
inserted
replaced
1:f079d2de4aa2 | 2:cb535b80218a |
---|---|
1 #!/usr/bin/python | |
2 | |
3 # | |
4 # run_experiments.py: | |
5 # Main script for CHiME-Home dataset baseline GMM evaluation | |
6 # | |
7 # Author: Peter Foster | |
8 # (c) 2015 Peter Foster | |
9 # | |
10 | |
11 from pylab import * | |
12 from sklearn import cross_validation | |
13 import os | |
14 from pandas import Series, DataFrame | |
15 from collections import defaultdict | |
16 from extract_features import FeatureExtractor | |
17 import exper002 | |
18 import custompickler | |
19 from compute_performance_statistics import compute_performance_statistics | |
20 import pdb | |
21 | |
22 Settings = {'paths':{}, 'algorithms':{}} | |
23 Settings['paths'] = {'chime_home': {}, 'resultsdir':'/import/c4dm-scratch/peterf/audex/results/', 'featuresdir':'/import/c4dm-scratch/peterf/audex/features/'} | |
24 Settings['paths']['chime_home'] = {'basepath':'/import/c4dm-02/people/peterf/audex/datasets/chime_home/'} | |
25 | |
26 #Read data sets and class assignments | |
27 Datasets = {'chime_home':{}} | |
28 | |
29 #Read in annotations | |
30 Chunks = list(Series.from_csv(Settings['paths']['chime_home']['basepath'] + 'release_chunks_refined.csv',header=None)) | |
31 Annotations = [] | |
32 for chunk in Chunks: | |
33 Annotations.append(Series.from_csv(Settings['paths']['chime_home']['basepath'] + 'chunks/' + chunk + '.csv')) | |
34 Datasets['chime_home']['dataset'] = DataFrame(Annotations) | |
35 | |
36 #Compute label statistics | |
37 Datasets['chime_home']['labelstats'] = defaultdict(lambda: 0) | |
38 for item in Datasets['chime_home']['dataset']['majorityvote']: | |
39 for label in item: | |
40 Datasets['chime_home']['labelstats'][label] += 1 | |
41 #Labels to consider for multilabel classification -- based on label set used in Stowell and Plumbley (2013) | |
42 Datasets['chime_home']['consideredlabels'] = ['c', 'b', 'f', 'm', 'o', 'p', 'v'] | |
43 #Populate binary label assignments | |
44 for label in Datasets['chime_home']['consideredlabels']: | |
45 Datasets['chime_home']['dataset'][label] = [label in item for item in Datasets['chime_home']['dataset']['majorityvote']] | |
46 #Obtain statistics for considered labels | |
47 sum(Datasets['chime_home']['dataset'][Datasets['chime_home']['consideredlabels']]) / len(Datasets['chime_home']['dataset']) | |
48 #Create partition for 10-fold cross-validation. Shuffling ensures each fold has approximately equal proportion of label ocurrences | |
49 np.random.seed(475686) | |
50 Datasets['chime_home']['crossval_10fold'] = cross_validation.KFold(len(Datasets['chime_home']['dataset']), 10, shuffle=True) | |
51 | |
52 Datasets['chime_home']['dataset']['wavfile'] = Datasets['chime_home']['dataset']['chunkname'].apply(lambda s: Settings['paths']['chime_home']['basepath'] + 'chunks/' + s + '.wav') | |
53 | |
54 #Extract features and assign them to Datasets structure | |
55 for dataset in Datasets.keys(): | |
56 picklepath = os.path.join(Settings['paths']['featuresdir'],'features_' + dataset) | |
57 if not(os.path.isfile(picklepath)): | |
58 if dataset == 'chime_home': | |
59 featureExtractor = FeatureExtractor(samplingRate=48000, frameLength=1024, hopLength=512) | |
60 else: | |
61 raise NotImplementedError() | |
62 FeatureList = featureExtractor.files_to_features(Datasets[dataset]['dataset']['wavfile']) | |
63 custompickler.pickle_save(FeatureList,picklepath) | |
64 else: | |
65 FeatureList = custompickler.pickle_load(picklepath) | |
66 #Integrity check | |
67 for features in FeatureList: | |
68 for feature in features.values(): | |
69 assert(all(isfinite(feature.ravel()))) | |
70 Datasets[dataset]['dataset']['features'] = FeatureList | |
71 | |
72 #GMM experiments using CHiME home dataset | |
73 EXPER005 = {} | |
74 EXPER005['name'] = 'GMM_Baseline_EXPER005' | |
75 EXPER005['path'] = os.path.join(Settings['paths']['resultsdir'],'exploratory','saved_objects','EXPER005') | |
76 EXPER005['settings'] = {'numcomponents': (1,2,4,8), 'features': ('librosa_mfccs',)} | |
77 EXPER005['datasets'] = {} | |
78 EXPER005['datasets']['chime_home'] = exper002.exper002_multilabelclassification(Datasets['chime_home']['dataset'], Datasets['chime_home']['consideredlabels'], Datasets['chime_home']['crossval_10fold'], Settings, numComponentValues=EXPER005['settings']['numcomponents'], featureTypeValues=EXPER005['settings']['features']) | |
79 EXPER005 = compute_performance_statistics(EXPER005, Datasets, Settings, iterableParameters=['numcomponents', 'features']) | |
80 custompickler.pickle_save(EXPER005, EXPER005['path']) | |
81 | |
82 #Collate results | |
83 def accumulate_results(EXPER): | |
84 EXPER['summaryresults'] = {} | |
85 ds = EXPER['datasets'].keys()[0] | |
86 for numComponents in EXPER['settings']['numcomponents']: | |
87 EXPER['summaryresults'][numComponents] = {} | |
88 for label in Datasets[ds]['consideredlabels']: | |
89 EXPER['summaryresults'][numComponents][label] = EXPER['datasets'][ds][(numComponents, 'librosa_mfccs')]['performance']['classwise'][label]['auc_precisionrecall'] | |
90 EXPER['summaryresults'] = DataFrame(EXPER['summaryresults']) | |
91 accumulate_results(EXPER005) | |
92 | |
93 #Generate plot | |
94 def plot_performance(EXPER): | |
95 fig_width_pt = 246.0 # Get this from LaTeX using \showthe\columnwidth | |
96 inches_per_pt = 1.0/72.27 # Convert pt to inch | |
97 golden_mean = (sqrt(5)-1.0)/2.0 # Aesthetic ratio | |
98 fig_width = fig_width_pt*inches_per_pt # width in inches | |
99 fig_height = fig_width*golden_mean # height in inches | |
100 fig_size = [fig_width,fig_height] | |
101 params = {'backend': 'ps', | |
102 'axes.labelsize': 8, | |
103 'text.fontsize': 8, | |
104 'legend.fontsize': 7.0, | |
105 'xtick.labelsize': 8, | |
106 'ytick.labelsize': 8, | |
107 'text.usetex': False, | |
108 'figure.figsize': fig_size} | |
109 rcParams.update(params) | |
110 ind = np.arange(len(EXPER['summaryresults'][1])) # the x locations for the groups | |
111 width = 0.22 # the width of the bars | |
112 fig, ax = plt.subplots() | |
113 rects = [] | |
114 colours = ('r', 'y', 'g', 'b', 'c') | |
115 for numComponents, i in zip(EXPER['summaryresults'],range(len(EXPER['summaryresults']))): | |
116 rects.append(ax.bar(ind+width*i, EXPER['summaryresults'][numComponents][['c','m','f','v','p','b','o']], width, color=colours[i], align='center')) | |
117 # add text for labels, title and axes ticks | |
118 ax.set_ylabel('AUC') | |
119 ax.set_xlabel('Label') | |
120 ax.set_xticks(ind+width) | |
121 ax.set_xticklabels(('c','m','f','v','p','b','o')) | |
122 ax.legend( (rect[0] for rect in rects), ('k=1', 'k=2', 'k=4','k=8') ,loc='lower right') | |
123 #Tweak x-axis limit | |
124 ax.set_xlim(left=-0.5) | |
125 ax.set_ylim(top=1.19) | |
126 plt.gcf().subplots_adjust(left=0.15) #Prevent y-axis label from being chopped off | |
127 def autolabel(r): | |
128 for rects in r: | |
129 for rect in rects: | |
130 height = rect.get_height() | |
131 ax.text(rect.get_x()+0.14,0.04+height,'%1.2f'%float(height),ha='center',va='bottom',rotation='vertical',size=6.0) | |
132 autolabel(rects) | |
133 plt.draw() | |
134 plt.savefig('figures/predictionperformance' + EXPER['name'] +'.pdf') | |
135 plot_performance(EXPER005) |