changeset 248:5682c7d7444b

Added new query type: nsequence : this reports the n-nearest neighbours for each track in a sequence averaging query. Next up is the same for the radius search (still accessed via nsequence but with -R set non zero)
author mas01mc
date Sun, 17 Feb 2008 14:56:02 +0000
parents 0d99b008fd6b
children 1da9a9ed55a3
files audioDB.cpp audioDB.h gengetopt.in query.cpp reporter.h
diffstat 5 files changed, 124 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/audioDB.cpp	Mon Dec 17 16:17:45 2007 +0000
+++ b/audioDB.cpp	Sun Feb 17 14:56:02 2008 +0000
@@ -306,6 +306,8 @@
       queryType=O2_POINT_QUERY;
     else if(strncmp(args_info.QUERY_arg, "sequence", MAXSTR)==0)
       queryType=O2_SEQUENCE_QUERY;
+    else if(strncmp(args_info.QUERY_arg, "nsequence", MAXSTR)==0)
+      queryType=O2_N_SEQUENCE_QUERY;
     else
       error("unsupported query type",args_info.QUERY_arg);
     
--- a/audioDB.h	Mon Dec 17 16:17:45 2007 +0000
+++ b/audioDB.h	Sun Feb 17 14:56:02 2008 +0000
@@ -54,6 +54,7 @@
 #define O2_DEFAULT_POINTNN (10U)
 #define O2_DEFAULT_TRACKNN  (10U)
 
+//#define O2_DEFAULTDBSIZE (4000000000) // 4GB table size
 #define O2_DEFAULTDBSIZE (2000000000) // 2GB table size
 
 #define O2_MAXFILES (20000U)
@@ -75,6 +76,8 @@
 #define O2_POINT_QUERY (0x4U)
 #define O2_SEQUENCE_QUERY (0x8U)
 #define O2_TRACK_QUERY (0x10U)
+#define O2_N_SEQUENCE_QUERY (0x20U)
+
 
 // Error Codes
 #define O2_ERR_KEYNOTFOUND (0xFFFFFF00)
--- a/gengetopt.in	Mon Dec 17 16:17:45 2007 +0000
+++ b/gengetopt.in	Sun Feb 17 14:56:02 2008 +0000
@@ -33,7 +33,7 @@
 
 section "Database Search" sectiondesc="Thse commands control the retrieval behaviour.\n"
 
-option "QUERY" Q "content-based search on --database using --features as a query. Optionally restrict the search to those tracks identified in a --keyList." values="point","track","sequence" typestr="searchtype" dependon="database" dependon="features" optional
+option "QUERY" Q "content-based search on --database using --features as a query. Optionally restrict the search to those tracks identified in a --keyList." values="point","track","sequence", "nsequence" typestr="searchtype" dependon="database" dependon="features" optional
 option "qpoint" p "ordinal position of query start point in --features file." int typestr="position" default="0" optional
 option "exhaustive" e "exhaustive search: iterate through all query vectors in search. Overrides --qpoint." flag off optional hidden
 option "pointnn" n "number of point nearest neighbours to use in retrieval." int typestr="numpoints" default="10" optional
--- a/query.cpp	Mon Dec 17 16:17:45 2007 +0000
+++ b/query.cpp	Sun Feb 17 14:56:02 2008 +0000
@@ -37,6 +37,13 @@
       r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles);
     }
     break;
+  case O2_N_SEQUENCE_QUERY :
+    if(radius == 0) {
+      r = new trackSequenceQueryNNReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles);
+    } else {
+      r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles);
+    }
+    break;
   default:
     error("unrecognized queryType in query()");
   }  
--- a/reporter.h	Mon Dec 17 16:17:45 2007 +0000
+++ b/reporter.h	Sun Feb 17 14:56:02 2008 +0000
@@ -285,3 +285,114 @@
     // FIXME
   }
 }
+
+
+template <class T> class trackSequenceQueryNNReporter : public Reporter {
+ public:
+  trackSequenceQueryNNReporter(unsigned int pointNN, unsigned int trackNN, unsigned int numFiles);
+  ~trackSequenceQueryNNReporter();
+  void add_point(unsigned int trackID, unsigned int qpos, unsigned int spos, double dist);
+  void report(char *fileTable, adb__queryResponse *adbQueryResponse);
+ private:
+  unsigned int pointNN;
+  unsigned int trackNN;
+  unsigned int numFiles;
+  std::priority_queue< NNresult, std::vector< NNresult>, T > *queues;  
+};
+
+template <class T> trackSequenceQueryNNReporter<T>::trackSequenceQueryNNReporter(unsigned int pointNN, unsigned int trackNN, unsigned int numFiles) 
+  : pointNN(pointNN), trackNN(trackNN), numFiles(numFiles) {
+  queues = new std::priority_queue< NNresult, std::vector< NNresult>, T >[numFiles];
+}
+
+template <class T> trackSequenceQueryNNReporter<T>::~trackSequenceQueryNNReporter() {
+  delete [] queues;
+}
+
+template <class T> void trackSequenceQueryNNReporter<T>::add_point(unsigned int trackID, unsigned int qpos, unsigned int spos, double dist) {
+  if (!isnan(dist)) {
+    NNresult r;
+    r.trackID = trackID;
+    r.qpos = qpos;
+    r.spos = spos;
+    r.dist = dist;
+    queues[trackID].push(r);
+    if(queues[trackID].size() > pointNN) {
+      queues[trackID].pop();
+    }
+  }
+}
+
+template <class T> void trackSequenceQueryNNReporter<T>::report(char *fileTable, adb__queryResponse *adbQueryResponse) {
+  std::priority_queue < NNresult, std::vector< NNresult>, T> result;
+  std::priority_queue< NNresult, std::vector< NNresult>, std::greater<NNresult> > *point_queues = new std::priority_queue< NNresult, std::vector< NNresult>, std::greater<NNresult> >[numFiles];
+
+  for (int i = numFiles-1; i >= 0; i--) {
+    unsigned int size = queues[i].size();
+    if (size > 0) {
+      NNresult r;
+      double dist = 0;
+      NNresult oldr = queues[i].top();
+      for (unsigned int j = 0; j < size; j++) {
+        r = queues[i].top();
+        dist += r.dist;
+	point_queues[i].push(r);
+        queues[i].pop();
+        if (r.dist == oldr.dist) {
+          r.qpos = oldr.qpos;
+          r.spos = oldr.spos;
+        } else {
+          oldr = r;
+        }
+      }
+      dist /= size;
+      r.dist = dist; // trackID, qpos and spos are magically right already.
+      result.push(r);
+      if (result.size() > trackNN) {
+        result.pop();
+      }
+    }
+  }
+
+  NNresult r;
+  std::vector<NNresult> v;
+  unsigned int size = result.size();
+  for(unsigned int k = 0; k < size; k++) {
+    r = result.top();
+    v.push_back(r);
+    result.pop();
+  }
+  std::vector<NNresult>::reverse_iterator rit;
+      
+  if(adbQueryResponse==0) {
+    for(rit = v.rbegin(); rit < v.rend(); rit++) {
+      r = *rit;
+      std::cout << fileTable + r.trackID*O2_FILETABLESIZE << std::endl;
+      for(int k=0; k < (int)pointNN; k++){
+	NNresult rk = point_queues[r.trackID].top();
+	std::cout << rk.dist << " " << rk.qpos << " " << rk.spos << std::endl;
+	point_queues[r.trackID].pop();
+      }
+    }
+  } else {
+    adbQueryResponse->result.__sizeRlist=size;
+    adbQueryResponse->result.__sizeDist=size;
+    adbQueryResponse->result.__sizeQpos=size;
+    adbQueryResponse->result.__sizeSpos=size;
+    adbQueryResponse->result.Rlist= new char*[size];
+    adbQueryResponse->result.Dist = new double[size];
+    adbQueryResponse->result.Qpos = new unsigned int[size];
+    adbQueryResponse->result.Spos = new unsigned int[size];
+    unsigned int k = 0;
+    for(rit = v.rbegin(); rit < v.rend(); rit++, k++) {
+      r = *rit;
+      adbQueryResponse->result.Rlist[k] = new char[O2_MAXFILESTR];
+      adbQueryResponse->result.Dist[k] = r.dist;
+      adbQueryResponse->result.Qpos[k] = r.qpos;
+      adbQueryResponse->result.Spos[k] = r.spos;
+      snprintf(adbQueryResponse->result.Rlist[k], O2_MAXFILESTR, "%s", fileTable+r.trackID*O2_FILETABLESIZE);
+    }
+  }
+  // clean up
+  delete[] point_queues;
+}