mas01mc@740: import random, numpy, pickle, os, operator, traceback, sys, math mas01mc@740: mas01mc@740: mas01mc@740: # Python implementation of Andoni's e2LSH. This version is fast because it mas01mc@740: # uses Python hashes to implement the buckets. The numerics are handled mas01mc@740: # by the numpy routine so this should be close to optimal in speed (although mas01mc@740: # there is no control of the hash tables layout in memory.) mas01mc@740: mas01mc@740: # Malcolm Slaney, Yahoo! Research, 2009 mas01mc@740: mas01mc@740: # Class that can be used to return random prime numbers. We need this mas01mc@740: # for our hash functions. mas01mc@740: class primes(): mas01mc@740: # This list is from http://primes.utm.edu/lists/2small/0bit.html mas01mc@740: def __init__(self): mas01mc@740: p = {} mas01mc@740: # These numbers, 2**i - j are primes mas01mc@740: p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59] mas01mc@740: p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55] mas01mc@740: p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57] mas01mc@740: p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61] mas01mc@740: p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77] mas01mc@740: p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91] mas01mc@740: p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111] mas01mc@740: p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135] mas01mc@740: p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123] mas01mc@740: p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99] mas01mc@740: p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93] mas01mc@740: p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87] mas01mc@740: p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185] mas01mc@740: p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129] mas01mc@740: p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123] mas01mc@740: p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159] mas01mc@740: p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167] mas01mc@740: p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183] mas01mc@740: p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135] mas01mc@740: p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235] mas01mc@740: p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273] mas01mc@740: p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133] mas01mc@740: p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173] mas01mc@740: p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171] mas01mc@740: mas01mc@740: primes = [] mas01mc@740: for k in p: mas01mc@740: for m in p[k]: mas01mc@740: primes.append(2**k-m) mas01mc@740: self.primes = primes mas01mc@740: self.len = len(primes) mas01mc@740: mas01mc@740: def sample(self): mas01mc@740: return self.primes[int(random.random()*self.len)] mas01mc@740: mas01mc@740: mas01mc@740: # This class just implements k-projections into k integers mas01mc@740: # (after quantization) and then reducing that integer vector mas01mc@740: # into a T1 and T2 hash. Data can either be entered into a mas01mc@740: # table, or retrieved. mas01mc@740: class lsh: mas01mc@740: def __init__(self, k, w, N, t2size = (2**16 - 15)): mas01mc@740: self.k = k # Number of projections mas01mc@740: self.w = w # Bin width mas01mc@740: self.N = N # Number of buckets mas01mc@740: self.t2size = t2size mas01mc@740: self.projections = None mas01mc@740: self.buckets = {} mas01mc@740: mas01mc@740: # Create the random constants needed for the projections. mas01mc@740: def CreateProjections(self, dim): mas01mc@740: self.dim = dim mas01mc@740: p = primes() mas01mc@740: self.projections = numpy.random.randn(self.k, self.dim) mas01mc@740: self.bias = numpy.random.rand(self.k) mas01mc@740: # Use the bias array just for its size (ignore contents) mas01mc@740: self.t1hash = map(lambda x: p.sample(), self.bias) mas01mc@740: self.t2hash = map(lambda x: p.sample(), self.bias) mas01mc@740: if 0: mas01mc@740: print "Dim is", self.dim mas01mc@740: print 'Projections:\n', self.projections mas01mc@740: print 'T1 hash:\n', self.t1hash mas01mc@740: print 'T2 hash:\n', self.t2hash mas01mc@740: mas01mc@740: # Actually create the t1 and t2 hashes for some data. mas01mc@740: def CalculateHashes(self, data): mas01mc@740: if self.projections == None: mas01mc@740: self.CreateProjections(len(data)) mas01mc@740: bins = numpy.zeros(self.k, 'i') mas01mc@740: for i in range(0,self.k): mas01mc@740: bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ mas01mc@740: + self.bias[i] mas01mc@740: t1 = numpy.sum(bins * self.t1hash) % self.N mas01mc@740: t2 = numpy.sum(bins * self.t2hash) % self.t2size mas01mc@740: return t1, t2 mas01mc@740: mas01mc@740: def CalculateHashes2(self, data): mas01mc@740: if self.projections == None: mas01mc@740: self.CreateProjections(len(data)) mas01mc@740: bins = numpy.zeros(self.k, 'i') mas01mc@740: for i in range(0,self.k): mas01mc@740: bins[i] = numpy.sum(data * self.projections[i,:])/self.w \ mas01mc@740: + self.bias[i] mas01mc@740: t1 = numpy.sum(bins * self.t1hash) % self.N mas01mc@740: t2 = numpy.sum(bins * self.t2hash) % self.t2size mas01mc@740: return t1, t2, bins mas01mc@740: mas01mc@740: # Put some data into the hash bucket for this LSH projection mas01mc@740: def InsertIntoTable(self, id, data): mas01mc@740: (t1, t2) = self.CalculateHashes(data) mas01mc@740: if t1 not in self.buckets: mas01mc@740: self.buckets[t1] = {t2: [id]} mas01mc@740: else: mas01mc@740: if t2 not in self.buckets[t1]: mas01mc@740: self.buckets[t1][t2] = [id] mas01mc@740: else: mas01mc@740: self.buckets[t1][t2].append(id) mas01mc@740: mas01mc@740: # Find some data in the hash bucket. Return all the ids mas01mc@740: # that we find for this T1-T2 pair. mas01mc@740: def Find(self, data): mas01mc@740: (t1, t2) = self.CalculateHashes(data) mas01mc@740: if t1 not in self.buckets: mas01mc@740: return [] mas01mc@740: row = self.buckets[t1] mas01mc@740: if t2 not in row: mas01mc@740: return [] mas01mc@740: return row[t2] mas01mc@740: mas01mc@740: # Create a dictionary showing all the buckets an ID appears in mas01mc@740: def CreateDictionary(self, theDictionary, prefix): mas01mc@740: for b in self.buckets: # Over all buckets mas01mc@740: w = prefix + str(b) mas01mc@740: for c in self.buckets[b]:# Over all T2 hashes mas01mc@740: for i in self.buckets[b][c]:#Over ids mas01mc@740: if not i in theDictionary: mas01mc@740: theDictionary[i] = [w] mas01mc@740: else: mas01mc@740: theDictionary[i] += w mas01mc@740: return theDictionary mas01mc@740: mas01mc@740: mas01mc@740: mas01mc@740: # Print some stats for these lsh buckets mas01mc@740: def Stats(self): mas01mc@740: maxCount = 0; sumCount = 0; mas01mc@740: numCount = 0; bucketLens = []; mas01mc@740: for b in self.buckets: mas01mc@740: for c in self.buckets[b]: mas01mc@740: l = len(self.buckets[b][c]) mas01mc@740: if l > maxCount: mas01mc@740: maxCount = l mas01mc@740: maxLoc = (b,c) mas01mc@740: # print b,c,self.buckets[b][c] mas01mc@740: sumCount += l mas01mc@740: numCount += 1 mas01mc@740: bucketLens.append(l) mas01mc@740: theValues = sorted(bucketLens) mas01mc@740: med = theValues[(len(theValues)+1)/2-1] mas01mc@740: print "Bucket Counts:" mas01mc@740: print "\tTotal indexed points:", sumCount mas01mc@740: print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N) mas01mc@740: print "\tT2 Buckets used: %d/%d" % (numCount, self.N) mas01mc@740: print "\tMaximum T2 chain length:", maxCount, "at", maxLoc mas01mc@740: print "\tAverage T2 chain length:", float(sumCount)/numCount mas01mc@740: print "\tMedian T2 chain length:", med mas01mc@740: mas01mc@740: # Get a list of all IDs that are contained in these hash buckets mas01mc@740: def GetAllIndices(self): mas01mc@740: theList = [] mas01mc@740: for b in self.buckets: mas01mc@740: for c in self.buckets[b]: mas01mc@740: theList += self.buckets[b][c] mas01mc@740: return theList mas01mc@740: mas01mc@740: # Put some data into the hash table, see how many collisions we get. mas01mc@740: def Test(self, n): mas01mc@740: self.buckets = {} mas01mc@740: self.projections = None mas01mc@740: d = numpy.array([.2,.3]) mas01mc@740: for i in range(0,n): mas01mc@740: self.InsertIntoTable(i, d+i) mas01mc@740: for i in range(0,n): mas01mc@740: r = self.Find(d+i) mas01mc@740: matches = sum(map(lambda x: x==i, r)) mas01mc@740: if matches == 0: mas01mc@740: print "Couldn't find item", i mas01mc@740: elif matches == 1: mas01mc@740: pass mas01mc@740: if len(r) > 1: mas01mc@740: print "Found big bin for", i,":", r mas01mc@740: mas01mc@740: mas01mc@740: # Put together several LSH projections to form an index. The only mas01mc@740: # new parameter is the number of groups of projections (one LSH class mas01mc@740: # object per group.) mas01mc@740: class index: mas01mc@740: def __init__(self, k, l, w, N): mas01mc@740: self.k = k; mas01mc@740: self.l = l mas01mc@740: self.w = w mas01mc@740: self.N = N mas01mc@740: self.projections = [] mas01mc@740: for i in range(0,l): # Create all LSH buckets mas01mc@740: self.projections.append(lsh(k, w, N)) mas01mc@740: # Insert some data into all LSH buckets mas01mc@740: def InsertIntoTable(self, id, data): mas01mc@740: for p in self.projections: mas01mc@740: p.InsertIntoTable(id, data) mas01mc@740: # Find some data in all the LSH buckets. mas01mc@740: def Find(self, data): mas01mc@740: items = [] mas01mc@740: for p in self.projections: mas01mc@740: items += p.Find(data) # Concatenate mas01mc@740: # print "Find results are:", items mas01mc@740: results = {} mas01mc@740: for item in items: mas01mc@740: results.setdefault(item, 0) mas01mc@740: results[item] += 1 mas01mc@740: s = sorted(results.items(), key=operator.itemgetter(1), \ mas01mc@740: reverse=True) mas01mc@740: return s mas01mc@740: mas01mc@740: # Return a list of results: (id, distance**2, count) mas01mc@740: def FindExact(self, data, GetData): mas01mc@740: s = self.Find(data) mas01mc@740: # print "Intermediate results are:", s mas01mc@740: d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) mas01mc@740: ds = sorted(d, key=operator.itemgetter(1)) mas01mc@740: return ds mas01mc@740: mas01mc@740: # Do an exhaustive distance calculation looking for all points and their distance. mas01mc@740: # Return a list of results: (id, distance**2, count) mas01mc@740: def FindAll(self, query, GetData): mas01mc@740: s = [] mas01mc@740: allIDs = self.GetAllIndices() mas01mc@740: for id in allIDs: mas01mc@740: dist = ((GetData(id)-query)**2).sum() mas01mc@740: s.append((id, dist, 0)) mas01mc@740: # print "Intermediate results are:", s mas01mc@740: # d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s) mas01mc@740: ds = sorted(s, key=operator.itemgetter(1)) mas01mc@740: return ds mas01mc@740: mas01mc@740: # Return the number of points that are closer than radius to the query mas01mc@740: def CountInsideRadius(self, data, GetData, radius): mas01mc@740: matches = self.FindExact(data, GetData) mas01mc@740: # print "CountInsideRadius found",len(matches),"matches" mas01mc@740: radius2 = radius**2 mas01mc@740: count = sum(map(lambda (id,distance,count): distance 0: mas01mc@740: p = self.projections[0] mas01mc@740: return p.GetAllIndices() mas01mc@740: return None mas01mc@740: mas01mc@740: # Return the buckets (t1 and t2 hashes) associated with a data point mas01mc@740: def GetBuckets(data): mas01mc@740: b = [] mas01mc@740: for p in self.projections: mas01mc@740: h = p.CalculateHashes(data) mas01mc@740: b += h mas01mc@740: mas01mc@740: # Create a list ordered by ID listing which buckets are used for each ID mas01mc@740: def CreateDictionary(): mas01mc@740: theDictionary = {} mas01mc@740: prefixes = 'abcdefghijklmnopqrstuvwxyz' mas01mc@740: pi = 0 mas01mc@740: for p in self.projections: mas01mc@740: prefix = 'W' mas01mc@740: pc = pi mas01mc@740: while pc > 0: # Create unique ID for theis bucket mas01mc@740: prefix += prefixes[pc%len(prefixes)] mas01mc@740: pc /= len(prefixes) mas01mc@740: theDictionary = p.CreateDictionary(theDictionary,\ mas01mc@740: prefix) mas01mc@740: pi += 1 mas01mc@740: return theDictionary mas01mc@740: mas01mc@740: # Use the expression in "Analysis of Minimum Distances in High-Dimensional mas01mc@740: # Musical Spaces" to calculate the underlying dimensionality of the data mas01mc@740: # For a random selection of ids, find the nearest neighbors and use this mas01mc@740: # to calculate the dimensionality of the data. mas01mc@740: def MeasureDimensionality(self,allData,N): mas01mc@740: allIDs = self.GetAllIndices() mas01mc@740: sampleIDs = random.sample(allIDs, N) mas01mc@740: L = 0.0; S=0.0 mas01mc@740: for id in sampleIDs: mas01mc@740: res = self.FindExact(allData[id,:], lambda i:allData[i, :]) mas01mc@740: if len(res) > 1: mas01mc@740: (nnid, dist, count) = res[1] mas01mc@740: S += dist mas01mc@740: L += math.log(dist) mas01mc@740: else: mas01mc@740: N -= 1 mas01mc@740: print "S="+str(S), "L="+str(L), "N="+str(N) mas01mc@740: if N > 1: mas01mc@740: x = math.log(S/N) - L/N # Equation 17 mas01mc@740: d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y)) mas01mc@740: print d mas01mc@740: return d mas01mc@740: else: mas01mc@740: return 0 mas01mc@740: mas01mc@740: # Only works for monotonic functions... Uses geometric midpoint to reduce mas01mc@740: # the search range, looking for the function output that equals the given value mas01mc@740: # Test with: mas01mc@740: # lsh.InvertFunction(2,math.sqrt) mas01mc@740: # lsh.InvertFunction(2,lambda x:1.0/x) mas01mc@740: # Needed for inverting the gamma function in the MeasureDimensionality method. mas01mc@740: def InvertFunction(x, func): mas01mc@740: min = 0.0001; max = 1000; mas01mc@740: if func(min) < func(max): mas01mc@740: sign = 1 mas01mc@740: else: mas01mc@740: sign = -1 mas01mc@740: print "Looking for Y() =", str(x), "d'=", sign mas01mc@740: while min + 1e-7 < max: mas01mc@740: mid = math.sqrt(min*max) mas01mc@740: Y = func(mid) mas01mc@740: # print min, mid, Y, max mas01mc@740: if sign*Y > sign*x: mas01mc@740: max = mid mas01mc@740: else: mas01mc@740: min = mid mas01mc@740: return mid mas01mc@740: mas01mc@740: ##### A bunch of routines used to generate data we can use to test mas01mc@740: # this LSH implementation. mas01mc@740: mas01mc@740: global gLSHTestData mas01mc@740: gLSHTestData = [] mas01mc@740: mas01mc@740: # Find a point in the array of data. (Needed so FindExact can get the mas01mc@740: # data it needs.) mas01mc@740: def FindLSHTestData(id): mas01mc@740: global gLSHTestData mas01mc@740: if id < gLSHTestData.shape[0]: mas01mc@740: return gLSHTestData[id,:] mas01mc@740: return None mas01mc@740: mas01mc@740: # Fill the test array with uniform random data between 0 and 1 mas01mc@740: def CreateRandomLSHTestData(numPoints, dim): mas01mc@740: global gLSHTestData mas01mc@740: gLSHTestData = [] mas01mc@740: gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0 mas01mc@740: mas01mc@740: # Fill the test array with a regular grid of points between -1 and 1 mas01mc@740: def CreateRegularLSHTestData(numDivs): mas01mc@740: gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2)) mas01mc@740: i = 0 mas01mc@740: for x in range(-numDivs, numDivs+1): mas01mc@740: for y in range(-numDivs, numDivs+1): mas01mc@740: gLSHTestData[i,0] = x/float(divs) mas01mc@740: gLSHTestData[i,1] = y/float(divs) mas01mc@740: i += 1 mas01mc@740: numPoints = i mas01mc@740: mas01mc@740: # Use Nearest Neighbor properties to calculate dimensionality. mas01mc@740: def TestDimensionality(N): mas01mc@740: numPoints = 100000 mas01mc@740: k = 10 mas01mc@740: CreateRandomLSHTestData(numPoints, 3) mas01mc@740: ind = index(k, 2, .1, 100) mas01mc@740: for i in range(0,numPoints): mas01mc@740: ind.InsertIntoTable(i, FindLSHTestData(i)) mas01mc@740: ind.MeasureDimensionality(gLSHTestData, N) mas01mc@740: mas01mc@740: # Use Dimension Doubling to measure the dimensionality of a random mas01mc@740: # set of data. Generate some data (either random Gaussian or a grid) mas01mc@740: # Then count the number of points that fall within the given radius of this query. mas01mc@740: def TestDimensionality2(): mas01mc@740: global gLSHTestData mas01mc@740: binWidth = .5 mas01mc@740: if True: mas01mc@740: numPoints = 100000 mas01mc@740: CreateRandomLSHTestData(numPoints, 3) mas01mc@740: else: mas01mc@740: CreateRegularLSHTestData(100) mas01mc@740: numPoints = gLSHTestData.shape[0] mas01mc@740: k = 4; l = 2; N = 1000 mas01mc@740: ind = index(k, l, binWidth, N) mas01mc@740: for i in range(0,numPoints): mas01mc@740: ind.InsertIntoTable(i, gLSHTestData[i,:]) mas01mc@740: rBig = binWidth/8.0 mas01mc@740: rSmall = rBig/2.0 mas01mc@740: cBig = 0.0; cSmall = 0.0 mas01mc@740: for id in random.sample(ind.GetAllIndices(), 2): mas01mc@740: qp = FindLSHTestData(id) mas01mc@740: cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig) mas01mc@740: cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall) mas01mc@740: if cBig > cSmall and cSmall > 0: mas01mc@740: dim = math.log(cBig/cSmall)/math.log(rBig/rSmall) mas01mc@740: else: mas01mc@740: dim = 0 mas01mc@740: print cBig, cSmall, dim mas01mc@740: return ind mas01mc@740: mas01mc@740: # Call an external process to compute the digamma function (from the GNU Scientific Library) mas01mc@740: import subprocess mas01mc@740: def digamma(x): mas01mc@740: y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0] mas01mc@740: return float(y.strip()) mas01mc@740: mas01mc@740: mas01mc@740: # Generate some 2-dimensional data, put it into an index and then mas01mc@740: # show the points retrieved. This is all done as a function of number mas01mc@740: # of projections per bucket, number of buckets to use for each index, and mas01mc@740: # the number of LSH bucket (the T1 size). Write out the data so we can mas01mc@740: # plot it (in Matlab) mas01mc@740: def GraphicalTest(k, l, N): mas01mc@741: global gLSHTestData mas01mc@740: numPoints = 1000 mas01mc@740: CreateRandomLSHTestData(numPoints, 3) mas01mc@740: ind = index(k, l, .1, N) mas01mc@740: for i in range(0,numPoints): mas01mc@740: ind.InsertIntoTable(i, FindLSHTestData(i)) mas01mc@741: data=gLSHTestData mas01mc@740: i = 42 mas01mc@740: r = ind.Find(data[i,:]) mas01mc@740: fp = open('lshtestpoints.txt','w') mas01mc@740: for i in range(0,numPoints): mas01mc@740: if i in r: mas01mc@740: c = r[i] mas01mc@740: else: mas01mc@740: c = 0 mas01mc@740: fp.write("%g %g %d\n" % (data[i,0], data[i,1], c)) mas01mc@740: fp.close() mas01mc@740: return r mas01mc@740: mas01mc@740: mas01mc@740: # Run one LSH test. Look for point 42 in the data. mas01mc@740: def ExactTest(): mas01mc@740: global gLSHTestData mas01mc@740: numPoints = 1000 mas01mc@740: CreateRandomLSHTestData(numPoints, 2) mas01mc@740: ind = index(10, 2, .1, 100) mas01mc@740: for i in range(0,numPoints): mas01mc@740: ind.InsertIntoTable(i, FindLSHTestData(i)) mas01mc@740: data = FindLSHTestData(42) mas01mc@740: res = ind.FindExact(data, FindLSHTestData) mas01mc@740: return res mas01mc@740: mas01mc@740: # Create a file with distances retrieved as a function of k. mas01mc@740: # First line is the exact result, showing all points in the dB. mas01mc@740: # Successive lines are results for an LSH index. mas01mc@740: def TestRetrieval(): mas01mc@740: dims = 3 mas01mc@740: numPoints = 100000 mas01mc@740: CreateRandomLSHTestData(numPoints, 3) mas01mc@740: qp = FindLSHTestData(0)*0.0 mas01mc@740: fp = open('TestRetrieval.txt','w') mas01mc@740: for l in range(1,5): mas01mc@740: for k in range(1,6): mas01mc@740: for iter in range(1,10): mas01mc@740: print "Building an index with l="+str(l)+", k="+str(k) mas01mc@740: ind = index(k, l, .1, 100) # Build new index mas01mc@740: for i in range(0,numPoints): mas01mc@740: ind.InsertIntoTable(i, FindLSHTestData(i)) mas01mc@740: if k == 1 and l == 1: mas01mc@740: matches = ind.FindAll(qp, FindLSHTestData) mas01mc@740: fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) mas01mc@740: fp.write('\n') mas01mc@740: matches = ind.FindExact(qp, FindLSHTestData) mas01mc@740: fp.write(' '.join(map(lambda (i,d,c): str(d), matches))) mas01mc@740: # Fill rest of the results with -1 mas01mc@740: fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist()))) mas01mc@740: fp.write('\n') mas01mc@740: fp.close() mas01mc@740: mas01mc@740: mas01mc@740: # Save an LSH index to a pickle file. mas01mc@740: def SaveIndex(filename, ind): mas01mc@740: try: mas01mc@740: fp = open(filename, 'w') mas01mc@740: pickle.dump(ind, fp) mas01mc@740: fp.close() mas01mc@740: statinfo = os.stat(filename,) mas01mc@740: if statinfo: mas01mc@740: print "Wrote out", statinfo.st_size, "bytes to", \ mas01mc@740: filename mas01mc@740: except: mas01mc@740: print "Couldn't pickle index to file", filename mas01mc@740: traceback.print_exc(file=sys.stderr) mas01mc@740: mas01mc@740: # Read an LSH index from a pickle file. mas01mc@740: def LoadIndex(filename): mas01mc@740: try: mas01mc@740: fp = open(filename, 'r') mas01mc@740: ind = pickle.load(fp) mas01mc@740: fp.close() mas01mc@740: return ind mas01mc@740: except: mas01mc@740: print "Couldn't read pickle file", filename mas01mc@740: traceback.print_exc(file=sys.stderr) mas01mc@740: mas01mc@740: mas01mc@740: