Mercurial > hg > audiodb
view lsh.py @ 740:92f034aa8f28 multiprobeLSH
added lsh.py
author | mas01mc |
---|---|
date | Mon, 04 Oct 2010 19:12:00 +0000 |
parents | |
children | 50a7fd50578f |
line wrap: on
line source
import random, numpy, pickle, os, operator, traceback, sys, math # Python implementation of Andoni's e2LSH. This version is fast because it # uses Python hashes to implement the buckets. The numerics are handled # by the numpy routine so this should be close to optimal in speed (although # there is no control of the hash tables layout in memory.) # Malcolm Slaney, Yahoo! Research, 2009 # Class that can be used to return random prime numbers. We need this # for our hash functions. class primes(): # This list is from http://primes.utm.edu/lists/2small/0bit.html def __init__(self): p = {} # These numbers, 2**i - j are primes p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59] p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55] p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57] p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61] p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77] p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91] p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111] p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135] p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123] p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99] p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93] p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87] p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185] p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129] p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123] p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159] p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167] p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183] p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135] p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235] p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273] p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133] p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173] p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171] primes = [] for k in p: for m in p[k]: primes.append(2**k-m) self.primes = primes self.len = len(primes) def sample(self): return self.primes[int(random.random()*self.len)] # This class just implements k-projections into k integers # (after quantization) and then reducing that integer vector # into a T1 and T2 hash. Data can either be entered into a # table, or retrieved. class lsh: def __init__(self, k, w, N, t2size = (2**16 - 15)): self.k = k # Number of projections self.w = w # Bin width self.N = N # Number of buckets self.t2size = t2size self.projections = None self.buckets = {} # Create the random constants needed for the projections. def CreateProjections(self, dim): self.dim = dim p = primes() self.projections = numpy.random.randn(self.k, self.dim) self.bias = numpy.random.rand(self.k) # Use the bias array just for its size (ignore contents) self.t1hash = map(lambda x: p.sample(), self.bias) self.t2hash = map(lambda x: p.sample(), self.bias) if 0: print "Dim is", self.dim print 'Projections:\n', self.projections print 'T1 hash:\n', self.t1hash print 'T2 hash:\n', self.t2hash # Actually create the t1 and t2 hashes for some data. def CalculateHashes(self, data): if self.projections == None: self.CreateProjections(len(data)) bins = numpy.zeros(self.k, 'i') for i in range(0,self.k): bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ + self.bias[i] t1 = numpy.sum(bins * self.t1hash) % self.N t2 = numpy.sum(bins * self.t2hash) % self.t2size return t1, t2 def CalculateHashes2(self, data): if self.projections == None: self.CreateProjections(len(data)) bins = numpy.zeros(self.k, 'i') for i in range(0,self.k): bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ + self.bias[i] t1 = numpy.sum(bins * self.t1hash) % self.N t2 = numpy.sum(bins * self.t2hash) % self.t2size return t1, t2, bins # Put some data into the hash bucket for this LSH projection def InsertIntoTable(self, id, data): (t1, t2) = self.CalculateHashes(data) if t1 not in self.buckets: self.buckets[t1] = {t2: [id]} else: if t2 not in self.buckets[t1]: self.buckets[t1][t2] = [id] else: self.buckets[t1][t2].append(id) # Find some data in the hash bucket. Return all the ids # that we find for this T1-T2 pair. def Find(self, data): (t1, t2) = self.CalculateHashes(data) if t1 not in self.buckets: return [] row = self.buckets[t1] if t2 not in row: return [] return row[t2] # Create a dictionary showing all the buckets an ID appears in def CreateDictionary(self, theDictionary, prefix): for b in self.buckets: # Over all buckets w = prefix + str(b) for c in self.buckets[b]:# Over all T2 hashes for i in self.buckets[b][c]:#Over ids if not i in theDictionary: theDictionary[i] = [w] else: theDictionary[i] += w return theDictionary # Print some stats for these lsh buckets def Stats(self): maxCount = 0; sumCount = 0; numCount = 0; bucketLens = []; for b in self.buckets: for c in self.buckets[b]: l = len(self.buckets[b][c]) if l > maxCount: maxCount = l maxLoc = (b,c) # print b,c,self.buckets[b][c] sumCount += l numCount += 1 bucketLens.append(l) theValues = sorted(bucketLens) med = theValues[(len(theValues)+1)/2-1] print "Bucket Counts:" print "\tTotal indexed points:", sumCount print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N) print "\tT2 Buckets used: %d/%d" % (numCount, self.N) print "\tMaximum T2 chain length:", maxCount, "at", maxLoc print "\tAverage T2 chain length:", float(sumCount)/numCount print "\tMedian T2 chain length:", med # Get a list of all IDs that are contained in these hash buckets def GetAllIndices(self): theList = [] for b in self.buckets: for c in self.buckets[b]: theList += self.buckets[b][c] return theList # Put some data into the hash table, see how many collisions we get. def Test(self, n): self.buckets = {} self.projections = None d = numpy.array([.2,.3]) for i in range(0,n): self.InsertIntoTable(i, d+i) for i in range(0,n): r = self.Find(d+i) matches = sum(map(lambda x: x==i, r)) if matches == 0: print "Couldn't find item", i elif matches == 1: pass if len(r) > 1: print "Found big bin for", i,":", r # Put together several LSH projections to form an index. The only # new parameter is the number of groups of projections (one LSH class # object per group.) class index: def __init__(self, k, l, w, N): self.k = k; self.l = l self.w = w self.N = N self.projections = [] for i in range(0,l): # Create all LSH buckets self.projections.append(lsh(k, w, N)) # Insert some data into all LSH buckets def InsertIntoTable(self, id, data): for p in self.projections: p.InsertIntoTable(id, data) # Find some data in all the LSH buckets. def Find(self, data): items = [] for p in self.projections: items += p.Find(data) # Concatenate # print "Find results are:", items results = {} for item in items: results.setdefault(item, 0) results[item] += 1 s = sorted(results.items(), key=operator.itemgetter(1), \ reverse=True) return s # Return a list of results: (id, distance**2, count) def FindExact(self, data, GetData): s = self.Find(data) # print "Intermediate results are:", s d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) ds = sorted(d, key=operator.itemgetter(1)) return ds # Do an exhaustive distance calculation looking for all points and their distance. # Return a list of results: (id, distance**2, count) def FindAll(self, query, GetData): s = [] allIDs = self.GetAllIndices() for id in allIDs: dist = ((GetData(id)-query)**2).sum() s.append((id, dist, 0)) # print "Intermediate results are:", s # d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) ds = sorted(s, key=operator.itemgetter(1)) return ds # Return the number of points that are closer than radius to the query def CountInsideRadius(self, data, GetData, radius): matches = self.FindExact(data, GetData) # print "CountInsideRadius found",len(matches),"matches" radius2 = radius**2 count = sum(map(lambda (id,distance,count): distance<radius2, matches)) return count # Put some data into the hash tables. def Test(self, n): d = numpy.array([.2,.3]) for i in range(0,n): self.InsertIntoTable(i, d+i) for i in range(0,n): r = self.Find(d+i) print r # Print the statistics of each hash table. def Stats(self): for i in range(0, len(self.projections)): p = self.projections[i] print "Buckets", i, p.Stats() # Get al the IDs that are part of this index. Just check one hash def GetAllIndices(self): if self.projections and len(self.projections) > 0: p = self.projections[0] return p.GetAllIndices() return None # Return the buckets (t1 and t2 hashes) associated with a data point def GetBuckets(data): b = [] for p in self.projections: h = p.CalculateHashes(data) b += h # Create a list ordered by ID listing which buckets are used for each ID def CreateDictionary(): theDictionary = {} prefixes = 'abcdefghijklmnopqrstuvwxyz' pi = 0 for p in self.projections: prefix = 'W' pc = pi while pc > 0: # Create unique ID for theis bucket prefix += prefixes[pc%len(prefixes)] pc /= len(prefixes) theDictionary = p.CreateDictionary(theDictionary,\ prefix) pi += 1 return theDictionary # Use the expression in "Analysis of Minimum Distances in High-Dimensional # Musical Spaces" to calculate the underlying dimensionality of the data # For a random selection of ids, find the nearest neighbors and use this # to calculate the dimensionality of the data. def MeasureDimensionality(self,allData,N): allIDs = self.GetAllIndices() sampleIDs = random.sample(allIDs, N) L = 0.0; S=0.0 for id in sampleIDs: res = self.FindExact(allData[id,:], lambda i:allData[i, :]) if len(res) > 1: (nnid, dist, count) = res[1] S += dist L += math.log(dist) else: N -= 1 print "S="+str(S), "L="+str(L), "N="+str(N) if N > 1: x = math.log(S/N) - L/N # Equation 17 d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y)) print d return d else: return 0 # Only works for monotonic functions... Uses geometric midpoint to reduce # the search range, looking for the function output that equals the given value # Test with: # lsh.InvertFunction(2,math.sqrt) # lsh.InvertFunction(2,lambda x:1.0/x) # Needed for inverting the gamma function in the MeasureDimensionality method. def InvertFunction(x, func): min = 0.0001; max = 1000; if func(min) < func(max): sign = 1 else: sign = -1 print "Looking for Y() =", str(x), "d'=", sign while min + 1e-7 < max: mid = math.sqrt(min*max) Y = func(mid) # print min, mid, Y, max if sign*Y > sign*x: max = mid else: min = mid return mid ##### A bunch of routines used to generate data we can use to test # this LSH implementation. global gLSHTestData gLSHTestData = [] # Find a point in the array of data. (Needed so FindExact can get the # data it needs.) def FindLSHTestData(id): global gLSHTestData if id < gLSHTestData.shape[0]: return gLSHTestData[id,:] return None # Fill the test array with uniform random data between 0 and 1 def CreateRandomLSHTestData(numPoints, dim): global gLSHTestData gLSHTestData = [] gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0 # Fill the test array with a regular grid of points between -1 and 1 def CreateRegularLSHTestData(numDivs): gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2)) i = 0 for x in range(-numDivs, numDivs+1): for y in range(-numDivs, numDivs+1): gLSHTestData[i,0] = x/float(divs) gLSHTestData[i,1] = y/float(divs) i += 1 numPoints = i # Use Nearest Neighbor properties to calculate dimensionality. def TestDimensionality(N): numPoints = 100000 k = 10 CreateRandomLSHTestData(numPoints, 3) ind = index(k, 2, .1, 100) for i in range(0,numPoints): ind.InsertIntoTable(i, FindLSHTestData(i)) ind.MeasureDimensionality(gLSHTestData, N) # Use Dimension Doubling to measure the dimensionality of a random # set of data. Generate some data (either random Gaussian or a grid) # Then count the number of points that fall within the given radius of this query. def TestDimensionality2(): global gLSHTestData binWidth = .5 if True: numPoints = 100000 CreateRandomLSHTestData(numPoints, 3) else: CreateRegularLSHTestData(100) numPoints = gLSHTestData.shape[0] k = 4; l = 2; N = 1000 ind = index(k, l, binWidth, N) for i in range(0,numPoints): ind.InsertIntoTable(i, gLSHTestData[i,:]) rBig = binWidth/8.0 rSmall = rBig/2.0 cBig = 0.0; cSmall = 0.0 for id in random.sample(ind.GetAllIndices(), 2): qp = FindLSHTestData(id) cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig) cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall) if cBig > cSmall and cSmall > 0: dim = math.log(cBig/cSmall)/math.log(rBig/rSmall) else: dim = 0 print cBig, cSmall, dim return ind # Call an external process to compute the digamma function (from the GNU Scientific Library) import subprocess def digamma(x): y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0] return float(y.strip()) # Generate some 2-dimensional data, put it into an index and then # show the points retrieved. This is all done as a function of number # of projections per bucket, number of buckets to use for each index, and # the number of LSH bucket (the T1 size). Write out the data so we can # plot it (in Matlab) def GraphicalTest(k, l, N): numPoints = 1000 CreateRandomLSHTestData(numPoints, 3) ind = index(k, l, .1, N) for i in range(0,numPoints): ind.InsertIntoTable(i, FindLSHTestData(i)) i = 42 r = ind.Find(data[i,:]) fp = open('lshtestpoints.txt','w') for i in range(0,numPoints): if i in r: c = r[i] else: c = 0 fp.write("%g %g %d\n" % (data[i,0], data[i,1], c)) fp.close() return r # Run one LSH test. Look for point 42 in the data. def ExactTest(): global gLSHTestData numPoints = 1000 CreateRandomLSHTestData(numPoints, 2) ind = index(10, 2, .1, 100) for i in range(0,numPoints): ind.InsertIntoTable(i, FindLSHTestData(i)) data = FindLSHTestData(42) res = ind.FindExact(data, FindLSHTestData) return res # Create a file with distances retrieved as a function of k. # First line is the exact result, showing all points in the dB. # Successive lines are results for an LSH index. def TestRetrieval(): dims = 3 numPoints = 100000 CreateRandomLSHTestData(numPoints, 3) qp = FindLSHTestData(0)*0.0 fp = open('TestRetrieval.txt','w') for l in range(1,5): for k in range(1,6): for iter in range(1,10): print "Building an index with l="+str(l)+", k="+str(k) ind = index(k, l, .1, 100) # Build new index for i in range(0,numPoints): ind.InsertIntoTable(i, FindLSHTestData(i)) if k == 1 and l == 1: matches = ind.FindAll(qp, FindLSHTestData) fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) fp.write('\n') matches = ind.FindExact(qp, FindLSHTestData) fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) # Fill rest of the results with -1 fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist()))) fp.write('\n') fp.close() # Save an LSH index to a pickle file. def SaveIndex(filename, ind): try: fp = open(filename, 'w') pickle.dump(ind, fp) fp.close() statinfo = os.stat(filename,) if statinfo: print "Wrote out", statinfo.st_size, "bytes to", \ filename except: print "Couldn't pickle index to file", filename traceback.print_exc(file=sys.stderr) # Read an LSH index from a pickle file. def LoadIndex(filename): try: fp = open(filename, 'r') ind = pickle.load(fp) fp.close() return ind except: print "Couldn't read pickle file", filename traceback.print_exc(file=sys.stderr)