view lshlib.h @ 770:c54bc2ffbf92 tip

update tags
author convert-repo
date Fri, 16 Dec 2011 11:34:01 +0000
parents 77f7bc99dfd6
children
line wrap: on
line source
// lshlib.h - a library for locality sensitive hashtable insertion and retrieval
//
// Author: Michael Casey
// Copyright (c) 2008 Michael Casey, All Rights Reserved

/*  		    GNU GENERAL PUBLIC LICENSE
		       Version 2, June 1991
		       See LICENSE.txt
*/

#ifndef __LSHLIB_H
#define __LSHLIB_H

#include <vector>
#include <queue>
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <string.h>
#include <iostream>
#include <fstream>
#include <math.h>
#include <sys/time.h>
#include <assert.h>
#include <float.h>
#include <signal.h>
#include <time.h>
#include <limits.h>
#include <errno.h>
#ifdef MT19937
#include "mt19937/mt19937ar.h"
#endif
#include "multiprobe.h"

#define IntT int
#define LongUns64T long long unsigned
#define Uns32T unsigned
#define Int32T int
#define BooleanT int
#define TRUE 1
#define FALSE 0

// A big number (>> max #  of points)
#define INDEX_START_EMPTY 1000000000U

// 4294967291 = 2^32-5
#define UH_PRIME_DEFAULT 4294967291U

// 2^29
#define MAX_HASH_RND 536870912U

// 2^32-1
#define TWO_TO_32_MINUS_1 4294967295U

#define O2_SERIAL_VERSION 1 // Sync with SVN version
#define O2_SERIAL_HEADER_SIZE sizeof(SerialHeaderT)
#define O2_SERIAL_ELEMENT_SIZE sizeof(SerialElementT)
#define O2_SERIAL_MAX_TABLES (200)
#define O2_SERIAL_MAX_ROWS (1000000000)
#define O2_SERIAL_MAX_COLS (1000000)
#define O2_SERIAL_MAX_DIM (20000)
#define O2_SERIAL_MAX_FUNS (100)
#define O2_SERIAL_MAX_BINWIDTH (200)
#define O2_SERIAL_MAXFILESIZE (4000000000UL)

// Flags for Serial Header
#define O2_SERIAL_FILEFORMAT1 (0x1U)       // Optimize disk format for on-disk search
#define O2_SERIAL_FILEFORMAT2 (0x2U)       // Optimize disk format for in-core search
#define O2_SERIAL_COREFORMAT1 (0x4U)
#define O2_SERIAL_COREFORMAT2 (0x8U)

// Flags for serialization fileformat2: use high 3 bits of Uns32T
#define O2_SERIAL_TOKEN_T1 (0xFFFFFFFCU)
#define O2_SERIAL_TOKEN_T2 (0xFFFFFFFDU)
#define O2_SERIAL_TOKEN_ENDTABLE (0xFFFFFFFEU)

#define O2_INDEX_MAXSTR (256)

unsigned align_up(unsigned x, unsigned w);

#define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \
+ sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \
+ sizeof(Uns32T) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * 2 \
+ O2_SERIAL_HEADER_SIZE,get_page_logn()))

#define O2_SERIAL_MAX_LSH_SIZE (O2_SERIAL_ELEMENT_SIZE * O2_SERIAL_MAX_TABLES \
			    * O2_SERIAL_MAX_ROWS * O2_SERIAL_MAX_COLS + O2_SERIAL_FUNCTIONS_SIZE)

#define O2_SERIAL_MAGIC ('o'|'2'<<8|'l'<<16|'s'<<24)

#define WRITE_UNS32(VAL, TOKENSTR) if( fwrite(VAL, sizeof(Uns32T), 1, dbFile) != 1 ){\
  fclose(dbFile);error("write error in serial_write_format2",TOKENSTR);}	

//#define LSH_DUMP_CORE_TABLES  // set to dump hashtables on load
//#define _LSH_DEBUG_             // turn on debugging information
//#define USE_U_FUNCTIONS       // set to use partial hashfunction re-use

// Backward-compatible CORE ARRAY lsh index
#define LSH_CORE_ARRAY  // Set to use arrays for hashtables rather than linked-lists
#define LSH_LIST_HEAD_COUNTERS // Enable counters in hashtable list heads

// Critical path logic
#if defined LSH_CORE_ARRAY && !defined LSH_LIST_HEAD_COUNTERS
#define LSH_LIST_HEAD_COUNTERS
#endif

#define LSH_CORE_ARRAY_BIT (0x80000000) //  LSH_CORE_ARRAY test bit for list head

#ifndef LSH_MULTI_PROBE_COUNT
#define LSH_MULTI_PROBE_COUNT 1 // How many adjacent hash-buckets to probe in LSH retrieval
#endif

Uns32T get_page_logn();

using namespace std;

// Disk table entry
typedef class SerialElement SerialElementT;
class SerialElement {
public:
  Uns32T hashValue;
  Uns32T pointID;
  
  SerialElement(Uns32T h, Uns32T pID): 
    hashValue(h), 
    pointID(pID){}
};

// Disk header
typedef class SerialHeader SerialHeaderT;
class SerialHeader {
public:
  Uns32T lshMagic;    // unique identifier for file header
  float  binWidth;    // hash-function bin width
  Uns32T numTables;   // number of hash tables in file
  Uns32T numRows;     // size of each hash table
  Uns32T numCols;     // max collisions in each hash table
  Uns32T elementSize; // size of a hash bucket
  Uns32T version;     // version number of file format
  Uns32T size;        // total size of database (bytes)
  Uns32T flags;       // 32 bits of useful information
  Uns32T dataDim;     // vector dimensionality
  Uns32T numFuns;     // number of independent hash functions
  float radius;       // 32-bit floating point radius
  Uns32T maxp;        // number of unique IDs in the database
  unsigned long long size_long; // long version of size
  Uns32T pointCount;     // number of points in the database

  SerialHeader();
  SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float radius, Uns32T p, Uns32T FMT, Uns32T pointCount);

  float get_binWidth(){return binWidth;}  
  Uns32T get_numTables(){return numTables;}
  Uns32T get_numRows(){return numRows;}
  Uns32T get_numCols(){return numCols;}
  Uns32T get_elementSize(){return elementSize;}
  Uns32T get_version(){return version;}
  Uns32T get_flags(){return flags;}
  unsigned long long get_size(){return size_long;}
  Uns32T get_dataDim(){return dataDim;}
  Uns32T get_numFuns(){return numFuns;}
  Uns32T get_maxp(){return maxp;}
  Uns32T get_pointCount(){return pointCount;}
};

#define IFLAG 0xFFFFFFFF

// Point-set collision bucket (sbucket). 
// sbuckets form a collision chain that identifies PointIDs falling in the same locale.
// sbuckets are chained from a bucket containing the collision list's t2 identifier
class sbucket {
  friend class bucket;
  friend class H;
  friend class G;

 public:
  class sbucket* snext;
  unsigned int pointID;

  sbucket(){
    snext=0;
    pointID=IFLAG;
  }
  ~sbucket(){delete snext;}
  sbucket* get_snext(){return snext;}
};

// bucket structure for a linked list of locales that collide with the same hash value t1
// different buckets represent different locales, collisions within a locale are chained
// in sbuckets
class bucket {
  friend class H;
  friend class G;
  bucket* next;
  union {
    sbucket* ptr;
    Uns32T numBuckets;
  } snext;
 public:
  unsigned int t2;  
  bucket(){
    next=0;
    snext.ptr=0;
    t2=IFLAG;
  }
  ~bucket(){delete next;delete snext.ptr;}
  bucket* get_next(){return next;}
};


// The hash_functions for locality-sensitive hashing
class H{
  friend class G;
 private:

  float *** A; // m x k x d random projectors from R^d->R^k
  float ** b;  // m x k uniform additive constants

  Uns32T ** g; // L x k random hash projections \in Z^k    
  Uns32T** r1;     // random ints for hashing
  Uns32T** r2;     // random ints for hashing

  bucket*** h;      // The LSH hash tables

  bool use_u_functions; // flag to optimize computation of hashes
#ifdef USE_U_FUNCTIONS
  vector<vector<Uns32T> > uu; // Storage for m patial hash evaluations ( g_j = [u_a,u_b] )
#endif
  Uns32T maxp; // highest pointID stored in database
  Uns32T bucketCount;  // count of number of point buckets allocated
  Uns32T pointCount;    // count of number of points inserted
  Uns32T collisionCount; // number of points collided in a hash-table row
  Uns32T tablesPointCount; // count of points per hash table on load

  Uns32T t1;       // first hash table key
  Uns32T t2;       // second hash table key
  Uns32T P;        // hash table prime number


  Uns32T N; // num rows per table
  Uns32T C; // num collision per row
  Uns32T k; // num projections per hash function
  Uns32T m; // ~sqrt num hash tables
  Uns32T L; // L = m*(m-1)/2, conversely, m = (1 + sqrt(1 + 8.0*L)) / 2.0
  Uns32T d; // dimensions
  Uns32T p; // current point
  float w;     // width of hash slots (relative to normalized feature space)
  float radius;// scaling coefficient for data (1./radius)

  MultiProbe* multiProbePtr; // Utility class for handling multi-probe queries
  float ** boundaryDistances; // Array of query bucket-boundary-distances per hashtable

  void initialize_data_structures();
  void initialize_lsh_functions();
  void initialize_partial_functions();
  void __bucket_insert_point(bucket*);
  void __sbucket_insert_point(sbucket*);
  bucket** get_pointer_to_bucket_linked_list(bucket* rowPtr);
  Uns32T computeProductModDefaultPrime(Uns32T*,Uns32T*,IntT);
  Uns32T randr();
  float randn();
  float ranf();
  bucket** get_bucket(int j);
  void error(const char* a, const char* b = "", const char *sysFunc = 0);

 public:

  H();
  H(Uns32T k, Uns32T m, Uns32T d, Uns32T N, Uns32T C, float w, float r);
  virtual ~H();

  float get_w(){return w;}
  float get_radius(){return radius;}  
  Uns32T get_numRows(){return N;}
  Uns32T get_numCols(){return C;}
  Uns32T get_numFuns(){return k;}
  Uns32T get_numTables(){return L;}
  Uns32T get_dataDim(){return d;}
  Uns32T get_maxp(){return maxp;}  

  Uns32T bucket_insert_point(bucket**);

  // Interface to hash functions
  void compute_hash_functions(vector<float>& v);
  void generate_hash_keys(Uns32T* g, Uns32T* r1, Uns32T* r2);
  void generate_multiprobe_keys(Uns32T*g, Uns32T* r1, Uns32T* r2);
  Uns32T get_t1(){return t1;} // hash-key t1
  Uns32T get_t2(){return t2;} // hash-key t2
};

// Typedef for point-reporting callback function. Used to collect points during LSH retrieval
typedef void (*ReporterCallbackPtr)(void* objPtr, Uns32T pointID, Uns32T queryIndex, float squaredDistance);

// Interface for indexing and retrieval
class G: public H{
 private:
  char* indexName;

  // LSH serial data structure file handling
  void get_lock(int fd, bool exclusive);
  void release_lock(int fd);
  int serial_create(char* filename, Uns32T FMT);
  int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT);
  char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0);
  void serial_munmap(char* db, Uns32T N);
  int serial_open(char* filename,int writeFlag);
  void serial_close(int dbfid);

  // Function to write hashfunctions to disk
  int serialize_lsh_hashfunctions(int fid);

  // Functions to write hashtables to disk in format1 (optimized for on-disk retrieval)
  int serialize_lsh_hashtables_format1(int fid, int merge);
  void serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* h, Uns32T& colCount);
  void serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
  void serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* h, Uns32T& colCount);
  void serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
  int serial_can_merge(Uns32T requestedFormat); // Test to see whether core and on-disk structures are compatible

  // Functions to write hashtables to disk in format2 (optimized for in-core retrieval)
  int serialize_lsh_hashtables_format2(FILE* dbFile, int merge);
  void serial_write_hashtable_row_format2(FILE* dbFile, bucket* h, Uns32T& colCount);
  void serial_write_element_format2(FILE* dbFile, sbucket* sb, Uns32T& colCount);
  Uns32T count_buckets_and_points_hashtable_row(bucket* bPtr);
  Uns32T count_points_hashtable_row(bucket* bPtr);

  // Functions to read serial header and hash functions (format1 and format2)
  int unserialize_lsh_header(char* filename);            // read lsh header from disk into core
  void unserialize_lsh_functions(int fid);               // read the lsh hash functions into core

  // Functions to read hashtables in format1
  void unserialize_lsh_hashtables_format1(int fid);       // read FORMAT1 hash tables into core (disk format)
  void unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b); // read lsh hash table row into core

  // Functions to read hashtables in format2
  void unserialize_lsh_hashtables_format2(FILE* dbFile, bool forMerge = 0);
  Uns32T unserialize_hashtable_row_format2(FILE* dbFile, bucket** b, Uns32T token=0); // to dynamic linked list
  Uns32T unserialize_hashtable_row_to_array(FILE* dbFile, bucket** b, Uns32T numElements); // to core array

  // Helper functions
  void serial_print_header(Uns32T requestedFormat);
  float* get_serial_hashfunction_base(char* db);
  SerialElementT* get_serial_hashtable_base(char* db);  
  Uns32T get_serial_hashtable_offset();                   // Size of SerialHeader + HashFunctions
  SerialHeaderT* serial_get_header(char* db);
  SerialHeaderT* lshHeader;

  // Core Retrieval/Inspections Functions
  void bucket_chain_point(bucket* p, Uns32T qpos);
  void sbucket_chain_point(sbucket* p, Uns32T qpos);  
  void dump_hashtable_row(bucket* p);
  void dump_core_hashtable_array(Uns32T* p);

  // Serial (Format 1) Retrieval/Inspection Functions
  void serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos);
  void serial_bucket_dump(SerialElementT* pe);

  // Core ARRAY Retrieval Functions
  void retrieve_from_core_hashtable_array(Uns32T* p, Uns32T qpos);


  // Callback Function for point reporting
  void* calling_instance; // store calling object instance for member-function callback
  ReporterCallbackPtr add_point_callback; // Pointer to the callback function

 public:
  G(char* lshFile, bool lshInCore = false); // unserialize constructor
  G(float w, Uns32T k,Uns32T m, Uns32T d, Uns32T N, Uns32T C, float r); // core constructor
  virtual ~G();

  Uns32T insert_point(vector<float>&, Uns32T pointID);
  void insert_point_set(vector<vector<float> >& vv, Uns32T basePointID);

  // point retrieval from core
  void retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);
  // point set retrieval from core
  void retrieve_point_set(vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
  // serial point set retrieval
  void serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
  // serial point retrieval
  void serial_retrieve_point(char* filename, vector<float>& vv, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);

  void serialize(char* filename, Uns32T serialFormat = O2_SERIAL_FILEFORMAT1); // write hashfunctions and hashtables to disk

  SerialHeaderT* get_lshHeader(){return lshHeader;}
  void serial_dump_tables(char* filename);
  float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; }
  char* get_indexName(){return indexName;}
  void dump_hashtables();
  void dump_core_row(Uns32T n);
  void dump_disk_row(char*, Uns32T n);
};

typedef class G LSH;

 

#endif