Mercurial > hg > segmentation
diff ProbSegmenter.py @ 13:cc8ceb270e79
add some gmm ipynbs
author | mitian |
---|---|
date | Fri, 05 Jun 2015 18:02:05 +0100 |
parents | |
children | 6dae41887406 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/ProbSegmenter.py Fri Jun 05 18:02:05 2015 +0100 @@ -0,0 +1,262 @@ +#!/usr/bin/env python +# encoding: utf-8 +""" +ProbSegmenter.py +""" + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import sys, os, optparse, cPickle +import numpy as np +from numpy import abs,log,exp,floor,sum,sqrt,cos,hstack, power +from math import ceil, floor +from scipy.signal import correlate2d, convolve2d +from itertools import combinations +from os.path import join, isdir, isfile, abspath, dirname, basename, split, splitext + +from sklearn.decomposition import PCA +from sklearn.mixture import GMM +from sklearn.metrics.pairwise import pairwise_distances +from sklearn.cluster import KMeans, DBSCAN +from scipy.spatial.distance import squareform, pdist + +from utils.GmmMetrics import GmmDistance +from utils.RankClustering import rClustering +from utils.kmeans import Kmeans +from utils.ComputationCache import Meta, with_pickle_dump +from utils.SegUtil import normaliseFeature, upSample, getSSM, getMean, getStd, getDelta + +def parse_args(): + # define parser + op = optparse.OptionParser() + # IO options + op.add_option('-i', '--input', action="store", dest="INPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/features1', type="str", help="Loading features from.." ) + op.add_option('-a', '--audioset', action="store", dest="AUDIO", default=None, type="str", help="Select audio datasets ['qupujicheng', 'salami', 'beatles'] ") + op.add_option('-g', '--groundtruth', action="store", dest="GT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/annotations/lowercase', type="str", help="Loading annotation files from.. ") + op.add_option('-o', '--output', action="store", dest="OUTPUT", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering', type="str", help="Loading annotation files from.. ") + op.add_option('-c', '--cache', action="store", dest="CACHE", default='/Volumes/c4dm-scratch/mi/seg/qupujicheng/clustering/cache', type="str", help="Loading annotation files from.. ") + op.add_option('-t', '--test', action="store_true", dest="TEST", help="Select TEST mode.") + op.add_option('-v', '--verbose', action="store_true", dest="VERBOSE", help="Select VERBOSE mode to save features/SSMs.") + + return op.parse_args() + +options, args = parse_args() + +class FeatureObj() : + __slots__ = ['key','audio','timestamps','features'] + +class AudioObj(): + __slots__ = ['feature_list','feature_matrix','distance_matrix','name'] + +class Segmenter(object): + '''The main segmentation object''' + + meta = Meta() + + @classmethod + def set_cache_filename(cls, filename, cache = True, cache_location=""): + cls.meta.cache = cache + cls.meta.cache_file_base = filename + cls.meta.cache_location = cache_location + + @with_pickle_dump(meta) + def getGMMs(self, feature, filename, gmmWindow=10, stepsize=1, save=True): + gmm_list = [] + steps = int((feature.shape[0] - gmmWindow + stepsize) / stepsize) + for i in xrange(steps): + gmm_list.append(GmmDistance(feature[i*stepsize:(i*stepsize+gmmWindow), :].T, components = 2)) + # if save: + # with open(join('cache', filename), 'w+') as f: + # f.write(cPickle.dumps(gmm_list)) + return gmm_list + + def pairwiseSKL(self, gmm_list): + '''Compute pairwise symmetrised KL divergence of a list of GMMs.''' + n_GMMs = len(gmm_list) + distance_matrix = np.zeros((n_GMMs, n_GMMs)) + for i in xrange(n_GMMs): + for j in xrange(i, n_GMMs): + distance_matrix[i][j] = gmm_list[i].skl_distance_full(gmm_list[j]) + distance_matrix[j][i] = distance_matrix[i][j] + + np.fill_diagonal(distance_matrix, 0.0) + distance_matrix[np.isnan(distance_matrix)] = 0 + # distance_matrix[np.isinf(distance_matrix)] = np.finfo(np.float64).max + if np.isinf(distance_matrix).any(): + data = np.sort(np.ndarray.flatten(distance_matrix)) + pos = np.where(data == np.inf)[0][0] + fMax = data[pos-1] + print len(data), pos, fMax + distance_matrix[np.isinf(distance_matrix)] = fMax + return distance_matrix + + def smoothLabels(self, label_list, size=5): + '''Smooth label list within given length.''' + prev_labels = -1 + next_labels = -1 + chain = 0 + for i in xrange(size, len(label_list)-size): + label = label_list[i] + if label == prev_label: + chain += 1 + else: + if chain < size: + label_list[i] = prev_label + chain = 0 + prev_label = label_list[i] + print chain + return label_list + + def getInitialCentroids(self, neighborhood_size, k=10): + candidates = [] + size = len(neighborhood_size) + for i in xrange(1, size-1): + d1 = neighborhood_size[i] - neighborhood_size[i-1] + d2 = neighborhood_size[i] - neighborhood_size[i+1] + if d1 > 0 and d2 > 0: + candidates.append((i, max(d1, d2))) + print 'candidates', len(candidates), candidates, size + ranked_nodes = sorted(candidates, key=lambda x: x[1],reverse=True)[:k] + ranked_nodes = [ranked_nodes[i][0] for i in xrange(len(ranked_nodes))] + return ranked_nodes + + + def process(self): + + audio_files = [x for x in os.listdir(options.GT) if not x.startswith(".") ] + audio_files.sort() + if options.TEST: + # audio_files = ["""17 We Are The Champions.wav"""] + audio_files = audio_files[:2] + audio_list = [] + + fobj_list = [] + feature_list = [i for i in os.listdir(options.INPUT) if not i.startswith('.')] + feature_list = ['pcamean', 'dct', 'contrast6'] + + feature_list.sort() + + winlength = 50 + stepsize = 50 + + if options.AUDIO == None: + print 'Must specify audio dataset for evaluvation!' + + for i, audio in enumerate(audio_files) : + ao = AudioObj() + ao.name = splitext(audio)[0] + + ao_featureset = [] + for feature in feature_list : + for f in os.listdir(join(options.INPUT, feature)): + if f[:f.find('_vamp')]==ao.name: + data = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0)[:, 1:] + ao_featureset.append(data) + break + + n_features = len(ao_featureset) + if n_features == 0: continue + + if n_features > 1: + # find the feature with the fewer number of frames (the last a few frames should be generally empty) + n_frame = np.min([x.shape[0] for x in ao_featureset]) + ao_featureset = [x[:n_frame,:] for x in ao_featureset] + feature_matrix = np.hstack((ao_featureset)) + else: + feature_matrix = ao_featureset[0] + + print "Processing data for audio file:", audio, n_features + if options.AUDIO == 'salami': + annotation_file = join(options.GT, ao.name+'.txt') # iso, salami + ao.gt = np.genfromtxt(annotation_file, usecols=0) + elif options.AUDIO == 'qupujicheng': + annotation_file = join(options.GT, ao.name+'.csv') # qupujicheng + ao.gt = np.genfromtxt(annotation_file, usecols=0, delimiter=',') + elif options.AUDIO == 'beatles': + annotation_file = join(options.GT, ao.name+'.lab') # beatles + ao.gt = np.genfromtxt(annotation_file, usecols=(0,1)) + ao.gt = np.unique(np.ndarray.flatten(ao.gt)) + + n_frames = feature_matrix.shape[0] + + timestamps = np.genfromtxt(join(options.INPUT, feature, f), delimiter=',', filling_values=0.0, usecols=0) + # map timestamps to the reduced representations + timestamps = timestamps[0::stepsize] + + # # normalise the feature matrix, get rid of negative features, ensure numerical stability by adding a small constant + feature_matrix = normaliseFeature(feature_matrix) + + # np.savetxt('test/feature_maxtrix-'+ao.name+"-wl%i-ss%i.txt" %(winlength,stepsize), feature_matrix, delimiter=',') + fr = basename(options.INPUT) + print 'fr', fr, options.INPUT + feature_name = '' + for feature in feature_list: + feature_name += ('-' + feature) + # np.savetxt(join(options.OUTPUT, test, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), feature_matrix, delimiter=',') + + # PCA + pca = PCA(n_components=6) + pca.fit(feature_matrix) + feature_matrix = pca.transform(feature_matrix) + + cach_filename = ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize) + self.set_cache_filename(filename=cach_filename, cache_location=options.CACHE) + gmm_list = self.getGMMs(feature_matrix, filename=cach_filename, gmmWindow=winlength, stepsize=stepsize) + print 'number of GMMs:', len(gmm_list) + + skl_matrix = self.pairwiseSKL(gmm_list) + ssm = getSSM(skl_matrix) + np.savetxt(join(options.CACHE, ao.name+feature_name+"-%s-wl%i-ss%i-ssm.txt" %(fr,winlength,stepsize)), ssm, delimiter=",") + + # # 1. DBSCAN clustering of raw feature + # db1 = DBSCAN(eps=10, min_samples=10).fit(feature_array) + # core_samples_mask1 = np.zeros_like(db1.labels_, dtype=bool) + # core_samples_mask1[db1.core_sample_indices_] = True + # labels1 = db1.labels_ + + # # 2. DBSCAN clustering of GMMs + # db2 = DBSCAN(eps=0.05, min_samples=10, metric='precomputed').fit(skl_matrix) + # core_samples_mask2 = np.zeros_like(db2.labels_, dtype=bool) + # core_samples_mask2[db2.core_sample_indices_] = True + # labels2 = db2.labels_ + + # 3. RC clustering of raw GMMs + rc = rClustering(eps=1.15, k=5, rank='max_neighbors') + rc.set_cache_filename(ao.name+feature_name+"-%s-wl%i-ss%i.txt" %(fr,winlength,stepsize), cache_location=options.CACHE) + rc.fit(gmm_list) + rc.test() + classification = rc.classification + print 'classification', classification + neighborhood_size, average_div, node_rank = rc.getNodeRank() + + # centroid_list = self.getInitialCentroids(neighborhood_size, k=10) + # print 'initial centroids', centroid_list + + # k-means clustering of GMMs + KmeansClustering = Kmeans(gmm_list, K=5, initial_centroids=set(classification)) + labels = KmeansClustering.fit() + + f1 = np.array(zip(timestamps[:len(labels)], labels)) + # f2 = np.array(zip(timestamps[:len(labels2)], labels2)) + f3 = np.array(zip(timestamps[:len(classification)], classification)) + f4 = np.array(zip(timestamps[:len(neighborhood_size)], neighborhood_size)) + f5 = np.array(zip(timestamps[:len(node_rank)], node_rank)) + f6 = np.array(zip(timestamps[:len(average_div)], average_div)) + # + np.savetxt(join(options.OUTPUT, 'kmeans')+splitext(audio)[0]+feature_name+splitext(audio)[0]+'.csv', f1, delimiter=',') + # np.savetxt(join(options.OUTPUT, 'dbscan_gmm')+splitext(audio)[0]+'.csv', f2, delimiter=',') + # np.savetxt(join(options.OUTPUT, 'classification')+splitext(audio)[0]+'.csv', f3, delimiter=',') + np.savetxt(join(options.OUTPUT, 'neighborhood_size')+splitext(audio)[0]+'.csv', f4, delimiter=',') + np.savetxt(join(options.OUTPUT, 'node_rank')+splitext(audio)[0]+'.csv', f5, delimiter=',') + np.savetxt(join(options.OUTPUT, 'average_div')+splitext(audio)[0]+'.csv', f6, delimiter=',') + +def main(): + segmenter = Segmenter() + segmenter.process() + + +if __name__ == '__main__': + main() +