annotate utils/MutualInfo.py @ 19:890cfe424f4a tip

added annotations
author mitian
date Fri, 11 Dec 2015 09:47:40 +0000
parents 26838b1f560f
children
rev   line source
mi@0 1 import sys
mi@0 2 import scipy.spatial as ss
mi@0 3 from scipy.special import digamma,gamma
mi@0 4 from math import log,pi
mi@0 5 import numpy.random as nr
mi@0 6 import numpy as np
mi@0 7 import random
mi@0 8 from sklearn.metrics.pairwise import pairwise_distances
mi@0 9 from scipy.stats import ttest_ind, ttest_rel, pearsonr, norm
mi@0 10 from scipy.linalg import eigh, cholesky
mi@0 11
mi@0 12 def mi(x,y,k=3,base=2):
mi@0 13 """ Mutual information of x and y
mi@0 14 x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
mi@0 15 if x is a one-dimensional scalar and we have four samples
mi@0 16 """
mi@0 17 assert len(x)==len(y), "Lists should have same length"
mi@0 18 assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
mi@0 19 intens = 1e-10 #small noise to break degeneracy, see doc.
mi@0 20
mi@0 21 x = [list(p + intens*nr.rand(len(x[0]))) for p in x]
mi@0 22 y = [list(p + intens*nr.rand(len(y[0]))) for p in y]
mi@0 23 points = zip2(x,y)
mi@0 24 #Find nearest neighbors in joint space, p=inf means max-norm
mi@0 25 tree = ss.cKDTree(points)
mi@0 26 dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
mi@0 27
mi@0 28 a,b,c,d = avgdigamma(x,dvec), avgdigamma(y,dvec), digamma(k), digamma(len(x))
mi@0 29
mi@0 30 return (-a-b+c+d)/log(base)
mi@0 31
mi@0 32 def mi2(x,y,k=3,base=2):
mi@0 33 """ Mutual information of x and y
mi@0 34 x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
mi@0 35 if x is a one-dimensional scalar and we have four samples
mi@0 36 """
mi@0 37 assert len(x)==len(y), "Lists should have same length"
mi@0 38 assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
mi@0 39 intens = 1e-10 #small noise to break degeneracy, see doc.
mi@0 40
mi@0 41 x += intens * nr.rand(len(x))
mi@0 42 y += intens * nr.rand(len(y))
mi@0 43 points = np.array([x,y]).T
mi@0 44
mi@0 45 #Find nearest neighbors in joint space, p=inf means max-norm
mi@0 46 tree = ss.cKDTree(points)
mi@0 47 dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
mi@0 48 a,b,c,d = avgdigamma(x[np.newaxis,:].T, dvec), avgdigamma(y[np.newaxis,:].T, dvec), digamma(k), digamma(len(x))
mi@0 49
mi@0 50 mi = (-a-b+c+d)/log(base)
mi@0 51 if mi < 0:
mi@0 52 return 0.0
mi@0 53 return mi
mi@0 54
mi@0 55 def mi3(x,y,k=3,base=2):
mi@0 56 """ Mutual information of x and y
mi@0 57 x,y should be a list of vectors, e.g. x = [[1.3],[3.7],[5.1],[2.4]]
mi@0 58 if x is a one-dimensional scalar and we have four samples
mi@0 59 """
mi@0 60 if len(x) < 1000:
mi@0 61 return mi2(x,y,k,base)
mi@0 62
mi@0 63 intens = 1e-10 #small noise to break degeneracy, see doc.
mi@0 64
mi@0 65 sampleSize = 500
mi@0 66 c = digamma(k)
mi@0 67 d = digamma(sampleSize)
mi@0 68 num_iter = 1 + int(len(x)/1000)
mi@0 69
mi@0 70 mi_mean = np.zeros(num_iter,dtype=np.float64)
mi@0 71 for i in xrange(num_iter):
mi@0 72 ix = np.random.randint(low=0, high=len(x), size=sampleSize)
mi@0 73
mi@0 74 xs = x[ix]
mi@0 75 ys = y[ix]
mi@0 76 xs += intens * nr.rand(len(xs))
mi@0 77 ys += intens * nr.rand(len(ys))
mi@0 78 points = np.array([xs,ys]).T
mi@0 79
mi@0 80 #Find nearest neighbors in joint space, p=inf means max-norm
mi@0 81 tree = ss.cKDTree(points)
mi@0 82 dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
mi@0 83 a,b = avgdigamma(xs[np.newaxis,:].T, dvec), avgdigamma(ys[np.newaxis,:].T, dvec)
mi@0 84
mi@0 85 mi_mean[i] = (-a-b+c+d)/log(base)
mi@0 86
mi@0 87 mi = np.mean(mi_mean)
mi@0 88 if mi < 0:
mi@0 89 return 0.0
mi@0 90 return mi
mi@0 91
mi@0 92
mi@0 93 def mic(xs,ys,intens,s,k):
mi@0 94 xs += intens * nr.rand(s)
mi@0 95 ys += intens * nr.rand(s)
mi@0 96 points = np.array([xs,ys]).T
mi@0 97 tree = ss.cKDTree(points)
mi@0 98 dvec = [tree.query(point,k+1,p=float('inf'))[0][k] for point in points]
mi@0 99 return avgdigamma(xs[np.newaxis,:].T, dvec), avgdigamma(ys[np.newaxis,:].T, dvec)
mi@0 100
mi@0 101
mi@0 102 def dmi(x,y,k=3,base=2):
mi@0 103 ''' Mutual information distance between x and y.'''
mi@0 104
mi@0 105 if np.array_equal(x, y):
mi@0 106 return 0.0
mi@0 107 intens = 1e-10 #small noise to break degeneracy
mi@0 108 c = digamma(k)
mi@0 109 s = len(x)
mi@0 110 lb = 1.0/log(base)
mi@0 111
mi@0 112 # for small samples calculate mi directly
mi@0 113 if s < 1000:
mi@0 114 a,b = mic(x,y,intens,s,k)
mi@0 115 d = digamma(s)
mi@0 116 mx = (-c+d)*lb
mi@0 117 nmi = (-a-b+c+d)*lb / mx
mi@0 118 if nmi > 1 : nmi = 1.0 # handle the case when mi of correlated samples is overestimated
mi@0 119 if nmi < 0 : nmi = 0.0 # handle estimation error resulting in small negative values
mi@0 120 return 1.0-nmi
mi@0 121
mi@0 122 sampleSize = 500
mi@0 123 num_iter = 1 + int(s/1000)
mi@0 124 d = digamma(sampleSize)
mi@0 125
mi@0 126 mi_mean = np.zeros(num_iter,dtype=np.float64)
mi@0 127 for i in xrange(num_iter):
mi@0 128 ix = np.random.randint(low = 0, high = s, size=sampleSize)
mi@0 129 a,b = mic(x[ix],y[ix],intens,sampleSize,k)
mi@0 130 mi_mean[i] = (-a-b+c+d)*lb
mi@0 131
mi@0 132 mmi = np.mean(mi_mean)
mi@0 133 mx = (-c+d)*lb
mi@0 134 nmi = mmi / mx
mi@0 135 print mmi,mx,nmi
mi@0 136
mi@0 137 if nmi > 1 : nmi = 1.0 # handle the case when mi of correlated samples is overestimated
mi@0 138 if nmi < 0 : nmi = 0.0 # handle estimation error resulting in small negative values
mi@0 139 return 1.0-nmi
mi@0 140
mi@0 141
mi@0 142 def avgdigamma(points,dvec):
mi@0 143 #This part finds number of neighbors in some radius in the marginal space
mi@0 144 #returns expectation value of <psi(nx)>
mi@0 145 N = len(points)
mi@0 146 tree = ss.cKDTree(points)
mi@0 147 avg = 0.
mi@0 148 for i in range(N):
mi@0 149 dist = dvec[i]
mi@0 150 #subtlety, we don't include the boundary point,
mi@0 151 #but we are implicitly adding 1 to kraskov def bc center point is included
mi@0 152 num_points = len(tree.query_ball_point(points[i],dist-1e-15,p=float('inf')))
mi@0 153 avg += digamma(num_points)/N
mi@0 154 return avg
mi@0 155
mi@0 156 def zip2(*args):
mi@0 157 #zip2(x,y) takes the lists of vectors and makes it a list of vectors in a joint space
mi@0 158 #E.g. zip2([[1],[2],[3]],[[4],[5],[6]]) = [[1,4],[2,5],[3,6]]
mi@0 159 return [sum(sublist,[]) for sublist in zip(*args)]
mi@0 160
mi@0 161
mi@0 162 def test_mi(num_samples=9000):
mi@0 163 '''
mi@0 164 Generate correlated multivariate random variables:
mi@0 165 '''
mi@0 166
mi@0 167 # num_samples = 9000
mi@0 168
mi@0 169 # Generate samples from three independent normally distributed random
mi@0 170 # variables (with mean 0 and std. dev. 1).
mi@0 171 X = norm.rvs(size=(3, num_samples))
mi@0 172
mi@0 173 # The desired covariance matrix.
mi@0 174 r = np.array([
mi@0 175 [ 3.40, -2.75, -2.00],
mi@0 176 [ -2.75, 5.50, 1.50],
mi@0 177 [ -2.00, 1.50, 1.25]
mi@0 178 ])
mi@0 179
mi@0 180 # Choice of cholesky or eigenvector method.
mi@0 181 method = 'cholesky'
mi@0 182 #method = 'eigenvectors'
mi@0 183
mi@0 184 if method == 'cholesky':
mi@0 185 # Compute the Cholesky decomposition.
mi@0 186 c = cholesky(r, lower=True)
mi@0 187 else:
mi@0 188 # Compute the eigenvalues and eigenvectors.
mi@0 189 evals, evecs = eigh(r)
mi@0 190 # Construct c, so c*c^T = random.
mi@0 191 c = np.dot(evecs, np.diag(np.sqrt(evals)))
mi@0 192
mi@0 193 # Convert the data to correlated random variables.
mi@0 194 Y1 = np.dot(c, X)[2,:]
mi@0 195 Y2 = norm.rvs(size=(3, num_samples))[0,:]
mi@0 196 X = X[0,:]
mi@0 197
mi@0 198 xx = mi2(X, X)
mi@0 199 xy1 = mi2(X, Y1)
mi@0 200 xy2 = mi2(X, Y2)
mi@0 201 print 'identical', xx
mi@0 202 print 'correlated', xy1
mi@0 203 print 'uncorrelated', xy2
mi@0 204
mi@0 205 xx = mi3(X, X)
mi@0 206 xy1 = mi3(X, Y1)
mi@0 207 xy2 = mi3(X, Y2)
mi@0 208 print 'identical', xx
mi@0 209 print 'correlated', xy1
mi@0 210 print 'uncorrelated', xy2
mi@0 211
mi@0 212 xx = dmi(X, X)
mi@0 213 xy1 = dmi(X, Y1)
mi@0 214 xy2 = dmi(X, Y2)
mi@0 215 print 'identical', xx
mi@0 216 print 'correlated', xy1
mi@0 217 print 'uncorrelated', xy2
mi@0 218
mi@0 219
mi@0 220 def print_progress(counter="", message=""):
mi@0 221 sys.stdout.write("%(counter)s: %(message)s" %vars())
mi@0 222 sys.stdout.flush()
mi@0 223 sys.stdout.write("\r\r")
mi@0 224
mi@0 225 def test_direct(num_samples):
mi@0 226 X = norm.rvs(size=(1, num_samples))[0,:]
mi@0 227 return mi2(X, X)
mi@0 228
mi@0 229 def main():
mi@0 230 test_mi()
mi@0 231 raise SystemExit
mi@0 232
mi@0 233 import matplotlib.pyplot as plt
mi@0 234 figure = plt.figure()
mi@0 235 axis = figure.add_subplot(111)
mi@0 236 series = np.linspace(100,25000,20)
mi@0 237 # series = np.linspace(10,250,20)
mi@0 238
mi@0 239 # result = [test_direct(int(x)) for x in series]
mi@0 240 result = []
mi@0 241 for i,x in enumerate(series) :
mi@0 242 print_progress(i)
mi@0 243 result.append(test_direct(int(x)))
mi@0 244 axis.plot(series,result)
mi@0 245 plt.show()
mi@0 246 # test_direct(1500)
mi@0 247
mi@0 248 if __name__ == '__main__':
mi@0 249 main()