Mercurial > hg > audiodb
comparison query.cpp @ 458:913a95f06998 api-inversion
Start using the query state structure.
Actually using it means moving it around in the source code a little
bit, thanks to entanglements. It'll all be alright on the night. Now
the accumulator, allowed_keys and exact_evaluation_queue are all part of
the query state, and can therefore be passed around with minimal effort
(and deleted in the appropriate place).
Now a whole bunch of static methods (the callbacks, basically) in
index.cpp can be rewritten as plain old C functions. The callbacks need
both an adb_t and a query state structure to function (the adb_t to get
at things like lsh_n_point_bits and the track->key table; the qstate to
get at the accumulator and allowed_keys list).
Rearrange audioDB::query a little bit, and mark the beginning and the
end of the putative audiodb_query_spec() API function implementation.
author | mas01cr |
---|---|
date | Sun, 28 Dec 2008 18:44:08 +0000 |
parents | 93ce12fe2f76 |
children | fcc6f7c4856b |
comparison
equal
deleted
inserted
replaced
457:823bca1e10f5 | 458:913a95f06998 |
---|---|
147 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { | 147 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { |
148 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); | 148 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); |
149 } else if (index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)) { | 149 } else if (index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)) { |
150 char* indexName = index_get_name(adb->path, qspec.refine.radius, qspec.qid.sequence_length); | 150 char* indexName = index_get_name(adb->path, qspec.refine.radius, qspec.qid.sequence_length); |
151 lsh = index_allocate(indexName, false); | 151 lsh = index_allocate(indexName, false); |
152 reporter = new trackSequenceQueryRadReporter(trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); | 152 reporter = new trackSequenceQueryRadReporter(trackNN, audiodb_index_to_track_id(lsh->get_maxp(), audiodb_lsh_n_point_bits(adb))+1); |
153 delete[] indexName; | 153 delete[] indexName; |
154 } else { | 154 } else { |
155 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); | 155 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); |
156 } | 156 } |
157 break; | 157 break; |
159 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { | 159 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { |
160 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); | 160 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); |
161 } else if (index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)){ | 161 } else if (index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)){ |
162 char* indexName = index_get_name(adb->path, qspec.refine.radius, qspec.qid.sequence_length); | 162 char* indexName = index_get_name(adb->path, qspec.refine.radius, qspec.qid.sequence_length); |
163 lsh = index_allocate(indexName, false); | 163 lsh = index_allocate(indexName, false); |
164 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); | 164 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, audiodb_index_to_track_id(lsh->get_maxp(), audiodb_lsh_n_point_bits(adb))+1); |
165 delete[] indexName; | 165 delete[] indexName; |
166 } else { | 166 } else { |
167 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); | 167 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); |
168 } | 168 } |
169 break; | 169 break; |
177 break; | 177 break; |
178 default: | 178 default: |
179 error("unrecognized queryType"); | 179 error("unrecognized queryType"); |
180 } | 180 } |
181 | 181 |
182 // keyKeyPos requires dbH to be initialized | 182 /* Somewhere around here is where the implementation of |
183 if(query_from_key && (!key || (query_from_key_index = audiodb_key_index(adb, key)) == (uint32_t) -1)) | 183 * audiodb_query_spec() starts. */ |
184 error("Query key not found", key); | 184 |
185 adb_qstate_internal_t qstate; | |
186 qstate.allowed_keys = new std::set<std::string>; | |
187 if(qspec.refine.flags & ADB_REFINE_INCLUDE_KEYLIST) { | |
188 for(unsigned int k = 0; k < qspec.refine.include.nkeys; k++) { | |
189 qstate.allowed_keys->insert(qspec.refine.include.keys[k]); | |
190 } | |
191 } else { | |
192 for(unsigned int k = 0; k < adb->header->numFiles; k++) { | |
193 qstate.allowed_keys->insert((*adb->keys)[k]); | |
194 } | |
195 } | |
196 if(qspec.refine.flags & ADB_REFINE_EXCLUDE_KEYLIST) { | |
197 for(unsigned int k = 0; k < qspec.refine.exclude.nkeys; k++) { | |
198 qstate.allowed_keys->erase(qspec.refine.exclude.keys[k]); | |
199 } | |
200 } | |
185 | 201 |
186 switch(qspec.params.distance) { | 202 switch(qspec.params.distance) { |
187 case ADB_DISTANCE_DOT_PRODUCT: | 203 case ADB_DISTANCE_DOT_PRODUCT: |
188 switch(qspec.params.accumulation) { | 204 switch(qspec.params.accumulation) { |
189 case ADB_ACCUMULATION_DB: | 205 case ADB_ACCUMULATION_DB: |
190 accumulator = new DBAccumulator<adb_result_dist_gt>(qspec.params.npoints); | 206 qstate.accumulator = new DBAccumulator<adb_result_dist_gt>(qspec.params.npoints); |
191 break; | 207 break; |
192 case ADB_ACCUMULATION_PER_TRACK: | 208 case ADB_ACCUMULATION_PER_TRACK: |
193 accumulator = new PerTrackAccumulator<adb_result_dist_gt>(qspec.params.npoints, qspec.params.ntracks); | 209 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_gt>(qspec.params.npoints, qspec.params.ntracks); |
194 break; | 210 break; |
195 case ADB_ACCUMULATION_ONE_TO_ONE: | 211 case ADB_ACCUMULATION_ONE_TO_ONE: |
196 accumulator = new NearestAccumulator<adb_result_dist_gt>(); | 212 qstate.accumulator = new NearestAccumulator<adb_result_dist_gt>(); |
197 break; | 213 break; |
198 default: | 214 default: |
199 error("unknown accumulation"); | 215 error("unknown accumulation"); |
200 } | 216 } |
201 break; | 217 break; |
202 case ADB_DISTANCE_EUCLIDEAN_NORMED: | 218 case ADB_DISTANCE_EUCLIDEAN_NORMED: |
203 case ADB_DISTANCE_EUCLIDEAN: | 219 case ADB_DISTANCE_EUCLIDEAN: |
204 switch(qspec.params.accumulation) { | 220 switch(qspec.params.accumulation) { |
205 case ADB_ACCUMULATION_DB: | 221 case ADB_ACCUMULATION_DB: |
206 accumulator = new DBAccumulator<adb_result_dist_lt>(qspec.params.npoints); | 222 qstate.accumulator = new DBAccumulator<adb_result_dist_lt>(qspec.params.npoints); |
207 break; | 223 break; |
208 case ADB_ACCUMULATION_PER_TRACK: | 224 case ADB_ACCUMULATION_PER_TRACK: |
209 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(qspec.params.npoints, qspec.params.ntracks); | 225 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_lt>(qspec.params.npoints, qspec.params.ntracks); |
210 break; | 226 break; |
211 case ADB_ACCUMULATION_ONE_TO_ONE: | 227 case ADB_ACCUMULATION_ONE_TO_ONE: |
212 accumulator = new NearestAccumulator<adb_result_dist_lt>(); | 228 qstate.accumulator = new NearestAccumulator<adb_result_dist_lt>(); |
213 break; | 229 break; |
214 default: | 230 default: |
215 error("unknown accumulation"); | 231 error("unknown accumulation"); |
216 } | 232 } |
217 break; | 233 break; |
220 } | 236 } |
221 | 237 |
222 // Test for index (again) here | 238 // Test for index (again) here |
223 if((qspec.refine.flags & ADB_REFINE_RADIUS) && index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)){ | 239 if((qspec.refine.flags & ADB_REFINE_RADIUS) && index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)){ |
224 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequence_length=%d\n", adb->path, qspec.refine.radius, qspec.qid.sequence_length); | 240 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequence_length=%d\n", adb->path, qspec.refine.radius, qspec.qid.sequence_length); |
225 index_query_loop(adb, &qspec); | 241 index_query_loop(adb, &qspec, &qstate); |
226 } | 242 } |
227 else{ | 243 else{ |
228 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); | 244 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); |
229 if(query_loop(adb, &qspec)) { | 245 if(query_loop(adb, &qspec, &qstate)) { |
230 error("query_loop failed"); | 246 error("query_loop failed"); |
231 } | 247 } |
232 } | 248 } |
233 | 249 |
234 adb_query_results_t *rs = accumulator->get_points(); | 250 adb_query_results_t *rs = qstate.accumulator->get_points(); |
251 | |
252 delete qstate.accumulator; | |
253 delete qstate.allowed_keys; | |
254 | |
255 /* End of audiodb_query_spec() function */ | |
256 | |
235 for(unsigned int k = 0; k < rs->nresults; k++) { | 257 for(unsigned int k = 0; k < rs->nresults; k++) { |
236 adb_result_t r = rs->results[k]; | 258 adb_result_t r = rs->results[k]; |
237 reporter->add_point(audiodb_key_index(adb, r.key), r.qpos, r.ipos, r.dist); | 259 reporter->add_point(audiodb_key_index(adb, r.key), r.qpos, r.ipos, r.dist); |
238 } | 260 } |
239 | 261 |
604 // A reporter has been allocated | 626 // A reporter has been allocated |
605 // | 627 // |
606 // Postconditions: | 628 // Postconditions: |
607 // reporter contains the points and distances that meet the reporter constraints | 629 // reporter contains the points and distances that meet the reporter constraints |
608 | 630 |
609 void audioDB::query_loop_points(adb_t *adb, adb_query_spec_t *spec, double* query, adb_qpointers_internal_t *qpointers) { | 631 void audioDB::query_loop_points(adb_t *adb, adb_query_spec_t *spec, adb_qstate_internal_t *qstate, double *query, adb_qpointers_internal_t *qpointers) { |
610 adb_qpointers_internal_t dbpointers = {0}; | 632 adb_qpointers_internal_t dbpointers = {0}; |
611 | 633 |
612 uint32_t sequence_length = spec->qid.sequence_length; | 634 uint32_t sequence_length = spec->qid.sequence_length; |
613 bool power_refine = spec->refine.flags & (ADB_REFINE_ABSOLUTE_THRESHOLD|ADB_REFINE_RELATIVE_THRESHOLD); | 635 bool power_refine = spec->refine.flags & (ADB_REFINE_ABSOLUTE_THRESHOLD|ADB_REFINE_RELATIVE_THRESHOLD); |
614 | 636 |
615 if(exact_evaluation_queue->size() == 0) { | 637 if(qstate->exact_evaluation_queue->size() == 0) { |
616 return; | 638 return; |
617 } | 639 } |
618 | 640 |
619 // Compute database info. FIXME: we more than likely don't need | 641 // Compute database info. FIXME: we more than likely don't need |
620 // very much of the database so write a new function to build these | 642 // very much of the database so write a new function to build these |
636 size_t data_buffer_size = 0; | 658 size_t data_buffer_size = 0; |
637 double *data_buffer = 0; | 659 double *data_buffer = 0; |
638 Uns32T trackOffset = 0; | 660 Uns32T trackOffset = 0; |
639 Uns32T trackIndexOffset = 0; | 661 Uns32T trackIndexOffset = 0; |
640 Uns32T currentTrack = 0x80000000; // Initialize with a value outside of track index range | 662 Uns32T currentTrack = 0x80000000; // Initialize with a value outside of track index range |
641 Uns32T npairs = exact_evaluation_queue->size(); | 663 Uns32T npairs = qstate->exact_evaluation_queue->size(); |
642 while(npairs--){ | 664 while(npairs--){ |
643 PointPair pp = exact_evaluation_queue->top(); | 665 PointPair pp = qstate->exact_evaluation_queue->top(); |
644 // Large ADB track data must be loaded here for sPower | 666 // Large ADB track data must be loaded here for sPower |
645 if(adb->header->flags & O2_FLAG_LARGE_ADB) { | 667 if(adb->header->flags & O2_FLAG_LARGE_ADB) { |
646 trackOffset=0; | 668 trackOffset=0; |
647 trackIndexOffset=0; | 669 trackIndexOffset=0; |
648 if(currentTrack!=pp.trackID){ | 670 if(currentTrack!=pp.trackID){ |
701 adb_result_t r; | 723 adb_result_t r; |
702 r.key = (*adb->keys)[pp.trackID].c_str(); | 724 r.key = (*adb->keys)[pp.trackID].c_str(); |
703 r.dist = dist; | 725 r.dist = dist; |
704 r.qpos = pp.qpos; | 726 r.qpos = pp.qpos; |
705 r.ipos = pp.spos; | 727 r.ipos = pp.spos; |
706 accumulator->add_point(&r); | 728 qstate->accumulator->add_point(&r); |
707 } | 729 } |
708 } | 730 } |
709 exact_evaluation_queue->pop(); | 731 qstate->exact_evaluation_queue->pop(); |
710 } | 732 } |
711 // Cleanup | 733 // Cleanup |
712 SAFE_DELETE_ARRAY(dbpointers.l2norm_data); | 734 SAFE_DELETE_ARRAY(dbpointers.l2norm_data); |
713 SAFE_DELETE_ARRAY(dbpointers.power_data); | 735 SAFE_DELETE_ARRAY(dbpointers.power_data); |
714 SAFE_DELETE_ARRAY(dbpointers.mean_duration); | 736 SAFE_DELETE_ARRAY(dbpointers.mean_duration); |
715 } | 737 delete qstate->exact_evaluation_queue; |
716 | 738 } |
717 int audioDB::query_loop(adb_t *adb, adb_query_spec_t *spec) { | 739 |
740 int audioDB::query_loop(adb_t *adb, adb_query_spec_t *spec, adb_qstate_internal_t *qstate) { | |
718 | 741 |
719 double *query, *query_data; | 742 double *query, *query_data; |
720 adb_qpointers_internal_t qpointers = {0}, dbpointers = {0}; | 743 adb_qpointers_internal_t qpointers = {0}, dbpointers = {0}; |
721 | 744 |
722 bool power_refine = spec->refine.flags & (ADB_REFINE_ABSOLUTE_THRESHOLD|ADB_REFINE_RELATIVE_THRESHOLD); | 745 bool power_refine = spec->refine.flags & (ADB_REFINE_ABSOLUTE_THRESHOLD|ADB_REFINE_RELATIVE_THRESHOLD); |
723 | |
724 std::set<std::string> keys; | |
725 if(spec->refine.flags & ADB_REFINE_INCLUDE_KEYLIST) { | |
726 for(unsigned int k = 0; k < spec->refine.include.nkeys; k++) { | |
727 keys.insert(spec->refine.include.keys[k]); | |
728 } | |
729 } else { | |
730 for(unsigned int k = 0; k < adb->header->numFiles; k++) { | |
731 keys.insert((*adb->keys)[k]); | |
732 } | |
733 } | |
734 if(spec->refine.flags & ADB_REFINE_EXCLUDE_KEYLIST) { | |
735 for(unsigned int k = 0; k < spec->refine.exclude.nkeys; k++) { | |
736 keys.erase(spec->refine.exclude.keys[k]); | |
737 } | |
738 } | |
739 | 746 |
740 if(adb->header->flags & O2_FLAG_LARGE_ADB) { | 747 if(adb->header->flags & O2_FLAG_LARGE_ADB) { |
741 /* FIXME: actually it would be nice to support this mode of | 748 /* FIXME: actually it would be nice to support this mode of |
742 * operation, but for now... */ | 749 * operation, but for now... */ |
743 return 1; | 750 return 1; |
764 // Track loop | 771 // Track loop |
765 size_t data_buffer_size = 0; | 772 size_t data_buffer_size = 0; |
766 double *data_buffer = 0; | 773 double *data_buffer = 0; |
767 lseek(adb->fd, adb->header->dataOffset, SEEK_SET); | 774 lseek(adb->fd, adb->header->dataOffset, SEEK_SET); |
768 | 775 |
776 std::set<std::string>::iterator keys_end = qstate->allowed_keys->end(); | |
769 for(track = 0; track < adb->header->numFiles; track++) { | 777 for(track = 0; track < adb->header->numFiles; track++) { |
770 unsigned t = track; | 778 unsigned t = track; |
771 | 779 |
772 while (keys.find((*adb->keys)[track]) == keys.end()) { | 780 while (qstate->allowed_keys->find((*adb->keys)[track]) == keys_end) { |
773 track++; | 781 track++; |
774 if(track == adb->header->numFiles) { | 782 if(track == adb->header->numFiles) { |
775 goto loop_finish; | 783 goto loop_finish; |
776 } | 784 } |
777 } | 785 } |
820 r.qpos = j; | 828 r.qpos = j; |
821 } else { | 829 } else { |
822 r.qpos = spec->qid.sequence_start; | 830 r.qpos = spec->qid.sequence_start; |
823 } | 831 } |
824 r.ipos = k; | 832 r.ipos = k; |
825 accumulator->add_point(&r); | 833 qstate->accumulator->add_point(&r); |
826 } | 834 } |
827 } | 835 } |
828 } | 836 } |
829 } | 837 } |
830 } // Duration match | 838 } // Duration match |