changeset 36:74b77a4d6552

Fill out the EM logic
author Chris Cannam
date Fri, 04 Apr 2014 17:48:06 +0100
parents 461d94ed3816
children 947996aac974
files Makefile.inc src/EM.cpp src/EM.h src/Silvet.cpp
diffstat 4 files changed, 136 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/Makefile.inc	Fri Apr 04 14:38:40 2014 +0100
+++ b/Makefile.inc	Fri Apr 04 17:48:06 2014 +0100
@@ -19,8 +19,8 @@
 
 PLUGIN	:= silvet$(PLUGIN_EXT)
 
-VAMP_HEADERS := $(SRC_DIR)/Silvet.h
-VAMP_SOURCES := $(SRC_DIR)/Silvet.cpp $(SRC_DIR)/libmain.cpp
+VAMP_HEADERS := $(SRC_DIR)/Silvet.h $(SRC_DIR)/EM.h
+VAMP_SOURCES := $(SRC_DIR)/Silvet.cpp $(SRC_DIR)/EM.cpp $(SRC_DIR)/libmain.cpp
 
 CQ_HEADERS   := $(CQ_DIR)/CQKernel.h $(CQ_DIR)/ConstantQ.h $(CQ_DIR)/CQInterpolated.h
 CQ_SOURCES   := $(CQ_DIR)/CQKernel.cpp $(CQ_DIR)/ConstantQ.cpp $(CQ_DIR)/CQInterpolated.cpp
@@ -47,16 +47,16 @@
 
 # DO NOT DELETE
 
-src/Silvet.o: src/Silvet.h data/include/templates.h data/include/bassoon.h
-src/Silvet.o: data/include/cello.h data/include/clarinet.h
-src/Silvet.o: data/include/flute.h data/include/guitar.h data/include/horn.h
-src/Silvet.o: data/include/oboe.h data/include/tenorsax.h
-src/Silvet.o: data/include/violin.h data/include/piano-maps-SptkBGCl.h
-src/Silvet.o: data/include/piano1.h data/include/piano2.h
-src/Silvet.o: data/include/piano3.h
+src/Silvet.o: src/Silvet.h src/EM.h
 src/Silvet.o: constant-q-cpp/cpp-qm-dsp/CQInterpolated.h
 src/Silvet.o: constant-q-cpp/cpp-qm-dsp/ConstantQ.h
 src/Silvet.o: constant-q-cpp/cpp-qm-dsp/CQKernel.h
+src/EM.o: src/EM.h data/include/templates.h data/include/bassoon.h
+src/EM.o: data/include/cello.h data/include/clarinet.h data/include/flute.h
+src/EM.o: data/include/guitar.h data/include/horn.h data/include/oboe.h
+src/EM.o: data/include/tenorsax.h data/include/violin.h
+src/EM.o: data/include/piano-maps-SptkBGCl.h data/include/piano1.h
+src/EM.o: data/include/piano2.h data/include/piano3.h
 src/libmain.o: src/Silvet.h
 constant-q-cpp/cpp-qm-dsp/CQKernel.o: constant-q-cpp/cpp-qm-dsp/CQKernel.h
 constant-q-cpp/cpp-qm-dsp/ConstantQ.o: constant-q-cpp/cpp-qm-dsp/ConstantQ.h
--- a/src/EM.cpp	Fri Apr 04 14:38:40 2014 +0100
+++ b/src/EM.cpp	Fri Apr 04 17:48:06 2014 +0100
@@ -17,6 +17,16 @@
 
 #include "data/include/templates.h"
 
+#include <cstdlib>
+
+#include <iostream>
+
+#include <vector>
+
+using std::vector;
+using std::cerr;
+using std::endl;
+
 static double epsilon = 1e-16;
 
 EM::EM() :
@@ -24,6 +34,8 @@
     m_bins(SILVET_TEMPLATE_HEIGHT),
     m_instruments(SILVET_TEMPLATE_COUNT)
 {
+    cerr << "init!" << endl;
+
     m_lowest = 0;
     m_highest = m_notes - 1;
 
@@ -38,7 +50,7 @@
 
     m_pitches = V(m_notes);
 
-    for (int n = 0; n < m_notes; ++i) {
+    for (int n = 0; n < m_notes; ++n) {
         m_pitches[n] = drand48();
     }
     
@@ -51,11 +63,10 @@
         }
     }
 
+    m_estimate = V(m_bins);
     m_q = V(m_bins);
-    
-    for (int w = 0; w < m_bins; ++w) {
-        m_q[w] = epsilon;
-    }
+
+    cerr << "(init done)" << endl;
 }
 
 EM::~EM()
@@ -69,3 +80,104 @@
             note <= silvet_templates[instrument].highest);
 }
 
+void
+EM::normalise(V &column)
+{
+    double sum = 0.0;
+    for (int i = 0; i < (int)column.size(); ++i) {
+        sum += column[i];
+    }
+    for (int i = 0; i < (int)column.size(); ++i) {
+        column[i] /= sum;
+    }
+}
+
+void
+EM::iterate(V column)
+{
+    normalise(column);
+    expectation(column);
+    maximisation(column);
+}
+
+void
+EM::expectation(const V &column)
+{
+    cerr << ".";
+
+    for (int i = 0; i < m_bins; ++i) {
+        m_estimate[i] = epsilon;
+    }
+
+    for (int i = 0; i < m_instruments; ++i) {
+        for (int n = 0; n < m_notes; ++n) {
+            float *w = silvet_templates[i].data[n];
+            double pitch = m_pitches[n];
+            double source = m_sources[i][n];
+            for (int j = 0; j < m_bins; ++j) {
+                m_estimate[j] += w[j] * pitch * source;
+            }
+        }
+    }
+
+    for (int i = 0; i < m_bins; ++i) {
+        m_q[i] = column[i] / m_estimate[i];
+    }
+}
+
+void
+EM::maximisation(const V &column)
+{
+    V newPitches = m_pitches;
+
+    for (int n = 0; n < m_notes; ++n) {
+        newPitches[n] = epsilon;
+        if (n >= m_lowest && n <= m_highest) {
+            for (int i = 0; i < m_instruments; ++i) {
+                float *w = silvet_templates[i].data[n];
+                double pitch = m_pitches[n];
+                double source = m_sources[i][n];
+                for (int j = 0; j < m_bins; ++j) {
+                    newPitches[n] += w[j] * m_q[j] * pitch * source;
+                }
+            }
+        }
+    }
+    normalise(newPitches);
+
+    Grid newSources = m_sources;
+
+    for (int i = 0; i < m_instruments; ++i) {
+        for (int n = 0; n < m_notes; ++n) {
+            newSources[i][n] = epsilon;
+            if (inRange(i, n)) {
+                float *w = silvet_templates[i].data[n];
+                for (int j = 0; j < m_bins; ++j) {
+                    newSources[i][n] +=
+                        w[j] * m_q[j] * m_pitches[n] * m_sources[i][n];
+                }
+            }
+        }
+        normalise(newSources[i]);
+    }
+
+    m_pitches = newPitches;
+    m_sources = newSources;
+}
+
+void
+EM::report()
+{
+    vector<int> sounding;
+    for (int n = 0; n < m_notes; ++n) {
+        if (m_pitches[n] > 0.05) {
+            sounding.push_back(n);
+        }
+    }
+    cerr << " sounding: ";
+    for (int i = 0; i < (int)sounding.size(); ++i) {
+        cerr << sounding[i] << " ";
+    }
+    cerr << endl;
+}
+
--- a/src/EM.h	Fri Apr 04 14:38:40 2014 +0100
+++ b/src/EM.h	Fri Apr 04 17:48:06 2014 +0100
@@ -24,7 +24,8 @@
     EM();
     ~EM();
 
-    void iterate(const std::vector<double> &column);
+    void iterate(std::vector<double> column);
+    void report();
 
 private:
     typedef std::vector<double> V;
@@ -32,7 +33,9 @@
 
     V m_pitches;
     Grid m_sources;
-    Grid m_q;
+
+    V m_estimate;
+    V m_q;
     
     int m_notes;
     int m_bins;
@@ -41,6 +44,10 @@
     int m_lowest;
     int m_highest;
 
+    void normalise(V &column);
+    void expectation(const V &column);
+    void maximisation(const V &column);
+
     bool inRange(int instrument, int note);
 };
 
--- a/src/Silvet.cpp	Fri Apr 04 14:38:40 2014 +0100
+++ b/src/Silvet.cpp	Fri Apr 04 17:48:06 2014 +0100
@@ -301,6 +301,7 @@
             em.iterate(filtered[i]);
         }
         //!!! now do something with the results from em!
+        em.report();
     }
 
     return fs;