Mercurial > hg > segmentation
view utils/PathTracker.py @ 17:c01fcb752221
new annotations
author | mitian |
---|---|
date | Fri, 21 Aug 2015 10:15:29 +0100 |
parents | 26838b1f560f |
children |
line wrap: on
line source
#!/usr/bin/env python # encoding: utf-8 """ TempoPathTrackerUtil.py Created by George Fazekas on 2014-04-06. Copyright (c) 2014 . All rights reserved. This program implements max path tracker combining ideas from dynamic programming and the Hough line transform. It may be used to track tempo tracks in tempograms, partials in STFT spectrograms, or similar tasks. """ import os, sys, itertools from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext from scipy import ndimage from scipy.ndimage.filters import maximum_filter, minimum_filter, median_filter, uniform_filter from math import ceil, floor from numpy import linspace from numpy.linalg import norm import numpy as np import matplotlib.pyplot as plt import matplotlib.image as mpimg import scipy.spatial as ss from math import sqrt from copy import deepcopy from skimage.feature import peak_local_max SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data_combined' GT_PATH = '/Users/mitian/Documents/audio/annotation/isophonics' TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/tracks' # SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/ssm_data1' # GT_PATH = '/Users/mitian/Documents/experiments/mit/annotation/qupujicheng1/lowercase' # TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/tracks' class Track(object): '''A track object representing a single fixed length path in the data.''' track_ID = 0 def __init__(self,start): self.node_array = [] self.pair_array = [] self.start = start self.id = Track.track_ID Track.track_ID += 1 self.sorted = False self.end = self.get_end() def __eq__(self,other): return self.id == other.id def add_point(self,point): '''Add a node/point to the trace.''' self.node_array.append(point) def add_pairs(self,point,pair): '''Add neighbouring points to aid puruning double traces.''' self.pair_array.append((point,pair)) @property def length(self): '''Calculate track length on the time axis.''' nodes = np.array(self.node_array) if len(nodes) : return max(nodes[:,0]) - min(nodes[:,0]) return 0 @property def mean(self): nodes = np.array(self.node_array) return nodes.mean()[1] @property def start_x(self): return self.start[0] '''Replacing the property in the original implementation with a func to avoid the AttributeError: can't set attribute''' # @property # def end(self): # if not self.node_array : # return self.start # if not self.sorted : # self.node_array = sorted(self.node_array) # self.start = self.node_array[0] # return self.node_array[-1] def get_end(self): if not self.node_array : return self.start if not self.sorted : self.node_array = sorted(self.node_array) self.start = self.node_array[0] return self.node_array[-1] def join(self, other): '''Join self with other by absorbing the nodes of other.''' if not len(other.node_array): # print "Warning: Empty track encountered." return None self.node_array.extend(other.node_array) self.node_array = list(set(self.node_array)) if other.end[0] < self.start[0] : print "Info: Starting point moved from ", self.start[0], " to " ,other.end[0] self.start = other.end def concatenate(self, other): if (not len(other.node_array)) or (not len(self.node_array)) : # print "Warning: Empty track encountered." return None self.end = other.end self.node_array.extend(other.node_array) self.node_array = list(set(self.node_array)) self.node_array.sort() class PathTracker(object): '''The main tracker object ''' def __init__(self): self.track_list = [] self.ssm = None self.max_index = None self.kd_tree = None self.group_num = 0 self.group = None def get_local_maxima(self,ssm, threshold = 0.7, neighborhood_size = 4): '''Find local maxima in the ssm using a minfilt/maxfilt approach.''' # # uniform filter to smooth out discontinuities in the tracks # ssm = uniform_filter(ssm, size = neighborhood_size) # # # basic noise reduction # ssm[ssm < threshold] = 0.0 # ssm[ssm > threshold] = 1.0 # # # maxfilt/minfilt local maxima detection # data_max = maximum_filter(ssm, size = neighborhood_size) # maxima = (ssm == data_max) # data_min = minimum_filter(ssm, size = neighborhood_size) # diff = ((data_max - data_min) > 0.00001) # maxima[diff == 0] = 0 maxima = (ssm>threshold) # create a list of tuples indexing the nonzero elements of maxima iy,ix = maxima.nonzero() indices = zip(ix,iy) return indices,maxima def get_peak_local(self, ssm, thresh=0.8, min_distance=10, threshold_rel=0.8): '''Final local maxima using skimage built-in funcs and return them as coordinates or a boolean array''' reduced_ssm = deepcopy(ssm) reduced_ssm[reduced_ssm<thresh] = 0.0 # a hard thresholding for finding maxima ssm[ssm<0.6] = 0.0 np.fill_diagonal(reduced_ssm, 0) # zero fill dignonal in case it will be picked as the only maxima in the neighborhood indices = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=True) maxima = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=False) return reduced_ssm, indices, maxima def prune_duplicates(self, maxima, size): track_list = deepcopy(self.track_list) # print "len track_list 1", len(track_list) for track in track_list: if not track.node_array: self.track_list.remove(track) # print "len track_list 2", len(self.track_list) track_list = deepcopy(self.track_list) print "self.track_list start", len(self.track_list) for track1, track2 in itertools.combinations(track_list, 2): points1 = track1.node_array points2 = track2.node_array if abs(track1.end[1] - track2.end[1]) > 10 : continue if abs(track1.start[1] - track2.start[1]) > 10 : continue if abs(track1.start[0] - track2.start[0]) > 10 : continue # print track1.id, track2.id dist = [((i[0]-j[0])**2 + (i[1]-j[1])**2) for i in points1 for j in points2] # if dist and sum(i < size for i in dist) > 1: # print min(dist) if dist and min(dist) < size : # print min(dist) # Nearby track found. If starts from distant positions, concatenate the two, # otherwise discard the one with shorter lengh. if len(points1) < len(points2): duplicate = track1 else: duplicate = track2 # duplicate = sorted([points1, points2], key=len)[0] if duplicate in self.track_list: self.track_list.remove(duplicate) # print "removing ", duplicate.id print "self.track_list pruned", len(self.track_list) def count_groups(self): '''Cluster the tracks within the same horizontal area for later to calcute distance''' self.track_list.sort(key=lambda x: x.start_x) start_points = [track.start for track in self.track_list] # start_points.sort(key=lambda tup: tup[0]) for i in xrange(1, len(start_points)): if start_points[i][0] - start_points[i-1][0] > 10.0: self.group_num += 1 self.groups = [[] for n in xrange(self.group_num)] for track in self.track_list: for group_idx in xrange(self.group_num): self.groups[group_idx].append(track) print 'self.groups', len(self.groups) pass def histogram(self): '''Compare pairwise distance for tracks within the same x-axis location and group by histograming the distance''' for group in self.groups: group_track = np.array(group) pass def process(self, ssm, thresh=0.8, min_local_dist=20, slice_size = 2, step_thresh=0.25, track_min_len=50, track_gap=50): '''Track path in the ssm and mask values using the set of discrete path found.''' self.ssm = ssm print "ssm.shape",ssm.shape # max_index,maxima = self.get_local_maxima(ssm, threshold=0.95, neighborhood_size =3) reduced_ssm,max_index,maxima = self.get_peak_local(ssm, min_distance=min_local_dist, threshold_rel=0.5) # build a spatial binary search tree to aid removing maxima already passed by a trace self.max_index = np.array(max_index) if not len(self.max_index): print 'No maxima found.' return np.zeros_like(ssm) self.kd_tree = ss.cKDTree(self.max_index) discard_maxima = set() # trace forwards for ix,iy in self.max_index : point = (ix,iy) if point in discard_maxima : continue start = point track = Track(start) self.track_list.append(track) while True : slice = self.get_neighbourhood(point, size = slice_size) x,y = self.step(point, slice, threshold = step_thresh, direction = "forward") if x == None : break point = (x,y) remove = self.get_nearest_maxima(point) if remove and remove != start: discard_maxima.add(remove) maxima[y,x] = True track.add_point(point) print "discarded maxima: ",len(discard_maxima) self.max_index = [(x,y) for x,y in self.max_index if (x,y) not in discard_maxima] # trace back print "Tracing back..." for ix,iy in self.max_index : point = (ix,iy) track = Track(point) self.track_list.append(track) while True : slice = self.get_neighbourhood(point, size = slice_size) x,y = self.step(point, slice, threshold = step_thresh, direction = "backward") if x == None : break point = (x,y) track.add_point(point) maxima[y,x] = True print "tracing done." print 'tracks after tracing:', len(self.track_list) # join forward and back traces with the same staring point self.join_tracks() # concatenate nearby tracks on the same diagonal direction self.concatenate_tracks(size=track_gap) # prune duplicated tracks in local neighbourhood # self.prune_duplicates(maxima, size = 10) maxima = maximum_filter(maxima, size=2) # TODO: smooth paths, experiment with segmentation of individual tracks... self.count_groups() # empty mask for visualisation / further processing tracks = np.zeros_like(maxima) ssm_len = tracks.shape[0] # assess tracks individually, skip short ones and add the rest of the tracks to the mask for track in self.track_list : if track.length < track_min_len : continue track.node_array.sort() # for point in track.node_array : # tracks[point[1],point[0]] = 1.0 xs, xe = track.node_array[0][1], track.node_array[-1][1] ys, ye = track.node_array[0][0], track.node_array[-1][0] track_len = xe - xs for i in xrange(track_len): if max(xs+i, ys+i) < ssm_len: tracks[xs+i, ys+i] = 1.0 print 'number of final tracks', len(self.track_list) # tracks = uniform_filter(tracks.astype(np.float32), size = 2) # tracks[tracks<0.2] = 0.0 # tracks[tracks>=0.2] = 1.0 return reduced_ssm, self.max_index, tracks def join_tracks(self): '''Join tracks which share a common starting point. This function is essentially trying to join forward traces and back traces.''' # collect the set of unique starting points start_points = set() [start_points.add(track.start) for track in self.track_list] print "Initial Traces before joining:", len(self.track_list) print "Unique start points:", len(start_points) # join tracks starting from the same point and remove the residual for start in start_points: shared_tracks = [x for x in self.track_list if x.start == start] if len(shared_tracks) == 2 : shared_tracks[1].join(shared_tracks[0]) self.track_list.remove(shared_tracks[0]) print "Final tracklist after joining", len(self.track_list) return self.track_list def concatenate_tracks(self, size=3): '''Concatenate the end point and start point of two sequential tracks.''' start_points = set() [start_points.add(track.start) for track in self.track_list] end_points = set() [end_points.add(track.end) for track in self.track_list] print "Traces before concatenation:", len(self.track_list), len(start_points), len(end_points) for end in end_points: xe, ye = end if not [x for x in self.track_list if (x.end == end and x.length >1)]: continue track = [x for x in self.track_list if x.end == end][0] for i in xrange(1, size): xs, ys = xe+i, ye+i if (xs, ys) in start_points: succeeding_track_list = [x for x in self.track_list if x.start == (xs,ys)] if not succeeding_track_list: continue succeeding_track = [x for x in self.track_list if x.start == (xs,ys)][0] track.concatenate(succeeding_track) self.track_list.remove(succeeding_track) print "Traces after concatenation:", len(self.track_list) return self.track_list def get_nearest_maxima(self,point,threshold = 5.0): '''Find the nearest maxima to a given point using NN serach in the array of known maxima. NN serach is done usinf a KD-Tree approach because pairwise comparison is way too slow.''' # query tree parameters: k is the number of nearest neighbours to return, d is the distance type used (2: Euclidean), # distance_upper_bound specifies search realm d,i = self.kd_tree.query(point, k=1, p=2, distance_upper_bound= threshold) if d != np.inf : return tuple(self.max_index[i,:]) return None def get_neighbourhood(self,point,size=1): '''Return a square matrix centered around a given point with zero padding if point is close to the edges of the data array.''' # calculate boundaries xs = point[0]-size xe = point[0]+size+1 ys = point[1]-size ye = point[1]+size+1 # extract slice from the array cropped at edges y,x = self.ssm.shape slice = self.ssm[max(0,ys):min(ye,y),max(0,xs):min(xe,x)] # left/right padding if xs < 0 : leftpad = np.zeros((slice.shape[0],abs(xs))) slice = np.hstack([leftpad,slice]) if xe > x : rightpad = np.zeros((slice.shape[0],xe-x)) slice = np.hstack([slice,rightpad]) # top/bottom padding if ys < 0 : bottompad = np.zeros((abs(ys),slice.shape[1])) slice = np.vstack([bottompad,slice]) if ye > y : toppad = np.zeros((ye-y,slice.shape[1])) slice = np.vstack([slice,toppad]) return slice def step(self, point, slice, threshold = 0.3, direction = "forward"): '''Choose a step from the given point and retun the coordinate of the selected point. inputs: point (x,y) is the starting coordinate in the data matrix, slice is a square matrix centered around the given point, threshold helps to decide where to terminate a track, direction {forwards | backwards} describes which way to track along the X axis. output: The output is always a tuple. (None,None) in case the track is terminated or reached the boundary of the data matrix. (x,y) for the next valid step forwards or backwards. Note: The algorithm never steps straight up or down, i.e. the next coordinate relates to either the next or the previous point on the x axis. Note2: The intuition of this algorithm relates to both classical dynamic programming search and that of the Hough line transform. At each step a weighted line segment is considered corresponding to the slice of the data slice around the considered point. The line segment is rotated around the center point and the most higlhly weighted is choosen which prescribes the step direction of the algorithm. ''' backward = False if direction == 'backward': backward = True x,y = point # create direction specific weight vector w = np.linspace(0.0, 1.0, slice.shape[0]) if backward : w = w[::-1] # calcualte weighted sums of main diagonal a = sum(slice.diagonal() * w) segment_weight = a.max() / sum(w) # adjust steps for desired direction direction = 1 xstep = 1 if backward : xstep = -1 direction *= -1 xs,ys = x+xstep, y+direction yd,xd = self.ssm.shape # Terminate tracking if the weighted mean of the segment is below a threshold if segment_weight < threshold : # print "Terminating due to thd" return None,None # Terminate tracking if data matrix bounds are reached if xs < 0 or xs >= xd or ys < 0 or ys >= yd : # print "Terminating due to bound" return None,None return xs,ys def main(): plot = "-p" in sys.argv plot = True tracker = PathTracker() # ssm = np.loadtxt('/Users/mitian/Documents/hg/py-features/data/ssm.txt', delimiter=',') # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/06YellowSubmarine.txt',usecols=0) # ssm = np.loadtxt('/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data/1-12ShesOutOfMyLife-otsu.txt', delimiter=',') # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/1-12ShesOutOfMyLife.txt',usecols=0) ssm_files = [x for x in os.listdir(SSM_PATH) if not x.startswith('.')] ssm_files = [join(SSM_PATH, x) for x in ssm_files] ssm_files.sort() gt_files = [x for x in os.listdir(GT_PATH) if not x.startswith('.')] gt_files = [join(GT_PATH, x) for x in gt_files] gt_files.sort() for i, x in enumerate(ssm_files): ssm = np.genfromtxt(x, delimiter=',') gt = np.genfromtxt(gt_files[i], usecols=0) # gt = np.genfromtxt(gt_files[i], delimiter=',', usecols=0) audio_name = splitext(basename(gt_files[i]))[0] if isfile(join(TRACK_PATH, audio_name+'.txt')): continue print 'Processing:', audio_name reduced_ssm, maxima, tracks = tracker.process(ssm, thresh=0.5, min_local_dist=20, slice_size=20, step_thresh=0.4, track_min_len=50, track_gap=50) np.savetxt(join(TRACK_PATH, audio_name+'.txt'), tracks, delimiter=',') track_df = np.sum(tracks, axis=-1) # track_df = np.zeros(len(tracks)) # print len(tracker.track_list) # for track in tracker.track_list: # start, end = track.start[0], track.end[0] # # if (track.length != len(tracks)-1 and start < end): # # track_df[start:end] += 1 # track_df[start] += 1 # track_df[end] += 1 if plot : ax1 = plt.subplot(131) ax1.imshow(ssm, cmap='Greys') ax1.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r') ax2 = plt.subplot(132) ax2.imshow(reduced_ssm, cmap='Greys') ax2.scatter(zip(*maxima)[0], zip(*maxima)[1], s=5, c='y') ax2.set_xlim([0, len(tracks)]) ax2.set_ylim([len(tracks), 0]) ax3 = plt.subplot(133) ax3.imshow(tracks, cmap='Greys') # ax2.plot(np.arange(0, len(tracks)), track_df*10) ax3.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r') ax3.set_xlim([0, len(tracks)]) ax3.set_ylim([len(tracks), 0]) # plt.show() plt.savefig(join(TRACK_PATH, audio_name+'.pdf'), fomat='pdf') plt.close() # smoothing funcs if __name__ == '__main__': main()