Mercurial > hg > segmentation
comparison ProbSegmenter.py @ 13:cc8ceb270e79
add some gmm ipynbs
author | mitian |
---|---|
date | Fri, 05 Jun 2015 18:02:05 +0100 |
parents | |
children | 6dae41887406 |
comparison
equal
deleted
inserted
replaced
12:c23658e8ae38 | 13:cc8ceb270e79 |
---|---|
1 #!/usr/bin/env python | |
2 # encoding: utf-8 | |
3 """ | |
4 ProbSegmenter.py | |
5 """ | |
6 | |
7 import matplotlib | |
8 matplotlib.use('Agg') | |
9 import matplotlib.pyplot as plt | |
10 | |
11 import sys, os, optparse, cPickle | |
12 import numpy as np | |
13 from numpy import abs,log,exp,floor,sum,sqrt,cos,hstack, power | |
14 from math import ceil, floor | |
15 from scipy.signal import correlate2d, convolve2d | |
16 from itertools import combinations | |
17 from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext | |
18 | |
19 from sklearn.decomposition import PCA | |
20 from sklearn.mixture import GMM | |
21 from sklearn.metrics.pairwise import pairwise_distances | |
22 from sklearn.cluster import KMeans, DBSCAN | |
23 from scipy.spatial.distance import squareform, pdist | |
24 | |
25 from utils.GmmMetrics import GmmDistance | |
26 from utils.RankClustering import rClustering | |
27 from utils.kmeans import Kmeans | |
28 from utils.ComputationCache import Meta, with_pickle_dump | |
29 from utils.SegUtil import normaliseFeature, upSample, getSSM, getMean, getStd, getDelta | |
30 | |
31 def parse_args(): | |
32 # define parser | |
33 op = optparse.OptionParser() | |
34 # IO options | |
35 op.add_option('-i', '--input', action="store", dest="INPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/features1', type="str", help="Loading features from.." ) | |
36 op.add_option('-a', '--audioset', action="store", dest="AUDIO", default=None, type="str", help="Select audio datasets ['qupujicheng', 'salami', 'beatles'] ") | |
37 op.add_option('-g', '--groundtruth', action="store", dest="GT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/annotations/lowercase', type="str", help="Loading annotation files from.. ") | |
38 op.add_option('-o', '--output', action="store", dest="OUTPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering', type="str", help="Loading annotation files from.. ") | |
39 op.add_option('-c', '--cache', action="store", dest="CACHE", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering/cache', type="str", help="Loading annotation files from.. ") | |
40 op.add_option('-t', '--test', action="store_true", dest="TEST", help="Select TEST mode.") | |
41 op.add_option('-v', '--verbose', action="store_true", dest="VERBOSE", help="Select VERBOSE mode to save features/SSMs.") | |
42 | |
43 return op.parse_args() | |
44 | |
45 options, args = parse_args() | |
46 | |
47 class FeatureObj() : | |
48 __slots__ = ['key','audio','timestamps','features'] | |
49 | |
50 class AudioObj(): | |
51 __slots__ = ['feature_list','feature_matrix','distance_matrix','name'] | |
52 | |
53 class Segmenter(object): | |
54 '''The main segmentation object''' | |
55 | |
56 meta = Meta() | |
57 | |
58 @classmethod | |
59 def set_cache_filename(cls, filename, cache = True, cache_location=""): | |
60 cls.meta.cache = cache | |
61 cls.meta.cache_file_base = filename | |
62 cls.meta.cache_location = cache_location | |
63 | |
64 @with_pickle_dump(meta) | |
65 def getGMMs(self, feature, filename, gmmWindow=10, stepsize=1, save=True): | |
66 gmm_list = [] | |
67 steps = int((feature.shape[0] - gmmWindow + stepsize) / stepsize) | |
68 for i in xrange(steps): | |
69 gmm_list.append(GmmDistance(feature[i*stepsize:(i*stepsize+gmmWindow), :].T, components = 2)) | |
70 # if save: | |
71 # with open(join('cache', filename), 'w+') as f: | |
72 # f.write(cPickle.dumps(gmm_list)) | |
73 return gmm_list | |
74 | |
75 def pairwiseSKL(self, gmm_list): | |
76 '''Compute pairwise symmetrised KL divergence of a list of GMMs.''' | |
77 n_GMMs = len(gmm_list) | |
78 distance_matrix = np.zeros((n_GMMs, n_GMMs)) | |
79 for i in xrange(n_GMMs): | |
80 for j in xrange(i, n_GMMs): | |
81 distance_matrix[i][j] = gmm_list[i].skl_distance_full(gmm_list[j]) | |
82 distance_matrix[j][i] = distance_matrix[i][j] | |
83 | |
84 np.fill_diagonal(distance_matrix, 0.0) | |
85 distance_matrix[np.isnan(distance_matrix)] = 0 | |
86 # distance_matrix[np.isinf(distance_matrix)] = np.finfo(np.float64).max | |
87 if np.isinf(distance_matrix).any(): | |
88 data = np.sort(np.ndarray.flatten(distance_matrix)) | |
89 pos = np.where(data == np.inf)[0][0] | |
90 fMax = data[pos-1] | |
91 print len(data), pos, fMax | |
92 distance_matrix[np.isinf(distance_matrix)] = fMax | |
93 return distance_matrix | |
94 | |
95 def smoothLabels(self, label_list, size=5): | |
96 '''Smooth label list within given length.''' | |
97 prev_labels = -1 | |
98 next_labels = -1 | |
99 chain = 0 | |
100 for i in xrange(size, len(label_list)-size): | |
101 label = label_list[i] | |
102 if label == prev_label: | |
103 chain += 1 | |
104 else: | |
105 if chain < size: | |
106 label_list[i] = prev_label | |
107 chain = 0 | |
108 prev_label = label_list[i] | |
109 print chain | |
110 return label_list | |
111 | |
112 def getInitialCentroids(self, neighborhood_size, k=10): | |
113 candidates = [] | |
114 size = len(neighborhood_size) | |
115 for i in xrange(1, size-1): | |
116 d1 = neighborhood_size[i] - neighborhood_size[i-1] | |
117 d2 = neighborhood_size[i] - neighborhood_size[i+1] | |
118 if d1 > 0 and d2 > 0: | |
119 candidates.append((i, max(d1, d2))) | |
120 print 'candidates', len(candidates), candidates, size | |
121 ranked_nodes = sorted(candidates, key=lambda x: x[1],reverse=True)[:k] | |
122 ranked_nodes = [ranked_nodes[i][0] for i in xrange(len(ranked_nodes))] | |
123 return ranked_nodes | |
124 | |
125 | |
126 def process(self): | |
127 | |
128 audio_files = [x for x in os.listdir(options.GT) if not x.startswith(".") ] | |
129 audio_files.sort() | |
130 if options.TEST: | |
131 # audio_files = ["""17 We Are The Champions.wav"""] | |
132 audio_files = audio_files[:2] | |
133 audio_list = [] | |
134 | |
135 fobj_list = [] | |
136 feature_list = [i for i in os.listdir(options.INPUT) if not i.startswith('.')] | |
137 feature_list = ['pcamean', 'dct', 'contrast6'] | |
138 | |
139 feature_list.sort() | |
140 | |
141 winlength = 50 | |
142 stepsize = 50 | |
143 | |
144 if options.AUDIO == None: | |
145 print 'Must specify audio dataset for evaluvation!' | |
146 | |
147 for i, audio in enumerate(audio_files) : | |
148 ao = AudioObj() | |
149 ao.name = splitext(audio)[0] | |
150 | |
151 ao_featureset = [] | |
152 for feature in feature_list : | |
153 for f in os.listdir(join(options.INPUT, feature)): | |
154 if f[:f.find('_vamp')]==ao.name: | |
155 data = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0)[:, 1:] | |
156 ao_featureset.append(data) | |
157 break | |
158 | |
159 n_features = len(ao_featureset) | |
160 if n_features == 0: continue | |
161 | |
162 if n_features > 1: | |
163 # find the feature with the fewer number of frames (the last a few frames should be generally empty) | |
164 n_frame = np.min([x.shape[0] for x in ao_featureset]) | |
165 ao_featureset = [x[:n_frame,:] for x in ao_featureset] | |
166 feature_matrix = np.hstack((ao_featureset)) | |
167 else: | |
168 feature_matrix = ao_featureset[0] | |
169 | |
170 print "Processing data for audio file:", audio, n_features | |
171 if options.AUDIO == 'salami': | |
172 annotation_file = join(options.GT, ao.name+'.txt') # iso, salami | |
173 ao.gt = np.genfromtxt(annotation_file, usecols=0) | |
174 elif options.AUDIO == 'qupujicheng': | |
175 annotation_file = join(options.GT, ao.name+'.csv') # qupujicheng | |
176 ao.gt = np.genfromtxt(annotation_file, usecols=0, delimiter=',') | |
177 elif options.AUDIO == 'beatles': | |
178 annotation_file = join(options.GT, ao.name+'.lab') # beatles | |
179 ao.gt = np.genfromtxt(annotation_file, usecols=(0,1)) | |
180 ao.gt = np.unique(np.ndarray.flatten(ao.gt)) | |
181 | |
182 n_frames = feature_matrix.shape[0] | |
183 | |
184 timestamps = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0, usecols=0) | |
185 # map timestamps to the reduced representations | |
186 timestamps = timestamps[0::stepsize] | |
187 | |
188 # # normalise the feature matrix, get rid of negative features, ensure numerical stability by adding a small constant | |
189 feature_matrix = normaliseFeature(feature_matrix) | |
190 | |
191 # np.savetxt('test/feature_maxtrix-'+ao.name+"-wl%i-ss%i.txt" %(winlength,stepsize), feature_matrix, delimiter=',') | |
192 fr = basename(options.INPUT) | |
193 print 'fr', fr, options.INPUT | |
194 feature_name = '' | |
195 for feature in feature_list: | |
196 feature_name += ('-' + feature) | |
197 # np.savetxt(join(options.OUTPUT, test, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), feature_matrix, delimiter=',') | |
198 | |
199 # PCA | |
200 pca = PCA(n_components=6) | |
201 pca.fit(feature_matrix) | |
202 feature_matrix = pca.transform(feature_matrix) | |
203 | |
204 cach_filename = ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize) | |
205 self.set_cache_filename(filename=cach_filename, cache_location=options.CACHE) | |
206 gmm_list = self.getGMMs(feature_matrix, filename=cach_filename, gmmWindow=winlength, stepsize=stepsize) | |
207 print 'number of GMMs:', len(gmm_list) | |
208 | |
209 skl_matrix = self.pairwiseSKL(gmm_list) | |
210 ssm = getSSM(skl_matrix) | |
211 np.savetxt(join(options.CACHE, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), ssm, delimiter=",") | |
212 | |
213 # # 1. DBSCAN clustering of raw feature | |
214 # db1 = DBSCAN(eps=10, min_samples=10).fit(feature_array) | |
215 # core_samples_mask1 = np.zeros_like(db1.labels_, dtype=bool) | |
216 # core_samples_mask1[db1.core_sample_indices_] = True | |
217 # labels1 = db1.labels_ | |
218 | |
219 # # 2. DBSCAN clustering of GMMs | |
220 # db2 = DBSCAN(eps=0.05, min_samples=10, metric='precomputed').fit(skl_matrix) | |
221 # core_samples_mask2 = np.zeros_like(db2.labels_, dtype=bool) | |
222 # core_samples_mask2[db2.core_sample_indices_] = True | |
223 # labels2 = db2.labels_ | |
224 | |
225 # 3. RC clustering of raw GMMs | |
226 rc = rClustering(eps=1.15, k=5, rank='max_neighbors') | |
227 rc.set_cache_filename(ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize), cache_location=options.CACHE) | |
228 rc.fit(gmm_list) | |
229 rc.test() | |
230 classification = rc.classification | |
231 print 'classification', classification | |
232 neighborhood_size, average_div, node_rank = rc.getNodeRank() | |
233 | |
234 # centroid_list = self.getInitialCentroids(neighborhood_size, k=10) | |
235 # print 'initial centroids', centroid_list | |
236 | |
237 # k-means clustering of GMMs | |
238 KmeansClustering = Kmeans(gmm_list, K=5, initial_centroids=set(classification)) | |
239 labels = KmeansClustering.fit() | |
240 | |
241 f1 = np.array(zip(timestamps[:len(labels)], labels)) | |
242 # f2 = np.array(zip(timestamps[:len(labels2)], labels2)) | |
243 f3 = np.array(zip(timestamps[:len(classification)], classification)) | |
244 f4 = np.array(zip(timestamps[:len(neighborhood_size)], neighborhood_size)) | |
245 f5 = np.array(zip(timestamps[:len(node_rank)], node_rank)) | |
246 f6 = np.array(zip(timestamps[:len(average_div)], average_div)) | |
247 # | |
248 np.savetxt(join(options.OUTPUT, 'kmeans')+splitext(audio)[0]+feature_name+splitext(audio)[0]+'.csv', f1, delimiter=',') | |
249 # np.savetxt(join(options.OUTPUT, 'dbscan_gmm')+splitext(audio)[0]+'.csv', f2, delimiter=',') | |
250 # np.savetxt(join(options.OUTPUT, 'classification')+splitext(audio)[0]+'.csv', f3, delimiter=',') | |
251 np.savetxt(join(options.OUTPUT, 'neighborhood_size')+splitext(audio)[0]+'.csv', f4, delimiter=',') | |
252 np.savetxt(join(options.OUTPUT, 'node_rank')+splitext(audio)[0]+'.csv', f5, delimiter=',') | |
253 np.savetxt(join(options.OUTPUT, 'average_div')+splitext(audio)[0]+'.csv', f6, delimiter=',') | |
254 | |
255 def main(): | |
256 segmenter = Segmenter() | |
257 segmenter.process() | |
258 | |
259 | |
260 if __name__ == '__main__': | |
261 main() | |
262 |