annotate query-indexed.cpp @ 755:37c2b9cce23a multiprobeLSH

Adding mkc_lsh_update branch, trunk candidate with improved LSH: merged trunk 1095 and branch multiprobe_lsh
author mas01mc
date Thu, 25 Nov 2010 13:42:40 +0000
parents eb5dd50dd7d1
children
rev   line source
mas01cr@509 1 extern "C" {
mas01cr@509 2 #include "audioDB_API.h"
mas01cr@509 3 }
mas01cr@509 4 #include "audioDB-internals.h"
mas01cr@509 5
mas01cr@509 6 /*
mas01cr@509 7 * Routines and datastructures which are specific to indexed queries.
mas01cr@509 8 */
mas01cr@509 9 typedef struct adb_qcallback {
mas01cr@509 10 adb_t *adb;
mas01cr@509 11 adb_qstate_internal_t *qstate;
mas01cr@509 12 } adb_qcallback_t;
mas01cr@509 13
mas01cr@509 14 // return true if indexed query performed else return false
mas01cr@509 15 int audiodb_index_init_query(adb_t *adb, const adb_query_spec_t *spec, adb_qstate_internal_t *qstate, bool corep) {
mas01cr@509 16
mas01cr@509 17 uint32_t sequence_length = spec->qid.sequence_length;
mas01cr@509 18 double radius = spec->refine.radius;
mas01cr@509 19 if(!(audiodb_index_exists(adb->path, radius, sequence_length)))
mas01cr@509 20 return false;
mas01cr@509 21
mas01cr@509 22 char *indexName = audiodb_index_get_name(adb->path, radius, sequence_length);
mas01cr@509 23 if(!indexName) {
mas01cr@509 24 return false;
mas01cr@509 25 }
mas01cr@509 26
mas01cr@509 27 qstate->lsh = audiodb_index_allocate(adb, indexName, corep);
mas01cr@509 28
mas01cr@509 29 /* FIXME: it would be nice if the LSH library didn't make me do
mas01cr@509 30 * this. */
mas01cr@509 31 if((!corep) && (qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2)) {
mas01cr@509 32 delete qstate->lsh;
mas01cr@509 33 qstate->lsh = audiodb_index_allocate(adb, indexName, true);
mas01mc@513 34 #ifdef LSH_DUMP_CORE_TABLES
mas01mc@513 35 qstate->lsh->dump_hashtables();
mas01mc@513 36 #endif
mas01cr@509 37 }
mas01cr@509 38
mas01cr@509 39 delete[] indexName;
mas01cr@509 40 return true;
mas01cr@509 41 }
mas01cr@509 42
mas01cr@509 43 void audiodb_index_add_point_approximate(void *user_data, Uns32T pointID, Uns32T qpos, float dist) {
mas01cr@509 44 adb_qcallback_t *data = (adb_qcallback_t *) user_data;
mas01cr@509 45 adb_t *adb = data->adb;
mas01cr@509 46 adb_qstate_internal_t *qstate = data->qstate;
mas01mc@532 47 uint32_t trackID = audiodb_index_to_track_id(adb, pointID);
mas01mc@532 48 uint32_t spos = audiodb_index_to_track_pos(adb, trackID, pointID);
mas01cr@509 49 std::set<std::string>::iterator keys_end = qstate->allowed_keys->end();
mas01cr@509 50 if(qstate->allowed_keys->find((*adb->keys)[trackID]) != keys_end) {
mas01cr@509 51 adb_result_t r;
mas01cr@509 52 r.key = (*adb->keys)[trackID].c_str();
mas01cr@509 53 r.dist = dist;
mas01cr@509 54 r.qpos = qpos;
mas01cr@509 55 r.ipos = spos;
mas01cr@509 56 qstate->accumulator->add_point(&r);
mas01cr@509 57 }
mas01cr@509 58 }
mas01cr@509 59
mas01cr@509 60 // Maintain a queue of points to pass to audiodb_query_queue_loop()
mas01cr@509 61 // for exact evaluation
mas01cr@509 62 void audiodb_index_add_point_exact(void *user_data, Uns32T pointID, Uns32T qpos, float dist) {
mas01cr@509 63 adb_qcallback_t *data = (adb_qcallback_t *) user_data;
mas01cr@509 64 adb_t *adb = data->adb;
mas01cr@509 65 adb_qstate_internal_t *qstate = data->qstate;
mas01mc@532 66 uint32_t trackID = audiodb_index_to_track_id(adb, pointID);
mas01mc@532 67 uint32_t spos = audiodb_index_to_track_pos(adb, trackID, pointID);
mas01cr@509 68 std::set<std::string>::iterator keys_end = qstate->allowed_keys->end();
mas01cr@509 69 if(qstate->allowed_keys->find((*adb->keys)[trackID]) != keys_end) {
mas01cr@509 70 PointPair p(trackID, qpos, spos);
mas01mc@529 71 if(qstate->set->find(p)==qstate->set->end()){
mas01mc@529 72 qstate->set->insert(p);
mas01mc@529 73 qstate->exact_evaluation_queue->push(p);
mas01mc@529 74 }
mas01cr@509 75 }
mas01cr@509 76 }
mas01cr@509 77
mas01cr@509 78 // return -1 on error
mas01cr@509 79 // return 0: if index does not exist
mas01cr@509 80 // return nqv: if index exists
mas01cr@509 81 int audiodb_index_query_loop(adb_t *adb, const adb_query_spec_t *spec, adb_qstate_internal_t *qstate) {
mas01mc@533 82 if(adb->header->flags>>28)
mas01mc@533 83 cerr << "WARNING: Database created using deprecated LSH_N_POINT_BITS coding: REBUILD INDEXES..." << endl;
mas01mc@533 84
mas01cr@509 85 double *query = 0, *query_data = 0;
mas01cr@509 86 adb_qpointers_internal_t qpointers = {0};
mas01cr@509 87
mas01cr@509 88 adb_qcallback_t callback_data;
mas01cr@509 89 callback_data.adb = adb;
mas01cr@509 90 callback_data.qstate = qstate;
mas01cr@509 91
mas01cr@509 92 void (*add_point_func)(void *, uint32_t, uint32_t, float);
mas01cr@509 93
mas01cr@509 94 uint32_t sequence_length = spec->qid.sequence_length;
mas01cr@509 95 bool normalized = (spec->params.distance == ADB_DISTANCE_EUCLIDEAN_NORMED);
mas01cr@509 96 double radius = spec->refine.radius;
mas01cr@509 97 bool use_absolute_threshold = spec->refine.flags & ADB_REFINE_ABSOLUTE_THRESHOLD;
mas01cr@509 98 double absolute_threshold = spec->refine.absolute_threshold;
mas01cr@509 99
mas01cr@509 100 if(spec->qid.flags & ADB_QID_FLAG_ALLOW_FALSE_POSITIVES) {
mas01cr@509 101 add_point_func = &audiodb_index_add_point_approximate;
mas01cr@509 102 } else {
mas01mc@530 103 qstate->exact_evaluation_queue = new std::priority_queue<PointPair, vector<PointPair>, greater<PointPair> >;
mas01mc@529 104 qstate->set = new std::set<PointPair, less<PointPair> >;
mas01cr@509 105 add_point_func = &audiodb_index_add_point_exact;
mas01cr@509 106 }
mas01cr@509 107
mas01cr@509 108 /* FIXME: this hardwired lsh_in_core is here to allow for a
mas01cr@509 109 * transition period while the need for the argument is worked
mas01cr@509 110 * through. Hopefully it will disappear again eventually. */
mas01cr@509 111 bool lsh_in_core = true;
mas01cr@509 112
mas01cr@509 113 if(!audiodb_index_init_query(adb, spec, qstate, lsh_in_core)) {
mas01cr@509 114 return 0;
mas01cr@509 115 }
mas01cr@509 116
mas01cr@509 117 char *database = audiodb_index_get_name(adb->path, radius, sequence_length);
mas01cr@509 118 if(!database) {
mas01cr@509 119 return -1;
mas01cr@509 120 }
mas01cr@509 121
mas01cr@509 122 if(audiodb_query_spec_qpointers(adb, spec, &query_data, &query, &qpointers)) {
mas01cr@509 123 delete [] database;
mas01cr@509 124 return -1;
mas01cr@509 125 }
mas01cr@509 126
mas01mc@532 127 uint32_t Nq = qpointers.nvectors - sequence_length + 1;
mas01cr@509 128 std::vector<std::vector<float> > *vv = audiodb_index_initialize_shingles(Nq, adb->header->dim, sequence_length);
mas01cr@509 129
mas01cr@509 130 // Construct shingles from query features
mas01cr@509 131 for(uint32_t pointID = 0; pointID < Nq; pointID++) {
mas01cr@509 132 audiodb_index_make_shingle(vv, pointID, query, adb->header->dim, sequence_length);
mas01cr@509 133 }
mas01cr@509 134
mas01cr@509 135 // Normalize query vectors
mas01cr@509 136 int vcount = audiodb_index_norm_shingles(vv, qpointers.l2norm, qpointers.power, adb->header->dim, sequence_length, radius, normalized, use_absolute_threshold, absolute_threshold);
mas01cr@509 137 if(vcount == -1) {
mas01cr@509 138 audiodb_index_delete_shingles(vv);
mas01cr@509 139 delete [] database;
mas01cr@509 140 return -1;
mas01cr@509 141 }
mas01cr@509 142 uint32_t numVecsAboveThreshold = vcount;
mas01cr@509 143
mas01cr@509 144 // Nq contains number of inspected points in query file,
mas01cr@509 145 // numVecsAboveThreshold is number of points with power >= absolute_threshold
mas01cr@509 146 double *qpp = qpointers.power; // Keep original qpPtr for possible exact evaluation
mas01cr@509 147 if(!(spec->qid.flags & ADB_QID_FLAG_EXHAUSTIVE) && numVecsAboveThreshold) {
mas01cr@509 148 if((qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2) || lsh_in_core) {
mas01cr@509 149 qstate->lsh->retrieve_point((*vv)[0], spec->qid.sequence_start, add_point_func, &callback_data);
mas01cr@509 150 } else {
mas01cr@509 151 qstate->lsh->serial_retrieve_point(database, (*vv)[0], spec->qid.sequence_start, add_point_func, &callback_data);
mas01cr@509 152 }
mas01cr@509 153 } else if(numVecsAboveThreshold) {
mas01cr@509 154 for(uint32_t pointID = 0; pointID < Nq; pointID++) {
mas01cr@509 155 if(!use_absolute_threshold || (use_absolute_threshold && (*qpp++ >= absolute_threshold))) {
mas01cr@509 156 if((qstate->lsh->get_lshHeader()->flags & O2_SERIAL_FILEFORMAT2) || lsh_in_core) {
mas01cr@509 157 qstate->lsh->retrieve_point((*vv)[pointID], pointID, add_point_func, &callback_data);
mas01cr@509 158 } else {
mas01cr@509 159 qstate->lsh->serial_retrieve_point(database, (*vv)[pointID], pointID, add_point_func, &callback_data);
mas01cr@509 160 }
mas01cr@509 161 }
mas01cr@509 162 }
mas01cr@509 163 }
mas01cr@509 164 audiodb_index_delete_shingles(vv);
mas01cr@509 165
mas01cr@509 166 if(!(spec->qid.flags & ADB_QID_FLAG_ALLOW_FALSE_POSITIVES)) {
mas01cr@509 167 audiodb_query_queue_loop(adb, spec, qstate, query, &qpointers);
mas01cr@509 168 }
mas01cr@509 169
mas01cr@509 170 // Clean up
mas01cr@509 171 if(query_data)
mas01cr@509 172 delete[] query_data;
mas01cr@509 173 if(qpointers.l2norm_data)
mas01cr@509 174 delete[] qpointers.l2norm_data;
mas01cr@509 175 if(qpointers.power_data)
mas01cr@509 176 delete[] qpointers.power_data;
mas01cr@509 177 if(qpointers.mean_duration)
mas01cr@509 178 delete[] qpointers.mean_duration;
mas01cr@509 179 if(database)
mas01cr@509 180 delete[] database;
mas01cr@509 181 if(qstate->lsh != adb->cached_lsh)
mas01cr@509 182 delete qstate->lsh;
mas01cr@509 183
mas01cr@509 184 return Nq;
mas01cr@509 185 }