Mercurial > hg > silvet
diff src/EM.cpp @ 36:74b77a4d6552
Fill out the EM logic
author | Chris Cannam |
---|---|
date | Fri, 04 Apr 2014 17:48:06 +0100 |
parents | 461d94ed3816 |
children | 947996aac974 |
line wrap: on
line diff
--- a/src/EM.cpp Fri Apr 04 14:38:40 2014 +0100 +++ b/src/EM.cpp Fri Apr 04 17:48:06 2014 +0100 @@ -17,6 +17,16 @@ #include "data/include/templates.h" +#include <cstdlib> + +#include <iostream> + +#include <vector> + +using std::vector; +using std::cerr; +using std::endl; + static double epsilon = 1e-16; EM::EM() : @@ -24,6 +34,8 @@ m_bins(SILVET_TEMPLATE_HEIGHT), m_instruments(SILVET_TEMPLATE_COUNT) { + cerr << "init!" << endl; + m_lowest = 0; m_highest = m_notes - 1; @@ -38,7 +50,7 @@ m_pitches = V(m_notes); - for (int n = 0; n < m_notes; ++i) { + for (int n = 0; n < m_notes; ++n) { m_pitches[n] = drand48(); } @@ -51,11 +63,10 @@ } } + m_estimate = V(m_bins); m_q = V(m_bins); - - for (int w = 0; w < m_bins; ++w) { - m_q[w] = epsilon; - } + + cerr << "(init done)" << endl; } EM::~EM() @@ -69,3 +80,104 @@ note <= silvet_templates[instrument].highest); } +void +EM::normalise(V &column) +{ + double sum = 0.0; + for (int i = 0; i < (int)column.size(); ++i) { + sum += column[i]; + } + for (int i = 0; i < (int)column.size(); ++i) { + column[i] /= sum; + } +} + +void +EM::iterate(V column) +{ + normalise(column); + expectation(column); + maximisation(column); +} + +void +EM::expectation(const V &column) +{ + cerr << "."; + + for (int i = 0; i < m_bins; ++i) { + m_estimate[i] = epsilon; + } + + for (int i = 0; i < m_instruments; ++i) { + for (int n = 0; n < m_notes; ++n) { + float *w = silvet_templates[i].data[n]; + double pitch = m_pitches[n]; + double source = m_sources[i][n]; + for (int j = 0; j < m_bins; ++j) { + m_estimate[j] += w[j] * pitch * source; + } + } + } + + for (int i = 0; i < m_bins; ++i) { + m_q[i] = column[i] / m_estimate[i]; + } +} + +void +EM::maximisation(const V &column) +{ + V newPitches = m_pitches; + + for (int n = 0; n < m_notes; ++n) { + newPitches[n] = epsilon; + if (n >= m_lowest && n <= m_highest) { + for (int i = 0; i < m_instruments; ++i) { + float *w = silvet_templates[i].data[n]; + double pitch = m_pitches[n]; + double source = m_sources[i][n]; + for (int j = 0; j < m_bins; ++j) { + newPitches[n] += w[j] * m_q[j] * pitch * source; + } + } + } + } + normalise(newPitches); + + Grid newSources = m_sources; + + for (int i = 0; i < m_instruments; ++i) { + for (int n = 0; n < m_notes; ++n) { + newSources[i][n] = epsilon; + if (inRange(i, n)) { + float *w = silvet_templates[i].data[n]; + for (int j = 0; j < m_bins; ++j) { + newSources[i][n] += + w[j] * m_q[j] * m_pitches[n] * m_sources[i][n]; + } + } + } + normalise(newSources[i]); + } + + m_pitches = newPitches; + m_sources = newSources; +} + +void +EM::report() +{ + vector<int> sounding; + for (int n = 0; n < m_notes; ++n) { + if (m_pitches[n] > 0.05) { + sounding.push_back(n); + } + } + cerr << " sounding: "; + for (int i = 0; i < (int)sounding.size(); ++i) { + cerr << sounding[i] << " "; + } + cerr << endl; +} +