mas01cr@509
|
1 extern "C" {
|
mas01cr@509
|
2 #include "audioDB_API.h"
|
mas01cr@509
|
3 }
|
mas01cr@509
|
4 #include "audioDB-internals.h"
|
mas01cr@589
|
5 #include "lshlib.h"
|
mas01cr@509
|
6
|
mas01cr@509
|
7 /*
|
mas01cr@509
|
8 * Routines and datastructures which are specific to indexed queries.
|
mas01cr@509
|
9 */
|
mas01cr@509
|
10 typedef struct adb_qcallback {
|
mas01cr@509
|
11 adb_t *adb;
|
mas01cr@509
|
12 adb_qstate_internal_t *qstate;
|
mas01cr@509
|
13 } adb_qcallback_t;
|
mas01cr@509
|
14
|
mas01cr@509
|
15 // return true if indexed query performed else return false
|
mas01cr@509
|
16 int audiodb_index_init_query(adb_t *adb, const adb_query_spec_t *spec, adb_qstate_internal_t *qstate, bool corep) {
|
mas01cr@509
|
17
|
mas01cr@509
|
18 uint32_t sequence_length = spec->qid.sequence_length;
|
mas01cr@509
|
19 double radius = spec->refine.radius;
|
mas01cr@509
|
20 if(!(audiodb_index_exists(adb->path, radius, sequence_length)))
|
mas01cr@509
|
21 return false;
|
mas01cr@509
|
22
|
mas01cr@509
|
23 char *indexName = audiodb_index_get_name(adb->path, radius, sequence_length);
|
mas01cr@509
|
24 if(!indexName) {
|
mas01cr@509
|
25 return false;
|
mas01cr@509
|
26 }
|
mas01cr@509
|
27
|
mas01cr@509
|
28 qstate->lsh = audiodb_index_allocate(adb, indexName, corep);
|
mas01cr@509
|
29
|
mas01cr@509
|
30 /* FIXME: it would be nice if the LSH library didn't make me do
|
mas01cr@509
|
31 * this. */
|
mas01cr@509
|
32 if((!corep) && (qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2)) {
|
mas01cr@509
|
33 delete qstate->lsh;
|
mas01cr@509
|
34 qstate->lsh = audiodb_index_allocate(adb, indexName, true);
|
mas01mc@513
|
35 #ifdef LSH_DUMP_CORE_TABLES
|
mas01mc@513
|
36 qstate->lsh->dump_hashtables();
|
mas01mc@513
|
37 #endif
|
mas01cr@509
|
38 }
|
mas01cr@509
|
39
|
mas01cr@509
|
40 delete[] indexName;
|
mas01cr@509
|
41 return true;
|
mas01cr@509
|
42 }
|
mas01cr@509
|
43
|
mas01cr@589
|
44 void audiodb_index_add_point_approximate(void *user_data, uint32_t pointID, uint32_t qpos, float dist) {
|
mas01cr@509
|
45 adb_qcallback_t *data = (adb_qcallback_t *) user_data;
|
mas01cr@509
|
46 adb_t *adb = data->adb;
|
mas01cr@509
|
47 adb_qstate_internal_t *qstate = data->qstate;
|
mas01mc@534
|
48 uint32_t trackID = audiodb_index_to_track_id(adb, pointID);
|
mas01mc@534
|
49 uint32_t spos = audiodb_index_to_track_pos(adb, trackID, pointID);
|
mas01cr@509
|
50 std::set<std::string>::iterator keys_end = qstate->allowed_keys->end();
|
mas01cr@509
|
51 if(qstate->allowed_keys->find((*adb->keys)[trackID]) != keys_end) {
|
mas01cr@509
|
52 adb_result_t r;
|
mas01cr@509
|
53 r.key = (*adb->keys)[trackID].c_str();
|
mas01cr@509
|
54 r.dist = dist;
|
mas01cr@509
|
55 r.qpos = qpos;
|
mas01cr@509
|
56 r.ipos = spos;
|
mas01cr@509
|
57 qstate->accumulator->add_point(&r);
|
mas01cr@509
|
58 }
|
mas01cr@509
|
59 }
|
mas01cr@509
|
60
|
mas01cr@509
|
61 // Maintain a queue of points to pass to audiodb_query_queue_loop()
|
mas01cr@509
|
62 // for exact evaluation
|
mas01cr@589
|
63 void audiodb_index_add_point_exact(void *user_data, uint32_t pointID, uint32_t qpos, float dist) {
|
mas01cr@509
|
64 adb_qcallback_t *data = (adb_qcallback_t *) user_data;
|
mas01cr@509
|
65 adb_t *adb = data->adb;
|
mas01cr@509
|
66 adb_qstate_internal_t *qstate = data->qstate;
|
mas01mc@534
|
67 uint32_t trackID = audiodb_index_to_track_id(adb, pointID);
|
mas01mc@534
|
68 uint32_t spos = audiodb_index_to_track_pos(adb, trackID, pointID);
|
mas01cr@509
|
69 std::set<std::string>::iterator keys_end = qstate->allowed_keys->end();
|
mas01cr@509
|
70 if(qstate->allowed_keys->find((*adb->keys)[trackID]) != keys_end) {
|
mas01cr@509
|
71 PointPair p(trackID, qpos, spos);
|
mas01cr@509
|
72 qstate->exact_evaluation_queue->push(p);
|
mas01cr@509
|
73 }
|
mas01cr@509
|
74 }
|
mas01cr@509
|
75
|
mas01cr@509
|
76 // return -1 on error
|
mas01cr@509
|
77 // return 0: if index does not exist
|
mas01cr@509
|
78 // return nqv: if index exists
|
mas01cr@509
|
79 int audiodb_index_query_loop(adb_t *adb, const adb_query_spec_t *spec, adb_qstate_internal_t *qstate) {
|
mas01mc@534
|
80 if(adb->header->flags>>28)
|
mas01mc@534
|
81 cerr << "WARNING: Database created using deprecated LSH_N_POINT_BITS coding: REBUILD INDEXES..." << endl;
|
mas01mc@534
|
82
|
mas01cr@509
|
83 double *query = 0, *query_data = 0;
|
mas01cr@509
|
84 adb_qpointers_internal_t qpointers = {0};
|
mas01cr@509
|
85
|
mas01cr@509
|
86 adb_qcallback_t callback_data;
|
mas01cr@509
|
87 callback_data.adb = adb;
|
mas01cr@509
|
88 callback_data.qstate = qstate;
|
mas01cr@509
|
89
|
mas01cr@509
|
90 void (*add_point_func)(void *, uint32_t, uint32_t, float);
|
mas01cr@509
|
91
|
mas01cr@509
|
92 uint32_t sequence_length = spec->qid.sequence_length;
|
mas01cr@509
|
93 bool normalized = (spec->params.distance == ADB_DISTANCE_EUCLIDEAN_NORMED);
|
mas01cr@509
|
94 double radius = spec->refine.radius;
|
mas01cr@509
|
95 bool use_absolute_threshold = spec->refine.flags & ADB_REFINE_ABSOLUTE_THRESHOLD;
|
mas01cr@509
|
96 double absolute_threshold = spec->refine.absolute_threshold;
|
mas01cr@509
|
97
|
mas01cr@509
|
98 if(spec->qid.flags & ADB_QID_FLAG_ALLOW_FALSE_POSITIVES) {
|
mas01cr@509
|
99 add_point_func = &audiodb_index_add_point_approximate;
|
mas01cr@509
|
100 } else {
|
mas01cr@509
|
101 qstate->exact_evaluation_queue = new std::priority_queue<PointPair>;
|
mas01cr@509
|
102 add_point_func = &audiodb_index_add_point_exact;
|
mas01cr@509
|
103 }
|
mas01cr@509
|
104
|
mas01cr@509
|
105 /* FIXME: this hardwired lsh_in_core is here to allow for a
|
mas01cr@509
|
106 * transition period while the need for the argument is worked
|
mas01cr@509
|
107 * through. Hopefully it will disappear again eventually. */
|
mas01cr@509
|
108 bool lsh_in_core = true;
|
mas01cr@509
|
109
|
mas01cr@509
|
110 if(!audiodb_index_init_query(adb, spec, qstate, lsh_in_core)) {
|
mas01cr@509
|
111 return 0;
|
mas01cr@509
|
112 }
|
mas01cr@509
|
113
|
mas01cr@509
|
114 char *database = audiodb_index_get_name(adb->path, radius, sequence_length);
|
mas01cr@509
|
115 if(!database) {
|
mas01cr@509
|
116 return -1;
|
mas01cr@509
|
117 }
|
mas01cr@509
|
118
|
mas01cr@509
|
119 if(audiodb_query_spec_qpointers(adb, spec, &query_data, &query, &qpointers)) {
|
mas01cr@509
|
120 delete [] database;
|
mas01cr@509
|
121 return -1;
|
mas01cr@509
|
122 }
|
mas01cr@509
|
123
|
mas01mc@534
|
124 uint32_t Nq = qpointers.nvectors - sequence_length + 1;
|
mas01cr@509
|
125 std::vector<std::vector<float> > *vv = audiodb_index_initialize_shingles(Nq, adb->header->dim, sequence_length);
|
mas01cr@509
|
126
|
mas01cr@509
|
127 // Construct shingles from query features
|
mas01cr@509
|
128 for(uint32_t pointID = 0; pointID < Nq; pointID++) {
|
mas01cr@509
|
129 audiodb_index_make_shingle(vv, pointID, query, adb->header->dim, sequence_length);
|
mas01cr@509
|
130 }
|
mas01cr@509
|
131
|
mas01cr@509
|
132 // Normalize query vectors
|
mas01cr@509
|
133 int vcount = audiodb_index_norm_shingles(vv, qpointers.l2norm, qpointers.power, adb->header->dim, sequence_length, radius, normalized, use_absolute_threshold, absolute_threshold);
|
mas01cr@509
|
134 if(vcount == -1) {
|
mas01cr@509
|
135 audiodb_index_delete_shingles(vv);
|
mas01cr@509
|
136 delete [] database;
|
mas01cr@509
|
137 return -1;
|
mas01cr@509
|
138 }
|
mas01cr@509
|
139 uint32_t numVecsAboveThreshold = vcount;
|
mas01cr@509
|
140
|
mas01cr@509
|
141 // Nq contains number of inspected points in query file,
|
mas01cr@509
|
142 // numVecsAboveThreshold is number of points with power >= absolute_threshold
|
mas01cr@509
|
143 double *qpp = qpointers.power; // Keep original qpPtr for possible exact evaluation
|
mas01cr@509
|
144 if(!(spec->qid.flags & ADB_QID_FLAG_EXHAUSTIVE) && numVecsAboveThreshold) {
|
mas01cr@509
|
145 if((qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2) || lsh_in_core) {
|
mas01cr@509
|
146 qstate->lsh->retrieve_point((*vv)[0], spec->qid.sequence_start, add_point_func, &callback_data);
|
mas01cr@509
|
147 } else {
|
mas01cr@509
|
148 qstate->lsh->serial_retrieve_point(database, (*vv)[0], spec->qid.sequence_start, add_point_func, &callback_data);
|
mas01cr@509
|
149 }
|
mas01cr@509
|
150 } else if(numVecsAboveThreshold) {
|
mas01cr@509
|
151 for(uint32_t pointID = 0; pointID < Nq; pointID++) {
|
mas01cr@509
|
152 if(!use_absolute_threshold || (use_absolute_threshold && (*qpp++ >= absolute_threshold))) {
|
mas01cr@509
|
153 if((qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2) || lsh_in_core) {
|
mas01cr@509
|
154 qstate->lsh->retrieve_point((*vv)[pointID], pointID, add_point_func, &callback_data);
|
mas01cr@509
|
155 } else {
|
mas01cr@509
|
156 qstate->lsh->serial_retrieve_point(database, (*vv)[pointID], pointID, add_point_func, &callback_data);
|
mas01cr@509
|
157 }
|
mas01cr@509
|
158 }
|
mas01cr@509
|
159 }
|
mas01cr@509
|
160 }
|
mas01cr@509
|
161 audiodb_index_delete_shingles(vv);
|
mas01cr@509
|
162
|
mas01cr@509
|
163 if(!(spec->qid.flags & ADB_QID_FLAG_ALLOW_FALSE_POSITIVES)) {
|
mas01cr@509
|
164 audiodb_query_queue_loop(adb, spec, qstate, query, &qpointers);
|
mas01cr@509
|
165 }
|
mas01cr@509
|
166
|
mas01cr@509
|
167 // Clean up
|
mas01cr@509
|
168 if(query_data)
|
mas01cr@509
|
169 delete[] query_data;
|
mas01cr@509
|
170 if(qpointers.l2norm_data)
|
mas01cr@509
|
171 delete[] qpointers.l2norm_data;
|
mas01cr@509
|
172 if(qpointers.power_data)
|
mas01cr@509
|
173 delete[] qpointers.power_data;
|
mas01cr@509
|
174 if(qpointers.mean_duration)
|
mas01cr@509
|
175 delete[] qpointers.mean_duration;
|
mas01cr@509
|
176 if(database)
|
mas01cr@509
|
177 delete[] database;
|
mas01cr@509
|
178 if(qstate->lsh != adb->cached_lsh)
|
mas01cr@509
|
179 delete qstate->lsh;
|
mas01cr@509
|
180
|
mas01cr@509
|
181 return Nq;
|
mas01cr@509
|
182 }
|