annotate multiprobe.cpp @ 757:ee612b7bd922 mkc_lsh_update

added queryFromData(), passes equivalence test with queryFromKey
author mas01mc
date Fri, 26 Nov 2010 08:05:48 +0000
parents 9bd13c7819ae
children
rev   line source
mas01mc@754 1 /*
mas01mc@754 2 * MultiProbe C++ class
mas01mc@754 3 *
mas01mc@754 4 * Given a vector of LSH boundary distances for a query,
mas01mc@754 5 * perform lookup by probing nearby hash-function locations
mas01mc@754 6 *
mas01mc@754 7 * Implementation using C++ STL
mas01mc@754 8 *
mas01mc@754 9 * Reference:
mas01mc@754 10 * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li,
mas01mc@754 11 * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity
mas01mc@754 12 * Search", Proc. Intl. Conf. VLDB, 2007
mas01mc@754 13 *
mas01mc@754 14 *
mas01mc@754 15 * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved
mas01mc@754 16 * License: GNU Public License 2.0
mas01mc@754 17 *
mas01mc@754 18 */
mas01mc@754 19
mas01mc@754 20 #include "multiprobe.h"
mas01mc@754 21
mas01mc@754 22 //#define _TEST_MP_LSH
mas01mc@754 23
mas01mc@754 24 bool operator> (const min_heap_element& a, const min_heap_element& b){
mas01mc@754 25 return a.score > b.score;
mas01mc@754 26 }
mas01mc@754 27
mas01mc@754 28 bool operator< (const min_heap_element& a, const min_heap_element& b){
mas01mc@754 29 return a.score < b.score;
mas01mc@754 30 }
mas01mc@754 31
mas01mc@754 32 bool operator>(const sorted_distance_functions& a, const sorted_distance_functions& b){
mas01mc@754 33 return a.first > b.first;
mas01mc@754 34 }
mas01mc@754 35
mas01mc@754 36 bool operator<(const sorted_distance_functions& a, const sorted_distance_functions& b){
mas01mc@754 37 return a.first < b.first;
mas01mc@754 38 }
mas01mc@754 39
mas01mc@754 40 MinHeapElement::MinHeapElement(perturbation_set a, float s):
mas01mc@754 41 perturbs(a),
mas01mc@754 42 score(s)
mas01mc@754 43 {
mas01mc@754 44
mas01mc@754 45 }
mas01mc@754 46
mas01mc@754 47 MinHeapElement::~MinHeapElement(){;}
mas01mc@754 48
mas01mc@754 49 MultiProbe::MultiProbe():
mas01mc@754 50 minHeap(0),
mas01mc@754 51 outSets(0),
mas01mc@754 52 distFuns(0),
mas01mc@754 53 numHashBoundaries(0)
mas01mc@754 54 {
mas01mc@754 55
mas01mc@754 56 }
mas01mc@754 57
mas01mc@754 58 MultiProbe::~MultiProbe(){
mas01mc@754 59 cleanup();
mas01mc@754 60 }
mas01mc@754 61
mas01mc@754 62 void MultiProbe::initialize(){
mas01mc@754 63 minHeap = new min_heap_of_perturbation_set();
mas01mc@754 64 outSets = new min_heap_of_perturbation_set();
mas01mc@754 65 }
mas01mc@754 66
mas01mc@754 67 void MultiProbe::cleanup(){
mas01mc@754 68 delete minHeap;
mas01mc@754 69 minHeap = 0;
mas01mc@754 70 delete outSets;
mas01mc@754 71 outSets = 0;
mas01mc@754 72 delete distFuns;
mas01mc@754 73 distFuns = 0;
mas01mc@754 74 }
mas01mc@754 75
mas01mc@754 76 size_t MultiProbe::size(){
mas01mc@754 77 return outSets->size();
mas01mc@754 78 }
mas01mc@754 79
mas01mc@754 80 bool MultiProbe::empty(){
mas01mc@754 81 return !outSets->size();
mas01mc@754 82 }
mas01mc@754 83
mas01mc@754 84
mas01mc@754 85 void MultiProbe::generatePerturbationSets(vector<float>& x, unsigned T){
mas01mc@754 86 cleanup(); // Make re-entrant
mas01mc@754 87 initialize();
mas01mc@754 88 makeSortedDistFuns(x);
mas01mc@754 89 algorithm1(T);
mas01mc@754 90 }
mas01mc@754 91
mas01mc@754 92 // overloading to support efficient array use without initial copy
mas01mc@754 93 void MultiProbe::generatePerturbationSets(float* x, unsigned N, unsigned T){
mas01mc@754 94 cleanup(); // Make re-entrant
mas01mc@754 95 initialize();
mas01mc@754 96 makeSortedDistFuns(x, N);
mas01mc@754 97 algorithm1(T);
mas01mc@754 98 }
mas01mc@754 99
mas01mc@754 100 // Generate the optimal T perturbation sets for current query
mas01mc@754 101 // pre-conditions:
mas01mc@754 102 // an LSH structure was initialized and passed to constructor
mas01mc@754 103 // a query vector was passed to lsh->compute_hash_functions()
mas01mc@754 104 // the query-to-boundary distances are stored in x[hashFunIndex]
mas01mc@754 105 //
mas01mc@754 106 // post-conditions:
mas01mc@754 107 // generates an ordered list of perturbation sets (stored as an array of sets)
mas01mc@754 108 // these are indexes into pi_j=(i,delta) pairs representing x_i(delta) in sort order z_j
mas01mc@754 109 // data structures are cleared and reset to zeros thereby making them re-entrant
mas01mc@754 110 //
mas01mc@754 111 void MultiProbe::algorithm1(unsigned T){
mas01mc@754 112 perturbation_set ai,as,ae;
mas01mc@754 113 float ai_score;
mas01mc@754 114 ai.insert(0); // Initialize for this query
mas01mc@754 115 minHeap->push(min_heap_element(ai, score(ai))); // unique instance stored in mhe
mas01mc@754 116
mas01mc@754 117 min_heap_element mhe = minHeap->top();
mas01mc@754 118
mas01mc@754 119 if(T>distFuns->size())
mas01mc@754 120 T = distFuns->size();
mas01mc@754 121 for(unsigned i = 0 ; i != T ; i++ ){
mas01mc@754 122 do{
mas01mc@754 123 mhe = minHeap->top();
mas01mc@754 124 ai = mhe.perturbs;
mas01mc@754 125 ai_score = mhe.score;
mas01mc@754 126 minHeap->pop();
mas01mc@754 127 as=ai;
mas01mc@754 128 shift(as);
mas01mc@754 129 minHeap->push(min_heap_element(as, score(as)));
mas01mc@754 130 ae=ai;
mas01mc@754 131 expand(ae);
mas01mc@754 132 minHeap->push(min_heap_element(ae, score(ae)));
mas01mc@754 133 }while(!valid(ai));
mas01mc@754 134 outSets->push(mhe); // Ordered list of perturbation sets
mas01mc@754 135 }
mas01mc@754 136 }
mas01mc@754 137
mas01mc@754 138 void MultiProbe::dump(perturbation_set a){
mas01mc@754 139 perturbation_set::iterator it = a.begin();
mas01mc@754 140 while(it != a.end()){
mas01mc@754 141 cout << "[" << (*distFuns)[*it].second.first << "," << (*distFuns)[*it].second.second << "]" << " "
mas01mc@754 142 << (*distFuns)[*it].first << *it << ", ";
mas01mc@754 143 it++;
mas01mc@754 144 }
mas01mc@754 145 cout << "(" << score(a) << ")";
mas01mc@754 146 cout << endl;
mas01mc@754 147 }
mas01mc@754 148
mas01mc@754 149 // Given the set a, add 1 to last element of the set
mas01mc@754 150 inline perturbation_set& MultiProbe::shift(perturbation_set& a){
mas01mc@754 151 perturbation_set::iterator it = a.end();
mas01mc@754 152 int val = *(--it) + 1;
mas01mc@754 153 a.erase(it);
mas01mc@754 154 a.insert(it,val);
mas01mc@754 155 return a;
mas01mc@754 156 }
mas01mc@754 157
mas01mc@754 158 // Given the set a, add a new element one greater than the max
mas01mc@754 159 inline perturbation_set& MultiProbe::expand(perturbation_set& a){
mas01mc@754 160 perturbation_set::reverse_iterator ri = a.rbegin();
mas01mc@754 161 a.insert(*ri+1);
mas01mc@754 162 return a;
mas01mc@754 163 }
mas01mc@754 164
mas01mc@754 165 // Take the list of distances (x) assuming list len is 2M and
mas01mc@754 166 // delta = (-1)^i, i = { 0 .. 2M-1 }
mas01mc@754 167 void MultiProbe::makeSortedDistFuns(vector<float>& x){
mas01mc@754 168 numHashBoundaries = x.size(); // x.size() == 2M
mas01mc@754 169 delete distFuns;
mas01mc@754 170 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries);
mas01mc@754 171 for(unsigned i = 0; i != numHashBoundaries ; i++ )
mas01mc@754 172 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1));
mas01mc@754 173 // SORT
mas01mc@754 174 sort( distFuns->begin(), distFuns->end() );
mas01mc@754 175 }
mas01mc@754 176
mas01mc@754 177 // Float array version of above
mas01mc@754 178 void MultiProbe::makeSortedDistFuns(float* x, unsigned N){
mas01mc@754 179 numHashBoundaries = N; // x.size() == 2M
mas01mc@754 180 delete distFuns;
mas01mc@754 181 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries);
mas01mc@754 182 for(unsigned i = 0; i != numHashBoundaries ; i++ )
mas01mc@754 183 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1));
mas01mc@754 184 // SORT
mas01mc@754 185 sort( distFuns->begin(), distFuns->end() );
mas01mc@754 186 }
mas01mc@754 187
mas01mc@754 188 // For a given perturbation set, the score is the
mas01mc@754 189 // sum of squares of corresponding distances in x
mas01mc@754 190 float MultiProbe::score(perturbation_set& a){
mas01mc@754 191 //assert(!a.empty());
mas01mc@754 192 float score = 0.0, tmp;
mas01mc@754 193 perturbation_set::iterator it;
mas01mc@754 194 it = a.begin();
mas01mc@754 195 do{
mas01mc@754 196 tmp = (*distFuns)[*it].first;
mas01mc@754 197 score += tmp*tmp;
mas01mc@754 198 }while( ++it != a.end() );
mas01mc@754 199 return score;
mas01mc@754 200 }
mas01mc@754 201
mas01mc@754 202 // A valid set must have at most one
mas01mc@754 203 // of the two elements {j, 2M + 1 - j} for every j
mas01mc@754 204 //
mas01mc@754 205 // A perturbation set containing an element > 2M is invalid
mas01mc@754 206 bool MultiProbe::valid(perturbation_set& a){
mas01mc@754 207 int j;
mas01mc@754 208 perturbation_set::iterator it = a.begin();
mas01mc@754 209 while( it != a.end() ){
mas01mc@754 210 j = *it;
mas01mc@754 211 it++;
mas01mc@754 212 if( ( (unsigned)j > numHashBoundaries ) || ( a.find( numHashBoundaries - j - 1 ) != a.end() ) )
mas01mc@754 213 return false;
mas01mc@754 214 }
mas01mc@754 215 return true;
mas01mc@754 216 }
mas01mc@754 217
mas01mc@754 218 int MultiProbe::getIndex(perturbation_set::iterator it){
mas01mc@754 219 return (*distFuns)[*it].second.first;
mas01mc@754 220 }
mas01mc@754 221
mas01mc@754 222 int MultiProbe::getBoundary(perturbation_set::iterator it){
mas01mc@754 223 return (*distFuns)[*it].second.second;
mas01mc@754 224 }
mas01mc@754 225
mas01mc@754 226 // copy return next perturbation_set
mas01mc@754 227 perturbation_set MultiProbe::getNextPerturbationSet(){
mas01mc@754 228 perturbation_set s = outSets->top().perturbs;
mas01mc@754 229 outSets->pop();
mas01mc@754 230 return s;
mas01mc@754 231 }
mas01mc@754 232
mas01mc@754 233 // Test routine: generate 100 random boundary distance pairs
mas01mc@754 234 // call generatePerturbationSets on these distances
mas01mc@754 235 // dump output for inspection
mas01mc@754 236 #ifdef _TEST_MP_LSH
mas01mc@754 237 int main(const int argc, const char* argv[]){
mas01mc@754 238 int N_SAMPS = 100; // Number of random samples
mas01mc@754 239 int W = 4; // simulated hash-bucket size
mas01mc@754 240 int N_ITER = 100; // How many re-entrant iterations
mas01mc@754 241 unsigned T = 10; // Number of multi-probe sets to generate
mas01mc@754 242
mas01mc@754 243 MultiProbe mp= MultiProbe();
mas01mc@754 244 vector<float> x(N_SAMPS);
mas01mc@754 245
mas01mc@754 246 srand((unsigned)time(0));
mas01mc@754 247
mas01mc@754 248 // Test re-entrance on single instance
mas01mc@754 249 for(int j = 0; j< N_ITER ; j++){
mas01mc@754 250 cout << "********** ITERATION " << j << " **********" << endl;
mas01mc@754 251 cout.flush();
mas01mc@754 252 for (int i = 0 ; i != x.size()/2 ; i++ ){
mas01mc@754 253 x[2*i] = W*(rand()/(RAND_MAX+1.0));
mas01mc@754 254 x[2*i+1] = W - x[2*i];
mas01mc@754 255 }
mas01mc@754 256 // Generate multi-probe sets
mas01mc@754 257 mp.generatePerturbationSets(x, T);
mas01mc@754 258 // Output contents of multi-probe sets
mas01mc@754 259 while(!mp.empty())
mas01mc@754 260 mp.dump(mp.getNextPerturbationSet());
mas01mc@754 261 }
mas01mc@754 262 }
mas01mc@754 263 #endif