Mercurial > hg > audiodb
comparison lshlib.cpp @ 292:d9a88cfd4ab6
Completed merge of lshlib back to current version of the trunk.
author | mas01mc |
---|---|
date | Tue, 29 Jul 2008 22:01:17 +0000 |
parents | |
children | 9fd5340faffd |
comparison
equal
deleted
inserted
replaced
291:63ae0dfc1767 | 292:d9a88cfd4ab6 |
---|---|
1 #include "lshlib.h" | |
2 | |
3 //#define __LSH_DUMP_CORE_TABLES__ | |
4 //#define USE_U_FUNCTIONS | |
5 //#define LSH_BLOCK_FULL_ROWS | |
6 | |
7 void err(char*s){cout << s << endl;exit(2);} | |
8 | |
9 Uns32T get_page_logn(){ | |
10 int pagesz = (int)sysconf(_SC_PAGESIZE); | |
11 return (Uns32T)log2((double)pagesz); | |
12 } | |
13 | |
14 unsigned align_up(unsigned x, unsigned w){ return ((x) + ((1<<w)-1) & ~((1<<w)-1)); } | |
15 | |
16 void H::error(const char* a, const char* b, const char *sysFunc) { | |
17 cerr << a << ": " << b << endl; | |
18 if (sysFunc) { | |
19 perror(sysFunc); | |
20 } | |
21 exit(1); | |
22 } | |
23 | |
24 H::H(Uns32T kk, Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC): | |
25 #ifdef USE_U_FUNCTIONS | |
26 use_u_functions(true), | |
27 #else | |
28 use_u_functions(false), | |
29 #endif | |
30 bucketCount(0), | |
31 pointCount(0), | |
32 N(NN), | |
33 C(CC), | |
34 k(kk), | |
35 m(mm), | |
36 L(mm*(mm-1)/2), | |
37 d(dd) | |
38 { | |
39 Uns32T j; | |
40 cout << "file size: ~" << (((unsigned long long)L*N*C*sizeof(SerialElementT))/1000000UL) << "MB" << endl; | |
41 if(((unsigned long long)L*N*C*sizeof(SerialElementT))>4000000000UL) | |
42 error("Maximum size of LSH file exceded: 12*L*N*C > 4000MB"); | |
43 else if(((unsigned long long)N*C*sizeof(SerialElementT))>1000000000UL) | |
44 cout << "warning: hash tables exceed 1000MB." << endl; | |
45 | |
46 if(m<2){ | |
47 m=2; | |
48 L=1; // check value of L | |
49 cout << "warning: setting m=2, L=1" << endl; | |
50 } | |
51 if(use_u_functions && k%2){ | |
52 k++; // make sure k is even | |
53 cout << "warning: setting k even" << endl; | |
54 } | |
55 __initialize_data_structures(); | |
56 for(j=0; j<L; j++) | |
57 for(kk=0; kk<k; kk++) { | |
58 r1[j][kk]=__randr(); // random 1..2^29 | |
59 r2[j][kk]=__randr(); // random 1..2^29 | |
60 } | |
61 } | |
62 | |
63 // Post constructor initialization | |
64 void H::__initialize_data_structures(){ | |
65 H::P = UH_PRIME_DEFAULT; | |
66 | |
67 /* FIXME: don't use time(); instead use /dev/random or similar */ | |
68 /* FIXME: write out the seed somewhere, so that we can get | |
69 repeatability */ | |
70 #ifdef MT19937 | |
71 init_genrand(time(NULL)); | |
72 #else | |
73 srand(time(NULL)); // seed random number generator | |
74 #endif | |
75 Uns32T i,j; | |
76 H::h = new bucket**[ H::L ]; | |
77 H::r1 = new Uns32T*[ H::L ]; | |
78 H::r2 = new Uns32T*[ H::L ]; | |
79 assert( H::h && H::r1 && H::r2 ); // failure | |
80 for( j = 0 ; j < H::L ; j++ ){ | |
81 H::r1[ j ] = new Uns32T[ H::k ]; | |
82 H::r2[ j ] = new Uns32T[ H::k ]; | |
83 assert( H::r1[j] && H::r2[j] ); // failure | |
84 } | |
85 | |
86 for( j = 0 ; j < H::L ; j++ ){ | |
87 H::h[j] = new bucket*[ H::N ]; | |
88 assert( H::h[j] ); | |
89 for( i = 0 ; i < H::N ; i++) | |
90 H::h[j][i] = 0; | |
91 } | |
92 } | |
93 | |
94 // Destruct hash tables | |
95 H::~H(){ | |
96 Uns32T i,j; | |
97 for( j=0 ; j < H::L ; j++ ){ | |
98 delete[] H::r1[ j ]; | |
99 delete[] H::r2[ j ]; | |
100 for(i = 0; i< H::N ; i++) | |
101 delete H::h[ j ][ i ]; | |
102 delete[] H::h[ j ]; | |
103 } | |
104 delete[] H::r1; | |
105 delete[] H::r2; | |
106 delete[] H::h; | |
107 } | |
108 | |
109 | |
110 // make hash value \in Z | |
111 void H::__generate_hash_keys(Uns32T*g,Uns32T* r1, Uns32T* r2){ | |
112 H::t1 = __computeProductModDefaultPrime( g, r1, H::k ) % H::N; | |
113 H::t2 = __computeProductModDefaultPrime( g, r2, H::k ); | |
114 | |
115 } | |
116 | |
117 #define CR_ASSERT(b){if(!(b)){fprintf(stderr, "ASSERT failed on line %d, file %s.\n", __LINE__, __FILE__); exit(1);}} | |
118 | |
119 // Computes (a.b) mod UH_PRIME_DEFAULT | |
120 inline Uns32T H::__computeProductModDefaultPrime(Uns32T *a, Uns32T *b, IntT size){ | |
121 LongUns64T h = 0; | |
122 | |
123 for(IntT i = 0; i < size; i++){ | |
124 h = h + (LongUns64T)a[i] * (LongUns64T)b[i]; | |
125 h = (h & TWO_TO_32_MINUS_1) + 5 * (h >> 32); | |
126 if (h >= UH_PRIME_DEFAULT) { | |
127 h = h - UH_PRIME_DEFAULT; | |
128 } | |
129 CR_ASSERT(h < UH_PRIME_DEFAULT); | |
130 } | |
131 return h; | |
132 } | |
133 | |
134 Uns32T H::bucket_insert_point(bucket **pp){ | |
135 Uns32T collisionCount = 0; | |
136 if(!*pp){ | |
137 *pp = new bucket(); | |
138 #ifdef LSH_BLOCK_FULL_ROWS | |
139 (*pp)->t2 = 0; // Use t2 as a collision counter for the row | |
140 (*pp)->next = new bucket(); | |
141 #endif | |
142 } | |
143 #ifdef LSH_BLOCK_FULL_ROWS | |
144 collisionCount = (*pp)->t2; | |
145 if(collisionCount < H::C){ // Block if row is full | |
146 (*pp)->t2++; // Increment collision counter | |
147 pointCount++; | |
148 collisionCount++; | |
149 __bucket_insert_point((*pp)->next); // First bucket holds collision count | |
150 } | |
151 #else | |
152 pointCount++; | |
153 __bucket_insert_point(*pp); // No collision count storage | |
154 #endif | |
155 return collisionCount; | |
156 } | |
157 | |
158 void H::__bucket_insert_point(bucket* p){ | |
159 if(p->t2 == IFLAG){ // initialization flag, is it in the domain of t2? | |
160 p->t2 = H::t2; | |
161 bucketCount++; // Record start of new point-locale collision chain | |
162 p->snext = new sbucket(); | |
163 __sbucket_insert_point(p->snext); | |
164 return; | |
165 } | |
166 | |
167 if(p->t2 == H::t2){ | |
168 __sbucket_insert_point(p->snext); | |
169 return; | |
170 } | |
171 | |
172 if(p->next){ | |
173 __bucket_insert_point(p->next); | |
174 } | |
175 | |
176 else{ | |
177 p->next = new bucket(); | |
178 __bucket_insert_point(p->next); | |
179 } | |
180 | |
181 } | |
182 | |
183 void H::__sbucket_insert_point(sbucket* p){ | |
184 if(p->pointID==IFLAG){ | |
185 p->pointID = H::p; | |
186 return; | |
187 } | |
188 | |
189 // Search for pointID | |
190 if(p->snext){ | |
191 __sbucket_insert_point(p->snext); | |
192 } | |
193 else{ | |
194 // Make new point collision bucket at end of list | |
195 p->snext = new sbucket(); | |
196 __sbucket_insert_point(p->snext); | |
197 } | |
198 } | |
199 | |
200 inline bucket** H::__get_bucket(int j){ | |
201 return *(h+j); | |
202 } | |
203 | |
204 // hash functions G | |
205 G::G(float ww, Uns32T kk,Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC, float r): | |
206 H(kk,mm,dd,NN,CC), | |
207 w(ww), | |
208 radius(r), | |
209 maxp(0), | |
210 calling_instance(0), | |
211 add_point_callback(0), | |
212 lshHeader(0) | |
213 { | |
214 Uns32T j; | |
215 #ifdef USE_U_FUNCTIONS | |
216 G::A = new float**[ H::m ]; // m x k x d random projectors | |
217 G::b = new float*[ H::m ]; // m x k random biases | |
218 #else | |
219 G::A = new float**[ H::L ]; // m x k x d random projectors | |
220 G::b = new float*[ H::L ]; // m x k random biases | |
221 #endif | |
222 G::g = new Uns32T*[ H::L ]; // L x k random projections | |
223 assert( G::g && G::A && G::b ); // failure | |
224 #ifdef USE_U_FUNCTIONS | |
225 // Use m \times u_i functions \in R^{(k/2) \times (d)} | |
226 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
227 for( j = 0; j < H::m ; j++ ){ // m functions u_i(v) | |
228 G::A[j] = new float*[ H::k/2 ]; // k/2 x d 2-stable distribution coefficients | |
229 G::b[j] = new float[ H::k/2 ]; // bias | |
230 assert( G::A[j] && G::b[j] ); // failure | |
231 for( kk = 0; kk < H::k/2 ; kk++ ){ | |
232 G::A[j][kk] = new float[ H::d ]; | |
233 assert( G::A[j][kk] ); // failure | |
234 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
235 G::A[j][kk][i] = randn(); // Normal | |
236 G::b[j][kk] = ranf()*G::w; // Uniform | |
237 } | |
238 } | |
239 #else | |
240 // Use m \times u_i functions \in R^{k \times (d)} | |
241 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
242 for( j = 0; j < H::L ; j++ ){ // m functions u_i(v) | |
243 G::A[j] = new float*[ H::k ]; // k x d 2-stable distribution coefficients | |
244 G::b[j] = new float[ H::k ]; // bias | |
245 assert( G::A[j] && G::b[j] ); // failure | |
246 for( kk = 0; kk < H::k ; kk++ ){ | |
247 G::A[j][kk] = new float[ H::d ]; | |
248 assert( G::A[j][kk] ); // failure | |
249 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
250 G::A[j][kk][i] = randn(); // Normal | |
251 G::b[j][kk] = ranf()*G::w; // Uniform | |
252 } | |
253 } | |
254 #endif | |
255 | |
256 for( j = 0 ; j < H::L ; j++ ){ // L functions g_j(u_a, u_b) a,b \in nchoosek(m,2) | |
257 G::g[j] = new Uns32T[ H::k ]; // k x 32-bit hash values, gj(v)=[x0 x1 ... xk-1] xk \in Z | |
258 assert( G::g[j] ); | |
259 } | |
260 | |
261 initialize_partial_functions(); // m partially evaluated hash functions | |
262 } | |
263 | |
264 // Serialize from file LSH constructor | |
265 // Read parameters from database file | |
266 // Load the hash functions, close the database | |
267 // Optionally load the LSH tables into head-allocated lists in core | |
268 G::G(char* filename, bool lshInCoreFlag): | |
269 calling_instance(0), | |
270 add_point_callback(0) | |
271 { | |
272 int dbfid = unserialize_lsh_header(filename); | |
273 unserialize_lsh_functions(dbfid); | |
274 initialize_partial_functions(); | |
275 | |
276 // Format1 only needs unserializing if specifically requested | |
277 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2) && lshInCoreFlag){ | |
278 unserialize_lsh_hashtables_format1(dbfid); | |
279 } | |
280 | |
281 // Format2 always needs unserializing | |
282 if(lshHeader->flags&O2_SERIAL_FILEFORMAT2 && lshInCoreFlag){ | |
283 unserialize_lsh_hashtables_format2(dbfid); | |
284 } | |
285 | |
286 close(dbfid); | |
287 } | |
288 | |
289 void G::initialize_partial_functions(){ | |
290 | |
291 #ifdef USE_U_FUNCTIONS | |
292 uu = vector<vector<Uns32T> >(H::m); | |
293 for( Uns32T aa=0 ; aa < H::m ; aa++ ) | |
294 uu[aa] = vector<Uns32T>( H::k/2 ); | |
295 #else | |
296 uu = vector<vector<Uns32T> >(H::L); | |
297 for( Uns32T aa=0 ; aa < H::L ; aa++ ) | |
298 uu[aa] = vector<Uns32T>( H::k ); | |
299 #endif | |
300 } | |
301 | |
302 | |
303 // Generate z ~ N(0,1) | |
304 float G::randn(){ | |
305 // Box-Muller | |
306 float x1, x2; | |
307 do{ | |
308 x1 = ranf(); | |
309 } while (x1 == 0); // cannot take log of 0 | |
310 x2 = ranf(); | |
311 float z; | |
312 z = sqrtf(-2.0 * logf(x1)) * cosf(2.0 * M_PI * x2); | |
313 return z; | |
314 } | |
315 | |
316 float G::ranf(){ | |
317 #ifdef MT19937 | |
318 return (float) genrand_real2(); | |
319 #else | |
320 return (float)( (double)rand() / ((double)(RAND_MAX)+(double)(1)) ); | |
321 #endif | |
322 } | |
323 | |
324 // range is 1..2^29 | |
325 /* FIXME: that looks like an ... odd range. Still. */ | |
326 Uns32T H::__randr(){ | |
327 #ifdef MT19937 | |
328 return (Uns32T)((genrand_int32() >> 3) + 1); | |
329 #else | |
330 return (Uns32T) ((rand() >> 2) + 1); | |
331 #endif | |
332 } | |
333 | |
334 G::~G(){ | |
335 Uns32T j,kk; | |
336 #ifdef USE_U_FUNCTIONS | |
337 for( j = 0 ; j < H::m ; j++ ){ | |
338 for( kk = 0 ; kk < H::k/2 ; kk++ ) | |
339 delete[] A[j][kk]; | |
340 delete[] A[j]; | |
341 } | |
342 delete[] A; | |
343 for( j = 0 ; j < H::m ; j++ ) | |
344 delete[] b[j]; | |
345 delete[] b; | |
346 #else | |
347 for( j = 0 ; j < H::L ; j++ ){ | |
348 for( kk = 0 ; kk < H::k ; kk++ ) | |
349 delete[] A[j][kk]; | |
350 delete[] A[j]; | |
351 } | |
352 delete[] A; | |
353 for( j = 0 ; j < H::L ; j++ ) | |
354 delete[] b[j]; | |
355 delete[] b; | |
356 #endif | |
357 | |
358 for( j = 0 ; j < H::L ; j++ ) | |
359 delete[] g[j]; | |
360 delete[] g; | |
361 delete lshHeader; | |
362 } | |
363 | |
364 // Compute all hash functions for vector v | |
365 // #ifdef USE_U_FUNCTIONS use Combination of m \times h_i \in R^{(k/2) \times d} | |
366 // to make L \times g_j functions \in Z^k | |
367 void G::compute_hash_functions(vector<float>& v){ // v \in R^d | |
368 float iw = 1. / G::w; // hash bucket width | |
369 Uns32T aa, kk; | |
370 if( v.size() != H::d ) | |
371 error("v.size != H::d","","compute_hash_functions"); // check input vector dimensionality | |
372 double tmp = 0; | |
373 float *pA, *pb; | |
374 Uns32T *pg; | |
375 int dd; | |
376 vector<float>::iterator vi; | |
377 vector<Uns32T>::iterator ui; | |
378 | |
379 #ifdef USE_U_FUNCTIONS | |
380 Uns32T bb; | |
381 // Store m dot products to expand | |
382 for( aa=0; aa < H::m ; aa++ ){ | |
383 ui = uu[aa].begin(); | |
384 for( kk = 0 ; kk < H::k/2 ; kk++ ){ | |
385 pb = *( G::b + aa ) + kk; | |
386 pA = * ( * ( G::A + aa ) + kk ); | |
387 dd = H::d; | |
388 tmp = 0.; | |
389 vi = v.begin(); | |
390 while( dd-- ) | |
391 tmp += *pA++ * *vi++; // project | |
392 tmp += *pb; // translate | |
393 tmp *= iw; // scale | |
394 *ui++ = (Uns32T) floor(tmp); // floor | |
395 } | |
396 } | |
397 // Binomial combinations of functions u_{a,b} \in Z^{(k/2) \times d} | |
398 Uns32T j; | |
399 for( aa=0, j=0 ; aa < H::m-1 ; aa++ ) | |
400 for( bb = aa + 1 ; bb < H::m ; bb++, j++ ){ | |
401 pg= *( G::g + j ); // L \times functions g_j(v) \in Z^k | |
402 // u_1 \in Z^{(k/2) \times d} | |
403 ui = uu[aa].begin(); | |
404 kk=H::k/2; | |
405 while( kk-- ) | |
406 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... x(k/2)]; xk \in Z | |
407 // u_2 \in Z^{(k/2) \times d} | |
408 ui = uu[bb].begin(); | |
409 kk=H::k/2; | |
410 while( kk--) | |
411 *pg++ = *ui++; // hash function g_j(v)=[x(k/2+1) x(k/2+2) ... xk]; xk \in Z | |
412 } | |
413 #else | |
414 for( aa=0; aa < H::L ; aa++ ){ | |
415 ui = uu[aa].begin(); | |
416 for( kk = 0 ; kk < H::k ; kk++ ){ | |
417 pb = *( G::b + aa ) + kk; | |
418 pA = * ( * ( G::A + aa ) + kk ); | |
419 dd = H::d; | |
420 tmp = 0.; | |
421 vi = v.begin(); | |
422 while( dd-- ) | |
423 tmp += *pA++ * *vi++; // project | |
424 tmp += *pb; // translate | |
425 tmp *= iw; // scale | |
426 *ui++ = (Uns32T) (floor(tmp)); // floor | |
427 } | |
428 } | |
429 // Compute hash functions | |
430 for( aa=0 ; aa < H::L ; aa++ ){ | |
431 pg= *( G::g + aa ); // L \times functions g_j(v) \in Z^k | |
432 // u_1 \in Z^{k \times d} | |
433 ui = uu[aa].begin(); | |
434 kk=H::k; | |
435 while( kk-- ) | |
436 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... xk]; xk \in Z | |
437 } | |
438 #endif | |
439 | |
440 } | |
441 | |
442 | |
443 // single point insertion; inserted values are hash value and pointID | |
444 Uns32T G::insert_point(vector<float>& v, Uns32T pp){ | |
445 Uns32T collisionCount = 0; | |
446 H::p = pp; | |
447 if(pp>G::maxp) | |
448 G::maxp=pp; // Store highest pointID in database | |
449 compute_hash_functions( v ); | |
450 for(Uns32T j = 0 ; j < H::L ; j++ ){ // insertion | |
451 __generate_hash_keys( *( G::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); | |
452 collisionCount += bucket_insert_point( *(h + j) + t1 ); | |
453 } | |
454 return collisionCount; | |
455 } | |
456 | |
457 | |
458 // batch insert for a point set | |
459 // inserted values are vector hash value and pointID starting at basePointID | |
460 void G::insert_point_set(vector<vector<float> >& vv, Uns32T basePointID){ | |
461 for(Uns32T point=0; point<vv.size(); point++) | |
462 insert_point(vv[point], basePointID+point); | |
463 } | |
464 | |
465 // point retrieval routine | |
466 void G::retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ | |
467 calling_instance = caller; | |
468 add_point_callback = add_point; | |
469 compute_hash_functions( v ); | |
470 for(Uns32T j = 0 ; j < H::L ; j++ ){ | |
471 __generate_hash_keys( *( G::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); | |
472 if( bucket* bPtr = *(__get_bucket(j) + get_t1()) ) | |
473 #ifdef LSH_BLOCK_FULL_ROWS | |
474 bucket_chain_point( bPtr->next, qpos); | |
475 #else | |
476 bucket_chain_point( bPtr , qpos); | |
477 #endif | |
478 } | |
479 } | |
480 | |
481 void G::retrieve_point_set(vector<vector<float> >& vv, ReporterCallbackPtr add_point, void* caller){ | |
482 for(Uns32T qpos = 0 ; qpos < vv.size() ; qpos++ ) | |
483 retrieve_point(vv[qpos], qpos, add_point, caller); | |
484 } | |
485 | |
486 // export lsh tables to table structure on disk | |
487 // | |
488 // LSH TABLE STRUCTURE | |
489 // ---header 64 bytes --- | |
490 // [magic #tables #rows #cols elementSize databaseSize version flags dim #funs 0 0 0 0 0 0] | |
491 // | |
492 // ---random projections L x k x d float --- | |
493 // A[0][0][0] A[0][0][1] ... A[0][0][d-1] | |
494 // A[0][1][0] A[0][1][1] ... A[1][1][d-1] | |
495 // ... | |
496 // A[0][K-1][0] A[0][1][1] ... A[0][k-1][d-1] | |
497 // ... | |
498 // ... | |
499 // A[L-1][0][0] A[M-1][0][1] ... A[L-1][0][d-1] | |
500 // A[L-1][1][0] A[M-1][1][1] ... A[L-1][1][d-1] | |
501 // ... | |
502 // A[L-1][k-1][0] A[M-1][1][1] ... A[L-1][k-1][d-1] | |
503 // | |
504 // ---bias L x k float --- | |
505 // b[0][0] b[0][1] ... b[0][k-1] | |
506 // b[1][0] b[1][1] ... b[1][k-1] | |
507 // ... | |
508 // b[L-1][0] b[L-1][1] ... b[L-1][k-1] | |
509 // | |
510 // ---random r1 L x k float --- | |
511 // r1[0][0] r1[0][1] ... r1[0][k-1] | |
512 // r1[1][0] r1[1][1] ... r1[1][k-1] | |
513 // ... | |
514 // r1[L-1][0] r1[L-1][1] ... r1[L-1][k-1] | |
515 // | |
516 // ---random r2 L x k float --- | |
517 // r2[0][0] r2[0][1] ... r2[0][k-1] | |
518 // r2[1][0] r2[1][1] ... r2[1][k-1] | |
519 // ... | |
520 // r2[L-1][0] r2[L-1][1] ... r2[L-1][k-1] | |
521 // | |
522 // ---hash table 0: N x C x 8 --- | |
523 // [t2 pointID][t2 pointID]...[t2 pointID] | |
524 // [t2 pointID][t2 pointID]...[t2 pointID] | |
525 // ... | |
526 // [t2 pointID][t2 pointID]...[t2 pointID] | |
527 // | |
528 // ---hash table 1: N x C x 8 --- | |
529 // [t2 pointID][t2 pointID]...[t2 pointID] | |
530 // [t2 pointID][t2 pointID]...[t2 pointID] | |
531 // ... | |
532 // [t2 pointID][t2 pointID]...[t2 pointID] | |
533 // | |
534 // ... | |
535 // | |
536 // ---hash table L-1: N x C x 8 --- | |
537 // [t2 pointID][t2 pointID]...[t2 pointID] | |
538 // [t2 pointID][t2 pointID]...[t2 pointID] | |
539 // ... | |
540 // [t2 pointID][t2 pointID]...[t2 pointID] | |
541 // | |
542 | |
543 // Serial header constructors | |
544 SerialHeader::SerialHeader(){;} | |
545 SerialHeader::SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float r, Uns32T p, Uns32T FMT): | |
546 lshMagic(O2_SERIAL_MAGIC), | |
547 binWidth(W), | |
548 numTables(L), | |
549 numRows(N), | |
550 numCols(C), | |
551 elementSize(O2_SERIAL_ELEMENT_SIZE), | |
552 version(O2_SERIAL_VERSION), | |
553 size(L * align_up(N * C * O2_SERIAL_ELEMENT_SIZE, get_page_logn()) // hash tables | |
554 + align_up(O2_SERIAL_HEADER_SIZE + // header + hash functions | |
555 L*k*( sizeof(float)*d+2*sizeof(Uns32T)+sizeof(float)),get_page_logn())), | |
556 flags(FMT), | |
557 dataDim(d), | |
558 numFuns(k), | |
559 radius(r), | |
560 maxp(p){;} // header | |
561 | |
562 float* G::get_serial_hashfunction_base(char* db){ | |
563 if(db&&lshHeader) | |
564 return (float*)(db+O2_SERIAL_HEADER_SIZE); | |
565 else return NULL; | |
566 } | |
567 | |
568 SerialElementT* G::get_serial_hashtable_base(char* db){ | |
569 if(db&&lshHeader) | |
570 return (SerialElementT*)(db+get_serial_hashtable_offset()); | |
571 else | |
572 return NULL; | |
573 } | |
574 | |
575 Uns32T G::get_serial_hashtable_offset(){ | |
576 if(lshHeader) | |
577 return align_up(O2_SERIAL_HEADER_SIZE + | |
578 L*lshHeader->numFuns*( sizeof(float)*lshHeader->dataDim+2*sizeof(Uns32T)+sizeof(float)),get_page_logn()); | |
579 else | |
580 return 0; | |
581 } | |
582 | |
583 void G::serialize(char* filename, Uns32T serialFormat){ | |
584 int dbfid; | |
585 char* db; | |
586 int dbIsNew=0; | |
587 | |
588 // Check requested serialFormat | |
589 if(!(serialFormat==O2_SERIAL_FILEFORMAT1 || serialFormat==O2_SERIAL_FILEFORMAT2)) | |
590 error("Unrecognized serial file format request: ", "serialize()"); | |
591 | |
592 // Test to see if file exists | |
593 if((dbfid = open (filename, O_RDONLY)) < 0) | |
594 // If it doesn't, then create the file (CREATE) | |
595 if(errno == ENOENT){ | |
596 // Create the file | |
597 std::cout << "Creating new serialized LSH database:" << filename << "..."; | |
598 std::cout.flush(); | |
599 serial_create(filename, serialFormat); | |
600 dbIsNew=1; | |
601 } | |
602 else | |
603 // The file can't be opened | |
604 error("Can't open the file", filename, "open"); | |
605 | |
606 // Load the on-disk header into core | |
607 dbfid = serial_open(filename, 1); // open for write | |
608 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer | |
609 serial_get_header(db); // read header | |
610 serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap | |
611 | |
612 // Check compatibility of core and disk data structures | |
613 if( !serial_can_merge(serialFormat) ) | |
614 error("Incompatible core and serial LSH, data structure dimensions mismatch."); | |
615 | |
616 // For new LSH databases write the hashfunctions | |
617 if(dbIsNew) | |
618 serialize_lsh_hashfunctions(dbfid); | |
619 // Write the hashtables in the requested format | |
620 if(serialFormat == O2_SERIAL_FILEFORMAT1) | |
621 serialize_lsh_hashtables_format1(dbfid, !dbIsNew); | |
622 else | |
623 serialize_lsh_hashtables_format2(dbfid, !dbIsNew); | |
624 | |
625 if(!dbIsNew){ | |
626 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer | |
627 //serial_get_header(db); // read header | |
628 cout << "maxp = " << G::maxp << endl; | |
629 lshHeader->maxp=G::maxp; | |
630 // Default to FILEFORMAT1 | |
631 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2)) | |
632 lshHeader->flags|=O2_SERIAL_FILEFORMAT2; | |
633 memcpy((char*)db, (char*)lshHeader, sizeof(SerialHeaderT)); | |
634 serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap | |
635 } | |
636 | |
637 serial_close(dbfid); | |
638 } | |
639 | |
640 // Test to see if core structure and requested format is | |
641 // compatible with currently opened database | |
642 int G::serial_can_merge(Uns32T format){ | |
643 SerialHeaderT* that = lshHeader; | |
644 if( (format==O2_SERIAL_FILEFORMAT2 && !that->flags&O2_SERIAL_FILEFORMAT2) | |
645 || (format!=O2_SERIAL_FILEFORMAT2 && that->flags&O2_SERIAL_FILEFORMAT2) | |
646 || !( this->w == that->binWidth && | |
647 this->L == that->numTables && | |
648 this->N == that->numRows && | |
649 this->k == that->numFuns && | |
650 this->d == that->dataDim && | |
651 sizeof(SerialElementT) == that->elementSize && | |
652 this->radius == that->radius)){ | |
653 serial_print_header(format); | |
654 return 0; | |
655 } | |
656 else | |
657 return 1; | |
658 } | |
659 | |
660 // Used as an error message for serial_can_merge() | |
661 void G::serial_print_header(Uns32T format){ | |
662 std::cout << "Fc:" << format << " Fs:" << lshHeader->flags << endl; | |
663 std::cout << "Wc:" << w << " Ls:" << lshHeader->binWidth << endl; | |
664 std::cout << "Lc:" << L << " Ls:" << lshHeader->numTables << endl; | |
665 std::cout << "Nc:" << N << " Ns:" << lshHeader->numRows << endl; | |
666 std::cout << "kc:" << k << " ks:" << lshHeader->numFuns << endl; | |
667 std::cout << "dc:" << d << " ds:" << lshHeader->dataDim << endl; | |
668 std::cout << "sc:" << sizeof(SerialElementT) << " ss:" << lshHeader->elementSize << endl; | |
669 std::cout << "rc:" << this->radius << " rs:" << lshHeader->radius << endl; | |
670 } | |
671 | |
672 int G::serialize_lsh_hashfunctions(int fid){ | |
673 float* pf; | |
674 Uns32T *pu; | |
675 Uns32T x,y,z; | |
676 | |
677 db = serial_mmap(fid, get_serial_hashtable_offset(), 1);// get database pointer | |
678 pf = get_serial_hashfunction_base(db); | |
679 | |
680 // HASH FUNCTIONS | |
681 // Write the random projectors A[][][] | |
682 #ifdef USE_U_FUNCTIONS | |
683 for( x = 0 ; x < H::m ; x++ ) | |
684 for( y = 0 ; y < H::k/2 ; y++ ) | |
685 #else | |
686 for( x = 0 ; x < H::L ; x++ ) | |
687 for( y = 0 ; y < H::k ; y++ ) | |
688 #endif | |
689 for( z = 0 ; z < d ; z++ ) | |
690 *pf++ = A[x][y][z]; | |
691 | |
692 // Write the random biases b[][] | |
693 #ifdef USE_U_FUNCTIONS | |
694 for( x = 0 ; x < H::m ; x++ ) | |
695 for( y = 0 ; y < H::k/2 ; y++ ) | |
696 #else | |
697 for( x = 0 ; x < H::L ; x++ ) | |
698 for( y = 0 ; y < H::k ; y++ ) | |
699 #endif | |
700 *pf++=b[x][y]; | |
701 | |
702 pu = (Uns32T*)pf; | |
703 | |
704 // Write the Z projectors r1[][] | |
705 for( x = 0 ; x < H::L ; x++) | |
706 for( y = 0 ; y < H::k ; y++) | |
707 *pu++ = r1[x][y]; | |
708 | |
709 // Write the Z projectors r2[][] | |
710 for( x = 0 ; x < H::L ; x++) | |
711 for( y = 0; y < H::k ; y++) | |
712 *pu++ = r2[x][y]; | |
713 | |
714 serial_munmap(db, get_serial_hashtable_offset()); | |
715 return 1; | |
716 } | |
717 | |
718 int G::serialize_lsh_hashtables_format1(int fid, int merge){ | |
719 SerialElementT *pe, *pt; | |
720 Uns32T x,y; | |
721 | |
722 if( merge && !serial_can_merge(O2_SERIAL_FILEFORMAT1) ) | |
723 error("Cannot merge core and serial LSH, data structure dimensions mismatch."); | |
724 | |
725 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | |
726 Uns32T colCount, meanColCount, colCountN, maxColCount, minColCount; | |
727 // Write the hash tables | |
728 for( x = 0 ; x < H::L ; x++ ){ | |
729 std::cout << (merge ? "merging":"writing") << " hash table " << x << " FORMAT1..."; | |
730 std::cout.flush(); | |
731 // memory map a single hash table for sequential access | |
732 // Align each hash table to page boundary | |
733 char* dbtable = serial_mmap(fid, hashTableSize, 1, | |
734 align_up(get_serial_hashtable_offset()+x*hashTableSize, get_page_logn())); | |
735 if(madvise(dbtable, hashTableSize, MADV_SEQUENTIAL)<0) | |
736 error("could not advise hashtable memory","","madvise"); | |
737 | |
738 maxColCount=0; | |
739 minColCount=O2_SERIAL_MAX_COLS; | |
740 meanColCount=0; | |
741 colCountN=0; | |
742 pt=(SerialElementT*)dbtable; | |
743 for( y = 0 ; y < H::N ; y++ ){ | |
744 // Move disk pointer to beginning of row | |
745 pe=pt+y*lshHeader->numCols; | |
746 | |
747 colCount=0; | |
748 if(bucket* bPtr = h[x][y]) | |
749 if(merge) | |
750 #ifdef LSH_BLOCK_FULL_ROWS | |
751 serial_merge_hashtable_row_format1(pe, bPtr->next, colCount); // skip collision counter bucket | |
752 else | |
753 serial_write_hashtable_row_format1(pe, bPtr->next, colCount); // skip collision counter bucket | |
754 #else | |
755 serial_merge_hashtable_row_format1(pe, bPtr, colCount); | |
756 else | |
757 serial_write_hashtable_row_format1(pe, bPtr, colCount); | |
758 #endif | |
759 if(colCount){ | |
760 if(colCount<minColCount) | |
761 minColCount=colCount; | |
762 if(colCount>maxColCount) | |
763 maxColCount=colCount; | |
764 meanColCount+=colCount; | |
765 colCountN++; | |
766 } | |
767 } | |
768 if(colCountN) | |
769 std::cout << "#rows with collisions =" << colCountN << ", mean = " << meanColCount/(float)colCountN | |
770 << ", min = " << minColCount << ", max = " << maxColCount | |
771 << endl; | |
772 serial_munmap(dbtable, hashTableSize); | |
773 } | |
774 | |
775 // We're done writing | |
776 return 1; | |
777 } | |
778 | |
779 void G::serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* b, Uns32T& colCount){ | |
780 while(b && b->t2!=IFLAG){ | |
781 SerialElementT*pe=pr; // reset disk pointer to beginning of row | |
782 serial_merge_element_format1(pe, b->snext, b->t2, colCount); | |
783 b=b->next; | |
784 } | |
785 } | |
786 | |
787 void G::serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount){ | |
788 while(sb){ | |
789 if(colCount==lshHeader->numCols){ | |
790 std::cout << "!point-chain full " << endl; | |
791 return; | |
792 } | |
793 Uns32T c=0; | |
794 // Merge collision chains | |
795 while(c<lshHeader->numCols){ | |
796 if( (pe+c)->hashValue==IFLAG){ | |
797 (pe+c)->hashValue=t2; | |
798 (pe+c)->pointID=sb->pointID; | |
799 colCount=c+1; | |
800 if(c+1<lshHeader->numCols) | |
801 (pe+c+1)->hashValue=IFLAG; | |
802 break; | |
803 } | |
804 c++; | |
805 } | |
806 sb=sb->snext; | |
807 } | |
808 return; | |
809 } | |
810 | |
811 void G::serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* b, Uns32T& colCount){ | |
812 pe->hashValue=IFLAG; | |
813 while(b && b->t2!=IFLAG){ | |
814 serial_write_element_format1(pe, b->snext, b->t2, colCount); | |
815 b=b->next; | |
816 } | |
817 } | |
818 | |
819 void G::serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount){ | |
820 while(sb){ | |
821 if(colCount==lshHeader->numCols){ | |
822 std::cout << "!point-chain full " << endl; | |
823 return; | |
824 } | |
825 pe->hashValue=t2; | |
826 pe->pointID=sb->pointID; | |
827 pe++; | |
828 colCount++; | |
829 sb=sb->snext; | |
830 } | |
831 pe->hashValue=IFLAG; | |
832 return; | |
833 } | |
834 | |
835 int G::serialize_lsh_hashtables_format2(int fid, int merge){ | |
836 Uns32T x,y; | |
837 | |
838 if( merge && !serial_can_merge(O2_SERIAL_FILEFORMAT2) ) | |
839 error("Cannot merge core and serial LSH, data structure dimensions mismatch."); | |
840 | |
841 // We must pereform FORMAT1 merges in core | |
842 if(merge) | |
843 unserialize_lsh_hashtables_format2(fid); | |
844 | |
845 Uns32T colCount, meanColCount, colCountN, maxColCount, minColCount, t1; | |
846 lseek(fid, get_serial_hashtable_offset(), SEEK_SET); | |
847 | |
848 // Write the hash tables | |
849 for( x = 0 ; x < H::L ; x++ ){ | |
850 std::cout << (merge ? "merging":"writing") << " hash table " << x << " FORMAT2..."; | |
851 std::cout.flush(); | |
852 maxColCount=0; | |
853 minColCount=O2_SERIAL_MAX_COLS; | |
854 meanColCount=0; | |
855 colCountN=0; | |
856 for( y = 0 ; y < H::N ; y++ ){ | |
857 colCount=0; | |
858 if(bucket* bPtr = h[x][y]){ | |
859 t1 = y | O2_SERIAL_FLAGS_T1_BIT; | |
860 if( write(fid, &t1, sizeof(Uns32T)) != sizeof(Uns32T) ){ | |
861 close(fid); | |
862 error("write error in serial_write_hashtable_format2() [t1]"); | |
863 } | |
864 #ifdef LSH_BLOCK_FULL_ROWS | |
865 serial_write_hashtable_row_format2(fid, bPtr->next, colCount); // skip collision counter bucket | |
866 #else | |
867 serial_write_hashtable_row_format2(fid, bPtr, colCount); | |
868 #endif | |
869 } | |
870 if(colCount){ | |
871 if(colCount<minColCount) | |
872 minColCount=colCount; | |
873 if(colCount>maxColCount) | |
874 maxColCount=colCount; | |
875 meanColCount+=colCount; | |
876 colCountN++; | |
877 } | |
878 } | |
879 // Write END of table marker | |
880 t1 = O2_SERIAL_FLAGS_END_BIT; | |
881 if( write(fid, &t1, sizeof(Uns32T)) != sizeof(Uns32T) ){ | |
882 close(fid); | |
883 error("write error in serial_write_hashtable_format2() [end]"); | |
884 } | |
885 | |
886 if(colCountN) | |
887 std::cout << "#rows with collisions =" << colCountN << ", mean = " << meanColCount/(float)colCountN | |
888 << ", min = " << minColCount << ", max = " << maxColCount | |
889 << endl; | |
890 } | |
891 | |
892 // We're done writing | |
893 return 1; | |
894 } | |
895 | |
896 void G::serial_write_hashtable_row_format2(int fid, bucket* b, Uns32T& colCount){ | |
897 while(b && b->t2!=IFLAG){ | |
898 t2 = O2_SERIAL_FLAGS_T2_BIT; | |
899 if( write(fid, &t2, sizeof(Uns32T)) != sizeof(Uns32T) ){ | |
900 close(fid); | |
901 error("write error in serial_write_hashtable_row_format2()"); | |
902 } | |
903 t2 = b->t2; | |
904 if( write(fid, &t2, sizeof(Uns32T)) != sizeof(Uns32T) ){ | |
905 close(fid); | |
906 error("write error in serial_write_hashtable_row_format2()"); | |
907 } | |
908 serial_write_element_format2(fid, b->snext, colCount); | |
909 b=b->next; | |
910 } | |
911 } | |
912 | |
913 void G::serial_write_element_format2(int fid, sbucket* sb, Uns32T& colCount){ | |
914 while(sb){ | |
915 if(write(fid, &sb->pointID, sizeof(Uns32T))!=sizeof(Uns32T)){ | |
916 close(fid); | |
917 error("Write error in serial_write_element_format2()"); | |
918 } | |
919 colCount++; | |
920 sb=sb->snext; | |
921 } | |
922 } | |
923 | |
924 | |
925 int G::serial_create(char* filename, Uns32T FMT){ | |
926 return serial_create(filename, w, L, N, C, k, d, FMT); | |
927 } | |
928 | |
929 | |
930 int G::serial_create(char* filename, float binWidth, Uns32T numTables, Uns32T numRows, Uns32T numCols, | |
931 Uns32T numFuns, Uns32T dim, Uns32T FMT){ | |
932 | |
933 if(numTables > O2_SERIAL_MAX_TABLES || numRows > O2_SERIAL_MAX_ROWS | |
934 || numCols > O2_SERIAL_MAX_COLS || numFuns > O2_SERIAL_MAX_FUNS | |
935 || dim>O2_SERIAL_MAX_DIM){ | |
936 error("LSH parameters out of bounds for serialization"); | |
937 } | |
938 | |
939 int dbfid; | |
940 if ((dbfid = open (filename, O_RDWR|O_CREAT|O_EXCL, S_IRUSR|S_IWUSR|S_IRGRP|S_IWGRP|S_IROTH|S_IWOTH)) < 0) | |
941 error("Can't create serial file", filename, "open"); | |
942 get_lock(dbfid, 1); | |
943 | |
944 // Make header first to get size of serialized database | |
945 lshHeader = new SerialHeaderT(binWidth, numTables, numRows, numCols, numFuns, dim, radius, maxp, FMT); | |
946 | |
947 // go to the location corresponding to the last byte | |
948 if (lseek (dbfid, lshHeader->get_size() - 1, SEEK_SET) == -1) | |
949 error("lseek error in db file", "", "lseek"); | |
950 | |
951 // write a dummy byte at the last location | |
952 if (write (dbfid, "", 1) != 1) | |
953 error("write error", "", "write"); | |
954 | |
955 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1); | |
956 | |
957 memcpy (db, lshHeader, O2_SERIAL_HEADER_SIZE); | |
958 | |
959 serial_munmap(db, O2_SERIAL_HEADER_SIZE); | |
960 | |
961 close(dbfid); | |
962 | |
963 std::cout << "done initializing tables." << endl; | |
964 | |
965 return 1; | |
966 } | |
967 | |
968 char* G::serial_mmap(int dbfid, Uns32T memSize, Uns32T forWrite, off_t offset){ | |
969 if(forWrite){ | |
970 if ((db = (char*) mmap(0, memSize, PROT_READ | PROT_WRITE, | |
971 MAP_SHARED, dbfid, offset)) == (caddr_t) -1) | |
972 error("mmap error in request for writable serialized database", "", "mmap"); | |
973 } | |
974 else if ((db = (char*) mmap(0, memSize, PROT_READ, MAP_SHARED, dbfid, offset)) == (caddr_t) -1) | |
975 error("mmap error in read-only serialized database", "", "mmap"); | |
976 | |
977 return db; | |
978 } | |
979 | |
980 SerialHeaderT* G::serial_get_header(char* db){ | |
981 lshHeader = new SerialHeaderT(); | |
982 memcpy((char*)lshHeader, db, sizeof(SerialHeaderT)); | |
983 | |
984 if(lshHeader->lshMagic!=O2_SERIAL_MAGIC) | |
985 error("Not an LSH database file"); | |
986 | |
987 return lshHeader; | |
988 } | |
989 | |
990 void G::serial_munmap(char* db, Uns32T N){ | |
991 munmap(db, N); | |
992 } | |
993 | |
994 int G::serial_open(char* filename, int writeFlag){ | |
995 int dbfid; | |
996 if(writeFlag){ | |
997 if ((dbfid = open (filename, O_RDWR)) < 0) | |
998 error("Can't open serial file for read/write", filename, "open"); | |
999 get_lock(dbfid, writeFlag); | |
1000 } | |
1001 else{ | |
1002 if ((dbfid = open (filename, O_RDONLY)) < 0) | |
1003 error("Can't open serial file for read", filename, "open"); | |
1004 get_lock(dbfid, 0); | |
1005 } | |
1006 | |
1007 return dbfid; | |
1008 } | |
1009 | |
1010 void G::serial_close(int dbfid){ | |
1011 | |
1012 release_lock(dbfid); | |
1013 close(dbfid); | |
1014 } | |
1015 | |
1016 int G::unserialize_lsh_header(char* filename){ | |
1017 | |
1018 int dbfid; | |
1019 char* db; | |
1020 // Test to see if file exists | |
1021 if((dbfid = open (filename, O_RDONLY)) < 0) | |
1022 error("Can't open the file", filename, "open"); | |
1023 close(dbfid); | |
1024 dbfid = serial_open(filename, 0); // open for read | |
1025 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer | |
1026 serial_get_header(db); // read header | |
1027 serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap | |
1028 | |
1029 // Unserialize header parameters | |
1030 H::L = lshHeader->numTables; | |
1031 H::m = (Uns32T)( (1.0 + sqrt(1 + 8.0*(int)H::L)) / 2.0); | |
1032 H::N = lshHeader->numRows; | |
1033 H::C = lshHeader->numCols; | |
1034 H::k = lshHeader->numFuns; | |
1035 H::d = lshHeader->dataDim; | |
1036 G::w = lshHeader->binWidth; | |
1037 G::radius = lshHeader->radius; | |
1038 G::maxp = lshHeader->maxp; | |
1039 | |
1040 return dbfid; | |
1041 } | |
1042 | |
1043 // unserialize the LSH parameters | |
1044 // we leave the LSH tree on disk as a flat file | |
1045 // it is this flat file that we search by memory mapping | |
1046 void G::unserialize_lsh_functions(int dbfid){ | |
1047 Uns32T j, kk; | |
1048 float* pf; | |
1049 Uns32T* pu; | |
1050 | |
1051 // Load the hash functions into core | |
1052 char* db = serial_mmap(dbfid, get_serial_hashtable_offset(), 0);// get database pointer again | |
1053 | |
1054 #ifdef USE_U_FUNCTIONS | |
1055 G::A = new float**[ H::m ]; // m x k x d random projectors | |
1056 G::b = new float*[ H::m ]; // m x k random biases | |
1057 #else | |
1058 G::A = new float**[ H::L ]; // m x k x d random projectors | |
1059 G::b = new float*[ H::L ]; // m x k random biases | |
1060 #endif | |
1061 G::g = new Uns32T*[ H::L ]; // L x k random projections | |
1062 assert(g&&A&&b); // failure | |
1063 | |
1064 pf = get_serial_hashfunction_base(db); | |
1065 | |
1066 #ifdef USE_U_FUNCTIONS | |
1067 for( j = 0 ; j < H::m ; j++ ){ // L functions gj(v) | |
1068 G::A[j] = new float*[ H::k/2 ]; // k x d 2-stable distribution coefficients | |
1069 G::b[j] = new float[ H::k/2 ]; // bias | |
1070 assert( G::A[j] && G::b[j] ); // failure | |
1071 for( kk = 0 ; kk < H::k/2 ; kk++ ){ // Normally distributed hash functions | |
1072 #else | |
1073 for( j = 0 ; j < H::L ; j++ ){ // L functions gj(v) | |
1074 G::A[j] = new float*[ H::k ]; // k x d 2-stable distribution coefficients | |
1075 G::b[j] = new float[ H::k ]; // bias | |
1076 assert( G::A[j] && G::b[j] ); // failure | |
1077 for( kk = 0 ; kk < H::k ; kk++ ){ // Normally distributed hash functions | |
1078 #endif | |
1079 G::A[j][kk] = new float[ H::d ]; | |
1080 assert( G::A[j][kk] ); // failure | |
1081 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
1082 G::A[j][kk][i] = *pf++; // Normally distributed random vectors | |
1083 } | |
1084 } | |
1085 #ifdef USE_U_FUNCTIONS | |
1086 for( j = 0 ; j < H::m ; j++ ) // biases b | |
1087 for( kk = 0 ; kk < H::k/2 ; kk++ ) | |
1088 #else | |
1089 for( j = 0 ; j < H::L ; j++ ) // biases b | |
1090 for( kk = 0 ; kk < H::k ; kk++ ) | |
1091 #endif | |
1092 G::b[j][kk] = *pf++; | |
1093 | |
1094 for( j = 0 ; j < H::L ; j++){ // 32-bit hash values, gj(v)=[x0 x1 ... xk-1] xk \in Z | |
1095 G::g[j] = new Uns32T[ H::k ]; | |
1096 assert( G::g[j] ); | |
1097 } | |
1098 | |
1099 | |
1100 H::__initialize_data_structures(); | |
1101 | |
1102 pu = (Uns32T*)pf; | |
1103 for( j = 0 ; j < H::L ; j++ ) // Z projectors r1 | |
1104 for( kk = 0 ; kk < H::k ; kk++ ) | |
1105 H::r1[j][kk] = *pu++; | |
1106 | |
1107 for( j = 0 ; j < H::L ; j++ ) // Z projectors r2 | |
1108 for( kk = 0 ; kk < H::k ; kk++ ) | |
1109 H::r2[j][kk] = *pu++; | |
1110 | |
1111 serial_munmap(db, get_serial_hashtable_offset()); | |
1112 } | |
1113 | |
1114 void G::unserialize_lsh_hashtables_format1(int fid){ | |
1115 SerialElementT *pe, *pt; | |
1116 Uns32T x,y; | |
1117 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | |
1118 // Read the hash tables into core | |
1119 for( x = 0 ; x < H::L ; x++ ){ | |
1120 // memory map a single hash table | |
1121 // Align each hash table to page boundary | |
1122 char* dbtable = serial_mmap(fid, hashTableSize, 0, | |
1123 align_up(get_serial_hashtable_offset()+x*hashTableSize, get_page_logn())); | |
1124 if(madvise(dbtable, hashTableSize, MADV_SEQUENTIAL)<0) | |
1125 error("could not advise hashtable memory","","madvise"); | |
1126 pt=(SerialElementT*)dbtable; | |
1127 for( y = 0 ; y < H::N ; y++ ){ | |
1128 // Move disk pointer to beginning of row | |
1129 pe=pt+y*lshHeader->numCols; | |
1130 unserialize_hashtable_row_format1(pe, h[x]+y); | |
1131 #ifdef __LSH_DUMP_CORE_TABLES__ | |
1132 printf("S[%d,%d]", x, y); | |
1133 serial_bucket_dump(pe); | |
1134 printf("C[%d,%d]", x, y); | |
1135 dump_hashtable_row(h[x][y]); | |
1136 #endif | |
1137 } | |
1138 serial_munmap(dbtable, hashTableSize); | |
1139 } | |
1140 } | |
1141 | |
1142 void G::unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b){ | |
1143 Uns32T colCount = 0; | |
1144 while(colCount!=lshHeader->numCols && pe->hashValue !=IFLAG){ | |
1145 H::p = pe->pointID; // current point ID | |
1146 t2 = pe->hashValue; | |
1147 bucket_insert_point(b); | |
1148 pe++; | |
1149 colCount++; | |
1150 } | |
1151 } | |
1152 | |
1153 void G::unserialize_lsh_hashtables_format2(int fid){ | |
1154 Uns32T x=0,y=0; | |
1155 | |
1156 // Seek to hashtable base offset | |
1157 if(lseek(fid, get_serial_hashtable_offset(), SEEK_SET)!=get_serial_hashtable_offset()){ | |
1158 close(fid); | |
1159 error("Seek error in unserialize_lsh_hashtables_format2"); | |
1160 } | |
1161 | |
1162 // Read the hash tables into core (structure is given in header) | |
1163 while( x < H::L){ | |
1164 if(read(fid, &(H::t1), sizeof(Uns32T))!=sizeof(Uns32T)){ | |
1165 close(fid); | |
1166 error("Read error","unserialize_lsh_hashtables_format2()"); | |
1167 } | |
1168 if(H::t1&O2_SERIAL_FLAGS_END_BIT) | |
1169 x++; // End of table | |
1170 else | |
1171 while(y < H::N){ | |
1172 // Read a row and move file pointer to beginning of next row or table | |
1173 if(!(H::t1&O2_SERIAL_FLAGS_T1_BIT)){ | |
1174 close(fid); | |
1175 error("State matchine error t1","unserialize_lsh_hashtables_format2()"); | |
1176 } | |
1177 y = H::t1 ^ O2_SERIAL_FLAGS_T1_BIT; | |
1178 if(y>=H::N){ | |
1179 close(fid); | |
1180 error("Unserialized hashtable row pointer out of range","unserialize_lsh_hashtables_format2()"); | |
1181 } | |
1182 Uns32T token = unserialize_hashtable_row_format2(fid, h[x]+y); | |
1183 | |
1184 #ifdef __LSH_DUMP_CORE_TABLES__ | |
1185 printf("C[%d,%d]", x, y); | |
1186 dump_hashtable_row(h[x][y]); | |
1187 #endif | |
1188 // Check that token is valid | |
1189 if( !(token&O2_SERIAL_FLAGS_T1_BIT || token&O2_SERIAL_FLAGS_END_BIT) ){ | |
1190 close(fid); | |
1191 error("State machine error end of row/table", "unserialize_lsh_hashtables_format2()"); | |
1192 } | |
1193 // Check for end of table flag | |
1194 if(token&O2_SERIAL_FLAGS_END_BIT){ | |
1195 x++; | |
1196 break; | |
1197 } | |
1198 // Check for new row flag | |
1199 if(token&O2_SERIAL_FLAGS_T1_BIT) | |
1200 H::t1 = token; | |
1201 } | |
1202 } | |
1203 } | |
1204 | |
1205 Uns32T G::unserialize_hashtable_row_format2(int fid, bucket** b){ | |
1206 bool pointFound = false; | |
1207 if(read(fid, &(H::t2), sizeof(Uns32T)) != sizeof(Uns32T)){ | |
1208 close(fid); | |
1209 error("Read error T2 token","unserialize_hashtable_row_format2"); | |
1210 } | |
1211 if( !(H::t2==O2_SERIAL_FLAGS_END_BIT || H::t2==O2_SERIAL_FLAGS_T2_BIT)){ | |
1212 close(fid); | |
1213 error("State machine error: expected E or T2"); | |
1214 } | |
1215 while(!(H::t2==O2_SERIAL_FLAGS_END_BIT || H::t2&O2_SERIAL_FLAGS_T1_BIT)){ | |
1216 pointFound=false; | |
1217 // Check for T2 token | |
1218 if(H::t2!=O2_SERIAL_FLAGS_T2_BIT) | |
1219 error("State machine error T2 token", "unserialize_hashtable_row_format2()"); | |
1220 // Read t2 value | |
1221 if(read(fid, &(H::t2), sizeof(Uns32T)) != sizeof(Uns32T)){ | |
1222 close(fid); | |
1223 error("Read error t2","unserialize_hashtable_row_format2"); | |
1224 } | |
1225 if(read(fid, &(H::p), sizeof(Uns32T)) != sizeof(Uns32T)){ | |
1226 close(fid); | |
1227 error("Read error H::p","unserialize_hashtable_row_format2"); | |
1228 } | |
1229 while(!(H::p==O2_SERIAL_FLAGS_END_BIT || H::p&O2_SERIAL_FLAGS_T1_BIT || H::p==O2_SERIAL_FLAGS_T2_BIT )){ | |
1230 pointFound=true; | |
1231 bucket_insert_point(b); | |
1232 if(read(fid, &(H::p), sizeof(Uns32T)) != sizeof(Uns32T)){ | |
1233 close(fid); | |
1234 error("Read error H::p","unserialize_hashtable_row_format2"); | |
1235 } | |
1236 } | |
1237 H::t2 = H::p; // Copy last found token to t2 | |
1238 if(!pointFound) | |
1239 error("State machine error: point", "unserialize_hashtable_row_format2()"); | |
1240 } | |
1241 return H::t2; // holds current token | |
1242 } | |
1243 | |
1244 void G::dump_hashtable_row(bucket* p){ | |
1245 while(p && p->t2!=IFLAG){ | |
1246 sbucket* sbp = p->snext; | |
1247 while(sbp){ | |
1248 printf("(%0X,%u)", p->t2, sbp->pointID); | |
1249 fflush(stdout); | |
1250 sbp=sbp->snext; | |
1251 } | |
1252 p=p->next; | |
1253 } | |
1254 printf("\n"); | |
1255 } | |
1256 | |
1257 | |
1258 // G::serial_retrieve_point( ... ) | |
1259 // retrieves (pointID) from a serialized LSH database | |
1260 // | |
1261 // inputs: | |
1262 // filename - file name of serialized LSH database | |
1263 // vv - query point set | |
1264 // | |
1265 // outputs: | |
1266 // inserts retrieved points into add_point() callback method | |
1267 void G::serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr add_point, void* caller) | |
1268 { | |
1269 int dbfid = serial_open(filename, 0); // open for read | |
1270 char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer | |
1271 serial_get_header(dbheader); // read header | |
1272 serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap | |
1273 | |
1274 if((lshHeader->flags & O2_SERIAL_FILEFORMAT2)){ | |
1275 close(dbfid); | |
1276 error("serial_retrieve_point_set is for SERIAL_FILEFORMAT1 only"); | |
1277 } | |
1278 | |
1279 // size of each hash table | |
1280 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | |
1281 calling_instance = caller; // class instance variable used in ...bucket_chain_point() | |
1282 add_point_callback = add_point; | |
1283 | |
1284 for(Uns32T j=0; j<L; j++){ | |
1285 // memory map a single hash table for random access | |
1286 char* db = serial_mmap(dbfid, hashTableSize, 0, | |
1287 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); | |
1288 if(madvise(db, hashTableSize, MADV_RANDOM)<0) | |
1289 error("could not advise local hashtable memory","","madvise"); | |
1290 SerialElementT* pe = (SerialElementT*)db ; | |
1291 for(Uns32T qpos=0; qpos<vv.size(); qpos++){ | |
1292 compute_hash_functions(vv[qpos]); | |
1293 __generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); | |
1294 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row | |
1295 } | |
1296 serial_munmap(db, hashTableSize); // drop hashtable mmap | |
1297 } | |
1298 serial_close(dbfid); | |
1299 } | |
1300 | |
1301 void G::serial_retrieve_point(char* filename, vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ | |
1302 int dbfid = serial_open(filename, 0); // open for read | |
1303 char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer | |
1304 serial_get_header(dbheader); // read header | |
1305 serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap | |
1306 | |
1307 if((lshHeader->flags & O2_SERIAL_FILEFORMAT2)){ | |
1308 close(dbfid); | |
1309 error("serial_retrieve_point is for SERIAL_FILEFORMAT1 only"); | |
1310 } | |
1311 | |
1312 // size of each hash table | |
1313 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | |
1314 calling_instance = caller; | |
1315 add_point_callback = add_point; | |
1316 compute_hash_functions(v); | |
1317 for(Uns32T j=0; j<L; j++){ | |
1318 // memory map a single hash table for random access | |
1319 char* db = serial_mmap(dbfid, hashTableSize, 0, | |
1320 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); | |
1321 if(madvise(db, hashTableSize, MADV_RANDOM)<0) | |
1322 error("could not advise local hashtable memory","","madvise"); | |
1323 SerialElementT* pe = (SerialElementT*)db ; | |
1324 __generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); | |
1325 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row | |
1326 serial_munmap(db, hashTableSize); // drop hashtable mmap | |
1327 } | |
1328 serial_close(dbfid); | |
1329 } | |
1330 | |
1331 void G::serial_dump_tables(char* filename){ | |
1332 int dbfid = serial_open(filename, 0); // open for read | |
1333 char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer | |
1334 serial_get_header(dbheader); // read header | |
1335 serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap | |
1336 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | |
1337 for(Uns32T j=0; j<L; j++){ | |
1338 // memory map a single hash table for random access | |
1339 char* db = serial_mmap(dbfid, hashTableSize, 0, | |
1340 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); | |
1341 if(madvise(db, hashTableSize, MADV_SEQUENTIAL)<0) | |
1342 error("could not advise local hashtable memory","","madvise"); | |
1343 SerialElementT* pe = (SerialElementT*)db ; | |
1344 printf("*********** TABLE %d ***************\n", j); | |
1345 fflush(stdout); | |
1346 int count=0; | |
1347 do{ | |
1348 printf("[%d,%d]", j, count++); | |
1349 fflush(stdout); | |
1350 serial_bucket_dump(pe); | |
1351 pe+=lshHeader->numCols; | |
1352 }while(pe<(SerialElementT*)db+lshHeader->numRows*lshHeader->numCols); | |
1353 } | |
1354 | |
1355 } | |
1356 | |
1357 void G::serial_bucket_dump(SerialElementT* pe){ | |
1358 SerialElementT* pend = pe+lshHeader->numCols; | |
1359 while( !(pe->hashValue==IFLAG || pe==pend ) ){ | |
1360 printf("(%0X,%u)",pe->hashValue,pe->pointID); | |
1361 pe++; | |
1362 } | |
1363 printf("\n"); | |
1364 fflush(stdout); | |
1365 } | |
1366 | |
1367 void G::serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos){ | |
1368 SerialElementT* pend = pe+lshHeader->numCols; | |
1369 while( !(pe->hashValue==IFLAG || pe==pend ) ){ | |
1370 if(pe->hashValue==t2){ // new match | |
1371 add_point_callback(calling_instance, pe->pointID, qpos, radius); | |
1372 } | |
1373 pe++; | |
1374 } | |
1375 } | |
1376 | |
1377 void G::bucket_chain_point(bucket* p, Uns32T qpos){ | |
1378 if(!p || p->t2==IFLAG) | |
1379 return; | |
1380 if(p->t2==t2){ // match | |
1381 sbucket_chain_point(p->snext, qpos); // add to reporter | |
1382 } | |
1383 if(p->next){ | |
1384 bucket_chain_point(p->next, qpos); // recurse | |
1385 } | |
1386 } | |
1387 | |
1388 void G::sbucket_chain_point(sbucket* p, Uns32T qpos){ | |
1389 add_point_callback(calling_instance, p->pointID, qpos, radius); | |
1390 if(p->snext){ | |
1391 sbucket_chain_point(p->snext, qpos); | |
1392 } | |
1393 } | |
1394 | |
1395 void G::get_lock(int fd, bool exclusive) { | |
1396 struct flock lock; | |
1397 int status; | |
1398 lock.l_type = exclusive ? F_WRLCK : F_RDLCK; | |
1399 lock.l_whence = SEEK_SET; | |
1400 lock.l_start = 0; | |
1401 lock.l_len = 0; /* "the whole file" */ | |
1402 retry: | |
1403 do { | |
1404 status = fcntl(fd, F_SETLKW, &lock); | |
1405 } while (status != 0 && errno == EINTR); | |
1406 if (status) { | |
1407 if (errno == EAGAIN) { | |
1408 sleep(1); | |
1409 goto retry; | |
1410 } else { | |
1411 error("fcntl lock error", "", "fcntl"); | |
1412 } | |
1413 } | |
1414 } | |
1415 | |
1416 void G::release_lock(int fd) { | |
1417 struct flock lock; | |
1418 int status; | |
1419 | |
1420 lock.l_type = F_UNLCK; | |
1421 lock.l_whence = SEEK_SET; | |
1422 lock.l_start = 0; | |
1423 lock.l_len = 0; | |
1424 | |
1425 status = fcntl(fd, F_SETLKW, &lock); | |
1426 | |
1427 if (status) | |
1428 error("fcntl unlock error", "", "fcntl"); | |
1429 } |