diff src/EM.cpp @ 36:74b77a4d6552

Fill out the EM logic
author Chris Cannam
date Fri, 04 Apr 2014 17:48:06 +0100
parents 461d94ed3816
children 947996aac974
line wrap: on
line diff
--- a/src/EM.cpp	Fri Apr 04 14:38:40 2014 +0100
+++ b/src/EM.cpp	Fri Apr 04 17:48:06 2014 +0100
@@ -17,6 +17,16 @@
 
 #include "data/include/templates.h"
 
+#include <cstdlib>
+
+#include <iostream>
+
+#include <vector>
+
+using std::vector;
+using std::cerr;
+using std::endl;
+
 static double epsilon = 1e-16;
 
 EM::EM() :
@@ -24,6 +34,8 @@
     m_bins(SILVET_TEMPLATE_HEIGHT),
     m_instruments(SILVET_TEMPLATE_COUNT)
 {
+    cerr << "init!" << endl;
+
     m_lowest = 0;
     m_highest = m_notes - 1;
 
@@ -38,7 +50,7 @@
 
     m_pitches = V(m_notes);
 
-    for (int n = 0; n < m_notes; ++i) {
+    for (int n = 0; n < m_notes; ++n) {
         m_pitches[n] = drand48();
     }
     
@@ -51,11 +63,10 @@
         }
     }
 
+    m_estimate = V(m_bins);
     m_q = V(m_bins);
-    
-    for (int w = 0; w < m_bins; ++w) {
-        m_q[w] = epsilon;
-    }
+
+    cerr << "(init done)" << endl;
 }
 
 EM::~EM()
@@ -69,3 +80,104 @@
             note <= silvet_templates[instrument].highest);
 }
 
+void
+EM::normalise(V &column)
+{
+    double sum = 0.0;
+    for (int i = 0; i < (int)column.size(); ++i) {
+        sum += column[i];
+    }
+    for (int i = 0; i < (int)column.size(); ++i) {
+        column[i] /= sum;
+    }
+}
+
+void
+EM::iterate(V column)
+{
+    normalise(column);
+    expectation(column);
+    maximisation(column);
+}
+
+void
+EM::expectation(const V &column)
+{
+    cerr << ".";
+
+    for (int i = 0; i < m_bins; ++i) {
+        m_estimate[i] = epsilon;
+    }
+
+    for (int i = 0; i < m_instruments; ++i) {
+        for (int n = 0; n < m_notes; ++n) {
+            float *w = silvet_templates[i].data[n];
+            double pitch = m_pitches[n];
+            double source = m_sources[i][n];
+            for (int j = 0; j < m_bins; ++j) {
+                m_estimate[j] += w[j] * pitch * source;
+            }
+        }
+    }
+
+    for (int i = 0; i < m_bins; ++i) {
+        m_q[i] = column[i] / m_estimate[i];
+    }
+}
+
+void
+EM::maximisation(const V &column)
+{
+    V newPitches = m_pitches;
+
+    for (int n = 0; n < m_notes; ++n) {
+        newPitches[n] = epsilon;
+        if (n >= m_lowest && n <= m_highest) {
+            for (int i = 0; i < m_instruments; ++i) {
+                float *w = silvet_templates[i].data[n];
+                double pitch = m_pitches[n];
+                double source = m_sources[i][n];
+                for (int j = 0; j < m_bins; ++j) {
+                    newPitches[n] += w[j] * m_q[j] * pitch * source;
+                }
+            }
+        }
+    }
+    normalise(newPitches);
+
+    Grid newSources = m_sources;
+
+    for (int i = 0; i < m_instruments; ++i) {
+        for (int n = 0; n < m_notes; ++n) {
+            newSources[i][n] = epsilon;
+            if (inRange(i, n)) {
+                float *w = silvet_templates[i].data[n];
+                for (int j = 0; j < m_bins; ++j) {
+                    newSources[i][n] +=
+                        w[j] * m_q[j] * m_pitches[n] * m_sources[i][n];
+                }
+            }
+        }
+        normalise(newSources[i]);
+    }
+
+    m_pitches = newPitches;
+    m_sources = newSources;
+}
+
+void
+EM::report()
+{
+    vector<int> sounding;
+    for (int n = 0; n < m_notes; ++n) {
+        if (m_pitches[n] > 0.05) {
+            sounding.push_back(n);
+        }
+    }
+    cerr << " sounding: ";
+    for (int i = 0; i < (int)sounding.size(); ++i) {
+        cerr << sounding[i] << " ";
+    }
+    cerr << endl;
+}
+