mas01mc@764
|
1 /*
|
mas01mc@764
|
2 * MultiProbe C++ class
|
mas01mc@764
|
3 *
|
mas01mc@764
|
4 * Given a vector of LSH boundary distances for a query,
|
mas01mc@764
|
5 * perform lookup by probing nearby hash-function locations
|
mas01mc@764
|
6 *
|
mas01mc@764
|
7 * Implementation using C++ STL
|
mas01mc@764
|
8 *
|
mas01mc@764
|
9 * Reference:
|
mas01mc@764
|
10 * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li,
|
mas01mc@764
|
11 * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity
|
mas01mc@764
|
12 * Search", Proc. Intl. Conf. VLDB, 2007
|
mas01mc@764
|
13 *
|
mas01mc@764
|
14 *
|
mas01mc@764
|
15 * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved
|
mas01mc@764
|
16 * License: GNU Public License 2.0
|
mas01mc@764
|
17 *
|
mas01mc@764
|
18 */
|
mas01mc@764
|
19
|
mas01mc@764
|
20 #include "multiprobe.h"
|
mas01mc@764
|
21
|
mas01mc@764
|
22 //#define _TEST_MP_LSH
|
mas01mc@764
|
23
|
mas01mc@764
|
24 bool operator> (const min_heap_element& a, const min_heap_element& b){
|
mas01mc@764
|
25 return a.score > b.score;
|
mas01mc@764
|
26 }
|
mas01mc@764
|
27
|
mas01mc@764
|
28 bool operator< (const min_heap_element& a, const min_heap_element& b){
|
mas01mc@764
|
29 return a.score < b.score;
|
mas01mc@764
|
30 }
|
mas01mc@764
|
31
|
mas01mc@764
|
32 bool operator>(const sorted_distance_functions& a, const sorted_distance_functions& b){
|
mas01mc@764
|
33 return a.first > b.first;
|
mas01mc@764
|
34 }
|
mas01mc@764
|
35
|
mas01mc@764
|
36 bool operator<(const sorted_distance_functions& a, const sorted_distance_functions& b){
|
mas01mc@764
|
37 return a.first < b.first;
|
mas01mc@764
|
38 }
|
mas01mc@764
|
39
|
mas01mc@764
|
40 MinHeapElement::MinHeapElement(perturbation_set a, float s):
|
mas01mc@764
|
41 perturbs(a),
|
mas01mc@764
|
42 score(s)
|
mas01mc@764
|
43 {
|
mas01mc@764
|
44
|
mas01mc@764
|
45 }
|
mas01mc@764
|
46
|
mas01mc@764
|
47 MinHeapElement::~MinHeapElement(){;}
|
mas01mc@764
|
48
|
mas01mc@764
|
49 MultiProbe::MultiProbe():
|
mas01mc@764
|
50 minHeap(0),
|
mas01mc@764
|
51 outSets(0),
|
mas01mc@764
|
52 distFuns(0),
|
mas01mc@764
|
53 numHashBoundaries(0)
|
mas01mc@764
|
54 {
|
mas01mc@764
|
55
|
mas01mc@764
|
56 }
|
mas01mc@764
|
57
|
mas01mc@764
|
58 MultiProbe::~MultiProbe(){
|
mas01mc@764
|
59 cleanup();
|
mas01mc@764
|
60 }
|
mas01mc@764
|
61
|
mas01mc@764
|
62 void MultiProbe::initialize(){
|
mas01mc@764
|
63 minHeap = new min_heap_of_perturbation_set();
|
mas01mc@764
|
64 outSets = new min_heap_of_perturbation_set();
|
mas01mc@764
|
65 }
|
mas01mc@764
|
66
|
mas01mc@764
|
67 void MultiProbe::cleanup(){
|
mas01mc@764
|
68 delete minHeap;
|
mas01mc@764
|
69 minHeap = 0;
|
mas01mc@764
|
70 delete outSets;
|
mas01mc@764
|
71 outSets = 0;
|
mas01mc@764
|
72 delete distFuns;
|
mas01mc@764
|
73 distFuns = 0;
|
mas01mc@764
|
74 }
|
mas01mc@764
|
75
|
mas01mc@764
|
76 size_t MultiProbe::size(){
|
mas01mc@764
|
77 return outSets->size();
|
mas01mc@764
|
78 }
|
mas01mc@764
|
79
|
mas01mc@764
|
80 bool MultiProbe::empty(){
|
mas01mc@764
|
81 return !outSets->size();
|
mas01mc@764
|
82 }
|
mas01mc@764
|
83
|
mas01mc@764
|
84
|
mas01mc@764
|
85 void MultiProbe::generatePerturbationSets(vector<float>& x, unsigned T){
|
mas01mc@764
|
86 cleanup(); // Make re-entrant
|
mas01mc@764
|
87 initialize();
|
mas01mc@764
|
88 makeSortedDistFuns(x);
|
mas01mc@764
|
89 algorithm1(T);
|
mas01mc@764
|
90 }
|
mas01mc@764
|
91
|
mas01mc@764
|
92 // overloading to support efficient array use without initial copy
|
mas01mc@764
|
93 void MultiProbe::generatePerturbationSets(float* x, unsigned N, unsigned T){
|
mas01mc@764
|
94 cleanup(); // Make re-entrant
|
mas01mc@764
|
95 initialize();
|
mas01mc@764
|
96 makeSortedDistFuns(x, N);
|
mas01mc@764
|
97 algorithm1(T);
|
mas01mc@764
|
98 }
|
mas01mc@764
|
99
|
mas01mc@764
|
100 // Generate the optimal T perturbation sets for current query
|
mas01mc@764
|
101 // pre-conditions:
|
mas01mc@764
|
102 // an LSH structure was initialized and passed to constructor
|
mas01mc@764
|
103 // a query vector was passed to lsh->compute_hash_functions()
|
mas01mc@764
|
104 // the query-to-boundary distances are stored in x[hashFunIndex]
|
mas01mc@764
|
105 //
|
mas01mc@764
|
106 // post-conditions:
|
mas01mc@764
|
107 // generates an ordered list of perturbation sets (stored as an array of sets)
|
mas01mc@764
|
108 // these are indexes into pi_j=(i,delta) pairs representing x_i(delta) in sort order z_j
|
mas01mc@764
|
109 // data structures are cleared and reset to zeros thereby making them re-entrant
|
mas01mc@764
|
110 //
|
mas01mc@764
|
111 void MultiProbe::algorithm1(unsigned T){
|
mas01mc@764
|
112 perturbation_set ai,as,ae;
|
mas01mc@764
|
113 float ai_score;
|
mas01mc@764
|
114 ai.insert(0); // Initialize for this query
|
mas01mc@764
|
115 minHeap->push(min_heap_element(ai, score(ai))); // unique instance stored in mhe
|
mas01mc@764
|
116
|
mas01mc@764
|
117 min_heap_element mhe = minHeap->top();
|
mas01mc@764
|
118
|
mas01mc@764
|
119 if(T>distFuns->size())
|
mas01mc@764
|
120 T = distFuns->size();
|
mas01mc@764
|
121 for(unsigned i = 0 ; i != T ; i++ ){
|
mas01mc@764
|
122 do{
|
mas01mc@764
|
123 mhe = minHeap->top();
|
mas01mc@764
|
124 ai = mhe.perturbs;
|
mas01mc@764
|
125 ai_score = mhe.score;
|
mas01mc@764
|
126 minHeap->pop();
|
mas01mc@764
|
127 as=ai;
|
mas01mc@764
|
128 shift(as);
|
mas01mc@764
|
129 minHeap->push(min_heap_element(as, score(as)));
|
mas01mc@764
|
130 ae=ai;
|
mas01mc@764
|
131 expand(ae);
|
mas01mc@764
|
132 minHeap->push(min_heap_element(ae, score(ae)));
|
mas01mc@764
|
133 }while(!valid(ai));
|
mas01mc@764
|
134 outSets->push(mhe); // Ordered list of perturbation sets
|
mas01mc@764
|
135 }
|
mas01mc@764
|
136 }
|
mas01mc@764
|
137
|
mas01mc@764
|
138 void MultiProbe::dump(perturbation_set a){
|
mas01mc@764
|
139 perturbation_set::iterator it = a.begin();
|
mas01mc@764
|
140 while(it != a.end()){
|
mas01mc@764
|
141 cout << "[" << (*distFuns)[*it].second.first << "," << (*distFuns)[*it].second.second << "]" << " "
|
mas01mc@764
|
142 << (*distFuns)[*it].first << *it << ", ";
|
mas01mc@764
|
143 it++;
|
mas01mc@764
|
144 }
|
mas01mc@764
|
145 cout << "(" << score(a) << ")";
|
mas01mc@764
|
146 cout << endl;
|
mas01mc@764
|
147 }
|
mas01mc@764
|
148
|
mas01mc@764
|
149 // Given the set a, add 1 to last element of the set
|
mas01mc@764
|
150 inline perturbation_set& MultiProbe::shift(perturbation_set& a){
|
mas01mc@764
|
151 perturbation_set::iterator it = a.end();
|
mas01mc@764
|
152 int val = *(--it) + 1;
|
mas01mc@764
|
153 a.erase(it);
|
mas01mc@764
|
154 a.insert(it,val);
|
mas01mc@764
|
155 return a;
|
mas01mc@764
|
156 }
|
mas01mc@764
|
157
|
mas01mc@764
|
158 // Given the set a, add a new element one greater than the max
|
mas01mc@764
|
159 inline perturbation_set& MultiProbe::expand(perturbation_set& a){
|
mas01mc@764
|
160 perturbation_set::reverse_iterator ri = a.rbegin();
|
mas01mc@764
|
161 a.insert(*ri+1);
|
mas01mc@764
|
162 return a;
|
mas01mc@764
|
163 }
|
mas01mc@764
|
164
|
mas01mc@764
|
165 // Take the list of distances (x) assuming list len is 2M and
|
mas01mc@764
|
166 // delta = (-1)^i, i = { 0 .. 2M-1 }
|
mas01mc@764
|
167 void MultiProbe::makeSortedDistFuns(vector<float>& x){
|
mas01mc@764
|
168 numHashBoundaries = x.size(); // x.size() == 2M
|
mas01mc@764
|
169 delete distFuns;
|
mas01mc@764
|
170 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries);
|
mas01mc@764
|
171 for(unsigned i = 0; i != numHashBoundaries ; i++ )
|
mas01mc@764
|
172 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1));
|
mas01mc@764
|
173 // SORT
|
mas01mc@764
|
174 sort( distFuns->begin(), distFuns->end() );
|
mas01mc@764
|
175 }
|
mas01mc@764
|
176
|
mas01mc@764
|
177 // Float array version of above
|
mas01mc@764
|
178 void MultiProbe::makeSortedDistFuns(float* x, unsigned N){
|
mas01mc@764
|
179 numHashBoundaries = N; // x.size() == 2M
|
mas01mc@764
|
180 delete distFuns;
|
mas01mc@764
|
181 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries);
|
mas01mc@764
|
182 for(unsigned i = 0; i != numHashBoundaries ; i++ )
|
mas01mc@764
|
183 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1));
|
mas01mc@764
|
184 // SORT
|
mas01mc@764
|
185 sort( distFuns->begin(), distFuns->end() );
|
mas01mc@764
|
186 }
|
mas01mc@764
|
187
|
mas01mc@764
|
188 // For a given perturbation set, the score is the
|
mas01mc@764
|
189 // sum of squares of corresponding distances in x
|
mas01mc@764
|
190 float MultiProbe::score(perturbation_set& a){
|
mas01mc@764
|
191 //assert(!a.empty());
|
mas01mc@764
|
192 float score = 0.0, tmp;
|
mas01mc@764
|
193 perturbation_set::iterator it;
|
mas01mc@764
|
194 it = a.begin();
|
mas01mc@764
|
195 do{
|
mas01mc@764
|
196 tmp = (*distFuns)[*it].first;
|
mas01mc@764
|
197 score += tmp*tmp;
|
mas01mc@764
|
198 }while( ++it != a.end() );
|
mas01mc@764
|
199 return score;
|
mas01mc@764
|
200 }
|
mas01mc@764
|
201
|
mas01mc@764
|
202 // A valid set must have at most one
|
mas01mc@764
|
203 // of the two elements {j, 2M + 1 - j} for every j
|
mas01mc@764
|
204 //
|
mas01mc@764
|
205 // A perturbation set containing an element > 2M is invalid
|
mas01mc@764
|
206 bool MultiProbe::valid(perturbation_set& a){
|
mas01mc@764
|
207 int j;
|
mas01mc@764
|
208 perturbation_set::iterator it = a.begin();
|
mas01mc@764
|
209 while( it != a.end() ){
|
mas01mc@764
|
210 j = *it;
|
mas01mc@764
|
211 it++;
|
mas01mc@764
|
212 if( ( (unsigned)j > numHashBoundaries ) || ( a.find( numHashBoundaries - j - 1 ) != a.end() ) )
|
mas01mc@764
|
213 return false;
|
mas01mc@764
|
214 }
|
mas01mc@764
|
215 return true;
|
mas01mc@764
|
216 }
|
mas01mc@764
|
217
|
mas01mc@764
|
218 int MultiProbe::getIndex(perturbation_set::iterator it){
|
mas01mc@764
|
219 return (*distFuns)[*it].second.first;
|
mas01mc@764
|
220 }
|
mas01mc@764
|
221
|
mas01mc@764
|
222 int MultiProbe::getBoundary(perturbation_set::iterator it){
|
mas01mc@764
|
223 return (*distFuns)[*it].second.second;
|
mas01mc@764
|
224 }
|
mas01mc@764
|
225
|
mas01mc@764
|
226 // copy return next perturbation_set
|
mas01mc@764
|
227 perturbation_set MultiProbe::getNextPerturbationSet(){
|
mas01mc@764
|
228 perturbation_set s = outSets->top().perturbs;
|
mas01mc@764
|
229 outSets->pop();
|
mas01mc@764
|
230 return s;
|
mas01mc@764
|
231 }
|
mas01mc@764
|
232
|
mas01mc@764
|
233 // Test routine: generate 100 random boundary distance pairs
|
mas01mc@764
|
234 // call generatePerturbationSets on these distances
|
mas01mc@764
|
235 // dump output for inspection
|
mas01mc@764
|
236 #ifdef _TEST_MP_LSH
|
mas01mc@764
|
237 int main(const int argc, const char* argv[]){
|
mas01mc@764
|
238 int N_SAMPS = 100; // Number of random samples
|
mas01mc@764
|
239 int W = 4; // simulated hash-bucket size
|
mas01mc@764
|
240 int N_ITER = 100; // How many re-entrant iterations
|
mas01mc@764
|
241 unsigned T = 10; // Number of multi-probe sets to generate
|
mas01mc@764
|
242
|
mas01mc@764
|
243 MultiProbe mp= MultiProbe();
|
mas01mc@764
|
244 vector<float> x(N_SAMPS);
|
mas01mc@764
|
245
|
mas01mc@764
|
246 srand((unsigned)time(0));
|
mas01mc@764
|
247
|
mas01mc@764
|
248 // Test re-entrance on single instance
|
mas01mc@764
|
249 for(int j = 0; j< N_ITER ; j++){
|
mas01mc@764
|
250 cout << "********** ITERATION " << j << " **********" << endl;
|
mas01mc@764
|
251 cout.flush();
|
mas01mc@764
|
252 for (int i = 0 ; i != x.size()/2 ; i++ ){
|
mas01mc@764
|
253 x[2*i] = W*(rand()/(RAND_MAX+1.0));
|
mas01mc@764
|
254 x[2*i+1] = W - x[2*i];
|
mas01mc@764
|
255 }
|
mas01mc@764
|
256 // Generate multi-probe sets
|
mas01mc@764
|
257 mp.generatePerturbationSets(x, T);
|
mas01mc@764
|
258 // Output contents of multi-probe sets
|
mas01mc@764
|
259 while(!mp.empty())
|
mas01mc@764
|
260 mp.dump(mp.getNextPerturbationSet());
|
mas01mc@764
|
261 }
|
mas01mc@764
|
262 }
|
mas01mc@764
|
263 #endif
|