annotate utils/gmmdist.py @ 19:890cfe424f4a tip

added annotations
author mitian
date Fri, 11 Dec 2015 09:47:40 +0000
parents 26838b1f560f
children
rev   line source
mi@0 1 #!/usr/bin/env python
mi@0 2 # encoding: utf-8
mi@0 3 """
mi@0 4 gmmdist.py
mi@0 5
mi@0 6 Created by George Fazekas on 2014-07-03.
mi@0 7 Copyright (c) 2014 . All rights reserved.
mi@0 8 """
mi@0 9
mi@0 10 import sys,os
mi@0 11
mi@0 12 from numpy import sum,isnan,isinf,vstack
mi@0 13 from numpy.random import rand
mi@0 14 import numpy as np
mi@0 15 from numpy import log, power, pi, exp, transpose, zeros, log, ones, dot
mi@0 16 from sklearn.mixture import GMM
mi@0 17 from sklearn.metrics.pairwise import pairwise_distances
mi@0 18 from scipy.linalg import *
mi@0 19 from pylab import plt
mi@0 20 from scipy.stats import norm
mi@0 21 #from gmmplot import plot_gmm
mi@0 22 from scipy.io import savemat,loadmat
mi@0 23 from numpy import trace
mi@0 24 from numpy.linalg import det, inv, eigvals
mi@0 25
mi@0 26 FLOAT_MAX = np.finfo(np.float64).max
mi@0 27
mi@0 28 def is_pos_def(x):
mi@0 29 '''Check if matrix is positive definite.'''
mi@0 30 return np.all(eigvals(x) > 0)
mi@0 31
mi@0 32 def skl_models(g1,g2):
mi@0 33 '''Wrapper function with error checking and adaptation to the GmmDistance/GaussianComponent classes.
mi@0 34 This function compares two Gaussian mixture models with and equal number of components and full covariance matrices.
mi@0 35 Covariance matrices must be positive definite.
mi@0 36 '''
mi@0 37 m1,m2 = g1.means,g2.means
mi@0 38 v1,v2 = g1.covars.swapaxes(0,2), g2.covars.swapaxes(0,2)
mi@0 39 w1,w2 = g1.weights[:,np.newaxis],g2.weights[:,np.newaxis]
mi@0 40 assert m1.shape[1] > 1, "The minimum number of features is 2."
mi@0 41 assert w1.shape == w2.shape, "Models must have the same number of components"
mi@0 42 # print 'v1, v2', v1.shape, v2.shape
mi@0 43 # assert (is_pos_def(v1) and is_pos_def(v2)) == True, "Covariance matrices must be positive definite."
mi@0 44 d = skl_gmm(m1,m2,v1,v2,w1,w2)
mi@0 45 if isnan(d): #or isinf(d):
mi@0 46 return FLOAT_MAX
mi@0 47 return d
mi@0 48
mi@0 49 # def kldiv_full(m0,m1,s0,s1):
mi@0 50 # '''Naive (unoptimised) implementation of the KL divergence between two single Gaussians with fully defined covariances (s).'''
mi@0 51 # return 0.5*(np.trace(s0/s1)+np.trace(s1/s0)+ np.dot( np.dot((m0-m1).T, np.linalg.inv(s0+s1)), (m0-m1)))-m0.shape[0]
mi@0 52 #
mi@0 53 # def skl_full(p0,p1):
mi@0 54 # '''Symmetrised KL divergence computed from 2 KL divergences using mean( KL(p||q), KL(q||p) )'''
mi@0 55 # d = (kldiv_full(p0.means,p0.covars,p1.means,p1.covars) + kldiv_full(p1.means,p1.covars,p0.means,p0.covars)) * 0.5
mi@0 56 # d = sum(d)
mi@0 57 # if isnan(d) :
mi@0 58 # return np.finfo(np.float64).max
mi@0 59 # return d
mi@0 60
mi@0 61 def kldiv_full(m1,m2,s1,s2):
mi@0 62 m1,m2 = m1[:,None],m2[:,None]
mi@0 63 logdet1, logdet2 = log(det(s1)), log(det(s2))
mi@0 64 inv1, inv2 = inv(s1), inv(s2)
mi@0 65 m = m1-m2
mi@0 66 d = m.shape[0] # number of dimensions
mi@0 67 return 0.5 * ((logdet1-logdet2) + trace(dot(inv1,s2)) + dot(dot(m.T, inv1), m) - d)[0][0]
mi@0 68
mi@0 69 def _skl_full(m1,m2,s1,s2):
mi@0 70 m1,m2 = m1[:,None],m2[:,None]
mi@0 71 logdet1, logdet2 = log(det(s1)), log(det(s2))
mi@0 72 inv1, inv2 = inv(s1), inv(s2)
mi@0 73 m12 = m1-m2
mi@0 74 m21 = m2-m1
mi@0 75 d = m12.shape[0] # number of dimensions
mi@0 76 kl12 = 0.5 * ((logdet1-logdet2) + trace(dot(inv1,s2)) + dot(dot(m12.T, inv1), m12) - d)[0][0]
mi@0 77 kl21 = 0.5 * ((logdet2-logdet1) + trace(dot(inv2,s1)) + dot(dot(m21.T, inv2), m21) - d)[0][0]
mi@0 78 return 0.5 * (kl12+kl21)
mi@0 79
mi@0 80 def skl_full(p1,p2):
mi@0 81 m1,m2 = p1.means,p2.means
mi@0 82 s1,s2 = p1.covars,p2.covars
mi@0 83 return _skl_full(m1,m2,s1,s2)
mi@0 84
mi@0 85
mi@0 86
mi@0 87 def skl_gmm(m1,m2,v1,v2,w1,w2):
mi@0 88 '''Take the mean of KL(g1mm||gmm2) & KL(gmm2||gmm1) to symmetrise the divergence.'''
mi@0 89 return (abs(kldiv_gmm(m1,m2,v1,v2,w1,w2)) + abs(kldiv_gmm(m2,m1,v2,v1,w2,w1))) * 0.5
mi@0 90
mi@0 91 def kldiv_gmm(m1,m2,v1,v2,w1,w2):
mi@0 92 '''Low level implementation of variational approximation of KL divergence between Gaussian Mixture Models.
mi@0 93 See first: J. R. Hershey and P. A. Olsen. "Approximating the Kullback-Leibler Divergence Between Gaussian Mixture Models."
mi@0 94 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Volume 4, pp. 317–320, April, 2007.
mi@0 95 Further theory and refinement: J. L. Durrieu, J. P. Thiran, F. Kelly. "Lower and Upper Bounds for Approximation
mi@0 96 of the Kullback-Leibler Divergence Between Gaussian Mixture Models", ICASSP, 2012.
mi@0 97 This implementation is by George Fazekas, Centre for Digital Music, QMUL, London, UK.
mi@0 98
mi@0 99 Inputs:
mi@0 100 m(x) : mean vector of gmm(x)
mi@0 101 v(x) : covariance matrices of gmm(x)
mi@0 102 w(x) : weight vector of gmm(x)
mi@0 103
mi@0 104 Ouptut:
mi@0 105 kl_full_12 : Kullback-Leibler divergence of the PDFs approximated by two Gaussian Mixture Models.
mi@0 106
mi@0 107 This implementation is using a variational approximation rather than the conventional (and expensive)
mi@0 108 Monte Carlo simulation based approach. See cited paper for details. The divergence is not symmetrised
mi@0 109 and may be negative (unlike the closed form) or inf. In case the output is complex, the coavriance matrices
mi@0 110 do not fulfill required criteria, i.e. somehow badly formed, sigular, not positive definite etc...
mi@0 111 '''
mi@0 112 # TODO: consider dieling better with inf/-inf outcomes in the final distance computation
mi@0 113 # - note: the max of the rows of kl12 are not always zero like that of kl11
mi@0 114 # - eliminate the need for swapaxes of the covariances
mi@0 115
mi@0 116 n = m1.shape[0] # number of components
mi@0 117 d = m1.shape[1] # number of dimensions
mi@0 118
mi@0 119 ixm = ones((n,1),dtype=int).T # selector of mean matrix components
mi@0 120 ixd = range(0,d*d,d+1) # indices of diagonal elements of DxD matrix
mi@0 121 t1 = v1.swapaxes(1,2).reshape(d,n*d) # concatenate gmm1 covariance matrices
mi@0 122 loopn = xrange(n)
mi@0 123
mi@0 124 # step 1) precompute log(determinant()) of covariance matrices of gmm1
mi@0 125 logdet1 = zeros((n,1))
mi@0 126 for i in loopn :
mi@0 127 logdet1[i] = log(det(v1[:,:,i]))
mi@0 128
mi@0 129 # step 2) compute reference kldiv between individual components of gmm1
mi@0 130 kl11 = zeros((n,n))
mi@0 131 for i in loopn :
mi@0 132 inv1 = inv(v1[:,:,i])
mi@0 133 mm1 = m1 - m1[i*ixm,:][0]
mi@0 134 b1 = dot(inv1,t1).swapaxes(0,1).reshape(n,power(d,2)).T
mi@0 135 kl11[:,i] = 0.5 * ( (logdet1[i]-d-logdet1)[:,0] + sum(b1[ixd,:],0).T + sum(dot(mm1,inv1) * mm1,1))
mi@0 136 # print kl11
mi@0 137
mi@0 138 # step 3) precompute log(determinant()) of covariance matrices of gmm2
mi@0 139 logdet2 = zeros((n,1))
mi@0 140 for i in loopn :
mi@0 141 logdet2[i] = log(det(v2[:,:,i]))
mi@0 142
mi@0 143 # step 4) compute pair-wise kldiv between components of gmm1 and gmm2
mi@0 144 kl12 = zeros((n,n))
mi@0 145 for i in loopn :
mi@0 146 inv2 = inv(v2[:,:,i])
mi@0 147 m12 = m1 - m2[i*ixm,:][0]
mi@0 148 b2 = dot(inv2,t1).swapaxes(0,1).reshape(n,power(d,2)).T
mi@0 149 kl12[:,i] = 0.5 * ( (logdet2[i]-d-logdet1)[:,0] + sum(b2[ixd,:],0).T + sum(dot(m12,inv2) * m12,1))
mi@0 150 # print kl12
mi@0 151
mi@0 152 # step 5) compute the final variational distance between gmm1 and gmm2
mi@0 153 kl_full_12 = dot(w1.T, (log(sum(exp(-kl11)*w1,1))) - log(sum(exp(-kl12)*w2,1)))[0]
mi@0 154 # print "KL divergence between gmm1 || gmm2:", kl_full_12
mi@0 155 return kl_full_12
mi@0 156
mi@0 157
mi@0 158 # models = loadmat("gmms.mat")
mi@0 159 # # print models.keys()
mi@0 160 #
mi@0 161 # X = models['X']
mi@0 162 # Y = models['Y']
mi@0 163 #
mi@0 164 # print "Data shape:"
mi@0 165 # print X.shape
mi@0 166 # print Y.shape
mi@0 167 #
mi@0 168 # # # plot the fitted model
mi@0 169 # # gmm1 = GMM(n_components = 3, covariance_type='full')
mi@0 170 # # model1 = gmm1.fit(X)
mi@0 171 # # plot_gmm(gmm1,X)
mi@0 172 # #
mi@0 173 # # # plot the fitted model
mi@0 174 # # gmm2 = GMM(n_components = 3, covariance_type='full')
mi@0 175 # # model2 = gmm2.fit(Y)
mi@0 176 # # plot_gmm(gmm2,Y)
mi@0 177 #
mi@0 178 # # print "KL=",kldiv_full(gmm1.means_[0],gmm1.means_[1],gmm1.covars_[0],gmm1.covars_[1])
mi@0 179 # #
mi@0 180 # # print "gmm1_covars:\n", gmm1.covars_, gmm1.covars_.shape
mi@0 181 #
mi@0 182 #
mi@0 183 # # m1 = gmm1.means_
mi@0 184 # # v1 = gmm1.covars_.swapaxes(0,2)
mi@0 185 # # w1 = gmm1.weights_
mi@0 186 # #
mi@0 187 # # m2 = gmm2.means_
mi@0 188 # # v2 = gmm2.covars_.swapaxes(0,2)
mi@0 189 # # w2 = gmm2.weights_
mi@0 190 #
mi@0 191 # m1 = models['gmm1_means']
mi@0 192 # v1 = models['gmm1_covars']
mi@0 193 # w1 = models['gmm1_weights']
mi@0 194 #
mi@0 195 # m2 = models['gmm2_means']
mi@0 196 # v2 = models['gmm2_covars']
mi@0 197 # w2 = models['gmm2_weights']
mi@0 198 #
mi@0 199 # print "KL divergence between gmm1 || gmm2:", kldiv_gmm(m1,m2,v1,v2,w1,w2)
mi@0 200 # print "KL divergence between gmm2 || gmm1:", kldiv_gmm(m2,m1,v2,v1,w2,w1)
mi@0 201 # print "Symmetrised KL distance between gmm1 || gmm2:", skl_gmm(m1,m2,v1,v2,w1,w2)
mi@0 202
mi@0 203 def main():
mi@0 204 pass
mi@0 205
mi@0 206
mi@0 207 if __name__ == '__main__':
mi@0 208 main()
mi@0 209