changeset 97:840c0d703bbb timing

Use single-precision floats throughout EM code
author Chris Cannam
date Tue, 06 May 2014 14:45:16 +0100
parents f1116eb464f9
children faeeb78badec
files data/include/templates.h src/EM.cpp src/EM.h src/Silvet.cpp src/Silvet.h yeti/scratch/generateTemplatesC.yeti
diffstat 6 files changed, 32 insertions(+), 31 deletions(-) [+]
line wrap: on
line diff
--- a/data/include/templates.h	Tue May 06 13:05:43 2014 +0100
+++ b/data/include/templates.h	Tue May 06 14:45:16 2014 +0100
@@ -15,7 +15,7 @@
     const char *name;
     int lowest;
     int highest;
-    double data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];
+    float data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];
 } silvet_template_t;
 
 static int silvet_templates_lowest_note = 15;
--- a/src/EM.cpp	Tue May 06 13:05:43 2014 +0100
+++ b/src/EM.cpp	Tue May 06 14:45:16 2014 +0100
@@ -28,7 +28,7 @@
 using std::cerr;
 using std::endl;
 
-static double epsilon = 1e-16;
+static float epsilon = 1e-8;
 
 EM::EM() :
     m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
@@ -87,7 +87,7 @@
 void
 EM::normaliseColumn(V &column)
 {
-    double sum = 0.0;
+    float sum = 0.0;
     for (int i = 0; i < (int)column.size(); ++i) {
         sum += column[i];
     }
@@ -115,14 +115,15 @@
 }
 
 void
-EM::iterate(V column)
+EM::iterate(const vector<double> &column)
 {
-    normaliseColumn(column);
-    expectation(column);
-    maximisation(column);
+    V norm(column.begin(), column.end());
+    normaliseColumn(norm);
+    expectation(norm);
+    maximisation(norm);
 }
 
-const double *
+const float *
 EM::templateFor(int instrument, int note, int shift)
 {
     return silvet_templates[instrument].data[note] + shift;
@@ -139,12 +140,12 @@
 
     for (int i = 0; i < m_instrumentCount; ++i) {
         for (int n = 0; n < m_noteCount; ++n) {
-            const double pitch = m_pitches[n];
-            const double source = m_sources[i][n];
+            const float pitch = m_pitches[n];
+            const float source = m_sources[i][n];
             for (int f = 0; f < m_shiftCount; ++f) {
-                const double *w = templateFor(i, n, f);
-                const double shift = m_shifts[f][n];
-                const double factor = pitch * source * shift;
+                const float *w = templateFor(i, n, f);
+                const float shift = m_shifts[f][n];
+                const float factor = pitch * source * shift;
                 for (int j = 0; j < m_binCount; ++j) {
                     m_estimate[j] += w[j] * factor;
                 }
@@ -166,17 +167,17 @@
 
     for (int n = 0; n < m_noteCount; ++n) {
 
-        const double pitch = m_pitches[n];
+        const float pitch = m_pitches[n];
 
         for (int f = 0; f < m_shiftCount; ++f) {
 
-            const double shift = m_shifts[f][n];
+            const float shift = m_shifts[f][n];
 
             for (int i = 0; i < m_instrumentCount; ++i) {
 
-                const double source = m_sources[i][n];
-                const double factor = pitch * source * shift;
-                const double *w = templateFor(i, n, f);
+                const float source = m_sources[i][n];
+                const float factor = pitch * source * shift;
+                const float *w = templateFor(i, n, f);
 
                 if (n >= m_lowestPitch && n <= m_highestPitch) {
 
--- a/src/EM.h	Tue May 06 13:05:43 2014 +0100
+++ b/src/EM.h	Tue May 06 14:45:16 2014 +0100
@@ -24,21 +24,21 @@
     EM();
     ~EM();
 
-    void iterate(std::vector<double> column);
+    void iterate(const std::vector<double> &column);
 
-    const std::vector<double> &getEstimate() const { 
+    const std::vector<float> &getEstimate() const { 
 	return m_estimate;
     }
-    const std::vector<double> &getPitchDistribution() const {
+    const std::vector<float> &getPitchDistribution() const {
 	return m_pitches;
     }
-    const std::vector<std::vector<double> > &getSources() const {
+    const std::vector<std::vector<float> > &getSources() const {
 	return m_sources; 
     }
 
 private:
-    typedef std::vector<double> V;
-    typedef std::vector<std::vector<double> > Grid;
+    typedef std::vector<float> V;
+    typedef std::vector<std::vector<float> > Grid;
 
     V m_pitches;
     Grid m_shifts;
@@ -52,8 +52,8 @@
     const int m_binCount;
     const int m_instrumentCount;
     
-    const double m_pitchSparsity;
-    const double m_sourceSparsity;
+    const float m_pitchSparsity;
+    const float m_sourceSparsity;
 
     const int m_lowestPitch;
     const int m_highestPitch;
@@ -63,7 +63,7 @@
     void expectation(const V &column);
     void maximisation(const V &column);
 
-    const double *templateFor(int instrument, int note, int shift);
+    const float *templateFor(int instrument, int note, int shift);
     void rangeFor(int instrument, int &minPitch, int &maxPitch);
     bool inRange(int instrument, int pitch);
 };
--- a/src/Silvet.cpp	Tue May 06 13:05:43 2014 +0100
+++ b/src/Silvet.cpp	Tue May 06 14:45:16 2014 +0100
@@ -401,7 +401,7 @@
             em.iterate(filtered[i]);
         }
 
-        vector<double> pitches = em.getPitchDistribution();
+        vector<float> pitches = em.getPitchDistribution();
         
         for (int j = 0; j < processingNotes; ++j) {
             pitches[j] *= sum;
@@ -498,7 +498,7 @@
 }
     
 Vamp::Plugin::FeatureList
-Silvet::postProcess(const vector<double> &pitches)        
+Silvet::postProcess(const vector<float> &pitches)
 {        
     vector<double> filtered;
 
--- a/src/Silvet.h	Tue May 06 13:05:43 2014 +0100
+++ b/src/Silvet.h	Tue May 06 14:45:16 2014 +0100
@@ -79,7 +79,7 @@
     vector<map<int, double> > m_pianoRoll;
 
     Grid preProcess(const Grid &);
-    FeatureList postProcess(const vector<double> &);
+    FeatureList postProcess(const vector<float> &);
     FeatureSet transcribe(const Grid &);
 
     string noteName(int n) const;
--- a/yeti/scratch/generateTemplatesC.yeti	Tue May 06 13:05:43 2014 +0100
+++ b/yeti/scratch/generateTemplatesC.yeti	Tue May 06 14:45:16 2014 +0100
@@ -106,7 +106,7 @@
         "    const char *name;",
         "    int lowest;",
         "    int highest;",
-        "    double data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];",
+        "    float data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];",
         "} silvet_template_t;",
         "",
         "static int silvet_templates_lowest_note = \(overallLowest);",