Mercurial > hg > audiodb
comparison lshlib.h @ 292:d9a88cfd4ab6
Completed merge of lshlib back to current version of the trunk.
author | mas01mc |
---|---|
date | Tue, 29 Jul 2008 22:01:17 +0000 |
parents | |
children | 9fd5340faffd |
comparison
equal
deleted
inserted
replaced
291:63ae0dfc1767 | 292:d9a88cfd4ab6 |
---|---|
1 // lshlib.h - a library for locality sensitive hashtable insertion and retrieval | |
2 // | |
3 // Author: Michael Casey | |
4 // Copyright (c) 2008 Michael Casey, All Rights Reserved | |
5 | |
6 /* GNU GENERAL PUBLIC LICENSE | |
7 Version 2, June 1991 | |
8 See LICENSE.txt | |
9 */ | |
10 | |
11 #ifndef __LSHLIB_H | |
12 #define __LSHLIB_H | |
13 | |
14 #include <vector> | |
15 #include <queue> | |
16 #include <stdio.h> | |
17 #include <stdlib.h> | |
18 #include <sys/types.h> | |
19 #include <sys/stat.h> | |
20 #include <sys/mman.h> | |
21 #include <fcntl.h> | |
22 #include <string.h> | |
23 #include <iostream> | |
24 #include <fstream> | |
25 #include <math.h> | |
26 #include <sys/time.h> | |
27 #include <assert.h> | |
28 #include <float.h> | |
29 #include <signal.h> | |
30 #include <time.h> | |
31 #include <limits.h> | |
32 #include <errno.h> | |
33 #ifdef MT19937 | |
34 #include "mt19937/mt19937ar.h" | |
35 #endif | |
36 | |
37 #define IntT int | |
38 #define LongUns64T long long unsigned | |
39 #define Uns32T unsigned | |
40 #define Int32T int | |
41 #define BooleanT int | |
42 #define TRUE 1 | |
43 #define FALSE 0 | |
44 | |
45 // A big number (>> max # of points) | |
46 #define INDEX_START_EMPTY 1000000000U | |
47 | |
48 // 4294967291 = 2^32-5 | |
49 #define UH_PRIME_DEFAULT 4294967291U | |
50 | |
51 // 2^29 | |
52 #define MAX_HASH_RND 536870912U | |
53 | |
54 // 2^32-1 | |
55 #define TWO_TO_32_MINUS_1 4294967295U | |
56 | |
57 #define O2_SERIAL_VERSION 1 // Sync with SVN version | |
58 #define O2_SERIAL_HEADER_SIZE sizeof(SerialHeaderT) | |
59 #define O2_SERIAL_ELEMENT_SIZE sizeof(SerialElementT) | |
60 #define O2_SERIAL_MAX_TABLES (200) | |
61 #define O2_SERIAL_MAX_ROWS (1000000) | |
62 #define O2_SERIAL_MAX_COLS (1000) | |
63 #define O2_SERIAL_MAX_DIM (2000) | |
64 #define O2_SERIAL_MAX_FUNS (100) | |
65 #define O2_SERIAL_MAX_BINWIDTH (200) | |
66 | |
67 // Flags for Serial Header | |
68 #define O2_SERIAL_FILEFORMAT1 (0x1U) // Optimize for on-disk search | |
69 #define O2_SERIAL_FILEFORMAT2 (0x2U) // Optimize for in-core search | |
70 | |
71 // Flags for serialization fileformat2: use high 3 bits of Uns32T | |
72 #define O2_SERIAL_FLAGS_T1_BIT (0x80000000U) | |
73 #define O2_SERIAL_FLAGS_T2_BIT (0x40000000U) | |
74 #define O2_SERIAL_FLAGS_END_BIT (0x20000000U) | |
75 | |
76 unsigned align_up(unsigned x, unsigned w); | |
77 | |
78 #define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \ | |
79 + sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \ | |
80 + sizeof(Uns32T) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * 2 \ | |
81 + O2_SERIAL_HEADER_SIZE,get_page_logn())) | |
82 | |
83 #define O2_SERIAL_MAX_LSH_SIZE (O2_SERIAL_ELEMENT_SIZE * O2_SERIAL_MAX_TABLES \ | |
84 * O2_SERIAL_MAX_ROWS * O2_SERIAL_MAX_COLS + O2_SERIAL_FUNCTIONS_SIZE) | |
85 | |
86 #define O2_SERIAL_MAGIC ('o'|'2'<<8|'l'<<16|'s'<<24) | |
87 | |
88 using namespace std; | |
89 | |
90 Uns32T get_page_logn(); | |
91 | |
92 // Disk table entry | |
93 typedef class SerialElement SerialElementT; | |
94 class SerialElement { | |
95 public: | |
96 Uns32T hashValue; | |
97 Uns32T pointID; | |
98 | |
99 SerialElement(Uns32T h, Uns32T pID): | |
100 hashValue(h), | |
101 pointID(pID){} | |
102 }; | |
103 | |
104 // Disk header | |
105 typedef class SerialHeader SerialHeaderT; | |
106 class SerialHeader { | |
107 public: | |
108 Uns32T lshMagic; // unique identifier for file header | |
109 float binWidth; // hash-function bin width | |
110 Uns32T numTables; // number of hash tables in file | |
111 Uns32T numRows; // size of each hash table | |
112 Uns32T numCols; // max collisions in each hash table | |
113 Uns32T elementSize; // size of a hash bucket | |
114 Uns32T version; // version number of file format | |
115 Uns32T size; // total size of database (bytes) | |
116 Uns32T flags; // 32 bits of useful information | |
117 Uns32T dataDim; // vector dimensionality | |
118 Uns32T numFuns; // number of independent hash functions | |
119 float radius; // 32-bit floating point radius | |
120 Uns32T maxp; // number of unique IDs in the database | |
121 Uns32T value_14; // spare value | |
122 Uns32T value_15; // spare value | |
123 Uns32T value_16; // spare value | |
124 | |
125 SerialHeader(); | |
126 SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float radius, Uns32T p, Uns32T FMT); | |
127 | |
128 float get_binWidth(){return binWidth;} | |
129 Uns32T get_numTables(){return numTables;} | |
130 Uns32T get_numRows(){return numRows;} | |
131 Uns32T get_numCols(){return numCols;} | |
132 Uns32T get_elementSize(){return elementSize;} | |
133 Uns32T get_version(){return version;} | |
134 Uns32T get_flags(){return flags;} | |
135 Uns32T get_size(){return size;} | |
136 Uns32T get_dataDim(){return dataDim;} | |
137 Uns32T get_numFuns(){return numFuns;} | |
138 Uns32T get_maxp(){return maxp;} | |
139 }; | |
140 | |
141 #define IFLAG 0xFFFFFFFF | |
142 | |
143 // Point-set collision bucket (sbucket). | |
144 // sbuckets form a collision chain that identifies PointIDs falling in the same locale. | |
145 // sbuckets are chained from a bucket containing the collision list's t2 identifier | |
146 class sbucket { | |
147 friend class bucket; | |
148 friend class H; | |
149 friend class G; | |
150 | |
151 public: | |
152 class sbucket* snext; | |
153 unsigned int pointID; | |
154 | |
155 sbucket(){ | |
156 snext=0; | |
157 pointID=IFLAG; | |
158 } | |
159 ~sbucket(){delete snext;} | |
160 sbucket* get_snext(){return snext;} | |
161 }; | |
162 | |
163 // bucket structure for a linked list of locales that collide with the same hash value t1 | |
164 // different buckets represent different locales, collisions within a locale are chained | |
165 // in sbuckets | |
166 class bucket { | |
167 friend class H; | |
168 friend class G; | |
169 bucket* next; | |
170 sbucket* snext; | |
171 public: | |
172 unsigned int t2; | |
173 bucket(){ | |
174 next=0; | |
175 snext=0; | |
176 t2=IFLAG; | |
177 } | |
178 ~bucket(){delete next;delete snext;} | |
179 bucket* get_next(){return next;} | |
180 }; | |
181 | |
182 | |
183 // The hash_functions for locality-sensitive hashing | |
184 class H{ | |
185 friend class G; | |
186 private: | |
187 bucket*** h; // hash functions | |
188 Uns32T** r1; // random ints for hashing | |
189 Uns32T** r2; // random ints for hashing | |
190 Uns32T t1; // first hash table key | |
191 Uns32T t2; // second hash table key | |
192 Uns32T P; // hash table prime number | |
193 bool use_u_functions; // flag to optimize computation of hashes | |
194 Uns32T bucketCount; // count of number of point buckets allocated | |
195 Uns32T pointCount; // count of number of points inserted | |
196 | |
197 void __initialize_data_structures(); | |
198 void __bucket_insert_point(bucket*); | |
199 void __sbucket_insert_point(sbucket*); | |
200 Uns32T __computeProductModDefaultPrime(Uns32T*,Uns32T*,IntT); | |
201 Uns32T __randr(); | |
202 bucket** __get_bucket(int j); | |
203 void __generate_hash_keys(Uns32T*,Uns32T*,Uns32T*); | |
204 void error(const char* a, const char* b = "", const char *sysFunc = 0); | |
205 | |
206 public: | |
207 Uns32T N; // num rows per table | |
208 Uns32T C; // num collision per row | |
209 Uns32T k; // num projections per hash function | |
210 Uns32T m; // ~sqrt num hash tables | |
211 Uns32T L; // L = m*(m-1)/2, conversely, m = (1 + sqrt(1 + 8.0*L)) / 2.0 | |
212 Uns32T d; // dimensions | |
213 Uns32T p; // current point | |
214 | |
215 H(){;} | |
216 H(Uns32T k, Uns32T m, Uns32T d, Uns32T N, Uns32T C); | |
217 ~H(); | |
218 | |
219 | |
220 Uns32T bucket_insert_point(bucket**); | |
221 | |
222 Uns32T get_t1(){return t1;} | |
223 Uns32T get_t2(){return t2;} | |
224 | |
225 }; | |
226 | |
227 | |
228 typedef void (*ReporterCallbackPtr)(void* objPtr, Uns32T pointID, Uns32T queryIndex, float squaredDistance); | |
229 | |
230 // Interface for indexing and retrieval | |
231 class G: public H{ | |
232 private: | |
233 float *** A; // m x k x d random projectors from R^d->R^k | |
234 float ** b; // m x k uniform additive constants | |
235 Uns32T ** g; // L x k random hash projections \in Z^k | |
236 float w; // width of hash slots (relative to normalized feature space) | |
237 float radius;// scaling coefficient for data (1./radius) | |
238 vector<vector<Uns32T> > uu; // Storage for m patial hash evaluations | |
239 Uns32T maxp; // highest pointID stored in database | |
240 void* calling_instance; // store calling object instance for member-function callback | |
241 void (*add_point_callback)(void*, Uns32T, Uns32T, float); | |
242 | |
243 void initialize_partial_functions(); | |
244 | |
245 void get_lock(int fd, bool exclusive); | |
246 void release_lock(int fd); | |
247 int serial_create(char* filename, Uns32T FMT); | |
248 int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT); | |
249 char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0); | |
250 void serial_munmap(char* db, Uns32T N); | |
251 int serial_open(char* filename,int writeFlag); | |
252 void serial_close(int dbfid); | |
253 | |
254 // Function to write hashfunctions to disk | |
255 int serialize_lsh_hashfunctions(int fid); | |
256 | |
257 // Functions to write hashtables to disk in format1 (optimized for on-disk retrieval) | |
258 int serialize_lsh_hashtables_format1(int fid, int merge); | |
259 void serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* h, Uns32T& colCount); | |
260 void serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount); | |
261 void serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* h, Uns32T& colCount); | |
262 void serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount); | |
263 int serial_can_merge(Uns32T requestedFormat); // Test to see whether core and on-disk structures are compatible | |
264 | |
265 // Functions to write hashtables to disk in format2 (optimized for in-core retrieval) | |
266 int serialize_lsh_hashtables_format2(int fid, int merge); | |
267 void serial_write_hashtable_row_format2(int fid, bucket* h, Uns32T& colCount); | |
268 void serial_write_element_format2(int fid, sbucket* sb, Uns32T& colCount); | |
269 | |
270 // Functions to read serial header and hash functions (format1 and format2) | |
271 int unserialize_lsh_header(char* filename); // read lsh header from disk into core | |
272 void unserialize_lsh_functions(int fid); // read the lsh hash functions into core | |
273 | |
274 // Functions to read hashtables in format1 | |
275 void unserialize_lsh_hashtables_format1(int fid); // read FORMAT1 hash tables into core (disk format) | |
276 void unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b); // read lsh hash table row into core | |
277 | |
278 // Functions to read hashtables in format2 | |
279 void unserialize_lsh_hashtables_format2(int fid); // read FORMAT2 hash tables into core (core format) | |
280 Uns32T unserialize_hashtable_row_format2(int fid, bucket** b); // read lsh hash table row into core | |
281 | |
282 // Helper functions | |
283 void serial_print_header(Uns32T requestedFormat); | |
284 float* get_serial_hashfunction_base(char* db); | |
285 SerialElementT* get_serial_hashtable_base(char* db); | |
286 Uns32T get_serial_hashtable_offset(); // Size of SerialHeader + HashFunctions | |
287 SerialHeaderT* serial_get_header(char* db); | |
288 SerialHeaderT* lshHeader; | |
289 | |
290 // Core Retrieval/Inspections Functions | |
291 void bucket_chain_point(bucket* p, Uns32T qpos); | |
292 void sbucket_chain_point(sbucket* p, Uns32T qpos); | |
293 void dump_hashtable_row(bucket* p); | |
294 | |
295 // Serial (Format 1) Retrieval/Inspection Functions | |
296 void serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos); | |
297 void serial_bucket_dump(SerialElementT* pe); | |
298 | |
299 // Hash functions | |
300 void compute_hash_functions(vector<float>& v); | |
301 float randn(); | |
302 float ranf(); | |
303 | |
304 char* db; // pointer to serialized structure | |
305 | |
306 public: | |
307 G(char* lshFile, bool lshInCore = false); // unserialize constructor | |
308 G(float w, Uns32T k,Uns32T m, Uns32T d, Uns32T N, Uns32T C, float r); // core constructor | |
309 ~G(); | |
310 | |
311 Uns32T insert_point(vector<float>&, Uns32T pointID); | |
312 void insert_point_set(vector<vector<float> >& vv, Uns32T basePointID); | |
313 | |
314 // point retrieval from core | |
315 void retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr, void* me=NULL); | |
316 // point set retrieval from core | |
317 void retrieve_point_set(vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL); | |
318 // serial point set retrieval | |
319 void serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL); | |
320 // serial point retrieval | |
321 void serial_retrieve_point(char* filename, vector<float>& vv, Uns32T qpos, ReporterCallbackPtr, void* me=NULL); | |
322 | |
323 void serialize(char* filename, Uns32T serialFormat = O2_SERIAL_FILEFORMAT1); // write hashfunctions and hashtables to disk | |
324 | |
325 SerialHeaderT* get_lshHeader(){return lshHeader;} | |
326 float get_radius(){return radius;} | |
327 Uns32T get_maxp(){return maxp;} | |
328 void serial_dump_tables(char* filename); | |
329 float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; } | |
330 }; | |
331 | |
332 typedef class G LSH; | |
333 | |
334 | |
335 | |
336 #endif |