view utils/PathTracker.py @ 18:b4bf37f94e92

prepared to add another annotation
author mitian
date Wed, 09 Dec 2015 16:27:10 +0000
parents 26838b1f560f
children
line wrap: on
line source
#!/usr/bin/env python
# encoding: utf-8
"""
TempoPathTrackerUtil.py 

Created by George Fazekas on 2014-04-06.
Copyright (c) 2014 . All rights reserved.

This program implements max path tracker combining ideas from dynamic programming and the Hough line transform.
It may be used to track tempo tracks in tempograms, partials in STFT spectrograms, or similar tasks.

"""

import os, sys, itertools
from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext
from scipy import ndimage
from scipy.ndimage.filters import maximum_filter, minimum_filter, median_filter, uniform_filter
from math import ceil, floor
from numpy import linspace
from numpy.linalg import norm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.spatial as ss
from math import sqrt
from copy import deepcopy
from skimage.feature import peak_local_max

SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data_combined'
GT_PATH = '/Users/mitian/Documents/audio/annotation/isophonics'
TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/tracks'
# SSM_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/ssm_data1'
# GT_PATH = '/Users/mitian/Documents/experiments/mit/annotation/qupujicheng1/lowercase'
# TRACK_PATH = '/Users/mitian/Documents/experiments/mit/segmentation/combined/qupujicheng/tracks'

class Track(object):
	'''A track object representing a single fixed length path in the data.'''
	
	track_ID = 0
	
	def __init__(self,start):
		self.node_array = []
		self.pair_array = []
		self.start = start
		self.id = Track.track_ID
		Track.track_ID += 1
		self.sorted = False
		self.end = self.get_end()
		
	def __eq__(self,other):
		return self.id == other.id
		
	def add_point(self,point):
		'''Add a node/point to the trace.'''
		self.node_array.append(point)
		
	def add_pairs(self,point,pair):
		'''Add neighbouring points to aid puruning double traces.'''
		self.pair_array.append((point,pair))
		
	@property
	def length(self):
		'''Calculate track length on the time axis.'''
		nodes = np.array(self.node_array)
		if len(nodes) :
			return max(nodes[:,0]) - min(nodes[:,0])
		return 0
		
	@property
	def mean(self):
		nodes = np.array(self.node_array)
		return nodes.mean()[1]
		
	@property
	def start_x(self):
		return self.start[0]
		
	'''Replacing the property in the original implementation with a func to avoid the AttributeError: can't set attribute'''
	# @property
	# def end(self):
	# 	if not self.node_array :
	# 		return self.start
	# 	if not self.sorted :
	# 		self.node_array = sorted(self.node_array)
	# 		self.start = self.node_array[0]
	# 	return self.node_array[-1]
	
	def get_end(self):
		if not self.node_array :
			return self.start
		if not self.sorted :
			self.node_array = sorted(self.node_array)
			self.start = self.node_array[0]
		return self.node_array[-1]
		
	def join(self, other):
		'''Join self with other by absorbing the nodes of other.'''
		if not len(other.node_array):
			# print "Warning: Empty track encountered."
			return None
		self.node_array.extend(other.node_array)
		self.node_array = list(set(self.node_array))		
		if other.end[0] < self.start[0] :
			print "Info: Starting point moved from ", self.start[0], " to " ,other.end[0]
			self.start = other.end
	
	def concatenate(self, other):
		if (not len(other.node_array)) or (not len(self.node_array)) :
			# print "Warning: Empty track encountered."
			return None
		self.end = other.end
		self.node_array.extend(other.node_array)
		self.node_array = list(set(self.node_array))
		self.node_array.sort()	
		
class PathTracker(object):
	'''The main tracker object '''
	
	def __init__(self):
		self.track_list = []
		self.ssm = None
		self.max_index = None
		self.kd_tree = None
		self.group_num = 0
		self.group = None
		
	def get_local_maxima(self,ssm, threshold = 0.7, neighborhood_size = 4):
		'''Find local maxima in the ssm using a minfilt/maxfilt approach.'''
		
		# # uniform filter to smooth out discontinuities in the tracks
		# ssm = uniform_filter(ssm, size = neighborhood_size)
		# 
		# # basic noise reduction
		# ssm[ssm < threshold] = 0.0
		# ssm[ssm > threshold] = 1.0
		# 
		# # maxfilt/minfilt local maxima detection
		# data_max = maximum_filter(ssm, size = neighborhood_size)
		# maxima = (ssm == data_max)
		# data_min = minimum_filter(ssm, size = neighborhood_size)
		# diff = ((data_max - data_min) > 0.00001)
		# maxima[diff == 0] = 0
		
		maxima = (ssm>threshold)		
		# create a list of tuples indexing the nonzero elements of maxima
		iy,ix = maxima.nonzero()
		indices = zip(ix,iy)
		return indices,maxima
	
	def get_peak_local(self, ssm, thresh=0.8, min_distance=10, threshold_rel=0.8):
		'''Final local maxima using skimage built-in funcs and return them as coordinates or a boolean array'''
		
		reduced_ssm = deepcopy(ssm)
		reduced_ssm[reduced_ssm<thresh] = 0.0 # a hard thresholding for finding maxima
		ssm[ssm<0.6] = 0.0 
		np.fill_diagonal(reduced_ssm, 0) # zero fill dignonal in case it will be picked as the only maxima in the neighborhood
		indices = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=True)
		maxima = peak_local_max(reduced_ssm, min_distance=min_distance, threshold_rel=threshold_rel, indices=False)
		return reduced_ssm, indices, maxima
			
	def prune_duplicates(self, maxima, size):
		track_list = deepcopy(self.track_list)
		# print "len track_list 1", len(track_list)
		for track in track_list:
			if not track.node_array:
				self.track_list.remove(track)
		# print "len track_list 2", len(self.track_list)
		
				
		track_list = deepcopy(self.track_list)		
		print "self.track_list start", len(self.track_list)
		for track1, track2 in itertools.combinations(track_list, 2): 
			points1 = track1.node_array
			points2 = track2.node_array
			if abs(track1.end[1] - track2.end[1]) > 10 :
				continue
			if abs(track1.start[1] - track2.start[1]) > 10 :
				continue
			if abs(track1.start[0] - track2.start[0]) > 10 :
				continue
			# print track1.id, track2.id
			dist = [((i[0]-j[0])**2 + (i[1]-j[1])**2) for i in points1 for j in points2]	
			# if dist and sum(i < size for i in dist) > 1:	
				# print min(dist)
			if dist and min(dist) < size :
				# print min(dist)
				# Nearby track found. If starts from distant positions, concatenate the two,
				# otherwise discard the one with shorter lengh.
				if len(points1) < len(points2):
					duplicate = track1
				else:
					duplicate = track2
				# duplicate = sorted([points1, points2], key=len)[0]
				if duplicate in self.track_list:
					self.track_list.remove(duplicate)
					# print "removing ", duplicate.id
		print "self.track_list pruned", len(self.track_list)
	
	def count_groups(self):
		'''Cluster the tracks within the same horizontal area for later to calcute distance'''
		self.track_list.sort(key=lambda x: x.start_x)
		start_points = [track.start for track in self.track_list]
		# start_points.sort(key=lambda tup: tup[0])
		for i in xrange(1, len(start_points)):
			if start_points[i][0] - start_points[i-1][0] > 10.0:
				self.group_num += 1
				
		self.groups = [[] for n in xrange(self.group_num)]
		for track in self.track_list:
			for group_idx in xrange(self.group_num):
				self.groups[group_idx].append(track)
		
		print 'self.groups', len(self.groups)
		pass
		
	def histogram(self):
		'''Compare pairwise distance for tracks within the same x-axis location and group by histograming the distance'''
		for group in self.groups:
			group_track = np.array(group)
			
		pass
		
	def process(self, ssm, thresh=0.8, min_local_dist=20, slice_size = 2, step_thresh=0.25, track_min_len=50, track_gap=50):
		'''Track path in the ssm and mask values using the set of discrete path found.'''
		
		self.ssm = ssm
		print "ssm.shape",ssm.shape
		
		# max_index,maxima = self.get_local_maxima(ssm, threshold=0.95, neighborhood_size =3)
		reduced_ssm,max_index,maxima = self.get_peak_local(ssm, min_distance=min_local_dist, threshold_rel=0.5)
		
		# build a spatial binary search tree to aid removing maxima already passed by a trace
		self.max_index = np.array(max_index)
		if not len(self.max_index):
			print 'No maxima found.'
			return np.zeros_like(ssm)
		self.kd_tree = ss.cKDTree(self.max_index)
		
		discard_maxima = set()

		# trace forwards
		for ix,iy in self.max_index :
			point = (ix,iy)
			if point in discard_maxima :
				continue
			start = point
			track = Track(start)
			self.track_list.append(track)
			while True :
				slice = self.get_neighbourhood(point, size = slice_size)
				x,y = self.step(point, slice, threshold = step_thresh, direction = "forward")
				if x == None : break
				point = (x,y)
				remove = self.get_nearest_maxima(point)
				if remove and remove != start:
					discard_maxima.add(remove)
				maxima[y,x] = True
				track.add_point(point)
		print "discarded maxima: ",len(discard_maxima)
				
		self.max_index = [(x,y) for x,y in self.max_index if (x,y) not in discard_maxima]
		
		# trace back
		print "Tracing back..."
		for ix,iy in self.max_index :
			point = (ix,iy)
			track = Track(point)
			self.track_list.append(track)
			while True :
				slice = self.get_neighbourhood(point, size = slice_size)
				x,y = self.step(point, slice, threshold = step_thresh, direction = "backward")
				if x == None : break
				point = (x,y)
				track.add_point(point)
				maxima[y,x] = True
		
		print "tracing done."
		
		print 'tracks after tracing:', len(self.track_list)
		# join forward and back traces with the same staring point
		self.join_tracks()
		
		# concatenate nearby tracks on the same diagonal direction
		self.concatenate_tracks(size=track_gap)
		
		# prune duplicated tracks in local neighbourhood
		# self.prune_duplicates(maxima, size = 10)
		maxima = maximum_filter(maxima, size=2)
		# TODO: smooth paths, experiment with segmentation of individual tracks...
		self.count_groups()
		
		# empty mask for visualisation / further processing
		tracks = np.zeros_like(maxima)
		ssm_len = tracks.shape[0]
		# assess tracks individually, skip short ones and add the rest of the tracks to the mask
		for track in self.track_list :
			if track.length < track_min_len : continue
			track.node_array.sort()
			# for point in track.node_array :
			# 	tracks[point[1],point[0]] = 1.0
			xs, xe = track.node_array[0][1], track.node_array[-1][1]
			ys, ye = track.node_array[0][0], track.node_array[-1][0]
			track_len = xe - xs
			for i in xrange(track_len):
				if max(xs+i, ys+i) < ssm_len:
					tracks[xs+i, ys+i] = 1.0
		print 'number of final tracks', len(self.track_list)
		# tracks = uniform_filter(tracks.astype(np.float32), size = 2)
		# tracks[tracks<0.2] = 0.0
		# tracks[tracks>=0.2] = 1.0
		
		return reduced_ssm, self.max_index, tracks

		
	def join_tracks(self):
		'''Join tracks which share a common starting point.
		This function is essentially trying to join forward traces and back traces.'''
		
		# collect the set of unique starting points
		start_points = set()
		[start_points.add(track.start) for track in self.track_list]
		print "Initial Traces before joining:", len(self.track_list)
		print "Unique start points:", len(start_points)
		
		# join tracks starting from the same point and remove the residual
		for start in start_points:
			shared_tracks = [x for x in self.track_list if x.start == start]
			if len(shared_tracks) == 2 :
				shared_tracks[1].join(shared_tracks[0])
				self.track_list.remove(shared_tracks[0])
		print "Final tracklist after joining", len(self.track_list)
		return self.track_list

	def concatenate_tracks(self, size=3):
		'''Concatenate the end point and start point of two sequential tracks.'''
		
		start_points = set()
		[start_points.add(track.start) for track in self.track_list]
		end_points = set()
		[end_points.add(track.end) for track in self.track_list]
		print "Traces before concatenation:", len(self.track_list), len(start_points), len(end_points)
		for end in end_points:
			xe, ye = end
			if not [x for x in self.track_list if (x.end == end and x.length >1)]: continue
			track = [x for x in self.track_list if x.end == end][0]
			for i in xrange(1, size):
				xs, ys = xe+i, ye+i
				if (xs, ys) in start_points:
					succeeding_track_list = [x for x in self.track_list if x.start == (xs,ys)]
					if not succeeding_track_list: continue
					succeeding_track = [x for x in self.track_list if x.start == (xs,ys)][0]
					track.concatenate(succeeding_track)
					self.track_list.remove(succeeding_track)
		print "Traces after concatenation:", len(self.track_list)
		return self.track_list
		
	def get_nearest_maxima(self,point,threshold = 5.0):
		'''Find the nearest maxima to a given point using NN serach in the array of known maxima.
		NN serach is done usinf a KD-Tree approach because pairwise comparison is way too slow.'''
		
		# query tree parameters: k is the number of nearest neighbours to return, d is the distance type used (2: Euclidean),
		# distance_upper_bound specifies search realm
		d,i = self.kd_tree.query(point, k=1, p=2, distance_upper_bound= threshold)
		if d != np.inf :
			return tuple(self.max_index[i,:])
		return None


	def get_neighbourhood(self,point,size=1):
		'''Return a square matrix centered around a given point 
		with zero padding if point is close to the edges of the data array.'''
		
		# calculate boundaries
		xs = point[0]-size
		xe = point[0]+size+1		
		ys = point[1]-size
		ye = point[1]+size+1
		
		# extract slice from the array cropped at edges
		y,x = self.ssm.shape
		slice = self.ssm[max(0,ys):min(ye,y),max(0,xs):min(xe,x)]
		
		# left/right padding
		if xs < 0 : 
			leftpad = np.zeros((slice.shape[0],abs(xs)))
			slice = np.hstack([leftpad,slice])

		if xe > x : 
			rightpad = np.zeros((slice.shape[0],xe-x))
			slice = np.hstack([slice,rightpad])
		
		# top/bottom padding
		if ys < 0 : 
			bottompad = np.zeros((abs(ys),slice.shape[1]))
			slice = np.vstack([bottompad,slice])

		if ye > y : 
			toppad = np.zeros((ye-y,slice.shape[1]))
			slice = np.vstack([slice,toppad])
		
		return slice

	
	def step(self, point, slice, threshold = 0.3, direction = "forward"):
		'''Choose a step from the given point and retun the coordinate of the selected point.

		inputs:
			point (x,y) is the starting coordinate in the data matrix,
			slice is a square matrix centered around the given point,
			threshold helps to decide where to terminate a track,
			direction {forwards | backwards} describes which way to track along the X axis.

		output:
			The output is always a tuple.
			(None,None) in case the track is terminated or reached the boundary of the data matrix.
			(x,y) for the next valid step forwards or backwards. 

		Note: The algorithm never steps straight up or down, i.e. the next coordinate relates to
		either the next or the previous point on the x axis.
		
		Note2: The intuition of this algorithm relates to both classical dynamic programming search 
		and that of the Hough line transform. At each step a weighted line segment is considered 
		corresponding to the slice of the data slice around the considered point. The line segment
		is rotated around the center point and the most higlhly weighted is choosen which prescribes
		the step direction of the algorithm.
		'''
	
		backward = False
		if direction == 'backward':
			backward = True		
		x,y = point
		
		# create direction specific weight vector
		w = np.linspace(0.0, 1.0, slice.shape[0])
		if backward : w = w[::-1]
		
		# calcualte weighted sums of main diagonal
		a = sum(slice.diagonal() * w)
		segment_weight = a.max() / sum(w)
		
		# adjust steps for desired direction
		direction = 1
		xstep = 1
		if backward : 
			xstep = -1
			direction *= -1
		
		xs,ys = x+xstep, y+direction
		yd,xd = self.ssm.shape
		
		# Terminate tracking if the weighted mean of the segment is below a threshold
		if segment_weight < threshold : 
			# print "Terminating due to thd"
			return None,None
		
		# Terminate tracking if data matrix bounds are reached
		if xs < 0 or xs >= xd or ys < 0 or ys >= yd : 
			# print "Terminating due to bound"
			return None,None
		
		return xs,ys


def main():
	
	plot = "-p" in sys.argv
	plot = True
	
	tracker = PathTracker()
	
	# ssm = np.loadtxt('/Users/mitian/Documents/hg/py-features/data/ssm.txt', delimiter=',')
	# gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/06YellowSubmarine.txt',usecols=0)
	# ssm = np.loadtxt('/Users/mitian/Documents/experiments/mit/segmentation/combined/iso/ssm_data/1-12ShesOutOfMyLife-otsu.txt', delimiter=',')
	# gt = np.genfromtxt('/Users/mitian/Documents/audio/annotation/isophonics/1-12ShesOutOfMyLife.txt',usecols=0)
	
	ssm_files = [x for x in os.listdir(SSM_PATH) if not x.startswith('.')]
	ssm_files = [join(SSM_PATH, x) for x in ssm_files]
	ssm_files.sort()
	gt_files = [x for x in os.listdir(GT_PATH) if not x.startswith('.')]
	gt_files = [join(GT_PATH, x) for x in gt_files]
	gt_files.sort()
	
	for i, x in enumerate(ssm_files):
		ssm = np.genfromtxt(x, delimiter=',')
		gt = np.genfromtxt(gt_files[i], usecols=0)
		# gt = np.genfromtxt(gt_files[i], delimiter=',', usecols=0)
		audio_name = splitext(basename(gt_files[i]))[0]
		if isfile(join(TRACK_PATH, audio_name+'.txt')): continue
		print 'Processing:', audio_name

		reduced_ssm, maxima, tracks = tracker.process(ssm, thresh=0.5, min_local_dist=20, slice_size=20, step_thresh=0.4, track_min_len=50, track_gap=50)
		np.savetxt(join(TRACK_PATH, audio_name+'.txt'), tracks, delimiter=',')
			
		track_df = np.sum(tracks, axis=-1)
		# track_df = np.zeros(len(tracks))
		# print len(tracker.track_list)
		# for track in tracker.track_list:
		# 	start, end = track.start[0], track.end[0]
		# 	# if (track.length != len(tracks)-1 and start < end):
		# 	# 	 track_df[start:end] += 1
		# 	track_df[start] += 1
		# 	track_df[end] += 1
		
		if plot :
			ax1 = plt.subplot(131)
			ax1.imshow(ssm, cmap='Greys')
			ax1.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r')
			
			ax2 = plt.subplot(132)
			ax2.imshow(reduced_ssm, cmap='Greys')
			ax2.scatter(zip(*maxima)[0], zip(*maxima)[1], s=5, c='y')
			ax2.set_xlim([0, len(tracks)])
			ax2.set_ylim([len(tracks), 0])
			
			ax3 = plt.subplot(133)
			ax3.imshow(tracks, cmap='Greys')
			# ax2.plot(np.arange(0, len(tracks)), track_df*10)
			ax3.vlines(gt / gt[-1] * len(track_df), 0, len(track_df), colors='r')
			ax3.set_xlim([0, len(tracks)])
			ax3.set_ylim([len(tracks), 0])
			# plt.show()
			plt.savefig(join(TRACK_PATH, audio_name+'.pdf'), fomat='pdf')
			plt.close()
		# smoothing funcs
	
	
if __name__ == '__main__':
	main()