diff src/EM.cpp @ 151:fc06b6f33021

double -> float in EM (to test)
author Chris Cannam
date Wed, 14 May 2014 19:38:36 +0100
parents d2bc51cc7f57
children 6003a9af43af
line wrap: on
line diff
--- a/src/EM.cpp	Wed May 14 18:09:06 2014 +0100
+++ b/src/EM.cpp	Wed May 14 19:38:36 2014 +0100
@@ -31,7 +31,7 @@
 
 using namespace breakfastquay;
 
-static double epsilon = 1e-16;
+static float epsilon = 1e-10;
 
 EM::EM(bool useShifts) :
     m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
@@ -45,15 +45,15 @@
     m_lowestPitch(silvet_templates_lowest_note),
     m_highestPitch(silvet_templates_highest_note)
 {
-    m_pitches = allocate<double>(m_noteCount);
-    m_updatePitches = allocate<double>(m_noteCount);
+    m_pitches = allocate<float>(m_noteCount);
+    m_updatePitches = allocate<float>(m_noteCount);
     for (int n = 0; n < m_noteCount; ++n) {
         m_pitches[n] = drand48();
     }
 
     if (useShifts) {
-        m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount);
-        m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount);
+        m_shifts = allocate_channels<float>(m_shiftCount, m_noteCount);
+        m_updateShifts = allocate_channels<float>(m_shiftCount, m_noteCount);
         for (int f = 0; f < m_shiftCount; ++f) {
             for (int n = 0; n < m_noteCount; ++n) {
                 m_shifts[f][n] = drand48();
@@ -64,16 +64,16 @@
         m_updateShifts = 0;
     }
     
-    m_sources = allocate_channels<double>(m_sourceCount, m_noteCount);
-    m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount);
+    m_sources = allocate_channels<float>(m_sourceCount, m_noteCount);
+    m_updateSources = allocate_channels<float>(m_sourceCount, m_noteCount);
     for (int i = 0; i < m_sourceCount; ++i) {
         for (int n = 0; n < m_noteCount; ++n) {
             m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
         }
     }
 
-    m_estimate = allocate<double>(m_binCount);
-    m_q = allocate<double>(m_binCount);
+    m_estimate = allocate<float>(m_binCount);
+    m_q = allocate<float>(m_binCount);
 }
 
 EM::~EM()
@@ -104,16 +104,16 @@
 }
 
 void
-EM::normaliseColumn(double *column, int size)
+EM::normaliseColumn(float *column, int size)
 {
-    double sum = v_sum(column, size);
+    float sum = v_sum(column, size);
     v_scale(column, 1.0 / sum, size);
 }
 
 void
-EM::normaliseGrid(double **grid, int size1, int size2)
+EM::normaliseGrid(float **grid, int size1, int size2)
 {
-    double *denominators = allocate_and_zero<double>(size2);
+    float *denominators = allocate_and_zero<float>(size2);
 
     for (int i = 0; i < size1; ++i) {
         for (int j = 0; j < size2; ++j) {
@@ -131,15 +131,15 @@
 void
 EM::iterate(const double *column)
 {
-    double *norm = allocate<double>(m_binCount);
-    v_copy(norm, column, m_binCount);
+    float *norm = allocate<float>(m_binCount);
+    v_convert(norm, column, m_binCount);
     normaliseColumn(norm, m_binCount);
     expectation(norm);
     maximisation(norm);
     deallocate(norm);
 }
 
-const double *
+const float *
 EM::templateFor(int instrument, int note, int shift)
 {
     if (m_shifts) {
@@ -151,7 +151,7 @@
 }
 
 void
-EM::expectation(const double *column)
+EM::expectation(const float *column)
 {
 //    cerr << ".";
 
@@ -159,23 +159,23 @@
 
     for (int f = 0; f < m_shiftCount; ++f) {
 
-        const double *shiftIn = m_shifts ? m_shifts[f] : 0;
+        const float *shiftIn = m_shifts ? m_shifts[f] : 0;
 
         for (int i = 0; i < m_sourceCount; ++i) {
 
-            const double *sourceIn = m_sources[i];
+            const float *sourceIn = m_sources[i];
 
             int lowest, highest;
             rangeFor(i, lowest, highest);
 
             for (int n = lowest; n <= highest; ++n) {
 
-                const double source = sourceIn[n];
-                const double shift = shiftIn ? shiftIn[n] : 1.0;
-                const double pitch = m_pitches[n];
+                const float source = sourceIn[n];
+                const float shift = shiftIn ? shiftIn[n] : 1.0;
+                const float pitch = m_pitches[n];
 
-                const double factor = pitch * source * shift;
-                const double *w = templateFor(i, n, f);
+                const float factor = pitch * source * shift;
+                const float *w = templateFor(i, n, f);
 
                 v_add_with_gain(m_estimate, w, factor, m_binCount);
             }
@@ -191,7 +191,7 @@
 }
 
 void
-EM::maximisation(const double *column)
+EM::maximisation(const float *column)
 {
     v_set(m_updatePitches, epsilon, m_noteCount);
 
@@ -205,34 +205,34 @@
         }
     }
 
-    double *contributions = allocate<double>(m_binCount);
+    float *contributions = allocate<float>(m_binCount);
 
     for (int f = 0; f < m_shiftCount; ++f) {
 
-        const double *shiftIn = m_shifts ? m_shifts[f] : 0;
-        double *shiftOut = m_shifts ? m_updateShifts[f] : 0;
+        const float *shiftIn = m_shifts ? m_shifts[f] : 0;
+        float *shiftOut = m_shifts ? m_updateShifts[f] : 0;
 
         for (int i = 0; i < m_sourceCount; ++i) {
 
-            const double *sourceIn = m_sources[i];
-            double *sourceOut = m_updateSources[i];
+            const float *sourceIn = m_sources[i];
+            float *sourceOut = m_updateSources[i];
 
             int lowest, highest;
             rangeFor(i, lowest, highest);
 
             for (int n = lowest; n <= highest; ++n) {
 
-                const double shift = shiftIn ? shiftIn[n] : 1.0;
-                const double source = sourceIn[n];
-                const double pitch = m_pitches[n];
+                const float shift = shiftIn ? shiftIn[n] : 1.0;
+                const float source = sourceIn[n];
+                const float pitch = m_pitches[n];
 
-                const double factor = pitch * source * shift;
-                const double *w = templateFor(i, n, f);
+                const float factor = pitch * source * shift;
+                const float *w = templateFor(i, n, f);
 
                 v_copy(contributions, w, m_binCount);
                 v_multiply(contributions, m_q, m_binCount);
 
-                double total = factor * v_sum(contributions, m_binCount);
+                float total = factor * v_sum(contributions, m_binCount);
 
                 m_updatePitches[n] += total;
                 sourceOut[n] += total;