Mercurial > hg > audiodb
changeset 754:9bd13c7819ae mkc_lsh_update
Adding mkc_lsh_update branch, trunk candidate with improved LSH: merged trunk 1095 and branch multiprobe_lsh
author | mas01mc |
---|---|
date | Thu, 25 Nov 2010 13:42:40 +0000 |
parents | fbf16508421f |
children | 9b75573be3b9 |
files | Makefile bindings/python/pyadbmodule.c lshlib.cpp lshlib.h multiprobe.cpp multiprobe.h |
diffstat | 6 files changed, 863 insertions(+), 176 deletions(-) [+] |
line wrap: on
line diff
--- a/Makefile Sat Nov 20 15:32:58 2010 +0000 +++ b/Makefile Thu Nov 25 13:42:40 2010 +0000 @@ -16,7 +16,7 @@ INCLUDEDIR=$(PREFIX)/include MANDIR=$(PREFIX)/share/man -LIBOBJS=lock.o pointpair.o create.o open.o power.o l2norm.o insert.o status.o query.o dump.o close.o index-utils.o query-indexed.o liszt.o retrieve.o lshlib.o sample.o +LIBOBJS=lock.o pointpair.o create.o open.o power.o l2norm.o insert.o status.o query.o dump.o close.o index-utils.o query-indexed.o liszt.o retrieve.o lshlib.o multiprobe.o sample.o OBJS=$(LIBOBJS) index.o soap.o cmdline.o audioDB.o common.o EXECUTABLE=audioDB
--- a/bindings/python/pyadbmodule.c Sat Nov 20 15:32:58 2010 +0000 +++ b/bindings/python/pyadbmodule.c Thu Nov 25 13:42:40 2010 +0000 @@ -335,7 +335,7 @@ adb_t *current_db; adb_query_spec_t *spec; adb_query_results_t *result; - int ok, exhaustive, falsePositives; + int ok, exhaustive=0, falsePositives=0; uint32_t i; const char *key; const char *accuMode = "db"; @@ -688,10 +688,10 @@ free(ins); // free the malloced adb_datum_t structure though if (!outgoing){ - PyErr_SetString(PyExc_TypeError, "Failed to convert retrieved datum to C-Array"); + PyErr_SetString(PyExc_TypeError, "Failed to convert retrieved datum to PyArray"); return NULL; } - // Apprently Python automatically INCREFs the data pointer, so we don't have to call + // Apparently Python automatically INCREFs the data pointer, so we don't have to call // audiodb_free_datum(current_db, ins); return outgoing;
--- a/lshlib.cpp Sat Nov 20 15:32:58 2010 +0000 +++ b/lshlib.cpp Thu Nov 25 13:42:40 2010 +0000 @@ -1,39 +1,14 @@ -#include <vector> -#include <queue> -#include <stdio.h> -#include <stdlib.h> -#include <sys/types.h> -#include <sys/stat.h> -#if defined(WIN32) -#include <sys/locking.h> -#include <io.h> -#include <windows.h> -#endif -#include <fcntl.h> -#include <string.h> -#include <iostream> -#include <fstream> -#include <math.h> -#include <sys/time.h> -#include <assert.h> -#include <float.h> -#include <signal.h> -#include <time.h> -#include <limits.h> -#include <errno.h> -#ifdef MT19937 -#include "mt19937/mt19937ar.h" -#endif - #include "lshlib.h" #define getpagesize() (64*1024) -Uns32T get_page_logn() { - int pagesz = (int) getpagesize(); - return (Uns32T) log2((double) pagesz); +Uns32T get_page_logn(){ + int pagesz = (int)sysconf(_SC_PAGESIZE); + return (Uns32T)log2((double)pagesz); } +unsigned align_up(unsigned x, unsigned w) { return (((x) + ((1<<w)-1)) & ~((1<<w)-1)); } + void H::error(const char* a, const char* b, const char *sysFunc) { cerr << a << ": " << b << endl; if (sysFunc) { @@ -42,7 +17,10 @@ exit(1); } -H::H(){ +H::H(): + multiProbePtr(new MultiProbe()), + boundaryDistances(0) +{ // Delay initialization of lsh functions until we know the parameters } @@ -62,7 +40,9 @@ L((mm*(mm-1))/2), d(dd), w(ww), - radius(rr) + radius(rr), + multiProbePtr(new MultiProbe()), + boundaryDistances(0) { if(m<2){ @@ -164,6 +144,16 @@ // Storage for whole or partial function evaluation depending on USE_U_FUNCTIONS H::initialize_partial_functions(); + + // MultiProbe distance functions, there are 2*k per hashtable + H::boundaryDistances = new float*[ H::L ]; // L x 2k boundary distances + assert( H::boundaryDistances ); // failure + for( j = 0; j < H::L ; j++ ){ // 2*k functions x_i(q) + H::boundaryDistances[j] = new float[ 2*H::k ]; + assert( H::boundaryDistances[j] ); // failure + for( kk = 0; kk < 2*H::k ; kk++ ) + H::boundaryDistances[j][kk] = 0.0f; // initialize with zeros + } } void H::initialize_partial_functions(){ @@ -261,6 +251,12 @@ delete[] H::r1; delete[] H::r2; delete[] H::h; + + // MultiProbe cleanup + for( j = 0 ; j < H::L ; j++ ) + delete[] H::boundaryDistances[j]; + delete[] H::boundaryDistances; + delete multiProbePtr; } @@ -273,7 +269,7 @@ if( v.size() != H::d ) error("v.size != H::d","","compute_hash_functions"); // check input vector dimensionality double tmp = 0; - float *pA, *pb; + float *pA, *pb, *bd; Uns32T *pg; int dd; vector<float>::iterator vi; @@ -316,6 +312,7 @@ #else for( aa=0; aa < H::L ; aa++ ){ pg= *( H::g + aa ); // L \times functions g_j(v) \in Z^k + bd= *( H::boundaryDistances + aa); for( kk = 0 ; kk != H::k ; kk++ ){ pb = *( H::b + aa ) + kk; pA = * ( * ( H::A + aa ) + kk ); @@ -325,8 +322,14 @@ while( dd-- ) tmp += *pA++ * *vi++; // project tmp += *pb; // translate - tmp *= iw; // scale - *pg++ = (Uns32T) (floor(tmp)); // hash function g_j(v)=[x1 x2 ... xk]; xk \in Z + tmp *= iw; // scale + tmp = floor(tmp); // handle negative values + while(tmp<0) // wrap around 0 to N + tmp += H::N; + *pg = (Uns32T) tmp; // hash function g_j(v)=[x1 x2 ... xk]; xk \in Z + *bd = (tmp - *pg++);//*w; // boundary distance -1 + *(bd+1) = (1.0f - *bd); //*w; // boundary distance +1 + bd+=2; } } #endif @@ -338,6 +341,35 @@ H::t2 = computeProductModDefaultPrime( g, r2, H::k ); } +// make hash value by purturbating the given hash functions +// according the the boundary distances of the current query +void H::generate_multiprobe_keys(Uns32T*g, Uns32T* r1, Uns32T* r2){ + assert(!multiProbePtr->empty()); // Test this for now, until all is stable + Uns32T* mpg = new Uns32T[H::k]; // temporary array storage + + // Copy the hash bucket identifiers + Uns32T* mpgPtr = mpg; + Uns32T kk = H::k; + while(kk--) + *mpgPtr++ = *g++; + + // Retrieve the next purturbation set + perturbation_set ps = multiProbePtr->getNextPerturbationSet(); + perturbation_set::iterator it = ps.begin(); + + // Perturbate the hash functions g + while( it != ps.end() ){ + *(mpg + multiProbePtr->getIndex(it)) += multiProbePtr->getBoundary(it); + it++; + } + + H::t1 = computeProductModDefaultPrime( mpg, r1, H::k ) % H::N; + H::t2 = computeProductModDefaultPrime( mpg, r2, H::k ); + + delete[] mpg; // free up temporary storage +} + + #define CR_ASSERT(b){if(!(b)){fprintf(stderr, "ASSERT failed on line %d, file %s.\n", __LINE__, __FILE__); exit(1);}} // Computes (a.b) mod UH_PRIME_DEFAULT @@ -399,18 +431,23 @@ __sbucket_insert_point(p->snext.ptr); return; } + + // insert bucket before current bucket + if(H::t2 < p->t2){ + bucket* tmp = new bucket(); + // copy current bucket contents into new bucket + tmp->next = p->next; + tmp->t2 = p->t2; + tmp->snext.ptr = p->snext.ptr; + p->next = tmp; + p->t2 = IFLAG; + p->snext.ptr=0; + __bucket_insert_point(p); + return; + } - if(p->next){ - // Construct list in t2 order - if(H::t2 < p->next->t2){ - bucket* tmp = new bucket(); - tmp->next = p->next; - p->next = tmp; - __bucket_insert_point(tmp); - } - else - __bucket_insert_point(p->next); - } + if(p->next) + __bucket_insert_point(p->next); else { p->next = new bucket(); __bucket_insert_point(p->next); @@ -471,6 +508,7 @@ { FILE* dbFile = 0; int dbfid = unserialize_lsh_header(filename); + indexName = new char[O2_INDEX_MAXSTR]; strncpy(indexName, filename, O2_INDEX_MAXSTR); // COPY THE CONTENTS TO THE NEW POINTER H::initialize_lsh_functions(); // Base-class data-structure initialization @@ -525,21 +563,29 @@ // point retrieval routine void G::retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ + // assert(LSH_MULTI_PROBE_COUNT); calling_instance = caller; add_point_callback = add_point; H::compute_hash_functions( v ); for(Uns32T j = 0 ; j < H::L ; j++ ){ - H::generate_hash_keys( *( H::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); - if( bucket* bPtr = *(get_bucket(j) + get_t1()) ) { + // MultiProbe loop + multiProbePtr->generatePerturbationSets( *( H::boundaryDistances + j ) , 2*H::k, (unsigned)LSH_MULTI_PROBE_COUNT); + for(Uns32T multiProbeIdx = 0 ; multiProbeIdx < multiProbePtr->size()+1 ; multiProbeIdx++ ){ + if(!multiProbeIdx) + H::generate_hash_keys( *( H::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); + else + H::generate_multiprobe_keys( *( H::g + j ), *( H::r1 + j ), *( H::r2 + j ) ); + if( bucket* bPtr = *(get_bucket(j) + get_t1()) ) { #ifdef LSH_LIST_HEAD_COUNTERS - if(bPtr->t2&LSH_CORE_ARRAY_BIT) { - retrieve_from_core_hashtable_array((Uns32T*)(bPtr->next), qpos); - } else { - bucket_chain_point( bPtr->next, qpos); + if(bPtr->t2&LSH_CORE_ARRAY_BIT) { + retrieve_from_core_hashtable_array((Uns32T*)(bPtr->next), qpos); + } else { + bucket_chain_point( bPtr->next, qpos); + } +#else + bucket_chain_point( bPtr , qpos); +#endif } -#else - bucket_chain_point( bPtr , qpos); -#endif } } } @@ -617,6 +663,7 @@ // // T1 - T1 hash token // t1 - t1 hash key (t1 range 0..2^29-1) +// %buckets+points% numElements in row for ARRAY encoding // T2 - T2 token // t2 - t2 hash key (range 1..2^32-6) // p - point identifier (range 0..2^32-1) @@ -628,7 +675,7 @@ // {...}^L - repeat argument L times // // FORMAT2 Regular expression: -// { [T1 t1 [T2 t2 p+]+ ]* E }^L +// { [T1 t1 %buckets+points% [T2 t2 p+]+ ]* E }^L // // Serial header constructors @@ -681,6 +728,7 @@ void G::serialize(char* filename, Uns32T serialFormat){ int dbfid; + char* db; int dbIsNew=0; FILE* dbFile = 0; // Check requested serialFormat @@ -704,7 +752,9 @@ // Load the on-disk header into core dbfid = serial_open(filename, 1); // open for write - serial_get_header(dbfid); // read header + db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer + serial_get_header(db); // read header + serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap // Check compatibility of core and disk data structures if( !serial_can_merge(serialFormat) ) @@ -725,12 +775,15 @@ } if(!dbIsNew) { + db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1);// get database pointer + //serial_get_header(db); // read header cout << "maxp = " << H::maxp << endl; lshHeader->maxp=H::maxp; // Default to FILEFORMAT1 if(!(lshHeader->flags&O2_SERIAL_FILEFORMAT2)) lshHeader->flags|=O2_SERIAL_FILEFORMAT1; - serial_write_header(dbfid, lshHeader); + memcpy((char*)db, (char*)lshHeader, sizeof(SerialHeaderT)); + serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap } serial_close(dbfid); if(dbFile){ @@ -776,8 +829,9 @@ Uns32T *pu; Uns32T x,y,z; - void *db = calloc(get_serial_hashtable_offset() - O2_SERIAL_HEADER_SIZE, 1); - pf = (float *) db; + char* db = serial_mmap(fid, get_serial_hashtable_offset(), 1);// get database pointer + pf = get_serial_hashfunction_base(db); + // HASH FUNCTIONS // Write the random projectors A[][][] #ifdef USE_U_FUNCTIONS @@ -812,32 +866,10 @@ for( y = 0; y < H::k ; y++) *pu++ = H::r2[x][y]; - off_t cur = lseek(fid, 0, SEEK_CUR); - lseek(fid, O2_SERIAL_HEADER_SIZE, SEEK_SET); - write(fid, db, get_serial_hashtable_offset() - O2_SERIAL_HEADER_SIZE); - lseek(fid, cur, SEEK_SET); - - free(db); - + serial_munmap(db, get_serial_hashtable_offset()); return 1; } -void G::serial_get_table(int fd, int nth, void *buf, size_t count) { - off_t cur = lseek(fd, 0, SEEK_CUR); - /* FIXME: if hashTableSize isn't bigger than a page, this loses. */ - lseek(fd, align_up(get_serial_hashtable_offset() + nth * count, get_page_logn()), SEEK_SET); - read(fd, buf, count); - lseek(fd, cur, SEEK_SET); -} - -void G::serial_write_table(int fd, int nth, void *buf, size_t count) { - off_t cur = lseek(fd, 0, SEEK_CUR); - /* FIXME: see the comment in serial_get_table() */ - lseek(fd, align_up(get_serial_hashtable_offset() + nth * count, get_page_logn()), SEEK_SET); - write(fd, buf, count); - lseek(fd, cur, SEEK_SET); -} - int G::serialize_lsh_hashtables_format1(int fid, int merge){ SerialElementT *pe, *pt; Uns32T x,y; @@ -845,19 +877,27 @@ if( merge && !serial_can_merge(O2_SERIAL_FILEFORMAT1) ) error("Cannot merge core and serial LSH, data structure dimensions mismatch."); + Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; Uns32T colCount, meanColCount, colCountN, maxColCount, minColCount; - size_t hashTableSize = sizeof(SerialElementT) * lshHeader->numRows * lshHeader->numCols; - pt = (SerialElementT *) malloc(hashTableSize); // Write the hash tables for( x = 0 ; x < H::L ; x++ ){ std::cout << (merge ? "merging":"writing") << " hash table " << x << " FORMAT1..."; std::cout.flush(); - // read a hash table's data from disk - serial_get_table(fid, x, pt, hashTableSize); + // memory map a single hash table for sequential access + // Align each hash table to page boundary + char* dbtable = serial_mmap(fid, hashTableSize, 1, + align_up(get_serial_hashtable_offset()+x*hashTableSize, get_page_logn())); +#ifdef __CYGWIN__ + // No madvise in CYGWIN +#else + if(madvise(dbtable, hashTableSize, MADV_SEQUENTIAL)<0) + error("could not advise hashtable memory","","madvise"); +#endif maxColCount=0; minColCount=O2_SERIAL_MAX_COLS; meanColCount=0; colCountN=0; + pt=(SerialElementT*)dbtable; for( y = 0 ; y < H::N ; y++ ){ // Move disk pointer to beginning of row pe=pt+y*lshHeader->numCols; @@ -889,11 +929,10 @@ std::cout << "#rows with collisions =" << colCountN << ", mean = " << meanColCount/(float)colCountN << ", min = " << minColCount << ", max = " << maxColCount << endl; - serial_write_table(fid, x, pt, hashTableSize); + serial_munmap(dbtable, hashTableSize); } // We're done writing - free(pt); return 1; } @@ -977,6 +1016,7 @@ minColCount=O2_SERIAL_MAX_COLS; meanColCount=0; colCountN=0; + H::tablesPointCount = 0; for( y = 0 ; y < H::N ; y++ ){ colCount=0; if(bucket* bPtr = h[x][y]){ @@ -1014,12 +1054,13 @@ meanColCount+=colCount; colCountN++; } + H::tablesPointCount+=colCount; } // Write END of table marker t1 = O2_SERIAL_TOKEN_ENDTABLE; WRITE_UNS32(&t1,"[end]"); if(colCountN) - std::cout << "#rows with collisions =" << colCountN << ", mean = " << meanColCount/(float)colCountN + std::cout << "#points: " << H::tablesPointCount << " #rows with collisions =" << colCountN << ", mean = " << meanColCount/(float)colCountN << ", min = " << minColCount << ", max = " << maxColCount << endl; } @@ -1064,17 +1105,29 @@ } void G::serial_write_hashtable_row_format2(FILE* dbFile, bucket* b, Uns32T& colCount){ +#ifdef _LSH_DEBUG_ + Uns32T last_t2 = 0; +#endif while(b && b->t2!=IFLAG){ if(!b->snext.ptr){ fclose(dbFile); error("Empty collision chain in serial_write_hashtable_row_format2()"); } t2 = O2_SERIAL_TOKEN_T2; + if( fwrite(&t2, sizeof(Uns32T), 1, dbFile) != 1 ){ fclose(dbFile); error("write error in serial_write_hashtable_row_format2()"); } t2 = b->t2; +#ifdef _LSH_DEBUG_ + if(t2 < last_t2){ + fclose(dbFile); + error("t2<last_t2 in serial_write_hashtable_row_format2()"); + } + last_t2 = t2; +#endif + if( fwrite(&t2, sizeof(Uns32T), 1, dbFile) != 1 ){ fclose(dbFile); error("write error in serial_write_hashtable_row_format2()"); @@ -1136,8 +1189,14 @@ // write a dummy byte at the last location if (write (dbfid, "", 1) != 1) error("write error", "", "write"); + + char* db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 1); + + memcpy (db, lshHeader, O2_SERIAL_HEADER_SIZE); + + serial_munmap(db, O2_SERIAL_HEADER_SIZE); + - serial_write_header(dbfid, lshHeader); close(dbfid); std::cout << "done initializing tables." << endl; @@ -1145,25 +1204,31 @@ return 1; } -SerialHeaderT* G::serial_get_header(int fd) { - off_t cur = lseek(fd, 0, SEEK_CUR); +char* G::serial_mmap(int dbfid, Uns32T memSize, Uns32T forWrite, off_t offset){ + char* db; + if(forWrite){ + if ((db = (char*) mmap(0, memSize, PROT_READ | PROT_WRITE, + MAP_SHARED, dbfid, offset)) == (caddr_t) -1) + error("mmap error in request for writable serialized database", "", "mmap"); + } + else if ((db = (char*) mmap(0, memSize, PROT_READ, MAP_SHARED, dbfid, offset)) == (caddr_t) -1) + error("mmap error in read-only serialized database", "", "mmap"); + + return db; +} + +SerialHeaderT* G::serial_get_header(char* db){ lshHeader = new SerialHeaderT(); - lseek(fd, 0, SEEK_SET); - if(read(fd, lshHeader, sizeof(SerialHeaderT)) != (ssize_t) (sizeof(SerialHeaderT))) - error("Bad return from read"); + memcpy((char*)lshHeader, db, sizeof(SerialHeaderT)); if(lshHeader->lshMagic!=O2_SERIAL_MAGIC) error("Not an LSH database file"); - lseek(fd, cur, SEEK_SET); + return lshHeader; } -void G::serial_write_header(int fd, SerialHeaderT *header) { - off_t cur = lseek(fd, 0, SEEK_CUR); - lseek(fd, 0, SEEK_SET); - if(write(fd, header, sizeof(SerialHeaderT)) != (ssize_t) (sizeof(SerialHeaderT))) - error("Bad return from write"); - lseek(fd, cur, SEEK_SET); +void G::serial_munmap(char* db, Uns32T N){ + munmap(db, N); } int G::serial_open(char* filename, int writeFlag){ @@ -1189,14 +1254,16 @@ } int G::unserialize_lsh_header(char* filename){ - int dbfid; + char* db; // Test to see if file exists if((dbfid = open (filename, O_RDONLY)) < 0) error("Can't open the file", filename, "open"); close(dbfid); dbfid = serial_open(filename, 0); // open for read - serial_get_header(dbfid); + db = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer + serial_get_header(db); // read header + serial_munmap(db, O2_SERIAL_HEADER_SIZE); // drop mmap // Unserialize header parameters H::L = lshHeader->numTables; @@ -1222,55 +1289,59 @@ Uns32T* pu; // Load the hash functions into core - off_t cur = lseek(dbfid, 0, SEEK_CUR); - void *db = malloc(get_serial_hashtable_offset() - O2_SERIAL_HEADER_SIZE); - lseek(dbfid, O2_SERIAL_HEADER_SIZE, SEEK_SET); - read(dbfid, db, get_serial_hashtable_offset() - O2_SERIAL_HEADER_SIZE); - lseek(dbfid, cur, SEEK_SET); - pf = (float *)db; + char* db = serial_mmap(dbfid, get_serial_hashtable_offset(), 0);// get database pointer again + + pf = get_serial_hashfunction_base(db); #ifdef USE_U_FUNCTIONS for( j = 0 ; j < H::m ; j++ ){ // L functions gj(v) for( kk = 0 ; kk < H::k/2 ; kk++ ){ // Normally distributed hash functions #else - for( j = 0 ; j < H::L ; j++ ){ // L functions gj(v) - for( kk = 0 ; kk < H::k ; kk++ ){ // Normally distributed hash functions + for( j = 0 ; j < H::L ; j++ ){ // L functions gj(v) + for( kk = 0 ; kk < H::k ; kk++ ){ // Normally distributed hash functions #endif - for(Uns32T i = 0 ; i < H::d ; i++ ) - H::A[j][kk][i] = *pf++; // Normally distributed random vectors - } - } + for(Uns32T i = 0 ; i < H::d ; i++ ) + H::A[j][kk][i] = *pf++; // Normally distributed random vectors + } + } #ifdef USE_U_FUNCTIONS - for( j = 0 ; j < H::m ; j++ ) // biases b - for( kk = 0 ; kk < H::k/2 ; kk++ ) + for( j = 0 ; j < H::m ; j++ ) // biases b + for( kk = 0 ; kk < H::k/2 ; kk++ ) #else - for( j = 0 ; j < H::L ; j++ ) // biases b - for( kk = 0 ; kk < H::k ; kk++ ) + for( j = 0 ; j < H::L ; j++ ) // biases b + for( kk = 0 ; kk < H::k ; kk++ ) #endif - H::b[j][kk] = *pf++; + H::b[j][kk] = *pf++; - pu = (Uns32T*)pf; - for( j = 0 ; j < H::L ; j++ ) // Z projectors r1 - for( kk = 0 ; kk < H::k ; kk++ ) - H::r1[j][kk] = *pu++; - - for( j = 0 ; j < H::L ; j++ ) // Z projectors r2 - for( kk = 0 ; kk < H::k ; kk++ ) - H::r2[j][kk] = *pu++; + pu = (Uns32T*)pf; + for( j = 0 ; j < H::L ; j++ ) // Z projectors r1 + for( kk = 0 ; kk < H::k ; kk++ ) + H::r1[j][kk] = *pu++; + + for( j = 0 ; j < H::L ; j++ ) // Z projectors r2 + for( kk = 0 ; kk < H::k ; kk++ ) + H::r2[j][kk] = *pu++; - free(db); + serial_munmap(db, get_serial_hashtable_offset()); } void G::unserialize_lsh_hashtables_format1(int fid){ SerialElementT *pe, *pt; Uns32T x,y; Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; - pt = (SerialElementT *) malloc(hashTableSize); // Read the hash tables into core for( x = 0 ; x < H::L ; x++ ){ // memory map a single hash table // Align each hash table to page boundary - serial_get_table(fid, x, pt, hashTableSize); + char* dbtable = serial_mmap(fid, hashTableSize, 0, + align_up(get_serial_hashtable_offset()+x*hashTableSize, get_page_logn())); +#ifdef __CYGWIN__ + // No madvise in CYGWIN +#else + if(madvise(dbtable, hashTableSize, MADV_SEQUENTIAL)<0) + error("could not advise hashtable memory","","madvise"); +#endif + pt=(SerialElementT*)dbtable; for( y = 0 ; y < H::N ; y++ ){ // Move disk pointer to beginning of row pe=pt+y*lshHeader->numCols; @@ -1282,8 +1353,8 @@ dump_hashtable_row(h[x][y]); #endif } + serial_munmap(dbtable, hashTableSize); } - free(pt); } void G::unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b){ @@ -1299,7 +1370,12 @@ void G::unserialize_lsh_hashtables_format2(FILE* dbFile, bool forMerge){ Uns32T x=0,y=0; - +#ifdef _LSH_DEBUG_ + cout << "Loading hashtables..." << endl; + cout << "header pointCount = " << pointCount << endl; + cout << "forMerge = " << forMerge << endl; + Uns32T sumTablesPointCount = 0; +#endif // Seek to hashtable base offset if(fseek(dbFile, get_serial_hashtable_offset(), SEEK_SET)){ fclose(dbFile); @@ -1308,6 +1384,7 @@ // Read the hash tables into core (structure is given in header) while( x < H::L){ + tablesPointCount=0; if(fread(&(H::t1), sizeof(Uns32T), 1, dbFile) != 1){ fclose(dbFile); error("Read error","unserialize_lsh_hashtables_format2()"); @@ -1337,21 +1414,26 @@ fclose(dbFile); error("Read error: numElements","unserialize_lsh_hashtables_format2()"); } - + + /* +#ifdef _LSH_DEBUG_ + cout << "[" << x << "," << y << "] numElements(disk) = " << numElements; +#endif + */ // BACKWARD COMPATIBILITY: check to see if T2 or END token was read if(numElements==O2_SERIAL_TOKEN_T2 || numElements==O2_SERIAL_TOKEN_ENDTABLE ){ forMerge=true; // Force use of dynamic linked list core format token = numElements; } - if(forMerge) + if(forMerge) // FOR INDEXING use dynamic linked list data structure // Use linked list CORE format token = unserialize_hashtable_row_format2(dbFile, h[x]+y, token); - else + else // FOR QUERY use static array data structure // Use ARRAY CORE format with numElements counter token = unserialize_hashtable_row_to_array(dbFile, h[x]+y, numElements); #else - token = unserialize_hashtable_row_format2(dbFile, h[x]+y); + token = unserialize_hashtable_row_format2(dbFile, h[x]+y); #endif // Check that token is valid if( !(token==O2_SERIAL_TOKEN_T1 || token==O2_SERIAL_TOKEN_ENDTABLE) ){ @@ -1367,7 +1449,14 @@ if(token==O2_SERIAL_TOKEN_T1) H::t1 = token; } +#ifdef _LSH_DEBUG_ + cout << "[T " << x-1 << "] pointCount = " << tablesPointCount << endl; + sumTablesPointCount+=tablesPointCount; +#endif } +#ifdef _LSH_DEBUG_ + cout << "TOTAL pointCount = " << sumTablesPointCount << endl; +#endif #ifdef LSH_DUMP_CORE_TABLES dump_hashtables(); #endif @@ -1407,6 +1496,7 @@ while(!(H::p==O2_SERIAL_TOKEN_ENDTABLE || H::p==O2_SERIAL_TOKEN_T1 || H::p==O2_SERIAL_TOKEN_T2 )){ pointFound=true; bucket_insert_point(b); + tablesPointCount++; if(fread(&(H::p), sizeof(Uns32T), 1, dbFile) != 1){ fclose(dbFile); error("Read error H::p","unserialize_hashtable_row_format2"); @@ -1434,7 +1524,7 @@ // // We store the values of numPoints and numBuckets in separate fields of the first bucket // rowPtr->t2 // numPoints -// (Uns32T)(rowPtr->snext) // numBuckets +// (rowPtr->snext.numBuckets) // numBuckets // // We cast the rowPtr->next pointer to (Uns32*) malloc(numElements*sizeof(Uns32T) + sizeof(bucket*)) // To get to the fist bucket, we use @@ -1459,6 +1549,7 @@ secondPtr=ap++;\ *secondPtr=0;\ numPoints++;\ + numSingletons++;\ }\ if(numPointsThisBucket>1){\ *firstPtr |= ( (numPointsThisBucket-1) & 0x3 ) << SKIP_BITS_LEFT_SHIFT_MSB;\ @@ -1470,6 +1561,7 @@ Uns32T numPoints = 0; Uns32T* firstPtr = 0; Uns32T* secondPtr = 0; + Uns32T numSingletons = 0; // Count single point puckets because we encode them with 2 points (for skip) // Initialize new row if(!*rowPP){ @@ -1480,7 +1572,7 @@ #endif } bucket* rowPtr = *rowPP; - + Uns32T last_t2 = 0; READ_UNS32T(&(H::t2),"t2"); TEST_TOKEN(!(H::t2==O2_SERIAL_TOKEN_ENDTABLE || H::t2==O2_SERIAL_TOKEN_T2), "expected E or T2"); // Because we encode points in 16-point blocks, we sometimes allocate repeated t2 elements @@ -1492,12 +1584,26 @@ secondPtr = 0; // reset second-point pointer TEST_TOKEN(H::t2!=O2_SERIAL_TOKEN_T2, "expected T2"); READ_UNS32T(&(H::t2), "Read error t2"); + if(H::t2<last_t2) + cout << "last_t2=" << last_t2 << ", t2=" << H::t2 << endl; + TEST_TOKEN(H::t2<last_t2, "t2 tokens not in ascending order"); + last_t2 = H::t2; + /* +#ifdef _LSH_DEBUG_ + cout << "+" << H::t2 << "+"; +#endif + */ *ap++ = H::t2; // Insert t2 value into array numBuckets++; READ_UNS32T(&(H::p), "Read error H::p"); while(!(H::p==O2_SERIAL_TOKEN_ENDTABLE || H::p==O2_SERIAL_TOKEN_T1 || H::p==O2_SERIAL_TOKEN_T2 )){ if(numPointsThisBucket==MAX_POINTS_IN_BUCKET_CORE_ARRAY){ ENCODE_POINT_SKIP_BITS; + /* +#ifdef _LSH_DEBUG_ + cout << "*" << H::t2 << "*"; +#endif + */ *ap++ = H::t2; // Extra element numBuckets++; // record this as a new bucket numPointsThisBucket=0; // reset bucket point counter @@ -1508,6 +1614,12 @@ else if( numPointsThisBucket == 2 ) secondPtr = ap; // store pointer to first point to insert skip bits later on numPoints++; + /* +#ifdef _LSH_DEBUG_ + cout << "(" << H::p << ":" << numPoints << ")"; +#endif + */ + *ap++ = H::p; READ_UNS32T(&(H::p), "Read error H::p"); } @@ -1526,6 +1638,12 @@ // Allocate a new dynamic list head at the end of the array bucket** listPtr = reinterpret_cast<bucket**> (ap); *listPtr = 0; + /* +#ifdef _LSH_DEBUG_ + cout << " numBuckets=" << numBuckets << " numPoints=" << numPoints - numSingletons << " numElements(array) " << numBuckets+numPoints - numSingletons << " " << endl; +#endif + */ + H::tablesPointCount += numPoints - numSingletons; // Return current token return H::t2; // return H::t2 which holds current token [E or T1] } @@ -1537,7 +1655,7 @@ // Retrieval is performed by generating a hash key query_t2 for query point q // We identify the row that t2 is stored in using a secondary hash t1, this row is the entry // point for retrieve_from_core_hashtable_array -#define SKIP_BITS (0xC0000000) +#define SKIP_BITS (0xC0000000U) void G::retrieve_from_core_hashtable_array(Uns32T* p, Uns32T qpos){ Uns32T skip; Uns32T t2; @@ -1554,9 +1672,9 @@ p2 = *p++; skip = (( p1 & SKIP_BITS ) >> SKIP_BITS_RIGHT_SHIFT_LSB) + (( p2 & SKIP_BITS ) >> SKIP_BITS_RIGHT_SHIFT_MSB); if( t2 == H::t2 ){ - add_point_callback(calling_instance, p1 ^ (p1 & SKIP_BITS), qpos, radius); + add_point_callback(calling_instance, p1 & ~SKIP_BITS, qpos, radius); if(skip--){ - add_point_callback(calling_instance, p2 ^ (p2 & SKIP_BITS), qpos, radius); + add_point_callback(calling_instance, p2 & ~SKIP_BITS, qpos, radius); while(skip-- ) add_point_callback(calling_instance, *p++, qpos, radius); } @@ -1590,6 +1708,125 @@ } } + void G::dump_core_row(Uns32T n){ + if(!(n<H::N)){ + printf("ROW OUT OF RANGE:%d (MAX:%d)\n", n, H::N-1); + return; + } + for(Uns32T j = 0 ; j < H::L ; j++ ){ + bucket* bPtr = h[j][n]; + if(bPtr){ + printf("C[%d,%d]", j, n); +#ifdef LSH_LIST_HEAD_COUNTERS + printf("[numBuckets=%d]",bPtr->snext.numBuckets); + if(bPtr->t2&LSH_CORE_ARRAY_BIT) { + dump_core_hashtable_array((Uns32T*)(bPtr->next)); + } + else { + dump_hashtable_row(bPtr->next); + } +#else + dump_hashtable_row(bPtr); +#endif + printf("\n"); + } + } + } + + void G::dump_disk_row(char* filename, Uns32T n){ + int dbfid = unserialize_lsh_header(filename); + if(dbfid<0){ + cerr << "Could not read header from " << filename << endl; + return; + } + FILE* dbFile = 0; + dbFile = fdopen(dbfid, "rb"); + if(!dbFile){ + cerr << "Could not create FILE pointer from file:" << filename << " with fid:" << dbfid << endl; + close(dbfid); + return; + } + + Uns32T x=0,y=0; + + // Seek to hashtable base offset + if(fseek(dbFile, get_serial_hashtable_offset(), SEEK_SET)){ + fclose(dbFile); + error("fSeek error in unserialize_lsh_hashtables_format2"); + } + Uns32T token = 0; + Uns32T pointID; + + // Read the hash tables into core (structure is given in header) + while( x < H::L){ + y=0; + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error T1","unserialize_lsh_hashtables_format2()"); + } + while(token != O2_SERIAL_TOKEN_ENDTABLE){ + if(token == O2_SERIAL_TOKEN_T1){ + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error t1","unserialize_lsh_hashtables_format2()"); + } + y=token; + if(y==n){ + printf("D[%d,%d]", x, y); + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error 2","unserialize_lsh_hashtables_format2()"); + } + printf("[numElements=%d]", token); + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error 3","unserialize_lsh_hashtables_format2()"); + } + while(!(token==O2_SERIAL_TOKEN_ENDTABLE || token==O2_SERIAL_TOKEN_T1)){ + // Check for T2 token + if(token!=O2_SERIAL_TOKEN_T2){ + printf("t2=%d",token); + fclose(dbFile); + error("State machine error T2 token", "unserialize_hashtable_row_format2()"); + } + // Read t2 value + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error t2","unserialize_hashtable_row_format2"); + } + if(fread(&pointID, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error pointID","unserialize_hashtable_row_format2"); + } + while(!(pointID==O2_SERIAL_TOKEN_ENDTABLE || pointID==O2_SERIAL_TOKEN_T1 || pointID==O2_SERIAL_TOKEN_T2 )){ + printf("(%0X,%u)", token, pointID); + if(fread(&pointID, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error H::p","unserialize_hashtable_row_format2"); + } + } + token = pointID; // Copy last found token + } + printf("\n"); + } + else{ // gobble up rest of row + while(!(token==O2_SERIAL_TOKEN_T1 || token==O2_SERIAL_TOKEN_ENDTABLE)){ + if(fread(&token, sizeof(Uns32T), 1, dbFile) != 1){ + fclose(dbFile); + error("Read error 4","unserialize_lsh_hashtables_format2()"); + } + } + } + } + } + if(token==O2_SERIAL_TOKEN_ENDTABLE){ + x++; + } + } + close(dbfid); + } + + void G::dump_core_hashtable_array(Uns32T* p){ Uns32T skip; Uns32T t2; @@ -1601,11 +1838,11 @@ p1 = *p++; p2 = *p++; skip = (( p1 & SKIP_BITS ) >> SKIP_BITS_RIGHT_SHIFT_LSB) + (( p2 & SKIP_BITS ) >> SKIP_BITS_RIGHT_SHIFT_MSB); - printf("(%0x, %0x)", t2, p1 ^ (p1 & SKIP_BITS)); + printf("(%0X, %u)", t2, p1 & ~SKIP_BITS); if(skip--){ - printf("(%0x, %0x)", t2, p2 ^ (p2 & SKIP_BITS)); + printf("(%0X, %u)", t2, p2 & ~SKIP_BITS); while(skip-- ) - printf("(%0x, %0x)", t2, *p++); + printf("(%0X, %u)", t2, *p++); } }while( *p != LSH_CORE_ARRAY_END_ROW_TOKEN ); } @@ -1636,7 +1873,9 @@ void G::serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr add_point, void* caller) { int dbfid = serial_open(filename, 0); // open for read - serial_get_header(dbfid); + char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer + serial_get_header(dbheader); // read header + serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap if((lshHeader->flags & O2_SERIAL_FILEFORMAT2)){ serial_close(dbfid); @@ -1648,23 +1887,32 @@ calling_instance = caller; // class instance variable used in ...bucket_chain_point() add_point_callback = add_point; - SerialElementT *pe = (SerialElementT *)malloc(hashTableSize); for(Uns32T j=0; j<L; j++){ - // read a single hash table for random access - serial_get_table(dbfid, j, pe, hashTableSize); + // memory map a single hash table for random access + char* db = serial_mmap(dbfid, hashTableSize, 0, + align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); +#ifdef __CYGWIN__ + // No madvise in CYGWIN +#else + if(madvise(db, hashTableSize, MADV_RANDOM)<0) + error("could not advise local hashtable memory","","madvise"); +#endif + SerialElementT* pe = (SerialElementT*)db ; for(Uns32T qpos=0; qpos<vv.size(); qpos++){ H::compute_hash_functions(vv[qpos]); H::generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row } + serial_munmap(db, hashTableSize); // drop hashtable mmap } - free(pe); serial_close(dbfid); } void G::serial_retrieve_point(char* filename, vector<float>& v, Uns32T qpos, ReporterCallbackPtr add_point, void* caller){ int dbfid = serial_open(filename, 0); // open for read - serial_get_header(dbfid); + char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer + serial_get_header(dbheader); // read header + serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap if((lshHeader->flags & O2_SERIAL_FILEFORMAT2)){ serial_close(dbfid); @@ -1676,27 +1924,41 @@ calling_instance = caller; add_point_callback = add_point; H::compute_hash_functions(v); - - SerialElementT *pe = (SerialElementT *)malloc(hashTableSize); for(Uns32T j=0; j<L; j++){ - // read a single hash table for random access - serial_get_table(dbfid, j, pe, hashTableSize); + // memory map a single hash table for random access + char* db = serial_mmap(dbfid, hashTableSize, 0, + align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); +#ifdef __CYGWIN__ + // No madvise in CYGWIN +#else + if(madvise(db, hashTableSize, MADV_RANDOM)<0) + error("could not advise local hashtable memory","","madvise"); +#endif + SerialElementT* pe = (SerialElementT*)db ; H::generate_hash_keys(*(g+j),*(r1+j),*(r2+j)); serial_bucket_chain_point(pe+t1*lshHeader->numCols, qpos); // Point to correct row - } - free(pe); + serial_munmap(db, hashTableSize); // drop hashtable mmap + } serial_close(dbfid); } void G::serial_dump_tables(char* filename){ int dbfid = serial_open(filename, 0); // open for read - serial_get_header(dbfid); + char* dbheader = serial_mmap(dbfid, O2_SERIAL_HEADER_SIZE, 0);// get database pointer + serial_get_header(dbheader); // read header + serial_munmap(dbheader, O2_SERIAL_HEADER_SIZE); // drop header mmap Uns32T hashTableSize=sizeof(SerialElementT)*lshHeader->numRows*lshHeader->numCols; - SerialElementT *db = (SerialElementT *)malloc(hashTableSize); for(Uns32T j=0; j<L; j++){ - // read a single hash table for random access - serial_get_table(dbfid, j, db, hashTableSize); - SerialElementT *pe = db; + // memory map a single hash table for random access + char* db = serial_mmap(dbfid, hashTableSize, 0, + align_up(get_serial_hashtable_offset()+j*hashTableSize,get_page_logn())); +#ifdef __CYGWIN__ + // No madvise in CYGWIN +#else + if(madvise(db, hashTableSize, MADV_SEQUENTIAL)<0) + error("could not advise local hashtable memory","","madvise"); +#endif + SerialElementT* pe = (SerialElementT*)db ; printf("*********** TABLE %d ***************\n", j); fflush(stdout); int count=0; @@ -1707,7 +1969,6 @@ pe+=lshHeader->numCols; }while(pe<(SerialElementT*)db+lshHeader->numRows*lshHeader->numCols); } - free(db); } void G::serial_bucket_dump(SerialElementT* pe){
--- a/lshlib.h Sat Nov 20 15:32:58 2010 +0000 +++ b/lshlib.h Thu Nov 25 13:42:40 2010 +0000 @@ -11,13 +11,37 @@ #ifndef __LSHLIB_H #define __LSHLIB_H -using namespace std; +#include <vector> +#include <queue> +#include <stdio.h> +#include <stdlib.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <string.h> +#include <iostream> +#include <fstream> +#include <math.h> +#include <sys/time.h> +#include <assert.h> +#include <float.h> +#include <signal.h> +#include <time.h> +#include <limits.h> +#include <errno.h> +#ifdef MT19937 +#include "mt19937/mt19937ar.h" +#endif +#include "multiprobe.h" #define IntT int #define LongUns64T long long unsigned #define Uns32T unsigned #define Int32T int #define BooleanT int +#define TRUE 1 +#define FALSE 0 // A big number (>> max # of points) #define INDEX_START_EMPTY 1000000000U @@ -55,7 +79,7 @@ #define O2_INDEX_MAXSTR (256) -#define align_up(x,w) (((x) + ((1<<w)-1)) & ~((1<<w)-1)) +unsigned align_up(unsigned x, unsigned w); #define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \ + sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \ @@ -71,6 +95,7 @@ fclose(dbFile);error("write error in serial_write_format2",TOKENSTR);} //#define LSH_DUMP_CORE_TABLES // set to dump hashtables on load +//#define _LSH_DEBUG_ // turn on debugging information //#define USE_U_FUNCTIONS // set to use partial hashfunction re-use // Backward-compatible CORE ARRAY lsh index @@ -84,8 +109,14 @@ #define LSH_CORE_ARRAY_BIT (0x80000000) // LSH_CORE_ARRAY test bit for list head +#ifndef LSH_MULTI_PROBE_COUNT +#define LSH_MULTI_PROBE_COUNT 1 // How many adjacent hash-buckets to probe in LSH retrieval +#endif + Uns32T get_page_logn(); +using namespace std; + // Disk table entry typedef class SerialElement SerialElementT; class SerialElement { @@ -202,11 +233,13 @@ Uns32T bucketCount; // count of number of point buckets allocated Uns32T pointCount; // count of number of points inserted Uns32T collisionCount; // number of points collided in a hash-table row + Uns32T tablesPointCount; // count of points per hash table on load Uns32T t1; // first hash table key Uns32T t2; // second hash table key Uns32T P; // hash table prime number + Uns32T N; // num rows per table Uns32T C; // num collision per row Uns32T k; // num projections per hash function @@ -217,6 +250,9 @@ float w; // width of hash slots (relative to normalized feature space) float radius;// scaling coefficient for data (1./radius) + MultiProbe* multiProbePtr; // Utility class for handling multi-probe queries + float ** boundaryDistances; // Array of query bucket-boundary-distances per hashtable + void initialize_data_structures(); void initialize_lsh_functions(); void initialize_partial_functions(); @@ -250,6 +286,7 @@ // Interface to hash functions void compute_hash_functions(vector<float>& v); void generate_hash_keys(Uns32T* g, Uns32T* r1, Uns32T* r2); + void generate_multiprobe_keys(Uns32T*g, Uns32T* r1, Uns32T* r2); Uns32T get_t1(){return t1;} // hash-key t1 Uns32T get_t2(){return t2;} // hash-key t2 }; @@ -267,6 +304,8 @@ void release_lock(int fd); int serial_create(char* filename, Uns32T FMT); int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT); + char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0); + void serial_munmap(char* db, Uns32T N); int serial_open(char* filename,int writeFlag); void serial_close(int dbfid); @@ -306,10 +345,7 @@ float* get_serial_hashfunction_base(char* db); SerialElementT* get_serial_hashtable_base(char* db); Uns32T get_serial_hashtable_offset(); // Size of SerialHeader + HashFunctions - SerialHeaderT* serial_get_header(int fd); - void serial_write_header(int fd, SerialHeaderT *header); - void serial_get_table(int, int, void *, size_t); - void serial_write_table(int, int, void *, size_t); + SerialHeaderT* serial_get_header(char* db); SerialHeaderT* lshHeader; // Core Retrieval/Inspections Functions @@ -354,7 +390,8 @@ float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; } char* get_indexName(){return indexName;} void dump_hashtables(); - + void dump_core_row(Uns32T n); + void dump_disk_row(char*, Uns32T n); }; typedef class G LSH;
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/multiprobe.cpp Thu Nov 25 13:42:40 2010 +0000 @@ -0,0 +1,263 @@ +/* + * MultiProbe C++ class + * + * Given a vector of LSH boundary distances for a query, + * perform lookup by probing nearby hash-function locations + * + * Implementation using C++ STL + * + * Reference: + * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li, + * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity + * Search", Proc. Intl. Conf. VLDB, 2007 + * + * + * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved + * License: GNU Public License 2.0 + * + */ + +#include "multiprobe.h" + +//#define _TEST_MP_LSH + +bool operator> (const min_heap_element& a, const min_heap_element& b){ + return a.score > b.score; +} + +bool operator< (const min_heap_element& a, const min_heap_element& b){ + return a.score < b.score; +} + +bool operator>(const sorted_distance_functions& a, const sorted_distance_functions& b){ + return a.first > b.first; +} + +bool operator<(const sorted_distance_functions& a, const sorted_distance_functions& b){ + return a.first < b.first; +} + +MinHeapElement::MinHeapElement(perturbation_set a, float s): + perturbs(a), + score(s) +{ + +} + +MinHeapElement::~MinHeapElement(){;} + +MultiProbe::MultiProbe(): + minHeap(0), + outSets(0), + distFuns(0), + numHashBoundaries(0) +{ + +} + +MultiProbe::~MultiProbe(){ + cleanup(); +} + +void MultiProbe::initialize(){ + minHeap = new min_heap_of_perturbation_set(); + outSets = new min_heap_of_perturbation_set(); +} + +void MultiProbe::cleanup(){ + delete minHeap; + minHeap = 0; + delete outSets; + outSets = 0; + delete distFuns; + distFuns = 0; +} + +size_t MultiProbe::size(){ + return outSets->size(); +} + +bool MultiProbe::empty(){ + return !outSets->size(); +} + + +void MultiProbe::generatePerturbationSets(vector<float>& x, unsigned T){ + cleanup(); // Make re-entrant + initialize(); + makeSortedDistFuns(x); + algorithm1(T); +} + +// overloading to support efficient array use without initial copy +void MultiProbe::generatePerturbationSets(float* x, unsigned N, unsigned T){ + cleanup(); // Make re-entrant + initialize(); + makeSortedDistFuns(x, N); + algorithm1(T); +} + +// Generate the optimal T perturbation sets for current query +// pre-conditions: +// an LSH structure was initialized and passed to constructor +// a query vector was passed to lsh->compute_hash_functions() +// the query-to-boundary distances are stored in x[hashFunIndex] +// +// post-conditions: +// generates an ordered list of perturbation sets (stored as an array of sets) +// these are indexes into pi_j=(i,delta) pairs representing x_i(delta) in sort order z_j +// data structures are cleared and reset to zeros thereby making them re-entrant +// +void MultiProbe::algorithm1(unsigned T){ + perturbation_set ai,as,ae; + float ai_score; + ai.insert(0); // Initialize for this query + minHeap->push(min_heap_element(ai, score(ai))); // unique instance stored in mhe + + min_heap_element mhe = minHeap->top(); + + if(T>distFuns->size()) + T = distFuns->size(); + for(unsigned i = 0 ; i != T ; i++ ){ + do{ + mhe = minHeap->top(); + ai = mhe.perturbs; + ai_score = mhe.score; + minHeap->pop(); + as=ai; + shift(as); + minHeap->push(min_heap_element(as, score(as))); + ae=ai; + expand(ae); + minHeap->push(min_heap_element(ae, score(ae))); + }while(!valid(ai)); + outSets->push(mhe); // Ordered list of perturbation sets + } +} + +void MultiProbe::dump(perturbation_set a){ + perturbation_set::iterator it = a.begin(); + while(it != a.end()){ + cout << "[" << (*distFuns)[*it].second.first << "," << (*distFuns)[*it].second.second << "]" << " " + << (*distFuns)[*it].first << *it << ", "; + it++; + } + cout << "(" << score(a) << ")"; + cout << endl; +} + +// Given the set a, add 1 to last element of the set +inline perturbation_set& MultiProbe::shift(perturbation_set& a){ + perturbation_set::iterator it = a.end(); + int val = *(--it) + 1; + a.erase(it); + a.insert(it,val); + return a; +} + +// Given the set a, add a new element one greater than the max +inline perturbation_set& MultiProbe::expand(perturbation_set& a){ + perturbation_set::reverse_iterator ri = a.rbegin(); + a.insert(*ri+1); + return a; +} + +// Take the list of distances (x) assuming list len is 2M and +// delta = (-1)^i, i = { 0 .. 2M-1 } +void MultiProbe::makeSortedDistFuns(vector<float>& x){ + numHashBoundaries = x.size(); // x.size() == 2M + delete distFuns; + distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries); + for(unsigned i = 0; i != numHashBoundaries ; i++ ) + (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1)); + // SORT + sort( distFuns->begin(), distFuns->end() ); +} + +// Float array version of above +void MultiProbe::makeSortedDistFuns(float* x, unsigned N){ + numHashBoundaries = N; // x.size() == 2M + delete distFuns; + distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries); + for(unsigned i = 0; i != numHashBoundaries ; i++ ) + (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1)); + // SORT + sort( distFuns->begin(), distFuns->end() ); +} + +// For a given perturbation set, the score is the +// sum of squares of corresponding distances in x +float MultiProbe::score(perturbation_set& a){ + //assert(!a.empty()); + float score = 0.0, tmp; + perturbation_set::iterator it; + it = a.begin(); + do{ + tmp = (*distFuns)[*it].first; + score += tmp*tmp; + }while( ++it != a.end() ); + return score; +} + +// A valid set must have at most one +// of the two elements {j, 2M + 1 - j} for every j +// +// A perturbation set containing an element > 2M is invalid +bool MultiProbe::valid(perturbation_set& a){ + int j; + perturbation_set::iterator it = a.begin(); + while( it != a.end() ){ + j = *it; + it++; + if( ( (unsigned)j > numHashBoundaries ) || ( a.find( numHashBoundaries - j - 1 ) != a.end() ) ) + return false; + } + return true; +} + +int MultiProbe::getIndex(perturbation_set::iterator it){ + return (*distFuns)[*it].second.first; +} + +int MultiProbe::getBoundary(perturbation_set::iterator it){ + return (*distFuns)[*it].second.second; +} + +// copy return next perturbation_set +perturbation_set MultiProbe::getNextPerturbationSet(){ + perturbation_set s = outSets->top().perturbs; + outSets->pop(); + return s; +} + +// Test routine: generate 100 random boundary distance pairs +// call generatePerturbationSets on these distances +// dump output for inspection +#ifdef _TEST_MP_LSH +int main(const int argc, const char* argv[]){ + int N_SAMPS = 100; // Number of random samples + int W = 4; // simulated hash-bucket size + int N_ITER = 100; // How many re-entrant iterations + unsigned T = 10; // Number of multi-probe sets to generate + + MultiProbe mp= MultiProbe(); + vector<float> x(N_SAMPS); + + srand((unsigned)time(0)); + + // Test re-entrance on single instance + for(int j = 0; j< N_ITER ; j++){ + cout << "********** ITERATION " << j << " **********" << endl; + cout.flush(); + for (int i = 0 ; i != x.size()/2 ; i++ ){ + x[2*i] = W*(rand()/(RAND_MAX+1.0)); + x[2*i+1] = W - x[2*i]; + } + // Generate multi-probe sets + mp.generatePerturbationSets(x, T); + // Output contents of multi-probe sets + while(!mp.empty()) + mp.dump(mp.getNextPerturbationSet()); + } +} +#endif
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/multiprobe.h Thu Nov 25 13:42:40 2010 +0000 @@ -0,0 +1,126 @@ +/* + * MultiProbe C++ class + * + * Given a vector of LSH boundary distances for a query, + * perform lookup by probing nearby hash-function locations + * + * Implementation using C++ STL + * + * Reference: + * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li, + * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity + * Search", Proc. Intl. Conf. VLDB, 2007 + * + * + * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved + * License: GNU Public License 2.0 + * + */ + +#ifndef __LSH_MULTIPROBE_ +#define __LSH_MULTIPROBE_ + +#include <functional> +#include <queue> +#include <vector> +#include <set> +#include <algorithm> +#include <iostream> + +using namespace std; + +typedef set<int > perturbation_set ; + +typedef class MinHeapElement{ +public: + perturbation_set perturbs; + float score; + MinHeapElement(perturbation_set a, float s); + virtual ~MinHeapElement(); +} min_heap_element; + +typedef priority_queue<min_heap_element, + vector<min_heap_element>, + greater<min_heap_element> + > min_heap_of_perturbation_set ; + +typedef pair<float, pair<int, int> > sorted_distance_functions ; + +class MultiProbe{ +protected: + min_heap_of_perturbation_set* minHeap; + min_heap_of_perturbation_set* outSets; + vector<sorted_distance_functions>* distFuns; + unsigned numHashBoundaries; + + // data structure initialization and cleanup + void initialize(); + void cleanup(); + + // perturbation set operations + perturbation_set& shift(perturbation_set&); + perturbation_set& expand(perturbation_set&); + float score(perturbation_set&); + bool valid(perturbation_set&); + void makeSortedDistFuns(vector<float> &); + void makeSortedDistFuns(float* x, unsigned N); + + // perturbation set generation algorithm + void algorithm1(unsigned T); + +public: + MultiProbe(); + ~MultiProbe(); + + // generation of perturbation sets + void generatePerturbationSets(vector<float>& vectorOfBounaryDistances, unsigned numSetsToGenerate); + void generatePerturbationSets(float* arrayOfBoundaryDistances, unsigned numDistances, unsigned numSetsToGenerate); + perturbation_set getNextPerturbationSet(); + void dump(perturbation_set); + size_t size(); // Number of perturbation sets are in the output queue + bool empty(); // predicate for empty MultiProbe set + int getIndex(perturbation_set::iterator it); // return index of hash function for given set entry + int getBoundary(perturbation_set::iterator it); // return boundary {-1,+1} for given set entry +}; + + +/* NOTES: + +Reference: +Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li, +"Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity +Search", Proc. Intl. Conf. VLDB, 2007 + + i = 1..M (number of hash functions used) + f_i(q) = a_i * q + b_i // projection to real line + h_i(q) = floor( ( a_i * q + b_i ) / w ) // hash slot + delta \in {-1, +1} + x_i( delta ) = distance of query to left or right boundary + Delta = [delta_1 delta_2 ... delta_M] + score(Delta) = sum_{i=1}_M x_i(delta_i).^2 + + z = sort(x(delta_i), increasing) // i = 1..2M delta={+1,-1} + p_j = (i, delta) if z_j == x_i(delta) + + A_k is index into p_j + + Multi-probe algorithm (after Lv et al. 2007) + --------------------------------------------- + + A0 = {1} + minHeap.insert(A0, score(A0)) + + for i = 1 to T do // T = number of perturbation sets + repeat + Ai = minHeap.extractMin() + As = shift(Ai) + minHeap.insert(As, score(As)) + Ae = expand(Ai) + minHeap.insert(Ae, score(Ae)) + until valid(Ai) + output Ai + end for + + */ + +#endif // __LSH_MULTIPROBE_