diff src/EM.cpp @ 55:384338fa460d preshift

Support shifts as an additional dimension (as in the original model). Also return velocity as well.
author Chris Cannam
date Tue, 08 Apr 2014 13:30:32 +0100
parents a54df67e607e
children 3e7e3c610fae
line wrap: on
line diff
--- a/src/EM.cpp	Mon Apr 07 17:36:40 2014 +0100
+++ b/src/EM.cpp	Tue Apr 08 13:30:32 2014 +0100
@@ -33,28 +33,31 @@
 EM::EM() :
     m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
     m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1),
-    m_pitchCount(m_noteCount * m_shiftCount),
     m_binCount(SILVET_TEMPLATE_HEIGHT),
     m_instrumentCount(SILVET_TEMPLATE_COUNT),
     m_pitchSparsity(1.1),
     m_sourceSparsity(1.3)
 {
-    m_lowestPitch = 
-        silvet_templates_lowest_note * m_shiftCount;
-    m_highestPitch =
-        silvet_templates_highest_note * m_shiftCount + m_shiftCount - 1;
+    m_lowestPitch = silvet_templates_lowest_note;
+    m_highestPitch = silvet_templates_highest_note;
 
-    m_pitches = V(m_pitchCount);
+    m_pitches = V(m_noteCount);
+    for (int n = 0; n < m_noteCount; ++n) {
+        m_pitches[n] = drand48();
+    }
 
-    for (int n = 0; n < m_pitchCount; ++n) {
-        m_pitches[n] = drand48();
+    m_shifts = Grid(m_shiftCount);
+    for (int f = 0; f < m_shiftCount; ++f) {
+        m_shifts[f] = V(m_noteCount);
+        for (int n = 0; n < m_noteCount; ++n) {
+            m_shifts[f][n] = drand48();
+        }
     }
     
     m_sources = Grid(m_instrumentCount);
-    
-    for (int i = 0; i < m_instrumentCount; ++i) {
-        m_sources[i] = V(m_pitchCount);
-        for (int n = 0; n < m_pitchCount; ++n) {
+        for (int i = 0; i < m_instrumentCount; ++i) {
+        m_sources[i] = V(m_noteCount);
+        for (int n = 0; n < m_noteCount; ++n) {
             m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
         }
     }
@@ -70,9 +73,8 @@
 void
 EM::rangeFor(int instrument, int &minPitch, int &maxPitch)
 {
-    minPitch = silvet_templates[instrument].lowest * m_shiftCount;
-    maxPitch = silvet_templates[instrument].highest * m_shiftCount
-        + m_shiftCount - 1;
+    minPitch = silvet_templates[instrument].lowest;
+    maxPitch = silvet_templates[instrument].highest;
 }
 
 bool
@@ -84,7 +86,7 @@
 }
 
 void
-EM::normalise(V &column)
+EM::normaliseColumn(V &column)
 {
     double sum = 0.0;
     for (int i = 0; i < (int)column.size(); ++i) {
@@ -96,19 +98,19 @@
 }
 
 void
-EM::normaliseSources(Grid &sources)
+EM::normaliseGrid(Grid &grid)
 {
-    V denominators(sources[0].size());
+    V denominators(grid[0].size());
 
-    for (int i = 0; i < (int)sources.size(); ++i) {
-        for (int j = 0; j < (int)sources[i].size(); ++j) {
-            denominators[j] += sources[i][j];
+    for (int i = 0; i < (int)grid.size(); ++i) {
+        for (int j = 0; j < (int)grid[i].size(); ++j) {
+            denominators[j] += grid[i][j];
         }
     }
 
-    for (int i = 0; i < (int)sources.size(); ++i) {
-        for (int j = 0; j < (int)sources[i].size(); ++j) {
-            sources[i][j] /= denominators[j];
+    for (int i = 0; i < (int)grid.size(); ++i) {
+        for (int j = 0; j < (int)grid[i].size(); ++j) {
+            grid[i][j] /= denominators[j];
         }
     }
 }
@@ -116,16 +118,14 @@
 void
 EM::iterate(V column)
 {
-    normalise(column);
+    normaliseColumn(column);
     expectation(column);
     maximisation(column);
 }
 
 const float *
-EM::templateFor(int instrument, int pitch)
+EM::templateFor(int instrument, int note, int shift)
 {
-    int note = pitch / m_shiftCount;
-    int shift = pitch % m_shiftCount;
     return silvet_templates[instrument].data[note] + shift;
 }
 
@@ -139,12 +139,15 @@
     }
 
     for (int i = 0; i < m_instrumentCount; ++i) {
-        for (int n = 0; n < m_pitchCount; ++n) {
-            const float *w = templateFor(i, n);
-            double pitch = m_pitches[n];
-            double source = m_sources[i][n];
-            for (int j = 0; j < m_binCount; ++j) {
-                m_estimate[j] += w[j] * pitch * source;
+        for (int n = 0; n < m_noteCount; ++n) {
+            for (int f = 0; f < m_shiftCount; ++f) {
+                const float *w = templateFor(i, n, f);
+                double pitch = m_pitches[n];
+                double source = m_sources[i][n];
+                double shift = m_shifts[f][n];
+                for (int j = 0; j < m_binCount; ++j) {
+                    m_estimate[j] += w[j] * pitch * source * shift;
+                }
             }
         }
     }
@@ -159,15 +162,18 @@
 {
     V newPitches = m_pitches;
 
-    for (int n = 0; n < m_pitchCount; ++n) {
+    for (int n = 0; n < m_noteCount; ++n) {
         newPitches[n] = epsilon;
         if (n >= m_lowestPitch && n <= m_highestPitch) {
             for (int i = 0; i < m_instrumentCount; ++i) {
-                const float *w = templateFor(i, n);
-                double pitch = m_pitches[n];
-                double source = m_sources[i][n];
-                for (int j = 0; j < m_binCount; ++j) {
-                    newPitches[n] += w[j] * m_q[j] * pitch * source;
+                for (int f = 0; f < m_shiftCount; ++f) {
+                    const float *w = templateFor(i, n, f);
+                    double pitch = m_pitches[n];
+                    double source = m_sources[i][n];
+                    double shift = m_shifts[f][n];
+                    for (int j = 0; j < m_binCount; ++j) {
+                        newPitches[n] += w[j] * m_q[j] * pitch * source * shift;
+                    }
                 }
             }
         }
@@ -175,19 +181,40 @@
             newPitches[n] = pow(newPitches[n], m_pitchSparsity);
         }
     }
-    normalise(newPitches);
+    normaliseColumn(newPitches);
+
+    Grid newShifts = m_shifts;
+
+    for (int f = 0; f < m_shiftCount; ++f) {
+        for (int n = 0; n < m_noteCount; ++n) {
+            newShifts[f][n] = epsilon;
+            for (int i = 0; i < m_instrumentCount; ++i) {
+                const float *w = templateFor(i, n, f);
+                double pitch = m_pitches[n];
+                double source = m_sources[i][n];
+                double shift = m_shifts[f][n];
+                for (int j = 0; j < m_binCount; ++j) {
+                    newShifts[f][n] += w[j] * m_q[j] * pitch * source * shift;
+                }
+            }
+        }
+    }
+    normaliseGrid(newShifts);
 
     Grid newSources = m_sources;
 
     for (int i = 0; i < m_instrumentCount; ++i) {
-        for (int n = 0; n < m_pitchCount; ++n) {
+        for (int n = 0; n < m_noteCount; ++n) {
             newSources[i][n] = epsilon;
             if (inRange(i, n)) {
-                const float *w = templateFor(i, n);
-                double pitch = m_pitches[n];
-                double source = m_sources[i][n];
-                for (int j = 0; j < m_binCount; ++j) {
-                    newSources[i][n] += w[j] * m_q[j] * pitch * source;
+                for (int f = 0; f < m_shiftCount; ++f) {
+                    const float *w = templateFor(i, n, f);
+                    double pitch = m_pitches[n];
+                    double source = m_sources[i][n];
+                    double shift = m_shifts[f][n];
+                    for (int j = 0; j < m_binCount; ++j) {
+                        newSources[i][n] += w[j] * m_q[j] * pitch * source * shift;
+                    }
                 }
             }
             if (m_sourceSparsity != 1.0) {
@@ -195,34 +222,11 @@
             }
         }
     }
-    normaliseSources(newSources);
+    normaliseGrid(newSources);
 
     m_pitches = newPitches;
+    m_shifts = newShifts;
     m_sources = newSources;
 }
 
-void
-EM::report()
-{
-    vector<int> sounding;
-    for (int n = 0; n < m_pitchCount; ++n) {
-        if (m_pitches[n] > 0.05) {
-            sounding.push_back(n);
-        }
-    }
-    cerr << " sounding: ";
-    for (int i = 0; i < (int)sounding.size(); ++i) {
-        cerr << sounding[i] << " ";
-        int maxj = -1;
-        double maxs = 0.0;
-        for (int j = 0; j < m_instrumentCount; ++j) {
-            if (j == 0 || m_sources[j][sounding[i]] > maxs) {
-                maxj = j;
-                maxs = m_sources[j][sounding[i]];
-            }
-        }
-        cerr << silvet_templates[maxj].name << " ";
-    }
-    cerr << endl;
-}