Mercurial > hg > audiodb
comparison multiprobe.cpp @ 754:9bd13c7819ae mkc_lsh_update
Adding mkc_lsh_update branch, trunk candidate with improved LSH: merged trunk 1095 and branch multiprobe_lsh
author | mas01mc |
---|---|
date | Thu, 25 Nov 2010 13:42:40 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
747:fbf16508421f | 754:9bd13c7819ae |
---|---|
1 /* | |
2 * MultiProbe C++ class | |
3 * | |
4 * Given a vector of LSH boundary distances for a query, | |
5 * perform lookup by probing nearby hash-function locations | |
6 * | |
7 * Implementation using C++ STL | |
8 * | |
9 * Reference: | |
10 * Qin Lv, William Josephson, Zhe Wang, Moses Charikar and Kai Li, | |
11 * "Multi-Probe LSH: Efficient Indexing for High-Dimensional Similarity | |
12 * Search", Proc. Intl. Conf. VLDB, 2007 | |
13 * | |
14 * | |
15 * Copyright (C) 2009 Michael Casey, Dartmouth College, All Rights Reserved | |
16 * License: GNU Public License 2.0 | |
17 * | |
18 */ | |
19 | |
20 #include "multiprobe.h" | |
21 | |
22 //#define _TEST_MP_LSH | |
23 | |
24 bool operator> (const min_heap_element& a, const min_heap_element& b){ | |
25 return a.score > b.score; | |
26 } | |
27 | |
28 bool operator< (const min_heap_element& a, const min_heap_element& b){ | |
29 return a.score < b.score; | |
30 } | |
31 | |
32 bool operator>(const sorted_distance_functions& a, const sorted_distance_functions& b){ | |
33 return a.first > b.first; | |
34 } | |
35 | |
36 bool operator<(const sorted_distance_functions& a, const sorted_distance_functions& b){ | |
37 return a.first < b.first; | |
38 } | |
39 | |
40 MinHeapElement::MinHeapElement(perturbation_set a, float s): | |
41 perturbs(a), | |
42 score(s) | |
43 { | |
44 | |
45 } | |
46 | |
47 MinHeapElement::~MinHeapElement(){;} | |
48 | |
49 MultiProbe::MultiProbe(): | |
50 minHeap(0), | |
51 outSets(0), | |
52 distFuns(0), | |
53 numHashBoundaries(0) | |
54 { | |
55 | |
56 } | |
57 | |
58 MultiProbe::~MultiProbe(){ | |
59 cleanup(); | |
60 } | |
61 | |
62 void MultiProbe::initialize(){ | |
63 minHeap = new min_heap_of_perturbation_set(); | |
64 outSets = new min_heap_of_perturbation_set(); | |
65 } | |
66 | |
67 void MultiProbe::cleanup(){ | |
68 delete minHeap; | |
69 minHeap = 0; | |
70 delete outSets; | |
71 outSets = 0; | |
72 delete distFuns; | |
73 distFuns = 0; | |
74 } | |
75 | |
76 size_t MultiProbe::size(){ | |
77 return outSets->size(); | |
78 } | |
79 | |
80 bool MultiProbe::empty(){ | |
81 return !outSets->size(); | |
82 } | |
83 | |
84 | |
85 void MultiProbe::generatePerturbationSets(vector<float>& x, unsigned T){ | |
86 cleanup(); // Make re-entrant | |
87 initialize(); | |
88 makeSortedDistFuns(x); | |
89 algorithm1(T); | |
90 } | |
91 | |
92 // overloading to support efficient array use without initial copy | |
93 void MultiProbe::generatePerturbationSets(float* x, unsigned N, unsigned T){ | |
94 cleanup(); // Make re-entrant | |
95 initialize(); | |
96 makeSortedDistFuns(x, N); | |
97 algorithm1(T); | |
98 } | |
99 | |
100 // Generate the optimal T perturbation sets for current query | |
101 // pre-conditions: | |
102 // an LSH structure was initialized and passed to constructor | |
103 // a query vector was passed to lsh->compute_hash_functions() | |
104 // the query-to-boundary distances are stored in x[hashFunIndex] | |
105 // | |
106 // post-conditions: | |
107 // generates an ordered list of perturbation sets (stored as an array of sets) | |
108 // these are indexes into pi_j=(i,delta) pairs representing x_i(delta) in sort order z_j | |
109 // data structures are cleared and reset to zeros thereby making them re-entrant | |
110 // | |
111 void MultiProbe::algorithm1(unsigned T){ | |
112 perturbation_set ai,as,ae; | |
113 float ai_score; | |
114 ai.insert(0); // Initialize for this query | |
115 minHeap->push(min_heap_element(ai, score(ai))); // unique instance stored in mhe | |
116 | |
117 min_heap_element mhe = minHeap->top(); | |
118 | |
119 if(T>distFuns->size()) | |
120 T = distFuns->size(); | |
121 for(unsigned i = 0 ; i != T ; i++ ){ | |
122 do{ | |
123 mhe = minHeap->top(); | |
124 ai = mhe.perturbs; | |
125 ai_score = mhe.score; | |
126 minHeap->pop(); | |
127 as=ai; | |
128 shift(as); | |
129 minHeap->push(min_heap_element(as, score(as))); | |
130 ae=ai; | |
131 expand(ae); | |
132 minHeap->push(min_heap_element(ae, score(ae))); | |
133 }while(!valid(ai)); | |
134 outSets->push(mhe); // Ordered list of perturbation sets | |
135 } | |
136 } | |
137 | |
138 void MultiProbe::dump(perturbation_set a){ | |
139 perturbation_set::iterator it = a.begin(); | |
140 while(it != a.end()){ | |
141 cout << "[" << (*distFuns)[*it].second.first << "," << (*distFuns)[*it].second.second << "]" << " " | |
142 << (*distFuns)[*it].first << *it << ", "; | |
143 it++; | |
144 } | |
145 cout << "(" << score(a) << ")"; | |
146 cout << endl; | |
147 } | |
148 | |
149 // Given the set a, add 1 to last element of the set | |
150 inline perturbation_set& MultiProbe::shift(perturbation_set& a){ | |
151 perturbation_set::iterator it = a.end(); | |
152 int val = *(--it) + 1; | |
153 a.erase(it); | |
154 a.insert(it,val); | |
155 return a; | |
156 } | |
157 | |
158 // Given the set a, add a new element one greater than the max | |
159 inline perturbation_set& MultiProbe::expand(perturbation_set& a){ | |
160 perturbation_set::reverse_iterator ri = a.rbegin(); | |
161 a.insert(*ri+1); | |
162 return a; | |
163 } | |
164 | |
165 // Take the list of distances (x) assuming list len is 2M and | |
166 // delta = (-1)^i, i = { 0 .. 2M-1 } | |
167 void MultiProbe::makeSortedDistFuns(vector<float>& x){ | |
168 numHashBoundaries = x.size(); // x.size() == 2M | |
169 delete distFuns; | |
170 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries); | |
171 for(unsigned i = 0; i != numHashBoundaries ; i++ ) | |
172 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1)); | |
173 // SORT | |
174 sort( distFuns->begin(), distFuns->end() ); | |
175 } | |
176 | |
177 // Float array version of above | |
178 void MultiProbe::makeSortedDistFuns(float* x, unsigned N){ | |
179 numHashBoundaries = N; // x.size() == 2M | |
180 delete distFuns; | |
181 distFuns = new std::vector<sorted_distance_functions>(numHashBoundaries); | |
182 for(unsigned i = 0; i != numHashBoundaries ; i++ ) | |
183 (*distFuns)[i] = make_pair(x[i], make_pair(i, i%2?1:-1)); | |
184 // SORT | |
185 sort( distFuns->begin(), distFuns->end() ); | |
186 } | |
187 | |
188 // For a given perturbation set, the score is the | |
189 // sum of squares of corresponding distances in x | |
190 float MultiProbe::score(perturbation_set& a){ | |
191 //assert(!a.empty()); | |
192 float score = 0.0, tmp; | |
193 perturbation_set::iterator it; | |
194 it = a.begin(); | |
195 do{ | |
196 tmp = (*distFuns)[*it].first; | |
197 score += tmp*tmp; | |
198 }while( ++it != a.end() ); | |
199 return score; | |
200 } | |
201 | |
202 // A valid set must have at most one | |
203 // of the two elements {j, 2M + 1 - j} for every j | |
204 // | |
205 // A perturbation set containing an element > 2M is invalid | |
206 bool MultiProbe::valid(perturbation_set& a){ | |
207 int j; | |
208 perturbation_set::iterator it = a.begin(); | |
209 while( it != a.end() ){ | |
210 j = *it; | |
211 it++; | |
212 if( ( (unsigned)j > numHashBoundaries ) || ( a.find( numHashBoundaries - j - 1 ) != a.end() ) ) | |
213 return false; | |
214 } | |
215 return true; | |
216 } | |
217 | |
218 int MultiProbe::getIndex(perturbation_set::iterator it){ | |
219 return (*distFuns)[*it].second.first; | |
220 } | |
221 | |
222 int MultiProbe::getBoundary(perturbation_set::iterator it){ | |
223 return (*distFuns)[*it].second.second; | |
224 } | |
225 | |
226 // copy return next perturbation_set | |
227 perturbation_set MultiProbe::getNextPerturbationSet(){ | |
228 perturbation_set s = outSets->top().perturbs; | |
229 outSets->pop(); | |
230 return s; | |
231 } | |
232 | |
233 // Test routine: generate 100 random boundary distance pairs | |
234 // call generatePerturbationSets on these distances | |
235 // dump output for inspection | |
236 #ifdef _TEST_MP_LSH | |
237 int main(const int argc, const char* argv[]){ | |
238 int N_SAMPS = 100; // Number of random samples | |
239 int W = 4; // simulated hash-bucket size | |
240 int N_ITER = 100; // How many re-entrant iterations | |
241 unsigned T = 10; // Number of multi-probe sets to generate | |
242 | |
243 MultiProbe mp= MultiProbe(); | |
244 vector<float> x(N_SAMPS); | |
245 | |
246 srand((unsigned)time(0)); | |
247 | |
248 // Test re-entrance on single instance | |
249 for(int j = 0; j< N_ITER ; j++){ | |
250 cout << "********** ITERATION " << j << " **********" << endl; | |
251 cout.flush(); | |
252 for (int i = 0 ; i != x.size()/2 ; i++ ){ | |
253 x[2*i] = W*(rand()/(RAND_MAX+1.0)); | |
254 x[2*i+1] = W - x[2*i]; | |
255 } | |
256 // Generate multi-probe sets | |
257 mp.generatePerturbationSets(x, T); | |
258 // Output contents of multi-probe sets | |
259 while(!mp.empty()) | |
260 mp.dump(mp.getNextPerturbationSet()); | |
261 } | |
262 } | |
263 #endif |