view lsh.py @ 741:50a7fd50578f multiprobeLSH

fixed error in lsh.py
author mas01mc
date Mon, 04 Oct 2010 19:24:55 +0000
parents 92f034aa8f28
children
line wrap: on
line source
import random, numpy, pickle, os, operator, traceback, sys, math


# Python implementation of Andoni's e2LSH.  This version is fast because it
# uses Python hashes to implement the buckets.  The numerics are handled 
# by the numpy routine so this should be close to optimal in speed (although
# there is no control of the hash tables layout in memory.)

# Malcolm Slaney, Yahoo! Research, 2009

# Class that can be used to return random prime numbers.  We need this
# for our hash functions.
class primes():
	# This list is from http://primes.utm.edu/lists/2small/0bit.html
	def __init__(self):
		p = {}
		# These numbers, 2**i - j are primes
		p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59]
		p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55]
		p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57]
		p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61]
		p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77]
		p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91]
		p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111]
		p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135]
		p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123]
		p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99]
		p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93]
		p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87]
		p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185]
		p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129]
		p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123]
		p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159]
		p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167]
		p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183]
		p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135]
		p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235]
		p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273]
		p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133]
		p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173]
		p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171]

		primes = []
		for k in p:
		        for m in p[k]:
		                primes.append(2**k-m)
		self.primes = primes
		self.len = len(primes)
		
	def sample(self):
		return self.primes[int(random.random()*self.len)]
		

# This class just implements k-projections into k integers
# (after quantization) and then reducing that integer vector
# into a T1 and T2 hash.  Data can either be entered into a
# table, or retrieved.
class lsh:
	def __init__(self, k, w, N, t2size = (2**16 - 15)):
		self.k = k	# Number of projections
		self.w = w	# Bin width
		self.N = N	# Number of buckets
		self.t2size = t2size
		self.projections = None
		self.buckets = {}
		
	# Create the random constants needed for the projections.
	def CreateProjections(self, dim):
		self.dim = dim
		p = primes()
		self.projections = numpy.random.randn(self.k, self.dim)
		self.bias = numpy.random.rand(self.k)
		# Use the bias array just for its size (ignore contents)
		self.t1hash = map(lambda x: p.sample(), self.bias)
		self.t2hash = map(lambda x: p.sample(), self.bias)
		if 0:
			print "Dim is", self.dim
			print 'Projections:\n', self.projections
			print 'T1 hash:\n', self.t1hash
			print 'T2 hash:\n', self.t2hash
			
	# Actually create the t1 and t2 hashes for some data.
	def CalculateHashes(self, data):
		if self.projections == None:
			self.CreateProjections(len(data))
		bins = numpy.zeros(self.k, 'i')
		for i in range(0,self.k):
			bins[i] = numpy.sum(data * self.projections[i,:])/self.w  \
				+ self.bias[i]
		t1 = numpy.sum(bins * self.t1hash) % self.N
		t2 = numpy.sum(bins * self.t2hash) % self.t2size 
		return t1, t2

	def CalculateHashes2(self, data):
		if self.projections == None:
			self.CreateProjections(len(data))
		bins = numpy.zeros(self.k, 'i')
		for i in range(0,self.k):
			bins[i] = numpy.sum(data * self.projections[i,:])/self.w  \
				+ self.bias[i]
		t1 = numpy.sum(bins * self.t1hash) % self.N
		t2 = numpy.sum(bins * self.t2hash) % self.t2size 
		return t1, t2, bins
		
	# Put some data into the hash bucket for this LSH projection
	def InsertIntoTable(self, id, data):
		(t1, t2) = self.CalculateHashes(data)
		if t1 not in self.buckets:
			self.buckets[t1] = {t2: [id]}
		else:
			if t2 not in self.buckets[t1]:
				self.buckets[t1][t2] = [id]
			else:
				self.buckets[t1][t2].append(id)
	
	# Find some data in the hash bucket.  Return all the ids
	# that we find for this T1-T2 pair.
	def Find(self, data):
		(t1, t2) = self.CalculateHashes(data)
		if t1 not in self.buckets:
			return []
		row = self.buckets[t1]
		if t2 not in row:
			return []
		return row[t2]
		
	# Create a dictionary showing all the buckets an ID appears in
	def CreateDictionary(self, theDictionary, prefix):
		for b in self.buckets:		# Over all buckets
			w = prefix + str(b)
			for c in self.buckets[b]:# Over all T2 hashes
				for i in self.buckets[b][c]:#Over ids
					if not i in theDictionary:
						theDictionary[i] = [w]
					else:
						theDictionary[i] += w
		return theDictionary
		


	# Print some stats for these lsh buckets
	def Stats(self):
		maxCount = 0; sumCount = 0; 
		numCount = 0; bucketLens = [];
		for b in self.buckets:
			for c in self.buckets[b]:
				l = len(self.buckets[b][c])
				if l > maxCount: 
					maxCount = l
					maxLoc = (b,c)
					# print b,c,self.buckets[b][c]
				sumCount += l
				numCount += 1
				bucketLens.append(l)
		theValues = sorted(bucketLens)
		med = theValues[(len(theValues)+1)/2-1]
		print "Bucket Counts:"
		print "\tTotal indexed points:", sumCount
		print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N)
		print "\tT2 Buckets used: %d/%d" % (numCount, self.N)
		print "\tMaximum T2 chain length:", maxCount, "at", maxLoc
		print "\tAverage T2 chain length:", float(sumCount)/numCount
		print "\tMedian T2 chain length:", med
	
	# Get a list of all IDs that are contained in these hash buckets
	def GetAllIndices(self):
		theList = []
		for b in self.buckets:
			for c in self.buckets[b]:
				theList += self.buckets[b][c]
		return theList

	# Put some data into the hash table, see how many collisions we get.
	def Test(self, n):
		self.buckets = {}
		self.projections = None
		d = numpy.array([.2,.3])
		for i in range(0,n):
			self.InsertIntoTable(i, d+i)
		for i in range(0,n):
			r = self.Find(d+i)
			matches = sum(map(lambda x: x==i, r))
			if matches == 0:
				print "Couldn't find item", i
			elif matches == 1:
				pass
			if len(r) > 1: 
				print "Found big bin for", i,":", r
	

# Put together several LSH projections to form an index.  The only 
# new parameter is the number of groups of projections (one LSH class
# object per group.)
class index:
	def __init__(self, k, l, w, N):
		self.k = k; 
		self.l = l
		self.w = w
		self.N = N
		self.projections = []
		for i in range(0,l):	# Create all LSH buckets
			self.projections.append(lsh(k, w, N))
	# Insert some data into all LSH buckets
	def InsertIntoTable(self, id, data):
		for p in self.projections:
			p.InsertIntoTable(id, data)
	# Find some data in all the LSH buckets.
	def Find(self, data):
		items = []
		for p in self.projections:
			items += p.Find(data)	# Concatenate
		# print "Find results are:", items
		results = {}
		for item in items: 
		    results.setdefault(item, 0)
		    results[item] += 1
		s = sorted(results.items(), key=operator.itemgetter(1), \
			reverse=True)
		return s
		
	# Return a list of results: (id, distance**2, count)
	def FindExact(self, data, GetData):
		s = self.Find(data)
		# print "Intermediate results are:", s
		d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
		ds = sorted(d, key=operator.itemgetter(1))
		return ds
	
	# Do an exhaustive distance calculation looking for all points and their distance.	
	# Return a list of results: (id, distance**2, count)
	def FindAll(self, query, GetData):
		s = []
		allIDs = self.GetAllIndices()
		for id in allIDs:
			dist = ((GetData(id)-query)**2).sum()
			s.append((id, dist, 0))
		# print "Intermediate results are:", s
		# d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
		ds = sorted(s, key=operator.itemgetter(1))
		return ds

	# Return the number of points that are closer than radius to the query
	def CountInsideRadius(self, data, GetData, radius):
		matches = self.FindExact(data, GetData)
		# print "CountInsideRadius found",len(matches),"matches"
		radius2 = radius**2
		count = sum(map(lambda (id,distance,count): distance<radius2, matches))
		return count
		
	# Put some data into the hash tables.
	def Test(self, n):
		d = numpy.array([.2,.3])
		for i in range(0,n): 
			self.InsertIntoTable(i, d+i)
		for i in range(0,n):
			r = self.Find(d+i)
			print r
	
	# Print the statistics of each hash table.
	def Stats(self):
		for i in range(0, len(self.projections)):
			p = self.projections[i]
			print "Buckets", i, 
			p.Stats()

	# Get al the IDs that are part of this index.  Just check one hash
	def GetAllIndices(self):
		if self.projections and len(self.projections) > 0:
			p = self.projections[0]
			return p.GetAllIndices()
		return None
			
	# Return the buckets (t1 and t2 hashes) associated with a data point
	def GetBuckets(data):
		b = []
		for p in self.projections:
			h = p.CalculateHashes(data)
			b += h
	
	# Create a list ordered by ID listing which buckets are used for each ID
	def CreateDictionary():
		theDictionary = {}
		prefixes = 'abcdefghijklmnopqrstuvwxyz'
		pi = 0
		for p in self.projections:
			prefix = 'W'
			pc = pi
			while pc > 0:	# Create unique ID for theis bucket
				prefix += prefixes[pc%len(prefixes)]
				pc /= len(prefixes)
			theDictionary = p.CreateDictionary(theDictionary,\
				prefix)
			pi += 1
		return theDictionary
	
	# Use the expression in "Analysis of Minimum Distances in High-Dimensional
	# Musical Spaces" to calculate the underlying dimensionality of the data
	# For a random selection of ids, find the nearest neighbors and use this 
	# to calculate the dimensionality of the data.
	def MeasureDimensionality(self,allData,N):
		allIDs = self.GetAllIndices()
		sampleIDs = random.sample(allIDs, N)
		L = 0.0; S=0.0
		for id in sampleIDs:
			res = self.FindExact(allData[id,:], lambda i:allData[i, :])
			if len(res) > 1:
				(nnid, dist, count) = res[1]
				S += dist
				L += math.log(dist)
			else:
				N -= 1
		print "S="+str(S), "L="+str(L), "N="+str(N) 
		if N > 1:
			x = math.log(S/N) - L/N		# Equation 17			
			d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y))
			print d
			return d
		else:
			return 0

# Only works for monotonic functions... Uses geometric midpoint to reduce
# the search range, looking for the function output that equals the given value
# Test with:
#	lsh.InvertFunction(2,math.sqrt)
#	lsh.InvertFunction(2,lambda x:1.0/x)
# Needed for inverting the gamma function in the MeasureDimensionality method.
def InvertFunction(x, func):
	min = 0.0001; max = 1000;
	if func(min) < func(max):
		sign = 1
	else:
		sign = -1
	print "Looking for Y() =", str(x), "d'=", sign
	while min + 1e-7 < max:
		mid = math.sqrt(min*max)
		Y = func(mid)
		# print min, mid, Y, max
		if sign*Y > sign*x:
			max = mid
		else:
			min = mid
	return mid	

#####  A bunch of routines used to generate data we can use to test
# this LSH implementation.

global gLSHTestData
gLSHTestData = []

# Find a point in the array of data.  (Needed so FindExact can get the
# data it needs.)
def FindLSHTestData(id):
	global gLSHTestData
	if id < gLSHTestData.shape[0]:
		return gLSHTestData[id,:]
	return None

# Fill the test array with uniform random data between 0 and 1
def CreateRandomLSHTestData(numPoints, dim):
	global gLSHTestData
	gLSHTestData = []
	gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0

# Fill the test array with a regular grid of points between -1 and 1
def CreateRegularLSHTestData(numDivs):
	gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2))
	i = 0
	for x in range(-numDivs, numDivs+1):
		for y in range(-numDivs, numDivs+1):
			gLSHTestData[i,0] = x/float(divs)
			gLSHTestData[i,1] = y/float(divs)
			i += 1
	numPoints = i

# Use Nearest Neighbor properties to calculate dimensionality.
def TestDimensionality(N):
	numPoints = 100000
	k = 10
	CreateRandomLSHTestData(numPoints, 3)	
	ind = index(k, 2, .1, 100)
	for i in range(0,numPoints):
		ind.InsertIntoTable(i, FindLSHTestData(i))
	ind.MeasureDimensionality(gLSHTestData, N)

# Use Dimension Doubling to measure the dimensionality of a random
# set of data.  Generate some data (either random Gaussian or a grid)
# Then count the number of points that fall within the given radius of this query.
def TestDimensionality2():
	global gLSHTestData
	binWidth = .5
	if True:
		numPoints = 100000
		CreateRandomLSHTestData(numPoints, 3)	
	else:
		CreateRegularLSHTestData(100)
		numPoints = gLSHTestData.shape[0]
	k = 4; l = 2; N = 1000
	ind = index(k, l, binWidth, N)
	for i in range(0,numPoints):
		ind.InsertIntoTable(i, gLSHTestData[i,:])
	rBig = binWidth/8.0
	rSmall = rBig/2.0
	cBig = 0.0; cSmall = 0.0
	for id in random.sample(ind.GetAllIndices(), 2):
		qp = FindLSHTestData(id)
		cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig)
		cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall)
	if cBig > cSmall and cSmall > 0:
		dim = math.log(cBig/cSmall)/math.log(rBig/rSmall)
	else:
		dim = 0
	print cBig, cSmall, dim
	return ind

# Call an external process to compute the digamma function (from the GNU Scientific Library)
import subprocess
def digamma(x):
	y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0]
	return float(y.strip())
		   
                   
# Generate some 2-dimensional data, put it into an index and then
# show the points retrieved.  This is all done as a function of number
# of projections per bucket, number of buckets to use for each index, and
# the number of LSH bucket (the T1 size).  Write out the data so we can
# plot it (in Matlab)
def GraphicalTest(k, l, N):
	global gLSHTestData
	numPoints = 1000
	CreateRandomLSHTestData(numPoints, 3)	
	ind = index(k, l, .1, N)
	for i in range(0,numPoints):
		ind.InsertIntoTable(i, FindLSHTestData(i))
	data=gLSHTestData
	i = 42
	r = ind.Find(data[i,:])
	fp = open('lshtestpoints.txt','w')
	for i in range(0,numPoints):
		if i in r: 
			c = r[i]
		else:
			c = 0
		fp.write("%g %g %d\n" % (data[i,0], data[i,1], c))
	fp.close()
	return r
		

# Run one LSH test.  Look for point 42 in the data.
def ExactTest():
	global gLSHTestData
	numPoints = 1000
	CreateRandomLSHTestData(numPoints, 2)
	ind = index(10, 2, .1, 100)
	for i in range(0,numPoints):
		ind.InsertIntoTable(i, FindLSHTestData(i))
	data = FindLSHTestData(42)
	res = ind.FindExact(data, FindLSHTestData)
	return res
	
# Create a file with distances retrieved as a function of k.
# First line is the exact result, showing all points in the dB.
# Successive lines are results for an LSH index.
def TestRetrieval():
	dims = 3
	numPoints = 100000
	CreateRandomLSHTestData(numPoints, 3)
	qp = FindLSHTestData(0)*0.0
	fp = open('TestRetrieval.txt','w')
	for l in range(1,5):
		for k in range(1,6):
			for iter in range(1,10):
				print "Building an index with l="+str(l)+", k="+str(k)
				ind = index(k, l, .1, 100)	# Build new index
				for i in range(0,numPoints):
					ind.InsertIntoTable(i, FindLSHTestData(i))
				if k == 1 and l == 1:
					matches = ind.FindAll(qp, FindLSHTestData)
					fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
					fp.write('\n')
				matches = ind.FindExact(qp, FindLSHTestData)
				fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
				# Fill rest of the results with -1
				fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist())))
				fp.write('\n')
	fp.close()
	
			
# Save an LSH index to a pickle file.
def SaveIndex(filename, ind):
	try:
		fp = open(filename, 'w')
		pickle.dump(ind, fp)
		fp.close()
		statinfo = os.stat(filename,)
		if statinfo:
			print "Wrote out", statinfo.st_size, "bytes to", \
				filename
	except:
		print "Couldn't pickle index to file", filename
		traceback.print_exc(file=sys.stderr)

# Read an LSH index from a pickle file.	
def LoadIndex(filename):
	try:
		fp = open(filename, 'r')
		ind = pickle.load(fp)
		fp.close()
		return ind
	except:
		print "Couldn't read pickle file", filename
		traceback.print_exc(file=sys.stderr)