# HG changeset patch # User mas01mc # Date 1286219520 0 # Node ID 92f034aa8f283cfe6e53f1f7932d0cdd263d2353 # Parent 1e6cc843563a598eba5ed990ecc2ebce9c919d39 added lsh.py diff -r 1e6cc843563a -r 92f034aa8f28 lsh.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/lsh.py Mon Oct 04 19:12:00 2010 +0000 @@ -0,0 +1,512 @@ +import random, numpy, pickle, os, operator, traceback, sys, math + + +# Python implementation of Andoni's e2LSH. This version is fast because it +# uses Python hashes to implement the buckets. The numerics are handled +# by the numpy routine so this should be close to optimal in speed (although +# there is no control of the hash tables layout in memory.) + +# Malcolm Slaney, Yahoo! Research, 2009 + +# Class that can be used to return random prime numbers. We need this +# for our hash functions. +class primes(): + # This list is from http://primes.utm.edu/lists/2small/0bit.html + def __init__(self): + p = {} + # These numbers, 2**i - j are primes + p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59] + p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55] + p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57] + p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61] + p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77] + p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91] + p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111] + p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135] + p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123] + p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99] + p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93] + p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87] + p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185] + p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129] + p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123] + p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159] + p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167] + p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183] + p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135] + p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235] + p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273] + p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133] + p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173] + p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171] + + primes = [] + for k in p: + for m in p[k]: + primes.append(2**k-m) + self.primes = primes + self.len = len(primes) + + def sample(self): + return self.primes[int(random.random()*self.len)] + + +# This class just implements k-projections into k integers +# (after quantization) and then reducing that integer vector +# into a T1 and T2 hash. Data can either be entered into a +# table, or retrieved. +class lsh: + def __init__(self, k, w, N, t2size = (2**16 - 15)): + self.k = k # Number of projections + self.w = w # Bin width + self.N = N # Number of buckets + self.t2size = t2size + self.projections = None + self.buckets = {} + + # Create the random constants needed for the projections. + def CreateProjections(self, dim): + self.dim = dim + p = primes() + self.projections = numpy.random.randn(self.k, self.dim) + self.bias = numpy.random.rand(self.k) + # Use the bias array just for its size (ignore contents) + self.t1hash = map(lambda x: p.sample(), self.bias) + self.t2hash = map(lambda x: p.sample(), self.bias) + if 0: + print "Dim is", self.dim + print 'Projections:\n', self.projections + print 'T1 hash:\n', self.t1hash + print 'T2 hash:\n', self.t2hash + + # Actually create the t1 and t2 hashes for some data. + def CalculateHashes(self, data): + if self.projections == None: + self.CreateProjections(len(data)) + bins = numpy.zeros(self.k, 'i') + for i in range(0,self.k): + bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ + + self.bias[i] + t1 = numpy.sum(bins * self.t1hash) % self.N + t2 = numpy.sum(bins * self.t2hash) % self.t2size + return t1, t2 + + def CalculateHashes2(self, data): + if self.projections == None: + self.CreateProjections(len(data)) + bins = numpy.zeros(self.k, 'i') + for i in range(0,self.k): + bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ + + self.bias[i] + t1 = numpy.sum(bins * self.t1hash) % self.N + t2 = numpy.sum(bins * self.t2hash) % self.t2size + return t1, t2, bins + + # Put some data into the hash bucket for this LSH projection + def InsertIntoTable(self, id, data): + (t1, t2) = self.CalculateHashes(data) + if t1 not in self.buckets: + self.buckets[t1] = {t2: [id]} + else: + if t2 not in self.buckets[t1]: + self.buckets[t1][t2] = [id] + else: + self.buckets[t1][t2].append(id) + + # Find some data in the hash bucket. Return all the ids + # that we find for this T1-T2 pair. + def Find(self, data): + (t1, t2) = self.CalculateHashes(data) + if t1 not in self.buckets: + return [] + row = self.buckets[t1] + if t2 not in row: + return [] + return row[t2] + + # Create a dictionary showing all the buckets an ID appears in + def CreateDictionary(self, theDictionary, prefix): + for b in self.buckets: # Over all buckets + w = prefix + str(b) + for c in self.buckets[b]:# Over all T2 hashes + for i in self.buckets[b][c]:#Over ids + if not i in theDictionary: + theDictionary[i] = [w] + else: + theDictionary[i] += w + return theDictionary + + + + # Print some stats for these lsh buckets + def Stats(self): + maxCount = 0; sumCount = 0; + numCount = 0; bucketLens = []; + for b in self.buckets: + for c in self.buckets[b]: + l = len(self.buckets[b][c]) + if l > maxCount: + maxCount = l + maxLoc = (b,c) + # print b,c,self.buckets[b][c] + sumCount += l + numCount += 1 + bucketLens.append(l) + theValues = sorted(bucketLens) + med = theValues[(len(theValues)+1)/2-1] + print "Bucket Counts:" + print "\tTotal indexed points:", sumCount + print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N) + print "\tT2 Buckets used: %d/%d" % (numCount, self.N) + print "\tMaximum T2 chain length:", maxCount, "at", maxLoc + print "\tAverage T2 chain length:", float(sumCount)/numCount + print "\tMedian T2 chain length:", med + + # Get a list of all IDs that are contained in these hash buckets + def GetAllIndices(self): + theList = [] + for b in self.buckets: + for c in self.buckets[b]: + theList += self.buckets[b][c] + return theList + + # Put some data into the hash table, see how many collisions we get. + def Test(self, n): + self.buckets = {} + self.projections = None + d = numpy.array([.2,.3]) + for i in range(0,n): + self.InsertIntoTable(i, d+i) + for i in range(0,n): + r = self.Find(d+i) + matches = sum(map(lambda x: x==i, r)) + if matches == 0: + print "Couldn't find item", i + elif matches == 1: + pass + if len(r) > 1: + print "Found big bin for", i,":", r + + +# Put together several LSH projections to form an index. The only +# new parameter is the number of groups of projections (one LSH class +# object per group.) +class index: + def __init__(self, k, l, w, N): + self.k = k; + self.l = l + self.w = w + self.N = N + self.projections = [] + for i in range(0,l): # Create all LSH buckets + self.projections.append(lsh(k, w, N)) + # Insert some data into all LSH buckets + def InsertIntoTable(self, id, data): + for p in self.projections: + p.InsertIntoTable(id, data) + # Find some data in all the LSH buckets. + def Find(self, data): + items = [] + for p in self.projections: + items += p.Find(data) # Concatenate + # print "Find results are:", items + results = {} + for item in items: + results.setdefault(item, 0) + results[item] += 1 + s = sorted(results.items(), key=operator.itemgetter(1), \ + reverse=True) + return s + + # Return a list of results: (id, distance**2, count) + def FindExact(self, data, GetData): + s = self.Find(data) + # print "Intermediate results are:", s + d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) + ds = sorted(d, key=operator.itemgetter(1)) + return ds + + # Do an exhaustive distance calculation looking for all points and their distance. + # Return a list of results: (id, distance**2, count) + def FindAll(self, query, GetData): + s = [] + allIDs = self.GetAllIndices() + for id in allIDs: + dist = ((GetData(id)-query)**2).sum() + s.append((id, dist, 0)) + # print "Intermediate results are:", s + # d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) + ds = sorted(s, key=operator.itemgetter(1)) + return ds + + # Return the number of points that are closer than radius to the query + def CountInsideRadius(self, data, GetData, radius): + matches = self.FindExact(data, GetData) + # print "CountInsideRadius found",len(matches),"matches" + radius2 = radius**2 + count = sum(map(lambda (id,distance,count): distance 0: + p = self.projections[0] + return p.GetAllIndices() + return None + + # Return the buckets (t1 and t2 hashes) associated with a data point + def GetBuckets(data): + b = [] + for p in self.projections: + h = p.CalculateHashes(data) + b += h + + # Create a list ordered by ID listing which buckets are used for each ID + def CreateDictionary(): + theDictionary = {} + prefixes = 'abcdefghijklmnopqrstuvwxyz' + pi = 0 + for p in self.projections: + prefix = 'W' + pc = pi + while pc > 0: # Create unique ID for theis bucket + prefix += prefixes[pc%len(prefixes)] + pc /= len(prefixes) + theDictionary = p.CreateDictionary(theDictionary,\ + prefix) + pi += 1 + return theDictionary + + # Use the expression in "Analysis of Minimum Distances in High-Dimensional + # Musical Spaces" to calculate the underlying dimensionality of the data + # For a random selection of ids, find the nearest neighbors and use this + # to calculate the dimensionality of the data. + def MeasureDimensionality(self,allData,N): + allIDs = self.GetAllIndices() + sampleIDs = random.sample(allIDs, N) + L = 0.0; S=0.0 + for id in sampleIDs: + res = self.FindExact(allData[id,:], lambda i:allData[i, :]) + if len(res) > 1: + (nnid, dist, count) = res[1] + S += dist + L += math.log(dist) + else: + N -= 1 + print "S="+str(S), "L="+str(L), "N="+str(N) + if N > 1: + x = math.log(S/N) - L/N # Equation 17 + d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y)) + print d + return d + else: + return 0 + +# Only works for monotonic functions... Uses geometric midpoint to reduce +# the search range, looking for the function output that equals the given value +# Test with: +# lsh.InvertFunction(2,math.sqrt) +# lsh.InvertFunction(2,lambda x:1.0/x) +# Needed for inverting the gamma function in the MeasureDimensionality method. +def InvertFunction(x, func): + min = 0.0001; max = 1000; + if func(min) < func(max): + sign = 1 + else: + sign = -1 + print "Looking for Y() =", str(x), "d'=", sign + while min + 1e-7 < max: + mid = math.sqrt(min*max) + Y = func(mid) + # print min, mid, Y, max + if sign*Y > sign*x: + max = mid + else: + min = mid + return mid + +##### A bunch of routines used to generate data we can use to test +# this LSH implementation. + +global gLSHTestData +gLSHTestData = [] + +# Find a point in the array of data. (Needed so FindExact can get the +# data it needs.) +def FindLSHTestData(id): + global gLSHTestData + if id < gLSHTestData.shape[0]: + return gLSHTestData[id,:] + return None + +# Fill the test array with uniform random data between 0 and 1 +def CreateRandomLSHTestData(numPoints, dim): + global gLSHTestData + gLSHTestData = [] + gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0 + +# Fill the test array with a regular grid of points between -1 and 1 +def CreateRegularLSHTestData(numDivs): + gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2)) + i = 0 + for x in range(-numDivs, numDivs+1): + for y in range(-numDivs, numDivs+1): + gLSHTestData[i,0] = x/float(divs) + gLSHTestData[i,1] = y/float(divs) + i += 1 + numPoints = i + +# Use Nearest Neighbor properties to calculate dimensionality. +def TestDimensionality(N): + numPoints = 100000 + k = 10 + CreateRandomLSHTestData(numPoints, 3) + ind = index(k, 2, .1, 100) + for i in range(0,numPoints): + ind.InsertIntoTable(i, FindLSHTestData(i)) + ind.MeasureDimensionality(gLSHTestData, N) + +# Use Dimension Doubling to measure the dimensionality of a random +# set of data. Generate some data (either random Gaussian or a grid) +# Then count the number of points that fall within the given radius of this query. +def TestDimensionality2(): + global gLSHTestData + binWidth = .5 + if True: + numPoints = 100000 + CreateRandomLSHTestData(numPoints, 3) + else: + CreateRegularLSHTestData(100) + numPoints = gLSHTestData.shape[0] + k = 4; l = 2; N = 1000 + ind = index(k, l, binWidth, N) + for i in range(0,numPoints): + ind.InsertIntoTable(i, gLSHTestData[i,:]) + rBig = binWidth/8.0 + rSmall = rBig/2.0 + cBig = 0.0; cSmall = 0.0 + for id in random.sample(ind.GetAllIndices(), 2): + qp = FindLSHTestData(id) + cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig) + cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall) + if cBig > cSmall and cSmall > 0: + dim = math.log(cBig/cSmall)/math.log(rBig/rSmall) + else: + dim = 0 + print cBig, cSmall, dim + return ind + +# Call an external process to compute the digamma function (from the GNU Scientific Library) +import subprocess +def digamma(x): + y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0] + return float(y.strip()) + + +# Generate some 2-dimensional data, put it into an index and then +# show the points retrieved. This is all done as a function of number +# of projections per bucket, number of buckets to use for each index, and +# the number of LSH bucket (the T1 size). Write out the data so we can +# plot it (in Matlab) +def GraphicalTest(k, l, N): + numPoints = 1000 + CreateRandomLSHTestData(numPoints, 3) + ind = index(k, l, .1, N) + for i in range(0,numPoints): + ind.InsertIntoTable(i, FindLSHTestData(i)) + i = 42 + r = ind.Find(data[i,:]) + fp = open('lshtestpoints.txt','w') + for i in range(0,numPoints): + if i in r: + c = r[i] + else: + c = 0 + fp.write("%g %g %d\n" % (data[i,0], data[i,1], c)) + fp.close() + return r + + +# Run one LSH test. Look for point 42 in the data. +def ExactTest(): + global gLSHTestData + numPoints = 1000 + CreateRandomLSHTestData(numPoints, 2) + ind = index(10, 2, .1, 100) + for i in range(0,numPoints): + ind.InsertIntoTable(i, FindLSHTestData(i)) + data = FindLSHTestData(42) + res = ind.FindExact(data, FindLSHTestData) + return res + +# Create a file with distances retrieved as a function of k. +# First line is the exact result, showing all points in the dB. +# Successive lines are results for an LSH index. +def TestRetrieval(): + dims = 3 + numPoints = 100000 + CreateRandomLSHTestData(numPoints, 3) + qp = FindLSHTestData(0)*0.0 + fp = open('TestRetrieval.txt','w') + for l in range(1,5): + for k in range(1,6): + for iter in range(1,10): + print "Building an index with l="+str(l)+", k="+str(k) + ind = index(k, l, .1, 100) # Build new index + for i in range(0,numPoints): + ind.InsertIntoTable(i, FindLSHTestData(i)) + if k == 1 and l == 1: + matches = ind.FindAll(qp, FindLSHTestData) + fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) + fp.write('\n') + matches = ind.FindExact(qp, FindLSHTestData) + fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) + # Fill rest of the results with -1 + fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist()))) + fp.write('\n') + fp.close() + + +# Save an LSH index to a pickle file. +def SaveIndex(filename, ind): + try: + fp = open(filename, 'w') + pickle.dump(ind, fp) + fp.close() + statinfo = os.stat(filename,) + if statinfo: + print "Wrote out", statinfo.st_size, "bytes to", \ + filename + except: + print "Couldn't pickle index to file", filename + traceback.print_exc(file=sys.stderr) + +# Read an LSH index from a pickle file. +def LoadIndex(filename): + try: + fp = open(filename, 'r') + ind = pickle.load(fp) + fp.close() + return ind + except: + print "Couldn't read pickle file", filename + traceback.print_exc(file=sys.stderr) + + +