Mercurial > hg > audiodb
comparison lshlib.cpp @ 293:9fd5340faffd
Refactored LSH interface to separate hashfunctions and parameters from insertion/retrieval/serialization
author | mas01mc |
---|---|
date | Wed, 30 Jul 2008 15:22:22 +0000 |
parents | d9a88cfd4ab6 |
children | 071a108580a4 |
comparison
equal
deleted
inserted
replaced
292:d9a88cfd4ab6 | 293:9fd5340faffd |
---|---|
19 perror(sysFunc); | 19 perror(sysFunc); |
20 } | 20 } |
21 exit(1); | 21 exit(1); |
22 } | 22 } |
23 | 23 |
24 H::H(Uns32T kk, Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC): | 24 H::H(){ |
25 // Delay initialization of lsh functions until we know the parameters | |
26 } | |
27 | |
28 H::H(Uns32T kk, Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC, float ww, float rr): | |
25 #ifdef USE_U_FUNCTIONS | 29 #ifdef USE_U_FUNCTIONS |
26 use_u_functions(true), | 30 use_u_functions(true), |
27 #else | 31 #else |
28 use_u_functions(false), | 32 use_u_functions(false), |
29 #endif | 33 #endif |
34 maxp(0), | |
30 bucketCount(0), | 35 bucketCount(0), |
31 pointCount(0), | 36 pointCount(0), |
32 N(NN), | 37 N(NN), |
33 C(CC), | 38 C(CC), |
34 k(kk), | 39 k(kk), |
35 m(mm), | 40 m(mm), |
36 L(mm*(mm-1)/2), | 41 L((mm*(mm-1))/2), |
37 d(dd) | 42 d(dd), |
43 w(ww), | |
44 radius(rr) | |
38 { | 45 { |
39 Uns32T j; | 46 |
47 if(m<2){ | |
48 m=2; | |
49 L=1; // check value of L | |
50 cout << "warning: setting m=2, L=1" << endl; | |
51 } | |
52 if(use_u_functions && k%2){ | |
53 k++; // make sure k is even | |
54 cout << "warning: setting k even" << endl; | |
55 } | |
56 | |
40 cout << "file size: ~" << (((unsigned long long)L*N*C*sizeof(SerialElementT))/1000000UL) << "MB" << endl; | 57 cout << "file size: ~" << (((unsigned long long)L*N*C*sizeof(SerialElementT))/1000000UL) << "MB" << endl; |
41 if(((unsigned long long)L*N*C*sizeof(SerialElementT))>4000000000UL) | 58 if(((unsigned long long)L*N*C*sizeof(SerialElementT))>4000000000UL) |
42 error("Maximum size of LSH file exceded: 12*L*N*C > 4000MB"); | 59 error("Maximum size of LSH file exceded: 12*L*N*C > 4000MB"); |
43 else if(((unsigned long long)N*C*sizeof(SerialElementT))>1000000000UL) | 60 else if(((unsigned long long)N*C*sizeof(SerialElementT))>1000000000UL) |
44 cout << "warning: hash tables exceed 1000MB." << endl; | 61 cout << "warning: hash tables exceed 1000MB." << endl; |
45 | 62 |
46 if(m<2){ | 63 // We have the necessary parameters, so construct hashfunction datastructures |
47 m=2; | 64 initialize_lsh_functions(); |
48 L=1; // check value of L | 65 } |
49 cout << "warning: setting m=2, L=1" << endl; | 66 |
50 } | 67 void H::initialize_lsh_functions(){ |
51 if(use_u_functions && k%2){ | |
52 k++; // make sure k is even | |
53 cout << "warning: setting k even" << endl; | |
54 } | |
55 __initialize_data_structures(); | |
56 for(j=0; j<L; j++) | |
57 for(kk=0; kk<k; kk++) { | |
58 r1[j][kk]=__randr(); // random 1..2^29 | |
59 r2[j][kk]=__randr(); // random 1..2^29 | |
60 } | |
61 } | |
62 | |
63 // Post constructor initialization | |
64 void H::__initialize_data_structures(){ | |
65 H::P = UH_PRIME_DEFAULT; | 68 H::P = UH_PRIME_DEFAULT; |
66 | 69 |
67 /* FIXME: don't use time(); instead use /dev/random or similar */ | 70 /* FIXME: don't use time(); instead use /dev/random or similar */ |
68 /* FIXME: write out the seed somewhere, so that we can get | 71 /* FIXME: write out the seed somewhere, so that we can get |
69 repeatability */ | 72 repeatability */ |
70 #ifdef MT19937 | 73 #ifdef MT19937 |
71 init_genrand(time(NULL)); | 74 init_genrand(time(NULL)); |
72 #else | 75 #else |
73 srand(time(NULL)); // seed random number generator | 76 srand(time(NULL)); // seed random number generator |
74 #endif | 77 #endif |
75 Uns32T i,j; | 78 Uns32T i,j, kk; |
79 #ifdef USE_U_FUNCTIONS | |
80 H::A = new float**[ H::m ]; // m x k x d random projectors | |
81 H::b = new float*[ H::m ]; // m x k random biases | |
82 #else | |
83 H::A = new float**[ H::L ]; // m x k x d random projectors | |
84 H::b = new float*[ H::L ]; // m x k random biases | |
85 #endif | |
86 H::g = new Uns32T*[ H::L ]; // L x k random projections | |
87 assert( H::g && H::A && H::b ); // failure | |
88 #ifdef USE_U_FUNCTIONS | |
89 // Use m \times u_i functions \in R^{(k/2) \times (d)} | |
90 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
91 for( j = 0; j < H::m ; j++ ){ // m functions u_i(v) | |
92 H::A[j] = new float*[ H::k/2 ]; // k/2 x d 2-stable distribution coefficients | |
93 H::b[j] = new float[ H::k/2 ]; // bias | |
94 assert( H::A[j] && H::b[j] ); // failure | |
95 for( kk = 0; kk < H::k/2 ; kk++ ){ | |
96 H::A[j][kk] = new float[ H::d ]; | |
97 assert( H::A[j][kk] ); // failure | |
98 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
99 H::A[j][kk][i] = H::randn(); // Normal | |
100 H::b[j][kk] = H::ranf()*H::w; // Uniform | |
101 } | |
102 } | |
103 #else | |
104 // Use m \times u_i functions \in R^{k \times (d)} | |
105 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
106 for( j = 0; j < H::L ; j++ ){ // m functions u_i(v) | |
107 H::A[j] = new float*[ H::k ]; // k x d 2-stable distribution coefficients | |
108 H::b[j] = new float[ H::k ]; // bias | |
109 assert( H::A[j] && H::b[j] ); // failure | |
110 for( kk = 0; kk < H::k ; kk++ ){ | |
111 H::A[j][kk] = new float[ H::d ]; | |
112 assert( H::A[j][kk] ); // failure | |
113 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
114 H::A[j][kk][i] = H::randn(); // Normal | |
115 H::b[j][kk] = H::ranf()*H::w; // Uniform | |
116 } | |
117 } | |
118 #endif | |
119 | |
120 // Storage for LSH hash function output (Uns32T) | |
121 for( j = 0 ; j < H::L ; j++ ){ // L functions g_j(u_a, u_b) a,b \in nchoosek(m,2) | |
122 H::g[j] = new Uns32T[ H::k ]; // k x 32-bit hash values, gj(v)=[x0 x1 ... xk-1] xk \in Z | |
123 assert( H::g[j] ); | |
124 } | |
125 | |
126 // LSH Hash tables | |
76 H::h = new bucket**[ H::L ]; | 127 H::h = new bucket**[ H::L ]; |
77 H::r1 = new Uns32T*[ H::L ]; | 128 assert( H::h ); |
78 H::r2 = new Uns32T*[ H::L ]; | |
79 assert( H::h && H::r1 && H::r2 ); // failure | |
80 for( j = 0 ; j < H::L ; j++ ){ | |
81 H::r1[ j ] = new Uns32T[ H::k ]; | |
82 H::r2[ j ] = new Uns32T[ H::k ]; | |
83 assert( H::r1[j] && H::r2[j] ); // failure | |
84 } | |
85 | |
86 for( j = 0 ; j < H::L ; j++ ){ | 129 for( j = 0 ; j < H::L ; j++ ){ |
87 H::h[j] = new bucket*[ H::N ]; | 130 H::h[j] = new bucket*[ H::N ]; |
88 assert( H::h[j] ); | 131 assert( H::h[j] ); |
89 for( i = 0 ; i < H::N ; i++) | 132 for( i = 0 ; i < H::N ; i++) |
90 H::h[j][i] = 0; | 133 H::h[j][i] = 0; |
91 } | 134 } |
92 } | 135 |
93 | 136 // Standard hash functions |
94 // Destruct hash tables | 137 H::r1 = new Uns32T*[ H::L ]; |
95 H::~H(){ | 138 H::r2 = new Uns32T*[ H::L ]; |
96 Uns32T i,j; | 139 assert( H::r1 && H::r2 ); // failure |
97 for( j=0 ; j < H::L ; j++ ){ | 140 for( j = 0 ; j < H::L ; j++ ){ |
98 delete[] H::r1[ j ]; | 141 H::r1[ j ] = new Uns32T[ H::k ]; |
99 delete[] H::r2[ j ]; | 142 H::r2[ j ] = new Uns32T[ H::k ]; |
100 for(i = 0; i< H::N ; i++) | 143 assert( H::r1[j] && H::r2[j] ); // failure |
101 delete H::h[ j ][ i ]; | 144 for( i = 0; i<H::k; i++){ |
102 delete[] H::h[ j ]; | 145 H::r1[j][i] = randr(); |
103 } | 146 H::r2[j][i] = randr(); |
104 delete[] H::r1; | 147 } |
105 delete[] H::r2; | 148 } |
106 delete[] H::h; | 149 |
107 } | 150 // Storage for whole or partial function evaluation depdenting on USE_U_FUNCTIONS |
108 | 151 H::initialize_partial_functions(); |
109 | 152 } |
110 // make hash value \in Z | 153 |
111 void H::__generate_hash_keys(Uns32T*g,Uns32T* r1, Uns32T* r2){ | 154 void H::initialize_partial_functions(){ |
112 H::t1 = __computeProductModDefaultPrime( g, r1, H::k ) % H::N; | 155 |
113 H::t2 = __computeProductModDefaultPrime( g, r2, H::k ); | |
114 | |
115 } | |
116 | |
117 #define CR_ASSERT(b){if(!(b)){fprintf(stderr, "ASSERT failed on line %d, file %s.\n", __LINE__, __FILE__); exit(1);}} | |
118 | |
119 // Computes (a.b) mod UH_PRIME_DEFAULT | |
120 inline Uns32T H::__computeProductModDefaultPrime(Uns32T *a, Uns32T *b, IntT size){ | |
121 LongUns64T h = 0; | |
122 | |
123 for(IntT i = 0; i < size; i++){ | |
124 h = h + (LongUns64T)a[i] * (LongUns64T)b[i]; | |
125 h = (h & TWO_TO_32_MINUS_1) + 5 * (h >> 32); | |
126 if (h >= UH_PRIME_DEFAULT) { | |
127 h = h - UH_PRIME_DEFAULT; | |
128 } | |
129 CR_ASSERT(h < UH_PRIME_DEFAULT); | |
130 } | |
131 return h; | |
132 } | |
133 | |
134 Uns32T H::bucket_insert_point(bucket **pp){ | |
135 Uns32T collisionCount = 0; | |
136 if(!*pp){ | |
137 *pp = new bucket(); | |
138 #ifdef LSH_BLOCK_FULL_ROWS | |
139 (*pp)->t2 = 0; // Use t2 as a collision counter for the row | |
140 (*pp)->next = new bucket(); | |
141 #endif | |
142 } | |
143 #ifdef LSH_BLOCK_FULL_ROWS | |
144 collisionCount = (*pp)->t2; | |
145 if(collisionCount < H::C){ // Block if row is full | |
146 (*pp)->t2++; // Increment collision counter | |
147 pointCount++; | |
148 collisionCount++; | |
149 __bucket_insert_point((*pp)->next); // First bucket holds collision count | |
150 } | |
151 #else | |
152 pointCount++; | |
153 __bucket_insert_point(*pp); // No collision count storage | |
154 #endif | |
155 return collisionCount; | |
156 } | |
157 | |
158 void H::__bucket_insert_point(bucket* p){ | |
159 if(p->t2 == IFLAG){ // initialization flag, is it in the domain of t2? | |
160 p->t2 = H::t2; | |
161 bucketCount++; // Record start of new point-locale collision chain | |
162 p->snext = new sbucket(); | |
163 __sbucket_insert_point(p->snext); | |
164 return; | |
165 } | |
166 | |
167 if(p->t2 == H::t2){ | |
168 __sbucket_insert_point(p->snext); | |
169 return; | |
170 } | |
171 | |
172 if(p->next){ | |
173 __bucket_insert_point(p->next); | |
174 } | |
175 | |
176 else{ | |
177 p->next = new bucket(); | |
178 __bucket_insert_point(p->next); | |
179 } | |
180 | |
181 } | |
182 | |
183 void H::__sbucket_insert_point(sbucket* p){ | |
184 if(p->pointID==IFLAG){ | |
185 p->pointID = H::p; | |
186 return; | |
187 } | |
188 | |
189 // Search for pointID | |
190 if(p->snext){ | |
191 __sbucket_insert_point(p->snext); | |
192 } | |
193 else{ | |
194 // Make new point collision bucket at end of list | |
195 p->snext = new sbucket(); | |
196 __sbucket_insert_point(p->snext); | |
197 } | |
198 } | |
199 | |
200 inline bucket** H::__get_bucket(int j){ | |
201 return *(h+j); | |
202 } | |
203 | |
204 // hash functions G | |
205 G::G(float ww, Uns32T kk,Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC, float r): | |
206 H(kk,mm,dd,NN,CC), | |
207 w(ww), | |
208 radius(r), | |
209 maxp(0), | |
210 calling_instance(0), | |
211 add_point_callback(0), | |
212 lshHeader(0) | |
213 { | |
214 Uns32T j; | |
215 #ifdef USE_U_FUNCTIONS | 156 #ifdef USE_U_FUNCTIONS |
216 G::A = new float**[ H::m ]; // m x k x d random projectors | 157 H::uu = vector<vector<Uns32T> >(H::m); |
217 G::b = new float*[ H::m ]; // m x k random biases | |
218 #else | |
219 G::A = new float**[ H::L ]; // m x k x d random projectors | |
220 G::b = new float*[ H::L ]; // m x k random biases | |
221 #endif | |
222 G::g = new Uns32T*[ H::L ]; // L x k random projections | |
223 assert( G::g && G::A && G::b ); // failure | |
224 #ifdef USE_U_FUNCTIONS | |
225 // Use m \times u_i functions \in R^{(k/2) \times (d)} | |
226 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
227 for( j = 0; j < H::m ; j++ ){ // m functions u_i(v) | |
228 G::A[j] = new float*[ H::k/2 ]; // k/2 x d 2-stable distribution coefficients | |
229 G::b[j] = new float[ H::k/2 ]; // bias | |
230 assert( G::A[j] && G::b[j] ); // failure | |
231 for( kk = 0; kk < H::k/2 ; kk++ ){ | |
232 G::A[j][kk] = new float[ H::d ]; | |
233 assert( G::A[j][kk] ); // failure | |
234 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
235 G::A[j][kk][i] = randn(); // Normal | |
236 G::b[j][kk] = ranf()*G::w; // Uniform | |
237 } | |
238 } | |
239 #else | |
240 // Use m \times u_i functions \in R^{k \times (d)} | |
241 // Combine to make L=m(m-1)/2 hash functions \in R^{k \times d} | |
242 for( j = 0; j < H::L ; j++ ){ // m functions u_i(v) | |
243 G::A[j] = new float*[ H::k ]; // k x d 2-stable distribution coefficients | |
244 G::b[j] = new float[ H::k ]; // bias | |
245 assert( G::A[j] && G::b[j] ); // failure | |
246 for( kk = 0; kk < H::k ; kk++ ){ | |
247 G::A[j][kk] = new float[ H::d ]; | |
248 assert( G::A[j][kk] ); // failure | |
249 for(Uns32T i = 0 ; i < H::d ; i++ ) | |
250 G::A[j][kk][i] = randn(); // Normal | |
251 G::b[j][kk] = ranf()*G::w; // Uniform | |
252 } | |
253 } | |
254 #endif | |
255 | |
256 for( j = 0 ; j < H::L ; j++ ){ // L functions g_j(u_a, u_b) a,b \in nchoosek(m,2) | |
257 G::g[j] = new Uns32T[ H::k ]; // k x 32-bit hash values, gj(v)=[x0 x1 ... xk-1] xk \in Z | |
258 assert( G::g[j] ); | |
259 } | |
260 | |
261 initialize_partial_functions(); // m partially evaluated hash functions | |
262 } | |
263 | |
264 // Serialize from file LSH constructor | |
265 // Read parameters from database file | |
266 // Load the hash functions, close the database | |
267 // Optionally load the LSH tables into head-allocated lists in core | |
268 G::G(char* filename, bool lshInCoreFlag): | |
269 calling_instance(0), | |
270 add_point_callback(0) | |
271 { | |
272 int dbfid = unserialize_lsh_header(filename); | |
273 unserialize_lsh_functions(dbfid); | |
274 initialize_partial_functions(); | |
275 | |
276 // Format1 only needs unserializing if specifically requested | |
277 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2) && lshInCoreFlag){ | |
278 unserialize_lsh_hashtables_format1(dbfid); | |
279 } | |
280 | |
281 // Format2 always needs unserializing | |
282 if(lshHeader->flags&O2_SERIAL_FILEFORMAT2 && lshInCoreFlag){ | |
283 unserialize_lsh_hashtables_format2(dbfid); | |
284 } | |
285 | |
286 close(dbfid); | |
287 } | |
288 | |
289 void G::initialize_partial_functions(){ | |
290 | |
291 #ifdef USE_U_FUNCTIONS | |
292 uu = vector<vector<Uns32T> >(H::m); | |
293 for( Uns32T aa=0 ; aa < H::m ; aa++ ) | 158 for( Uns32T aa=0 ; aa < H::m ; aa++ ) |
294 uu[aa] = vector<Uns32T>( H::k/2 ); | 159 H::uu[aa] = vector<Uns32T>( H::k/2 ); |
295 #else | 160 #else |
296 uu = vector<vector<Uns32T> >(H::L); | 161 H::uu = vector<vector<Uns32T> >(H::L); |
297 for( Uns32T aa=0 ; aa < H::L ; aa++ ) | 162 for( Uns32T aa=0 ; aa < H::L ; aa++ ) |
298 uu[aa] = vector<Uns32T>( H::k ); | 163 H::uu[aa] = vector<Uns32T>( H::k ); |
299 #endif | 164 #endif |
300 } | 165 } |
301 | 166 |
302 | 167 |
303 // Generate z ~ N(0,1) | 168 // Generate z ~ N(0,1) |
304 float G::randn(){ | 169 float H::randn(){ |
305 // Box-Muller | 170 // Box-Muller |
306 float x1, x2; | 171 float x1, x2; |
307 do{ | 172 do{ |
308 x1 = ranf(); | 173 x1 = ranf(); |
309 } while (x1 == 0); // cannot take log of 0 | 174 } while (x1 == 0); // cannot take log of 0 |
311 float z; | 176 float z; |
312 z = sqrtf(-2.0 * logf(x1)) * cosf(2.0 * M_PI * x2); | 177 z = sqrtf(-2.0 * logf(x1)) * cosf(2.0 * M_PI * x2); |
313 return z; | 178 return z; |
314 } | 179 } |
315 | 180 |
316 float G::ranf(){ | 181 float H::ranf(){ |
317 #ifdef MT19937 | 182 #ifdef MT19937 |
318 return (float) genrand_real2(); | 183 return (float) genrand_real2(); |
319 #else | 184 #else |
320 return (float)( (double)rand() / ((double)(RAND_MAX)+(double)(1)) ); | 185 return (float)( (double)rand() / ((double)(RAND_MAX)+(double)(1)) ); |
321 #endif | 186 #endif |
322 } | 187 } |
323 | 188 |
324 // range is 1..2^29 | 189 // range is 1..2^29 |
325 /* FIXME: that looks like an ... odd range. Still. */ | 190 /* FIXME: that looks like an ... odd range. Still. */ |
326 Uns32T H::__randr(){ | 191 Uns32T H::randr(){ |
327 #ifdef MT19937 | 192 #ifdef MT19937 |
328 return (Uns32T)((genrand_int32() >> 3) + 1); | 193 return (Uns32T)((genrand_int32() >> 3) + 1); |
329 #else | 194 #else |
330 return (Uns32T) ((rand() >> 2) + 1); | 195 return (Uns32T) ((rand() >> 2) + 1); |
331 #endif | 196 #endif |
332 } | 197 } |
333 | 198 |
334 G::~G(){ | 199 // Destruct hash tables |
335 Uns32T j,kk; | 200 H::~H(){ |
201 Uns32T i,j,kk; | |
336 #ifdef USE_U_FUNCTIONS | 202 #ifdef USE_U_FUNCTIONS |
337 for( j = 0 ; j < H::m ; j++ ){ | 203 for( j = 0 ; j < H::m ; j++ ){ |
338 for( kk = 0 ; kk < H::k/2 ; kk++ ) | 204 for( kk = 0 ; kk < H::k/2 ; kk++ ) |
339 delete[] A[j][kk]; | 205 delete[] A[j][kk]; |
340 delete[] A[j]; | 206 delete[] A[j]; |
356 #endif | 222 #endif |
357 | 223 |
358 for( j = 0 ; j < H::L ; j++ ) | 224 for( j = 0 ; j < H::L ; j++ ) |
359 delete[] g[j]; | 225 delete[] g[j]; |
360 delete[] g; | 226 delete[] g; |
361 delete lshHeader; | 227 for( j=0 ; j < H::L ; j++ ){ |
362 } | 228 delete[] H::r1[ j ]; |
229 delete[] H::r2[ j ]; | |
230 for(i = 0; i< H::N ; i++) | |
231 delete H::h[ j ][ i ]; | |
232 delete[] H::h[ j ]; | |
233 } | |
234 delete[] H::r1; | |
235 delete[] H::r2; | |
236 delete[] H::h; | |
237 } | |
238 | |
363 | 239 |
364 // Compute all hash functions for vector v | 240 // Compute all hash functions for vector v |
365 // #ifdef USE_U_FUNCTIONS use Combination of m \times h_i \in R^{(k/2) \times d} | 241 // #ifdef USE_U_FUNCTIONS use Combination of m \times h_i \in R^{(k/2) \times d} |
366 // to make L \times g_j functions \in Z^k | 242 // to make L \times g_j functions \in Z^k |
367 void G::compute_hash_functions(vector<float>& v){ // v \in R^d | 243 void H::compute_hash_functions(vector<float>& v){ // v \in R^d |
368 float iw = 1. / G::w; // hash bucket width | 244 float iw = 1. / H::w; // hash bucket width |
369 Uns32T aa, kk; | 245 Uns32T aa, kk; |
370 if( v.size() != H::d ) | 246 if( v.size() != H::d ) |
371 error("v.size != H::d","","compute_hash_functions"); // check input vector dimensionality | 247 error("v.size != H::d","","compute_hash_functions"); // check input vector dimensionality |
372 double tmp = 0; | 248 double tmp = 0; |
373 float *pA, *pb; | 249 float *pA, *pb; |
378 | 254 |
379 #ifdef USE_U_FUNCTIONS | 255 #ifdef USE_U_FUNCTIONS |
380 Uns32T bb; | 256 Uns32T bb; |
381 // Store m dot products to expand | 257 // Store m dot products to expand |
382 for( aa=0; aa < H::m ; aa++ ){ | 258 for( aa=0; aa < H::m ; aa++ ){ |
383 ui = uu[aa].begin(); | 259 ui = H::uu[aa].begin(); |
384 for( kk = 0 ; kk < H::k/2 ; kk++ ){ | 260 for( kk = 0 ; kk < H::k/2 ; kk++ ){ |
385 pb = *( G::b + aa ) + kk; | 261 pb = *( H::b + aa ) + kk; |
386 pA = * ( * ( G::A + aa ) + kk ); | 262 pA = * ( * ( H::A + aa ) + kk ); |
387 dd = H::d; | 263 dd = H::d; |
388 tmp = 0.; | 264 tmp = 0.; |
389 vi = v.begin(); | 265 vi = v.begin(); |
390 while( dd-- ) | 266 while( dd-- ) |
391 tmp += *pA++ * *vi++; // project | 267 tmp += *pA++ * *vi++; // project |
396 } | 272 } |
397 // Binomial combinations of functions u_{a,b} \in Z^{(k/2) \times d} | 273 // Binomial combinations of functions u_{a,b} \in Z^{(k/2) \times d} |
398 Uns32T j; | 274 Uns32T j; |
399 for( aa=0, j=0 ; aa < H::m-1 ; aa++ ) | 275 for( aa=0, j=0 ; aa < H::m-1 ; aa++ ) |
400 for( bb = aa + 1 ; bb < H::m ; bb++, j++ ){ | 276 for( bb = aa + 1 ; bb < H::m ; bb++, j++ ){ |
401 pg= *( G::g + j ); // L \times functions g_j(v) \in Z^k | 277 pg= *( H::g + j ); // L \times functions g_j(v) \in Z^k |
402 // u_1 \in Z^{(k/2) \times d} | 278 // u_1 \in Z^{(k/2) \times d} |
403 ui = uu[aa].begin(); | 279 ui = H::uu[aa].begin(); |
404 kk=H::k/2; | 280 kk=H::k/2; |
405 while( kk-- ) | 281 while( kk-- ) |
406 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... x(k/2)]; xk \in Z | 282 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... x(k/2)]; xk \in Z |
407 // u_2 \in Z^{(k/2) \times d} | 283 // u_2 \in Z^{(k/2) \times d} |
408 ui = uu[bb].begin(); | 284 ui = H::uu[bb].begin(); |
409 kk=H::k/2; | 285 kk=H::k/2; |
410 while( kk--) | 286 while( kk--) |
411 *pg++ = *ui++; // hash function g_j(v)=[x(k/2+1) x(k/2+2) ... xk]; xk \in Z | 287 *pg++ = *ui++; // hash function g_j(v)=[x(k/2+1) x(k/2+2) ... xk]; xk \in Z |
412 } | 288 } |
413 #else | 289 #else |
414 for( aa=0; aa < H::L ; aa++ ){ | 290 for( aa=0; aa < H::L ; aa++ ){ |
415 ui = uu[aa].begin(); | 291 ui = H::uu[aa].begin(); |
416 for( kk = 0 ; kk < H::k ; kk++ ){ | 292 for( kk = 0 ; kk < H::k ; kk++ ){ |
417 pb = *( G::b + aa ) + kk; | 293 pb = *( H::b + aa ) + kk; |
418 pA = * ( * ( G::A + aa ) + kk ); | 294 pA = * ( * ( H::A + aa ) + kk ); |
419 dd = H::d; | 295 dd = H::d; |
420 tmp = 0.; | 296 tmp = 0.; |
421 vi = v.begin(); | 297 vi = v.begin(); |
422 while( dd-- ) | 298 while( dd-- ) |
423 tmp += *pA++ * *vi++; // project | 299 tmp += *pA++ * *vi++; // project |
426 *ui++ = (Uns32T) (floor(tmp)); // floor | 302 *ui++ = (Uns32T) (floor(tmp)); // floor |
427 } | 303 } |
428 } | 304 } |
429 // Compute hash functions | 305 // Compute hash functions |
430 for( aa=0 ; aa < H::L ; aa++ ){ | 306 for( aa=0 ; aa < H::L ; aa++ ){ |
431 pg= *( G::g + aa ); // L \times functions g_j(v) \in Z^k | 307 pg= *( H::g + aa ); // L \times functions g_j(v) \in Z^k |
432 // u_1 \in Z^{k \times d} | 308 // u_1 \in Z^{k \times d} |
433 ui = uu[aa].begin(); | 309 ui = H::uu[aa].begin(); |
434 kk=H::k; | 310 kk=H::k; |
435 while( kk-- ) | 311 while( kk-- ) |
436 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... xk]; xk \in Z | 312 *pg++ = *ui++; // hash function g_j(v)=[x1 x2 ... xk]; xk \in Z |
437 } | 313 } |
438 #endif | 314 #endif |
439 | 315 } |
440 } | 316 |
441 | 317 // make hash value \in Z |
318 void H::generate_hash_keys(Uns32T*g, Uns32T* r1, Uns32T* r2){ | |
319 H::t1 = computeProductModDefaultPrime( g, r1, H::k ) % H::N; | |
320 H::t2 = computeProductModDefaultPrime( g, r2, H::k ); | |
321 } | |
322 | |
323 #define CR_ASSERT(b){if(!(b)){fprintf(stderr, "ASSERT failed on line %d, file %s.\n", __LINE__, __FILE__); exit(1);}} | |
324 | |
325 // Computes (a.b) mod UH_PRIME_DEFAULT | |
326 inline Uns32T H::computeProductModDefaultPrime(Uns32T *a, Uns32T *b, IntT size){ | |
327 LongUns64T h = 0; | |
328 | |
329 for(IntT i = 0; i < size; i++){ | |
330 h = h + (LongUns64T)a[i] * (LongUns64T)b[i]; | |
331 h = (h & TWO_TO_32_MINUS_1) + 5 * (h >> 32); | |
332 if (h >= UH_PRIME_DEFAULT) { | |
333 h = h - UH_PRIME_DEFAULT; | |
334 } | |
335 CR_ASSERT(h < UH_PRIME_DEFAULT); | |
336 } | |
337 return h; | |
338 } | |
339 | |
340 Uns32T H::bucket_insert_point(bucket **pp){ | |
341 Uns32T collisionCount = 0; | |
342 if(!*pp){ | |
343 *pp = new bucket(); | |
344 #ifdef LSH_BLOCK_FULL_ROWS | |
345 (*pp)->t2 = 0; // Use t2 as a collision counter for the row | |
346 (*pp)->next = new bucket(); | |
347 #endif | |
348 } | |
349 #ifdef LSH_BLOCK_FULL_ROWS | |
350 collisionCount = (*pp)->t2; | |
351 if(collisionCount < H::C){ // Block if row is full | |
352 (*pp)->t2++; // Increment collision counter | |
353 pointCount++; | |
354 collisionCount++; | |
355 __bucket_insert_point((*pp)->next); // First bucket holds collision count | |
356 } | |
357 #else | |
358 pointCount++; | |
359 __bucket_insert_point(*pp); // No collision count storage | |
360 #endif | |
361 return collisionCount; | |
362 } | |
363 | |
364 void H::__bucket_insert_point(bucket* p){ | |
365 if(p->t2 == IFLAG){ // initialization flag, is it in the domain of t2? | |
366 p->t2 = H::t2; | |
367 bucketCount++; // Record start of new point-locale collision chain | |
368 p->snext = new sbucket(); | |
369 __sbucket_insert_point(p->snext); | |
370 return; | |
371 } | |
372 | |
373 if(p->t2 == H::t2){ | |
374 __sbucket_insert_point(p->snext); | |
375 return; | |
376 } | |
377 | |
378 if(p->next){ | |
379 __bucket_insert_point(p->next); | |
380 } | |
381 | |
382 else{ | |
383 p->next = new bucket(); | |
384 __bucket_insert_point(p->next); | |
385 } | |
386 | |
387 } | |
388 | |
389 void H::__sbucket_insert_point(sbucket* p){ | |
390 if(p->pointID==IFLAG){ | |
391 p->pointID = H::p; | |
392 return; | |
393 } | |
394 | |
395 // Search for pointID | |
396 if(p->snext){ | |
397 __sbucket_insert_point(p->snext); | |
398 } | |
399 else{ | |
400 // Make new point collision bucket at end of list | |
401 p->snext = new sbucket(); | |
402 __sbucket_insert_point(p->snext); | |
403 } | |
404 } | |
405 | |
406 inline bucket** H::get_bucket(int j){ | |
407 return *(h+j); | |
408 } | |
409 | |
410 // Interface to Locality Sensitive Hashing G | |
411 G::G(float ww, Uns32T kk,Uns32T mm, Uns32T dd, Uns32T NN, Uns32T CC, float rr): | |
412 H(kk,mm,dd,NN,CC,ww,rr), // constructor to initialize data structures | |
413 lshHeader(0), | |
414 calling_instance(0), | |
415 add_point_callback(0) | |
416 { | |
417 | |
418 } | |
419 | |
420 // Serialize from file LSH constructor | |
421 // Read parameters from database file | |
422 // Load the hash functions, close the database | |
423 // Optionally load the LSH tables into head-allocated lists in core | |
424 G::G(char* filename, bool lshInCoreFlag): | |
425 H(), // default base-class constructor call delays data-structure initialization | |
426 lshHeader(0), | |
427 calling_instance(0), | |
428 add_point_callback(0) | |
429 { | |
430 int dbfid = unserialize_lsh_header(filename); | |
431 | |
432 H::initialize_lsh_functions(); // Base-class data-structure initialization | |
433 unserialize_lsh_functions(dbfid); // populate with on-disk hashfunction values | |
434 | |
435 // Format1 only needs unserializing if specifically requested | |
436 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2) && lshInCoreFlag){ | |
437 unserialize_lsh_hashtables_format1(dbfid); | |
438 } | |
439 | |
440 // Format2 always needs unserializing | |
441 if(lshHeader->flags&O2_SERIAL_FILEFORMAT2 && lshInCoreFlag){ | |
442 unserialize_lsh_hashtables_format2(dbfid); | |
443 } | |
444 | |
445 close(dbfid);} | |
446 | |
447 G::~G(){ | |
448 delete lshHeader; | |
449 } | |
442 | 450 |
443 // single point insertion; inserted values are hash value and pointID | 451 // single point insertion; inserted values are hash value and pointID |
444 Uns32T G::insert_point(vector<float>& v, Uns32T pp){ | 452 Uns32T G::insert_point(vector<float>& v, Uns32T pp){ |
445 Uns32T collisionCount = 0; | 453 Uns32T collisionCount = 0; |
446 H::p = pp; | 454 H::p = pp; |
447 if(pp>G::maxp) | 455 if(pp>H::maxp) |
448 G::maxp=pp; // Store highest pointID in database | 456 H::maxp=pp; // Store highest pointID in database |
449 compute_hash_functions( v ); | 457 H::compute_hash_functions( v ); |
450 for(Uns32T j = 0 ; j < H::L ; j++ ){ // insertion | 458 for(Uns32T j = 0 ; j < H::L ; j++ ){ // insertion |
451 __generate_hash_keys( *( G::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); | 459 H::generate_hash_keys( *( H::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); |
452 collisionCount += bucket_insert_point( *(h + j) + t1 ); | 460 collisionCount += bucket_insert_point( *(h + j) + t1 ); |
453 } | 461 } |
454 return collisionCount; | 462 return collisionCount; |
455 } | 463 } |
456 | 464 |
464 | 472 |
465 // point retrieval routine | 473 // point retrieval routine |
466 void G::retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ | 474 void G::retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ |
467 calling_instance = caller; | 475 calling_instance = caller; |
468 add_point_callback = add_point; | 476 add_point_callback = add_point; |
469 compute_hash_functions( v ); | 477 H::compute_hash_functions( v ); |
470 for(Uns32T j = 0 ; j < H::L ; j++ ){ | 478 for(Uns32T j = 0 ; j < H::L ; j++ ){ |
471 __generate_hash_keys( *( G::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); | 479 H::generate_hash_keys( *( H::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); |
472 if( bucket* bPtr = *(__get_bucket(j) + get_t1()) ) | 480 if( bucket* bPtr = *(get_bucket(j) + get_t1()) ) |
473 #ifdef LSH_BLOCK_FULL_ROWS | 481 #ifdef LSH_BLOCK_FULL_ROWS |
474 bucket_chain_point( bPtr->next, qpos); | 482 bucket_chain_point( bPtr->next, qpos); |
475 #else | 483 #else |
476 bucket_chain_point( bPtr , qpos); | 484 bucket_chain_point( bPtr , qpos); |
477 #endif | 485 #endif |
517 // r2[0][0] r2[0][1] ... r2[0][k-1] | 525 // r2[0][0] r2[0][1] ... r2[0][k-1] |
518 // r2[1][0] r2[1][1] ... r2[1][k-1] | 526 // r2[1][0] r2[1][1] ... r2[1][k-1] |
519 // ... | 527 // ... |
520 // r2[L-1][0] r2[L-1][1] ... r2[L-1][k-1] | 528 // r2[L-1][0] r2[L-1][1] ... r2[L-1][k-1] |
521 // | 529 // |
530 // ******* HASHTABLES FORMAT1 (optimized for LSH_ON_DISK retrieval) ******* | |
522 // ---hash table 0: N x C x 8 --- | 531 // ---hash table 0: N x C x 8 --- |
523 // [t2 pointID][t2 pointID]...[t2 pointID] | 532 // [t2 pointID][t2 pointID]...[t2 pointID] |
524 // [t2 pointID][t2 pointID]...[t2 pointID] | 533 // [t2 pointID][t2 pointID]...[t2 pointID] |
525 // ... | 534 // ... |
526 // [t2 pointID][t2 pointID]...[t2 pointID] | 535 // [t2 pointID][t2 pointID]...[t2 pointID] |
537 // [t2 pointID][t2 pointID]...[t2 pointID] | 546 // [t2 pointID][t2 pointID]...[t2 pointID] |
538 // [t2 pointID][t2 pointID]...[t2 pointID] | 547 // [t2 pointID][t2 pointID]...[t2 pointID] |
539 // ... | 548 // ... |
540 // [t2 pointID][t2 pointID]...[t2 pointID] | 549 // [t2 pointID][t2 pointID]...[t2 pointID] |
541 // | 550 // |
551 // ******* HASHTABLES FORMAT2 (optimized for LSH_IN_CORE retrieval) ******* | |
552 // | |
553 // State machine controlled by regular expression. | |
554 // legend: | |
555 // | |
556 // O2_SERIAL_FLAGS_T1_BIT = 0x80000000U | |
557 // O2_SERIAL_FLAGS_T2_BIT = 0x40000000U | |
558 // O2_SERIAL_FLAGS_END_BIT = 0x20000000U | |
559 // | |
560 // T1(t1) - T1 hash token containing t1 hash key with O2_SERIAL_FLAGS_T1_BIT set (t1 range 0..2^29-1) | |
561 // T2 - T2 hash token with O2_SERIAL_FLAGS_T2_BIT set | |
562 // t2 - t2 hash key (range 1..2^32-6) | |
563 // p - point identifier (range 0..2^32-1) | |
564 // E - end hash table token with O2_SERIAL_FLAGS_END_BIT set | |
565 // {...} required arguments | |
566 // [...] optional arguments | |
567 // * - match zero or more occurences | |
568 // + - match one or more occurences | |
569 // {...}^L - repeat argument L times | |
570 // | |
571 // FORMAT2 Regular expression: | |
572 // { [T1(t1) T2 t2 p+ [T2 t2 p+]* ]* E. }^L | |
573 // | |
542 | 574 |
543 // Serial header constructors | 575 // Serial header constructors |
544 SerialHeader::SerialHeader(){;} | 576 SerialHeader::SerialHeader(){;} |
545 SerialHeader::SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float r, Uns32T p, Uns32T FMT): | 577 SerialHeader::SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float r, Uns32T p, Uns32T FMT): |
546 lshMagic(O2_SERIAL_MAGIC), | 578 lshMagic(O2_SERIAL_MAGIC), |
623 serialize_lsh_hashtables_format2(dbfid, !dbIsNew); | 655 serialize_lsh_hashtables_format2(dbfid, !dbIsNew); |
624 | 656 |
625 if(!dbIsNew){ | 657 if(!dbIsNew){ |
626 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer | 658 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer |
627 //serial_get_header(db); // read header | 659 //serial_get_header(db); // read header |
628 cout << "maxp = " << G::maxp << endl; | 660 cout << "maxp = " << H::maxp << endl; |
629 lshHeader->maxp=G::maxp; | 661 lshHeader->maxp=H::maxp; |
630 // Default to FILEFORMAT1 | 662 // Default to FILEFORMAT1 |
631 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2)) | 663 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2)) |
632 lshHeader->flags|=O2_SERIAL_FILEFORMAT2; | 664 lshHeader->flags|=O2_SERIAL_FILEFORMAT2; |
633 memcpy((char*)db, (char*)lshHeader, sizeof(SerialHeaderT)); | 665 memcpy((char*)db, (char*)lshHeader, sizeof(SerialHeaderT)); |
634 serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap | 666 serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap |
672 int G::serialize_lsh_hashfunctions(int fid){ | 704 int G::serialize_lsh_hashfunctions(int fid){ |
673 float* pf; | 705 float* pf; |
674 Uns32T *pu; | 706 Uns32T *pu; |
675 Uns32T x,y,z; | 707 Uns32T x,y,z; |
676 | 708 |
677 db = serial_mmap(fid, get_serial_hashtable_offset(), 1);// get database pointer | 709 char* db = serial_mmap(fid, get_serial_hashtable_offset(), 1);// get database pointer |
678 pf = get_serial_hashfunction_base(db); | 710 pf = get_serial_hashfunction_base(db); |
679 | 711 |
680 // HASH FUNCTIONS | 712 // HASH FUNCTIONS |
681 // Write the random projectors A[][][] | 713 // Write the random projectors A[][][] |
682 #ifdef USE_U_FUNCTIONS | 714 #ifdef USE_U_FUNCTIONS |
685 #else | 717 #else |
686 for( x = 0 ; x < H::L ; x++ ) | 718 for( x = 0 ; x < H::L ; x++ ) |
687 for( y = 0 ; y < H::k ; y++ ) | 719 for( y = 0 ; y < H::k ; y++ ) |
688 #endif | 720 #endif |
689 for( z = 0 ; z < d ; z++ ) | 721 for( z = 0 ; z < d ; z++ ) |
690 *pf++ = A[x][y][z]; | 722 *pf++ = H::A[x][y][z]; |
691 | 723 |
692 // Write the random biases b[][] | 724 // Write the random biases b[][] |
693 #ifdef USE_U_FUNCTIONS | 725 #ifdef USE_U_FUNCTIONS |
694 for( x = 0 ; x < H::m ; x++ ) | 726 for( x = 0 ; x < H::m ; x++ ) |
695 for( y = 0 ; y < H::k/2 ; y++ ) | 727 for( y = 0 ; y < H::k/2 ; y++ ) |
696 #else | 728 #else |
697 for( x = 0 ; x < H::L ; x++ ) | 729 for( x = 0 ; x < H::L ; x++ ) |
698 for( y = 0 ; y < H::k ; y++ ) | 730 for( y = 0 ; y < H::k ; y++ ) |
699 #endif | 731 #endif |
700 *pf++=b[x][y]; | 732 *pf++ = H::b[x][y]; |
701 | 733 |
702 pu = (Uns32T*)pf; | 734 pu = (Uns32T*)pf; |
703 | 735 |
704 // Write the Z projectors r1[][] | 736 // Write the Z projectors r1[][] |
705 for( x = 0 ; x < H::L ; x++) | 737 for( x = 0 ; x < H::L ; x++) |
706 for( y = 0 ; y < H::k ; y++) | 738 for( y = 0 ; y < H::k ; y++) |
707 *pu++ = r1[x][y]; | 739 *pu++ = H::r1[x][y]; |
708 | 740 |
709 // Write the Z projectors r2[][] | 741 // Write the Z projectors r2[][] |
710 for( x = 0 ; x < H::L ; x++) | 742 for( x = 0 ; x < H::L ; x++) |
711 for( y = 0; y < H::k ; y++) | 743 for( y = 0; y < H::k ; y++) |
712 *pu++ = r2[x][y]; | 744 *pu++ = H::r2[x][y]; |
713 | 745 |
714 serial_munmap(db, get_serial_hashtable_offset()); | 746 serial_munmap(db, get_serial_hashtable_offset()); |
715 return 1; | 747 return 1; |
716 } | 748 } |
717 | 749 |
950 | 982 |
951 // write a dummy byte at the last location | 983 // write a dummy byte at the last location |
952 if (write (dbfid, "", 1) != 1) | 984 if (write (dbfid, "", 1) != 1) |
953 error("write error", "", "write"); | 985 error("write error", "", "write"); |
954 | 986 |
955 db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1); | 987 char* db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1); |
956 | 988 |
957 memcpy (db, lshHeader, O2_SERIAL_HEADER_SIZE); | 989 memcpy (db, lshHeader, O2_SERIAL_HEADER_SIZE); |
958 | 990 |
959 serial_munmap(db, O2_SERIAL_HEADER_SIZE); | 991 serial_munmap(db, O2_SERIAL_HEADER_SIZE); |
960 | 992 |
964 | 996 |
965 return 1; | 997 return 1; |
966 } | 998 } |
967 | 999 |
968 char* G::serial_mmap(int dbfid, Uns32T memSize, Uns32T forWrite, off_t offset){ | 1000 char* G::serial_mmap(int dbfid, Uns32T memSize, Uns32T forWrite, off_t offset){ |
1001 char* db; | |
969 if(forWrite){ | 1002 if(forWrite){ |
970 if ((db = (char*) mmap(0, memSize, PROT_READ | PROT_WRITE, | 1003 if ((db = (char*) mmap(0, memSize, PROT_READ | PROT_WRITE, |
971 MAP_SHARED, dbfid, offset)) == (caddr_t) -1) | 1004 MAP_SHARED, dbfid, offset)) == (caddr_t) -1) |
972 error("mmap error in request for writable serialized database", "", "mmap"); | 1005 error("mmap error in request for writable serialized database", "", "mmap"); |
973 } | 1006 } |
1031 H::m = (Uns32T)( (1.0 + sqrt(1 + 8.0*(int)H::L)) / 2.0); | 1064 H::m = (Uns32T)( (1.0 + sqrt(1 + 8.0*(int)H::L)) / 2.0); |
1032 H::N = lshHeader->numRows; | 1065 H::N = lshHeader->numRows; |
1033 H::C = lshHeader->numCols; | 1066 H::C = lshHeader->numCols; |
1034 H::k = lshHeader->numFuns; | 1067 H::k = lshHeader->numFuns; |
1035 H::d = lshHeader->dataDim; | 1068 H::d = lshHeader->dataDim; |
1036 G::w = lshHeader->binWidth; | 1069 H::w = lshHeader->binWidth; |
1037 G::radius = lshHeader->radius; | 1070 H::radius = lshHeader->radius; |
1038 G::maxp = lshHeader->maxp; | 1071 H::maxp = lshHeader->maxp; |
1039 | 1072 |
1040 return dbfid; | 1073 return dbfid; |
1041 } | 1074 } |
1042 | 1075 |
1043 // unserialize the LSH parameters | 1076 // unserialize the LSH parameters |
1049 Uns32T* pu; | 1082 Uns32T* pu; |
1050 | 1083 |
1051 // Load the hash functions into core | 1084 // Load the hash functions into core |
1052 char* db = serial_mmap(dbfid, get_serial_hashtable_offset(), 0);// get database pointer again | 1085 char* db = serial_mmap(dbfid, get_serial_hashtable_offset(), 0);// get database pointer again |
1053 | 1086 |
1054 #ifdef USE_U_FUNCTIONS | |
1055 G::A = new float**[ H::m ]; // m x k x d random projectors | |
1056 G::b = new float*[ H::m ]; // m x k random biases | |
1057 #else | |
1058 G::A = new float**[ H::L ]; // m x k x d random projectors | |
1059 G::b = new float*[ H::L ]; // m x k random biases | |
1060 #endif | |
1061 G::g = new Uns32T*[ H::L ]; // L x k random projections | |
1062 assert(g&&A&&b); // failure | |
1063 | |
1064 pf = get_serial_hashfunction_base(db); | 1087 pf = get_serial_hashfunction_base(db); |
1065 | 1088 |
1066 #ifdef USE_U_FUNCTIONS | 1089 #ifdef USE_U_FUNCTIONS |
1067 for( j = 0 ; j < H::m ; j++ ){ // L functions gj(v) | 1090 for( j = 0 ; j < H::m ; j++ ){ // L functions gj(v) |
1068 G::A[j] = new float*[ H::k/2 ]; // k x d 2-stable distribution coefficients | |
1069 G::b[j] = new float[ H::k/2 ]; // bias | |
1070 assert( G::A[j] && G::b[j] ); // failure | |
1071 for( kk = 0 ; kk < H::k/2 ; kk++ ){ // Normally distributed hash functions | 1091 for( kk = 0 ; kk < H::k/2 ; kk++ ){ // Normally distributed hash functions |
1072 #else | 1092 #else |
1073 for( j = 0 ; j < H::L ; j++ ){ // L functions gj(v) | 1093 for( j = 0 ; j < H::L ; j++ ){ // L functions gj(v) |
1074 G::A[j] = new float*[ H::k ]; // k x d 2-stable distribution coefficients | |
1075 G::b[j] = new float[ H::k ]; // bias | |
1076 assert( G::A[j] && G::b[j] ); // failure | |
1077 for( kk = 0 ; kk < H::k ; kk++ ){ // Normally distributed hash functions | 1094 for( kk = 0 ; kk < H::k ; kk++ ){ // Normally distributed hash functions |
1078 #endif | 1095 #endif |
1079 G::A[j][kk] = new float[ H::d ]; | |
1080 assert( G::A[j][kk] ); // failure | |
1081 for(Uns32T i = 0 ; i < H::d ; i++ ) | 1096 for(Uns32T i = 0 ; i < H::d ; i++ ) |
1082 G::A[j][kk][i] = *pf++; // Normally distributed random vectors | 1097 H::A[j][kk][i] = *pf++; // Normally distributed random vectors |
1083 } | 1098 } |
1084 } | 1099 } |
1085 #ifdef USE_U_FUNCTIONS | 1100 #ifdef USE_U_FUNCTIONS |
1086 for( j = 0 ; j < H::m ; j++ ) // biases b | 1101 for( j = 0 ; j < H::m ; j++ ) // biases b |
1087 for( kk = 0 ; kk < H::k/2 ; kk++ ) | 1102 for( kk = 0 ; kk < H::k/2 ; kk++ ) |
1088 #else | 1103 #else |
1089 for( j = 0 ; j < H::L ; j++ ) // biases b | 1104 for( j = 0 ; j < H::L ; j++ ) // biases b |
1090 for( kk = 0 ; kk < H::k ; kk++ ) | 1105 for( kk = 0 ; kk < H::k ; kk++ ) |
1091 #endif | 1106 #endif |
1092 G::b[j][kk] = *pf++; | 1107 H::b[j][kk] = *pf++; |
1093 | |
1094 for( j = 0 ; j < H::L ; j++){ // 32-bit hash values, gj(v)=[x0 x1 ... xk-1] xk \in Z | |
1095 G::g[j] = new Uns32T[ H::k ]; | |
1096 assert( G::g[j] ); | |
1097 } | |
1098 | |
1099 | |
1100 H::__initialize_data_structures(); | |
1101 | 1108 |
1102 pu = (Uns32T*)pf; | 1109 pu = (Uns32T*)pf; |
1103 for( j = 0 ; j < H::L ; j++ ) // Z projectors r1 | 1110 for( j = 0 ; j < H::L ; j++ ) // Z projectors r1 |
1104 for( kk = 0 ; kk < H::k ; kk++ ) | 1111 for( kk = 0 ; kk < H::k ; kk++ ) |
1105 H::r1[j][kk] = *pu++; | 1112 H::r1[j][kk] = *pu++; |
1106 | 1113 |
1107 for( j = 0 ; j < H::L ; j++ ) // Z projectors r2 | 1114 for( j = 0 ; j < H::L ; j++ ) // Z projectors r2 |
1108 for( kk = 0 ; kk < H::k ; kk++ ) | 1115 for( kk = 0 ; kk < H::k ; kk++ ) |
1109 H::r2[j][kk] = *pu++; | 1116 H::r2[j][kk] = *pu++; |
1110 | 1117 |
1111 serial_munmap(db, get_serial_hashtable_offset()); | 1118 serial_munmap(db, get_serial_hashtable_offset()); |
1112 } | 1119 } |
1113 | 1120 |
1114 void G::unserialize_lsh_hashtables_format1(int fid){ | 1121 void G::unserialize_lsh_hashtables_format1(int fid){ |
1115 SerialElementT *pe, *pt; | 1122 SerialElementT *pe, *pt; |
1116 Uns32T x,y; | 1123 Uns32T x,y; |
1287 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); | 1294 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); |
1288 if(madvise(db, hashTableSize, MADV_RANDOM)<0) | 1295 if(madvise(db, hashTableSize, MADV_RANDOM)<0) |
1289 error("could not advise local hashtable memory","","madvise"); | 1296 error("could not advise local hashtable memory","","madvise"); |
1290 SerialElementT* pe = (SerialElementT*)db ; | 1297 SerialElementT* pe = (SerialElementT*)db ; |
1291 for(Uns32T qpos=0; qpos<vv.size(); qpos++){ | 1298 for(Uns32T qpos=0; qpos<vv.size(); qpos++){ |
1292 compute_hash_functions(vv[qpos]); | 1299 H::compute_hash_functions(vv[qpos]); |
1293 __generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); | 1300 H::generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); |
1294 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row | 1301 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row |
1295 } | 1302 } |
1296 serial_munmap(db, hashTableSize); // drop hashtable mmap | 1303 serial_munmap(db, hashTableSize); // drop hashtable mmap |
1297 } | 1304 } |
1298 serial_close(dbfid); | 1305 serial_close(dbfid); |
1311 | 1318 |
1312 // size of each hash table | 1319 // size of each hash table |
1313 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; | 1320 Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; |
1314 calling_instance = caller; | 1321 calling_instance = caller; |
1315 add_point_callback = add_point; | 1322 add_point_callback = add_point; |
1316 compute_hash_functions(v); | 1323 H::compute_hash_functions(v); |
1317 for(Uns32T j=0; j<L; j++){ | 1324 for(Uns32T j=0; j<L; j++){ |
1318 // memory map a single hash table for random access | 1325 // memory map a single hash table for random access |
1319 char* db = serial_mmap(dbfid, hashTableSize, 0, | 1326 char* db = serial_mmap(dbfid, hashTableSize, 0, |
1320 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); | 1327 align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); |
1321 if(madvise(db, hashTableSize, MADV_RANDOM)<0) | 1328 if(madvise(db, hashTableSize, MADV_RANDOM)<0) |
1322 error("could not advise local hashtable memory","","madvise"); | 1329 error("could not advise local hashtable memory","","madvise"); |
1323 SerialElementT* pe = (SerialElementT*)db ; | 1330 SerialElementT* pe = (SerialElementT*)db ; |
1324 __generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); | 1331 H::generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); |
1325 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row | 1332 serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row |
1326 serial_munmap(db, hashTableSize); // drop hashtable mmap | 1333 serial_munmap(db, hashTableSize); // drop hashtable mmap |
1327 } | 1334 } |
1328 serial_close(dbfid); | 1335 serial_close(dbfid); |
1329 } | 1336 } |