mas01mc@740
|
1 import random, numpy, pickle, os, operator, traceback, sys, math
|
mas01mc@740
|
2
|
mas01mc@740
|
3
|
mas01mc@740
|
4 # Python implementation of Andoni's e2LSH. This version is fast because it
|
mas01mc@740
|
5 # uses Python hashes to implement the buckets. The numerics are handled
|
mas01mc@740
|
6 # by the numpy routine so this should be close to optimal in speed (although
|
mas01mc@740
|
7 # there is no control of the hash tables layout in memory.)
|
mas01mc@740
|
8
|
mas01mc@740
|
9 # Malcolm Slaney, Yahoo! Research, 2009
|
mas01mc@740
|
10
|
mas01mc@740
|
11 # Class that can be used to return random prime numbers. We need this
|
mas01mc@740
|
12 # for our hash functions.
|
mas01mc@740
|
13 class primes():
|
mas01mc@740
|
14 # This list is from http://primes.utm.edu/lists/2small/0bit.html
|
mas01mc@740
|
15 def __init__(self):
|
mas01mc@740
|
16 p = {}
|
mas01mc@740
|
17 # These numbers, 2**i - j are primes
|
mas01mc@740
|
18 p[8] = [5, 15, 17, 23, 27, 29, 33, 45, 57, 59]
|
mas01mc@740
|
19 p[9] = [3, 9, 13, 21, 25, 33, 45, 49, 51, 55]
|
mas01mc@740
|
20 p[10] = [3, 5, 11, 15, 27, 33, 41, 47, 53, 57]
|
mas01mc@740
|
21 p[11] = [9, 19, 21, 31, 37, 45, 49, 51, 55, 61]
|
mas01mc@740
|
22 p[12] = [3, 5, 17, 23, 39, 45, 47, 69, 75, 77]
|
mas01mc@740
|
23 p[13] = [1, 13, 21, 25, 31, 45, 69, 75, 81, 91]
|
mas01mc@740
|
24 p[14] = [3, 15, 21, 23, 35, 45, 51, 65, 83, 111]
|
mas01mc@740
|
25 p[15] = [19, 49, 51, 55, 61, 75, 81, 115, 121, 135]
|
mas01mc@740
|
26 p[16] = [15, 17, 39, 57, 87, 89, 99, 113, 117, 123]
|
mas01mc@740
|
27 p[17] = [1, 9, 13, 31, 49, 61, 63, 85, 91, 99]
|
mas01mc@740
|
28 p[18] = [5, 11, 17, 23, 33, 35, 41, 65, 75, 93]
|
mas01mc@740
|
29 p[19] = [1, 19, 27, 31, 45, 57, 67, 69, 85, 87]
|
mas01mc@740
|
30 p[20] = [3, 5, 17, 27, 59, 69, 129, 143, 153, 185]
|
mas01mc@740
|
31 p[21] = [9, 19, 21, 55, 61, 69, 105, 111, 121, 129]
|
mas01mc@740
|
32 p[22] = [3, 17, 27, 33, 57, 87, 105, 113, 117, 123]
|
mas01mc@740
|
33 p[23] = [15, 21, 27, 37, 61, 69, 135, 147, 157, 159]
|
mas01mc@740
|
34 p[24] = [3, 17, 33, 63, 75, 77, 89, 95, 117, 167]
|
mas01mc@740
|
35 p[25] = [39, 49, 61, 85, 91, 115, 141, 159, 165, 183]
|
mas01mc@740
|
36 p[26] = [5, 27, 45, 87, 101, 107, 111, 117, 125, 135]
|
mas01mc@740
|
37 p[27] = [39, 79, 111, 115, 135, 187, 199, 219, 231, 235]
|
mas01mc@740
|
38 p[28] = [57, 89, 95, 119, 125, 143, 165, 183, 213, 273]
|
mas01mc@740
|
39 p[29] = [3, 33, 43, 63, 73, 75, 93, 99, 121, 133]
|
mas01mc@740
|
40 p[30] = [35, 41, 83, 101, 105, 107, 135, 153, 161, 173]
|
mas01mc@740
|
41 p[31] = [1, 19, 61, 69, 85, 99, 105, 151, 159, 171]
|
mas01mc@740
|
42
|
mas01mc@740
|
43 primes = []
|
mas01mc@740
|
44 for k in p:
|
mas01mc@740
|
45 for m in p[k]:
|
mas01mc@740
|
46 primes.append(2**k-m)
|
mas01mc@740
|
47 self.primes = primes
|
mas01mc@740
|
48 self.len = len(primes)
|
mas01mc@740
|
49
|
mas01mc@740
|
50 def sample(self):
|
mas01mc@740
|
51 return self.primes[int(random.random()*self.len)]
|
mas01mc@740
|
52
|
mas01mc@740
|
53
|
mas01mc@740
|
54 # This class just implements k-projections into k integers
|
mas01mc@740
|
55 # (after quantization) and then reducing that integer vector
|
mas01mc@740
|
56 # into a T1 and T2 hash. Data can either be entered into a
|
mas01mc@740
|
57 # table, or retrieved.
|
mas01mc@740
|
58 class lsh:
|
mas01mc@740
|
59 def __init__(self, k, w, N, t2size = (2**16 - 15)):
|
mas01mc@740
|
60 self.k = k # Number of projections
|
mas01mc@740
|
61 self.w = w # Bin width
|
mas01mc@740
|
62 self.N = N # Number of buckets
|
mas01mc@740
|
63 self.t2size = t2size
|
mas01mc@740
|
64 self.projections = None
|
mas01mc@740
|
65 self.buckets = {}
|
mas01mc@740
|
66
|
mas01mc@740
|
67 # Create the random constants needed for the projections.
|
mas01mc@740
|
68 def CreateProjections(self, dim):
|
mas01mc@740
|
69 self.dim = dim
|
mas01mc@740
|
70 p = primes()
|
mas01mc@740
|
71 self.projections = numpy.random.randn(self.k, self.dim)
|
mas01mc@740
|
72 self.bias = numpy.random.rand(self.k)
|
mas01mc@740
|
73 # Use the bias array just for its size (ignore contents)
|
mas01mc@740
|
74 self.t1hash = map(lambda x: p.sample(), self.bias)
|
mas01mc@740
|
75 self.t2hash = map(lambda x: p.sample(), self.bias)
|
mas01mc@740
|
76 if 0:
|
mas01mc@740
|
77 print "Dim is", self.dim
|
mas01mc@740
|
78 print 'Projections:\n', self.projections
|
mas01mc@740
|
79 print 'T1 hash:\n', self.t1hash
|
mas01mc@740
|
80 print 'T2 hash:\n', self.t2hash
|
mas01mc@740
|
81
|
mas01mc@740
|
82 # Actually create the t1 and t2 hashes for some data.
|
mas01mc@740
|
83 def CalculateHashes(self, data):
|
mas01mc@740
|
84 if self.projections == None:
|
mas01mc@740
|
85 self.CreateProjections(len(data))
|
mas01mc@740
|
86 bins = numpy.zeros(self.k, 'i')
|
mas01mc@740
|
87 for i in range(0,self.k):
|
mas01mc@740
|
88 bins[i] = numpy.sum(data * self.projections[i,:])/self.w \
|
mas01mc@740
|
89 + self.bias[i]
|
mas01mc@740
|
90 t1 = numpy.sum(bins * self.t1hash) % self.N
|
mas01mc@740
|
91 t2 = numpy.sum(bins * self.t2hash) % self.t2size
|
mas01mc@740
|
92 return t1, t2
|
mas01mc@740
|
93
|
mas01mc@740
|
94 def CalculateHashes2(self, data):
|
mas01mc@740
|
95 if self.projections == None:
|
mas01mc@740
|
96 self.CreateProjections(len(data))
|
mas01mc@740
|
97 bins = numpy.zeros(self.k, 'i')
|
mas01mc@740
|
98 for i in range(0,self.k):
|
mas01mc@740
|
99 bins[i] = numpy.sum(data * self.projections[i,:])/self.w \
|
mas01mc@740
|
100 + self.bias[i]
|
mas01mc@740
|
101 t1 = numpy.sum(bins * self.t1hash) % self.N
|
mas01mc@740
|
102 t2 = numpy.sum(bins * self.t2hash) % self.t2size
|
mas01mc@740
|
103 return t1, t2, bins
|
mas01mc@740
|
104
|
mas01mc@740
|
105 # Put some data into the hash bucket for this LSH projection
|
mas01mc@740
|
106 def InsertIntoTable(self, id, data):
|
mas01mc@740
|
107 (t1, t2) = self.CalculateHashes(data)
|
mas01mc@740
|
108 if t1 not in self.buckets:
|
mas01mc@740
|
109 self.buckets[t1] = {t2: [id]}
|
mas01mc@740
|
110 else:
|
mas01mc@740
|
111 if t2 not in self.buckets[t1]:
|
mas01mc@740
|
112 self.buckets[t1][t2] = [id]
|
mas01mc@740
|
113 else:
|
mas01mc@740
|
114 self.buckets[t1][t2].append(id)
|
mas01mc@740
|
115
|
mas01mc@740
|
116 # Find some data in the hash bucket. Return all the ids
|
mas01mc@740
|
117 # that we find for this T1-T2 pair.
|
mas01mc@740
|
118 def Find(self, data):
|
mas01mc@740
|
119 (t1, t2) = self.CalculateHashes(data)
|
mas01mc@740
|
120 if t1 not in self.buckets:
|
mas01mc@740
|
121 return []
|
mas01mc@740
|
122 row = self.buckets[t1]
|
mas01mc@740
|
123 if t2 not in row:
|
mas01mc@740
|
124 return []
|
mas01mc@740
|
125 return row[t2]
|
mas01mc@740
|
126
|
mas01mc@740
|
127 # Create a dictionary showing all the buckets an ID appears in
|
mas01mc@740
|
128 def CreateDictionary(self, theDictionary, prefix):
|
mas01mc@740
|
129 for b in self.buckets: # Over all buckets
|
mas01mc@740
|
130 w = prefix + str(b)
|
mas01mc@740
|
131 for c in self.buckets[b]:# Over all T2 hashes
|
mas01mc@740
|
132 for i in self.buckets[b][c]:#Over ids
|
mas01mc@740
|
133 if not i in theDictionary:
|
mas01mc@740
|
134 theDictionary[i] = [w]
|
mas01mc@740
|
135 else:
|
mas01mc@740
|
136 theDictionary[i] += w
|
mas01mc@740
|
137 return theDictionary
|
mas01mc@740
|
138
|
mas01mc@740
|
139
|
mas01mc@740
|
140
|
mas01mc@740
|
141 # Print some stats for these lsh buckets
|
mas01mc@740
|
142 def Stats(self):
|
mas01mc@740
|
143 maxCount = 0; sumCount = 0;
|
mas01mc@740
|
144 numCount = 0; bucketLens = [];
|
mas01mc@740
|
145 for b in self.buckets:
|
mas01mc@740
|
146 for c in self.buckets[b]:
|
mas01mc@740
|
147 l = len(self.buckets[b][c])
|
mas01mc@740
|
148 if l > maxCount:
|
mas01mc@740
|
149 maxCount = l
|
mas01mc@740
|
150 maxLoc = (b,c)
|
mas01mc@740
|
151 # print b,c,self.buckets[b][c]
|
mas01mc@740
|
152 sumCount += l
|
mas01mc@740
|
153 numCount += 1
|
mas01mc@740
|
154 bucketLens.append(l)
|
mas01mc@740
|
155 theValues = sorted(bucketLens)
|
mas01mc@740
|
156 med = theValues[(len(theValues)+1)/2-1]
|
mas01mc@740
|
157 print "Bucket Counts:"
|
mas01mc@740
|
158 print "\tTotal indexed points:", sumCount
|
mas01mc@740
|
159 print "\tT1 Buckets filled: %d/%d" % (len(self.buckets), self.N)
|
mas01mc@740
|
160 print "\tT2 Buckets used: %d/%d" % (numCount, self.N)
|
mas01mc@740
|
161 print "\tMaximum T2 chain length:", maxCount, "at", maxLoc
|
mas01mc@740
|
162 print "\tAverage T2 chain length:", float(sumCount)/numCount
|
mas01mc@740
|
163 print "\tMedian T2 chain length:", med
|
mas01mc@740
|
164
|
mas01mc@740
|
165 # Get a list of all IDs that are contained in these hash buckets
|
mas01mc@740
|
166 def GetAllIndices(self):
|
mas01mc@740
|
167 theList = []
|
mas01mc@740
|
168 for b in self.buckets:
|
mas01mc@740
|
169 for c in self.buckets[b]:
|
mas01mc@740
|
170 theList += self.buckets[b][c]
|
mas01mc@740
|
171 return theList
|
mas01mc@740
|
172
|
mas01mc@740
|
173 # Put some data into the hash table, see how many collisions we get.
|
mas01mc@740
|
174 def Test(self, n):
|
mas01mc@740
|
175 self.buckets = {}
|
mas01mc@740
|
176 self.projections = None
|
mas01mc@740
|
177 d = numpy.array([.2,.3])
|
mas01mc@740
|
178 for i in range(0,n):
|
mas01mc@740
|
179 self.InsertIntoTable(i, d+i)
|
mas01mc@740
|
180 for i in range(0,n):
|
mas01mc@740
|
181 r = self.Find(d+i)
|
mas01mc@740
|
182 matches = sum(map(lambda x: x==i, r))
|
mas01mc@740
|
183 if matches == 0:
|
mas01mc@740
|
184 print "Couldn't find item", i
|
mas01mc@740
|
185 elif matches == 1:
|
mas01mc@740
|
186 pass
|
mas01mc@740
|
187 if len(r) > 1:
|
mas01mc@740
|
188 print "Found big bin for", i,":", r
|
mas01mc@740
|
189
|
mas01mc@740
|
190
|
mas01mc@740
|
191 # Put together several LSH projections to form an index. The only
|
mas01mc@740
|
192 # new parameter is the number of groups of projections (one LSH class
|
mas01mc@740
|
193 # object per group.)
|
mas01mc@740
|
194 class index:
|
mas01mc@740
|
195 def __init__(self, k, l, w, N):
|
mas01mc@740
|
196 self.k = k;
|
mas01mc@740
|
197 self.l = l
|
mas01mc@740
|
198 self.w = w
|
mas01mc@740
|
199 self.N = N
|
mas01mc@740
|
200 self.projections = []
|
mas01mc@740
|
201 for i in range(0,l): # Create all LSH buckets
|
mas01mc@740
|
202 self.projections.append(lsh(k, w, N))
|
mas01mc@740
|
203 # Insert some data into all LSH buckets
|
mas01mc@740
|
204 def InsertIntoTable(self, id, data):
|
mas01mc@740
|
205 for p in self.projections:
|
mas01mc@740
|
206 p.InsertIntoTable(id, data)
|
mas01mc@740
|
207 # Find some data in all the LSH buckets.
|
mas01mc@740
|
208 def Find(self, data):
|
mas01mc@740
|
209 items = []
|
mas01mc@740
|
210 for p in self.projections:
|
mas01mc@740
|
211 items += p.Find(data) # Concatenate
|
mas01mc@740
|
212 # print "Find results are:", items
|
mas01mc@740
|
213 results = {}
|
mas01mc@740
|
214 for item in items:
|
mas01mc@740
|
215 results.setdefault(item, 0)
|
mas01mc@740
|
216 results[item] += 1
|
mas01mc@740
|
217 s = sorted(results.items(), key=operator.itemgetter(1), \
|
mas01mc@740
|
218 reverse=True)
|
mas01mc@740
|
219 return s
|
mas01mc@740
|
220
|
mas01mc@740
|
221 # Return a list of results: (id, distance**2, count)
|
mas01mc@740
|
222 def FindExact(self, data, GetData):
|
mas01mc@740
|
223 s = self.Find(data)
|
mas01mc@740
|
224 # print "Intermediate results are:", s
|
mas01mc@740
|
225 d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
|
mas01mc@740
|
226 ds = sorted(d, key=operator.itemgetter(1))
|
mas01mc@740
|
227 return ds
|
mas01mc@740
|
228
|
mas01mc@740
|
229 # Do an exhaustive distance calculation looking for all points and their distance.
|
mas01mc@740
|
230 # Return a list of results: (id, distance**2, count)
|
mas01mc@740
|
231 def FindAll(self, query, GetData):
|
mas01mc@740
|
232 s = []
|
mas01mc@740
|
233 allIDs = self.GetAllIndices()
|
mas01mc@740
|
234 for id in allIDs:
|
mas01mc@740
|
235 dist = ((GetData(id)-query)**2).sum()
|
mas01mc@740
|
236 s.append((id, dist, 0))
|
mas01mc@740
|
237 # print "Intermediate results are:", s
|
mas01mc@740
|
238 # d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(),count), s)
|
mas01mc@740
|
239 ds = sorted(s, key=operator.itemgetter(1))
|
mas01mc@740
|
240 return ds
|
mas01mc@740
|
241
|
mas01mc@740
|
242 # Return the number of points that are closer than radius to the query
|
mas01mc@740
|
243 def CountInsideRadius(self, data, GetData, radius):
|
mas01mc@740
|
244 matches = self.FindExact(data, GetData)
|
mas01mc@740
|
245 # print "CountInsideRadius found",len(matches),"matches"
|
mas01mc@740
|
246 radius2 = radius**2
|
mas01mc@740
|
247 count = sum(map(lambda (id,distance,count): distance<radius2, matches))
|
mas01mc@740
|
248 return count
|
mas01mc@740
|
249
|
mas01mc@740
|
250 # Put some data into the hash tables.
|
mas01mc@740
|
251 def Test(self, n):
|
mas01mc@740
|
252 d = numpy.array([.2,.3])
|
mas01mc@740
|
253 for i in range(0,n):
|
mas01mc@740
|
254 self.InsertIntoTable(i, d+i)
|
mas01mc@740
|
255 for i in range(0,n):
|
mas01mc@740
|
256 r = self.Find(d+i)
|
mas01mc@740
|
257 print r
|
mas01mc@740
|
258
|
mas01mc@740
|
259 # Print the statistics of each hash table.
|
mas01mc@740
|
260 def Stats(self):
|
mas01mc@740
|
261 for i in range(0, len(self.projections)):
|
mas01mc@740
|
262 p = self.projections[i]
|
mas01mc@740
|
263 print "Buckets", i,
|
mas01mc@740
|
264 p.Stats()
|
mas01mc@740
|
265
|
mas01mc@740
|
266 # Get al the IDs that are part of this index. Just check one hash
|
mas01mc@740
|
267 def GetAllIndices(self):
|
mas01mc@740
|
268 if self.projections and len(self.projections) > 0:
|
mas01mc@740
|
269 p = self.projections[0]
|
mas01mc@740
|
270 return p.GetAllIndices()
|
mas01mc@740
|
271 return None
|
mas01mc@740
|
272
|
mas01mc@740
|
273 # Return the buckets (t1 and t2 hashes) associated with a data point
|
mas01mc@740
|
274 def GetBuckets(data):
|
mas01mc@740
|
275 b = []
|
mas01mc@740
|
276 for p in self.projections:
|
mas01mc@740
|
277 h = p.CalculateHashes(data)
|
mas01mc@740
|
278 b += h
|
mas01mc@740
|
279
|
mas01mc@740
|
280 # Create a list ordered by ID listing which buckets are used for each ID
|
mas01mc@740
|
281 def CreateDictionary():
|
mas01mc@740
|
282 theDictionary = {}
|
mas01mc@740
|
283 prefixes = 'abcdefghijklmnopqrstuvwxyz'
|
mas01mc@740
|
284 pi = 0
|
mas01mc@740
|
285 for p in self.projections:
|
mas01mc@740
|
286 prefix = 'W'
|
mas01mc@740
|
287 pc = pi
|
mas01mc@740
|
288 while pc > 0: # Create unique ID for theis bucket
|
mas01mc@740
|
289 prefix += prefixes[pc%len(prefixes)]
|
mas01mc@740
|
290 pc /= len(prefixes)
|
mas01mc@740
|
291 theDictionary = p.CreateDictionary(theDictionary,\
|
mas01mc@740
|
292 prefix)
|
mas01mc@740
|
293 pi += 1
|
mas01mc@740
|
294 return theDictionary
|
mas01mc@740
|
295
|
mas01mc@740
|
296 # Use the expression in "Analysis of Minimum Distances in High-Dimensional
|
mas01mc@740
|
297 # Musical Spaces" to calculate the underlying dimensionality of the data
|
mas01mc@740
|
298 # For a random selection of ids, find the nearest neighbors and use this
|
mas01mc@740
|
299 # to calculate the dimensionality of the data.
|
mas01mc@740
|
300 def MeasureDimensionality(self,allData,N):
|
mas01mc@740
|
301 allIDs = self.GetAllIndices()
|
mas01mc@740
|
302 sampleIDs = random.sample(allIDs, N)
|
mas01mc@740
|
303 L = 0.0; S=0.0
|
mas01mc@740
|
304 for id in sampleIDs:
|
mas01mc@740
|
305 res = self.FindExact(allData[id,:], lambda i:allData[i, :])
|
mas01mc@740
|
306 if len(res) > 1:
|
mas01mc@740
|
307 (nnid, dist, count) = res[1]
|
mas01mc@740
|
308 S += dist
|
mas01mc@740
|
309 L += math.log(dist)
|
mas01mc@740
|
310 else:
|
mas01mc@740
|
311 N -= 1
|
mas01mc@740
|
312 print "S="+str(S), "L="+str(L), "N="+str(N)
|
mas01mc@740
|
313 if N > 1:
|
mas01mc@740
|
314 x = math.log(S/N) - L/N # Equation 17
|
mas01mc@740
|
315 d = 2*InvertFunction(x, lambda y:math.log(y)-digamma(y))
|
mas01mc@740
|
316 print d
|
mas01mc@740
|
317 return d
|
mas01mc@740
|
318 else:
|
mas01mc@740
|
319 return 0
|
mas01mc@740
|
320
|
mas01mc@740
|
321 # Only works for monotonic functions... Uses geometric midpoint to reduce
|
mas01mc@740
|
322 # the search range, looking for the function output that equals the given value
|
mas01mc@740
|
323 # Test with:
|
mas01mc@740
|
324 # lsh.InvertFunction(2,math.sqrt)
|
mas01mc@740
|
325 # lsh.InvertFunction(2,lambda x:1.0/x)
|
mas01mc@740
|
326 # Needed for inverting the gamma function in the MeasureDimensionality method.
|
mas01mc@740
|
327 def InvertFunction(x, func):
|
mas01mc@740
|
328 min = 0.0001; max = 1000;
|
mas01mc@740
|
329 if func(min) < func(max):
|
mas01mc@740
|
330 sign = 1
|
mas01mc@740
|
331 else:
|
mas01mc@740
|
332 sign = -1
|
mas01mc@740
|
333 print "Looking for Y() =", str(x), "d'=", sign
|
mas01mc@740
|
334 while min + 1e-7 < max:
|
mas01mc@740
|
335 mid = math.sqrt(min*max)
|
mas01mc@740
|
336 Y = func(mid)
|
mas01mc@740
|
337 # print min, mid, Y, max
|
mas01mc@740
|
338 if sign*Y > sign*x:
|
mas01mc@740
|
339 max = mid
|
mas01mc@740
|
340 else:
|
mas01mc@740
|
341 min = mid
|
mas01mc@740
|
342 return mid
|
mas01mc@740
|
343
|
mas01mc@740
|
344 ##### A bunch of routines used to generate data we can use to test
|
mas01mc@740
|
345 # this LSH implementation.
|
mas01mc@740
|
346
|
mas01mc@740
|
347 global gLSHTestData
|
mas01mc@740
|
348 gLSHTestData = []
|
mas01mc@740
|
349
|
mas01mc@740
|
350 # Find a point in the array of data. (Needed so FindExact can get the
|
mas01mc@740
|
351 # data it needs.)
|
mas01mc@740
|
352 def FindLSHTestData(id):
|
mas01mc@740
|
353 global gLSHTestData
|
mas01mc@740
|
354 if id < gLSHTestData.shape[0]:
|
mas01mc@740
|
355 return gLSHTestData[id,:]
|
mas01mc@740
|
356 return None
|
mas01mc@740
|
357
|
mas01mc@740
|
358 # Fill the test array with uniform random data between 0 and 1
|
mas01mc@740
|
359 def CreateRandomLSHTestData(numPoints, dim):
|
mas01mc@740
|
360 global gLSHTestData
|
mas01mc@740
|
361 gLSHTestData = []
|
mas01mc@740
|
362 gLSHTestData = (numpy.random.rand(numPoints, dim)-.5)*2.0
|
mas01mc@740
|
363
|
mas01mc@740
|
364 # Fill the test array with a regular grid of points between -1 and 1
|
mas01mc@740
|
365 def CreateRegularLSHTestData(numDivs):
|
mas01mc@740
|
366 gLSHTestData = numpy.zeros(((2*numDivs+1)**2,2))
|
mas01mc@740
|
367 i = 0
|
mas01mc@740
|
368 for x in range(-numDivs, numDivs+1):
|
mas01mc@740
|
369 for y in range(-numDivs, numDivs+1):
|
mas01mc@740
|
370 gLSHTestData[i,0] = x/float(divs)
|
mas01mc@740
|
371 gLSHTestData[i,1] = y/float(divs)
|
mas01mc@740
|
372 i += 1
|
mas01mc@740
|
373 numPoints = i
|
mas01mc@740
|
374
|
mas01mc@740
|
375 # Use Nearest Neighbor properties to calculate dimensionality.
|
mas01mc@740
|
376 def TestDimensionality(N):
|
mas01mc@740
|
377 numPoints = 100000
|
mas01mc@740
|
378 k = 10
|
mas01mc@740
|
379 CreateRandomLSHTestData(numPoints, 3)
|
mas01mc@740
|
380 ind = index(k, 2, .1, 100)
|
mas01mc@740
|
381 for i in range(0,numPoints):
|
mas01mc@740
|
382 ind.InsertIntoTable(i, FindLSHTestData(i))
|
mas01mc@740
|
383 ind.MeasureDimensionality(gLSHTestData, N)
|
mas01mc@740
|
384
|
mas01mc@740
|
385 # Use Dimension Doubling to measure the dimensionality of a random
|
mas01mc@740
|
386 # set of data. Generate some data (either random Gaussian or a grid)
|
mas01mc@740
|
387 # Then count the number of points that fall within the given radius of this query.
|
mas01mc@740
|
388 def TestDimensionality2():
|
mas01mc@740
|
389 global gLSHTestData
|
mas01mc@740
|
390 binWidth = .5
|
mas01mc@740
|
391 if True:
|
mas01mc@740
|
392 numPoints = 100000
|
mas01mc@740
|
393 CreateRandomLSHTestData(numPoints, 3)
|
mas01mc@740
|
394 else:
|
mas01mc@740
|
395 CreateRegularLSHTestData(100)
|
mas01mc@740
|
396 numPoints = gLSHTestData.shape[0]
|
mas01mc@740
|
397 k = 4; l = 2; N = 1000
|
mas01mc@740
|
398 ind = index(k, l, binWidth, N)
|
mas01mc@740
|
399 for i in range(0,numPoints):
|
mas01mc@740
|
400 ind.InsertIntoTable(i, gLSHTestData[i,:])
|
mas01mc@740
|
401 rBig = binWidth/8.0
|
mas01mc@740
|
402 rSmall = rBig/2.0
|
mas01mc@740
|
403 cBig = 0.0; cSmall = 0.0
|
mas01mc@740
|
404 for id in random.sample(ind.GetAllIndices(), 2):
|
mas01mc@740
|
405 qp = FindLSHTestData(id)
|
mas01mc@740
|
406 cBig += ind.CountInsideRadius(qp, FindLSHTestData, rBig)
|
mas01mc@740
|
407 cSmall += ind.CountInsideRadius(qp, FindLSHTestData, rSmall)
|
mas01mc@740
|
408 if cBig > cSmall and cSmall > 0:
|
mas01mc@740
|
409 dim = math.log(cBig/cSmall)/math.log(rBig/rSmall)
|
mas01mc@740
|
410 else:
|
mas01mc@740
|
411 dim = 0
|
mas01mc@740
|
412 print cBig, cSmall, dim
|
mas01mc@740
|
413 return ind
|
mas01mc@740
|
414
|
mas01mc@740
|
415 # Call an external process to compute the digamma function (from the GNU Scientific Library)
|
mas01mc@740
|
416 import subprocess
|
mas01mc@740
|
417 def digamma(x):
|
mas01mc@740
|
418 y = subprocess.Popen( ["./digamma", str(x)], stdout=subprocess.PIPE).communicate()[0]
|
mas01mc@740
|
419 return float(y.strip())
|
mas01mc@740
|
420
|
mas01mc@740
|
421
|
mas01mc@740
|
422 # Generate some 2-dimensional data, put it into an index and then
|
mas01mc@740
|
423 # show the points retrieved. This is all done as a function of number
|
mas01mc@740
|
424 # of projections per bucket, number of buckets to use for each index, and
|
mas01mc@740
|
425 # the number of LSH bucket (the T1 size). Write out the data so we can
|
mas01mc@740
|
426 # plot it (in Matlab)
|
mas01mc@740
|
427 def GraphicalTest(k, l, N):
|
mas01mc@741
|
428 global gLSHTestData
|
mas01mc@740
|
429 numPoints = 1000
|
mas01mc@740
|
430 CreateRandomLSHTestData(numPoints, 3)
|
mas01mc@740
|
431 ind = index(k, l, .1, N)
|
mas01mc@740
|
432 for i in range(0,numPoints):
|
mas01mc@740
|
433 ind.InsertIntoTable(i, FindLSHTestData(i))
|
mas01mc@741
|
434 data=gLSHTestData
|
mas01mc@740
|
435 i = 42
|
mas01mc@740
|
436 r = ind.Find(data[i,:])
|
mas01mc@740
|
437 fp = open('lshtestpoints.txt','w')
|
mas01mc@740
|
438 for i in range(0,numPoints):
|
mas01mc@740
|
439 if i in r:
|
mas01mc@740
|
440 c = r[i]
|
mas01mc@740
|
441 else:
|
mas01mc@740
|
442 c = 0
|
mas01mc@740
|
443 fp.write("%g %g %d\n" % (data[i,0], data[i,1], c))
|
mas01mc@740
|
444 fp.close()
|
mas01mc@740
|
445 return r
|
mas01mc@740
|
446
|
mas01mc@740
|
447
|
mas01mc@740
|
448 # Run one LSH test. Look for point 42 in the data.
|
mas01mc@740
|
449 def ExactTest():
|
mas01mc@740
|
450 global gLSHTestData
|
mas01mc@740
|
451 numPoints = 1000
|
mas01mc@740
|
452 CreateRandomLSHTestData(numPoints, 2)
|
mas01mc@740
|
453 ind = index(10, 2, .1, 100)
|
mas01mc@740
|
454 for i in range(0,numPoints):
|
mas01mc@740
|
455 ind.InsertIntoTable(i, FindLSHTestData(i))
|
mas01mc@740
|
456 data = FindLSHTestData(42)
|
mas01mc@740
|
457 res = ind.FindExact(data, FindLSHTestData)
|
mas01mc@740
|
458 return res
|
mas01mc@740
|
459
|
mas01mc@740
|
460 # Create a file with distances retrieved as a function of k.
|
mas01mc@740
|
461 # First line is the exact result, showing all points in the dB.
|
mas01mc@740
|
462 # Successive lines are results for an LSH index.
|
mas01mc@740
|
463 def TestRetrieval():
|
mas01mc@740
|
464 dims = 3
|
mas01mc@740
|
465 numPoints = 100000
|
mas01mc@740
|
466 CreateRandomLSHTestData(numPoints, 3)
|
mas01mc@740
|
467 qp = FindLSHTestData(0)*0.0
|
mas01mc@740
|
468 fp = open('TestRetrieval.txt','w')
|
mas01mc@740
|
469 for l in range(1,5):
|
mas01mc@740
|
470 for k in range(1,6):
|
mas01mc@740
|
471 for iter in range(1,10):
|
mas01mc@740
|
472 print "Building an index with l="+str(l)+", k="+str(k)
|
mas01mc@740
|
473 ind = index(k, l, .1, 100) # Build new index
|
mas01mc@740
|
474 for i in range(0,numPoints):
|
mas01mc@740
|
475 ind.InsertIntoTable(i, FindLSHTestData(i))
|
mas01mc@740
|
476 if k == 1 and l == 1:
|
mas01mc@740
|
477 matches = ind.FindAll(qp, FindLSHTestData)
|
mas01mc@740
|
478 fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
|
mas01mc@740
|
479 fp.write('\n')
|
mas01mc@740
|
480 matches = ind.FindExact(qp, FindLSHTestData)
|
mas01mc@740
|
481 fp.write(' '.join(map(lambda (i,d,c): str(d), matches)))
|
mas01mc@740
|
482 # Fill rest of the results with -1
|
mas01mc@740
|
483 fp.write(' '.join(map(str, (-numpy.ones((1,numPoints-len(matches)+1))).tolist())))
|
mas01mc@740
|
484 fp.write('\n')
|
mas01mc@740
|
485 fp.close()
|
mas01mc@740
|
486
|
mas01mc@740
|
487
|
mas01mc@740
|
488 # Save an LSH index to a pickle file.
|
mas01mc@740
|
489 def SaveIndex(filename, ind):
|
mas01mc@740
|
490 try:
|
mas01mc@740
|
491 fp = open(filename, 'w')
|
mas01mc@740
|
492 pickle.dump(ind, fp)
|
mas01mc@740
|
493 fp.close()
|
mas01mc@740
|
494 statinfo = os.stat(filename,)
|
mas01mc@740
|
495 if statinfo:
|
mas01mc@740
|
496 print "Wrote out", statinfo.st_size, "bytes to", \
|
mas01mc@740
|
497 filename
|
mas01mc@740
|
498 except:
|
mas01mc@740
|
499 print "Couldn't pickle index to file", filename
|
mas01mc@740
|
500 traceback.print_exc(file=sys.stderr)
|
mas01mc@740
|
501
|
mas01mc@740
|
502 # Read an LSH index from a pickle file.
|
mas01mc@740
|
503 def LoadIndex(filename):
|
mas01mc@740
|
504 try:
|
mas01mc@740
|
505 fp = open(filename, 'r')
|
mas01mc@740
|
506 ind = pickle.load(fp)
|
mas01mc@740
|
507 fp.close()
|
mas01mc@740
|
508 return ind
|
mas01mc@740
|
509 except:
|
mas01mc@740
|
510 print "Couldn't read pickle file", filename
|
mas01mc@740
|
511 traceback.print_exc(file=sys.stderr)
|
mas01mc@740
|
512
|
mas01mc@740
|
513
|
mas01mc@740
|
514
|