Mercurial > hg > audiodb
comparison query.cpp @ 431:8632cd387e24 api-inversion
Punishment gluttony.
Continue teasing out vague orthogonalities by beginning the task of
using an adb_query_parameters_t. This does have the benefit of making
the distance calculation clearer, and we begin to see the shape of a
putative audiodb_query() emerging from the shrapnel of audioDB::query.
(Only the general shape; the detail is still a long, long way away).
author | mas01cr |
---|---|
date | Wed, 24 Dec 2008 10:55:40 +0000 |
parents | 2d14d21f826b |
children | 681837f7c903 |
comparison
equal
deleted
inserted
replaced
430:2d14d21f826b | 431:8632cd387e24 |
---|---|
18 return true; | 18 return true; |
19 } | 19 } |
20 | 20 |
21 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { | 21 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { |
22 | 22 |
23 // init database tables and dbH first | |
24 if(query_from_key) | |
25 initTables(dbName); | |
26 else | |
27 initTables(dbName, inFile); | |
28 | |
23 adb_query_refine_t refine; | 29 adb_query_refine_t refine; |
30 adb_query_parameters_t params; | |
24 refine.flags = 0; | 31 refine.flags = 0; |
25 /* FIXME: trackFile / ADB_REFINE_KEYLIST */ | 32 /* FIXME: trackFile / ADB_REFINE_KEYLIST */ |
26 if(radius) { | 33 if(radius) { |
27 refine.flags |= ADB_REFINE_RADIUS; | 34 refine.flags |= ADB_REFINE_RADIUS; |
28 refine.radius = radius; | 35 refine.radius = radius; |
43 if(sequenceHop != 1) { | 50 if(sequenceHop != 1) { |
44 refine.flags |= ADB_REFINE_HOP_SIZE; | 51 refine.flags |= ADB_REFINE_HOP_SIZE; |
45 refine.hopsize = sequenceHop; | 52 refine.hopsize = sequenceHop; |
46 } | 53 } |
47 | 54 |
48 // init database tables and dbH first | 55 switch(queryType) { |
49 if(query_from_key) | |
50 initTables(dbName); | |
51 else | |
52 initTables(dbName, inFile); | |
53 | |
54 // keyKeyPos requires dbH to be initialized | |
55 if(query_from_key && (!key || (query_from_key_index = audiodb_key_index(adb, key)) == (uint32_t) -1)) | |
56 error("Query key not found", key); | |
57 | |
58 switch (queryType) { | |
59 case O2_POINT_QUERY: | 56 case O2_POINT_QUERY: |
60 sequenceLength = 1; | 57 sequenceLength = 1; |
61 normalizedDistance = false; | 58 params.accumulation = ADB_ACCUMULATION_DB; |
59 params.distance = ADB_DISTANCE_DOT_PRODUCT; | |
60 params.npoints = pointNN; | |
61 params.ntracks = 0; | |
62 reporter = new pointQueryReporter< std::greater < NNresult > >(pointNN); | 62 reporter = new pointQueryReporter< std::greater < NNresult > >(pointNN); |
63 accumulator = new DBAccumulator<adb_result_dist_gt>(pointNN); | |
64 break; | 63 break; |
65 case O2_TRACK_QUERY: | 64 case O2_TRACK_QUERY: |
66 sequenceLength = 1; | 65 sequenceLength = 1; |
67 normalizedDistance = false; | 66 params.accumulation = ADB_ACCUMULATION_PER_TRACK; |
67 params.distance = ADB_DISTANCE_DOT_PRODUCT; | |
68 params.npoints = pointNN; | |
69 params.ntracks = trackNN; | |
68 reporter = new trackAveragingReporter< std::greater< NNresult > >(pointNN, trackNN, dbH->numFiles); | 70 reporter = new trackAveragingReporter< std::greater< NNresult > >(pointNN, trackNN, dbH->numFiles); |
69 accumulator = new PerTrackAccumulator<adb_result_dist_gt>(pointNN, trackNN); | |
70 break; | 71 break; |
71 case O2_SEQUENCE_QUERY: | 72 case O2_SEQUENCE_QUERY: |
72 if(no_unit_norming) | 73 case O2_N_SEQUENCE_QUERY: |
73 normalizedDistance = false; | 74 params.accumulation = ADB_ACCUMULATION_PER_TRACK; |
74 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); | 75 params.distance = no_unit_norming ? ADB_DISTANCE_EUCLIDEAN : ADB_DISTANCE_EUCLIDEAN_NORMED; |
75 if(!(refine.flags & ADB_REFINE_RADIUS)) { | 76 params.npoints = pointNN; |
76 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); | 77 params.ntracks = trackNN; |
77 } else { | 78 switch(queryType) { |
78 if(index_exists(dbName, radius, sequenceLength)){ | 79 case O2_SEQUENCE_QUERY: |
80 if(!(refine.flags & ADB_REFINE_RADIUS)) { | |
81 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); | |
82 } else if (index_exists(dbName, radius, sequenceLength)) { | |
79 char* indexName = index_get_name(dbName, radius, sequenceLength); | 83 char* indexName = index_get_name(dbName, radius, sequenceLength); |
80 lsh = index_allocate(indexName, false); | 84 lsh = index_allocate(indexName, false); |
81 reporter = new trackSequenceQueryRadReporter(trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); | 85 reporter = new trackSequenceQueryRadReporter(trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); |
82 delete[] indexName; | 86 delete[] indexName; |
83 } | 87 } else { |
84 else | |
85 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); | 88 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); |
86 } | 89 } |
87 break; | 90 break; |
88 case O2_N_SEQUENCE_QUERY: | 91 case O2_N_SEQUENCE_QUERY: |
89 if(no_unit_norming) | 92 if(!(refine.flags & ADB_REFINE_RADIUS)) { |
90 normalizedDistance = false; | 93 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); |
91 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(pointNN, trackNN); | 94 } else if (index_exists(dbName, radius, sequenceLength)){ |
92 if(!(refine.flags & ADB_REFINE_RADIUS)) { | |
93 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); | |
94 } else { | |
95 if(index_exists(dbName, radius, sequenceLength)){ | |
96 char* indexName = index_get_name(dbName, radius, sequenceLength); | 95 char* indexName = index_get_name(dbName, radius, sequenceLength); |
97 lsh = index_allocate(indexName, false); | 96 lsh = index_allocate(indexName, false); |
98 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); | 97 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, index_to_trackID(lsh->get_maxp(), lsh_n_point_bits)+1); |
99 delete[] indexName; | 98 delete[] indexName; |
100 } | 99 } else { |
101 else | |
102 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); | 100 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); |
101 } | |
102 break; | |
103 } | 103 } |
104 break; | 104 break; |
105 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY : | 105 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY: |
106 accumulator = new NearestAccumulator<adb_result_dist_lt>(); | 106 params.accumulation = ADB_ACCUMULATION_ONE_TO_ONE; |
107 if(!(refine.flags & ADB_REFINE_RADIUS)) { | 107 params.distance = ADB_DISTANCE_EUCLIDEAN_NORMED; |
108 error("query-type not yet supported"); | 108 params.npoints = 0; |
109 } else { | 109 params.ntracks = 0; |
110 reporter = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles); | |
111 } | |
112 break; | 110 break; |
113 default: | 111 default: |
114 error("unrecognized queryType in query()"); | 112 error("unrecognized queryType"); |
115 } | 113 } |
116 | 114 |
115 // keyKeyPos requires dbH to be initialized | |
116 if(query_from_key && (!key || (query_from_key_index = audiodb_key_index(adb, key)) == (uint32_t) -1)) | |
117 error("Query key not found", key); | |
118 | |
119 switch(params.distance) { | |
120 case ADB_DISTANCE_DOT_PRODUCT: | |
121 switch(params.accumulation) { | |
122 case ADB_ACCUMULATION_DB: | |
123 accumulator = new DBAccumulator<adb_result_dist_gt>(params.npoints); | |
124 break; | |
125 case ADB_ACCUMULATION_PER_TRACK: | |
126 accumulator = new PerTrackAccumulator<adb_result_dist_gt>(params.npoints, params.ntracks); | |
127 break; | |
128 case ADB_ACCUMULATION_ONE_TO_ONE: | |
129 accumulator = new NearestAccumulator<adb_result_dist_gt>(); | |
130 break; | |
131 default: | |
132 error("unknown accumulation"); | |
133 } | |
134 break; | |
135 case ADB_DISTANCE_EUCLIDEAN_NORMED: | |
136 case ADB_DISTANCE_EUCLIDEAN: | |
137 switch(params.accumulation) { | |
138 case ADB_ACCUMULATION_DB: | |
139 accumulator = new DBAccumulator<adb_result_dist_lt>(params.npoints); | |
140 break; | |
141 case ADB_ACCUMULATION_PER_TRACK: | |
142 accumulator = new PerTrackAccumulator<adb_result_dist_lt>(params.npoints, params.ntracks); | |
143 break; | |
144 case ADB_ACCUMULATION_ONE_TO_ONE: | |
145 accumulator = new NearestAccumulator<adb_result_dist_lt>(); | |
146 break; | |
147 default: | |
148 error("unknown accumulation"); | |
149 } | |
150 break; | |
151 default: | |
152 error("unknown distance function"); | |
153 } | |
154 | |
117 // Test for index (again) here | 155 // Test for index (again) here |
118 if((refine.flags & ADB_REFINE_RADIUS) && index_exists(dbName, radius, sequenceLength)){ | 156 if((refine.flags & ADB_REFINE_RADIUS) && index_exists(dbName, radius, sequenceLength)){ |
119 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequenceLength=%d\n", dbName, radius, sequenceLength); | 157 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequenceLength=%d\n", dbName, radius, sequenceLength); |
120 index_query_loop(&refine, dbName, query_from_key_index); | 158 index_query_loop(¶ms, &refine, dbName, query_from_key_index); |
121 } | 159 } |
122 else{ | 160 else{ |
123 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); | 161 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); |
124 query_loop(&refine, query_from_key_index); | 162 query_loop(¶ms, &refine, query_from_key_index); |
125 } | 163 } |
126 | 164 |
127 adb_query_results_t *rs = accumulator->get_points(); | 165 adb_query_results_t *rs = accumulator->get_points(); |
128 for(unsigned int k = 0; k < rs->nresults; k++) { | 166 for(unsigned int k = 0; k < rs->nresults; k++) { |
129 adb_result_t r = rs->results[k]; | 167 adb_result_t r = rs->results[k]; |
513 // A reporter has been allocated | 551 // A reporter has been allocated |
514 // | 552 // |
515 // Postconditions: | 553 // Postconditions: |
516 // reporter contains the points and distances that meet the reporter constraints | 554 // reporter contains the points and distances that meet the reporter constraints |
517 | 555 |
518 void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors, adb_query_refine_t *refine){ | 556 void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors, adb_query_parameters_t *params, adb_query_refine_t *refine){ |
519 unsigned int dbVectors; | 557 unsigned int dbVectors; |
520 double *sNorm = 0, *snPtr, *sPower = 0, *spPtr = 0; | 558 double *sNorm = 0, *snPtr, *sPower = 0, *spPtr = 0; |
521 double *meanDBdur = 0; | 559 double *meanDBdur = 0; |
522 | 560 |
523 // check pre-conditions | 561 // check pre-conditions |
530 // so make a new method to build these values per-track or, even better, per-point | 568 // so make a new method to build these values per-track or, even better, per-point |
531 if( !( dbH->flags & O2_FLAG_LARGE_ADB) ) | 569 if( !( dbH->flags & O2_FLAG_LARGE_ADB) ) |
532 set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors); | 570 set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors); |
533 | 571 |
534 VERB_LOG(1, "matching points..."); | 572 VERB_LOG(1, "matching points..."); |
535 | |
536 assert(pointNN>0 && pointNN<=O2_MAXNN); | |
537 assert(trackNN>0 && trackNN<=O2_MAXNN); | |
538 | 573 |
539 // We are guaranteed that the order of points is sorted by: | 574 // We are guaranteed that the order of points is sorted by: |
540 // trackID, spos, qpos | 575 // trackID, spos, qpos |
541 // so we can be relatively efficient in initialization of track data. | 576 // so we can be relatively efficient in initialization of track data. |
542 // Here we assume that points don't overlap, so we will use exhaustive dot | 577 // Here we assume that points don't overlap, so we will use exhaustive dot |
595 } | 630 } |
596 // Compute distance | 631 // Compute distance |
597 dist = audiodb_dot_product(query+qPos*dbH->dim, data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength); | 632 dist = audiodb_dot_product(query+qPos*dbH->dim, data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength); |
598 double qn = qnPtr[qPos]; | 633 double qn = qnPtr[qPos]; |
599 double sn = sNorm[sPos]; | 634 double sn = sNorm[sPos]; |
600 if(normalizedDistance) | 635 switch(params->distance) { |
636 case ADB_DISTANCE_EUCLIDEAN_NORMED: | |
601 dist = 2 - (2/(qn*sn))*dist; | 637 dist = 2 - (2/(qn*sn))*dist; |
602 else | 638 break; |
603 if(no_unit_norming) | 639 case ADB_DISTANCE_EUCLIDEAN: |
604 dist = qn*qn + sn*sn - 2*dist; | 640 dist = qn*qn + sn*sn - 2*dist; |
605 // else | 641 break; |
606 // dist = dist; | 642 } |
607 if((!radius) || dist <= (O2_LSH_EXACT_MULT*radius+O2_DISTANCE_TOLERANCE)) { | 643 if((!radius) || dist <= (O2_LSH_EXACT_MULT*radius+O2_DISTANCE_TOLERANCE)) { |
608 adb_result_t r; | 644 adb_result_t r; |
609 r.key = fileTable + pp.trackID * O2_FILETABLE_ENTRY_SIZE; | 645 r.key = fileTable + pp.trackID * O2_FILETABLE_ENTRY_SIZE; |
610 r.dist = dist; | 646 r.dist = dist; |
611 r.qpos = pp.qpos; | 647 r.qpos = pp.qpos; |
619 SAFE_DELETE_ARRAY(sNorm); | 655 SAFE_DELETE_ARRAY(sNorm); |
620 SAFE_DELETE_ARRAY(sPower); | 656 SAFE_DELETE_ARRAY(sPower); |
621 SAFE_DELETE_ARRAY(meanDBdur); | 657 SAFE_DELETE_ARRAY(meanDBdur); |
622 } | 658 } |
623 | 659 |
624 void audioDB::query_loop(adb_query_refine_t *refine, Uns32T queryIndex) { | 660 void audioDB::query_loop(adb_query_parameters_t *params, adb_query_refine_t *refine, Uns32T queryIndex) { |
625 | 661 |
626 unsigned int numVectors; | 662 unsigned int numVectors; |
627 double *query, *query_data; | 663 double *query, *query_data; |
628 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0; | 664 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0; |
629 double meanQdur; | 665 double meanQdur; |
642 | 678 |
643 set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors); | 679 set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors); |
644 | 680 |
645 VERB_LOG(1, "matching tracks..."); | 681 VERB_LOG(1, "matching tracks..."); |
646 | 682 |
647 assert(pointNN>0 && pointNN<=O2_MAXNN); | |
648 assert(trackNN>0 && trackNN<=O2_MAXNN); | |
649 | |
650 unsigned j,k,track,trackOffset=0, HOP_SIZE=sequenceHop, wL=sequenceLength; | 683 unsigned j,k,track,trackOffset=0, HOP_SIZE=sequenceHop, wL=sequenceLength; |
651 double **D = 0; // Differences query and target | 684 double **D = 0; // Differences query and target |
652 double **DD = 0; // Matched filter distance | 685 double **DD = 0; // Matched filter distance |
653 | 686 |
654 D = new double*[numVectors]; // pre-allocate | 687 D = new double*[numVectors]; // pre-allocate |
714 } | 747 } |
715 | 748 |
716 // Search for minimum distance by shingles (concatenated vectors) | 749 // Search for minimum distance by shingles (concatenated vectors) |
717 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) { | 750 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) { |
718 for(k = 0; k <= trackTable[track] - wL; k += HOP_SIZE) { | 751 for(k = 0; k <= trackTable[track] - wL; k += HOP_SIZE) { |
719 double thisDist; | 752 double thisDist = 0; |
720 if(normalizedDistance) | 753 switch(params->distance) { |
754 case ADB_DISTANCE_EUCLIDEAN_NORMED: | |
721 thisDist = 2-(2/(qnPtr[j]*sNorm[trackIndexOffset+k]))*DD[j][k]; | 755 thisDist = 2-(2/(qnPtr[j]*sNorm[trackIndexOffset+k]))*DD[j][k]; |
722 else | 756 break; |
723 if(no_unit_norming) | 757 case ADB_DISTANCE_EUCLIDEAN: |
724 thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k]; | 758 thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k]; |
725 else | 759 break; |
726 thisDist = DD[j][k]; | 760 case ADB_DISTANCE_DOT_PRODUCT: |
727 | 761 thisDist = DD[j][k]; |
762 break; | |
763 } | |
728 // Power test | 764 // Power test |
729 if ((!usingPower) || audiodb_powers_acceptable(refine, qpPtr[j], sPower[trackIndexOffset + k])) { | 765 if ((!usingPower) || audiodb_powers_acceptable(refine, qpPtr[j], sPower[trackIndexOffset + k])) { |
730 // radius test | 766 // radius test |
731 if((!(refine->flags & ADB_REFINE_RADIUS)) || | 767 if((!(refine->flags & ADB_REFINE_RADIUS)) || |
732 thisDist <= (refine->radius+O2_DISTANCE_TOLERANCE)) { | 768 thisDist <= (refine->radius+O2_DISTANCE_TOLERANCE)) { |