diff ProbSegmenter.py @ 13:cc8ceb270e79

add some gmm ipynbs
author mitian
date Fri, 05 Jun 2015 18:02:05 +0100
parents
children 6dae41887406
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ProbSegmenter.py	Fri Jun 05 18:02:05 2015 +0100
@@ -0,0 +1,262 @@
+#!/usr/bin/env python
+# encoding: utf-8
+"""
+ProbSegmenter.py
+"""
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+import sys, os, optparse, cPickle
+import numpy as np
+from numpy import abs,log,exp,floor,sum,sqrt,cos,hstack, power
+from math import ceil, floor
+from scipy.signal import correlate2d, convolve2d
+from itertools import combinations
+from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext
+
+from sklearn.decomposition import PCA
+from sklearn.mixture import GMM
+from sklearn.metrics.pairwise import pairwise_distances
+from sklearn.cluster import KMeans, DBSCAN
+from scipy.spatial.distance import squareform, pdist
+
+from utils.GmmMetrics import GmmDistance
+from utils.RankClustering import rClustering
+from utils.kmeans import Kmeans
+from utils.ComputationCache import Meta, with_pickle_dump
+from utils.SegUtil import normaliseFeature, upSample, getSSM, getMean, getStd, getDelta
+
+def parse_args():
+	# define parser
+	op = optparse.OptionParser()
+	# IO options
+	op.add_option('-i', '--input', action="store", dest="INPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/features1', type="str", help="Loading features from.." )
+	op.add_option('-a', '--audioset', action="store", dest="AUDIO", default=None, type="str", help="Select audio datasets ['qupujicheng', 'salami', 'beatles'] ")
+	op.add_option('-g', '--groundtruth', action="store", dest="GT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/annotations/lowercase', type="str", help="Loading annotation files from.. ")
+	op.add_option('-o', '--output', action="store", dest="OUTPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering', type="str", help="Loading annotation files from.. ")
+	op.add_option('-c', '--cache', action="store", dest="CACHE", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering/cache', type="str", help="Loading annotation files from.. ")
+	op.add_option('-t', '--test', action="store_true", dest="TEST", help="Select TEST mode.")
+	op.add_option('-v', '--verbose', action="store_true", dest="VERBOSE", help="Select VERBOSE mode to save features/SSMs.")
+
+	return op.parse_args()
+
+options, args = parse_args()
+
+class FeatureObj() :
+	__slots__ = ['key','audio','timestamps','features']
+
+class AudioObj():
+	__slots__ = ['feature_list','feature_matrix','distance_matrix','name']
+
+class Segmenter(object):
+	'''The main segmentation object'''
+	
+	meta = Meta()
+	
+	@classmethod	
+	def set_cache_filename(cls, filename, cache = True, cache_location=""):
+		cls.meta.cache = cache
+		cls.meta.cache_file_base = filename	
+		cls.meta.cache_location = cache_location	
+	
+	@with_pickle_dump(meta)
+	def getGMMs(self, feature, filename, gmmWindow=10, stepsize=1, save=True):
+		gmm_list = []
+		steps = int((feature.shape[0] - gmmWindow + stepsize) / stepsize)
+		for i in xrange(steps):
+			gmm_list.append(GmmDistance(feature[i*stepsize:(i*stepsize+gmmWindow), :].T, components = 2))
+		# if save:
+		# 	with open(join('cache', filename), 'w+') as f:
+		# 		f.write(cPickle.dumps(gmm_list))
+		return gmm_list
+	
+	def pairwiseSKL(self, gmm_list):
+		'''Compute pairwise symmetrised KL divergence of a list of GMMs.'''
+		n_GMMs = len(gmm_list)
+		distance_matrix = np.zeros((n_GMMs, n_GMMs))
+		for i in xrange(n_GMMs):
+			for j in xrange(i, n_GMMs):
+				distance_matrix[i][j] = gmm_list[i].skl_distance_full(gmm_list[j])
+				distance_matrix[j][i] = distance_matrix[i][j]
+
+		np.fill_diagonal(distance_matrix, 0.0)
+		distance_matrix[np.isnan(distance_matrix)] = 0
+		# distance_matrix[np.isinf(distance_matrix)] = np.finfo(np.float64).max
+		if np.isinf(distance_matrix).any():
+			data = np.sort(np.ndarray.flatten(distance_matrix))
+			pos = np.where(data == np.inf)[0][0]
+			fMax = data[pos-1]
+			print len(data), pos, fMax
+			distance_matrix[np.isinf(distance_matrix)] = fMax
+		return distance_matrix
+	
+	def smoothLabels(self, label_list, size=5):
+		'''Smooth label list within given length.'''		
+		prev_labels = -1
+		next_labels = -1
+		chain = 0
+		for i in xrange(size, len(label_list)-size):
+			label = label_list[i]
+			if label == prev_label:
+				chain += 1
+			else:
+				if chain < size:
+					label_list[i] = prev_label
+				chain = 0
+			prev_label = label_list[i]
+			print chain
+		return label_list
+	
+	def getInitialCentroids(self, neighborhood_size, k=10):
+		candidates = []
+		size = len(neighborhood_size)
+		for i in xrange(1, size-1):
+			d1 = neighborhood_size[i] - neighborhood_size[i-1]
+			d2 = neighborhood_size[i] - neighborhood_size[i+1]
+			if d1 > 0 and d2 > 0:
+				candidates.append((i, max(d1, d2)))
+		print 'candidates', len(candidates), candidates, size
+		ranked_nodes = sorted(candidates, key=lambda x: x[1],reverse=True)[:k]
+		ranked_nodes = [ranked_nodes[i][0] for i in xrange(len(ranked_nodes))]
+		return ranked_nodes
+				
+		
+	def process(self):
+		
+		audio_files = [x for x in os.listdir(options.GT) if not x.startswith(".") ]
+		audio_files.sort()
+		if options.TEST:
+			# audio_files = ["""17 We Are The Champions.wav"""]
+			audio_files = audio_files[:2]
+		audio_list = []
+
+		fobj_list = []
+		feature_list = [i for i in os.listdir(options.INPUT) if not i.startswith('.')]
+		feature_list = ['pcamean', 'dct', 'contrast6']
+		
+		feature_list.sort()
+		
+		winlength = 50
+		stepsize = 50
+		
+		if options.AUDIO == None:
+			print 'Must specify audio dataset for evaluvation!'
+			
+		for i, audio in enumerate(audio_files) :
+			ao = AudioObj()
+			ao.name = splitext(audio)[0]
+			
+			ao_featureset = []
+			for feature in feature_list :
+				for f in os.listdir(join(options.INPUT, feature)):
+					if f[:f.find('_vamp')]==ao.name: 
+						data = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0)[:, 1:]
+						ao_featureset.append(data)
+						break
+
+			n_features = len(ao_featureset)		
+			if n_features == 0: continue 
+
+			if n_features > 1:
+				# find the feature with the fewer number of frames (the last a few frames should be generally empty)
+				n_frame = np.min([x.shape[0] for x in ao_featureset])
+				ao_featureset = [x[:n_frame,:] for x in ao_featureset]
+				feature_matrix = np.hstack((ao_featureset))
+			else:
+				feature_matrix = ao_featureset[0]	
+						
+			print "Processing data for audio file:", audio, n_features
+			if options.AUDIO == 'salami':
+				annotation_file = join(options.GT, ao.name+'.txt') # iso, salami
+				ao.gt = np.genfromtxt(annotation_file, usecols=0)	
+			elif options.AUDIO == 'qupujicheng':
+				annotation_file = join(options.GT, ao.name+'.csv') # qupujicheng
+				ao.gt = np.genfromtxt(annotation_file, usecols=0, delimiter=',')	
+			elif options.AUDIO == 'beatles':
+				annotation_file = join(options.GT, ao.name+'.lab') # beatles
+				ao.gt = np.genfromtxt(annotation_file, usecols=(0,1))
+				ao.gt = np.unique(np.ndarray.flatten(ao.gt))
+				
+			n_frames = feature_matrix.shape[0]
+			
+			timestamps = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0, usecols=0)
+			# map timestamps to the reduced representations
+			timestamps = timestamps[0::stepsize]
+			
+			# # normalise the feature matrix, get rid of negative features, ensure numerical stability by adding a small constant
+			feature_matrix = normaliseFeature(feature_matrix)
+			
+			# np.savetxt('test/feature_maxtrix-'+ao.name+"-wl%i-ss%i.txt" %(winlength,stepsize), feature_matrix, delimiter=',')
+			fr = basename(options.INPUT)
+			print 'fr', fr, options.INPUT
+			feature_name = ''
+			for feature in feature_list:
+				feature_name += ('-' + feature)
+			# np.savetxt(join(options.OUTPUT, test, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), feature_matrix, delimiter=',')
+			
+			# PCA
+			pca = PCA(n_components=6)
+			pca.fit(feature_matrix)
+			feature_matrix = pca.transform(feature_matrix)
+			
+			cach_filename = ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize)
+			self.set_cache_filename(filename=cach_filename, cache_location=options.CACHE)
+			gmm_list = self.getGMMs(feature_matrix, filename=cach_filename, gmmWindow=winlength, stepsize=stepsize)
+			print 'number of GMMs:', len(gmm_list)
+			
+			skl_matrix = self.pairwiseSKL(gmm_list)
+			ssm = getSSM(skl_matrix)
+			np.savetxt(join(options.CACHE, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), ssm, delimiter=",")
+			
+			# # 1. DBSCAN clustering of raw feature
+			# db1 = DBSCAN(eps=10, min_samples=10).fit(feature_array)
+			# core_samples_mask1 = np.zeros_like(db1.labels_, dtype=bool)
+			# core_samples_mask1[db1.core_sample_indices_] = True
+			# labels1 = db1.labels_
+			
+			# # 2. DBSCAN clustering of GMMs
+			# db2 = DBSCAN(eps=0.05, min_samples=10, metric='precomputed').fit(skl_matrix)
+			# core_samples_mask2 = np.zeros_like(db2.labels_, dtype=bool)
+			# core_samples_mask2[db2.core_sample_indices_] = True
+			# labels2 = db2.labels_
+			
+			# 3. RC clustering of raw GMMs
+			rc = rClustering(eps=1.15, k=5, rank='max_neighbors')
+			rc.set_cache_filename(ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize), cache_location=options.CACHE)
+			rc.fit(gmm_list)
+			rc.test()
+			classification = rc.classification
+			print 'classification', classification
+			neighborhood_size, average_div, node_rank = rc.getNodeRank()
+			
+			# centroid_list = self.getInitialCentroids(neighborhood_size, k=10)
+			# print 'initial centroids', centroid_list
+			
+			# k-means clustering of GMMs 
+			KmeansClustering = Kmeans(gmm_list, K=5, initial_centroids=set(classification))
+			labels = KmeansClustering.fit()
+						
+			f1 = np.array(zip(timestamps[:len(labels)], labels))
+			# f2 = np.array(zip(timestamps[:len(labels2)], labels2))
+			f3 = np.array(zip(timestamps[:len(classification)], classification))
+			f4 = np.array(zip(timestamps[:len(neighborhood_size)], neighborhood_size))
+			f5 = np.array(zip(timestamps[:len(node_rank)], node_rank))
+			f6 = np.array(zip(timestamps[:len(average_div)], average_div))
+			# 
+			np.savetxt(join(options.OUTPUT, 'kmeans')+splitext(audio)[0]+feature_name+splitext(audio)[0]+'.csv', f1, delimiter=',')
+			# np.savetxt(join(options.OUTPUT, 'dbscan_gmm')+splitext(audio)[0]+'.csv', f2, delimiter=',')
+			# np.savetxt(join(options.OUTPUT, 'classification')+splitext(audio)[0]+'.csv', f3, delimiter=',')
+			np.savetxt(join(options.OUTPUT, 'neighborhood_size')+splitext(audio)[0]+'.csv', f4, delimiter=',')
+			np.savetxt(join(options.OUTPUT, 'node_rank')+splitext(audio)[0]+'.csv', f5, delimiter=',')
+			np.savetxt(join(options.OUTPUT, 'average_div')+splitext(audio)[0]+'.csv', f6, delimiter=',')
+			
+def main():
+	segmenter = Segmenter()
+	segmenter.process()
+	
+
+if __name__ == '__main__':
+	main()
+