view utils/MutualInfo.py @ 19:890cfe424f4a tip

added annotations
author mitian
date Fri, 11 Dec 2015 09:47:40 +0000
parents 26838b1f560f
children
line wrap: on
line source
import sys
import scipy.spatial as ss
from scipy.special import digamma,gamma
from math import log,pi
import numpy.random as nr
import numpy as np
import random
from sklearn.metrics.pairwise import pairwise_distances
from scipy.stats import ttest_ind, ttest_rel, pearsonr, norm
from scipy.linalg import eigh, cholesky

def mi(x,y,k=3,base=2):
	""" Mutual information of x and y
		x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
		if x is a one-dimensional scalar and we have four samples
	"""
	assert len(x)==len(y), "Lists should have same length"
	assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
	intens = 1e-10 #small noise to break degeneracy, see doc.
	
	x = [list(p + intens*nr.rand(len(x[0]))) for p in x]
	y = [list(p + intens*nr.rand(len(y[0]))) for p in y]
	points = zip2(x,y)
	#Find nearest neighbors in joint space, p=inf means max-norm
	tree = ss.cKDTree(points)
	dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
	
	a,b,c,d = avgdigamma(x,dvec), avgdigamma(y,dvec), digamma(k), digamma(len(x)) 
	
	return (-a-b+c+d)/log(base)

def mi2(x,y,k=3,base=2):
	""" Mutual information of x and y
		x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
		if x is a one-dimensional scalar and we have four samples
	"""
	assert len(x)==len(y), "Lists should have same length"
	assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
	intens = 1e-10 #small noise to break degeneracy, see doc.

	x += intens * nr.rand(len(x))
	y += intens * nr.rand(len(y))
	points = np.array([x,y]).T

	#Find nearest neighbors in joint space, p=inf means max-norm
	tree = ss.cKDTree(points)
	dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
	a,b,c,d = avgdigamma(x[np.newaxis,:].T, dvec), avgdigamma(y[np.newaxis,:].T, dvec), digamma(k), digamma(len(x)) 
	
	mi = (-a-b+c+d)/log(base)
	if mi < 0:
		return 0.0
	return mi

def mi3(x,y,k=3,base=2):
	""" Mutual information of x and y
		x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
		if x is a one-dimensional scalar and we have four samples
	"""
	if len(x) < 1000:
		return mi2(x,y,k,base)
	
	intens = 1e-10 #small noise to break degeneracy, see doc.
	
	sampleSize = 500
	c = digamma(k)
	d = digamma(sampleSize)
	num_iter = 1 + int(len(x)/1000)

	mi_mean = np.zeros(num_iter,dtype=np.float64)
	for i in xrange(num_iter):
		ix = np.random.randint(low=0, high=len(x), size=sampleSize)
		
		xs = x[ix]
		ys = y[ix]
		xs += intens * nr.rand(len(xs))
		ys += intens * nr.rand(len(ys))
		points = np.array([xs,ys]).T

		#Find nearest neighbors in joint space, p=inf means max-norm
		tree = ss.cKDTree(points)
		dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
		a,b = avgdigamma(xs[np.newaxis,:].T, dvec), avgdigamma(ys[np.newaxis,:].T, dvec)
	
		mi_mean[i] = (-a-b+c+d)/log(base)
	
	mi = np.mean(mi_mean)
	if mi < 0:
		return 0.0
	return mi
	
	
def mic(xs,ys,intens,s,k):
	xs += intens * nr.rand(s)
	ys += intens * nr.rand(s)
	points = np.array([xs,ys]).T
	tree = ss.cKDTree(points)
	dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
	return avgdigamma(xs[np.newaxis,:].T, dvec), avgdigamma(ys[np.newaxis,:].T, dvec)
	

def dmi(x,y,k=3,base=2):
	''' Mutual information distance between x and y.'''
	
	if np.array_equal(x, y):
		return 0.0
	intens = 1e-10 #small noise to break degeneracy
	c = digamma(k)
	s = len(x)
	lb = 1.0/log(base)
	
	# for small samples calculate mi directly
	if s < 1000:
		a,b = mic(x,y,intens,s,k)
		d = digamma(s)
		mx  = (-c+d)*lb
		nmi = (-a-b+c+d)*lb / mx
		if nmi > 1 : nmi = 1.0 # handle the case when mi of correlated samples is overestimated 
		if nmi < 0 : nmi = 0.0 # handle estimation error resulting in small negative values 
		return 1.0-nmi	
	
	sampleSize = 500
	num_iter = 1 + int(s/1000)
	d = digamma(sampleSize)	

	mi_mean = np.zeros(num_iter,dtype=np.float64)
	for i in xrange(num_iter):
		ix = np.random.randint(low = 0, high = s, size=sampleSize)
		a,b = mic(x[ix],y[ix],intens,sampleSize,k)
		mi_mean[i] = (-a-b+c+d)*lb
	
	mmi = np.mean(mi_mean)
	mx  = (-c+d)*lb
	nmi = mmi / mx
	print mmi,mx,nmi
	
	if nmi > 1 : nmi = 1.0 # handle the case when mi of correlated samples is overestimated 
	if nmi < 0 : nmi = 0.0 # handle estimation error resulting in small negative values 
	return 1.0-nmi	
	
	
def avgdigamma(points,dvec):
	#This part finds number of neighbors in some radius in the marginal space
	#returns expectation value of <psi(nx)>
	N = len(points)
	tree = ss.cKDTree(points)
	avg = 0.
	for i in range(N):
		dist = dvec[i]
		#subtlety, we don't include the boundary point, 
		#but we are implicitly adding 1 to kraskov def bc center point is included
		num_points = len(tree.query_ball_point(points[i],dist-1e-15,p=float('inf'))) 
		avg += digamma(num_points)/N
	return avg

def zip2(*args):
	#zip2(x,y) takes the lists of vectors and makes it a list of vectors in a joint space
	#E.g. zip2([[1],[2],[3]],[[4],[5],[6]]) = [[1,4],[2,5],[3,6]]
	return [sum(sublist,[]) for sublist in zip(*args)]


def test_mi(num_samples=9000):
	'''
	Generate correlated multivariate random variables:
	'''

	# num_samples = 9000

	# Generate samples from three independent normally distributed random
	# variables (with mean 0 and std. dev. 1).
	X = norm.rvs(size=(3, num_samples))

	# The desired covariance matrix.
	r = np.array([
	        [  3.40, -2.75, -2.00],
	        [ -2.75,  5.50,  1.50],
	        [ -2.00,  1.50,  1.25]
	    ])

	# Choice of cholesky or eigenvector method.
	method = 'cholesky'
	#method = 'eigenvectors'

	if method == 'cholesky':
	    # Compute the Cholesky decomposition.
	    c = cholesky(r, lower=True)
	else:
	    # Compute the eigenvalues and eigenvectors.
	    evals, evecs = eigh(r)
	    # Construct c, so c*c^T = random.
	    c = np.dot(evecs, np.diag(np.sqrt(evals)))

	# Convert the data to correlated random variables. 
	Y1 = np.dot(c, X)[2,:]
	Y2 = norm.rvs(size=(3, num_samples))[0,:]
	X = X[0,:]

	xx = mi2(X, X)
	xy1 = mi2(X, Y1)
	xy2 = mi2(X, Y2)
	print 'identical', xx
	print 'correlated', xy1
	print 'uncorrelated', xy2

	xx = mi3(X, X)
	xy1 = mi3(X, Y1)
	xy2 = mi3(X, Y2)
	print 'identical', xx
	print 'correlated', xy1
	print 'uncorrelated', xy2
	
	xx = dmi(X, X)
	xy1 = dmi(X, Y1)
	xy2 = dmi(X, Y2)
	print 'identical', xx
	print 'correlated', xy1
	print 'uncorrelated', xy2
	
	
def print_progress(counter="", message=""):
	sys.stdout.write("%(counter)s: %(message)s" %vars())
	sys.stdout.flush()
	sys.stdout.write("\r\r")
	
def test_direct(num_samples):
	X = norm.rvs(size=(1, num_samples))[0,:]
	return mi2(X, X)
		
def main():
	test_mi()
	raise SystemExit
	
	import matplotlib.pyplot as plt
	figure = plt.figure()
	axis = figure.add_subplot(111)
	series = np.linspace(100,25000,20)
	# series = np.linspace(10,250,20)

	# result = [test_direct(int(x)) for x in series]
	result = []
	for i,x in enumerate(series) :
		print_progress(i)
		result.append(test_direct(int(x)))
	axis.plot(series,result)
	plt.show()
	# test_direct(1500)
	
if __name__ == '__main__':
	main()