Mercurial > hg > audiodb
comparison query.cpp @ 425:d65410f4bb85 api-inversion
Begin pushing information through.
Create and initialize an adb_query_refine_t struct in audioDB::query,
and pass it through to various callees, which can then use it instead of
the automagic class member variables.
This allows conversion of some routines into ordinary static C
functions. Begin doing so.
author | mas01cr |
---|---|
date | Wed, 24 Dec 2008 10:55:16 +0000 |
parents | c6046dd80570 |
children | 4a22a0bdf9a9 |
comparison
equal
deleted
inserted
replaced
424:c6046dd80570 | 425:d65410f4bb85 |
---|---|
2 #include "reporter.h" | 2 #include "reporter.h" |
3 | 3 |
4 #include "audioDB-internals.h" | 4 #include "audioDB-internals.h" |
5 #include "accumulators.h" | 5 #include "accumulators.h" |
6 | 6 |
7 bool audioDB::powers_acceptable(double p1, double p2) { | 7 static bool audiodb_powers_acceptable(adb_query_refine_t *r, double p1, double p2) { |
8 if (use_absolute_threshold) { | 8 if (r->flags & ADB_REFINE_ABSOLUTE_THRESHOLD) { |
9 if ((p1 < absolute_threshold) || (p2 < absolute_threshold)) { | 9 if ((p1 < r->absolute_threshold) || (p2 < r->absolute_threshold)) { |
10 return false; | 10 return false; |
11 } | 11 } |
12 } | 12 } |
13 if (use_relative_threshold) { | 13 if (r->flags & ADB_REFINE_RELATIVE_THRESHOLD) { |
14 if (fabs(p1-p2) > fabs(relative_threshold)) { | 14 if (fabs(p1-p2) > fabs(r->relative_threshold)) { |
15 return false; | 15 return false; |
16 } | 16 } |
17 } | 17 } |
18 return true; | 18 return true; |
19 } | 19 } |
20 | 20 |
21 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { | 21 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { |
22 | |
23 adb_query_refine_t refine; | |
24 refine.flags = 0; | |
25 /* FIXME: trackFile / ADB_REFINE_KEYLIST */ | |
26 if(radius) { | |
27 refine.flags |= ADB_REFINE_RADIUS; | |
28 refine.radius = radius; | |
29 } | |
30 if(use_absolute_threshold) { | |
31 refine.flags |= ADB_REFINE_ABSOLUTE_THRESHOLD; | |
32 refine.absolute_threshold = absolute_threshold; | |
33 } | |
34 if(use_relative_threshold) { | |
35 refine.flags |= ADB_REFINE_RELATIVE_THRESHOLD; | |
36 refine.relative_threshold = relative_threshold; | |
37 } | |
38 if(usingTimes) { | |
39 refine.flags |= ADB_REFINE_DURATION_RATIO; | |
40 refine.duration_ratio = timesTol; | |
41 } | |
42 /* FIXME: not sure about this any more; maybe it belongs in query_id */ | |
43 if(sequenceHop != 1) { | |
44 refine.flags |= ADB_REFINE_HOP_SIZE; | |
45 refine.hopsize = sequenceHop; | |
46 } | |
47 | |
22 // init database tables and dbH first | 48 // init database tables and dbH first |
23 if(query_from_key) | 49 if(query_from_key) |
24 initTables(dbName); | 50 initTables(dbName); |
25 else | 51 else |
26 initTables(dbName, inFile); | 52 initTables(dbName, inFile); |
44 break; | 70 break; |
45 case O2_SEQUENCE_QUERY: | 71 case O2_SEQUENCE_QUERY: |
46 if(no_unit_norming) | 72 if(no_unit_norming) |
47 normalizedDistance = false; | 73 normalizedDistance = false; |
48 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); | 74 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); |
49 if(radius == 0) { | 75 if(!(refine.flags & ADB_REFINE_RADIUS)) { |
50 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); | 76 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); |
51 } else { | 77 } else { |
52 if(index_exists(dbName, radius, sequenceLength)){ | 78 if(index_exists(dbName, radius, sequenceLength)){ |
53 char* indexName = index_get_name(dbName, radius, sequenceLength); | 79 char* indexName = index_get_name(dbName, radius, sequenceLength); |
54 lsh = index_allocate(indexName, false); | 80 lsh = index_allocate(indexName, false); |
61 break; | 87 break; |
62 case O2_N_SEQUENCE_QUERY: | 88 case O2_N_SEQUENCE_QUERY: |
63 if(no_unit_norming) | 89 if(no_unit_norming) |
64 normalizedDistance = false; | 90 normalizedDistance = false; |
65 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); | 91 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); |
66 if(radius == 0) { | 92 if(!(refine.flags & ADB_REFINE_RADIUS)) { |
67 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); | 93 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); |
68 } else { | 94 } else { |
69 if(index_exists(dbName, radius, sequenceLength)){ | 95 if(index_exists(dbName, radius, sequenceLength)){ |
70 char* indexName = index_get_name(dbName, radius, sequenceLength); | 96 char* indexName = index_get_name(dbName, radius, sequenceLength); |
71 lsh = index_allocate(indexName, false); | 97 lsh = index_allocate(indexName, false); |
76 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); | 102 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); |
77 } | 103 } |
78 break; | 104 break; |
79 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY : | 105 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY : |
80 accumulator = new NearestAccumulator<adb_result_dist_lt>(); | 106 accumulator = new NearestAccumulator<adb_result_dist_lt>(); |
81 if(radius == 0) { | 107 if(!(refine.flags & ADB_REFINE_RADIUS)) { |
82 error("query-type not yet supported"); | 108 error("query-type not yet supported"); |
83 } else { | 109 } else { |
84 reporter = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles); | 110 reporter = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles); |
85 } | 111 } |
86 break; | 112 break; |
87 default: | 113 default: |
88 error("unrecognized queryType in query()"); | 114 error("unrecognized queryType in query()"); |
89 } | 115 } |
90 | 116 |
91 // Test for index (again) here | 117 // Test for index (again) here |
92 if(radius && index_exists(dbName, radius, sequenceLength)){ | 118 if((refine.flags & ADB_REFINE_RADIUS) && index_exists(dbName, radius, sequenceLength)){ |
93 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequenceLength=%d\n", dbName, radius, sequenceLength); | 119 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequenceLength=%d\n", dbName, radius, sequenceLength); |
94 index_query_loop(dbName, query_from_key_index); | 120 index_query_loop(&refine, dbName, query_from_key_index); |
95 } | 121 } |
96 else{ | 122 else{ |
97 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); | 123 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); |
98 query_loop(dbName, query_from_key_index); | 124 query_loop(&refine, query_from_key_index); |
99 } | 125 } |
100 | 126 |
101 adb_query_results_t *rs = accumulator->get_points(); | 127 adb_query_results_t *rs = accumulator->get_points(); |
102 for(unsigned int k = 0; k < rs->nresults; k++) { | 128 for(unsigned int k = 0; k < rs->nresults; k++) { |
103 adb_result_t r = rs->results[k]; | 129 adb_result_t r = rs->results[k]; |
548 // A reporter has been allocated | 574 // A reporter has been allocated |
549 // | 575 // |
550 // Postconditions: | 576 // Postconditions: |
551 // reporter contains the points and distances that meet the reporter constraints | 577 // reporter contains the points and distances that meet the reporter constraints |
552 | 578 |
553 void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors){ | 579 void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors, adb_query_refine_t *refine){ |
554 unsigned int dbVectors; | 580 unsigned int dbVectors; |
555 double *sNorm = 0, *snPtr, *sPower = 0, *spPtr = 0; | 581 double *sNorm = 0, *snPtr, *sPower = 0, *spPtr = 0; |
556 double *meanDBdur = 0; | 582 double *meanDBdur = 0; |
557 | 583 |
558 // check pre-conditions | 584 // check pre-conditions |
617 trackIndexOffset=trackOffset/dbH->dim; // num vectors offset | 643 trackIndexOffset=trackOffset/dbH->dim; // num vectors offset |
618 } | 644 } |
619 Uns32T qPos = usingQueryPoint?0:pp.qpos;// index for query point | 645 Uns32T qPos = usingQueryPoint?0:pp.qpos;// index for query point |
620 Uns32T sPos = trackIndexOffset+pp.spos; // index into l2norm table | 646 Uns32T sPos = trackIndexOffset+pp.spos; // index into l2norm table |
621 // Test power thresholds before computing distance | 647 // Test power thresholds before computing distance |
622 if( ( !usingPower || powers_acceptable(qpPtr[qPos], sPower[sPos])) && | 648 if( ( !usingPower || audiodb_powers_acceptable(refine, qpPtr[qPos], sPower[sPos])) && |
623 ( qPos<numVectors-sequenceLength+1 && pp.spos<trackTable[pp.trackID]-sequenceLength+1 ) ){ | 649 ( qPos<numVectors-sequenceLength+1 && pp.spos<trackTable[pp.trackID]-sequenceLength+1 ) ){ |
624 // Non-large ADB track data is loaded inside power test for efficiency | 650 // Non-large ADB track data is loaded inside power test for efficiency |
625 if( !(dbH->flags & O2_FLAG_LARGE_ADB) && (currentTrack!=pp.trackID) ){ | 651 if( !(dbH->flags & O2_FLAG_LARGE_ADB) && (currentTrack!=pp.trackID) ){ |
626 // On currentTrack change, allocate and load track data | 652 // On currentTrack change, allocate and load track data |
627 currentTrack=pp.trackID; | 653 currentTrack=pp.trackID; |
628 lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET); | 654 lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET); |
629 read_data(dbfid, currentTrack, &data_buffer, &data_buffer_size); | 655 read_data(dbfid, currentTrack, &data_buffer, &data_buffer_size); |
630 } | 656 } |
631 // Compute distance | 657 // Compute distance |
632 dist = dot_product_points(query+qPos*dbH->dim, data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength); | 658 dist = audiodb_dot_product(query+qPos*dbH->dim, data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength); |
633 double qn = qnPtr[qPos]; | 659 double qn = qnPtr[qPos]; |
634 double sn = sNorm[sPos]; | 660 double sn = sNorm[sPos]; |
635 if(normalizedDistance) | 661 if(normalizedDistance) |
636 dist = 2 - (2/(qn*sn))*dist; | 662 dist = 2 - (2/(qn*sn))*dist; |
637 else | 663 else |
654 SAFE_DELETE_ARRAY(sNorm); | 680 SAFE_DELETE_ARRAY(sNorm); |
655 SAFE_DELETE_ARRAY(sPower); | 681 SAFE_DELETE_ARRAY(sPower); |
656 SAFE_DELETE_ARRAY(meanDBdur); | 682 SAFE_DELETE_ARRAY(meanDBdur); |
657 } | 683 } |
658 | 684 |
659 // A completely unprotected dot-product method | 685 void audioDB::query_loop(adb_query_refine_t *refine, Uns32T queryIndex) { |
660 // Caller is responsible for ensuring that memory is within bounds | |
661 inline double audioDB::dot_product_points(double* q, double* p, Uns32T L){ | |
662 double dist = 0.0; | |
663 while(L--) | |
664 dist += *q++ * *p++; | |
665 return dist; | |
666 } | |
667 | |
668 void audioDB::query_loop(const char* dbName, Uns32T queryIndex) { | |
669 | 686 |
670 unsigned int numVectors; | 687 unsigned int numVectors; |
671 double *query, *query_data; | 688 double *query, *query_data; |
672 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0; | 689 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0; |
673 double meanQdur; | 690 double meanQdur; |
743 | 760 |
744 VERB_LOG(7,"%u.%jd.%u | ", track, (intmax_t) trackIndexOffset, trackTable[track]); | 761 VERB_LOG(7,"%u.%jd.%u | ", track, (intmax_t) trackIndexOffset, trackTable[track]); |
745 | 762 |
746 initialize_arrays(track, numVectors, query, data_buffer, D, DD); | 763 initialize_arrays(track, numVectors, query, data_buffer, D, DD); |
747 | 764 |
748 if(usingTimes) { | 765 if(refine->flags & ADB_REFINE_DURATION_RATIO) { |
749 VERB_LOG(3,"meanQdur=%f meanDBdur=%f\n", meanQdur, meanDBdur[track]); | 766 VERB_LOG(3,"meanQdur=%f meanDBdur=%f\n", meanQdur, meanDBdur[track]); |
750 } | 767 } |
751 | 768 |
752 if((!usingTimes) || fabs(meanDBdur[track]-meanQdur) < meanQdur*timesTol) { | 769 if((!(refine->flags & ADB_REFINE_DURATION_RATIO)) || fabs(meanDBdur[track]-meanQdur) < meanQdur*refine->duration_ratio) { |
753 if(usingTimes) { | 770 if(refine->flags & ADB_REFINE_DURATION_RATIO) { |
754 VERB_LOG(3,"within duration tolerance.\n"); | 771 VERB_LOG(3,"within duration tolerance.\n"); |
755 } | 772 } |
756 | 773 |
757 // Search for minimum distance by shingles (concatenated vectors) | 774 // Search for minimum distance by shingles (concatenated vectors) |
758 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) { | 775 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) { |
765 thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k]; | 782 thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k]; |
766 else | 783 else |
767 thisDist = DD[j][k]; | 784 thisDist = DD[j][k]; |
768 | 785 |
769 // Power test | 786 // Power test |
770 if ((!usingPower) || powers_acceptable(qpPtr[j], sPower[trackIndexOffset + k])) { | 787 if ((!usingPower) || audiodb_powers_acceptable(refine, qpPtr[j], sPower[trackIndexOffset + k])) { |
771 // radius test | 788 // radius test |
772 if((!radius) || thisDist <= (radius+O2_DISTANCE_TOLERANCE)) { | 789 if((!(refine->flags & ADB_REFINE_RADIUS)) || |
790 thisDist <= (refine->radius+O2_DISTANCE_TOLERANCE)) { | |
773 adb_result_t r; | 791 adb_result_t r; |
774 r.key = fileTable + track * O2_FILETABLE_ENTRY_SIZE; | 792 r.key = fileTable + track * O2_FILETABLE_ENTRY_SIZE; |
775 r.dist = thisDist; | 793 r.dist = thisDist; |
776 r.qpos = usingQueryPoint ? queryPoint : j; | 794 r.qpos = usingQueryPoint ? queryPoint : j; |
777 r.ipos = k; | 795 r.ipos = k; |
830 } | 848 } |
831 X += dim; | 849 X += dim; |
832 } | 850 } |
833 VERB_LOG(2, "done.\n"); | 851 VERB_LOG(2, "done.\n"); |
834 } | 852 } |
835 | |
836 |