Mercurial > hg > silvet
changeset 97:840c0d703bbb timing
Use single-precision floats throughout EM code
author | Chris Cannam |
---|---|
date | Tue, 06 May 2014 14:45:16 +0100 |
parents | f1116eb464f9 |
children | faeeb78badec |
files | data/include/templates.h src/EM.cpp src/EM.h src/Silvet.cpp src/Silvet.h yeti/scratch/generateTemplatesC.yeti |
diffstat | 6 files changed, 32 insertions(+), 31 deletions(-) [+] |
line wrap: on
line diff
--- a/data/include/templates.h Tue May 06 13:05:43 2014 +0100 +++ b/data/include/templates.h Tue May 06 14:45:16 2014 +0100 @@ -15,7 +15,7 @@ const char *name; int lowest; int highest; - double data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE]; + float data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE]; } silvet_template_t; static int silvet_templates_lowest_note = 15;
--- a/src/EM.cpp Tue May 06 13:05:43 2014 +0100 +++ b/src/EM.cpp Tue May 06 14:45:16 2014 +0100 @@ -28,7 +28,7 @@ using std::cerr; using std::endl; -static double epsilon = 1e-16; +static float epsilon = 1e-8; EM::EM() : m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), @@ -87,7 +87,7 @@ void EM::normaliseColumn(V &column) { - double sum = 0.0; + float sum = 0.0; for (int i = 0; i < (int)column.size(); ++i) { sum += column[i]; } @@ -115,14 +115,15 @@ } void -EM::iterate(V column) +EM::iterate(const vector<double> &column) { - normaliseColumn(column); - expectation(column); - maximisation(column); + V norm(column.begin(), column.end()); + normaliseColumn(norm); + expectation(norm); + maximisation(norm); } -const double * +const float * EM::templateFor(int instrument, int note, int shift) { return silvet_templates[instrument].data[note] + shift; @@ -139,12 +140,12 @@ for (int i = 0; i < m_instrumentCount; ++i) { for (int n = 0; n < m_noteCount; ++n) { - const double pitch = m_pitches[n]; - const double source = m_sources[i][n]; + const float pitch = m_pitches[n]; + const float source = m_sources[i][n]; for (int f = 0; f < m_shiftCount; ++f) { - const double *w = templateFor(i, n, f); - const double shift = m_shifts[f][n]; - const double factor = pitch * source * shift; + const float *w = templateFor(i, n, f); + const float shift = m_shifts[f][n]; + const float factor = pitch * source * shift; for (int j = 0; j < m_binCount; ++j) { m_estimate[j] += w[j] * factor; } @@ -166,17 +167,17 @@ for (int n = 0; n < m_noteCount; ++n) { - const double pitch = m_pitches[n]; + const float pitch = m_pitches[n]; for (int f = 0; f < m_shiftCount; ++f) { - const double shift = m_shifts[f][n]; + const float shift = m_shifts[f][n]; for (int i = 0; i < m_instrumentCount; ++i) { - const double source = m_sources[i][n]; - const double factor = pitch * source * shift; - const double *w = templateFor(i, n, f); + const float source = m_sources[i][n]; + const float factor = pitch * source * shift; + const float *w = templateFor(i, n, f); if (n >= m_lowestPitch && n <= m_highestPitch) {
--- a/src/EM.h Tue May 06 13:05:43 2014 +0100 +++ b/src/EM.h Tue May 06 14:45:16 2014 +0100 @@ -24,21 +24,21 @@ EM(); ~EM(); - void iterate(std::vector<double> column); + void iterate(const std::vector<double> &column); - const std::vector<double> &getEstimate() const { + const std::vector<float> &getEstimate() const { return m_estimate; } - const std::vector<double> &getPitchDistribution() const { + const std::vector<float> &getPitchDistribution() const { return m_pitches; } - const std::vector<std::vector<double> > &getSources() const { + const std::vector<std::vector<float> > &getSources() const { return m_sources; } private: - typedef std::vector<double> V; - typedef std::vector<std::vector<double> > Grid; + typedef std::vector<float> V; + typedef std::vector<std::vector<float> > Grid; V m_pitches; Grid m_shifts; @@ -52,8 +52,8 @@ const int m_binCount; const int m_instrumentCount; - const double m_pitchSparsity; - const double m_sourceSparsity; + const float m_pitchSparsity; + const float m_sourceSparsity; const int m_lowestPitch; const int m_highestPitch; @@ -63,7 +63,7 @@ void expectation(const V &column); void maximisation(const V &column); - const double *templateFor(int instrument, int note, int shift); + const float *templateFor(int instrument, int note, int shift); void rangeFor(int instrument, int &minPitch, int &maxPitch); bool inRange(int instrument, int pitch); };
--- a/src/Silvet.cpp Tue May 06 13:05:43 2014 +0100 +++ b/src/Silvet.cpp Tue May 06 14:45:16 2014 +0100 @@ -401,7 +401,7 @@ em.iterate(filtered[i]); } - vector<double> pitches = em.getPitchDistribution(); + vector<float> pitches = em.getPitchDistribution(); for (int j = 0; j < processingNotes; ++j) { pitches[j] *= sum; @@ -498,7 +498,7 @@ } Vamp::Plugin::FeatureList -Silvet::postProcess(const vector<double> &pitches) +Silvet::postProcess(const vector<float> &pitches) { vector<double> filtered;
--- a/src/Silvet.h Tue May 06 13:05:43 2014 +0100 +++ b/src/Silvet.h Tue May 06 14:45:16 2014 +0100 @@ -79,7 +79,7 @@ vector<map<int, double> > m_pianoRoll; Grid preProcess(const Grid &); - FeatureList postProcess(const vector<double> &); + FeatureList postProcess(const vector<float> &); FeatureSet transcribe(const Grid &); string noteName(int n) const;
--- a/yeti/scratch/generateTemplatesC.yeti Tue May 06 13:05:43 2014 +0100 +++ b/yeti/scratch/generateTemplatesC.yeti Tue May 06 14:45:16 2014 +0100 @@ -106,7 +106,7 @@ " const char *name;", " int lowest;", " int highest;", - " double data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];", + " float data[SILVET_TEMPLATE_NOTE_COUNT][SILVET_TEMPLATE_SIZE];", "} silvet_template_t;", "", "static int silvet_templates_lowest_note = \(overallLowest);",