diff query.cpp @ 292:d9a88cfd4ab6

Completed merge of lshlib back to current version of the trunk.
author mas01mc
date Tue, 29 Jul 2008 22:01:17 +0000
parents 210b2f661b88
children 896679d8cc39
line wrap: on
line diff
--- a/query.cpp	Tue Jul 22 20:09:31 2008 +0000
+++ b/query.cpp	Tue Jul 29 22:01:17 2008 +0000
@@ -1,5 +1,4 @@
 #include "audioDB.h"
-
 #include "reporter.h"
 
 bool audioDB::powers_acceptable(double p1, double p2) {
@@ -17,50 +16,88 @@
 }
 
 void audioDB::query(const char* dbName, const char* inFile, adb__queryResponse *adbQueryResponse) {
-  initTables(dbName, inFile);
-  Reporter *r = 0;
+  // init database tables and dbH first
+  if(query_from_key)
+    initTables(dbName);
+  else
+    initTables(dbName, inFile);
+
+  // keyKeyPos requires dbH to be initialized
+  if(query_from_key && (!key || (query_from_key_index = getKeyPos((char*)key))==O2_ERR_KEYNOTFOUND))
+    error("Query key not found :",key);  
+  
   switch (queryType) {
   case O2_POINT_QUERY:
     sequenceLength = 1;
     normalizedDistance = false;
-    r = new pointQueryReporter<std::greater < NNresult > >(pointNN);
+    reporter = new pointQueryReporter< std::greater < NNresult > >(pointNN);
     break;
   case O2_TRACK_QUERY:
     sequenceLength = 1;
     normalizedDistance = false;
-    r = new trackAveragingReporter<std::greater < NNresult > >(pointNN, trackNN, dbH->numFiles);
+    reporter = new trackAveragingReporter< std::greater< NNresult > >(pointNN, trackNN, dbH->numFiles);
     break;
-  case O2_SEQUENCE_QUERY:
+  case O2_SEQUENCE_QUERY:    
+    if(no_unit_norming)
+      normalizedDistance = false;
     if(radius == 0) {
-      r = new trackAveragingReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles);
+      reporter = new trackAveragingReporter< std::less< NNresult > >(pointNN, trackNN, dbH->numFiles);
     } else {
-      r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles);
+      if(index_exists(dbName, radius, sequenceLength)){
+	char* indexName = index_get_name(dbName, radius, sequenceLength);
+	lsh = new LSH(indexName);
+	assert(lsh);
+	reporter = new trackSequenceQueryRadReporter(trackNN, index_to_trackID(lsh->get_maxp())+1);
+	delete lsh;
+	delete[] indexName;
+      }
+      else
+	reporter = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles);
     }
     break;
-  case O2_N_SEQUENCE_QUERY :
+  case O2_N_SEQUENCE_QUERY:
+    if(no_unit_norming)
+      normalizedDistance = false;
     if(radius == 0) {
-      r = new trackSequenceQueryNNReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles);
+      reporter = new trackSequenceQueryNNReporter< std::less < NNresult > >(pointNN, trackNN, dbH->numFiles);
     } else {
-      r = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles);
+      if(index_exists(dbName, radius, sequenceLength)){
+	char* indexName = index_get_name(dbName, radius, sequenceLength);
+	lsh = new LSH(indexName);
+	assert(lsh);
+	reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, index_to_trackID(lsh->get_maxp())+1);
+	delete lsh;
+	delete[] indexName;
+      }
+      else
+	reporter = new trackSequenceQueryRadNNReporter(pointNN,trackNN, dbH->numFiles);
     }
     break;
   case O2_ONE_TO_ONE_N_SEQUENCE_QUERY :
     if(radius == 0) {
       error("query-type not yet supported");
     } else {
-      r = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles);
+      reporter = new trackSequenceQueryRadNNReporterOneToOne(pointNN,trackNN, dbH->numFiles);
     }
     break;
   default:
     error("unrecognized queryType in query()");
   }  
-  query_loop(dbName, inFile, r);
-  r->report(fileTable, adbQueryResponse);
-  delete r;
+
+  // Test for index (again) here
+  if(radius && index_exists(dbName, radius, sequenceLength)) 
+    index_query_loop(dbName, query_from_key_index);
+  else
+    query_loop(dbName, query_from_key_index);
+
+  reporter->report(fileTable, adbQueryResponse);
 }
 
 // return ordinal position of key in keyTable
+// this should really be a STL hash map search
 unsigned audioDB::getKeyPos(char* key){  
+  if(!dbH)
+    error("dbH not initialized","getKeyPos");
   for(unsigned k=0; k<dbH->numFiles; k++)
     if(strncmp(fileTable + k*O2_FILETABLE_ENTRY_SIZE, key, strlen(key))==0)
       return k;
@@ -155,8 +192,8 @@
         sp = DD[j];
         spd = D[j+w] + w;
         k = trackTable[track] - w;
-        while(k--)
-          *sp++ += *spd++;
+	while(k--)
+	  *sp++ += *spd++;
       }
     }
   } else { // HOP_SIZE != 1
@@ -211,7 +248,7 @@
 // actual query point's information.  -- CSR, 2007-12-05
 void audioDB::set_up_query(double **qp, double **vqp, double **qnp, double **vqnp, double **qpp, double **vqpp, double *mqdp, unsigned *nvp) {
   *nvp = (statbuf.st_size - sizeof(int)) / (dbH->dim * sizeof(double));
-  
+
   if(!(dbH->flags & O2_FLAG_L2NORM)) {
     error("Database must be L2 normed for sequence query","use -L2NORM");
   }
@@ -285,6 +322,92 @@
   }
 }
 
+// Does the same as set_up_query(...) but from database features instead of from a file
+// Constructs the same outputs as set_up_query
+void audioDB::set_up_query_from_key(double **qp, double **vqp, double **qnp, double **vqnp, double **qpp, double **vqpp, double *mqdp, unsigned *nvp, Uns32T queryIndex) {
+  if(!trackTable)
+    error("trackTable not initialized","set_up_query_from_key");
+
+  if(!(dbH->flags & O2_FLAG_L2NORM)) {
+    error("Database must be L2 normed for sequence query","use -L2NORM");
+  }
+  
+  if(dbH->flags & O2_FLAG_POWER)
+    usingPower = true;
+  
+  if(dbH->flags & O2_FLAG_TIMES)
+    usingTimes = true;
+
+  *nvp = trackTable[queryIndex];  
+  if(*nvp < sequenceLength) {
+    error("Query shorter than requested sequence length", "maybe use -l");
+  }
+  
+  VERB_LOG(1, "performing norms... ");
+
+  // Read query feature vectors from database
+  *qp = NULL;
+  lseek(dbfid, dbH->dataOffset + trackOffsetTable[queryIndex] * sizeof(double), SEEK_SET);
+  size_t allocatedSize = 0;
+  read_data(queryIndex, qp, &allocatedSize);
+  // Consistency check on allocated memory and query feature size
+  if(*nvp*sizeof(double)*dbH->dim != allocatedSize)
+    error("Query memory allocation failed consitency check","set_up_query_from_key");
+
+  Uns32T trackIndexOffset = trackOffsetTable[queryIndex]/dbH->dim; // Convert num data elements to num vectors
+  // Copy L2 norm partial-sum coefficients
+  assert(*qnp = new double[*nvp]);
+  memcpy(*qnp, l2normTable+trackIndexOffset, *nvp*sizeof(double));
+  sequence_sum(*qnp, *nvp, sequenceLength);
+  sequence_sqrt(*qnp, *nvp, sequenceLength);
+
+  if( usingPower ){
+    // Copy Power partial-sum coefficients
+    assert(*qpp = new double[*nvp]);
+    memcpy(*qpp, powerTable+trackIndexOffset, *nvp*sizeof(double));
+    sequence_sum(*qpp, *nvp, sequenceLength);
+    sequence_average(*qpp, *nvp, sequenceLength);
+  }
+
+  if (usingTimes) {
+    unsigned int k;
+    *mqdp = 0.0;
+    double *querydurs = new double[*nvp];
+    double *timesdata = new double[*nvp*2];
+    assert(querydurs && timesdata);
+    memcpy(timesdata, timesTable+trackIndexOffset, *nvp*sizeof(double));    
+    for(k = 0; k < *nvp; k++) {
+      querydurs[k] = timesdata[2*k+1] - timesdata[2*k];
+      *mqdp += querydurs[k];
+    }
+    *mqdp /= k;
+    
+    VERB_LOG(1, "mean query file duration: %f\n", *mqdp);
+    
+    delete [] querydurs;
+    delete [] timesdata;
+  }
+  // Defaults, for exhaustive search (!usingQueryPoint)
+  *vqp = *qp;
+  *vqnp = *qnp;
+  *vqpp = *qpp;
+
+  if(usingQueryPoint) {
+    if(queryPoint > *nvp || queryPoint > *nvp - sequenceLength + 1) {
+      error("queryPoint > numVectors-wL+1 in query");
+    } else {
+      VERB_LOG(1, "query point: %u\n", queryPoint);
+      *vqp = *qp + queryPoint * dbH->dim;
+      *vqnp = *qnp + queryPoint;
+      if (usingPower) {
+        *vqpp = *qpp + queryPoint;
+      }
+      *nvp = sequenceLength;
+    }
+  }
+}
+
+
 // FIXME: this is not the right name; we're not actually setting up
 // the database, but copying various bits of it out of mmap()ed tables
 // in order to reduce seeks.
@@ -341,14 +464,98 @@
   *vspp = *spp;
 }
 
-void audioDB::query_loop(const char* dbName, const char* inFile, Reporter *reporter) {
+// query_points()
+//
+// using PointPairs held in the exact_evaluation_queue compute squared distance for each PointPair
+// and insert result into the current reporter.
+//
+// Preconditions:
+// A query inFile has been opened with setup_query(...) and query pointers initialized
+// The database contains some points
+// An exact_evaluation_queue has been allocated and populated
+// A reporter has been allocated
+//
+// Postconditions:
+// reporter contains the points and distances that meet the reporter constraints 
+
+void audioDB::query_loop_points(double* query, double* qnPtr, double* qpPtr, double meanQdur, Uns32T numVectors){ 
+  unsigned int dbVectors;
+  double *sNorm, *snPtr, *sPower = 0, *spPtr = 0;
+  double *meanDBdur = 0;
+
+  // check pre-conditions
+  assert(exact_evaluation_queue&&reporter);
+  if(!exact_evaluation_queue->size()) // Exit if no points to evaluate
+    return;
+
+  // Compute database info
+  // FIXME: we more than likely don't need very much of the database
+  // so make a new method to build these values per-track or, even better, per-point
+  set_up_db(&sNorm, &snPtr, &sPower, &spPtr, &meanDBdur, &dbVectors);
+
+  VERB_LOG(1, "matching points...");
+
+  assert(pointNN>0 && pointNN<=O2_MAXNN);
+  assert(trackNN>0 && trackNN<=O2_MAXNN);
+
+  // We are guaranteed that the order of points is sorted by:
+  // qpos, trackID, spos
+  // so we can be relatively efficient in initialization of track data.
+  // Here we assume that points don't overlap, so we will use exhaustive dot
+  // product evaluation over the sequence
+  double dist;
+  size_t data_buffer_size = 0;
+  double *data_buffer = 0;
+  Uns32T trackOffset;
+  Uns32T trackIndexOffset;
+  Uns32T currentTrack = 0x80000000; // Initialize with a value outside of track index range
+  Uns32T npairs = exact_evaluation_queue->size();
+  while(npairs--){
+    PointPair pp = exact_evaluation_queue->top();
+    trackOffset=trackOffsetTable[pp.trackID]; // num data elements offset
+    trackIndexOffset=trackOffset/dbH->dim;    // num vectors offset
+    if((!(usingPower) || powers_acceptable(qpPtr[usingQueryPoint?0:pp.qpos], sPower[trackIndexOffset+pp.spos])) &&
+       ((usingQueryPoint?0:pp.qpos) < numVectors-sequenceLength+1 && pp.spos < trackTable[pp.trackID]-sequenceLength+1)){
+      if(currentTrack!=pp.trackID){
+	currentTrack=pp.trackID;
+        lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET);
+	read_data(currentTrack, &data_buffer, &data_buffer_size);
+      }
+      dist = dot_product_points(query+(usingQueryPoint?0:pp.qpos*dbH->dim), data_buffer+pp.spos*dbH->dim, dbH->dim*sequenceLength);
+      if(normalizedDistance) 
+	dist = 2-(2/(qnPtr[usingQueryPoint?0:pp.qpos]*sNorm[trackIndexOffset+pp.spos]))*dist;
+      else 
+	if(no_unit_norming)
+	  dist = qnPtr[usingQueryPoint?0:pp.qpos]*qnPtr[usingQueryPoint?0:pp.qpos]+sNorm[trackIndexOffset+pp.spos]*sNorm[trackIndexOffset+pp.spos] - 2*dist;
+      // else
+      // dist = dist;      
+      if((!radius) || dist <= (radius+O2_DISTANCE_TOLERANCE)) 
+	reporter->add_point(pp.trackID, pp.qpos, pp.spos, dist);
+    }
+    exact_evaluation_queue->pop();
+  }
+}
+
+// A completely unprotected dot-product method
+// Caller is responsible for ensuring that memory is within bounds
+inline double audioDB::dot_product_points(double* q, double* p, Uns32T  L){
+  double dist = 0.0;
+  while(L--)
+    dist += *q++ * *p++;
+  return dist;
+}
+
+void audioDB::query_loop(const char* dbName, Uns32T queryIndex) {
   
   unsigned int numVectors;
   double *query, *query_data;
   double *qNorm, *qnPtr, *qPower = 0, *qpPtr = 0;
   double meanQdur;
 
-  set_up_query(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors);
+  if(query_from_key)
+    set_up_query_from_key(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors, queryIndex);
+  else
+    set_up_query(&query_data, &query, &qNorm, &qnPtr, &qPower, &qpPtr, &meanQdur, &numVectors);
 
   unsigned int dbVectors;
   double *sNorm, *snPtr, *sPower = 0, *spPtr = 0;
@@ -365,21 +572,12 @@
   double **D = 0;    // Differences query and target 
   double **DD = 0;   // Matched filter distance
 
-  D = new double*[numVectors];
+  D = new double*[numVectors]; // pre-allocate 
   DD = new double*[numVectors];
 
   gettimeofday(&tv1, NULL); 
   unsigned processedTracks = 0;
-
-  // build track offset table
-  off_t *trackOffsetTable = new off_t[dbH->numFiles];
-  unsigned cumTrack=0;
   off_t trackIndexOffset;
-  for(k = 0; k < dbH->numFiles; k++){
-    trackOffsetTable[k] = cumTrack;
-    cumTrack += trackTable[k] * dbH->dim;
-  }
-
   char nextKey[MAXSTR];
 
   // Track loop 
@@ -403,6 +601,18 @@
       }
     }
 
+    // skip identity on query_from_key
+    if( query_from_key && (track == queryIndex) ) {
+      if(queryIndex!=dbH->numFiles-1){
+	track++;
+	trackOffset = trackOffsetTable[track];
+	lseek(dbfid, dbH->dataOffset + trackOffset * sizeof(double), SEEK_SET);
+      }
+      else{
+	break;
+      }
+    }
+
     trackIndexOffset=trackOffset/dbH->dim; // numVectors offset
 
     read_data(track, &data_buffer, &data_buffer_size);
@@ -425,15 +635,18 @@
 	for(j = 0; j <= numVectors - wL; j += HOP_SIZE) {
 	  for(k = 0; k <= trackTable[track] - wL; k += HOP_SIZE) {
             double thisDist;
-            if(normalizedDistance) {
+            if(normalizedDistance) 
               thisDist = 2-(2/(qnPtr[j]*sNorm[trackIndexOffset+k]))*DD[j][k];
-            } else {
-              thisDist = DD[j][k];
-            }
+	    else 
+	      if(no_unit_norming)
+		thisDist = qnPtr[j]*qnPtr[j]+sNorm[trackIndexOffset+k]*sNorm[trackIndexOffset+k] - 2*DD[j][k];
+	      else
+		thisDist = DD[j][k];
+
 	    // Power test
 	    if ((!usingPower) || powers_acceptable(qpPtr[j], sPower[trackIndexOffset + k])) {
               // radius test
-              if((!radius) || thisDist < radius) {
+              if((!radius) || thisDist <= (radius+O2_DISTANCE_TOLERANCE)) {
                 reporter->add_point(track, usingQueryPoint ? queryPoint : j, k, thisDist);
               }
             }
@@ -452,8 +665,6 @@
            (tv1.tv_sec*1000 + tv1.tv_usec/1000))
 
   // Clean up
-  if(trackOffsetTable)
-    delete[] trackOffsetTable;
   if(query_data)
     delete[] query_data;
   if(qNorm)
@@ -493,3 +704,5 @@
   }
   VERB_LOG(2, "done.\n");
 }
+
+