mi@0
|
1 #!/usr/bin/env python
|
mi@0
|
2 # encoding: utf-8
|
mi@0
|
3 """
|
mi@0
|
4 TempoPathTrackerUtil.py
|
mi@0
|
5
|
mi@0
|
6 Created by George Fazekas on 2014-04-06.
|
mi@0
|
7 Copyright (c) 2014 . All rights reserved.
|
mi@0
|
8
|
mi@0
|
9 This program implements max path tracker combining ideas from dynamic programming and the Hough line transform.
|
mi@0
|
10 It may be used to track tempo tracks in tempograms, partials in STFT spectrograms, or similar tasks.
|
mi@0
|
11
|
mi@0
|
12 """
|
mi@0
|
13
|
mi@0
|
14 import os, sys, itertools
|
mi@0
|
15 from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext
|
mi@0
|
16 from scipy import ndimage
|
mi@0
|
17 from scipy.ndimage.filters import maximum_filter, minimum_filter, median_filter, uniform_filter
|
mi@0
|
18 from math import ceil, floor
|
mi@0
|
19 from numpy import linspace
|
mi@0
|
20 from numpy.linalg import norm
|
mi@0
|
21 import numpy as np
|
mi@0
|
22 import matplotlib.pyplot as plt
|
mi@0
|
23 import matplotlib.image as mpimg
|
mi@0
|
24 import scipy.spatial as ss
|
mi@0
|
25 from math import sqrt
|
mi@0
|
26 from copy import deepcopy
|
mi@0
|
27 from skimage.feature import peak_local_max
|
mi@0
|
28
|
mi@0
|
29 SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data_combined'
|
mi@0
|
30 GT_PATH = '/Users/mitian/Documents/audio/annotation/isophonics'
|
mi@0
|
31 TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/tracks'
|
mi@0
|
32 # SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/ssm_data1'
|
mi@0
|
33 # GT_PATH = '/Users/mitian/Documents/experiments/mit/annotation/qupujicheng1/lowercase'
|
mi@0
|
34 # TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/tracks'
|
mi@0
|
35
|
mi@0
|
36 class Track(object):
|
mi@0
|
37 '''A track object representing a single fixed length path in the data.'''
|
mi@0
|
38
|
mi@0
|
39 track_ID = 0
|
mi@0
|
40
|
mi@0
|
41 def __init__(self,start):
|
mi@0
|
42 self.node_array = []
|
mi@0
|
43 self.pair_array = []
|
mi@0
|
44 self.start = start
|
mi@0
|
45 self.id = Track.track_ID
|
mi@0
|
46 Track.track_ID += 1
|
mi@0
|
47 self.sorted = False
|
mi@0
|
48 self.end = self.get_end()
|
mi@0
|
49
|
mi@0
|
50 def __eq__(self,other):
|
mi@0
|
51 return self.id == other.id
|
mi@0
|
52
|
mi@0
|
53 def add_point(self,point):
|
mi@0
|
54 '''Add a node/point to the trace.'''
|
mi@0
|
55 self.node_array.append(point)
|
mi@0
|
56
|
mi@0
|
57 def add_pairs(self,point,pair):
|
mi@0
|
58 '''Add neighbouring points to aid puruning double traces.'''
|
mi@0
|
59 self.pair_array.append((point,pair))
|
mi@0
|
60
|
mi@0
|
61 @property
|
mi@0
|
62 def length(self):
|
mi@0
|
63 '''Calculate track length on the time axis.'''
|
mi@0
|
64 nodes = np.array(self.node_array)
|
mi@0
|
65 if len(nodes) :
|
mi@0
|
66 return max(nodes[:,0]) - min(nodes[:,0])
|
mi@0
|
67 return 0
|
mi@0
|
68
|
mi@0
|
69 @property
|
mi@0
|
70 def mean(self):
|
mi@0
|
71 nodes = np.array(self.node_array)
|
mi@0
|
72 return nodes.mean()[1]
|
mi@0
|
73
|
mi@0
|
74 @property
|
mi@0
|
75 def start_x(self):
|
mi@0
|
76 return self.start[0]
|
mi@0
|
77
|
mi@0
|
78 '''Replacing the property in the original implementation with a func to avoid the AttributeError: can't set attribute'''
|
mi@0
|
79 # @property
|
mi@0
|
80 # def end(self):
|
mi@0
|
81 # if not self.node_array :
|
mi@0
|
82 # return self.start
|
mi@0
|
83 # if not self.sorted :
|
mi@0
|
84 # self.node_array = sorted(self.node_array)
|
mi@0
|
85 # self.start = self.node_array[0]
|
mi@0
|
86 # return self.node_array[-1]
|
mi@0
|
87
|
mi@0
|
88 def get_end(self):
|
mi@0
|
89 if not self.node_array :
|
mi@0
|
90 return self.start
|
mi@0
|
91 if not self.sorted :
|
mi@0
|
92 self.node_array = sorted(self.node_array)
|
mi@0
|
93 self.start = self.node_array[0]
|
mi@0
|
94 return self.node_array[-1]
|
mi@0
|
95
|
mi@0
|
96 def join(self, other):
|
mi@0
|
97 '''Join self with other by absorbing the nodes of other.'''
|
mi@0
|
98 if not len(other.node_array):
|
mi@0
|
99 # print "Warning: Empty track encountered."
|
mi@0
|
100 return None
|
mi@0
|
101 self.node_array.extend(other.node_array)
|
mi@0
|
102 self.node_array = list(set(self.node_array))
|
mi@0
|
103 if other.end[0] < self.start[0] :
|
mi@0
|
104 print "Info: Starting point moved from ", self.start[0], " to " ,other.end[0]
|
mi@0
|
105 self.start = other.end
|
mi@0
|
106
|
mi@0
|
107 def concatenate(self, other):
|
mi@0
|
108 if (not len(other.node_array)) or (not len(self.node_array)) :
|
mi@0
|
109 # print "Warning: Empty track encountered."
|
mi@0
|
110 return None
|
mi@0
|
111 self.end = other.end
|
mi@0
|
112 self.node_array.extend(other.node_array)
|
mi@0
|
113 self.node_array = list(set(self.node_array))
|
mi@0
|
114 self.node_array.sort()
|
mi@0
|
115
|
mi@0
|
116 class PathTracker(object):
|
mi@0
|
117 '''The main tracker object '''
|
mi@0
|
118
|
mi@0
|
119 def __init__(self):
|
mi@0
|
120 self.track_list = []
|
mi@0
|
121 self.ssm = None
|
mi@0
|
122 self.max_index = None
|
mi@0
|
123 self.kd_tree = None
|
mi@0
|
124 self.group_num = 0
|
mi@0
|
125 self.group = None
|
mi@0
|
126
|
mi@0
|
127 def get_local_maxima(self,ssm, threshold = 0.7, neighborhood_size = 4):
|
mi@0
|
128 '''Find local maxima in the ssm using a minfilt/maxfilt approach.'''
|
mi@0
|
129
|
mi@0
|
130 # # uniform filter to smooth out discontinuities in the tracks
|
mi@0
|
131 # ssm = uniform_filter(ssm, size = neighborhood_size)
|
mi@0
|
132 #
|
mi@0
|
133 # # basic noise reduction
|
mi@0
|
134 # ssm[ssm < threshold] = 0.0
|
mi@0
|
135 # ssm[ssm > threshold] = 1.0
|
mi@0
|
136 #
|
mi@0
|
137 # # maxfilt/minfilt local maxima detection
|
mi@0
|
138 # data_max = maximum_filter(ssm, size = neighborhood_size)
|
mi@0
|
139 # maxima = (ssm == data_max)
|
mi@0
|
140 # data_min = minimum_filter(ssm, size = neighborhood_size)
|
mi@0
|
141 # diff = ((data_max - data_min) > 0.00001)
|
mi@0
|
142 # maxima[diff == 0] = 0
|
mi@0
|
143
|
mi@0
|
144 maxima = (ssm>threshold)
|
mi@0
|
145 # create a list of tuples indexing the nonzero elements of maxima
|
mi@0
|
146 iy,ix = maxima.nonzero()
|
mi@0
|
147 indices = zip(ix,iy)
|
mi@0
|
148 return indices,maxima
|
mi@0
|
149
|
mi@0
|
150 def get_peak_local(self, ssm, thresh=0.8, min_distance=10, threshold_rel=0.8):
|
mi@0
|
151 '''Final local maxima using skimage built-in funcs and return them as coordinates or a boolean array'''
|
mi@0
|
152
|
mi@0
|
153 reduced_ssm = deepcopy(ssm)
|
mi@0
|
154 reduced_ssm[reduced_ssm<thresh] = 0.0 # a hard thresholding for finding maxima
|
mi@0
|
155 ssm[ssm<0.6] = 0.0
|
mi@0
|
156 np.fill_diagonal(reduced_ssm, 0) # zero fill dignonal in case it will be picked as the only maxima in the neighborhood
|
mi@0
|
157 indices = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=True)
|
mi@0
|
158 maxima = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=False)
|
mi@0
|
159 return reduced_ssm, indices, maxima
|
mi@0
|
160
|
mi@0
|
161 def prune_duplicates(self, maxima, size):
|
mi@0
|
162 track_list = deepcopy(self.track_list)
|
mi@0
|
163 # print "len track_list 1", len(track_list)
|
mi@0
|
164 for track in track_list:
|
mi@0
|
165 if not track.node_array:
|
mi@0
|
166 self.track_list.remove(track)
|
mi@0
|
167 # print "len track_list 2", len(self.track_list)
|
mi@0
|
168
|
mi@0
|
169
|
mi@0
|
170 track_list = deepcopy(self.track_list)
|
mi@0
|
171 print "self.track_list start", len(self.track_list)
|
mi@0
|
172 for track1, track2 in itertools.combinations(track_list, 2):
|
mi@0
|
173 points1 = track1.node_array
|
mi@0
|
174 points2 = track2.node_array
|
mi@0
|
175 if abs(track1.end[1] - track2.end[1]) > 10 :
|
mi@0
|
176 continue
|
mi@0
|
177 if abs(track1.start[1] - track2.start[1]) > 10 :
|
mi@0
|
178 continue
|
mi@0
|
179 if abs(track1.start[0] - track2.start[0]) > 10 :
|
mi@0
|
180 continue
|
mi@0
|
181 # print track1.id, track2.id
|
mi@0
|
182 dist = [((i[0]-j[0])**2 + (i[1]-j[1])**2) for i in points1 for j in points2]
|
mi@0
|
183 # if dist and sum(i < size for i in dist) > 1:
|
mi@0
|
184 # print min(dist)
|
mi@0
|
185 if dist and min(dist) < size :
|
mi@0
|
186 # print min(dist)
|
mi@0
|
187 # Nearby track found. If starts from distant positions, concatenate the two,
|
mi@0
|
188 # otherwise discard the one with shorter lengh.
|
mi@0
|
189 if len(points1) < len(points2):
|
mi@0
|
190 duplicate = track1
|
mi@0
|
191 else:
|
mi@0
|
192 duplicate = track2
|
mi@0
|
193 # duplicate = sorted([points1, points2], key=len)[0]
|
mi@0
|
194 if duplicate in self.track_list:
|
mi@0
|
195 self.track_list.remove(duplicate)
|
mi@0
|
196 # print "removing ", duplicate.id
|
mi@0
|
197 print "self.track_list pruned", len(self.track_list)
|
mi@0
|
198
|
mi@0
|
199 def count_groups(self):
|
mi@0
|
200 '''Cluster the tracks within the same horizontal area for later to calcute distance'''
|
mi@0
|
201 self.track_list.sort(key=lambda x: x.start_x)
|
mi@0
|
202 start_points = [track.start for track in self.track_list]
|
mi@0
|
203 # start_points.sort(key=lambda tup: tup[0])
|
mi@0
|
204 for i in xrange(1, len(start_points)):
|
mi@0
|
205 if start_points[i][0] - start_points[i-1][0] > 10.0:
|
mi@0
|
206 self.group_num += 1
|
mi@0
|
207
|
mi@0
|
208 self.groups = [[] for n in xrange(self.group_num)]
|
mi@0
|
209 for track in self.track_list:
|
mi@0
|
210 for group_idx in xrange(self.group_num):
|
mi@0
|
211 self.groups[group_idx].append(track)
|
mi@0
|
212
|
mi@0
|
213 print 'self.groups', len(self.groups)
|
mi@0
|
214 pass
|
mi@0
|
215
|
mi@0
|
216 def histogram(self):
|
mi@0
|
217 '''Compare pairwise distance for tracks within the same x-axis location and group by histograming the distance'''
|
mi@0
|
218 for group in self.groups:
|
mi@0
|
219 group_track = np.array(group)
|
mi@0
|
220
|
mi@0
|
221 pass
|
mi@0
|
222
|
mi@0
|
223 def process(self, ssm, thresh=0.8, min_local_dist=20, slice_size = 2, step_thresh=0.25, track_min_len=50, track_gap=50):
|
mi@0
|
224 '''Track path in the ssm and mask values using the set of discrete path found.'''
|
mi@0
|
225
|
mi@0
|
226 self.ssm = ssm
|
mi@0
|
227 print "ssm.shape",ssm.shape
|
mi@0
|
228
|
mi@0
|
229 # max_index,maxima = self.get_local_maxima(ssm, threshold=0.95, neighborhood_size =3)
|
mi@0
|
230 reduced_ssm,max_index,maxima = self.get_peak_local(ssm, min_distance=min_local_dist, threshold_rel=0.5)
|
mi@0
|
231
|
mi@0
|
232 # build a spatial binary search tree to aid removing maxima already passed by a trace
|
mi@0
|
233 self.max_index = np.array(max_index)
|
mi@0
|
234 if not len(self.max_index):
|
mi@0
|
235 print 'No maxima found.'
|
mi@0
|
236 return np.zeros_like(ssm)
|
mi@0
|
237 self.kd_tree = ss.cKDTree(self.max_index)
|
mi@0
|
238
|
mi@0
|
239 discard_maxima = set()
|
mi@0
|
240
|
mi@0
|
241 # trace forwards
|
mi@0
|
242 for ix,iy in self.max_index :
|
mi@0
|
243 point = (ix,iy)
|
mi@0
|
244 if point in discard_maxima :
|
mi@0
|
245 continue
|
mi@0
|
246 start = point
|
mi@0
|
247 track = Track(start)
|
mi@0
|
248 self.track_list.append(track)
|
mi@0
|
249 while True :
|
mi@0
|
250 slice = self.get_neighbourhood(point, size = slice_size)
|
mi@0
|
251 x,y = self.step(point, slice, threshold = step_thresh, direction = "forward")
|
mi@0
|
252 if x == None : break
|
mi@0
|
253 point = (x,y)
|
mi@0
|
254 remove = self.get_nearest_maxima(point)
|
mi@0
|
255 if remove and remove != start:
|
mi@0
|
256 discard_maxima.add(remove)
|
mi@0
|
257 maxima[y,x] = True
|
mi@0
|
258 track.add_point(point)
|
mi@0
|
259 print "discarded maxima: ",len(discard_maxima)
|
mi@0
|
260
|
mi@0
|
261 self.max_index = [(x,y) for x,y in self.max_index if (x,y) not in discard_maxima]
|
mi@0
|
262
|
mi@0
|
263 # trace back
|
mi@0
|
264 print "Tracing back..."
|
mi@0
|
265 for ix,iy in self.max_index :
|
mi@0
|
266 point = (ix,iy)
|
mi@0
|
267 track = Track(point)
|
mi@0
|
268 self.track_list.append(track)
|
mi@0
|
269 while True :
|
mi@0
|
270 slice = self.get_neighbourhood(point, size = slice_size)
|
mi@0
|
271 x,y = self.step(point, slice, threshold = step_thresh, direction = "backward")
|
mi@0
|
272 if x == None : break
|
mi@0
|
273 point = (x,y)
|
mi@0
|
274 track.add_point(point)
|
mi@0
|
275 maxima[y,x] = True
|
mi@0
|
276
|
mi@0
|
277 print "tracing done."
|
mi@0
|
278
|
mi@0
|
279 print 'tracks after tracing:', len(self.track_list)
|
mi@0
|
280 # join forward and back traces with the same staring point
|
mi@0
|
281 self.join_tracks()
|
mi@0
|
282
|
mi@0
|
283 # concatenate nearby tracks on the same diagonal direction
|
mi@0
|
284 self.concatenate_tracks(size=track_gap)
|
mi@0
|
285
|
mi@0
|
286 # prune duplicated tracks in local neighbourhood
|
mi@0
|
287 # self.prune_duplicates(maxima, size = 10)
|
mi@0
|
288 maxima = maximum_filter(maxima, size=2)
|
mi@0
|
289 # TODO: smooth paths, experiment with segmentation of individual tracks...
|
mi@0
|
290 self.count_groups()
|
mi@0
|
291
|
mi@0
|
292 # empty mask for visualisation / further processing
|
mi@0
|
293 tracks = np.zeros_like(maxima)
|
mi@0
|
294 ssm_len = tracks.shape[0]
|
mi@0
|
295 # assess tracks individually, skip short ones and add the rest of the tracks to the mask
|
mi@0
|
296 for track in self.track_list :
|
mi@0
|
297 if track.length < track_min_len : continue
|
mi@0
|
298 track.node_array.sort()
|
mi@0
|
299 # for point in track.node_array :
|
mi@0
|
300 # tracks[point[1],point[0]] = 1.0
|
mi@0
|
301 xs, xe = track.node_array[0][1], track.node_array[-1][1]
|
mi@0
|
302 ys, ye = track.node_array[0][0], track.node_array[-1][0]
|
mi@0
|
303 track_len = xe - xs
|
mi@0
|
304 for i in xrange(track_len):
|
mi@0
|
305 if max(xs+i, ys+i) < ssm_len:
|
mi@0
|
306 tracks[xs+i, ys+i] = 1.0
|
mi@0
|
307 print 'number of final tracks', len(self.track_list)
|
mi@0
|
308 # tracks = uniform_filter(tracks.astype(np.float32), size = 2)
|
mi@0
|
309 # tracks[tracks<0.2] = 0.0
|
mi@0
|
310 # tracks[tracks>=0.2] = 1.0
|
mi@0
|
311
|
mi@0
|
312 return reduced_ssm, self.max_index, tracks
|
mi@0
|
313
|
mi@0
|
314
|
mi@0
|
315 def join_tracks(self):
|
mi@0
|
316 '''Join tracks which share a common starting point.
|
mi@0
|
317 This function is essentially trying to join forward traces and back traces.'''
|
mi@0
|
318
|
mi@0
|
319 # collect the set of unique starting points
|
mi@0
|
320 start_points = set()
|
mi@0
|
321 [start_points.add(track.start) for track in self.track_list]
|
mi@0
|
322 print "Initial Traces before joining:", len(self.track_list)
|
mi@0
|
323 print "Unique start points:", len(start_points)
|
mi@0
|
324
|
mi@0
|
325 # join tracks starting from the same point and remove the residual
|
mi@0
|
326 for start in start_points:
|
mi@0
|
327 shared_tracks = [x for x in self.track_list if x.start == start]
|
mi@0
|
328 if len(shared_tracks) == 2 :
|
mi@0
|
329 shared_tracks[1].join(shared_tracks[0])
|
mi@0
|
330 self.track_list.remove(shared_tracks[0])
|
mi@0
|
331 print "Final tracklist after joining", len(self.track_list)
|
mi@0
|
332 return self.track_list
|
mi@0
|
333
|
mi@0
|
334 def concatenate_tracks(self, size=3):
|
mi@0
|
335 '''Concatenate the end point and start point of two sequential tracks.'''
|
mi@0
|
336
|
mi@0
|
337 start_points = set()
|
mi@0
|
338 [start_points.add(track.start) for track in self.track_list]
|
mi@0
|
339 end_points = set()
|
mi@0
|
340 [end_points.add(track.end) for track in self.track_list]
|
mi@0
|
341 print "Traces before concatenation:", len(self.track_list), len(start_points), len(end_points)
|
mi@0
|
342 for end in end_points:
|
mi@0
|
343 xe, ye = end
|
mi@0
|
344 if not [x for x in self.track_list if (x.end == end and x.length >1)]: continue
|
mi@0
|
345 track = [x for x in self.track_list if x.end == end][0]
|
mi@0
|
346 for i in xrange(1, size):
|
mi@0
|
347 xs, ys = xe+i, ye+i
|
mi@0
|
348 if (xs, ys) in start_points:
|
mi@0
|
349 succeeding_track_list = [x for x in self.track_list if x.start == (xs,ys)]
|
mi@0
|
350 if not succeeding_track_list: continue
|
mi@0
|
351 succeeding_track = [x for x in self.track_list if x.start == (xs,ys)][0]
|
mi@0
|
352 track.concatenate(succeeding_track)
|
mi@0
|
353 self.track_list.remove(succeeding_track)
|
mi@0
|
354 print "Traces after concatenation:", len(self.track_list)
|
mi@0
|
355 return self.track_list
|
mi@0
|
356
|
mi@0
|
357 def get_nearest_maxima(self,point,threshold = 5.0):
|
mi@0
|
358 '''Find the nearest maxima to a given point using NN serach in the array of known maxima.
|
mi@0
|
359 NN serach is done usinf a KD-Tree approach because pairwise comparison is way too slow.'''
|
mi@0
|
360
|
mi@0
|
361 # query tree parameters: k is the number of nearest neighbours to return, d is the distance type used (2: Euclidean),
|
mi@0
|
362 # distance_upper_bound specifies search realm
|
mi@0
|
363 d,i = self.kd_tree.query(point, k=1, p=2, distance_upper_bound= threshold)
|
mi@0
|
364 if d != np.inf :
|
mi@0
|
365 return tuple(self.max_index[i,:])
|
mi@0
|
366 return None
|
mi@0
|
367
|
mi@0
|
368
|
mi@0
|
369 def get_neighbourhood(self,point,size=1):
|
mi@0
|
370 '''Return a square matrix centered around a given point
|
mi@0
|
371 with zero padding if point is close to the edges of the data array.'''
|
mi@0
|
372
|
mi@0
|
373 # calculate boundaries
|
mi@0
|
374 xs = point[0]-size
|
mi@0
|
375 xe = point[0]+size+1
|
mi@0
|
376 ys = point[1]-size
|
mi@0
|
377 ye = point[1]+size+1
|
mi@0
|
378
|
mi@0
|
379 # extract slice from the array cropped at edges
|
mi@0
|
380 y,x = self.ssm.shape
|
mi@0
|
381 slice = self.ssm[max(0,ys):min(ye,y),max(0,xs):min(xe,x)]
|
mi@0
|
382
|
mi@0
|
383 # left/right padding
|
mi@0
|
384 if xs < 0 :
|
mi@0
|
385 leftpad = np.zeros((slice.shape[0],abs(xs)))
|
mi@0
|
386 slice = np.hstack([leftpad,slice])
|
mi@0
|
387
|
mi@0
|
388 if xe > x :
|
mi@0
|
389 rightpad = np.zeros((slice.shape[0],xe-x))
|
mi@0
|
390 slice = np.hstack([slice,rightpad])
|
mi@0
|
391
|
mi@0
|
392 # top/bottom padding
|
mi@0
|
393 if ys < 0 :
|
mi@0
|
394 bottompad = np.zeros((abs(ys),slice.shape[1]))
|
mi@0
|
395 slice = np.vstack([bottompad,slice])
|
mi@0
|
396
|
mi@0
|
397 if ye > y :
|
mi@0
|
398 toppad = np.zeros((ye-y,slice.shape[1]))
|
mi@0
|
399 slice = np.vstack([slice,toppad])
|
mi@0
|
400
|
mi@0
|
401 return slice
|
mi@0
|
402
|
mi@0
|
403
|
mi@0
|
404 def step(self, point, slice, threshold = 0.3, direction = "forward"):
|
mi@0
|
405 '''Choose a step from the given point and retun the coordinate of the selected point.
|
mi@0
|
406
|
mi@0
|
407 inputs:
|
mi@0
|
408 point (x,y) is the starting coordinate in the data matrix,
|
mi@0
|
409 slice is a square matrix centered around the given point,
|
mi@0
|
410 threshold helps to decide where to terminate a track,
|
mi@0
|
411 direction {forwards | backwards} describes which way to track along the X axis.
|
mi@0
|
412
|
mi@0
|
413 output:
|
mi@0
|
414 The output is always a tuple.
|
mi@0
|
415 (None,None) in case the track is terminated or reached the boundary of the data matrix.
|
mi@0
|
416 (x,y) for the next valid step forwards or backwards.
|
mi@0
|
417
|
mi@0
|
418 Note: The algorithm never steps straight up or down, i.e. the next coordinate relates to
|
mi@0
|
419 either the next or the previous point on the x axis.
|
mi@0
|
420
|
mi@0
|
421 Note2: The intuition of this algorithm relates to both classical dynamic programming search
|
mi@0
|
422 and that of the Hough line transform. At each step a weighted line segment is considered
|
mi@0
|
423 corresponding to the slice of the data slice around the considered point. The line segment
|
mi@0
|
424 is rotated around the center point and the most higlhly weighted is choosen which prescribes
|
mi@0
|
425 the step direction of the algorithm.
|
mi@0
|
426 '''
|
mi@0
|
427
|
mi@0
|
428 backward = False
|
mi@0
|
429 if direction == 'backward':
|
mi@0
|
430 backward = True
|
mi@0
|
431 x,y = point
|
mi@0
|
432
|
mi@0
|
433 # create direction specific weight vector
|
mi@0
|
434 w = np.linspace(0.0, 1.0, slice.shape[0])
|
mi@0
|
435 if backward : w = w[::-1]
|
mi@0
|
436
|
mi@0
|
437 # calcualte weighted sums of main diagonal
|
mi@0
|
438 a = sum(slice.diagonal() * w)
|
mi@0
|
439 segment_weight = a.max() / sum(w)
|
mi@0
|
440
|
mi@0
|
441 # adjust steps for desired direction
|
mi@0
|
442 direction = 1
|
mi@0
|
443 xstep = 1
|
mi@0
|
444 if backward :
|
mi@0
|
445 xstep = -1
|
mi@0
|
446 direction *= -1
|
mi@0
|
447
|
mi@0
|
448 xs,ys = x+xstep, y+direction
|
mi@0
|
449 yd,xd = self.ssm.shape
|
mi@0
|
450
|
mi@0
|
451 # Terminate tracking if the weighted mean of the segment is below a threshold
|
mi@0
|
452 if segment_weight < threshold :
|
mi@0
|
453 # print "Terminating due to thd"
|
mi@0
|
454 return None,None
|
mi@0
|
455
|
mi@0
|
456 # Terminate tracking if data matrix bounds are reached
|
mi@0
|
457 if xs < 0 or xs >= xd or ys < 0 or ys >= yd :
|
mi@0
|
458 # print "Terminating due to bound"
|
mi@0
|
459 return None,None
|
mi@0
|
460
|
mi@0
|
461 return xs,ys
|
mi@0
|
462
|
mi@0
|
463
|
mi@0
|
464 def main():
|
mi@0
|
465
|
mi@0
|
466 plot = "-p" in sys.argv
|
mi@0
|
467 plot = True
|
mi@0
|
468
|
mi@0
|
469 tracker = PathTracker()
|
mi@0
|
470
|
mi@0
|
471 # ssm = np.loadtxt('/Users/mitian/Documents/hg/py-features/data/ssm.txt', delimiter=',')
|
mi@0
|
472 # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/06YellowSubmarine.txt',usecols=0)
|
mi@0
|
473 # ssm = np.loadtxt('/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data/1-12ShesOutOfMyLife-otsu.txt', delimiter=',')
|
mi@0
|
474 # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/1-12ShesOutOfMyLife.txt',usecols=0)
|
mi@0
|
475
|
mi@0
|
476 ssm_files = [x for x in os.listdir(SSM_PATH) if not x.startswith('.')]
|
mi@0
|
477 ssm_files = [join(SSM_PATH, x) for x in ssm_files]
|
mi@0
|
478 ssm_files.sort()
|
mi@0
|
479 gt_files = [x for x in os.listdir(GT_PATH) if not x.startswith('.')]
|
mi@0
|
480 gt_files = [join(GT_PATH, x) for x in gt_files]
|
mi@0
|
481 gt_files.sort()
|
mi@0
|
482
|
mi@0
|
483 for i, x in enumerate(ssm_files):
|
mi@0
|
484 ssm = np.genfromtxt(x, delimiter=',')
|
mi@0
|
485 gt = np.genfromtxt(gt_files[i], usecols=0)
|
mi@0
|
486 # gt = np.genfromtxt(gt_files[i], delimiter=',', usecols=0)
|
mi@0
|
487 audio_name = splitext(basename(gt_files[i]))[0]
|
mi@0
|
488 if isfile(join(TRACK_PATH, audio_name+'.txt')): continue
|
mi@0
|
489 print 'Processing:', audio_name
|
mi@0
|
490
|
mi@0
|
491 reduced_ssm, maxima, tracks = tracker.process(ssm, thresh=0.5, min_local_dist=20, slice_size=20, step_thresh=0.4, track_min_len=50, track_gap=50)
|
mi@0
|
492 np.savetxt(join(TRACK_PATH, audio_name+'.txt'), tracks, delimiter=',')
|
mi@0
|
493
|
mi@0
|
494 track_df = np.sum(tracks, axis=-1)
|
mi@0
|
495 # track_df = np.zeros(len(tracks))
|
mi@0
|
496 # print len(tracker.track_list)
|
mi@0
|
497 # for track in tracker.track_list:
|
mi@0
|
498 # start, end = track.start[0], track.end[0]
|
mi@0
|
499 # # if (track.length != len(tracks)-1 and start < end):
|
mi@0
|
500 # # track_df[start:end] += 1
|
mi@0
|
501 # track_df[start] += 1
|
mi@0
|
502 # track_df[end] += 1
|
mi@0
|
503
|
mi@0
|
504 if plot :
|
mi@0
|
505 ax1 = plt.subplot(131)
|
mi@0
|
506 ax1.imshow(ssm, cmap='Greys')
|
mi@0
|
507 ax1.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r')
|
mi@0
|
508
|
mi@0
|
509 ax2 = plt.subplot(132)
|
mi@0
|
510 ax2.imshow(reduced_ssm, cmap='Greys')
|
mi@0
|
511 ax2.scatter(zip(*maxima)[0], zip(*maxima)[1], s=5, c='y')
|
mi@0
|
512 ax2.set_xlim([0, len(tracks)])
|
mi@0
|
513 ax2.set_ylim([len(tracks), 0])
|
mi@0
|
514
|
mi@0
|
515 ax3 = plt.subplot(133)
|
mi@0
|
516 ax3.imshow(tracks, cmap='Greys')
|
mi@0
|
517 # ax2.plot(np.arange(0, len(tracks)), track_df*10)
|
mi@0
|
518 ax3.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r')
|
mi@0
|
519 ax3.set_xlim([0, len(tracks)])
|
mi@0
|
520 ax3.set_ylim([len(tracks), 0])
|
mi@0
|
521 # plt.show()
|
mi@0
|
522 plt.savefig(join(TRACK_PATH, audio_name+'.pdf'), fomat='pdf')
|
mi@0
|
523 plt.close()
|
mi@0
|
524 # smoothing funcs
|
mi@0
|
525
|
mi@0
|
526
|
mi@0
|
527 if __name__ == '__main__':
|
mi@0
|
528 main()
|