comparison query.cpp @ 292:d9a88cfd4ab6

Completed merge of lshlib back to current version of the trunk.
author mas01mc
date Tue, 29 Jul 2008 22:01:17 +0000
parents 210b2f661b88
children 896679d8cc39
comparison
equal deleted inserted replaced
291:63ae0dfc1767 292:d9a88cfd4ab6
1 #include "audioDB.h" 1 #include "audioDB.h"
2
3 #include "reporter.h" 2 #include "reporter.h"
4 3
5 bool audioDB::powers_acceptable(double p1, double p2) { 4 bool audioDB::powers_acceptable(double p1, double p2) {
6 if (use_absolute_threshold) { 5 if (use_absolute_threshold) {
7 if ((p1 < absolute_threshold) || (p2 < absolute_threshold)) { 6 if ((p1 < absolute_threshold) || (p2 < absolute_threshold)) {
15 } 14 }
16 return true; 15 return true;
17 } 16 }
18 17
19 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) { 18 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) {
20 initTables(dbName, inFile); 19 // init database tables and dbH first
21 Reporter *r = 0; 20 if(query_from_key)
21 initTables(dbName);
22 else
23 initTables(dbName, inFile);
24
25 // keyKeyPos requires dbH to be initialized
26 if(query_from_key && (!key || (query_from_key_index = getKeyPos((char*)key))==O2_ERR_KEYNOTFOUND))
27 error("Query key not found :",key);
28
22 switch (queryType) { 29 switch (queryType) {
23 case O2_POINT_QUERY: 30 case O2_POINT_QUERY:
24 sequenceLength = 1; 31 sequenceLength = 1;
25 normalizedDistance = false; 32 normalizedDistance = false;
26 r = new pointQueryReporter<std::greater < NNresult > >(pointNN); 33 reporter = new pointQueryReporter< std::greater < NNresult > >(pointNN);
27 break; 34 break;
28 case O2_TRACK_QUERY: 35 case O2_TRACK_QUERY:
29 sequenceLength = 1; 36 sequenceLength = 1;
30 normalizedDistance = false; 37 normalizedDistance = false;
31 r = new trackAveragingReporter<std::greater < NNresult > >(pointNN, trackNN, dbH->numFiles); 38 reporter = new trackAveragingReporter< std::greater< NNresult > >(pointNN, trackNN, dbH->numFiles);
32 break; 39 break;
33 case O2_SEQUENCE_QUERY: 40 case O2_SEQUENCE_QUERY:
41 if(no_unit_norming)
42 normalizedDistance = false;
34 if(radius == 0) { 43 if(radius == 0) {
35 r = new trackAveragingReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); 44 reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles);
36 } else { 45 } else {
37 r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); 46 if(index_exists(dbName, radius, sequenceLength)){
47 char* indexName = index_get_name(dbName, radius, sequenceLength);
48 lsh = new LSH(indexName);
49 assert(lsh);
50 reporter = new trackSequenceQueryRadReporter(trackNN, index_to_trackID(lsh->get_maxp())+1);
51 delete lsh;
52 delete[] indexName;
53 }
54 else
55 reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles);
38 } 56 }
39 break; 57 break;
40 case O2_N_SEQUENCE_QUERY : 58 case O2_N_SEQUENCE_QUERY:
59 if(no_unit_norming)
60 normalizedDistance = false;
41 if(radius == 0) { 61 if(radius == 0) {
42 r = new trackSequenceQueryNNReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); 62 reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles);
43 } else { 63 } else {
44 r = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles); 64 if(index_exists(dbName, radius, sequenceLength)){
65 char* indexName = index_get_name(dbName, radius, sequenceLength);
66 lsh = new LSH(indexName);
67 assert(lsh);
68 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, index_to_trackID(lsh->get_maxp())+1);
69 delete lsh;
70 delete[] indexName;
71 }
72 else
73 reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles);
45 } 74 }
46 break; 75 break;
47 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY : 76 case O2_ONE_TO_ONE_N_SEQUENCE_QUERY :
48 if(radius == 0) { 77 if(radius == 0) {
49 error("query-type not yet supported"); 78 error("query-type not yet supported");
50 } else { 79 } else {
51 r = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles); 80 reporter = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles);
52 } 81 }
53 break; 82 break;
54 default: 83 default:
55 error("unrecognized queryType in query()"); 84 error("unrecognized queryType in query()");
56 } 85 }
57 query_loop(dbName, inFile, r); 86
58 r->report(fileTable, adbQueryResponse); 87 // Test for index (again) here
59 delete r; 88 if(radius && index_exists(dbName, radius, sequenceLength))
89 index_query_loop(dbName, query_from_key_index);
90 else
91 query_loop(dbName, query_from_key_index);
92
93 reporter->report(fileTable, adbQueryResponse);
60 } 94 }
61 95
62 // return ordinal position of key in keyTable 96 // return ordinal position of key in keyTable
97 // this should really be a STL hash map search
63 unsigned audioDB::getKeyPos(char* key){ 98 unsigned audioDB::getKeyPos(char* key){
99 if(!dbH)
100 error("dbH not initialized","getKeyPos");
64 for(unsigned k=0; k<dbH->numFiles; k++) 101 for(unsigned k=0; k<dbH->numFiles; k++)
65 if(strncmp(fileTable + k*O2_FILETABLE_ENTRY_SIZE, key, strlen(key))==0) 102 if(strncmp(fileTable + k*O2_FILETABLE_ENTRY_SIZE, key, strlen(key))==0)
66 return k; 103 return k;
67 error("Key not found",key); 104 error("Key not found",key);
68 return O2_ERR_KEYNOTFOUND; 105 return O2_ERR_KEYNOTFOUND;
153 for(w = 0; w < wL; w++) { 190 for(w = 0; w < wL; w++) {
154 for(j = 0; j < numVectors - w; j++) { 191 for(j = 0; j < numVectors - w; j++) {
155 sp = DD[j]; 192 sp = DD[j];
156 spd = D[j+w] + w; 193 spd = D[j+w] + w;
157 k = trackTable[track] - w; 194 k = trackTable[track] - w;
158 while(k--) 195 while(k--)
159 *sp++ += *spd++; 196 *sp++ += *spd++;
160 } 197 }
161 } 198 }
162 } else { // HOP_SIZE != 1 199 } else { // HOP_SIZE != 1
163 for(w = 0; w < wL; w++) { 200 for(w = 0; w < wL; w++) {
164 for(j = 0; j < numVectors - w; j += HOP_SIZE) { 201 for(j = 0; j < numVectors - w; j += HOP_SIZE) {
209 // are pointers to the query, norm and power vectors; the names 246 // are pointers to the query, norm and power vectors; the names
210 // starting with "v" are things that will end up pointing to the 247 // starting with "v" are things that will end up pointing to the
211 // actual query point's information. -- CSR, 2007-12-05 248 // actual query point's information. -- CSR, 2007-12-05
212 void audioDB::set_up_query(double **qp, double **vqp, double **qnp, double **vqnp, double **qpp, double **vqpp, double *mqdp, unsigned *nvp) { 249 void audioDB::set_up_query(double **qp, double **vqp, double **qnp, double **vqnp, double **qpp, double **vqpp, double *mqdp, unsigned *nvp) {
213 *nvp = (statbuf.st_size - sizeof(int)) / (dbH->dim * sizeof(double)); 250 *nvp = (statbuf.st_size - sizeof(int)) / (dbH->dim * sizeof(double));
214 251
215 if(!(dbH->flags & O2_FLAG_L2NORM)) { 252 if(!(dbH->flags & O2_FLAG_L2NORM)) {
216 error("Database must be L2 normed for sequence query","use -L2NORM"); 253 error("Database must be L2 normed for sequence query","use -L2NORM");
217 } 254 }
218 255
219 if(*nvp < sequenceLength) { 256 if(*nvp < sequenceLength) {
283 *nvp = sequenceLength; 320 *nvp = sequenceLength;
284 } 321 }
285 } 322 }
286 } 323 }
287 324
325 // Does the same as set_up_query(...) but from database features instead of from a file
326 // Constructs the same outputs as set_up_query
327 void audioDB::set_up_query_from_key(double **qp, double **vqp, double **qnp, double **vqnp, double **qpp, double **vqpp, double *mqdp, unsigned *nvp, Uns32T queryIndex) {
328 if(!trackTable)
329 error("trackTable not initialized","set_up_query_from_key");
330
331 if(!(dbH->flags & O2_FLAG_L2NORM)) {
332 error("Database must be L2 normed for sequence query","use -L2NORM");
333 }
334
335 if(dbH->flags & O2_FLAG_POWER)
336 usingPower = true;
337
338 if(dbH->flags & O2_FLAG_TIMES)
339 usingTimes = true;
340
341 *nvp = trackTable[queryIndex];
342 if(*nvp < sequenceLength) {
343 error("Query shorter than requested sequence length", "maybe use -l");
344 }
345
346 VERB_LOG(1, "performing norms... ");
347
348 // Read query feature vectors from database
349 *qp = NULL;
350 lseek(dbfid, dbH->dataOffset + trackOffsetTable[queryIndex] * sizeof(double), SEEK_SET);
351 size_t allocatedSize = 0;
352 read_data(queryIndex, qp, &allocatedSize);
353 // Consistency check on allocated memory and query feature size
354 if(*nvp*sizeof(double)*dbH->dim != allocatedSize)
355 error("Query memory allocation failed consitency check","set_up_query_from_key");
356
357 Uns32T trackIndexOffset = trackOffsetTable[queryIndex]/dbH->dim; // Convert num data elements to num vectors
358 // Copy L2 norm partial-sum coefficients
359 assert(*qnp = new double[*nvp]);
360 memcpy(*qnp, l2normTable+trackIndexOffset, *nvp*sizeof(double));
361 sequence_sum(*qnp, *nvp, sequenceLength);
362 sequence_sqrt(*qnp, *nvp, sequenceLength);
363
364 if( usingPower ){
365 // Copy Power partial-sum coefficients
366 assert(*qpp = new double[*nvp]);
367 memcpy(*qpp, powerTable+trackIndexOffset, *nvp*sizeof(double));
368 sequence_sum(*qpp, *nvp, sequenceLength);
369 sequence_average(*qpp, *nvp, sequenceLength);
370 }
371
372 if (usingTimes) {
373 unsigned int k;
374 *mqdp = 0.0;
375 double *querydurs = new double[*nvp];
376 double *timesdata = new double[*nvp*2];
377 assert(querydurs && timesdata);
378 memcpy(timesdata, timesTable+trackIndexOffset, *nvp*sizeof(double));
379 for(k = 0; k < *nvp; k++) {
380 querydurs[k] = timesdata[2*k+1] - timesdata[2*k];
381 *mqdp += querydurs[k];
382 }
383 *mqdp /= k;
384
385 VERB_LOG(1, "mean query file duration: %f\n", *mqdp);
386
387 delete [] querydurs;
388 delete [] timesdata;
389 }
390 // Defaults, for exhaustive search (!usingQueryPoint)
391 *vqp = *qp;
392 *vqnp = *qnp;
393 *vqpp = *qpp;
394
395 if(usingQueryPoint) {
396 if(queryPoint > *nvp || queryPoint > *nvp - sequenceLength + 1) {
397 error("queryPoint > numVectors-wL+1 in query");
398 } else {
399 VERB_LOG(1, "query point: %u\n", queryPoint);
400 *vqp = *qp + queryPoint * dbH->dim;
401 *vqnp = *qnp + queryPoint;
402 if (usingPower) {
403 *vqpp = *qpp + queryPoint;
404 }
405 *nvp = sequenceLength;
406 }
407 }
408 }
409
410
288 // FIXME: this is not the right name; we're not actually setting up 411 // FIXME: this is not the right name; we're not actually setting up
289 // the database, but copying various bits of it out of mmap()ed tables 412 // the database, but copying various bits of it out of mmap()ed tables
290 // in order to reduce seeks. 413 // in order to reduce seeks.
291 void audioDB::set_up_db(double **snp, double **vsnp, double **spp, double **vspp, double **mddp, unsigned int *dvp) { 414 void audioDB::set_up_db(double **snp, double **vsnp, double **spp, double **vspp, double **mddp, unsigned int *dvp) {
292 *dvp = dbH->length / (dbH->dim * sizeof(double)); 415 *dvp = dbH->length / (dbH->dim * sizeof(double));
339 462
340 *vsnp = *snp; 463 *vsnp = *snp;
341 *vspp = *spp; 464 *vspp = *spp;
342 } 465 }
343 466
344 void audioDB::query_loop(const char* dbName, const char* inFile, Reporter *reporter) { 467 // query_points()
468 //
469 // using PointPairs held in the exact_evaluation_queue compute squared distance for each PointPair
470 // and insert result into the current reporter.
471 //
472 // Preconditions:
473 // A query inFile has been opened with setup_query(...) and query pointers initialized
474 // The database contains some points
475 // An exact_evaluation_queue has been allocated and populated
476 // A reporter has been allocated
477 //
478 // Postconditions:
479 // reporter contains the points and distances that meet the reporter constraints
480
481 void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors){
482 unsigned int dbVectors;
483 double *sNorm, *snPtr, *sPower = 0, *spPtr = 0;
484 double *meanDBdur = 0;
485
486 // check pre-conditions
487 assert(exact_evaluation_queue&&reporter);
488 if(!exact_evaluation_queue->size()) // Exit if no points to evaluate
489 return;
490
491 // Compute database info
492 // FIXME: we more than likely don't need very much of the database
493 // so make a new method to build these values per-track or, even better, per-point
494 set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors);
495
496 VERB_LOG(1, "matching points...");
497
498 assert(pointNN>0 && pointNN<=O2_MAXNN);
499 assert(trackNN>0 && trackNN<=O2_MAXNN);
500
501 // We are guaranteed that the order of points is sorted by:
502 // qpos, trackID, spos
503 // so we can be relatively efficient in initialization of track data.
504 // Here we assume that points don't overlap, so we will use exhaustive dot
505 // product evaluation over the sequence
506 double dist;
507 size_t data_buffer_size = 0;
508 double *data_buffer = 0;
509 Uns32T trackOffset;
510 Uns32T trackIndexOffset;
511 Uns32T currentTrack = 0x80000000; // Initialize with a value outside of track index range
512 Uns32T npairs = exact_evaluation_queue->size();
513 while(npairs--){
514 PointPair pp = exact_evaluation_queue->top();
515 trackOffset=trackOffsetTable[pp.trackID]; // num data elements offset
516 trackIndexOffset=trackOffset/dbH->dim; // num vectors offset
517 if((!(usingPower) || powers_acceptable(qpPtr[usingQueryPoint?0:pp.qpos], sPower[trackIndexOffset+pp.spos])) &&
518 ((usingQueryPoint?0:pp.qpos) < numVectors-sequenceLength+1 && pp.spos < trackTable[pp.trackID]-sequenceLength+1)){
519 if(currentTrack!=pp.trackID){
520 currentTrack=pp.trackID;
521 lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET);
522 read_data(currentTrack, &data_buffer, &data_buffer_size);
523 }
524 dist = dot_product_points(query+(usingQueryPoint?0:pp.qpos*dbH->dim), data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength);
525 if(normalizedDistance)
526 dist = 2-(2/(qnPtr[usingQueryPoint?0:pp.qpos]*sNorm[trackIndexOffset+pp.spos]))*dist;
527 else
528 if(no_unit_norming)
529 dist = qnPtr[usingQueryPoint?0:pp.qpos]*qnPtr[usingQueryPoint?0:pp.qpos]+sNorm[trackIndexOffset+pp.spos]*sNorm[trackIndexOffset+pp.spos] - 2*dist;
530 // else
531 // dist = dist;
532 if((!radius) || dist <= (radius+O2_DISTANCE_TOLERANCE))
533 reporter->add_point(pp.trackID, pp.qpos, pp.spos, dist);
534 }
535 exact_evaluation_queue->pop();
536 }
537 }
538
539 // A completely unprotected dot-product method
540 // Caller is responsible for ensuring that memory is within bounds
541 inline double audioDB::dot_product_points(double* q, double* p, Uns32T L){
542 double dist = 0.0;
543 while(L--)
544 dist += *q++ * *p++;
545 return dist;
546 }
547
548 void audioDB::query_loop(const char* dbName, Uns32T queryIndex) {
345 549
346 unsigned int numVectors; 550 unsigned int numVectors;
347 double *query, *query_data; 551 double *query, *query_data;
348 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0; 552 double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0;
349 double meanQdur; 553 double meanQdur;
350 554
351 set_up_query(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors); 555 if(query_from_key)
556 set_up_query_from_key(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors, queryIndex);
557 else
558 set_up_query(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors);
352 559
353 unsigned int dbVectors; 560 unsigned int dbVectors;
354 double *sNorm, *snPtr, *sPower = 0, *spPtr = 0; 561 double *sNorm, *snPtr, *sPower = 0, *spPtr = 0;
355 double *meanDBdur = 0; 562 double *meanDBdur = 0;
356 563
363 570
364 unsigned j,k,track,trackOffset=0, HOP_SIZE=sequenceHop, wL=sequenceLength; 571 unsigned j,k,track,trackOffset=0, HOP_SIZE=sequenceHop, wL=sequenceLength;
365 double **D = 0; // Differences query and target 572 double **D = 0; // Differences query and target
366 double **DD = 0; // Matched filter distance 573 double **DD = 0; // Matched filter distance
367 574
368 D = new double*[numVectors]; 575 D = new double*[numVectors]; // pre-allocate
369 DD = new double*[numVectors]; 576 DD = new double*[numVectors];
370 577
371 gettimeofday(&tv1, NULL); 578 gettimeofday(&tv1, NULL);
372 unsigned processedTracks = 0; 579 unsigned processedTracks = 0;
373
374 // build track offset table
375 off_t *trackOffsetTable = new off_t[dbH->numFiles];
376 unsigned cumTrack=0;
377 off_t trackIndexOffset; 580 off_t trackIndexOffset;
378 for(k = 0; k < dbH->numFiles; k++){
379 trackOffsetTable[k] = cumTrack;
380 cumTrack += trackTable[k] * dbH->dim;
381 }
382
383 char nextKey[MAXSTR]; 581 char nextKey[MAXSTR];
384 582
385 // Track loop 583 // Track loop
386 size_t data_buffer_size = 0; 584 size_t data_buffer_size = 0;
387 double *data_buffer = 0; 585 double *data_buffer = 0;
401 } else { 599 } else {
402 break; 600 break;
403 } 601 }
404 } 602 }
405 603
604 // skip identity on query_from_key
605 if( query_from_key && (track == queryIndex) ) {
606 if(queryIndex!=dbH->numFiles-1){
607 track++;
608 trackOffset = trackOffsetTable[track];
609 lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET);
610 }
611 else{
612 break;
613 }
614 }
615
406 trackIndexOffset=trackOffset/dbH->dim; // numVectors offset 616 trackIndexOffset=trackOffset/dbH->dim; // numVectors offset
407 617
408 read_data(track, &data_buffer, &data_buffer_size); 618 read_data(track, &data_buffer, &data_buffer_size);
409 if(sequenceLength <= trackTable[track]) { // test for short sequences 619 if(sequenceLength <= trackTable[track]) { // test for short sequences
410 620
423 633
424 // Search for minimum distance by shingles (concatenated vectors) 634 // Search for minimum distance by shingles (concatenated vectors)
425 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) { 635 for(j = 0; j <= numVectors - wL; j += HOP_SIZE) {
426 for(k = 0; k <= trackTable[track] - wL; k += HOP_SIZE) { 636 for(k = 0; k <= trackTable[track] - wL; k += HOP_SIZE) {
427 double thisDist; 637 double thisDist;
428 if(normalizedDistance) { 638 if(normalizedDistance)
429 thisDist = 2-(2/(qnPtr[j]*sNorm[trackIndexOffset+k]))*DD[j][k]; 639 thisDist = 2-(2/(qnPtr[j]*sNorm[trackIndexOffset+k]))*DD[j][k];
430 } else { 640 else
431 thisDist = DD[j][k]; 641 if(no_unit_norming)
432 } 642 thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k];
643 else
644 thisDist = DD[j][k];
645
433 // Power test 646 // Power test
434 if ((!usingPower) || powers_acceptable(qpPtr[j], sPower[trackIndexOffset + k])) { 647 if ((!usingPower) || powers_acceptable(qpPtr[j], sPower[trackIndexOffset + k])) {
435 // radius test 648 // radius test
436 if((!radius) || thisDist < radius) { 649 if((!radius) || thisDist <= (radius+O2_DISTANCE_TOLERANCE)) {
437 reporter->add_point(track, usingQueryPoint ? queryPoint : j, k, thisDist); 650 reporter->add_point(track, usingQueryPoint ? queryPoint : j, k, thisDist);
438 } 651 }
439 } 652 }
440 } 653 }
441 } 654 }
450 VERB_LOG(1,"elapsed time: %ld msec\n", 663 VERB_LOG(1,"elapsed time: %ld msec\n",
451 (tv2.tv_sec*1000 + tv2.tv_usec/1000) - 664 (tv2.tv_sec*1000 + tv2.tv_usec/1000) -
452 (tv1.tv_sec*1000 + tv1.tv_usec/1000)) 665 (tv1.tv_sec*1000 + tv1.tv_usec/1000))
453 666
454 // Clean up 667 // Clean up
455 if(trackOffsetTable)
456 delete[] trackOffsetTable;
457 if(query_data) 668 if(query_data)
458 delete[] query_data; 669 delete[] query_data;
459 if(qNorm) 670 if(qNorm)
460 delete[] qNorm; 671 delete[] qNorm;
461 if(sNorm) 672 if(sNorm)
491 } 702 }
492 X += dim; 703 X += dim;
493 } 704 }
494 VERB_LOG(2, "done.\n"); 705 VERB_LOG(2, "done.\n");
495 } 706 }
707
708