mi@0: #!/usr/bin/env python mi@0: # encoding: utf-8 mi@0: """ mi@0: TempoPathTrackerUtil.py mi@0: mi@0: Created by George Fazekas on 2014-04-06. mi@0: Copyright (c) 2014 . All rights reserved. mi@0: mi@0: This program implements max path tracker combining ideas from dynamic programming and the Hough line transform. mi@0: It may be used to track tempo tracks in tempograms, partials in STFT spectrograms, or similar tasks. mi@0: mi@0: """ mi@0: mi@0: import os, sys, itertools mi@0: from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext mi@0: from scipy import ndimage mi@0: from scipy.ndimage.filters import maximum_filter, minimum_filter, median_filter, uniform_filter mi@0: from math import ceil, floor mi@0: from numpy import linspace mi@0: from numpy.linalg import norm mi@0: import numpy as np mi@0: import matplotlib.pyplot as plt mi@0: import matplotlib.image as mpimg mi@0: import scipy.spatial as ss mi@0: from math import sqrt mi@0: from copy import deepcopy mi@0: from skimage.feature import peak_local_max mi@0: mi@0: SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data_combined' mi@0: GT_PATH = '/Users/mitian/Documents/audio/annotation/isophonics' mi@0: TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/tracks' mi@0: # SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/ssm_data1' mi@0: # GT_PATH = '/Users/mitian/Documents/experiments/mit/annotation/qupujicheng1/lowercase' mi@0: # TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/tracks' mi@0: mi@0: class Track(object): mi@0: '''A track object representing a single fixed length path in the data.''' mi@0: mi@0: track_ID = 0 mi@0: mi@0: def __init__(self,start): mi@0: self.node_array = [] mi@0: self.pair_array = [] mi@0: self.start = start mi@0: self.id = Track.track_ID mi@0: Track.track_ID += 1 mi@0: self.sorted = False mi@0: self.end = self.get_end() mi@0: mi@0: def __eq__(self,other): mi@0: return self.id == other.id mi@0: mi@0: def add_point(self,point): mi@0: '''Add a node/point to the trace.''' mi@0: self.node_array.append(point) mi@0: mi@0: def add_pairs(self,point,pair): mi@0: '''Add neighbouring points to aid puruning double traces.''' mi@0: self.pair_array.append((point,pair)) mi@0: mi@0: @property mi@0: def length(self): mi@0: '''Calculate track length on the time axis.''' mi@0: nodes = np.array(self.node_array) mi@0: if len(nodes) : mi@0: return max(nodes[:,0]) - min(nodes[:,0]) mi@0: return 0 mi@0: mi@0: @property mi@0: def mean(self): mi@0: nodes = np.array(self.node_array) mi@0: return nodes.mean()[1] mi@0: mi@0: @property mi@0: def start_x(self): mi@0: return self.start[0] mi@0: mi@0: '''Replacing the property in the original implementation with a func to avoid the AttributeError: can't set attribute''' mi@0: # @property mi@0: # def end(self): mi@0: # if not self.node_array : mi@0: # return self.start mi@0: # if not self.sorted : mi@0: # self.node_array = sorted(self.node_array) mi@0: # self.start = self.node_array[0] mi@0: # return self.node_array[-1] mi@0: mi@0: def get_end(self): mi@0: if not self.node_array : mi@0: return self.start mi@0: if not self.sorted : mi@0: self.node_array = sorted(self.node_array) mi@0: self.start = self.node_array[0] mi@0: return self.node_array[-1] mi@0: mi@0: def join(self, other): mi@0: '''Join self with other by absorbing the nodes of other.''' mi@0: if not len(other.node_array): mi@0: # print "Warning: Empty track encountered." mi@0: return None mi@0: self.node_array.extend(other.node_array) mi@0: self.node_array = list(set(self.node_array)) mi@0: if other.end[0] < self.start[0] : mi@0: print "Info: Starting point moved from ", self.start[0], " to " ,other.end[0] mi@0: self.start = other.end mi@0: mi@0: def concatenate(self, other): mi@0: if (not len(other.node_array)) or (not len(self.node_array)) : mi@0: # print "Warning: Empty track encountered." mi@0: return None mi@0: self.end = other.end mi@0: self.node_array.extend(other.node_array) mi@0: self.node_array = list(set(self.node_array)) mi@0: self.node_array.sort() mi@0: mi@0: class PathTracker(object): mi@0: '''The main tracker object ''' mi@0: mi@0: def __init__(self): mi@0: self.track_list = [] mi@0: self.ssm = None mi@0: self.max_index = None mi@0: self.kd_tree = None mi@0: self.group_num = 0 mi@0: self.group = None mi@0: mi@0: def get_local_maxima(self,ssm, threshold = 0.7, neighborhood_size = 4): mi@0: '''Find local maxima in the ssm using a minfilt/maxfilt approach.''' mi@0: mi@0: # # uniform filter to smooth out discontinuities in the tracks mi@0: # ssm = uniform_filter(ssm, size = neighborhood_size) mi@0: # mi@0: # # basic noise reduction mi@0: # ssm[ssm < threshold] = 0.0 mi@0: # ssm[ssm > threshold] = 1.0 mi@0: # mi@0: # # maxfilt/minfilt local maxima detection mi@0: # data_max = maximum_filter(ssm, size = neighborhood_size) mi@0: # maxima = (ssm == data_max) mi@0: # data_min = minimum_filter(ssm, size = neighborhood_size) mi@0: # diff = ((data_max - data_min) > 0.00001) mi@0: # maxima[diff == 0] = 0 mi@0: mi@0: maxima = (ssm>threshold) mi@0: # create a list of tuples indexing the nonzero elements of maxima mi@0: iy,ix = maxima.nonzero() mi@0: indices = zip(ix,iy) mi@0: return indices,maxima mi@0: mi@0: def get_peak_local(self, ssm, thresh=0.8, min_distance=10, threshold_rel=0.8): mi@0: '''Final local maxima using skimage built-in funcs and return them as coordinates or a boolean array''' mi@0: mi@0: reduced_ssm = deepcopy(ssm) mi@0: reduced_ssm[reduced_ssm 10 : mi@0: continue mi@0: if abs(track1.start[1] - track2.start[1]) > 10 : mi@0: continue mi@0: if abs(track1.start[0] - track2.start[0]) > 10 : mi@0: continue mi@0: # print track1.id, track2.id mi@0: dist = [((i[0]-j[0])**2 + (i[1]-j[1])**2) for i in points1 for j in points2] mi@0: # if dist and sum(i < size for i in dist) > 1: mi@0: # print min(dist) mi@0: if dist and min(dist) < size : mi@0: # print min(dist) mi@0: # Nearby track found. If starts from distant positions, concatenate the two, mi@0: # otherwise discard the one with shorter lengh. mi@0: if len(points1) < len(points2): mi@0: duplicate = track1 mi@0: else: mi@0: duplicate = track2 mi@0: # duplicate = sorted([points1, points2], key=len)[0] mi@0: if duplicate in self.track_list: mi@0: self.track_list.remove(duplicate) mi@0: # print "removing ", duplicate.id mi@0: print "self.track_list pruned", len(self.track_list) mi@0: mi@0: def count_groups(self): mi@0: '''Cluster the tracks within the same horizontal area for later to calcute distance''' mi@0: self.track_list.sort(key=lambda x: x.start_x) mi@0: start_points = [track.start for track in self.track_list] mi@0: # start_points.sort(key=lambda tup: tup[0]) mi@0: for i in xrange(1, len(start_points)): mi@0: if start_points[i][0] - start_points[i-1][0] > 10.0: mi@0: self.group_num += 1 mi@0: mi@0: self.groups = [[] for n in xrange(self.group_num)] mi@0: for track in self.track_list: mi@0: for group_idx in xrange(self.group_num): mi@0: self.groups[group_idx].append(track) mi@0: mi@0: print 'self.groups', len(self.groups) mi@0: pass mi@0: mi@0: def histogram(self): mi@0: '''Compare pairwise distance for tracks within the same x-axis location and group by histograming the distance''' mi@0: for group in self.groups: mi@0: group_track = np.array(group) mi@0: mi@0: pass mi@0: mi@0: def process(self, ssm, thresh=0.8, min_local_dist=20, slice_size = 2, step_thresh=0.25, track_min_len=50, track_gap=50): mi@0: '''Track path in the ssm and mask values using the set of discrete path found.''' mi@0: mi@0: self.ssm = ssm mi@0: print "ssm.shape",ssm.shape mi@0: mi@0: # max_index,maxima = self.get_local_maxima(ssm, threshold=0.95, neighborhood_size =3) mi@0: reduced_ssm,max_index,maxima = self.get_peak_local(ssm, min_distance=min_local_dist, threshold_rel=0.5) mi@0: mi@0: # build a spatial binary search tree to aid removing maxima already passed by a trace mi@0: self.max_index = np.array(max_index) mi@0: if not len(self.max_index): mi@0: print 'No maxima found.' mi@0: return np.zeros_like(ssm) mi@0: self.kd_tree = ss.cKDTree(self.max_index) mi@0: mi@0: discard_maxima = set() mi@0: mi@0: # trace forwards mi@0: for ix,iy in self.max_index : mi@0: point = (ix,iy) mi@0: if point in discard_maxima : mi@0: continue mi@0: start = point mi@0: track = Track(start) mi@0: self.track_list.append(track) mi@0: while True : mi@0: slice = self.get_neighbourhood(point, size = slice_size) mi@0: x,y = self.step(point, slice, threshold = step_thresh, direction = "forward") mi@0: if x == None : break mi@0: point = (x,y) mi@0: remove = self.get_nearest_maxima(point) mi@0: if remove and remove != start: mi@0: discard_maxima.add(remove) mi@0: maxima[y,x] = True mi@0: track.add_point(point) mi@0: print "discarded maxima: ",len(discard_maxima) mi@0: mi@0: self.max_index = [(x,y) for x,y in self.max_index if (x,y) not in discard_maxima] mi@0: mi@0: # trace back mi@0: print "Tracing back..." mi@0: for ix,iy in self.max_index : mi@0: point = (ix,iy) mi@0: track = Track(point) mi@0: self.track_list.append(track) mi@0: while True : mi@0: slice = self.get_neighbourhood(point, size = slice_size) mi@0: x,y = self.step(point, slice, threshold = step_thresh, direction = "backward") mi@0: if x == None : break mi@0: point = (x,y) mi@0: track.add_point(point) mi@0: maxima[y,x] = True mi@0: mi@0: print "tracing done." mi@0: mi@0: print 'tracks after tracing:', len(self.track_list) mi@0: # join forward and back traces with the same staring point mi@0: self.join_tracks() mi@0: mi@0: # concatenate nearby tracks on the same diagonal direction mi@0: self.concatenate_tracks(size=track_gap) mi@0: mi@0: # prune duplicated tracks in local neighbourhood mi@0: # self.prune_duplicates(maxima, size = 10) mi@0: maxima = maximum_filter(maxima, size=2) mi@0: # TODO: smooth paths, experiment with segmentation of individual tracks... mi@0: self.count_groups() mi@0: mi@0: # empty mask for visualisation / further processing mi@0: tracks = np.zeros_like(maxima) mi@0: ssm_len = tracks.shape[0] mi@0: # assess tracks individually, skip short ones and add the rest of the tracks to the mask mi@0: for track in self.track_list : mi@0: if track.length < track_min_len : continue mi@0: track.node_array.sort() mi@0: # for point in track.node_array : mi@0: # tracks[point[1],point[0]] = 1.0 mi@0: xs, xe = track.node_array[0][1], track.node_array[-1][1] mi@0: ys, ye = track.node_array[0][0], track.node_array[-1][0] mi@0: track_len = xe - xs mi@0: for i in xrange(track_len): mi@0: if max(xs+i, ys+i) < ssm_len: mi@0: tracks[xs+i, ys+i] = 1.0 mi@0: print 'number of final tracks', len(self.track_list) mi@0: # tracks = uniform_filter(tracks.astype(np.float32), size = 2) mi@0: # tracks[tracks<0.2] = 0.0 mi@0: # tracks[tracks>=0.2] = 1.0 mi@0: mi@0: return reduced_ssm, self.max_index, tracks mi@0: mi@0: mi@0: def join_tracks(self): mi@0: '''Join tracks which share a common starting point. mi@0: This function is essentially trying to join forward traces and back traces.''' mi@0: mi@0: # collect the set of unique starting points mi@0: start_points = set() mi@0: [start_points.add(track.start) for track in self.track_list] mi@0: print "Initial Traces before joining:", len(self.track_list) mi@0: print "Unique start points:", len(start_points) mi@0: mi@0: # join tracks starting from the same point and remove the residual mi@0: for start in start_points: mi@0: shared_tracks = [x for x in self.track_list if x.start == start] mi@0: if len(shared_tracks) == 2 : mi@0: shared_tracks[1].join(shared_tracks[0]) mi@0: self.track_list.remove(shared_tracks[0]) mi@0: print "Final tracklist after joining", len(self.track_list) mi@0: return self.track_list mi@0: mi@0: def concatenate_tracks(self, size=3): mi@0: '''Concatenate the end point and start point of two sequential tracks.''' mi@0: mi@0: start_points = set() mi@0: [start_points.add(track.start) for track in self.track_list] mi@0: end_points = set() mi@0: [end_points.add(track.end) for track in self.track_list] mi@0: print "Traces before concatenation:", len(self.track_list), len(start_points), len(end_points) mi@0: for end in end_points: mi@0: xe, ye = end mi@0: if not [x for x in self.track_list if (x.end == end and x.length >1)]: continue mi@0: track = [x for x in self.track_list if x.end == end][0] mi@0: for i in xrange(1, size): mi@0: xs, ys = xe+i, ye+i mi@0: if (xs, ys) in start_points: mi@0: succeeding_track_list = [x for x in self.track_list if x.start == (xs,ys)] mi@0: if not succeeding_track_list: continue mi@0: succeeding_track = [x for x in self.track_list if x.start == (xs,ys)][0] mi@0: track.concatenate(succeeding_track) mi@0: self.track_list.remove(succeeding_track) mi@0: print "Traces after concatenation:", len(self.track_list) mi@0: return self.track_list mi@0: mi@0: def get_nearest_maxima(self,point,threshold = 5.0): mi@0: '''Find the nearest maxima to a given point using NN serach in the array of known maxima. mi@0: NN serach is done usinf a KD-Tree approach because pairwise comparison is way too slow.''' mi@0: mi@0: # query tree parameters: k is the number of nearest neighbours to return, d is the distance type used (2: Euclidean), mi@0: # distance_upper_bound specifies search realm mi@0: d,i = self.kd_tree.query(point, k=1, p=2, distance_upper_bound= threshold) mi@0: if d != np.inf : mi@0: return tuple(self.max_index[i,:]) mi@0: return None mi@0: mi@0: mi@0: def get_neighbourhood(self,point,size=1): mi@0: '''Return a square matrix centered around a given point mi@0: with zero padding if point is close to the edges of the data array.''' mi@0: mi@0: # calculate boundaries mi@0: xs = point[0]-size mi@0: xe = point[0]+size+1 mi@0: ys = point[1]-size mi@0: ye = point[1]+size+1 mi@0: mi@0: # extract slice from the array cropped at edges mi@0: y,x = self.ssm.shape mi@0: slice = self.ssm[max(0,ys):min(ye,y),max(0,xs):min(xe,x)] mi@0: mi@0: # left/right padding mi@0: if xs < 0 : mi@0: leftpad = np.zeros((slice.shape[0],abs(xs))) mi@0: slice = np.hstack([leftpad,slice]) mi@0: mi@0: if xe > x : mi@0: rightpad = np.zeros((slice.shape[0],xe-x)) mi@0: slice = np.hstack([slice,rightpad]) mi@0: mi@0: # top/bottom padding mi@0: if ys < 0 : mi@0: bottompad = np.zeros((abs(ys),slice.shape[1])) mi@0: slice = np.vstack([bottompad,slice]) mi@0: mi@0: if ye > y : mi@0: toppad = np.zeros((ye-y,slice.shape[1])) mi@0: slice = np.vstack([slice,toppad]) mi@0: mi@0: return slice mi@0: mi@0: mi@0: def step(self, point, slice, threshold = 0.3, direction = "forward"): mi@0: '''Choose a step from the given point and retun the coordinate of the selected point. mi@0: mi@0: inputs: mi@0: point (x,y) is the starting coordinate in the data matrix, mi@0: slice is a square matrix centered around the given point, mi@0: threshold helps to decide where to terminate a track, mi@0: direction {forwards | backwards} describes which way to track along the X axis. mi@0: mi@0: output: mi@0: The output is always a tuple. mi@0: (None,None) in case the track is terminated or reached the boundary of the data matrix. mi@0: (x,y) for the next valid step forwards or backwards. mi@0: mi@0: Note: The algorithm never steps straight up or down, i.e. the next coordinate relates to mi@0: either the next or the previous point on the x axis. mi@0: mi@0: Note2: The intuition of this algorithm relates to both classical dynamic programming search mi@0: and that of the Hough line transform. At each step a weighted line segment is considered mi@0: corresponding to the slice of the data slice around the considered point. The line segment mi@0: is rotated around the center point and the most higlhly weighted is choosen which prescribes mi@0: the step direction of the algorithm. mi@0: ''' mi@0: mi@0: backward = False mi@0: if direction == 'backward': mi@0: backward = True mi@0: x,y = point mi@0: mi@0: # create direction specific weight vector mi@0: w = np.linspace(0.0, 1.0, slice.shape[0]) mi@0: if backward : w = w[::-1] mi@0: mi@0: # calcualte weighted sums of main diagonal mi@0: a = sum(slice.diagonal() * w) mi@0: segment_weight = a.max() / sum(w) mi@0: mi@0: # adjust steps for desired direction mi@0: direction = 1 mi@0: xstep = 1 mi@0: if backward : mi@0: xstep = -1 mi@0: direction *= -1 mi@0: mi@0: xs,ys = x+xstep, y+direction mi@0: yd,xd = self.ssm.shape mi@0: mi@0: # Terminate tracking if the weighted mean of the segment is below a threshold mi@0: if segment_weight < threshold : mi@0: # print "Terminating due to thd" mi@0: return None,None mi@0: mi@0: # Terminate tracking if data matrix bounds are reached mi@0: if xs < 0 or xs >= xd or ys < 0 or ys >= yd : mi@0: # print "Terminating due to bound" mi@0: return None,None mi@0: mi@0: return xs,ys mi@0: mi@0: mi@0: def main(): mi@0: mi@0: plot = "-p" in sys.argv mi@0: plot = True mi@0: mi@0: tracker = PathTracker() mi@0: mi@0: # ssm = np.loadtxt('/Users/mitian/Documents/hg/py-features/data/ssm.txt', delimiter=',') mi@0: # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/06YellowSubmarine.txt',usecols=0) mi@0: # ssm = np.loadtxt('/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data/1-12ShesOutOfMyLife-otsu.txt', delimiter=',') mi@0: # gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/1-12ShesOutOfMyLife.txt',usecols=0) mi@0: mi@0: ssm_files = [x for x in os.listdir(SSM_PATH) if not x.startswith('.')] mi@0: ssm_files = [join(SSM_PATH, x) for x in ssm_files] mi@0: ssm_files.sort() mi@0: gt_files = [x for x in os.listdir(GT_PATH) if not x.startswith('.')] mi@0: gt_files = [join(GT_PATH, x) for x in gt_files] mi@0: gt_files.sort() mi@0: mi@0: for i, x in enumerate(ssm_files): mi@0: ssm = np.genfromtxt(x, delimiter=',') mi@0: gt = np.genfromtxt(gt_files[i], usecols=0) mi@0: # gt = np.genfromtxt(gt_files[i], delimiter=',', usecols=0) mi@0: audio_name = splitext(basename(gt_files[i]))[0] mi@0: if isfile(join(TRACK_PATH, audio_name+'.txt')): continue mi@0: print 'Processing:', audio_name mi@0: mi@0: reduced_ssm, maxima, tracks = tracker.process(ssm, thresh=0.5, min_local_dist=20, slice_size=20, step_thresh=0.4, track_min_len=50, track_gap=50) mi@0: np.savetxt(join(TRACK_PATH, audio_name+'.txt'), tracks, delimiter=',') mi@0: mi@0: track_df = np.sum(tracks, axis=-1) mi@0: # track_df = np.zeros(len(tracks)) mi@0: # print len(tracker.track_list) mi@0: # for track in tracker.track_list: mi@0: # start, end = track.start[0], track.end[0] mi@0: # # if (track.length != len(tracks)-1 and start < end): mi@0: # # track_df[start:end] += 1 mi@0: # track_df[start] += 1 mi@0: # track_df[end] += 1 mi@0: mi@0: if plot : mi@0: ax1 = plt.subplot(131) mi@0: ax1.imshow(ssm, cmap='Greys') mi@0: ax1.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r') mi@0: mi@0: ax2 = plt.subplot(132) mi@0: ax2.imshow(reduced_ssm, cmap='Greys') mi@0: ax2.scatter(zip(*maxima)[0], zip(*maxima)[1], s=5, c='y') mi@0: ax2.set_xlim([0, len(tracks)]) mi@0: ax2.set_ylim([len(tracks), 0]) mi@0: mi@0: ax3 = plt.subplot(133) mi@0: ax3.imshow(tracks, cmap='Greys') mi@0: # ax2.plot(np.arange(0, len(tracks)), track_df*10) mi@0: ax3.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r') mi@0: ax3.set_xlim([0, len(tracks)]) mi@0: ax3.set_ylim([len(tracks), 0]) mi@0: # plt.show() mi@0: plt.savefig(join(TRACK_PATH, audio_name+'.pdf'), fomat='pdf') mi@0: plt.close() mi@0: # smoothing funcs mi@0: mi@0: mi@0: if __name__ == '__main__': mi@0: main()