Mercurial > hg > audiodb
comparison query.cpp @ 469:d3afc91d205d api-inversion
Move audioDB::query over to audioDB.cpp
At the same time, remove all the abstraction violations in
audioDB::query, which came in two flavours: use of dbH->numFiles, which
is dealt with by getting the database status instead (and is eventually
unnecessary, being only needed now because reporters are implemented in
terms of vectors indexed by ID), and use of fileTable in reporter's
report functions (dealt with by passing in the adb instead).
To actually implement reporting as of now, we continue to use stuff from
audioDB-internals.h; maybe someday we will be clean and shiny.
author | mas01cr |
---|---|
date | Wed, 31 Dec 2008 15:44:16 +0000 |
parents | 4dbd7917bf9e |
children | 0f96ad351990 |
comparison
equal
deleted
inserted
replaced
468:4dbd7917bf9e | 469:d3afc91d205d |
---|---|
1 #include "audioDB.h" | 1 #include "audioDB.h" |
2 #include "reporter.h" | |
3 | |
4 #include "audioDB-internals.h" | 2 #include "audioDB-internals.h" |
5 #include "accumulators.h" | 3 #include "accumulators.h" |
6 | 4 |
7 bool audiodb_powers_acceptable(adb_query_refine_t *r, double p1, double p2) { | 5 bool audiodb_powers_acceptable(adb_query_refine_t *r, double p1, double p2) { |
8 if (r->flags & ADB_REFINE_ABSOLUTE_THRESHOLD) { | 6 if (r->flags & ADB_REFINE_ABSOLUTE_THRESHOLD) { |
16 } | 14 } |
17 } | 15 } |
18 return true; | 16 return true; |
19 } | 17 } |
20 | 18 |
21 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { | 19 adb_query_results_t *audiodb_query_spec(adb_t *adb, adb_query_spec_t *qspec) { |
22 | 20 adb_qstate_internal_t qstate = {0}; |
23 // init database tables and dbH first | |
24 if(query_from_key) | |
25 initTables(dbName); | |
26 else | |
27 initTables(dbName, inFile); | |
28 | |
29 adb_query_spec_t qspec; | |
30 adb_datum_t datum = {0}; | |
31 | |
32 qspec.refine.flags = 0; | |
33 if(trackFile) { | |
34 qspec.refine.flags |= ADB_REFINE_INCLUDE_KEYLIST; | |
35 std::vector<const char *> v; | |
36 char *k = new char[MAXSTR]; | |
37 trackFile->getline(k, MAXSTR); | |
38 while(!trackFile->eof()) { | |
39 v.push_back(k); | |
40 k = new char[MAXSTR]; | |
41 trackFile->getline(k, MAXSTR); | |
42 } | |
43 delete [] k; | |
44 qspec.refine.include.nkeys = v.size(); | |
45 qspec.refine.include.keys = new const char *[qspec.refine.include.nkeys]; | |
46 for(unsigned int k = 0; k < qspec.refine.include.nkeys; k++) { | |
47 qspec.refine.include.keys[k] = v[k]; | |
48 } | |
49 } | |
50 if(query_from_key) { | |
51 qspec.refine.flags |= ADB_REFINE_EXCLUDE_KEYLIST; | |
52 qspec.refine.exclude.nkeys = 1; | |
53 qspec.refine.exclude.keys = &key; | |
54 } | |
55 if(radius) { | |
56 qspec.refine.flags |= ADB_REFINE_RADIUS; | |
57 qspec.refine.radius = radius; | |
58 } | |
59 if(use_absolute_threshold) { | |
60 qspec.refine.flags |= ADB_REFINE_ABSOLUTE_THRESHOLD; | |
61 qspec.refine.absolute_threshold = absolute_threshold; | |
62 } | |
63 if(use_relative_threshold) { | |
64 qspec.refine.flags |= ADB_REFINE_RELATIVE_THRESHOLD; | |
65 qspec.refine.relative_threshold = relative_threshold; | |
66 } | |
67 if(usingTimes) { | |
68 qspec.refine.flags |= ADB_REFINE_DURATION_RATIO; | |
69 qspec.refine.duration_ratio = timesTol; | |
70 } | |
71 /* FIXME: not sure about this any more; maybe it belongs in | |
72 query_id? Or maybe we just don't need a flag for it? */ | |
73 qspec.refine.hopsize = sequenceHop; | |
74 if(sequenceHop != 1) { | |
75 qspec.refine.flags |= ADB_REFINE_HOP_SIZE; | |
76 } | |
77 | |
78 if(query_from_key) { | |
79 datum.key = key; | |
80 } else { | |
81 int fd; | |
82 struct stat st; | |
83 | |
84 /* FIXME: around here there are all sorts of hideous leaks. */ | |
85 fd = open(inFile, O_RDONLY); | |
86 if(fd < 0) { | |
87 error("failed to open feature file", inFile); | |
88 } | |
89 fstat(fd, &st); | |
90 read(fd, &datum.dim, sizeof(uint32_t)); | |
91 datum.nvectors = (st.st_size - sizeof(uint32_t)) / (datum.dim * sizeof(double)); | |
92 datum.data = (double *) malloc(st.st_size - sizeof(uint32_t)); | |
93 read(fd, datum.data, st.st_size - sizeof(uint32_t)); | |
94 close(fd); | |
95 if(usingPower) { | |
96 uint32_t one; | |
97 fd = open(powerFileName, O_RDONLY); | |
98 if(fd < 0) { | |
99 error("failed to open power file", powerFileName); | |
100 } | |
101 read(fd, &one, sizeof(uint32_t)); | |
102 if(one != 1) { | |
103 error("malformed power file dimensionality", powerFileName); | |
104 } | |
105 datum.power = (double *) malloc(datum.nvectors * sizeof(double)); | |
106 if(read(fd, datum.power, datum.nvectors * sizeof(double)) != (ssize_t) (datum.nvectors * sizeof(double))) { | |
107 error("malformed power file", powerFileName); | |
108 } | |
109 close(fd); | |
110 } | |
111 if(usingTimes) { | |
112 datum.times = (double *) malloc(2 * datum.nvectors * sizeof(double)); | |
113 insertTimeStamps(datum.nvectors, timesFile, datum.times); | |
114 } | |
115 } | |
116 | |
117 qspec.qid.datum = &datum; | |
118 qspec.qid.sequence_length = sequenceLength; | |
119 qspec.qid.flags = 0; | |
120 qspec.qid.flags |= usingQueryPoint ? 0 : ADB_QID_FLAG_EXHAUSTIVE; | |
121 qspec.qid.flags |= lsh_exact ? 0 : ADB_QID_FLAG_ALLOW_FALSE_POSITIVES; | |
122 qspec.qid.sequence_start = queryPoint; | |
123 | |
124 switch(queryType) { | |
125 case O2_POINT_QUERY: | |
126 qspec.qid.sequence_length = 1; | |
127 qspec.params.accumulation = ADB_ACCUMULATION_DB; | |
128 qspec.params.distance = ADB_DISTANCE_DOT_PRODUCT; | |
129 qspec.params.npoints = pointNN; | |
130 qspec.params.ntracks = 0; | |
131 reporter = new pointQueryReporter< std::greater < NNresult > >(pointNN); | |
132 break; | |
133 case O2_TRACK_QUERY: | |
134 qspec.qid.sequence_length = 1; | |
135 qspec.params.accumulation = ADB_ACCUMULATION_PER_TRACK; | |
136 qspec.params.distance = ADB_DISTANCE_DOT_PRODUCT; | |
137 qspec.params.npoints = pointNN; | |
138 qspec.params.ntracks = trackNN; | |
139 reporter = new trackAveragingReporter< std::greater< NNresult > >(pointNN, trackNN, dbH->numFiles); | |
140 break; | |
141 case O2_SEQUENCE_QUERY: | |
142 case O2_N_SEQUENCE_QUERY: | |
143 qspec.params.accumulation = ADB_ACCUMULATION_PER_TRACK; | |
144 qspec.params.distance = no_unit_norming ? ADB_DISTANCE_EUCLIDEAN : ADB_DISTANCE_EUCLIDEAN_NORMED; | |
145 qspec.params.npoints = pointNN; | |
146 qspec.params.ntracks = trackNN; | |
147 switch(queryType) { | |
148 case O2_SEQUENCE_QUERY: | |
149 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { | |
150 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles); | |
151 } else { | |
152 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); | |
153 } | |
154 break; | |
155 case O2_N_SEQUENCE_QUERY: | |
156 if(!(qspec.refine.flags & ADB_REFINE_RADIUS)) { | |
157 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); | |
158 } else { | |
159 reporter = new trackSequenceQueryRadNNReporter(pointNN, trackNN, dbH->numFiles); | |
160 } | |
161 break; | |
162 } | |
163 break; | |
164 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY: | |
165 qspec.params.accumulation = ADB_ACCUMULATION_ONE_TO_ONE; | |
166 qspec.params.distance = ADB_DISTANCE_EUCLIDEAN_NORMED; | |
167 qspec.params.npoints = 0; | |
168 qspec.params.ntracks = 0; | |
169 break; | |
170 default: | |
171 error("unrecognized queryType"); | |
172 } | |
173 | |
174 /* Somewhere around here is where the implementation of | |
175 * audiodb_query_spec() starts. */ | |
176 | |
177 adb_qstate_internal_t qstate; | |
178 qstate.allowed_keys = new std::set<std::string>; | 21 qstate.allowed_keys = new std::set<std::string>; |
179 if(qspec.refine.flags & ADB_REFINE_INCLUDE_KEYLIST) { | 22 adb_query_results_t *results; |
180 for(unsigned int k = 0; k < qspec.refine.include.nkeys; k++) { | 23 if(qspec->refine.flags & ADB_REFINE_INCLUDE_KEYLIST) { |
181 qstate.allowed_keys->insert(qspec.refine.include.keys[k]); | 24 for(unsigned int k = 0; k < qspec->refine.include.nkeys; k++) { |
25 qstate.allowed_keys->insert(qspec->refine.include.keys[k]); | |
182 } | 26 } |
183 } else { | 27 } else { |
184 for(unsigned int k = 0; k < adb->header->numFiles; k++) { | 28 for(unsigned int k = 0; k < adb->header->numFiles; k++) { |
185 qstate.allowed_keys->insert((*adb->keys)[k]); | 29 qstate.allowed_keys->insert((*adb->keys)[k]); |
186 } | 30 } |
187 } | 31 } |
188 if(qspec.refine.flags & ADB_REFINE_EXCLUDE_KEYLIST) { | 32 if(qspec->refine.flags & ADB_REFINE_EXCLUDE_KEYLIST) { |
189 for(unsigned int k = 0; k < qspec.refine.exclude.nkeys; k++) { | 33 for(unsigned int k = 0; k < qspec->refine.exclude.nkeys; k++) { |
190 qstate.allowed_keys->erase(qspec.refine.exclude.keys[k]); | 34 qstate.allowed_keys->erase(qspec->refine.exclude.keys[k]); |
191 } | 35 } |
192 } | 36 } |
193 | 37 |
194 switch(qspec.params.distance) { | 38 switch(qspec->params.distance) { |
195 case ADB_DISTANCE_DOT_PRODUCT: | 39 case ADB_DISTANCE_DOT_PRODUCT: |
196 switch(qspec.params.accumulation) { | 40 switch(qspec->params.accumulation) { |
197 case ADB_ACCUMULATION_DB: | 41 case ADB_ACCUMULATION_DB: |
198 qstate.accumulator = new DBAccumulator<adb_result_dist_gt>(qspec.params.npoints); | 42 qstate.accumulator = new DBAccumulator<adb_result_dist_gt>(qspec->params.npoints); |
199 break; | 43 break; |
200 case ADB_ACCUMULATION_PER_TRACK: | 44 case ADB_ACCUMULATION_PER_TRACK: |
201 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_gt>(qspec.params.npoints, qspec.params.ntracks); | 45 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_gt>(qspec->params.npoints, qspec->params.ntracks); |
202 break; | 46 break; |
203 case ADB_ACCUMULATION_ONE_TO_ONE: | 47 case ADB_ACCUMULATION_ONE_TO_ONE: |
204 qstate.accumulator = new NearestAccumulator<adb_result_dist_gt>(); | 48 qstate.accumulator = new NearestAccumulator<adb_result_dist_gt>(); |
205 break; | 49 break; |
206 default: | 50 default: |
207 error("unknown accumulation"); | 51 goto error; |
208 } | 52 } |
209 break; | 53 break; |
210 case ADB_DISTANCE_EUCLIDEAN_NORMED: | 54 case ADB_DISTANCE_EUCLIDEAN_NORMED: |
211 case ADB_DISTANCE_EUCLIDEAN: | 55 case ADB_DISTANCE_EUCLIDEAN: |
212 switch(qspec.params.accumulation) { | 56 switch(qspec->params.accumulation) { |
213 case ADB_ACCUMULATION_DB: | 57 case ADB_ACCUMULATION_DB: |
214 qstate.accumulator = new DBAccumulator<adb_result_dist_lt>(qspec.params.npoints); | 58 qstate.accumulator = new DBAccumulator<adb_result_dist_lt>(qspec->params.npoints); |
215 break; | 59 break; |
216 case ADB_ACCUMULATION_PER_TRACK: | 60 case ADB_ACCUMULATION_PER_TRACK: |
217 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_lt>(qspec.params.npoints, qspec.params.ntracks); | 61 qstate.accumulator = new PerTrackAccumulator<adb_result_dist_lt>(qspec->params.npoints, qspec->params.ntracks); |
218 break; | 62 break; |
219 case ADB_ACCUMULATION_ONE_TO_ONE: | 63 case ADB_ACCUMULATION_ONE_TO_ONE: |
220 qstate.accumulator = new NearestAccumulator<adb_result_dist_lt>(); | 64 qstate.accumulator = new NearestAccumulator<adb_result_dist_lt>(); |
221 break; | 65 break; |
222 default: | 66 default: |
223 error("unknown accumulation"); | 67 goto error; |
224 } | 68 } |
225 break; | 69 break; |
226 default: | 70 default: |
227 error("unknown distance function"); | 71 goto error; |
228 } | 72 } |
229 | 73 |
230 // Test for index (again) here | 74 if((qspec->refine.flags & ADB_REFINE_RADIUS) && audiodb_index_exists(adb->path, qspec->refine.radius, qspec->qid.sequence_length)) { |
231 if((qspec.refine.flags & ADB_REFINE_RADIUS) && audiodb_index_exists(adb->path, qspec.refine.radius, qspec.qid.sequence_length)){ | 75 if(audiodb_index_query_loop(adb, qspec, &qstate) < 0) { |
232 VERB_LOG(1, "Calling indexed query on database %s, radius=%f, sequence_length=%d\n", adb->path, qspec.refine.radius, qspec.qid.sequence_length); | 76 goto error; |
233 if(audiodb_index_query_loop(adb, &qspec, &qstate) < 0) { | |
234 error("index_query_loop failed"); | |
235 } | 77 } |
236 } else { | 78 } else { |
237 VERB_LOG(1, "Calling brute-force query on database %s\n", dbName); | 79 if(audiodb_query_loop(adb, qspec, &qstate)) { |
238 if(audiodb_query_loop(adb, &qspec, &qstate)) { | 80 goto error; |
239 error("audiodb_query_loop failed"); | 81 } |
240 } | 82 } |
241 } | 83 |
242 | 84 results = qstate.accumulator->get_points(); |
243 adb_query_results_t *rs = qstate.accumulator->get_points(); | |
244 | 85 |
245 delete qstate.accumulator; | 86 delete qstate.accumulator; |
246 delete qstate.allowed_keys; | 87 delete qstate.allowed_keys; |
247 | 88 |
248 /* End of audiodb_query_spec() function */ | 89 return results; |
249 | 90 |
250 for(unsigned int k = 0; k < rs->nresults; k++) { | 91 error: |
251 adb_result_t r = rs->results[k]; | 92 if(qstate.accumulator) |
252 reporter->add_point(audiodb_key_index(adb, r.key), r.qpos, r.ipos, r.dist); | 93 delete qstate.accumulator; |
253 } | 94 if(qstate.allowed_keys) |
254 audiodb_query_free_results(adb, &qspec, rs); | 95 delete qstate.allowed_keys; |
255 | 96 return NULL; |
256 reporter->report(fileTable, adbQueryResponse); | |
257 } | 97 } |
258 | 98 |
259 int audiodb_query_free_results(adb_t *adb, adb_query_spec_t *spec, adb_query_results_t *rs) { | 99 int audiodb_query_free_results(adb_t *adb, adb_query_spec_t *spec, adb_query_results_t *rs) { |
260 free(rs->results); | 100 free(rs->results); |
261 free(rs); | 101 free(rs); |
352 read_or_goto_error(trkfid, *data_buffer_p, track_size); | 192 read_or_goto_error(trkfid, *data_buffer_p, track_size); |
353 return 0; | 193 return 0; |
354 | 194 |
355 error: | 195 error: |
356 return 1; | 196 return 1; |
357 } | |
358 | |
359 void audioDB::insertTimeStamps(unsigned numVectors, std::ifstream *timesFile, double *timesdata) { | |
360 assert(usingTimes); | |
361 | |
362 unsigned numtimes = 0; | |
363 | |
364 if(!timesFile->is_open()) { | |
365 error("problem opening times file on timestamped database", timesFileName); | |
366 } | |
367 | |
368 double timepoint, next; | |
369 *timesFile >> timepoint; | |
370 if (timesFile->eof()) { | |
371 error("no entries in times file", timesFileName); | |
372 } | |
373 numtimes++; | |
374 do { | |
375 *timesFile >> next; | |
376 if (timesFile->eof()) { | |
377 break; | |
378 } | |
379 numtimes++; | |
380 timesdata[0] = timepoint; | |
381 timepoint = (timesdata[1] = next); | |
382 timesdata += 2; | |
383 } while (numtimes < numVectors + 1); | |
384 | |
385 if (numtimes < numVectors + 1) { | |
386 error("too few timepoints in times file", timesFileName); | |
387 } | |
388 | |
389 *timesFile >> next; | |
390 if (!timesFile->eof()) { | |
391 error("too many timepoints in times file", timesFileName); | |
392 } | |
393 } | 197 } |
394 | 198 |
395 int audiodb_track_id_datum(adb_t *adb, uint32_t track_id, adb_datum_t *d) { | 199 int audiodb_track_id_datum(adb_t *adb, uint32_t track_id, adb_datum_t *d) { |
396 off_t track_offset = (*adb->track_offsets)[track_id]; | 200 off_t track_offset = (*adb->track_offsets)[track_id]; |
397 if(adb->header->flags & O2_FLAG_LARGE_ADB) { | 201 if(adb->header->flags & O2_FLAG_LARGE_ADB) { |