annotate lsh.py @ 755:37c2b9cce23a multiprobeLSH

Adding mkc_lsh_update branch, trunk candidate with improved LSH: merged trunk 1095 and branch multiprobe_lsh
author mas01mc
date Thu, 25 Nov 2010 13:42:40 +0000
parents 50a7fd50578f
children
rev   line source
mas01mc@740 1 import random, numpy, pickle, os, operator, traceback, sys, math
mas01mc@740 2
mas01mc@740 3
mas01mc@740 4 # Python implementation of Andoni's e2LSH. This version is fast because it
mas01mc@740 5 # uses Python hashes to implement the buckets. The numerics are handled
mas01mc@740 6 # by the numpy routine so this should be close to optimal in speed (although
mas01mc@740 7 # there is no control of the hash tables layout in memory.)
mas01mc@740 8
mas01mc@740 9 # Malcolm Slaney, Yahoo! Research, 2009
mas01mc@740 10
mas01mc@740 11 # Class that can be used to return random prime numbers. We need this
mas01mc@740 12 # for our hash functions.
mas01mc@740 13 class primes():
mas01mc@740 14 # This list is from http://primes.utm.edu/lists/2small/0bit.html
mas01mc@740 15 def __init__(self):
mas01mc@740 16 p = {}
mas01mc@740 17 # These numbers, 2**i - j are primes
mas01mc@740 18 p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59]
mas01mc@740 19 p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55]
mas01mc@740 20 p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57]
mas01mc@740 21 p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61]
mas01mc@740 22 p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77]
mas01mc@740 23 p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91]
mas01mc@740 24 p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111]
mas01mc@740 25 p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135]
mas01mc@740 26 p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123]
mas01mc@740 27 p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99]
mas01mc@740 28 p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93]
mas01mc@740 29 p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87]
mas01mc@740 30 p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185]
mas01mc@740 31 p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129]
mas01mc@740 32 p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123]
mas01mc@740 33 p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159]
mas01mc@740 34 p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167]
mas01mc@740 35 p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183]
mas01mc@740 36 p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135]
mas01mc@740 37 p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235]
mas01mc@740 38 p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273]
mas01mc@740 39 p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133]
mas01mc@740 40 p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173]
mas01mc@740 41 p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171]
mas01mc@740 42
mas01mc@740 43 primes = []
mas01mc@740 44 for k in p:
mas01mc@740 45 for m in p[k]:
mas01mc@740 46 primes.append(2**k-m)
mas01mc@740 47 self.primes = primes
mas01mc@740 48 self.len = len(primes)
mas01mc@740 49
mas01mc@740 50 def sample(self):
mas01mc@740 51 return self.primes[int(random.random()*self.len)]
mas01mc@740 52
mas01mc@740 53
mas01mc@740 54 # This class just implements k-projections into k integers
mas01mc@740 55 # (after quantization) and then reducing that integer vector
mas01mc@740 56 # into a T1 and T2 hash. Data can either be entered into a
mas01mc@740 57 # table, or retrieved.
mas01mc@740 58 class lsh:
mas01mc@740 59 def __init__(self, k, w, N, t2size = (2**16 - 15)):
mas01mc@740 60 self.k = k # Number of projections
mas01mc@740 61 self.w = w # Bin width
mas01mc@740 62 self.N = N # Number of buckets
mas01mc@740 63 self.t2size = t2size
mas01mc@740 64 self.projections = None
mas01mc@740 65 self.buckets = {}
mas01mc@740 66
mas01mc@740 67 # Create the random constants needed for the projections.
mas01mc@740 68 def CreateProjections(self, dim):
mas01mc@740 69 self.dim = dim
mas01mc@740 70 p = primes()
mas01mc@740 71 self.projections = numpy.random.randn(self.k, self.dim)
mas01mc@740 72 self.bias = numpy.random.rand(self.k)
mas01mc@740 73 # Use the bias array just for its size (ignore contents)
mas01mc@740 74 self.t1hash = map(lambda x: p.sample(), self.bias)
mas01mc@740 75 self.t2hash = map(lambda x: p.sample(), self.bias)
mas01mc@740 76 if 0:
mas01mc@740 77 print "Dim is", self.dim
mas01mc@740 78 print 'Projections:\n', self.projections
mas01mc@740 79 print 'T1 hash:\n', self.t1hash
mas01mc@740 80 print 'T2 hash:\n', self.t2hash
mas01mc@740 81
mas01mc@740 82 # Actually create the t1 and t2 hashes for some data.
mas01mc@740 83 def CalculateHashes(self, data):
mas01mc@740 84 if self.projections == None:
mas01mc@740 85 self.CreateProjections(len(data))
mas01mc@740 86 bins = numpy.zeros(self.k, 'i')
mas01mc@740 87 for i in range(0,self.k):
mas01mc@740 88 bins[i] = numpy.sum(data * self.projections[i,:])/self.w \
mas01mc@740 89 + self.bias[i]
mas01mc@740 90 t1 = numpy.sum(bins * self.t1hash) % self.N
mas01mc@740 91 t2 = numpy.sum(bins * self.t2hash) % self.t2size
mas01mc@740 92 return t1, t2
mas01mc@740 93
mas01mc@740 94 def CalculateHashes2(self, data):
mas01mc@740 95 if self.projections == None:
mas01mc@740 96 self.CreateProjections(len(data))
mas01mc@740 97 bins = numpy.zeros(self.k, 'i')
mas01mc@740 98 for i in range(0,self.k):
mas01mc@740 99 bins[i] = numpy.sum(data * self.projections[i,:])/self.w \
mas01mc@740 100 + self.bias[i]
mas01mc@740 101 t1 = numpy.sum(bins * self.t1hash) % self.N
mas01mc@740 102 t2 = numpy.sum(bins * self.t2hash) % self.t2size
mas01mc@740 103 return t1, t2, bins
mas01mc@740 104
mas01mc@740 105 # Put some data into the hash bucket for this LSH projection
mas01mc@740 106 def InsertIntoTable(self, id, data):
mas01mc@740 107 (t1, t2) = self.CalculateHashes(data)
mas01mc@740 108 if t1 not in self.buckets:
mas01mc@740 109 self.buckets[t1] = {t2: [id]}
mas01mc@740 110 else:
mas01mc@740 111 if t2 not in self.buckets[t1]:
mas01mc@740 112 self.buckets[t1][t2] = [id]
mas01mc@740 113 else:
mas01mc@740 114 self.buckets[t1][t2].append(id)
mas01mc@740 115
mas01mc@740 116 # Find some data in the hash bucket. Return all the ids
mas01mc@740 117 # that we find for this T1-T2 pair.
mas01mc@740 118 def Find(self, data):
mas01mc@740 119 (t1, t2) = self.CalculateHashes(data)
mas01mc@740 120 if t1 not in self.buckets:
mas01mc@740 121 return []
mas01mc@740 122 row = self.buckets[t1]
mas01mc@740 123 if t2 not in row:
mas01mc@740 124 return []
mas01mc@740 125 return row[t2]
mas01mc@740 126
mas01mc@740 127 # Create a dictionary showing all the buckets an ID appears in
mas01mc@740 128 def CreateDictionary(self, theDictionary, prefix):
mas01mc@740 129 for b in self.buckets: # Over all buckets
mas01mc@740 130 w = prefix + str(b)
mas01mc@740 131 for c in self.buckets[b]:# Over all T2 hashes
mas01mc@740 132 for i in self.buckets[b][c]:#Over ids
mas01mc@740 133 if not i in theDictionary:
mas01mc@740 134 theDictionary[i] = [w]
mas01mc@740 135 else:
mas01mc@740 136 theDictionary[i] += w
mas01mc@740 137 return theDictionary
mas01mc@740 138
mas01mc@740 139
mas01mc@740 140
mas01mc@740 141 # Print some stats for these lsh buckets
mas01mc@740 142 def Stats(self):
mas01mc@740 143 maxCount = 0; sumCount = 0;
mas01mc@740 144 numCount = 0; bucketLens = [];
mas01mc@740 145 for b in self.buckets:
mas01mc@740 146 for c in self.buckets[b]:
mas01mc@740 147 l = len(self.buckets[b][c])
mas01mc@740 148 if l > maxCount:
mas01mc@740 149 maxCount = l
mas01mc@740 150 maxLoc = (b,c)
mas01mc@740 151 # print b,c,self.buckets[b][c]
mas01mc@740 152 sumCount += l
mas01mc@740 153 numCount += 1
mas01mc@740 154 bucketLens.append(l)
mas01mc@740 155 theValues = sorted(bucketLens)
mas01mc@740 156 med = theValues[(len(theValues)+1)/2-1]
mas01mc@740 157 print "Bucket Counts:"
mas01mc@740 158 print "\tTotal indexed points:", sumCount
mas01mc@740 159 print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N)
mas01mc@740 160 print "\tT2 Buckets used: %d/%d" % (numCount, self.N)
mas01mc@740 161 print "\tMaximum T2 chain length:", maxCount, "at", maxLoc
mas01mc@740 162 print "\tAverage T2 chain length:", float(sumCount)/numCount
mas01mc@740 163 print "\tMedian T2 chain length:", med
mas01mc@740 164
mas01mc@740 165 # Get a list of all IDs that are contained in these hash buckets
mas01mc@740 166 def GetAllIndices(self):
mas01mc@740 167 theList = []
mas01mc@740 168 for b in self.buckets:
mas01mc@740 169 for c in self.buckets[b]:
mas01mc@740 170 theList += self.buckets[b][c]
mas01mc@740 171 return theList
mas01mc@740 172
mas01mc@740 173 # Put some data into the hash table, see how many collisions we get.
mas01mc@740 174 def Test(self, n):
mas01mc@740 175 self.buckets = {}
mas01mc@740 176 self.projections = None
mas01mc@740 177 d = numpy.array([.2,.3])
mas01mc@740 178 for i in range(0,n):
mas01mc@740 179 self.InsertIntoTable(i, d+i)
mas01mc@740 180 for i in range(0,n):
mas01mc@740 181 r = self.Find(d+i)
mas01mc@740 182 matches = sum(map(lambda x: x==i, r))
mas01mc@740 183 if matches == 0:
mas01mc@740 184 print "Couldn't find item", i
mas01mc@740 185 elif matches == 1:
mas01mc@740 186 pass
mas01mc@740 187 if len(r) > 1:
mas01mc@740 188 print "Found big bin for", i,":", r
mas01mc@740 189
mas01mc@740 190
mas01mc@740 191 # Put together several LSH projections to form an index. The only
mas01mc@740 192 # new parameter is the number of groups of projections (one LSH class
mas01mc@740 193 # object per group.)
mas01mc@740 194 class index:
mas01mc@740 195 def __init__(self, k, l, w, N):
mas01mc@740 196 self.k = k;
mas01mc@740 197 self.l = l
mas01mc@740 198 self.w = w
mas01mc@740 199 self.N = N
mas01mc@740 200 self.projections = []
mas01mc@740 201 for i in range(0,l): # Create all LSH buckets
mas01mc@740 202 self.projections.append(lsh(k, w, N))
mas01mc@740 203 # Insert some data into all LSH buckets
mas01mc@740 204 def InsertIntoTable(self, id, data):
mas01mc@740 205 for p in self.projections:
mas01mc@740 206 p.InsertIntoTable(id, data)
mas01mc@740 207 # Find some data in all the LSH buckets.
mas01mc@740 208 def Find(self, data):
mas01mc@740 209 items = []
mas01mc@740 210 for p in self.projections:
mas01mc@740 211 items += p.Find(data) # Concatenate
mas01mc@740 212 # print "Find results are:", items
mas01mc@740 213 results = {}
mas01mc@740 214 for item in items:
mas01mc@740 215 results.setdefault(item, 0)
mas01mc@740 216 results[item] += 1
mas01mc@740 217 s = sorted(results.items(), key=operator.itemgetter(1), \
mas01mc@740 218 reverse=True)
mas01mc@740 219 return s
mas01mc@740 220
mas01mc@740 221 # Return a list of results: (id, distance**2, count)
mas01mc@740 222 def FindExact(self, data, GetData):
mas01mc@740 223 s = self.Find(data)
mas01mc@740 224 # print "Intermediate results are:", s
mas01mc@740 225 d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
mas01mc@740 226 ds = sorted(d, key=operator.itemgetter(1))
mas01mc@740 227 return ds
mas01mc@740 228
mas01mc@740 229 # Do an exhaustive distance calculation looking for all points and their distance.
mas01mc@740 230 # Return a list of results: (id, distance**2, count)
mas01mc@740 231 def FindAll(self, query, GetData):
mas01mc@740 232 s = []
mas01mc@740 233 allIDs = self.GetAllIndices()
mas01mc@740 234 for id in allIDs:
mas01mc@740 235 dist = ((GetData(id)-query)**2).sum()
mas01mc@740 236 s.append((id, dist, 0))
mas01mc@740 237 # print "Intermediate results are:", s
mas01mc@740 238 # d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
mas01mc@740 239 ds = sorted(s, key=operator.itemgetter(1))
mas01mc@740 240 return ds
mas01mc@740 241
mas01mc@740 242 # Return the number of points that are closer than radius to the query
mas01mc@740 243 def CountInsideRadius(self, data, GetData, radius):
mas01mc@740 244 matches = self.FindExact(data, GetData)
mas01mc@740 245 # print "CountInsideRadius found",len(matches),"matches"
mas01mc@740 246 radius2 = radius**2
mas01mc@740 247 count = sum(map(lambda (id,distance,count): distance<radius2, matches))
mas01mc@740 248 return count
mas01mc@740 249
mas01mc@740 250 # Put some data into the hash tables.
mas01mc@740 251 def Test(self, n):
mas01mc@740 252 d = numpy.array([.2,.3])
mas01mc@740 253 for i in range(0,n):
mas01mc@740 254 self.InsertIntoTable(i, d+i)
mas01mc@740 255 for i in range(0,n):
mas01mc@740 256 r = self.Find(d+i)
mas01mc@740 257 print r
mas01mc@740 258
mas01mc@740 259 # Print the statistics of each hash table.
mas01mc@740 260 def Stats(self):
mas01mc@740 261 for i in range(0, len(self.projections)):
mas01mc@740 262 p = self.projections[i]
mas01mc@740 263 print "Buckets", i,
mas01mc@740 264 p.Stats()
mas01mc@740 265
mas01mc@740 266 # Get al the IDs that are part of this index. Just check one hash
mas01mc@740 267 def GetAllIndices(self):
mas01mc@740 268 if self.projections and len(self.projections) > 0:
mas01mc@740 269 p = self.projections[0]
mas01mc@740 270 return p.GetAllIndices()
mas01mc@740 271 return None
mas01mc@740 272
mas01mc@740 273 # Return the buckets (t1 and t2 hashes) associated with a data point
mas01mc@740 274 def GetBuckets(data):
mas01mc@740 275 b = []
mas01mc@740 276 for p in self.projections:
mas01mc@740 277 h = p.CalculateHashes(data)
mas01mc@740 278 b += h
mas01mc@740 279
mas01mc@740 280 # Create a list ordered by ID listing which buckets are used for each ID
mas01mc@740 281 def CreateDictionary():
mas01mc@740 282 theDictionary = {}
mas01mc@740 283 prefixes = 'abcdefghijklmnopqrstuvwxyz'
mas01mc@740 284 pi = 0
mas01mc@740 285 for p in self.projections:
mas01mc@740 286 prefix = 'W'
mas01mc@740 287 pc = pi
mas01mc@740 288 while pc > 0: # Create unique ID for theis bucket
mas01mc@740 289 prefix += prefixes[pc%len(prefixes)]
mas01mc@740 290 pc /= len(prefixes)
mas01mc@740 291 theDictionary = p.CreateDictionary(theDictionary,\
mas01mc@740 292 prefix)
mas01mc@740 293 pi += 1
mas01mc@740 294 return theDictionary
mas01mc@740 295
mas01mc@740 296 # Use the expression in "Analysis of Minimum Distances in High-Dimensional
mas01mc@740 297 # Musical Spaces" to calculate the underlying dimensionality of the data
mas01mc@740 298 # For a random selection of ids, find the nearest neighbors and use this
mas01mc@740 299 # to calculate the dimensionality of the data.
mas01mc@740 300 def MeasureDimensionality(self,allData,N):
mas01mc@740 301 allIDs = self.GetAllIndices()
mas01mc@740 302 sampleIDs = random.sample(allIDs, N)
mas01mc@740 303 L = 0.0; S=0.0
mas01mc@740 304 for id in sampleIDs:
mas01mc@740 305 res = self.FindExact(allData[id,:], lambda i:allData[i, :])
mas01mc@740 306 if len(res) > 1:
mas01mc@740 307 (nnid, dist, count) = res[1]
mas01mc@740 308 S += dist
mas01mc@740 309 L += math.log(dist)
mas01mc@740 310 else:
mas01mc@740 311 N -= 1
mas01mc@740 312 print "S="+str(S), "L="+str(L), "N="+str(N)
mas01mc@740 313 if N > 1:
mas01mc@740 314 x = math.log(S/N) - L/N # Equation 17
mas01mc@740 315 d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y))
mas01mc@740 316 print d
mas01mc@740 317 return d
mas01mc@740 318 else:
mas01mc@740 319 return 0
mas01mc@740 320
mas01mc@740 321 # Only works for monotonic functions... Uses geometric midpoint to reduce
mas01mc@740 322 # the search range, looking for the function output that equals the given value
mas01mc@740 323 # Test with:
mas01mc@740 324 # lsh.InvertFunction(2,math.sqrt)
mas01mc@740 325 # lsh.InvertFunction(2,lambda x:1.0/x)
mas01mc@740 326 # Needed for inverting the gamma function in the MeasureDimensionality method.
mas01mc@740 327 def InvertFunction(x, func):
mas01mc@740 328 min = 0.0001; max = 1000;
mas01mc@740 329 if func(min) < func(max):
mas01mc@740 330 sign = 1
mas01mc@740 331 else:
mas01mc@740 332 sign = -1
mas01mc@740 333 print "Looking for Y() =", str(x), "d'=", sign
mas01mc@740 334 while min + 1e-7 < max:
mas01mc@740 335 mid = math.sqrt(min*max)
mas01mc@740 336 Y = func(mid)
mas01mc@740 337 # print min, mid, Y, max
mas01mc@740 338 if sign*Y > sign*x:
mas01mc@740 339 max = mid
mas01mc@740 340 else:
mas01mc@740 341 min = mid
mas01mc@740 342 return mid
mas01mc@740 343
mas01mc@740 344 ##### A bunch of routines used to generate data we can use to test
mas01mc@740 345 # this LSH implementation.
mas01mc@740 346
mas01mc@740 347 global gLSHTestData
mas01mc@740 348 gLSHTestData = []
mas01mc@740 349
mas01mc@740 350 # Find a point in the array of data. (Needed so FindExact can get the
mas01mc@740 351 # data it needs.)
mas01mc@740 352 def FindLSHTestData(id):
mas01mc@740 353 global gLSHTestData
mas01mc@740 354 if id < gLSHTestData.shape[0]:
mas01mc@740 355 return gLSHTestData[id,:]
mas01mc@740 356 return None
mas01mc@740 357
mas01mc@740 358 # Fill the test array with uniform random data between 0 and 1
mas01mc@740 359 def CreateRandomLSHTestData(numPoints, dim):
mas01mc@740 360 global gLSHTestData
mas01mc@740 361 gLSHTestData = []
mas01mc@740 362 gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0
mas01mc@740 363
mas01mc@740 364 # Fill the test array with a regular grid of points between -1 and 1
mas01mc@740 365 def CreateRegularLSHTestData(numDivs):
mas01mc@740 366 gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2))
mas01mc@740 367 i = 0
mas01mc@740 368 for x in range(-numDivs, numDivs+1):
mas01mc@740 369 for y in range(-numDivs, numDivs+1):
mas01mc@740 370 gLSHTestData[i,0] = x/float(divs)
mas01mc@740 371 gLSHTestData[i,1] = y/float(divs)
mas01mc@740 372 i += 1
mas01mc@740 373 numPoints = i
mas01mc@740 374
mas01mc@740 375 # Use Nearest Neighbor properties to calculate dimensionality.
mas01mc@740 376 def TestDimensionality(N):
mas01mc@740 377 numPoints = 100000
mas01mc@740 378 k = 10
mas01mc@740 379 CreateRandomLSHTestData(numPoints, 3)
mas01mc@740 380 ind = index(k, 2, .1, 100)
mas01mc@740 381 for i in range(0,numPoints):
mas01mc@740 382 ind.InsertIntoTable(i, FindLSHTestData(i))
mas01mc@740 383 ind.MeasureDimensionality(gLSHTestData, N)
mas01mc@740 384
mas01mc@740 385 # Use Dimension Doubling to measure the dimensionality of a random
mas01mc@740 386 # set of data. Generate some data (either random Gaussian or a grid)
mas01mc@740 387 # Then count the number of points that fall within the given radius of this query.
mas01mc@740 388 def TestDimensionality2():
mas01mc@740 389 global gLSHTestData
mas01mc@740 390 binWidth = .5
mas01mc@740 391 if True:
mas01mc@740 392 numPoints = 100000
mas01mc@740 393 CreateRandomLSHTestData(numPoints, 3)
mas01mc@740 394 else:
mas01mc@740 395 CreateRegularLSHTestData(100)
mas01mc@740 396 numPoints = gLSHTestData.shape[0]
mas01mc@740 397 k = 4; l = 2; N = 1000
mas01mc@740 398 ind = index(k, l, binWidth, N)
mas01mc@740 399 for i in range(0,numPoints):
mas01mc@740 400 ind.InsertIntoTable(i, gLSHTestData[i,:])
mas01mc@740 401 rBig = binWidth/8.0
mas01mc@740 402 rSmall = rBig/2.0
mas01mc@740 403 cBig = 0.0; cSmall = 0.0
mas01mc@740 404 for id in random.sample(ind.GetAllIndices(), 2):
mas01mc@740 405 qp = FindLSHTestData(id)
mas01mc@740 406 cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig)
mas01mc@740 407 cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall)
mas01mc@740 408 if cBig > cSmall and cSmall > 0:
mas01mc@740 409 dim = math.log(cBig/cSmall)/math.log(rBig/rSmall)
mas01mc@740 410 else:
mas01mc@740 411 dim = 0
mas01mc@740 412 print cBig, cSmall, dim
mas01mc@740 413 return ind
mas01mc@740 414
mas01mc@740 415 # Call an external process to compute the digamma function (from the GNU Scientific Library)
mas01mc@740 416 import subprocess
mas01mc@740 417 def digamma(x):
mas01mc@740 418 y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0]
mas01mc@740 419 return float(y.strip())
mas01mc@740 420
mas01mc@740 421
mas01mc@740 422 # Generate some 2-dimensional data, put it into an index and then
mas01mc@740 423 # show the points retrieved. This is all done as a function of number
mas01mc@740 424 # of projections per bucket, number of buckets to use for each index, and
mas01mc@740 425 # the number of LSH bucket (the T1 size). Write out the data so we can
mas01mc@740 426 # plot it (in Matlab)
mas01mc@740 427 def GraphicalTest(k, l, N):
mas01mc@741 428 global gLSHTestData
mas01mc@740 429 numPoints = 1000
mas01mc@740 430 CreateRandomLSHTestData(numPoints, 3)
mas01mc@740 431 ind = index(k, l, .1, N)
mas01mc@740 432 for i in range(0,numPoints):
mas01mc@740 433 ind.InsertIntoTable(i, FindLSHTestData(i))
mas01mc@741 434 data=gLSHTestData
mas01mc@740 435 i = 42
mas01mc@740 436 r = ind.Find(data[i,:])
mas01mc@740 437 fp = open('lshtestpoints.txt','w')
mas01mc@740 438 for i in range(0,numPoints):
mas01mc@740 439 if i in r:
mas01mc@740 440 c = r[i]
mas01mc@740 441 else:
mas01mc@740 442 c = 0
mas01mc@740 443 fp.write("%g %g %d\n" % (data[i,0], data[i,1], c))
mas01mc@740 444 fp.close()
mas01mc@740 445 return r
mas01mc@740 446
mas01mc@740 447
mas01mc@740 448 # Run one LSH test. Look for point 42 in the data.
mas01mc@740 449 def ExactTest():
mas01mc@740 450 global gLSHTestData
mas01mc@740 451 numPoints = 1000
mas01mc@740 452 CreateRandomLSHTestData(numPoints, 2)
mas01mc@740 453 ind = index(10, 2, .1, 100)
mas01mc@740 454 for i in range(0,numPoints):
mas01mc@740 455 ind.InsertIntoTable(i, FindLSHTestData(i))
mas01mc@740 456 data = FindLSHTestData(42)
mas01mc@740 457 res = ind.FindExact(data, FindLSHTestData)
mas01mc@740 458 return res
mas01mc@740 459
mas01mc@740 460 # Create a file with distances retrieved as a function of k.
mas01mc@740 461 # First line is the exact result, showing all points in the dB.
mas01mc@740 462 # Successive lines are results for an LSH index.
mas01mc@740 463 def TestRetrieval():
mas01mc@740 464 dims = 3
mas01mc@740 465 numPoints = 100000
mas01mc@740 466 CreateRandomLSHTestData(numPoints, 3)
mas01mc@740 467 qp = FindLSHTestData(0)*0.0
mas01mc@740 468 fp = open('TestRetrieval.txt','w')
mas01mc@740 469 for l in range(1,5):
mas01mc@740 470 for k in range(1,6):
mas01mc@740 471 for iter in range(1,10):
mas01mc@740 472 print "Building an index with l="+str(l)+", k="+str(k)
mas01mc@740 473 ind = index(k, l, .1, 100) # Build new index
mas01mc@740 474 for i in range(0,numPoints):
mas01mc@740 475 ind.InsertIntoTable(i, FindLSHTestData(i))
mas01mc@740 476 if k == 1 and l == 1:
mas01mc@740 477 matches = ind.FindAll(qp, FindLSHTestData)
mas01mc@740 478 fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
mas01mc@740 479 fp.write('\n')
mas01mc@740 480 matches = ind.FindExact(qp, FindLSHTestData)
mas01mc@740 481 fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
mas01mc@740 482 # Fill rest of the results with -1
mas01mc@740 483 fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist())))
mas01mc@740 484 fp.write('\n')
mas01mc@740 485 fp.close()
mas01mc@740 486
mas01mc@740 487
mas01mc@740 488 # Save an LSH index to a pickle file.
mas01mc@740 489 def SaveIndex(filename, ind):
mas01mc@740 490 try:
mas01mc@740 491 fp = open(filename, 'w')
mas01mc@740 492 pickle.dump(ind, fp)
mas01mc@740 493 fp.close()
mas01mc@740 494 statinfo = os.stat(filename,)
mas01mc@740 495 if statinfo:
mas01mc@740 496 print "Wrote out", statinfo.st_size, "bytes to", \
mas01mc@740 497 filename
mas01mc@740 498 except:
mas01mc@740 499 print "Couldn't pickle index to file", filename
mas01mc@740 500 traceback.print_exc(file=sys.stderr)
mas01mc@740 501
mas01mc@740 502 # Read an LSH index from a pickle file.
mas01mc@740 503 def LoadIndex(filename):
mas01mc@740 504 try:
mas01mc@740 505 fp = open(filename, 'r')
mas01mc@740 506 ind = pickle.load(fp)
mas01mc@740 507 fp.close()
mas01mc@740 508 return ind
mas01mc@740 509 except:
mas01mc@740 510 print "Couldn't read pickle file", filename
mas01mc@740 511 traceback.print_exc(file=sys.stderr)
mas01mc@740 512
mas01mc@740 513
mas01mc@740 514