mas01mc@292: // lshlib.h - a library for locality sensitive hashtable insertion and retrieval mas01mc@292: // mas01mc@292: // Author: Michael Casey mas01mc@292: // Copyright (c) 2008 Michael Casey, All Rights Reserved mas01mc@292: mas01mc@292: /* GNU GENERAL PUBLIC LICENSE mas01mc@292: Version 2, June 1991 mas01mc@292: See LICENSE.txt mas01mc@292: */ mas01mc@292: mas01mc@292: #ifndef __LSHLIB_H mas01mc@292: #define __LSHLIB_H mas01mc@292: mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #include mas01mc@764: #ifdef MT19937 mas01mc@764: #include "mt19937/mt19937ar.h" mas01mc@764: #endif mas01mc@764: #include "multiprobe.h" mas01mc@292: mas01mc@292: #define IntT int mas01mc@292: #define LongUns64T long long unsigned mas01mc@292: #define Uns32T unsigned mas01mc@292: #define Int32T int mas01mc@292: #define BooleanT int mas01mc@764: #define TRUE 1 mas01mc@764: #define FALSE 0 mas01mc@292: mas01mc@292: // A big number (>> max # of points) mas01mc@292: #define INDEX_START_EMPTY 1000000000U mas01mc@292: mas01mc@292: // 4294967291 = 2^32-5 mas01mc@292: #define UH_PRIME_DEFAULT 4294967291U mas01mc@292: mas01mc@292: // 2^29 mas01mc@292: #define MAX_HASH_RND 536870912U mas01mc@292: mas01mc@292: // 2^32-1 mas01mc@292: #define TWO_TO_32_MINUS_1 4294967295U mas01mc@292: mas01mc@292: #define O2_SERIAL_VERSION 1 // Sync with SVN version mas01mc@292: #define O2_SERIAL_HEADER_SIZE sizeof(SerialHeaderT) mas01mc@292: #define O2_SERIAL_ELEMENT_SIZE sizeof(SerialElementT) mas01mc@292: #define O2_SERIAL_MAX_TABLES (200) mas01mc@324: #define O2_SERIAL_MAX_ROWS (1000000000) mas01mc@324: #define O2_SERIAL_MAX_COLS (1000000) mas01mc@464: #define O2_SERIAL_MAX_DIM (20000) mas01mc@292: #define O2_SERIAL_MAX_FUNS (100) mas01mc@292: #define O2_SERIAL_MAX_BINWIDTH (200) mas01mc@296: #define O2_SERIAL_MAXFILESIZE (4000000000UL) mas01mc@292: mas01mc@292: // Flags for Serial Header mas01mc@324: #define O2_SERIAL_FILEFORMAT1 (0x1U) // Optimize disk format for on-disk search mas01mc@324: #define O2_SERIAL_FILEFORMAT2 (0x2U) // Optimize disk format for in-core search mas01mc@324: #define O2_SERIAL_COREFORMAT1 (0x4U) mas01mc@324: #define O2_SERIAL_COREFORMAT2 (0x8U) mas01mc@292: mas01mc@292: // Flags for serialization fileformat2: use high 3 bits of Uns32T mas01mc@324: #define O2_SERIAL_TOKEN_T1 (0xFFFFFFFCU) mas01mc@306: #define O2_SERIAL_TOKEN_T2 (0xFFFFFFFDU) mas01mc@306: #define O2_SERIAL_TOKEN_ENDTABLE (0xFFFFFFFEU) mas01mc@292: mas01mc@324: #define O2_INDEX_MAXSTR (256) mas01mc@309: mas01mc@764: unsigned align_up(unsigned x, unsigned w); mas01mc@292: mas01mc@292: #define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \ mas01mc@292: + sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \ mas01mc@292: + sizeof(Uns32T) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * 2 \ mas01mc@292: + O2_SERIAL_HEADER_SIZE,get_page_logn())) mas01mc@292: mas01mc@292: #define O2_SERIAL_MAX_LSH_SIZE (O2_SERIAL_ELEMENT_SIZE * O2_SERIAL_MAX_TABLES \ mas01mc@292: * O2_SERIAL_MAX_ROWS * O2_SERIAL_MAX_COLS + O2_SERIAL_FUNCTIONS_SIZE) mas01mc@292: mas01mc@292: #define O2_SERIAL_MAGIC ('o'|'2'<<8|'l'<<16|'s'<<24) mas01mc@292: mas01mc@340: #define WRITE_UNS32(VAL, TOKENSTR) if( fwrite(VAL, sizeof(Uns32T), 1, dbFile) != 1 ){\ mas01mc@340: fclose(dbFile);error("write error in serial_write_format2",TOKENSTR);} mas01mc@340: mas01mc@514: //#define LSH_DUMP_CORE_TABLES // set to dump hashtables on load mas01mc@764: //#define _LSH_DEBUG_ // turn on debugging information mas01mc@514: //#define USE_U_FUNCTIONS // set to use partial hashfunction re-use mas01mc@514: mas01mc@514: // Backward-compatible CORE ARRAY lsh index mas01mc@514: #define LSH_CORE_ARRAY // Set to use arrays for hashtables rather than linked-lists mas01mc@514: #define LSH_LIST_HEAD_COUNTERS // Enable counters in hashtable list heads mas01mc@514: mas01mc@514: // Critical path logic mas01mc@514: #if defined LSH_CORE_ARRAY && !defined LSH_LIST_HEAD_COUNTERS mas01mc@514: #define LSH_LIST_HEAD_COUNTERS mas01mc@514: #endif mas01mc@514: mas01mc@514: #define LSH_CORE_ARRAY_BIT (0x80000000) // LSH_CORE_ARRAY test bit for list head mas01mc@514: mas01mc@764: #ifndef LSH_MULTI_PROBE_COUNT mas01mc@764: #define LSH_MULTI_PROBE_COUNT 1 // How many adjacent hash-buckets to probe in LSH retrieval mas01mc@764: #endif mas01mc@764: mas01mc@292: Uns32T get_page_logn(); mas01mc@292: mas01mc@764: using namespace std; mas01mc@764: mas01mc@292: // Disk table entry mas01mc@292: typedef class SerialElement SerialElementT; mas01mc@292: class SerialElement { mas01mc@292: public: mas01mc@292: Uns32T hashValue; mas01mc@292: Uns32T pointID; mas01mc@292: mas01mc@292: SerialElement(Uns32T h, Uns32T pID): mas01mc@292: hashValue(h), mas01mc@292: pointID(pID){} mas01mc@292: }; mas01mc@292: mas01mc@292: // Disk header mas01mc@292: typedef class SerialHeader SerialHeaderT; mas01mc@292: class SerialHeader { mas01mc@292: public: mas01mc@292: Uns32T lshMagic; // unique identifier for file header mas01mc@292: float binWidth; // hash-function bin width mas01mc@292: Uns32T numTables; // number of hash tables in file mas01mc@292: Uns32T numRows; // size of each hash table mas01mc@292: Uns32T numCols; // max collisions in each hash table mas01mc@292: Uns32T elementSize; // size of a hash bucket mas01mc@292: Uns32T version; // version number of file format mas01mc@292: Uns32T size; // total size of database (bytes) mas01mc@292: Uns32T flags; // 32 bits of useful information mas01mc@292: Uns32T dataDim; // vector dimensionality mas01mc@292: Uns32T numFuns; // number of independent hash functions mas01mc@292: float radius; // 32-bit floating point radius mas01mc@292: Uns32T maxp; // number of unique IDs in the database mas01mc@296: unsigned long long size_long; // long version of size mas01mc@296: Uns32T pointCount; // number of points in the database mas01mc@292: mas01mc@292: SerialHeader(); mas01mc@296: SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float radius, Uns32T p, Uns32T FMT, Uns32T pointCount); mas01mc@292: mas01mc@292: float get_binWidth(){return binWidth;} mas01mc@292: Uns32T get_numTables(){return numTables;} mas01mc@292: Uns32T get_numRows(){return numRows;} mas01mc@292: Uns32T get_numCols(){return numCols;} mas01mc@292: Uns32T get_elementSize(){return elementSize;} mas01mc@292: Uns32T get_version(){return version;} mas01mc@292: Uns32T get_flags(){return flags;} mas01mc@296: unsigned long long get_size(){return size_long;} mas01mc@292: Uns32T get_dataDim(){return dataDim;} mas01mc@292: Uns32T get_numFuns(){return numFuns;} mas01mc@292: Uns32T get_maxp(){return maxp;} mas01mc@296: Uns32T get_pointCount(){return pointCount;} mas01mc@292: }; mas01mc@292: mas01mc@292: #define IFLAG 0xFFFFFFFF mas01mc@292: mas01mc@292: // Point-set collision bucket (sbucket). mas01mc@292: // sbuckets form a collision chain that identifies PointIDs falling in the same locale. mas01mc@292: // sbuckets are chained from a bucket containing the collision list's t2 identifier mas01mc@292: class sbucket { mas01mc@292: friend class bucket; mas01mc@292: friend class H; mas01mc@292: friend class G; mas01mc@292: mas01mc@292: public: mas01mc@292: class sbucket* snext; mas01mc@292: unsigned int pointID; mas01mc@292: mas01mc@292: sbucket(){ mas01mc@292: snext=0; mas01mc@292: pointID=IFLAG; mas01mc@292: } mas01mc@292: ~sbucket(){delete snext;} mas01mc@292: sbucket* get_snext(){return snext;} mas01mc@292: }; mas01mc@292: mas01mc@292: // bucket structure for a linked list of locales that collide with the same hash value t1 mas01mc@292: // different buckets represent different locales, collisions within a locale are chained mas01mc@292: // in sbuckets mas01mc@292: class bucket { mas01mc@292: friend class H; mas01mc@292: friend class G; mas01mc@292: bucket* next; mas01mc@344: union { mas01mc@344: sbucket* ptr; mas01mc@344: Uns32T numBuckets; mas01mc@344: } snext; mas01mc@292: public: mas01mc@292: unsigned int t2; mas01mc@292: bucket(){ mas01mc@292: next=0; mas01mc@344: snext.ptr=0; mas01mc@292: t2=IFLAG; mas01mc@292: } mas01mc@344: ~bucket(){delete next;delete snext.ptr;} mas01mc@292: bucket* get_next(){return next;} mas01mc@292: }; mas01mc@292: mas01mc@292: mas01mc@292: // The hash_functions for locality-sensitive hashing mas01mc@292: class H{ mas01mc@292: friend class G; mas01mc@292: private: mas01mc@293: mas01mc@293: float *** A; // m x k x d random projectors from R^d->R^k mas01mc@293: float ** b; // m x k uniform additive constants mas01mc@293: mas01mc@293: Uns32T ** g; // L x k random hash projections \in Z^k mas01mc@292: Uns32T** r1; // random ints for hashing mas01mc@292: Uns32T** r2; // random ints for hashing mas01mc@293: mas01mc@293: bucket*** h; // The LSH hash tables mas01mc@293: mas01mc@293: bool use_u_functions; // flag to optimize computation of hashes mas01mc@514: #ifdef USE_U_FUNCTIONS mas01mc@293: vector > uu; // Storage for m patial hash evaluations ( g_j = [u_a,u_b] ) mas01mc@514: #endif mas01mc@293: Uns32T maxp; // highest pointID stored in database mas01mc@293: Uns32T bucketCount; // count of number of point buckets allocated mas01mc@293: Uns32T pointCount; // count of number of points inserted mas01mc@296: Uns32T collisionCount; // number of points collided in a hash-table row mas01mc@764: Uns32T tablesPointCount; // count of points per hash table on load mas01mc@293: mas01mc@292: Uns32T t1; // first hash table key mas01mc@292: Uns32T t2; // second hash table key mas01mc@292: Uns32T P; // hash table prime number mas01mc@292: mas01mc@764: mas01mc@292: Uns32T N; // num rows per table mas01mc@292: Uns32T C; // num collision per row mas01mc@292: Uns32T k; // num projections per hash function mas01mc@292: Uns32T m; // ~sqrt num hash tables mas01mc@292: Uns32T L; // L = m*(m-1)/2, conversely, m = (1 + sqrt(1 + 8.0*L)) / 2.0 mas01mc@292: Uns32T d; // dimensions mas01mc@292: Uns32T p; // current point mas01mc@293: float w; // width of hash slots (relative to normalized feature space) mas01mc@293: float radius;// scaling coefficient for data (1./radius) mas01mc@292: mas01mc@764: MultiProbe* multiProbePtr; // Utility class for handling multi-probe queries mas01mc@764: float ** boundaryDistances; // Array of query bucket-boundary-distances per hashtable mas01mc@764: mas01mc@293: void initialize_data_structures(); mas01mc@293: void initialize_lsh_functions(); mas01mc@293: void initialize_partial_functions(); mas01mc@293: void __bucket_insert_point(bucket*); mas01mc@293: void __sbucket_insert_point(sbucket*); mas01mc@340: bucket** get_pointer_to_bucket_linked_list(bucket* rowPtr); mas01mc@293: Uns32T computeProductModDefaultPrime(Uns32T*,Uns32T*,IntT); mas01mc@293: Uns32T randr(); mas01mc@293: float randn(); mas01mc@293: float ranf(); mas01mc@293: bucket** get_bucket(int j); mas01mc@293: void error(const char* a, const char* b = "", const char *sysFunc = 0); mas01mc@293: mas01mc@293: public: mas01mc@293: mas01mc@293: H(); mas01mc@293: H(Uns32T k, Uns32T m, Uns32T d, Uns32T N, Uns32T C, float w, float r); mas01mc@474: virtual ~H(); mas01mc@292: mas01mc@293: float get_w(){return w;} mas01mc@293: float get_radius(){return radius;} mas01mc@293: Uns32T get_numRows(){return N;} mas01mc@293: Uns32T get_numCols(){return C;} mas01mc@293: Uns32T get_numFuns(){return k;} mas01mc@293: Uns32T get_numTables(){return L;} mas01mc@293: Uns32T get_dataDim(){return d;} mas01mc@293: Uns32T get_maxp(){return maxp;} mas01mc@292: mas01mc@292: Uns32T bucket_insert_point(bucket**); mas01mc@292: mas01mc@293: // Interface to hash functions mas01mc@293: void compute_hash_functions(vector& v); mas01mc@293: void generate_hash_keys(Uns32T* g, Uns32T* r1, Uns32T* r2); mas01mc@764: void generate_multiprobe_keys(Uns32T*g, Uns32T* r1, Uns32T* r2); mas01mc@293: Uns32T get_t1(){return t1;} // hash-key t1 mas01mc@293: Uns32T get_t2(){return t2;} // hash-key t2 mas01mc@292: }; mas01mc@292: mas01mc@293: // Typedef for point-reporting callback function. Used to collect points during LSH retrieval mas01mc@292: typedef void (*ReporterCallbackPtr)(void* objPtr, Uns32T pointID, Uns32T queryIndex, float squaredDistance); mas01mc@292: mas01mc@292: // Interface for indexing and retrieval mas01mc@292: class G: public H{ mas01mc@292: private: mas01mc@308: char* indexName; mas01mc@308: mas01mc@293: // LSH serial data structure file handling mas01mc@292: void get_lock(int fd, bool exclusive); mas01mc@292: void release_lock(int fd); mas01mc@292: int serial_create(char* filename, Uns32T FMT); mas01mc@292: int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT); mas01mc@764: char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0); mas01mc@764: void serial_munmap(char* db, Uns32T N); mas01mc@292: int serial_open(char* filename,int writeFlag); mas01mc@292: void serial_close(int dbfid); mas01mc@292: mas01mc@292: // Function to write hashfunctions to disk mas01mc@292: int serialize_lsh_hashfunctions(int fid); mas01mc@292: mas01mc@292: // Functions to write hashtables to disk in format1 (optimized for on-disk retrieval) mas01mc@292: int serialize_lsh_hashtables_format1(int fid, int merge); mas01mc@292: void serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* h, Uns32T& colCount); mas01mc@292: void serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount); mas01mc@292: void serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* h, Uns32T& colCount); mas01mc@292: void serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount); mas01mc@292: int serial_can_merge(Uns32T requestedFormat); // Test to see whether core and on-disk structures are compatible mas01mc@292: mas01mc@292: // Functions to write hashtables to disk in format2 (optimized for in-core retrieval) mas01mc@336: int serialize_lsh_hashtables_format2(FILE* dbFile, int merge); mas01mc@336: void serial_write_hashtable_row_format2(FILE* dbFile, bucket* h, Uns32T& colCount); mas01mc@336: void serial_write_element_format2(FILE* dbFile, sbucket* sb, Uns32T& colCount); mas01mc@340: Uns32T count_buckets_and_points_hashtable_row(bucket* bPtr); mas01mc@340: Uns32T count_points_hashtable_row(bucket* bPtr); mas01mc@292: mas01mc@292: // Functions to read serial header and hash functions (format1 and format2) mas01mc@292: int unserialize_lsh_header(char* filename); // read lsh header from disk into core mas01mc@292: void unserialize_lsh_functions(int fid); // read the lsh hash functions into core mas01mc@292: mas01mc@292: // Functions to read hashtables in format1 mas01mc@292: void unserialize_lsh_hashtables_format1(int fid); // read FORMAT1 hash tables into core (disk format) mas01mc@292: void unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b); // read lsh hash table row into core mas01mc@292: mas01mc@292: // Functions to read hashtables in format2 mas01mc@340: void unserialize_lsh_hashtables_format2(FILE* dbFile, bool forMerge = 0); mas01mc@340: Uns32T unserialize_hashtable_row_format2(FILE* dbFile, bucket** b, Uns32T token=0); // to dynamic linked list mas01mc@340: Uns32T unserialize_hashtable_row_to_array(FILE* dbFile, bucket** b, Uns32T numElements); // to core array mas01mc@292: mas01mc@292: // Helper functions mas01mc@292: void serial_print_header(Uns32T requestedFormat); mas01mc@292: float* get_serial_hashfunction_base(char* db); mas01mc@292: SerialElementT* get_serial_hashtable_base(char* db); mas01mc@292: Uns32T get_serial_hashtable_offset(); // Size of SerialHeader + HashFunctions mas01mc@764: SerialHeaderT* serial_get_header(char* db); mas01mc@292: SerialHeaderT* lshHeader; mas01mc@292: mas01mc@292: // Core Retrieval/Inspections Functions mas01mc@292: void bucket_chain_point(bucket* p, Uns32T qpos); mas01mc@292: void sbucket_chain_point(sbucket* p, Uns32T qpos); mas01mc@292: void dump_hashtable_row(bucket* p); mas01mc@513: void dump_core_hashtable_array(Uns32T* p); mas01mc@292: mas01mc@292: // Serial (Format 1) Retrieval/Inspection Functions mas01mc@292: void serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos); mas01mc@292: void serial_bucket_dump(SerialElementT* pe); mas01mc@292: mas01mc@340: // Core ARRAY Retrieval Functions mas01mc@340: void retrieve_from_core_hashtable_array(Uns32T* p, Uns32T qpos); mas01mc@340: mas01mc@340: mas01mc@293: // Callback Function for point reporting mas01mc@293: void* calling_instance; // store calling object instance for member-function callback mas01mc@324: ReporterCallbackPtr add_point_callback; // Pointer to the callback function mas01mc@292: mas01mc@292: public: mas01mc@292: G(char* lshFile, bool lshInCore = false); // unserialize constructor mas01mc@292: G(float w, Uns32T k,Uns32T m, Uns32T d, Uns32T N, Uns32T C, float r); // core constructor mas01mc@474: virtual ~G(); mas01mc@292: mas01mc@292: Uns32T insert_point(vector&, Uns32T pointID); mas01mc@292: void insert_point_set(vector >& vv, Uns32T basePointID); mas01mc@292: mas01mc@292: // point retrieval from core mas01mc@292: void retrieve_point(vector& v, Uns32T qpos, ReporterCallbackPtr, void* me=NULL); mas01mc@292: // point set retrieval from core mas01mc@292: void retrieve_point_set(vector >& vv, ReporterCallbackPtr, void* me=NULL); mas01mc@292: // serial point set retrieval mas01mc@292: void serial_retrieve_point_set(char* filename, vector >& vv, ReporterCallbackPtr, void* me=NULL); mas01mc@292: // serial point retrieval mas01mc@292: void serial_retrieve_point(char* filename, vector& vv, Uns32T qpos, ReporterCallbackPtr, void* me=NULL); mas01mc@292: mas01mc@292: void serialize(char* filename, Uns32T serialFormat = O2_SERIAL_FILEFORMAT1); // write hashfunctions and hashtables to disk mas01mc@292: mas01mc@292: SerialHeaderT* get_lshHeader(){return lshHeader;} mas01mc@292: void serial_dump_tables(char* filename); mas01mc@292: float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; } mas01mc@308: char* get_indexName(){return indexName;} mas01mc@513: void dump_hashtables(); mas01mc@764: void dump_core_row(Uns32T n); mas01mc@764: void dump_disk_row(char*, Uns32T n); mas01mc@292: }; mas01mc@292: mas01mc@292: typedef class G LSH; mas01mc@292: mas01mc@292: mas01mc@292: mas01mc@292: #endif