changeset 740:92f034aa8f28 multiprobeLSH

added lsh.py
author mas01mc
date Mon, 04 Oct 2010 19:12:00 +0000
parents 1e6cc843563a
children 50a7fd50578f
files lsh.py
diffstat 1 files changed, 512 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/lsh.py	Mon Oct 04 19:12:00 2010 +0000
@@ -0,0 +1,512 @@
+import random, numpy, pickle, os, operator, traceback, sys, math
+
+
+# Python implementation of Andoni's e2LSH.  This version is fast because it
+# uses Python hashes to implement the buckets.  The numerics are handled 
+# by the numpy routine so this should be close to optimal in speed (although
+# there is no control of the hash tables layout in memory.)
+
+# Malcolm Slaney, Yahoo! Research, 2009
+
+# Class that can be used to return random prime numbers.  We need this
+# for our hash functions.
+class primes():
+	# This list is from http://primes.utm.edu/lists/2small/0bit.html
+	def __init__(self):
+		p = {}
+		# These numbers, 2**i - j are primes
+		p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59]
+		p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55]
+		p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57]
+		p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61]
+		p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77]
+		p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91]
+		p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111]
+		p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135]
+		p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123]
+		p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99]
+		p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93]
+		p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87]
+		p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185]
+		p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129]
+		p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123]
+		p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159]
+		p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167]
+		p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183]
+		p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135]
+		p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235]
+		p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273]
+		p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133]
+		p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173]
+		p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171]
+
+		primes = []
+		for k in p:
+		        for m in p[k]:
+		                primes.append(2**k-m)
+		self.primes = primes
+		self.len = len(primes)
+		
+	def sample(self):
+		return self.primes[int(random.random()*self.len)]
+		
+
+# This class just implements k-projections into k integers
+# (after quantization) and then reducing that integer vector
+# into a T1 and T2 hash.  Data can either be entered into a
+# table, or retrieved.
+class lsh:
+	def __init__(self, k, w, N, t2size = (2**16 - 15)):
+		self.k = k	# Number of projections
+		self.w = w	# Bin width
+		self.N = N	# Number of buckets
+		self.t2size = t2size
+		self.projections = None
+		self.buckets = {}
+		
+	# Create the random constants needed for the projections.
+	def CreateProjections(self, dim):
+		self.dim = dim
+		p = primes()
+		self.projections = numpy.random.randn(self.k, self.dim)
+		self.bias = numpy.random.rand(self.k)
+		# Use the bias array just for its size (ignore contents)
+		self.t1hash = map(lambda x: p.sample(), self.bias)
+		self.t2hash = map(lambda x: p.sample(), self.bias)
+		if 0:
+			print "Dim is", self.dim
+			print 'Projections:\n', self.projections
+			print 'T1 hash:\n', self.t1hash
+			print 'T2 hash:\n', self.t2hash
+			
+	# Actually create the t1 and t2 hashes for some data.
+	def CalculateHashes(self, data):
+		if self.projections == None:
+			self.CreateProjections(len(data))
+		bins = numpy.zeros(self.k, 'i')
+		for i in range(0,self.k):
+			bins[i] = numpy.sum(data * self.projections[i,:])/self.w  \
+				+ self.bias[i]
+		t1 = numpy.sum(bins * self.t1hash) % self.N
+		t2 = numpy.sum(bins * self.t2hash) % self.t2size 
+		return t1, t2
+
+	def CalculateHashes2(self, data):
+		if self.projections == None:
+			self.CreateProjections(len(data))
+		bins = numpy.zeros(self.k, 'i')
+		for i in range(0,self.k):
+			bins[i] = numpy.sum(data * self.projections[i,:])/self.w  \
+				+ self.bias[i]
+		t1 = numpy.sum(bins * self.t1hash) % self.N
+		t2 = numpy.sum(bins * self.t2hash) % self.t2size 
+		return t1, t2, bins
+		
+	# Put some data into the hash bucket for this LSH projection
+	def InsertIntoTable(self, id, data):
+		(t1, t2) = self.CalculateHashes(data)
+		if t1 not in self.buckets:
+			self.buckets[t1] = {t2: [id]}
+		else:
+			if t2 not in self.buckets[t1]:
+				self.buckets[t1][t2] = [id]
+			else:
+				self.buckets[t1][t2].append(id)
+	
+	# Find some data in the hash bucket.  Return all the ids
+	# that we find for this T1-T2 pair.
+	def Find(self, data):
+		(t1, t2) = self.CalculateHashes(data)
+		if t1 not in self.buckets:
+			return []
+		row = self.buckets[t1]
+		if t2 not in row:
+			return []
+		return row[t2]
+		
+	# Create a dictionary showing all the buckets an ID appears in
+	def CreateDictionary(self, theDictionary, prefix):
+		for b in self.buckets:		# Over all buckets
+			w = prefix + str(b)
+			for c in self.buckets[b]:# Over all T2 hashes
+				for i in self.buckets[b][c]:#Over ids
+					if not i in theDictionary:
+						theDictionary[i] = [w]
+					else:
+						theDictionary[i] += w
+		return theDictionary
+		
+
+
+	# Print some stats for these lsh buckets
+	def Stats(self):
+		maxCount = 0; sumCount = 0; 
+		numCount = 0; bucketLens = [];
+		for b in self.buckets:
+			for c in self.buckets[b]:
+				l = len(self.buckets[b][c])
+				if l > maxCount: 
+					maxCount = l
+					maxLoc = (b,c)
+					# print b,c,self.buckets[b][c]
+				sumCount += l
+				numCount += 1
+				bucketLens.append(l)
+		theValues = sorted(bucketLens)
+		med = theValues[(len(theValues)+1)/2-1]
+		print "Bucket Counts:"
+		print "\tTotal indexed points:", sumCount
+		print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N)
+		print "\tT2 Buckets used: %d/%d" % (numCount, self.N)
+		print "\tMaximum T2 chain length:", maxCount, "at", maxLoc
+		print "\tAverage T2 chain length:", float(sumCount)/numCount
+		print "\tMedian T2 chain length:", med
+	
+	# Get a list of all IDs that are contained in these hash buckets
+	def GetAllIndices(self):
+		theList = []
+		for b in self.buckets:
+			for c in self.buckets[b]:
+				theList += self.buckets[b][c]
+		return theList
+
+	# Put some data into the hash table, see how many collisions we get.
+	def Test(self, n):
+		self.buckets = {}
+		self.projections = None
+		d = numpy.array([.2,.3])
+		for i in range(0,n):
+			self.InsertIntoTable(i, d+i)
+		for i in range(0,n):
+			r = self.Find(d+i)
+			matches = sum(map(lambda x: x==i, r))
+			if matches == 0:
+				print "Couldn't find item", i
+			elif matches == 1:
+				pass
+			if len(r) > 1: 
+				print "Found big bin for", i,":", r
+	
+
+# Put together several LSH projections to form an index.  The only 
+# new parameter is the number of groups of projections (one LSH class
+# object per group.)
+class index:
+	def __init__(self, k, l, w, N):
+		self.k = k; 
+		self.l = l
+		self.w = w
+		self.N = N
+		self.projections = []
+		for i in range(0,l):	# Create all LSH buckets
+			self.projections.append(lsh(k, w, N))
+	# Insert some data into all LSH buckets
+	def InsertIntoTable(self, id, data):
+		for p in self.projections:
+			p.InsertIntoTable(id, data)
+	# Find some data in all the LSH buckets.
+	def Find(self, data):
+		items = []
+		for p in self.projections:
+			items += p.Find(data)	# Concatenate
+		# print "Find results are:", items
+		results = {}
+		for item in items: 
+		    results.setdefault(item, 0)
+		    results[item] += 1
+		s = sorted(results.items(), key=operator.itemgetter(1), \
+			reverse=True)
+		return s
+		
+	# Return a list of results: (id, distance**2, count)
+	def FindExact(self, data, GetData):
+		s = self.Find(data)
+		# print "Intermediate results are:", s
+		d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
+		ds = sorted(d, key=operator.itemgetter(1))
+		return ds
+	
+	# Do an exhaustive distance calculation looking for all points and their distance.	
+	# Return a list of results: (id, distance**2, count)
+	def FindAll(self, query, GetData):
+		s = []
+		allIDs = self.GetAllIndices()
+		for id in allIDs:
+			dist = ((GetData(id)-query)**2).sum()
+			s.append((id, dist, 0))
+		# print "Intermediate results are:", s
+		# d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
+		ds = sorted(s, key=operator.itemgetter(1))
+		return ds
+
+	# Return the number of points that are closer than radius to the query
+	def CountInsideRadius(self, data, GetData, radius):
+		matches = self.FindExact(data, GetData)
+		# print "CountInsideRadius found",len(matches),"matches"
+		radius2 = radius**2
+		count = sum(map(lambda (id,distance,count): distance<radius2, matches))
+		return count
+		
+	# Put some data into the hash tables.
+	def Test(self, n):
+		d = numpy.array([.2,.3])
+		for i in range(0,n): 
+			self.InsertIntoTable(i, d+i)
+		for i in range(0,n):
+			r = self.Find(d+i)
+			print r
+	
+	# Print the statistics of each hash table.
+	def Stats(self):
+		for i in range(0, len(self.projections)):
+			p = self.projections[i]
+			print "Buckets", i, 
+			p.Stats()
+
+	# Get al the IDs that are part of this index.  Just check one hash
+	def GetAllIndices(self):
+		if self.projections and len(self.projections) > 0:
+			p = self.projections[0]
+			return p.GetAllIndices()
+		return None
+			
+	# Return the buckets (t1 and t2 hashes) associated with a data point
+	def GetBuckets(data):
+		b = []
+		for p in self.projections:
+			h = p.CalculateHashes(data)
+			b += h
+	
+	# Create a list ordered by ID listing which buckets are used for each ID
+	def CreateDictionary():
+		theDictionary = {}
+		prefixes = 'abcdefghijklmnopqrstuvwxyz'
+		pi = 0
+		for p in self.projections:
+			prefix = 'W'
+			pc = pi
+			while pc > 0:	# Create unique ID for theis bucket
+				prefix += prefixes[pc%len(prefixes)]
+				pc /= len(prefixes)
+			theDictionary = p.CreateDictionary(theDictionary,\
+				prefix)
+			pi += 1
+		return theDictionary
+	
+	# Use the expression in "Analysis of Minimum Distances in High-Dimensional
+	# Musical Spaces" to calculate the underlying dimensionality of the data
+	# For a random selection of ids, find the nearest neighbors and use this 
+	# to calculate the dimensionality of the data.
+	def MeasureDimensionality(self,allData,N):
+		allIDs = self.GetAllIndices()
+		sampleIDs = random.sample(allIDs, N)
+		L = 0.0; S=0.0
+		for id in sampleIDs:
+			res = self.FindExact(allData[id,:], lambda i:allData[i, :])
+			if len(res) > 1:
+				(nnid, dist, count) = res[1]
+				S += dist
+				L += math.log(dist)
+			else:
+				N -= 1
+		print "S="+str(S), "L="+str(L), "N="+str(N) 
+		if N > 1:
+			x = math.log(S/N) - L/N		# Equation 17			
+			d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y))
+			print d
+			return d
+		else:
+			return 0
+
+# Only works for monotonic functions... Uses geometric midpoint to reduce
+# the search range, looking for the function output that equals the given value
+# Test with:
+#	lsh.InvertFunction(2,math.sqrt)
+#	lsh.InvertFunction(2,lambda x:1.0/x)
+# Needed for inverting the gamma function in the MeasureDimensionality method.
+def InvertFunction(x, func):
+	min = 0.0001; max = 1000;
+	if func(min) < func(max):
+		sign = 1
+	else:
+		sign = -1
+	print "Looking for Y() =", str(x), "d'=", sign
+	while min + 1e-7 < max:
+		mid = math.sqrt(min*max)
+		Y = func(mid)
+		# print min, mid, Y, max
+		if sign*Y > sign*x:
+			max = mid
+		else:
+			min = mid
+	return mid	
+
+#####  A bunch of routines used to generate data we can use to test
+# this LSH implementation.
+
+global gLSHTestData
+gLSHTestData = []
+
+# Find a point in the array of data.  (Needed so FindExact can get the
+# data it needs.)
+def FindLSHTestData(id):
+	global gLSHTestData
+	if id < gLSHTestData.shape[0]:
+		return gLSHTestData[id,:]
+	return None
+
+# Fill the test array with uniform random data between 0 and 1
+def CreateRandomLSHTestData(numPoints, dim):
+	global gLSHTestData
+	gLSHTestData = []
+	gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0
+
+# Fill the test array with a regular grid of points between -1 and 1
+def CreateRegularLSHTestData(numDivs):
+	gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2))
+	i = 0
+	for x in range(-numDivs, numDivs+1):
+		for y in range(-numDivs, numDivs+1):
+			gLSHTestData[i,0] = x/float(divs)
+			gLSHTestData[i,1] = y/float(divs)
+			i += 1
+	numPoints = i
+
+# Use Nearest Neighbor properties to calculate dimensionality.
+def TestDimensionality(N):
+	numPoints = 100000
+	k = 10
+	CreateRandomLSHTestData(numPoints, 3)	
+	ind = index(k, 2, .1, 100)
+	for i in range(0,numPoints):
+		ind.InsertIntoTable(i, FindLSHTestData(i))
+	ind.MeasureDimensionality(gLSHTestData, N)
+
+# Use Dimension Doubling to measure the dimensionality of a random
+# set of data.  Generate some data (either random Gaussian or a grid)
+# Then count the number of points that fall within the given radius of this query.
+def TestDimensionality2():
+	global gLSHTestData
+	binWidth = .5
+	if True:
+		numPoints = 100000
+		CreateRandomLSHTestData(numPoints, 3)	
+	else:
+		CreateRegularLSHTestData(100)
+		numPoints = gLSHTestData.shape[0]
+	k = 4; l = 2; N = 1000
+	ind = index(k, l, binWidth, N)
+	for i in range(0,numPoints):
+		ind.InsertIntoTable(i, gLSHTestData[i,:])
+	rBig = binWidth/8.0
+	rSmall = rBig/2.0
+	cBig = 0.0; cSmall = 0.0
+	for id in random.sample(ind.GetAllIndices(), 2):
+		qp = FindLSHTestData(id)
+		cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig)
+		cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall)
+	if cBig > cSmall and cSmall > 0:
+		dim = math.log(cBig/cSmall)/math.log(rBig/rSmall)
+	else:
+		dim = 0
+	print cBig, cSmall, dim
+	return ind
+
+# Call an external process to compute the digamma function (from the GNU Scientific Library)
+import subprocess
+def digamma(x):
+	y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0]
+	return float(y.strip())
+		   
+                   
+# Generate some 2-dimensional data, put it into an index and then
+# show the points retrieved.  This is all done as a function of number
+# of projections per bucket, number of buckets to use for each index, and
+# the number of LSH bucket (the T1 size).  Write out the data so we can
+# plot it (in Matlab)
+def GraphicalTest(k, l, N):
+	numPoints = 1000
+	CreateRandomLSHTestData(numPoints, 3)	
+	ind = index(k, l, .1, N)
+	for i in range(0,numPoints):
+		ind.InsertIntoTable(i, FindLSHTestData(i))
+	i = 42
+	r = ind.Find(data[i,:])
+	fp = open('lshtestpoints.txt','w')
+	for i in range(0,numPoints):
+		if i in r: 
+			c = r[i]
+		else:
+			c = 0
+		fp.write("%g %g %d\n" % (data[i,0], data[i,1], c))
+	fp.close()
+	return r
+		
+
+# Run one LSH test.  Look for point 42 in the data.
+def ExactTest():
+	global gLSHTestData
+	numPoints = 1000
+	CreateRandomLSHTestData(numPoints, 2)
+	ind = index(10, 2, .1, 100)
+	for i in range(0,numPoints):
+		ind.InsertIntoTable(i, FindLSHTestData(i))
+	data = FindLSHTestData(42)
+	res = ind.FindExact(data, FindLSHTestData)
+	return res
+	
+# Create a file with distances retrieved as a function of k.
+# First line is the exact result, showing all points in the dB.
+# Successive lines are results for an LSH index.
+def TestRetrieval():
+	dims = 3
+	numPoints = 100000
+	CreateRandomLSHTestData(numPoints, 3)
+	qp = FindLSHTestData(0)*0.0
+	fp = open('TestRetrieval.txt','w')
+	for l in range(1,5):
+		for k in range(1,6):
+			for iter in range(1,10):
+				print "Building an index with l="+str(l)+", k="+str(k)
+				ind = index(k, l, .1, 100)	# Build new index
+				for i in range(0,numPoints):
+					ind.InsertIntoTable(i, FindLSHTestData(i))
+				if k == 1 and l == 1:
+					matches = ind.FindAll(qp, FindLSHTestData)
+					fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
+					fp.write('\n')
+				matches = ind.FindExact(qp, FindLSHTestData)
+				fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
+				# Fill rest of the results with -1
+				fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist())))
+				fp.write('\n')
+	fp.close()
+	
+			
+# Save an LSH index to a pickle file.
+def SaveIndex(filename, ind):
+	try:
+		fp = open(filename, 'w')
+		pickle.dump(ind, fp)
+		fp.close()
+		statinfo = os.stat(filename,)
+		if statinfo:
+			print "Wrote out", statinfo.st_size, "bytes to", \
+				filename
+	except:
+		print "Couldn't pickle index to file", filename
+		traceback.print_exc(file=sys.stderr)
+
+# Read an LSH index from a pickle file.	
+def LoadIndex(filename):
+	try:
+		fp = open(filename, 'r')
+		ind = pickle.load(fp)
+		fp.close()
+		return ind
+	except:
+		print "Couldn't read pickle file", filename
+		traceback.print_exc(file=sys.stderr)
+		
+		
+