Mercurial > hg > audiodb
diff query.cpp @ 768:b9dbe4611dde
Adding Kullback-Leibler divergence as alternate distance function
author | mas01mc |
---|---|
date | Sat, 15 Oct 2011 17:28:07 +0000 |
parents | ddf08008d45b |
children |
line wrap: on
line diff
--- a/query.cpp Thu Jun 02 16:31:44 2011 +0000 +++ b/query.cpp Sat Oct 15 17:28:07 2011 +0000 @@ -55,6 +55,7 @@ break; case ADB_DISTANCE_EUCLIDEAN_NORMED: case ADB_DISTANCE_EUCLIDEAN: + case ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE: switch(qspec->params.accumulation) { case ADB_ACCUMULATION_DB: qstate.accumulator = new DBAccumulator<adb_result_dist_lt>(qspec->params.npoints); @@ -108,6 +109,10 @@ static void audiodb_initialize_arrays(adb_t *adb, const adb_query_spec_t *spec, int track, unsigned int numVectors, double *query, double *data_buffer, double **D, double **DD) { unsigned int j, k, l, w; double *dp, *qp, *sp; + double a,b, tmp1; +#ifdef SYMMETRIC_KL + double tmp2; +#endif const unsigned wL = spec->qid.sequence_length; @@ -127,8 +132,27 @@ dp = &D[j][k]; // point to correlation cell j,k *dp = 0.0; // initialize correlation cell l = adb->header->dim; // size of vectors - while(l--) - *dp += *qp++ * *sp++; + if (spec->params.distance!=ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE){ + while(l--) + *dp += *qp++ * *sp++; + } + else{ // KL + while(l--){ + a = *qp++; + b = *sp++; + tmp1 = a * log( a / b ); + if(isnan(tmp1)) + tmp1=0.0; +#ifdef SYMMETRIC_KL + tmp2 = b * log( b / a ); + if(isnan(tmp2)) + tmp2=0.0; + *dp += ( tmp1 + tmp2 ) / 2.0; +#else + *dp += tmp1; +#endif + } + } } double* spd; @@ -461,7 +485,11 @@ if( ( (!power_refine) || audiodb_powers_acceptable(&spec->refine, qpointers->power[qPos], dbpointers.power[sPos])) && ( qPos<qpointers->nvectors-sequence_length+1 && sPos<(*adb->track_lengths)[pp.trackID]-sequence_length+1 ) ){ // Compute distance - dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length); + dist = 1.0e9; + if (spec->params.distance==ADB_DISTANCE_EUCLIDEAN_NORMED || spec->params.distance==ADB_DISTANCE_EUCLIDEAN) + dist = audiodb_dot_product(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length); + else if(spec->params.distance==ADB_DISTANCE_KULLBACK_LEIBLER_DIVERGENCE) + dist = audiodb_kullback_leibler(query + qPos*adb->header->dim, dbdata + sPos*adb->header->dim, adb->header->dim*sequence_length); double qn = qpointers->l2norm[qPos]; double sn = dbpointers.l2norm[sPos]; switch(spec->params.distance) {