diff audioDB.cpp @ 17:6d899df0cfe4

added Euclidean distance for sequences with -R (--radus) (via dot product of unit norm vectors), re-worked L2-norm behaviour, fixed a load of bugs there, fixed shingle norming. Cosine dist sequence match not working now because of L2 norm behaviour
author mas01mc
date Fri, 10 Aug 2007 04:52:33 +0000
parents c633f3819a49
children 999c9c216565
line wrap: on
line diff
--- a/audioDB.cpp	Thu Aug 02 10:47:20 2007 +0000
+++ b/audioDB.cpp	Fri Aug 10 04:52:33 2007 +0000
@@ -135,7 +135,8 @@
   isClient(0),
   isServer(0),
   port(0),
-  timesTol(0.1){
+  timesTol(0.1),
+  radius(0){
   
   if(processArgs(argc, argv)<0){
     printf("No command found.\n");
@@ -226,6 +227,17 @@
     }
   }
 
+  if(args_info.radius_given){
+    radius=args_info.radius_arg;
+    if(radius<=0 || radius>1000000000){
+      cerr << "Warning: radius out of range" << endl;
+      exit(1);
+    }
+    else 
+      if(verbosity>3)
+	cerr << "Setting radius to " << radius << endl;
+  }
+  
   if(args_info.SERVER_given){
     command=COM_SERVER;
     port=args_info.SERVER_arg;
@@ -356,7 +368,7 @@
    
 
    segNN=args_info.resultlength_arg;
-   if(segNN<1 || segNN >1000)
+   if(segNN<1 || segNN >10000)
      error("resultlength out of range: 1 <= resultlength <= 1000");
 
 	         
@@ -905,8 +917,10 @@
   if(!dbH)
     initTables(dbName,0);
   
-  for(unsigned k=0; k<dbH->numFiles; k++)
+  for(unsigned k=0, j=0; k<dbH->numFiles; k++){
     cout << fileTable+k*O2_FILETABLESIZE << " " << segTable[k] << endl;
+    j+=segTable[k];
+  }
 
   status(dbName);
 }
@@ -930,7 +944,10 @@
     pointQuery(dbName, inFile, adbQueryResult);
     break;
   case O2_FLAG_SEQUENCE_QUERY:
-    segSequenceQuery(dbName, inFile, adbQueryResult);
+    if(radius==0)
+      segSequenceQuery(dbName, inFile, adbQueryResult);
+    else
+      segSequenceQueryEuc(dbName, inFile, adbQueryResult);
     break;
   case O2_FLAG_SEG_QUERY:
     segPointQuery(dbName, inFile, adbQueryResult);
@@ -1466,7 +1483,7 @@
     cerr << "processedSegs: " << processedSegs << endl;
   SILENCE_THRESH/=processedSegs;
   USE_THRESH=1; // Turn thresholding on
-  DIFF_THRESH=SILENCE_THRESH/=2; // 50% of the mean shingle power
+  DIFF_THRESH=SILENCE_THRESH/2; // 50% of the mean shingle power
   SILENCE_THRESH/=10; // 10% of the mean shingle power is SILENCE
   
   w=sequenceLength-1;
@@ -1719,12 +1736,17 @@
 	// Calculate the mean of the N-Best matches
 	thisDist=0.0;
 	for(m=0; m<pointNN; m++)
-	  thisDist+=distances[m];
+	  if(distances[m]<0.000001){ // Stop rubbish songs getting good scores
+	    thisDist=0.0;
+	    break;
+	  }
+	  else
+	    thisDist+=distances[m];
 	thisDist/=pointNN;
 	
 	// Let's see the distances then...
 	if(verbosity>3)
-	  cerr << "d[" << fileTable+seg*O2_FILETABLESIZE << "]=" << thisDist << endl;
+	  cerr << fileTable+seg*O2_FILETABLESIZE << " " << thisDist << endl;
 
 	// All the seg stuff goes here
 	n=segNN;
@@ -1823,6 +1845,475 @@
 
 }
 
+// NBest matched filter distance between query and target segs
+// efficient implementation
+// outputs average of N minimum matched filter distances
+void audioDB::segSequenceQueryEuc(const char* dbName, const char* inFile, adb__queryResult *adbQueryResult){
+  
+  initTables(dbName, inFile);
+  
+  // For each input vector, find the closest pointNN matching output vectors and report
+  // we use stdout in this stub version
+  unsigned numVectors = (statbuf.st_size-sizeof(int))/(sizeof(double)*dbH->dim);
+  unsigned numSegs = dbH->numFiles;
+  
+  double* query = (double*)(indata+sizeof(int));
+  double* data = dataBuf;
+  double* queryCopy = 0;
+
+  double qMeanL2;
+  double* sMeanL2;
+
+  unsigned USE_THRESH=0;
+  double SILENCE_THRESH=0;
+  double DIFF_THRESH=0;
+
+  if(!(dbH->flags & O2_FLAG_L2NORM) )
+    error("Database must be L2 normed for sequence query","use -l2norm");
+  
+  if(verbosity>1)
+    cerr << "performing norms ... "; cerr.flush();
+  unsigned dbVectors = dbH->length/(sizeof(double)*dbH->dim);
+  // Make a copy of the query
+  queryCopy = new double[numVectors*dbH->dim];
+  memcpy(queryCopy, query, numVectors*dbH->dim*sizeof(double));
+  qNorm = new double[numVectors];
+  sNorm = new double[dbVectors];
+  sMeanL2=new double[dbH->numFiles];
+  assert(qNorm&&sNorm&&queryCopy&&sMeanL2&&sequenceLength);    
+  unitNorm(queryCopy, dbH->dim, numVectors, qNorm);
+  query = queryCopy;
+  // Make norm measurements relative to sequenceLength
+  unsigned w = sequenceLength-1;
+  unsigned i,j;
+  double* ps;
+  double tmp1,tmp2;
+  // Copy the L2 norm values to core to avoid disk random access later on
+  memcpy(sNorm, l2normTable, dbVectors*sizeof(double));
+  double* snPtr = sNorm;
+  for(i=0; i<dbH->numFiles; i++){
+    if(segTable[i]>=sequenceLength){
+      tmp1=*snPtr;
+      j=1;
+      w=sequenceLength-1;
+      while(w--)
+	*snPtr+=snPtr[j++];
+      ps = snPtr+1;
+      w=segTable[i]-sequenceLength; // +1 - 1
+      while(w--){
+	tmp2=*ps;
+	*ps=*(ps-1)-tmp1+*(ps+sequenceLength-1);
+	tmp1=tmp2;
+	ps++;
+      }
+      ps = snPtr;
+      w=segTable[i]-sequenceLength+1;
+      while(w--){
+	*ps=sqrt(*ps);
+	ps++;
+      }
+    }
+    snPtr+=segTable[i];
+  }
+  
+  double* pn = sMeanL2;
+  w=dbH->numFiles;
+  while(w--)
+    *pn++=0.0;
+  ps=sNorm;
+  unsigned processedSegs=0;
+  for(i=0; i<dbH->numFiles; i++){
+    if(segTable[i]>sequenceLength-1){
+      w = segTable[i]-sequenceLength;
+      pn = sMeanL2+i;
+      *pn=0;
+      while(w--)
+	if(*ps>0)
+	  *pn+=*ps++;
+      *pn/=segTable[i]-sequenceLength;
+      SILENCE_THRESH+=*pn;
+      processedSegs++;
+    }
+    ps = sNorm + segTable[i];
+  }
+  if(verbosity>1)
+    cerr << "processedSegs: " << processedSegs << endl;
+
+    
+  SILENCE_THRESH/=processedSegs;
+  USE_THRESH=1; // Turn thresholding on
+  DIFF_THRESH=SILENCE_THRESH; // 50% of the mean shingle power
+  SILENCE_THRESH/=5; // 20% of the mean shingle power is SILENCE
+  if(verbosity>4)
+    cerr << "silence thresh: " << SILENCE_THRESH;
+  w=sequenceLength-1;
+  i=1;
+  tmp1=*qNorm;
+  while(w--)
+    *qNorm+=qNorm[i++];
+  ps = qNorm+1;
+  w=numVectors-sequenceLength; // +1 -1
+  while(w--){
+    tmp2=*ps;
+    *ps=*(ps-1)-tmp1+*(ps+sequenceLength-1);
+    tmp1=tmp2;
+    ps++;
+  }
+  ps = qNorm;
+  qMeanL2 = 0;
+  w=numVectors-sequenceLength+1;
+  while(w--){
+    *ps=sqrt(*ps);
+    qMeanL2+=*ps++;
+  }
+  qMeanL2 /= numVectors-sequenceLength+1;
+
+  if(verbosity>1)
+    cerr << "done." << endl;    
+  
+  
+  if(verbosity>1)
+    cerr << "matching segs..." << endl;
+  
+  assert(pointNN>0 && pointNN<=O2_MAXNN);
+  assert(segNN>0 && segNN<=O2_MAXNN);
+  
+  // Make temporary dynamic memory for results
+  double segDistances[segNN];
+  unsigned segIDs[segNN];
+  unsigned segQIndexes[segNN];
+  unsigned segSIndexes[segNN];
+  
+  double distances[pointNN];
+  unsigned qIndexes[pointNN];
+  unsigned sIndexes[pointNN];
+  
+
+  unsigned k,l,m,n,seg,segOffset=0, HOP_SIZE=sequenceHop, wL=sequenceLength;
+  double thisDist;
+  double oneOverWL=1.0/wL;
+  
+  for(k=0; k<pointNN; k++){
+    distances[k]=0.0;
+    qIndexes[k]=~0;
+    sIndexes[k]=~0;    
+  }
+  
+  for(k=0; k<segNN; k++){
+    segDistances[k]=0.0;
+    segQIndexes[k]=~0;
+    segSIndexes[k]=~0;
+    segIDs[k]=~0;
+  }
+
+  // Timestamp and durations processing
+  double meanQdur = 0;
+  double* timesdata = 0;
+  double* meanDBdur = 0;
+  
+  if(usingTimes && !(dbH->flags & O2_FLAG_TIMES)){
+    cerr << "warning: ignoring query timestamps for non-timestamped database" << endl;
+    usingTimes=0;
+  }
+  
+  else if(!usingTimes && (dbH->flags & O2_FLAG_TIMES))
+    cerr << "warning: no timestamps given for query. Ignoring database timestamps." << endl;
+  
+  else if(usingTimes && (dbH->flags & O2_FLAG_TIMES)){
+    timesdata = new double[numVectors];
+    assert(timesdata);
+    insertTimeStamps(numVectors, timesFile, timesdata);
+    // Calculate durations of points
+    for(k=0; k<numVectors-1; k++){
+      timesdata[k]=timesdata[k+1]-timesdata[k];
+      meanQdur+=timesdata[k];
+    }
+    meanQdur/=k;
+    if(verbosity>1)
+      cerr << "mean query file duration: " << meanQdur << endl;
+    meanDBdur = new double[dbH->numFiles];
+    assert(meanDBdur);
+    for(k=0; k<dbH->numFiles; k++){
+      meanDBdur[k]=0.0;
+      for(j=0; j<segTable[k]-1 ; j++)
+	meanDBdur[k]+=timesTable[j+1]-timesTable[j];
+      meanDBdur[k]/=j;
+    }
+  }
+
+  if(usingQueryPoint)
+    if(queryPoint>numVectors || queryPoint>numVectors-wL+1)
+      error("queryPoint > numVectors-wL+1 in query");
+    else{
+      if(verbosity>1)
+	cerr << "query point: " << queryPoint << endl; cerr.flush();
+      query=query+queryPoint*dbH->dim;
+      qNorm=qNorm+queryPoint;
+      numVectors=wL;
+    }
+  
+  double ** D = 0;    // Differences query and target 
+  double ** DD = 0;   // Matched filter distance
+
+  D = new double*[numVectors];
+  assert(D);
+  DD = new double*[numVectors];
+  assert(DD);
+
+  gettimeofday(&tv1, NULL); 
+  processedSegs=0;
+  unsigned successfulSegs=0;
+
+  double* qp;
+  double* sp;
+  double* dp;
+  double diffL2;
+
+  // build segment offset table
+  unsigned *segOffsetTable = new unsigned[dbH->numFiles];
+  unsigned cumSeg=0;
+  unsigned segIndexOffset;
+  for(k=0; k<dbH->numFiles;k++){
+    segOffsetTable[k]=cumSeg;
+    cumSeg+=segTable[k]*dbH->dim;
+  }
+
+  char nextKey [MAXSTR];
+
+  // chi^2 statistics
+  double sampleCount = 0;
+  double sampleSum = 0;
+  double logSampleSum = 0;
+  double minSample = 1e9;
+  double maxSample = 0;
+
+  // Track loop 
+  for(processedSegs=0, seg=0 ; processedSegs < dbH->numFiles ; seg++, processedSegs++){
+
+    // get segID from file if using a control file
+    if(segFile){
+      if(!segFile->eof()){
+	segFile->getline(nextKey,MAXSTR);
+	seg=getKeyPos(nextKey);
+      }
+      else
+	break;
+    }
+
+    segOffset=segOffsetTable[seg];     // numDoubles offset
+    segIndexOffset=segOffset/dbH->dim; // numVectors offset
+
+    if(sequenceLength<segTable[seg]){  // test for short sequences
+      
+      if(verbosity>7)
+	cerr << seg << "." << segIndexOffset << "." << segTable[seg] << " | ";cerr.flush();
+		
+      // Sum products matrix
+      for(j=0; j<numVectors;j++){
+	D[j]=new double[segTable[seg]]; 
+	assert(D[j]);
+
+      }
+
+      // Matched filter matrix
+      for(j=0; j<numVectors;j++){
+	DD[j]=new double[segTable[seg]];
+	assert(DD[j]);
+      }
+
+      double tmp;
+      // Dot product
+      for(j=0; j<numVectors; j++)
+	for(k=0; k<segTable[seg]; k++){
+	  qp=query+j*dbH->dim;
+	  sp=dataBuf+segOffset+k*dbH->dim;
+	  DD[j][k]=0.0; // Initialize matched filter array
+	  dp=&D[j][k];  // point to correlation cell j,k
+	  *dp=0.0;      // initialize correlation cell
+	  l=dbH->dim;         // size of vectors
+	  while(l--)
+	    *dp+=*qp++**sp++;
+	}
+  
+      // Matched Filter
+      // HOP SIZE == 1
+      double* spd;
+      if(HOP_SIZE==1){ // HOP_SIZE = shingleHop
+	for(w=0; w<wL; w++)
+	  for(j=0; j<numVectors-w; j++){ 
+	    sp=DD[j];
+	    spd=D[j+w]+w;
+	    k=segTable[seg]-w;
+	    while(k--)
+	      *sp+++=*spd++;
+	  }
+      }
+
+      else{ // HOP_SIZE != 1
+	for(w=0; w<wL; w++)
+	  for(j=0; j<numVectors-w; j+=HOP_SIZE){
+	    sp=DD[j];
+	    spd=D[j+w]+w;
+	    for(k=0; k<segTable[seg]-w; k+=HOP_SIZE){
+	      *sp+=*spd;
+	      sp+=HOP_SIZE;
+	      spd+=HOP_SIZE;
+	    }
+	  }
+      }
+      
+      if(verbosity>3 && usingTimes){
+	cerr << "meanQdur=" << meanQdur << " meanDBdur=" << meanDBdur[seg] << endl;
+	cerr.flush();
+      }
+
+      if(!usingTimes || 
+	 (usingTimes 
+	  && fabs(meanDBdur[seg]-meanQdur)<meanQdur*timesTol)){
+
+	if(verbosity>3 && usingTimes){
+	  cerr << "within duration tolerance." << endl;
+	  cerr.flush();
+	}
+
+	// Search for minimum distance by shingles (concatenated vectors)
+	for(j=0;j<numVectors-wL;j+=HOP_SIZE)
+	  for(k=0;k<segTable[seg]-wL;k+=HOP_SIZE){
+	    thisDist=2-(2/(qNorm[j]*sNorm[segIndexOffset+k]))*DD[j][k];
+	    if(verbosity>10)
+	      cerr << thisDist << " " << qNorm[j] << " " << sNorm[segIndexOffset+k] << endl;
+	    // Gather chi^2 statistics
+	    if(thisDist<minSample)
+	      minSample=thisDist;
+	    else if(thisDist>maxSample)
+	      maxSample=thisDist;
+	    if(thisDist>1e-9){
+	      sampleCount++;
+	      sampleSum+=thisDist;
+	      logSampleSum+=log(thisDist);
+	    }
+
+	    diffL2 = fabs(qNorm[j] - sNorm[segIndexOffset+k]);
+	    // Power test
+	    if(!USE_THRESH || 
+	       // Threshold on mean L2 of Q and S sequences
+	       (USE_THRESH && qNorm[j]>SILENCE_THRESH && sNorm[segIndexOffset+k]>SILENCE_THRESH && 
+		// Are both query and target windows above mean energy?
+		(qNorm[j]>qMeanL2*.25 && sNorm[segIndexOffset+k]>sMeanL2[seg]*.25))) // &&  diffL2 < DIFF_THRESH )))
+	      thisDist=thisDist; // Computed above
+	    else
+	      thisDist=1000000.0;
+	    if(thisDist>=0 && thisDist<=radius){
+	      distances[0]++; // increment count
+	      break; // only need one seg point per query point
+	    }
+	  }
+	// How many points were below threshold ?
+	thisDist=distances[0];
+	
+	// Let's see the distances then...
+	if(verbosity>3)
+	  cerr << fileTable+seg*O2_FILETABLESIZE << " " << thisDist << endl;
+
+	// All the seg stuff goes here
+	n=segNN;
+	while(n--){
+	  if(thisDist>segDistances[n]){
+	    if((n==0 || thisDist<=segDistances[n-1])){
+	      // Copy all values above up the queue
+	      for( l=segNN-1 ; l > n ; l--){
+		segDistances[l]=segDistances[l-1];
+		segQIndexes[l]=segQIndexes[l-1];
+		segSIndexes[l]=segSIndexes[l-1];
+		segIDs[l]=segIDs[l-1];
+	      }
+	      segDistances[n]=thisDist;
+	      segQIndexes[n]=qIndexes[0];
+	      segSIndexes[n]=sIndexes[0];
+	      successfulSegs++;
+	      segIDs[n]=seg;
+	      break;
+	    }
+	  }
+	  else
+	    break;
+	}
+      } // Duration match
+            
+      // Clean up current seg
+      if(D!=NULL){
+	for(j=0; j<numVectors; j++)
+	  delete[] D[j];
+      }
+
+      if(DD!=NULL){
+	for(j=0; j<numVectors; j++)
+	  delete[] DD[j];
+      }
+    }
+    // per-seg reset array values
+    for(unsigned k=0; k<pointNN; k++){
+      distances[k]=0.0;
+      qIndexes[k]=~0;
+      sIndexes[k]=~0;    
+    }
+  }
+
+  gettimeofday(&tv2,NULL);
+  if(verbosity>1){
+    cerr << endl << "processed segs :" << processedSegs << " matched segments: " << successfulSegs << " elapsed time:" 
+	 << ( tv2.tv_sec*1000 + tv2.tv_usec/1000 ) - ( tv1.tv_sec*1000+tv1.tv_usec/1000 ) << " msec" << endl;
+    cerr << "sampleCount: " << sampleCount << " sampleSum: " << sampleSum << " logSampleSum: " << logSampleSum 
+	 << " minSample: " << minSample << " maxSample: " << maxSample << endl;
+  }
+  
+  if(adbQueryResult==0){
+    if(verbosity>1)
+      cerr<<endl;
+    // Output answer
+    // Loop over nearest neighbours
+    for(k=0; k < min(segNN,successfulSegs); k++)
+      cout << fileTable+segIDs[k]*O2_FILETABLESIZE << " " << segDistances[k] << endl;
+  }
+  else{ // Process Web Services Query
+    int listLen = min(segNN, processedSegs);
+    adbQueryResult->__sizeRlist=listLen;
+    adbQueryResult->__sizeDist=listLen;
+    adbQueryResult->__sizeQpos=listLen;
+    adbQueryResult->__sizeSpos=listLen;
+    adbQueryResult->Rlist= new char*[listLen];
+    adbQueryResult->Dist = new double[listLen];
+    adbQueryResult->Qpos = new int[listLen];
+    adbQueryResult->Spos = new int[listLen];
+    for(k=0; k<adbQueryResult->__sizeRlist; k++){
+      adbQueryResult->Rlist[k]=new char[O2_MAXFILESTR];
+      adbQueryResult->Dist[k]=segDistances[k];
+      adbQueryResult->Qpos[k]=segQIndexes[k];
+      adbQueryResult->Spos[k]=segSIndexes[k];
+      sprintf(adbQueryResult->Rlist[k], "%s", fileTable+segIDs[k]*O2_FILETABLESIZE);
+    }
+  }
+
+
+  // Clean up
+  if(segOffsetTable)
+    delete[] segOffsetTable;
+  if(queryCopy)
+    delete[] queryCopy;
+  //if(qNorm)
+  //delete qNorm;
+  if(D)
+    delete[] D;
+  if(DD)
+    delete[] DD;
+  if(timesdata)
+    delete[] timesdata;
+  if(meanDBdur)
+    delete[] meanDBdur;
+
+
+}
+
 void audioDB::normalize(double* X, int dim, int n){
   unsigned c = n*dim;
   double minval,maxval,v,*p;
@@ -1872,15 +2363,17 @@
       L2+=*p**p;
       p++;
     }
-    L2=sqrt(L2);
+    /*    L2=sqrt(L2);*/
     if(qNorm)
       *qNorm++=L2;
+    /*
     oneOverL2 = 1.0/L2;
     d=dim;
     while(d--){
       *X*=oneOverL2;
       X++;
-    }
+    */
+    X+=dim;
   }
   if(verbosity>2)
     cerr << "done..." << endl;
@@ -1913,13 +2406,16 @@
       *l2ptr+=*p**p;
       p++;
     }
-    *l2ptr=sqrt(*l2ptr);
-    oneOverL2 = 1.0/(*l2ptr++);
-    d=dim;
-    while(d--){
+    l2ptr++;
+    /*
+      oneOverL2 = 1.0/(*l2ptr++);
+      d=dim;
+      while(d--){
       *X*=oneOverL2;
       X++;
-    }
+      }
+    */
+    X+=dim;
   }
   unsigned offset;
   if(append)
@@ -1928,7 +2424,7 @@
     offset=0;
   memcpy(l2normTable+offset, l2buf, n*sizeof(double));
   if(l2buf)
-    delete l2buf;
+    delete[] l2buf;
   if(verbosity>2)
     cerr << "done..." << endl;
 }
@@ -2038,5 +2534,3 @@
 int main(const unsigned argc, char* const argv[]){
   audioDB(argc, argv);
 }
-
-