diff query.cpp @ 768:b9dbe4611dde

Adding Kullback-Leibler divergence as alternate distance function
author mas01mc
date Sat, 15 Oct 2011 17:28:07 +0000
parents ddf08008d45b
children
line wrap: on
line diff
--- a/query.cpp	Thu Jun 02 16:31:44 2011 +0000
+++ b/query.cpp	Sat Oct 15 17:28:07 2011 +0000
@@ -55,6 +55,7 @@
     break;
   case ADB_DISTANCE_EUCLIDEAN_NORMED:
   case ADB_DISTANCE_EUCLIDEAN:
+  case ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE:
     switch(qspec->params.accumulation) {
     case ADB_ACCUMULATION_DB:
       qstate.accumulator = new DBAccumulator<adb_result_dist_lt>(qspec->params.npoints);
@@ -108,6 +109,10 @@
 static void audiodb_initialize_arrays(adb_t *adb, const adb_query_spec_t *spec, int track, unsigned int numVectors, double *query, double *data_buffer, double **D, double **DD) {
   unsigned int j, k, l, w;
   double *dp, *qp, *sp;
+  double a,b, tmp1;
+#ifdef SYMMETRIC_KL
+  double tmp2;
+#endif
 
   const unsigned wL = spec->qid.sequence_length;
 
@@ -127,8 +132,27 @@
       dp = &D[j][k];  // point to correlation cell j,k
       *dp = 0.0;      // initialize correlation cell
       l = adb->header->dim;         // size of vectors
-      while(l--)
-        *dp += *qp++ * *sp++;
+      if (spec->params.distance!=ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE){
+	while(l--)
+	  *dp += *qp++ * *sp++;
+      }
+      else{ // KL
+	while(l--){
+	  a = *qp++;
+	  b = *sp++;    
+	  tmp1 = a * log( a / b );
+	  if(isnan(tmp1))
+	    tmp1=0.0;
+#ifdef SYMMETRIC_KL
+	  tmp2 = b * log( b / a );
+	  if(isnan(tmp2))
+	    tmp2=0.0;
+	  *dp += ( tmp1 + tmp2 ) / 2.0;
+#else
+	  *dp += tmp1;
+#endif
+	}
+      }
     }
   
   double* spd;
@@ -461,7 +485,11 @@
     if( ( (!power_refine) || audiodb_powers_acceptable(&spec->refine, qpointers->power[qPos], dbpointers.power[sPos])) &&
 	( qPos<qpointers->nvectors-sequence_length+1 && sPos<(*adb->track_lengths)[pp.trackID]-sequence_length+1 ) ){
       // Compute distance
-      dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length);
+      dist = 1.0e9;
+      if (spec->params.distance==ADB_DISTANCE_EUCLIDEAN_NORMED || spec->params.distance==ADB_DISTANCE_EUCLIDEAN)
+	dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length);
+      else if(spec->params.distance==ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE)
+	dist = audiodb_kullback_leibler(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length);
       double qn = qpointers->l2norm[qPos];
       double sn = dbpointers.l2norm[sPos];
       switch(spec->params.distance) {