comparison ProbSegmenter.py @ 13:cc8ceb270e79

add some gmm ipynbs
author mitian
date Fri, 05 Jun 2015 18:02:05 +0100
parents
children 6dae41887406
comparison
equal deleted inserted replaced
12:c23658e8ae38 13:cc8ceb270e79
1 #!/usr/bin/env python
2 # encoding: utf-8
3 """
4 ProbSegmenter.py
5 """
6
7 import matplotlib
8 matplotlib.use('Agg')
9 import matplotlib.pyplot as plt
10
11 import sys, os, optparse, cPickle
12 import numpy as np
13 from numpy import abs,log,exp,floor,sum,sqrt,cos,hstack, power
14 from math import ceil, floor
15 from scipy.signal import correlate2d, convolve2d
16 from itertools import combinations
17 from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext
18
19 from sklearn.decomposition import PCA
20 from sklearn.mixture import GMM
21 from sklearn.metrics.pairwise import pairwise_distances
22 from sklearn.cluster import KMeans, DBSCAN
23 from scipy.spatial.distance import squareform, pdist
24
25 from utils.GmmMetrics import GmmDistance
26 from utils.RankClustering import rClustering
27 from utils.kmeans import Kmeans
28 from utils.ComputationCache import Meta, with_pickle_dump
29 from utils.SegUtil import normaliseFeature, upSample, getSSM, getMean, getStd, getDelta
30
31 def parse_args():
32 # define parser
33 op = optparse.OptionParser()
34 # IO options
35 op.add_option('-i', '--input', action="store", dest="INPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/features1', type="str", help="Loading features from.." )
36 op.add_option('-a', '--audioset', action="store", dest="AUDIO", default=None, type="str", help="Select audio datasets ['qupujicheng', 'salami', 'beatles'] ")
37 op.add_option('-g', '--groundtruth', action="store", dest="GT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/annotations/lowercase', type="str", help="Loading annotation files from.. ")
38 op.add_option('-o', '--output', action="store", dest="OUTPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering', type="str", help="Loading annotation files from.. ")
39 op.add_option('-c', '--cache', action="store", dest="CACHE", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering/cache', type="str", help="Loading annotation files from.. ")
40 op.add_option('-t', '--test', action="store_true", dest="TEST", help="Select TEST mode.")
41 op.add_option('-v', '--verbose', action="store_true", dest="VERBOSE", help="Select VERBOSE mode to save features/SSMs.")
42
43 return op.parse_args()
44
45 options, args = parse_args()
46
47 class FeatureObj() :
48 __slots__ = ['key','audio','timestamps','features']
49
50 class AudioObj():
51 __slots__ = ['feature_list','feature_matrix','distance_matrix','name']
52
53 class Segmenter(object):
54 '''The main segmentation object'''
55
56 meta = Meta()
57
58 @classmethod
59 def set_cache_filename(cls, filename, cache = True, cache_location=""):
60 cls.meta.cache = cache
61 cls.meta.cache_file_base = filename
62 cls.meta.cache_location = cache_location
63
64 @with_pickle_dump(meta)
65 def getGMMs(self, feature, filename, gmmWindow=10, stepsize=1, save=True):
66 gmm_list = []
67 steps = int((feature.shape[0] - gmmWindow + stepsize) / stepsize)
68 for i in xrange(steps):
69 gmm_list.append(GmmDistance(feature[i*stepsize:(i*stepsize+gmmWindow), :].T, components = 2))
70 # if save:
71 # with open(join('cache', filename), 'w+') as f:
72 # f.write(cPickle.dumps(gmm_list))
73 return gmm_list
74
75 def pairwiseSKL(self, gmm_list):
76 '''Compute pairwise symmetrised KL divergence of a list of GMMs.'''
77 n_GMMs = len(gmm_list)
78 distance_matrix = np.zeros((n_GMMs, n_GMMs))
79 for i in xrange(n_GMMs):
80 for j in xrange(i, n_GMMs):
81 distance_matrix[i][j] = gmm_list[i].skl_distance_full(gmm_list[j])
82 distance_matrix[j][i] = distance_matrix[i][j]
83
84 np.fill_diagonal(distance_matrix, 0.0)
85 distance_matrix[np.isnan(distance_matrix)] = 0
86 # distance_matrix[np.isinf(distance_matrix)] = np.finfo(np.float64).max
87 if np.isinf(distance_matrix).any():
88 data = np.sort(np.ndarray.flatten(distance_matrix))
89 pos = np.where(data == np.inf)[0][0]
90 fMax = data[pos-1]
91 print len(data), pos, fMax
92 distance_matrix[np.isinf(distance_matrix)] = fMax
93 return distance_matrix
94
95 def smoothLabels(self, label_list, size=5):
96 '''Smooth label list within given length.'''
97 prev_labels = -1
98 next_labels = -1
99 chain = 0
100 for i in xrange(size, len(label_list)-size):
101 label = label_list[i]
102 if label == prev_label:
103 chain += 1
104 else:
105 if chain < size:
106 label_list[i] = prev_label
107 chain = 0
108 prev_label = label_list[i]
109 print chain
110 return label_list
111
112 def getInitialCentroids(self, neighborhood_size, k=10):
113 candidates = []
114 size = len(neighborhood_size)
115 for i in xrange(1, size-1):
116 d1 = neighborhood_size[i] - neighborhood_size[i-1]
117 d2 = neighborhood_size[i] - neighborhood_size[i+1]
118 if d1 > 0 and d2 > 0:
119 candidates.append((i, max(d1, d2)))
120 print 'candidates', len(candidates), candidates, size
121 ranked_nodes = sorted(candidates, key=lambda x: x[1],reverse=True)[:k]
122 ranked_nodes = [ranked_nodes[i][0] for i in xrange(len(ranked_nodes))]
123 return ranked_nodes
124
125
126 def process(self):
127
128 audio_files = [x for x in os.listdir(options.GT) if not x.startswith(".") ]
129 audio_files.sort()
130 if options.TEST:
131 # audio_files = ["""17 We Are The Champions.wav"""]
132 audio_files = audio_files[:2]
133 audio_list = []
134
135 fobj_list = []
136 feature_list = [i for i in os.listdir(options.INPUT) if not i.startswith('.')]
137 feature_list = ['pcamean', 'dct', 'contrast6']
138
139 feature_list.sort()
140
141 winlength = 50
142 stepsize = 50
143
144 if options.AUDIO == None:
145 print 'Must specify audio dataset for evaluvation!'
146
147 for i, audio in enumerate(audio_files) :
148 ao = AudioObj()
149 ao.name = splitext(audio)[0]
150
151 ao_featureset = []
152 for feature in feature_list :
153 for f in os.listdir(join(options.INPUT, feature)):
154 if f[:f.find('_vamp')]==ao.name:
155 data = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0)[:, 1:]
156 ao_featureset.append(data)
157 break
158
159 n_features = len(ao_featureset)
160 if n_features == 0: continue
161
162 if n_features > 1:
163 # find the feature with the fewer number of frames (the last a few frames should be generally empty)
164 n_frame = np.min([x.shape[0] for x in ao_featureset])
165 ao_featureset = [x[:n_frame,:] for x in ao_featureset]
166 feature_matrix = np.hstack((ao_featureset))
167 else:
168 feature_matrix = ao_featureset[0]
169
170 print "Processing data for audio file:", audio, n_features
171 if options.AUDIO == 'salami':
172 annotation_file = join(options.GT, ao.name+'.txt') # iso, salami
173 ao.gt = np.genfromtxt(annotation_file, usecols=0)
174 elif options.AUDIO == 'qupujicheng':
175 annotation_file = join(options.GT, ao.name+'.csv') # qupujicheng
176 ao.gt = np.genfromtxt(annotation_file, usecols=0, delimiter=',')
177 elif options.AUDIO == 'beatles':
178 annotation_file = join(options.GT, ao.name+'.lab') # beatles
179 ao.gt = np.genfromtxt(annotation_file, usecols=(0,1))
180 ao.gt = np.unique(np.ndarray.flatten(ao.gt))
181
182 n_frames = feature_matrix.shape[0]
183
184 timestamps = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0, usecols=0)
185 # map timestamps to the reduced representations
186 timestamps = timestamps[0::stepsize]
187
188 # # normalise the feature matrix, get rid of negative features, ensure numerical stability by adding a small constant
189 feature_matrix = normaliseFeature(feature_matrix)
190
191 # np.savetxt('test/feature_maxtrix-'+ao.name+"-wl%i-ss%i.txt" %(winlength,stepsize), feature_matrix, delimiter=',')
192 fr = basename(options.INPUT)
193 print 'fr', fr, options.INPUT
194 feature_name = ''
195 for feature in feature_list:
196 feature_name += ('-' + feature)
197 # np.savetxt(join(options.OUTPUT, test, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), feature_matrix, delimiter=',')
198
199 # PCA
200 pca = PCA(n_components=6)
201 pca.fit(feature_matrix)
202 feature_matrix = pca.transform(feature_matrix)
203
204 cach_filename = ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize)
205 self.set_cache_filename(filename=cach_filename, cache_location=options.CACHE)
206 gmm_list = self.getGMMs(feature_matrix, filename=cach_filename, gmmWindow=winlength, stepsize=stepsize)
207 print 'number of GMMs:', len(gmm_list)
208
209 skl_matrix = self.pairwiseSKL(gmm_list)
210 ssm = getSSM(skl_matrix)
211 np.savetxt(join(options.CACHE, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), ssm, delimiter=",")
212
213 # # 1. DBSCAN clustering of raw feature
214 # db1 = DBSCAN(eps=10, min_samples=10).fit(feature_array)
215 # core_samples_mask1 = np.zeros_like(db1.labels_, dtype=bool)
216 # core_samples_mask1[db1.core_sample_indices_] = True
217 # labels1 = db1.labels_
218
219 # # 2. DBSCAN clustering of GMMs
220 # db2 = DBSCAN(eps=0.05, min_samples=10, metric='precomputed').fit(skl_matrix)
221 # core_samples_mask2 = np.zeros_like(db2.labels_, dtype=bool)
222 # core_samples_mask2[db2.core_sample_indices_] = True
223 # labels2 = db2.labels_
224
225 # 3. RC clustering of raw GMMs
226 rc = rClustering(eps=1.15, k=5, rank='max_neighbors')
227 rc.set_cache_filename(ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize), cache_location=options.CACHE)
228 rc.fit(gmm_list)
229 rc.test()
230 classification = rc.classification
231 print 'classification', classification
232 neighborhood_size, average_div, node_rank = rc.getNodeRank()
233
234 # centroid_list = self.getInitialCentroids(neighborhood_size, k=10)
235 # print 'initial centroids', centroid_list
236
237 # k-means clustering of GMMs
238 KmeansClustering = Kmeans(gmm_list, K=5, initial_centroids=set(classification))
239 labels = KmeansClustering.fit()
240
241 f1 = np.array(zip(timestamps[:len(labels)], labels))
242 # f2 = np.array(zip(timestamps[:len(labels2)], labels2))
243 f3 = np.array(zip(timestamps[:len(classification)], classification))
244 f4 = np.array(zip(timestamps[:len(neighborhood_size)], neighborhood_size))
245 f5 = np.array(zip(timestamps[:len(node_rank)], node_rank))
246 f6 = np.array(zip(timestamps[:len(average_div)], average_div))
247 #
248 np.savetxt(join(options.OUTPUT, 'kmeans')+splitext(audio)[0]+feature_name+splitext(audio)[0]+'.csv', f1, delimiter=',')
249 # np.savetxt(join(options.OUTPUT, 'dbscan_gmm')+splitext(audio)[0]+'.csv', f2, delimiter=',')
250 # np.savetxt(join(options.OUTPUT, 'classification')+splitext(audio)[0]+'.csv', f3, delimiter=',')
251 np.savetxt(join(options.OUTPUT, 'neighborhood_size')+splitext(audio)[0]+'.csv', f4, delimiter=',')
252 np.savetxt(join(options.OUTPUT, 'node_rank')+splitext(audio)[0]+'.csv', f5, delimiter=',')
253 np.savetxt(join(options.OUTPUT, 'average_div')+splitext(audio)[0]+'.csv', f6, delimiter=',')
254
255 def main():
256 segmenter = Segmenter()
257 segmenter.process()
258
259
260 if __name__ == '__main__':
261 main()
262