diff lshlib.h @ 292:d9a88cfd4ab6

Completed merge of lshlib back to current version of the trunk.
author mas01mc
date Tue, 29 Jul 2008 22:01:17 +0000
parents
children 9fd5340faffd
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/lshlib.h	Tue Jul 29 22:01:17 2008 +0000
@@ -0,0 +1,336 @@
+// lshlib.h - a library for locality sensitive hashtable insertion and retrieval
+//
+// Author: Michael Casey
+// Copyright (c) 2008 Michael Casey, All Rights Reserved
+
+/*  		    GNU GENERAL PUBLIC LICENSE
+		       Version 2, June 1991
+		       See LICENSE.txt
+*/
+
+#ifndef __LSHLIB_H
+#define __LSHLIB_H
+
+#include <vector>
+#include <queue>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <string.h>
+#include <iostream>
+#include <fstream>
+#include <math.h>
+#include <sys/time.h>
+#include <assert.h>
+#include <float.h>
+#include <signal.h>
+#include <time.h>
+#include <limits.h>
+#include <errno.h>
+#ifdef MT19937
+#include "mt19937/mt19937ar.h"
+#endif
+
+#define IntT int
+#define LongUns64T long long unsigned
+#define Uns32T unsigned
+#define Int32T int
+#define BooleanT int
+#define TRUE 1
+#define FALSE 0
+
+// A big number (>> max #  of points)
+#define INDEX_START_EMPTY 1000000000U
+
+// 4294967291 = 2^32-5
+#define UH_PRIME_DEFAULT 4294967291U
+
+// 2^29
+#define MAX_HASH_RND 536870912U
+
+// 2^32-1
+#define TWO_TO_32_MINUS_1 4294967295U
+
+#define O2_SERIAL_VERSION 1 // Sync with SVN version
+#define O2_SERIAL_HEADER_SIZE sizeof(SerialHeaderT)
+#define O2_SERIAL_ELEMENT_SIZE sizeof(SerialElementT)
+#define O2_SERIAL_MAX_TABLES (200)
+#define O2_SERIAL_MAX_ROWS (1000000)
+#define O2_SERIAL_MAX_COLS (1000)
+#define O2_SERIAL_MAX_DIM (2000)
+#define O2_SERIAL_MAX_FUNS (100)
+#define O2_SERIAL_MAX_BINWIDTH (200)
+
+// Flags for Serial Header
+#define O2_SERIAL_FILEFORMAT1 (0x1U)       // Optimize for on-disk search
+#define O2_SERIAL_FILEFORMAT2 (0x2U)       // Optimize for in-core search
+
+// Flags for serialization fileformat2: use high 3 bits of Uns32T
+#define O2_SERIAL_FLAGS_T1_BIT (0x80000000U)
+#define O2_SERIAL_FLAGS_T2_BIT (0x40000000U)
+#define O2_SERIAL_FLAGS_END_BIT (0x20000000U)
+
+unsigned align_up(unsigned x, unsigned w);
+
+#define O2_SERIAL_FUNCTIONS_SIZE (align_up(sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * O2_SERIAL_MAX_DIM \
++ sizeof(float) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS + \
++ sizeof(Uns32T) * O2_SERIAL_MAX_TABLES * O2_SERIAL_MAX_FUNS * 2 \
++ O2_SERIAL_HEADER_SIZE,get_page_logn()))
+
+#define O2_SERIAL_MAX_LSH_SIZE (O2_SERIAL_ELEMENT_SIZE * O2_SERIAL_MAX_TABLES \
+			    * O2_SERIAL_MAX_ROWS * O2_SERIAL_MAX_COLS + O2_SERIAL_FUNCTIONS_SIZE)
+
+#define O2_SERIAL_MAGIC ('o'|'2'<<8|'l'<<16|'s'<<24)
+
+using namespace std;
+
+Uns32T get_page_logn();
+
+// Disk table entry
+typedef class SerialElement SerialElementT;
+class SerialElement {
+public:
+  Uns32T hashValue;
+  Uns32T pointID;
+  
+  SerialElement(Uns32T h, Uns32T pID): 
+    hashValue(h), 
+    pointID(pID){}
+};
+
+// Disk header
+typedef class SerialHeader SerialHeaderT;
+class SerialHeader {
+public:
+  Uns32T lshMagic;    // unique identifier for file header
+  float  binWidth;    // hash-function bin width
+  Uns32T numTables;   // number of hash tables in file
+  Uns32T numRows;     // size of each hash table
+  Uns32T numCols;     // max collisions in each hash table
+  Uns32T elementSize; // size of a hash bucket
+  Uns32T version;     // version number of file format
+  Uns32T size;        // total size of database (bytes)
+  Uns32T flags;       // 32 bits of useful information
+  Uns32T dataDim;     // vector dimensionality
+  Uns32T numFuns;     // number of independent hash functions
+  float radius;       // 32-bit floating point radius
+  Uns32T maxp;        // number of unique IDs in the database
+  Uns32T value_14;    // spare value
+  Uns32T value_15;    // spare value
+  Uns32T value_16;    // spare value
+
+  SerialHeader();
+  SerialHeader(float W, Uns32T L, Uns32T N, Uns32T C, Uns32T k, Uns32T d, float radius, Uns32T p, Uns32T FMT);
+
+  float get_binWidth(){return binWidth;}  
+  Uns32T get_numTables(){return numTables;}
+  Uns32T get_numRows(){return numRows;}
+  Uns32T get_numCols(){return numCols;}
+  Uns32T get_elementSize(){return elementSize;}
+  Uns32T get_version(){return version;}
+  Uns32T get_flags(){return flags;}
+  Uns32T get_size(){return size;}
+  Uns32T get_dataDim(){return dataDim;}
+  Uns32T get_numFuns(){return numFuns;}
+  Uns32T get_maxp(){return maxp;}
+};
+
+#define IFLAG 0xFFFFFFFF
+
+// Point-set collision bucket (sbucket). 
+// sbuckets form a collision chain that identifies PointIDs falling in the same locale.
+// sbuckets are chained from a bucket containing the collision list's t2 identifier
+class sbucket {
+  friend class bucket;
+  friend class H;
+  friend class G;
+
+ public:
+  class sbucket* snext;
+  unsigned int pointID;
+
+  sbucket(){
+    snext=0;
+    pointID=IFLAG;
+  }
+  ~sbucket(){delete snext;}
+  sbucket* get_snext(){return snext;}
+};
+
+// bucket structure for a linked list of locales that collide with the same hash value t1
+// different buckets represent different locales, collisions within a locale are chained
+// in sbuckets
+class bucket {
+  friend class H;
+  friend class G;
+  bucket* next;
+  sbucket* snext;
+ public:
+  unsigned int t2;  
+  bucket(){
+    next=0;
+    snext=0;
+    t2=IFLAG;
+  }
+  ~bucket(){delete next;delete snext;}
+  bucket* get_next(){return next;}
+};
+
+
+// The hash_functions for locality-sensitive hashing
+class H{
+  friend class G;
+ private:
+  bucket*** h;      // hash functions
+  Uns32T** r1;     // random ints for hashing
+  Uns32T** r2;     // random ints for hashing
+  Uns32T t1;       // first hash table key
+  Uns32T t2;       // second hash table key
+  Uns32T P;        // hash table prime number
+  bool use_u_functions; // flag to optimize computation of hashes
+  Uns32T bucketCount;  // count of number of point buckets allocated
+  Uns32T pointCount;    // count of number of points inserted
+
+  void __initialize_data_structures();
+  void __bucket_insert_point(bucket*);
+  void __sbucket_insert_point(sbucket*);
+  Uns32T __computeProductModDefaultPrime(Uns32T*,Uns32T*,IntT);
+  Uns32T __randr();
+  bucket** __get_bucket(int j);
+  void __generate_hash_keys(Uns32T*,Uns32T*,Uns32T*);
+  void error(const char* a, const char* b = "", const char *sysFunc = 0);
+
+ public:
+  Uns32T N; // num rows per table
+  Uns32T C; // num collision per row
+  Uns32T k; // num projections per hash function
+  Uns32T m; // ~sqrt num hash tables
+  Uns32T L; // L = m*(m-1)/2, conversely, m = (1 + sqrt(1 + 8.0*L)) / 2.0
+  Uns32T d; // dimensions
+  Uns32T p; // current point
+
+  H(){;}
+  H(Uns32T k, Uns32T m, Uns32T d, Uns32T N, Uns32T C);
+  ~H();
+
+
+  Uns32T bucket_insert_point(bucket**);
+
+  Uns32T get_t1(){return t1;}
+  Uns32T get_t2(){return t2;}
+
+};
+
+
+typedef void (*ReporterCallbackPtr)(void* objPtr, Uns32T pointID, Uns32T queryIndex, float squaredDistance);
+
+// Interface for indexing and retrieval
+class G: public H{
+ private:
+  float *** A; // m x k x d random projectors from R^d->R^k
+  float ** b;  // m x k uniform additive constants
+  Uns32T ** g; // L x k random hash projections \in Z^k    
+  float w;     // width of hash slots (relative to normalized feature space)
+  float radius;// scaling coefficient for data (1./radius)
+  vector<vector<Uns32T> > uu; // Storage for m patial hash evaluations
+  Uns32T maxp; // highest pointID stored in database
+  void* calling_instance; // store calling object instance for member-function callback
+  void (*add_point_callback)(void*, Uns32T, Uns32T, float);
+
+  void initialize_partial_functions();
+
+  void get_lock(int fd, bool exclusive);
+  void release_lock(int fd);
+  int serial_create(char* filename, Uns32T FMT);
+  int serial_create(char* filename, float binWidth, Uns32T nTables, Uns32T nRows, Uns32T nCols, Uns32T k, Uns32T d, Uns32T FMT);
+  char* serial_mmap(int dbfid, Uns32T sz, Uns32T w, off_t offset = 0);
+  void serial_munmap(char* db, Uns32T N);
+  int serial_open(char* filename,int writeFlag);
+  void serial_close(int dbfid);
+
+  // Function to write hashfunctions to disk
+  int serialize_lsh_hashfunctions(int fid);
+
+  // Functions to write hashtables to disk in format1 (optimized for on-disk retrieval)
+  int serialize_lsh_hashtables_format1(int fid, int merge);
+  void serial_write_hashtable_row_format1(SerialElementT*& pe, bucket* h, Uns32T& colCount);
+  void serial_write_element_format1(SerialElementT*& pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
+  void serial_merge_hashtable_row_format1(SerialElementT* pr, bucket* h, Uns32T& colCount);
+  void serial_merge_element_format1(SerialElementT* pe, sbucket* sb, Uns32T t2, Uns32T& colCount);
+  int serial_can_merge(Uns32T requestedFormat); // Test to see whether core and on-disk structures are compatible
+
+  // Functions to write hashtables to disk in format2 (optimized for in-core retrieval)
+  int serialize_lsh_hashtables_format2(int fid, int merge);
+  void serial_write_hashtable_row_format2(int fid, bucket* h, Uns32T& colCount);
+  void serial_write_element_format2(int fid, sbucket* sb, Uns32T& colCount);
+
+  // Functions to read serial header and hash functions (format1 and format2)
+  int unserialize_lsh_header(char* filename);            // read lsh header from disk into core
+  void unserialize_lsh_functions(int fid);               // read the lsh hash functions into core
+
+  // Functions to read hashtables in format1
+  void unserialize_lsh_hashtables_format1(int fid);       // read FORMAT1 hash tables into core (disk format)
+  void unserialize_hashtable_row_format1(SerialElementT* pe, bucket** b); // read lsh hash table row into core
+
+  // Functions to read hashtables in format2
+  void unserialize_lsh_hashtables_format2(int fid);       // read FORMAT2 hash tables into core (core format)
+  Uns32T unserialize_hashtable_row_format2(int fid, bucket** b); // read lsh hash table row into core
+
+  // Helper functions
+  void serial_print_header(Uns32T requestedFormat);
+  float* get_serial_hashfunction_base(char* db);
+  SerialElementT* get_serial_hashtable_base(char* db);  
+  Uns32T get_serial_hashtable_offset();                   // Size of SerialHeader + HashFunctions
+  SerialHeaderT* serial_get_header(char* db);
+  SerialHeaderT* lshHeader;
+
+  // Core Retrieval/Inspections Functions
+  void bucket_chain_point(bucket* p, Uns32T qpos);
+  void sbucket_chain_point(sbucket* p, Uns32T qpos);  
+  void dump_hashtable_row(bucket* p);
+
+  // Serial (Format 1) Retrieval/Inspection Functions
+  void serial_bucket_chain_point(SerialElementT* pe, Uns32T qpos);
+  void serial_bucket_dump(SerialElementT* pe);
+
+  // Hash functions
+  void compute_hash_functions(vector<float>& v);
+  float randn();
+  float ranf();
+
+  char* db;    // pointer to serialized structure
+
+ public:
+  G(char* lshFile, bool lshInCore = false); // unserialize constructor
+  G(float w, Uns32T k,Uns32T m, Uns32T d, Uns32T N, Uns32T C, float r); // core constructor
+  ~G();
+
+  Uns32T insert_point(vector<float>&, Uns32T pointID);
+  void insert_point_set(vector<vector<float> >& vv, Uns32T basePointID);
+
+  // point retrieval from core
+  void retrieve_point(vector<float>& v, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);
+  // point set retrieval from core
+  void retrieve_point_set(vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
+  // serial point set retrieval
+  void serial_retrieve_point_set(char* filename, vector<vector<float> >& vv, ReporterCallbackPtr, void* me=NULL);
+  // serial point retrieval
+  void serial_retrieve_point(char* filename, vector<float>& vv, Uns32T qpos, ReporterCallbackPtr, void* me=NULL);
+
+  void serialize(char* filename, Uns32T serialFormat = O2_SERIAL_FILEFORMAT1); // write hashfunctions and hashtables to disk
+
+  SerialHeaderT* get_lshHeader(){return lshHeader;}
+  float get_radius(){return radius;}  
+  Uns32T get_maxp(){return maxp;}
+  void serial_dump_tables(char* filename);
+  float get_mean_collision_rate(){ return (float) pointCount / bucketCount ; }
+};
+
+typedef class G LSH;
+
+ 
+
+#endif