view multiprobe.cpp @ 517:807c8be7dd45 multiprobeLSH

Completed multiprobe framework for LSH. Requires testing.
author mas01mc
date Sun, 25 Jan 2009 06:10:38 +0000
parents 2a7bad47a4a7
children ca1ee92c359c
line wrap: on
line source
/*
 * MultiProbe C++ class
 *
 * Given a LSH structure, perform lookup by probing nearby hash-function locations
 *
 * Implementation using C++ STL
 *
 * Reference:
 * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li,
 * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity
 * Search", Proc. Intl. Conf. VLDB, 2007
 *
 *
 * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved
 * License: GNU Public License 2.0
 *
 */

#include "multiprobe.h"


#define _TEST_MP_LSH

MultiProbe::MultiProbe():
  minHeap(0),
  outSets(0),
  distFuns(0),
  numHashBoundaries(0)
{
  minHeap = new min_heap_of_perturbation_set();
  outSets = new vector_of_perturbation_set();
}

MultiProbe::~MultiProbe(){  
  // FIXME: Are these arrays ?  
  delete minHeap;
  delete outSets;
  delete distFuns;
}

// Generate the optimal T perturbation sets for current query
// pre-conditions:
//   an LSH structure was initialized and passed to constructor
//   a query vector was passed to lsh->compute_hash_functions()
//   the query-to-boundary distances are stored in x[hashFunIndex]
//
// post-conditions:
//
//  generates an ordered list of perturbation sets (stored as an array of sets)
//  these are indexes into pi_j=(i,delta) pairs representing x_i(delta) in sort order z_j
//  data structures are cleared and reset to zeros thereby making them re-entrant
//
void MultiProbe::generatePerturbationSets(int T, vector<double>& x){
  perturbation_set ai,as,ae;

  makeSortedDistFuns(x);

  ai.insert(1); // Initialize for this query
  minHeap->push(make_pair(perturbation_set(ai), score(ai))); // unique instance stored in mhe

  for(int i = 0 ; i < T ; i++ ){
    do{
      min_heap_element mhe = minHeap->top();
      minHeap->pop();
      ai = mhe.first;
      cout << "ai: ";
      dump(ai);
      as = perturbation_set(ai);
      shift(as);
      cout << "as: ";
      dump(as);
      minHeap->push(make_pair(perturbation_set(as), score(as)));
      ae = perturbation_set(ai);
      expand(ae);
      cout << "ae: ";
      dump(ae);
      minHeap->push(make_pair(perturbation_set(ae), score(ae)));
    }while(!valid(ai));
    outSets->push_back(ai); // Ordered list of perturbation sets
  }
}

void MultiProbe::dump(perturbation_set& a){
  perturbation_set::iterator it = a.begin();
  while(it != a.end())
    cout << *it++ << " ";
  cout << "\n";
}

// Given the set a, add 1 to last element of the set
perturbation_set& MultiProbe::shift(perturbation_set& a){  
  perturbation_set::iterator it = a.end();
  int val = *(--it) + 1;
  a.erase(it);
  a.insert(it,val);
}

// Given the set a, add a new element one greater than the max
perturbation_set& MultiProbe::expand(perturbation_set& a){
  perturbation_set::reverse_iterator ri = a.rbegin();
  a.insert(*ri+1);
}

// Take the list of distances (x) assuming list len is 2M and
// delta = (-1)^i, i = { 0 .. 2M-1 }
void MultiProbe::makeSortedDistFuns(vector<double>& x){
  numHashBoundaries = x.size(); // x.size() == 2M
  delete distFuns;
  distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries);
  for(int i = 0; i != numHashBoundaries ; i++ ){  
    (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?-1:+1));
  }
  // SORT
  sort( distFuns->begin(), distFuns->end() );
}

// For a given perturbation set, the score is the 
// sum of squares of corresponding distances in x
double MultiProbe::score(perturbation_set& a){
  //assert(!a.empty());
  double score = 0.0, tmp;
  perturbation_set::iterator it;
  it = a.begin();
  do{
    tmp = (*distFuns)[*it++].first;
    score += tmp*tmp;
  }while( it != a.end() );
  return score;
}

// A valid set must have at most one
// of the two elements {j, 2M + 1 - j} for every j
//
// A perturbation set containing an element > 2M is invalid
bool MultiProbe::valid(perturbation_set& a){
  int j;  
  perturbation_set::iterator it = a.begin();  
  while( it != a.end() ){
    j = *it++;
    if( ( j > numHashBoundaries ) || ( a.find( numHashBoundaries + 1 - j ) != a.end() ) )
      return false;    
  }
  return true;
}

#ifdef _TEST_MP_LSH
int main(const int argc, const char* argv[]){
  MultiProbe mp = MultiProbe();  
  vector<double> x(4);
  x[0]=0.1;
  x[1]=0.9;
  x[2]=0.2;
  x[3]=0.8;
  mp.generatePerturbationSets(2, x);
}
#endif