Mercurial > hg > audiodb
changeset 248:5682c7d7444b
Added new query type: nsequence : this reports the n-nearest neighbours for each track in a sequence averaging query.
Next up is the same for the radius search (still accessed via nsequence but with -R set non zero)
author | mas01mc |
---|---|
date | Sun, 17 Feb 2008 14:56:02 +0000 |
parents | 0d99b008fd6b |
children | 1da9a9ed55a3 |
files | audioDB.cpp audioDB.h gengetopt.in query.cpp reporter.h |
diffstat | 5 files changed, 124 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/audioDB.cpp Mon Dec 17 16:17:45 2007 +0000 +++ b/audioDB.cpp Sun Feb 17 14:56:02 2008 +0000 @@ -306,6 +306,8 @@ queryType=O2_POINT_QUERY; else if(strncmp(args_info.QUERY_arg, "sequence", MAXSTR)==0) queryType=O2_SEQUENCE_QUERY; + else if(strncmp(args_info.QUERY_arg, "nsequence", MAXSTR)==0) + queryType=O2_N_SEQUENCE_QUERY; else error("unsupported query type",args_info.QUERY_arg);
--- a/audioDB.h Mon Dec 17 16:17:45 2007 +0000 +++ b/audioDB.h Sun Feb 17 14:56:02 2008 +0000 @@ -54,6 +54,7 @@ #define O2_DEFAULT_POINTNN (10U) #define O2_DEFAULT_TRACKNN (10U) +//#define O2_DEFAULTDBSIZE (4000000000) // 4GB table size #define O2_DEFAULTDBSIZE (2000000000) // 2GB table size #define O2_MAXFILES (20000U) @@ -75,6 +76,8 @@ #define O2_POINT_QUERY (0x4U) #define O2_SEQUENCE_QUERY (0x8U) #define O2_TRACK_QUERY (0x10U) +#define O2_N_SEQUENCE_QUERY (0x20U) + // Error Codes #define O2_ERR_KEYNOTFOUND (0xFFFFFF00)
--- a/gengetopt.in Mon Dec 17 16:17:45 2007 +0000 +++ b/gengetopt.in Sun Feb 17 14:56:02 2008 +0000 @@ -33,7 +33,7 @@ section "Database Search" sectiondesc="Thse commands control the retrieval behaviour.\n" -option "QUERY" Q "content-based search on --database using --features as a query. Optionally restrict the search to those tracks identified in a --keyList." values="point","track","sequence" typestr="searchtype" dependon="database" dependon="features" optional +option "QUERY" Q "content-based search on --database using --features as a query. Optionally restrict the search to those tracks identified in a --keyList." values="point","track","sequence", "nsequence" typestr="searchtype" dependon="database" dependon="features" optional option "qpoint" p "ordinal position of query start point in --features file." int typestr="position" default="0" optional option "exhaustive" e "exhaustive search: iterate through all query vectors in search. Overrides --qpoint." flag off optional hidden option "pointnn" n "number of point nearest neighbours to use in retrieval." int typestr="numpoints" default="10" optional
--- a/query.cpp Mon Dec 17 16:17:45 2007 +0000 +++ b/query.cpp Sun Feb 17 14:56:02 2008 +0000 @@ -37,6 +37,13 @@ r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); } break; + case O2_N_SEQUENCE_QUERY : + if(radius == 0) { + r = new trackSequenceQueryNNReporter<std::less < NNresult > >(pointNN, trackNN, dbH->numFiles); + } else { + r = new trackSequenceQueryRadReporter(trackNN, dbH->numFiles); + } + break; default: error("unrecognized queryType in query()"); }
--- a/reporter.h Mon Dec 17 16:17:45 2007 +0000 +++ b/reporter.h Sun Feb 17 14:56:02 2008 +0000 @@ -285,3 +285,114 @@ // FIXME } } + + +template <class T> class trackSequenceQueryNNReporter : public Reporter { + public: + trackSequenceQueryNNReporter(unsigned int pointNN, unsigned int trackNN, unsigned int numFiles); + ~trackSequenceQueryNNReporter(); + void add_point(unsigned int trackID, unsigned int qpos, unsigned int spos, double dist); + void report(char *fileTable, adb__queryResponse *adbQueryResponse); + private: + unsigned int pointNN; + unsigned int trackNN; + unsigned int numFiles; + std::priority_queue< NNresult, std::vector< NNresult>, T > *queues; +}; + +template <class T> trackSequenceQueryNNReporter<T>::trackSequenceQueryNNReporter(unsigned int pointNN, unsigned int trackNN, unsigned int numFiles) + : pointNN(pointNN), trackNN(trackNN), numFiles(numFiles) { + queues = new std::priority_queue< NNresult, std::vector< NNresult>, T >[numFiles]; +} + +template <class T> trackSequenceQueryNNReporter<T>::~trackSequenceQueryNNReporter() { + delete [] queues; +} + +template <class T> void trackSequenceQueryNNReporter<T>::add_point(unsigned int trackID, unsigned int qpos, unsigned int spos, double dist) { + if (!isnan(dist)) { + NNresult r; + r.trackID = trackID; + r.qpos = qpos; + r.spos = spos; + r.dist = dist; + queues[trackID].push(r); + if(queues[trackID].size() > pointNN) { + queues[trackID].pop(); + } + } +} + +template <class T> void trackSequenceQueryNNReporter<T>::report(char *fileTable, adb__queryResponse *adbQueryResponse) { + std::priority_queue < NNresult, std::vector< NNresult>, T> result; + std::priority_queue< NNresult, std::vector< NNresult>, std::greater<NNresult> > *point_queues = new std::priority_queue< NNresult, std::vector< NNresult>, std::greater<NNresult> >[numFiles]; + + for (int i = numFiles-1; i >= 0; i--) { + unsigned int size = queues[i].size(); + if (size > 0) { + NNresult r; + double dist = 0; + NNresult oldr = queues[i].top(); + for (unsigned int j = 0; j < size; j++) { + r = queues[i].top(); + dist += r.dist; + point_queues[i].push(r); + queues[i].pop(); + if (r.dist == oldr.dist) { + r.qpos = oldr.qpos; + r.spos = oldr.spos; + } else { + oldr = r; + } + } + dist /= size; + r.dist = dist; // trackID, qpos and spos are magically right already. + result.push(r); + if (result.size() > trackNN) { + result.pop(); + } + } + } + + NNresult r; + std::vector<NNresult> v; + unsigned int size = result.size(); + for(unsigned int k = 0; k < size; k++) { + r = result.top(); + v.push_back(r); + result.pop(); + } + std::vector<NNresult>::reverse_iterator rit; + + if(adbQueryResponse==0) { + for(rit = v.rbegin(); rit < v.rend(); rit++) { + r = *rit; + std::cout << fileTable + r.trackID*O2_FILETABLESIZE << std::endl; + for(int k=0; k < (int)pointNN; k++){ + NNresult rk = point_queues[r.trackID].top(); + std::cout << rk.dist << " " << rk.qpos << " " << rk.spos << std::endl; + point_queues[r.trackID].pop(); + } + } + } else { + adbQueryResponse->result.__sizeRlist=size; + adbQueryResponse->result.__sizeDist=size; + adbQueryResponse->result.__sizeQpos=size; + adbQueryResponse->result.__sizeSpos=size; + adbQueryResponse->result.Rlist= new char*[size]; + adbQueryResponse->result.Dist = new double[size]; + adbQueryResponse->result.Qpos = new unsigned int[size]; + adbQueryResponse->result.Spos = new unsigned int[size]; + unsigned int k = 0; + for(rit = v.rbegin(); rit < v.rend(); rit++, k++) { + r = *rit; + adbQueryResponse->result.Rlist[k] = new char[O2_MAXFILESTR]; + adbQueryResponse->result.Dist[k] = r.dist; + adbQueryResponse->result.Qpos[k] = r.qpos; + adbQueryResponse->result.Spos[k] = r.spos; + snprintf(adbQueryResponse->result.Rlist[k], O2_MAXFILESTR, "%s", fileTable+r.trackID*O2_FILETABLESIZE); + } + } + // clean up + delete[] point_queues; +}