comparison query.cpp @ 539:06ed85832c3b multiprobeLSH

Optimized the query_loop_points inner loop for memcpy and I/O efficiency. Uses sparse seeks and reads to perform scattered reads across data set. Current version does not cache fid between open calls to the same trackID.
author mas01mc
date Sat, 07 Feb 2009 01:20:05 +0000
parents ddf763553175
children 52d82badc544
comparison
equal deleted inserted replaced
538:02e0a9ecfd0f 539:06ed85832c3b
196 196
197 error: 197 error:
198 return 1; 198 return 1;
199 } 199 }
200 200
201 int audiodb_track_id_datum(adb_t *adb, uint32_t track_id, adb_datum_t *d) { 201 int audiodb_track_id_datum(adb_t *adb, uint32_t track_id, adb_datum_t *d, off_t vector_offset=0, size_t num_vectors=0) {
202 off_t track_offset = (*adb->track_offsets)[track_id]; 202 off_t track_offset = (*adb->track_offsets)[track_id];
203 if(adb->header->flags & ADB_HEADER_FLAG_REFERENCES) { 203 if(adb->header->flags & ADB_HEADER_FLAG_REFERENCES) {
204 /* create a reference/insert, then use adb_insert_create_datum() */ 204 /* create a reference/insert, then use adb_insert_create_datum() */
205 adb_reference_t reference = {0}; 205 adb_reference_t reference = {0};
206 char features[ADB_MAXSTR], power[ADB_MAXSTR], times[ADB_MAXSTR]; 206 char features[ADB_MAXSTR], power[ADB_MAXSTR], times[ADB_MAXSTR];
215 if(adb->header->flags & ADB_HEADER_FLAG_TIMES) { 215 if(adb->header->flags & ADB_HEADER_FLAG_TIMES) {
216 lseek(adb->fd, adb->header->timesTableOffset + track_id * ADB_FILETABLE_ENTRY_SIZE, SEEK_SET); 216 lseek(adb->fd, adb->header->timesTableOffset + track_id * ADB_FILETABLE_ENTRY_SIZE, SEEK_SET);
217 read_or_goto_error(adb->fd, times, ADB_MAXSTR); 217 read_or_goto_error(adb->fd, times, ADB_MAXSTR);
218 reference.times = times; 218 reference.times = times;
219 } 219 }
220 return audiodb_insert_create_datum(&reference, d); 220 return audiodb_insert_create_datum(&reference, d, vector_offset*adb->header->dim*sizeof(double), num_vectors*adb->header->dim*sizeof(double));
221 } else { 221 } else {
222 /* initialize from sources of data that we already have */ 222 /* initialize from sources of data that we already have */
223 d->nvectors = (*adb->track_lengths)[track_id]; 223 if(num_vectors)
224 d->nvectors = num_vectors;
225 else
226 d->nvectors = (*adb->track_lengths)[track_id];
224 d->dim = adb->header->dim; 227 d->dim = adb->header->dim;
225 d->key = (*adb->keys)[track_id].c_str(); 228 d->key = (*adb->keys)[track_id].c_str();
226 /* read out stuff from the database tables */ 229 /* read out stuff from the database tables */
227 d->data = (double *) malloc(d->nvectors * d->dim * sizeof(double)); 230 d->data = (double *) malloc(d->nvectors * d->dim * sizeof(double));
228 lseek(adb->fd, adb->header->dataOffset + track_offset, SEEK_SET); 231 lseek(adb->fd, adb->header->dataOffset + track_offset + vector_offset*d->dim*sizeof(double), SEEK_SET);
229 read_or_goto_error(adb->fd, d->data, d->nvectors * d->dim * sizeof(double)); 232 read_or_goto_error(adb->fd, d->data, d->nvectors * d->dim * sizeof(double));
230 if(adb->header->flags & ADB_HEADER_FLAG_POWER) { 233 if(adb->header->flags & ADB_HEADER_FLAG_POWER) {
231 d->power = (double *) malloc(d->nvectors * sizeof(double)); 234 d->power = (double *) malloc(d->nvectors * sizeof(double));
232 lseek(adb->fd, adb->header->powerTableOffset + track_offset / d->dim, SEEK_SET); 235 lseek(adb->fd, adb->header->powerTableOffset + track_offset / d->dim + vector_offset*sizeof(double), SEEK_SET);
233 read_or_goto_error(adb->fd, d->power, d->nvectors * sizeof(double)); 236 read_or_goto_error(adb->fd, d->power, d->nvectors * sizeof(double));
234 } 237 }
235 if(adb->header->flags & ADB_HEADER_FLAG_TIMES) { 238 if(adb->header->flags & ADB_HEADER_FLAG_TIMES) {
236 d->times = (double *) malloc(2 * d->nvectors * sizeof(double)); 239 d->times = (double *) malloc(2 * d->nvectors * sizeof(double));
237 lseek(adb->fd, adb->header->timesTableOffset + track_offset / d->dim, SEEK_SET); 240 lseek(adb->fd, adb->header->timesTableOffset + track_offset / d->dim + 2 * vector_offset*sizeof(double), SEEK_SET);
238 read_or_goto_error(adb->fd, d->times, 2 * d->nvectors * sizeof(double)); 241 read_or_goto_error(adb->fd, d->times, 2 * d->nvectors * sizeof(double));
239 } 242 }
240 return 0; 243 return 0;
241 } 244 }
242 error: 245 error:
283 int audiodb_datum_qpointers_partial(adb_datum_t *d, uint32_t sequence_length, double **vector_data, 286 int audiodb_datum_qpointers_partial(adb_datum_t *d, uint32_t sequence_length, double **vector_data,
284 double **vector, adb_qpointers_internal_t *qpointers, 287 double **vector, adb_qpointers_internal_t *qpointers,
285 adb_qstate_internal_t *qstate){ 288 adb_qstate_internal_t *qstate){
286 uint32_t nvectors = d->nvectors; 289 uint32_t nvectors = d->nvectors;
287 qpointers->nvectors = nvectors; 290 qpointers->nvectors = nvectors;
288 std::priority_queue<PointPair, std::vector<PointPair>, greater<PointPair> > ppairs(*qstate->exact_evaluation_queue); 291
289 292 PointPair pp = (*qstate->exact_evaluation_queue).top();
290 size_t vector_size = nvectors * sizeof(double) * d->dim;
291
292 if(d->power)
293 qpointers->power_data = new double[vector_size / d->dim];
294
295 uint32_t seq_len_dbl = sequence_length*sizeof(double);
296 PointPair pp = ppairs.top();
297 uint32_t tid = pp.trackID;
298
299 while( !ppairs.empty() && pp.trackID==tid){
300 uint32_t spos = pp.spos;
301 #ifdef _LSH_DEBUG_ 293 #ifdef _LSH_DEBUG_
302 cout << "tid=" << pp.trackID << " qpos=" << pp.qpos << " spos=" << pp.spos << endl; 294 cout << "tid=" << pp.trackID << " qpos=" << pp.qpos << " spos=" << pp.spos << endl;
303 cout.flush(); 295 cout.flush();
304 #endif 296 #endif
305 297
306 if(d->power) { 298 if(d->power) {
307 memcpy(qpointers->power_data+spos, d->power+spos, seq_len_dbl); 299 //memcpy(qpointers->power_data, d->power, seq_len_dbl);
308 audiodb_sequence_sum(qpointers->power_data+spos, sequence_length, sequence_length); 300 audiodb_sequence_sum(d->power, sequence_length, sequence_length);
309 audiodb_sequence_average(qpointers->power_data+spos, sequence_length, sequence_length); 301 audiodb_sequence_average(d->power, sequence_length, sequence_length);
310 } 302 }
311 ppairs.pop(); 303
312 if(!ppairs.empty())
313 pp = ppairs.top();
314 }
315
316 if(d->times) { 304 if(d->times) {
317 qpointers->mean_duration = new double[1]; 305 qpointers->mean_duration = new double[1];
318 *qpointers->mean_duration = 0; 306 *qpointers->mean_duration = 0;
319 for(unsigned int k = 0; k < nvectors; k++) { 307 for(unsigned int k = 0; k < nvectors; k++) {
320 *qpointers->mean_duration += d->times[2*k+1] - d->times[2*k]; 308 *qpointers->mean_duration += d->times[2*k+1] - d->times[2*k];
321 } 309 }
322 *qpointers->mean_duration /= nvectors; 310 *qpointers->mean_duration /= nvectors;
323 } 311 }
324 312
325 *vector = d->data; 313 *vector = d->data;
326 *vector_data = d->data; 314 *vector_data = d->data;
327 qpointers->l2norm = 0 ; 315 qpointers->l2norm = 0 ;
328 qpointers->power = qpointers->power_data; 316 qpointers->power = d->power;
329 return 0; 317 return 0;
330 } 318 }
331 319
332 int audiodb_query_spec_qpointers(adb_t *adb, const adb_query_spec_t *spec, double **vector_data, double **vector, adb_qpointers_internal_t *qpointers) { 320 int audiodb_query_spec_qpointers(adb_t *adb, const adb_query_spec_t *spec, double **vector_data, double **vector, adb_qpointers_internal_t *qpointers) {
333 adb_datum_t *datum; 321 adb_datum_t *datum;
494 * don't overlap, so we will use exhaustive dot product evaluation 482 * don't overlap, so we will use exhaustive dot product evaluation
495 * (instead of memoization of partial sums, as in query_loop()). 483 * (instead of memoization of partial sums, as in query_loop()).
496 */ 484 */
497 double dist; 485 double dist;
498 double *dbdata = 0, *dbdata_pointer; 486 double *dbdata = 0, *dbdata_pointer;
499 Uns32T currentTrack = 0x80000000; // KLUDGE: Initialize with a value outside of track index range
500 Uns32T npairs = qstate->exact_evaluation_queue->size(); 487 Uns32T npairs = qstate->exact_evaluation_queue->size();
501 #ifdef _LSH_DEBUG_ 488 #ifdef _LSH_DEBUG_
502 cout << "Num vector pairs to evaluate: " << npairs << "..." << endl; 489 cout << "Num vector pairs to evaluate: " << npairs << "..." << endl;
503 cout.flush(); 490 cout.flush();
504 #endif 491 #endif
505 adb_datum_t d = {0}; 492 adb_datum_t d = {0};
506 while(npairs--) { 493 while(npairs--) {
507 PointPair pp = qstate->exact_evaluation_queue->top(); 494 PointPair pp = qstate->exact_evaluation_queue->top();
508 if(currentTrack != pp.trackID) { 495 maybe_delete_array(dbpointers.mean_duration);
509 maybe_delete_array(dbpointers.power_data); 496 if(audiodb_track_id_datum(adb, pp.trackID, &d, pp.spos, sequence_length)) {
510 maybe_delete_array(dbpointers.mean_duration); 497 delete qstate->exact_evaluation_queue;
511 currentTrack = pp.trackID; 498 delete qstate->set;
499 return 1;
500 }
501
502 if(audiodb_datum_qpointers_partial(&d, sequence_length, &dbdata, &dbdata_pointer, &dbpointers, qstate)) {
503 delete qstate->exact_evaluation_queue;
504 delete qstate->set;
512 audiodb_free_datum(&d); 505 audiodb_free_datum(&d);
513 if(audiodb_track_id_datum(adb, pp.trackID, &d)) { 506 return 1;
514 delete qstate->exact_evaluation_queue; 507 }
515 delete qstate->set;
516 return 1;
517 }
518 508
519 if(audiodb_datum_qpointers_partial(&d, sequence_length, &dbdata, &dbdata_pointer, &dbpointers, qstate)) {
520 delete qstate->exact_evaluation_queue;
521 delete qstate->set;
522 audiodb_free_datum(&d);
523 return 1;
524 }
525 }
526 Uns32T qPos = (spec->qid.flags & ADB_QID_FLAG_EXHAUSTIVE) ? pp.qpos : 0; 509 Uns32T qPos = (spec->qid.flags & ADB_QID_FLAG_EXHAUSTIVE) ? pp.qpos : 0;
527 Uns32T sPos = pp.spos; // index into l2norm table
528 // Test power thresholds before computing distance 510 // Test power thresholds before computing distance
529 if( ( (!power_refine) || audiodb_powers_acceptable(&spec->refine, qpointers->power[qPos], dbpointers.power[sPos])) && 511 if( ( (!power_refine) || audiodb_powers_acceptable(&spec->refine, qpointers->power[qPos], dbpointers.power[0])) &&
530 ( qPos<qpointers->nvectors-sequence_length+1 && sPos<(*adb->track_lengths)[pp.trackID]-sequence_length+1 ) ){ 512 ( qPos<qpointers->nvectors-sequence_length+1 && pp.spos<(*adb->track_lengths)[pp.trackID]-sequence_length+1 ) ){
531 // Compute distance 513 // Compute distance
532 dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length); 514 dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata, adb->header->dim*sequence_length);
533 double qn = audiodb_dot_product(query + qPos*adb->header->dim, query + qPos*adb->header->dim, adb->header->dim*sequence_length); 515 double qn = audiodb_dot_product(query + qPos*adb->header->dim, query + qPos*adb->header->dim, adb->header->dim*sequence_length);
534 double sn = audiodb_dot_product(dbdata + sPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length); 516 double sn = audiodb_dot_product(dbdata, dbdata, adb->header->dim*sequence_length);
535 qn = sqrt(qn); 517 qn = sqrt(qn);
536 sn = sqrt(sn); 518 sn = sqrt(sn);
537 switch(spec->params.distance) { 519 switch(spec->params.distance) {
538 case ADB_DISTANCE_EUCLIDEAN_NORMED: 520 case ADB_DISTANCE_EUCLIDEAN_NORMED:
539 dist = 2 - (2/(qn*sn))*dist; 521 dist = 2 - (2/(qn*sn))*dist;
551 r.ipos = pp.spos; 533 r.ipos = pp.spos;
552 qstate->accumulator->add_point(&r); 534 qstate->accumulator->add_point(&r);
553 } 535 }
554 } 536 }
555 qstate->exact_evaluation_queue->pop(); 537 qstate->exact_evaluation_queue->pop();
538 audiodb_free_datum(&d);
556 } 539 }
557 540
558 // Cleanup 541 // Cleanup
559 audiodb_free_datum(&d);
560 // maybe_delete_array(dbdata); 542 // maybe_delete_array(dbdata);
561 //maybe_delete_array(dbpointers.l2norm_data); 543 //maybe_delete_array(dbpointers.l2norm_data);
562 maybe_delete_array(dbpointers.power_data); 544 //maybe_delete_array(dbpointers.power_data);
563 maybe_delete_array(dbpointers.mean_duration); 545 maybe_delete_array(dbpointers.mean_duration);
564 delete qstate->exact_evaluation_queue; 546 delete qstate->exact_evaluation_queue;
565 delete qstate->set; 547 delete qstate->set;
566 return 0; 548 return 0;
567 } 549 }