comparison lshlib.h @ 292:d9a88cfd4ab6

Completed merge of lshlib back to current version of the trunk.
author mas01mc
date Tue, 29 Jul 2008 22:01:17 +0000
parents
children 9fd5340faffd
comparison
equal deleted inserted replaced
291:63ae0dfc1767 292:d9a88cfd4ab6
1 // lshlib.h - a library for locality sensitive hashtable insertion and retrieval
2 //
3 // Author: Michael Casey
4 // Copyright (c) 2008 Michael Casey, All Rights Reserved
5
6 /* GNU GENERAL PUBLIC LICENSE
7 Version 2, June 1991
8 See LICENSE.txt
9 */
10
11 #ifndef __LSHLIB_H
12 #define __LSHLIB_H
13
14 #include <vector>
15 #include <queue>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <sys/types.h>
19 #include <sys/stat.h>
20 #include <sys/mman.h>
21 #include <fcntl.h>
22 #include <string.h>
23 #include <iostream>
24 #include <fstream>
25 #include <math.h>
26 #include <sys/time.h>
27 #include <assert.h>
28 #include <float.h>
29 #include <signal.h>
30 #include <time.h>
31 #include <limits.h>
32 #include <errno.h>
33 #ifdef MT19937
34 #include "mt19937/mt19937ar.h"
35 #endif
36
37 #define IntT int
38 #define LongUns64T long long unsigned
39 #define Uns32T unsigned
40 #define Int32T int
41 #define BooleanT int
42 #define TRUE 1
43 #define FALSE 0
44
45 // A big number (>> max # of points)
46 #define INDEX_START_EMPTY 1000000000U
47
48 // 4294967291 = 2^32-5
49 #define UH_PRIME_DEFAULT 4294967291U
50
51 // 2^29
52 #define MAX_HASH_RND 536870912U
53
54 // 2^32-1
55 #define TWO_TO_32_MINUS_1 4294967295U
56
57 #define O2_SERIAL_VERSION 1 // Sync with SVN version
58 #define O2_SERIAL_HEADER_SIZE sizeof(SerialHeaderT)
59 #define O2_SERIAL_ELEMENT_SIZE sizeof(SerialElementT)
60 #define O2_SERIAL_MAX_TABLES (200)
61 #define O2_SERIAL_MAX_ROWS (1000000)
62 #define O2_SERIAL_MAX_COLS (1000)
63 #define O2_SERIAL_MAX_DIM (2000)
64 #define O2_SERIAL_MAX_FUNS (100)
65 #define O2_SERIAL_MAX_BINWIDTH (200)
66
67 // Flags for Serial Header
68 #define O2_SERIAL_FILEFORMAT1 (0x1U) // Optimize for on-disk search
69 #define O2_SERIAL_FILEFORMAT2 (0x2U) // Optimize for in-core search
70
71 // Flags for serialization fileformat2: use high 3 bits of Uns32T
72 #define O2_SERIAL_FLAGS_T1_BIT (0x80000000U)
73 #define O2_SERIAL_FLAGS_T2_BIT (0x40000000U)
74 #define O2_SERIAL_FLAGS_END_BIT (0x20000000U)
75
76 unsigned align_up(unsigned x, unsigned w);
77
78 #define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \
79 + sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \
80 + sizeof(Uns32T) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * 2 \
81 + O2_SERIAL_HEADER_SIZE,get_page_logn()))
82
83 #define O2_SERIAL_MAX_LSH_SIZE (O2_SERIAL_ELEMENT_SIZE * O2_SERIAL_MAX_TABLES \
84 * O2_SERIAL_MAX_ROWS * O2_SERIAL_MAX_COLS + O2_SERIAL_FUNCTIONS_SIZE)
85
86 #define O2_SERIAL_MAGIC ('o'|'2'<<8|'l'<<16|'s'<<24)
87
88 using namespace std;
89
90 Uns32T get_page_logn();
91
92 // Disk table entry
93 typedef class SerialElement SerialElementT;
94 class SerialElement {
95 public:
96 Uns32T hashValue;
97 Uns32T pointID;
98
99 SerialElement(Uns32T h, Uns32T pID):
100 hashValue(h),
101 pointID(pID){}
102 };
103
104 // Disk header
105 typedef class SerialHeader SerialHeaderT;
106 class SerialHeader {
107 public:
108 Uns32T lshMagic; // unique identifier for file header
109 float binWidth; // hash-function bin width
110 Uns32T numTables; // number of hash tables in file
111 Uns32T numRows; // size of each hash table
112 Uns32T numCols; // max collisions in each hash table
113 Uns32T elementSize; // size of a hash bucket
114 Uns32T version; // version number of file format
115 Uns32T size; // total size of database (bytes)
116 Uns32T flags; // 32 bits of useful information
117 Uns32T dataDim; // vector dimensionality
118 Uns32T numFuns; // number of independent hash functions
119 float radius; // 32-bit floating point radius
120 Uns32T maxp; // number of unique IDs in the database
121 Uns32T value_14; // spare value
122 Uns32T value_15; // spare value
123 Uns32T value_16; // spare value
124
125 SerialHeader();
126 SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float radius, Uns32T p, Uns32T FMT);
127
128 float get_binWidth(){return binWidth;}
129 Uns32T get_numTables(){return numTables;}
130 Uns32T get_numRows(){return numRows;}
131 Uns32T get_numCols(){return numCols;}
132 Uns32T get_elementSize(){return elementSize;}
133 Uns32T get_version(){return version;}
134 Uns32T get_flags(){return flags;}
135 Uns32T get_size(){return size;}
136 Uns32T get_dataDim(){return dataDim;}
137 Uns32T get_numFuns(){return numFuns;}
138 Uns32T get_maxp(){return maxp;}
139 };
140
141 #define IFLAG 0xFFFFFFFF
142
143 // Point-set collision bucket (sbucket).
144 // sbuckets form a collision chain that identifies PointIDs falling in the same locale.
145 // sbuckets are chained from a bucket containing the collision list's t2 identifier
146 class sbucket {
147 friend class bucket;
148 friend class H;
149 friend class G;
150
151 public:
152 class sbucket* snext;
153 unsigned int pointID;
154
155 sbucket(){
156 snext=0;
157 pointID=IFLAG;
158 }
159 ~sbucket(){delete snext;}
160 sbucket* get_snext(){return snext;}
161 };
162
163 // bucket structure for a linked list of locales that collide with the same hash value t1
164 // different buckets represent different locales, collisions within a locale are chained
165 // in sbuckets
166 class bucket {
167 friend class H;
168 friend class G;
169 bucket* next;
170 sbucket* snext;
171 public:
172 unsigned int t2;
173 bucket(){
174 next=0;
175 snext=0;
176 t2=IFLAG;
177 }
178 ~bucket(){delete next;delete snext;}
179 bucket* get_next(){return next;}
180 };
181
182
183 // The hash_functions for locality-sensitive hashing
184 class H{
185 friend class G;
186 private:
187 bucket*** h; // hash functions
188 Uns32T** r1; // random ints for hashing
189 Uns32T** r2; // random ints for hashing
190 Uns32T t1; // first hash table key
191 Uns32T t2; // second hash table key
192 Uns32T P; // hash table prime number
193 bool use_u_functions; // flag to optimize computation of hashes
194 Uns32T bucketCount; // count of number of point buckets allocated
195 Uns32T pointCount; // count of number of points inserted
196
197 void __initialize_data_structures();
198 void __bucket_insert_point(bucket*);
199 void __sbucket_insert_point(sbucket*);
200 Uns32T __computeProductModDefaultPrime(Uns32T*,Uns32T*,IntT);
201 Uns32T __randr();
202 bucket** __get_bucket(int j);
203 void __generate_hash_keys(Uns32T*,Uns32T*,Uns32T*);
204 void error(const char* a, const char* b = "", const char *sysFunc = 0);
205
206 public:
207 Uns32T N; // num rows per table
208 Uns32T C; // num collision per row
209 Uns32T k; // num projections per hash function
210 Uns32T m; // ~sqrt num hash tables
211 Uns32T L; // L = m*(m-1)/2, conversely, m = (1 + sqrt(1 + 8.0*L)) / 2.0
212 Uns32T d; // dimensions
213 Uns32T p; // current point
214
215 H(){;}
216 H(Uns32T k, Uns32T m, Uns32T d, Uns32T N, Uns32T C);
217 ~H();
218
219
220 Uns32T bucket_insert_point(bucket**);
221
222 Uns32T get_t1(){return t1;}
223 Uns32T get_t2(){return t2;}
224
225 };
226
227
228 typedef void (*ReporterCallbackPtr)(void* objPtr, Uns32T pointID, Uns32T queryIndex, float squaredDistance);
229
230 // Interface for indexing and retrieval
231 class G: public H{
232 private:
233 float *** A; // m x k x d random projectors from R^d->R^k
234 float ** b; // m x k uniform additive constants
235 Uns32T ** g; // L x k random hash projections \in Z^k
236 float w; // width of hash slots (relative to normalized feature space)
237 float radius;// scaling coefficient for data (1./radius)
238 vector<vector<Uns32T> > uu; // Storage for m patial hash evaluations
239 Uns32T maxp; // highest pointID stored in database
240 void* calling_instance; // store calling object instance for member-function callback
241 void (*add_point_callback)(void*, Uns32T, Uns32T, float);
242
243 void initialize_partial_functions();
244
245 void get_lock(int fd, bool exclusive);
246 void release_lock(int fd);
247 int serial_create(char* filename, Uns32T FMT);
248 int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT);
249 char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0);
250 void serial_munmap(char* db, Uns32T N);
251 int serial_open(char* filename,int writeFlag);
252 void serial_close(int dbfid);
253
254 // Function to write hashfunctions to disk
255 int serialize_lsh_hashfunctions(int fid);
256
257 // Functions to write hashtables to disk in format1 (optimized for on-disk retrieval)
258 int serialize_lsh_hashtables_format1(int fid, int merge);
259 void serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* h, Uns32T& colCount);
260 void serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
261 void serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* h, Uns32T& colCount);
262 void serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
263 int serial_can_merge(Uns32T requestedFormat); // Test to see whether core and on-disk structures are compatible
264
265 // Functions to write hashtables to disk in format2 (optimized for in-core retrieval)
266 int serialize_lsh_hashtables_format2(int fid, int merge);
267 void serial_write_hashtable_row_format2(int fid, bucket* h, Uns32T& colCount);
268 void serial_write_element_format2(int fid, sbucket* sb, Uns32T& colCount);
269
270 // Functions to read serial header and hash functions (format1 and format2)
271 int unserialize_lsh_header(char* filename); // read lsh header from disk into core
272 void unserialize_lsh_functions(int fid); // read the lsh hash functions into core
273
274 // Functions to read hashtables in format1
275 void unserialize_lsh_hashtables_format1(int fid); // read FORMAT1 hash tables into core (disk format)
276 void unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b); // read lsh hash table row into core
277
278 // Functions to read hashtables in format2
279 void unserialize_lsh_hashtables_format2(int fid); // read FORMAT2 hash tables into core (core format)
280 Uns32T unserialize_hashtable_row_format2(int fid, bucket** b); // read lsh hash table row into core
281
282 // Helper functions
283 void serial_print_header(Uns32T requestedFormat);
284 float* get_serial_hashfunction_base(char* db);
285 SerialElementT* get_serial_hashtable_base(char* db);
286 Uns32T get_serial_hashtable_offset(); // Size of SerialHeader + HashFunctions
287 SerialHeaderT* serial_get_header(char* db);
288 SerialHeaderT* lshHeader;
289
290 // Core Retrieval/Inspections Functions
291 void bucket_chain_point(bucket* p, Uns32T qpos);
292 void sbucket_chain_point(sbucket* p, Uns32T qpos);
293 void dump_hashtable_row(bucket* p);
294
295 // Serial (Format 1) Retrieval/Inspection Functions
296 void serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos);
297 void serial_bucket_dump(SerialElementT* pe);
298
299 // Hash functions
300 void compute_hash_functions(vector<float>& v);
301 float randn();
302 float ranf();
303
304 char* db; // pointer to serialized structure
305
306 public:
307 G(char* lshFile, bool lshInCore = false); // unserialize constructor
308 G(float w, Uns32T k,Uns32T m, Uns32T d, Uns32T N, Uns32T C, float r); // core constructor
309 ~G();
310
311 Uns32T insert_point(vector<float>&, Uns32T pointID);
312 void insert_point_set(vector<vector<float> >& vv, Uns32T basePointID);
313
314 // point retrieval from core
315 void retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);
316 // point set retrieval from core
317 void retrieve_point_set(vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
318 // serial point set retrieval
319 void serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
320 // serial point retrieval
321 void serial_retrieve_point(char* filename, vector<float>& vv, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);
322
323 void serialize(char* filename, Uns32T serialFormat = O2_SERIAL_FILEFORMAT1); // write hashfunctions and hashtables to disk
324
325 SerialHeaderT* get_lshHeader(){return lshHeader;}
326 float get_radius(){return radius;}
327 Uns32T get_maxp(){return maxp;}
328 void serial_dump_tables(char* filename);
329 float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; }
330 };
331
332 typedef class G LSH;
333
334
335
336 #endif