Mercurial > hg > silvet
view src/EM.cpp @ 110:e282930cfca7
Add draft/intensive mode setting (determines whether to use shifts)
author | Chris Cannam |
---|---|
date | Tue, 06 May 2014 18:55:11 +0100 |
parents | 3e7e3c610fae |
children | 2169e7a448c5 |
line wrap: on
line source
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ /* Silvet A Vamp plugin for note transcription. Centre for Digital Music, Queen Mary University of London. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. See the file COPYING included with this distribution for more information. */ #include "EM.h" #include "data/include/templates.h" #include <cstdlib> #include <cmath> #include <iostream> #include <vector> using std::vector; using std::cerr; using std::endl; static double epsilon = 1e-16; EM::EM(bool useShifts) : m_useShifts(useShifts), m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), m_binCount(SILVET_TEMPLATE_HEIGHT), m_instrumentCount(SILVET_TEMPLATE_COUNT), m_pitchSparsity(1.1), m_sourceSparsity(1.3) { m_lowestPitch = silvet_templates_lowest_note; m_highestPitch = silvet_templates_highest_note; m_pitches = V(m_noteCount); for (int n = 0; n < m_noteCount; ++n) { m_pitches[n] = drand48(); } m_shifts = Grid(m_shiftCount); for (int f = 0; f < m_shiftCount; ++f) { m_shifts[f] = V(m_noteCount); for (int n = 0; n < m_noteCount; ++n) { if (m_useShifts) { m_shifts[f][n] = drand48(); } else { m_shifts[f][n] = 1.0; } } } m_sources = Grid(m_instrumentCount); for (int i = 0; i < m_instrumentCount; ++i) { m_sources[i] = V(m_noteCount); for (int n = 0; n < m_noteCount; ++n) { m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); } } m_estimate = V(m_binCount); m_q = V(m_binCount); } EM::~EM() { } void EM::rangeFor(int instrument, int &minPitch, int &maxPitch) { minPitch = silvet_templates[instrument].lowest; maxPitch = silvet_templates[instrument].highest; } bool EM::inRange(int instrument, int pitch) { int minPitch, maxPitch; rangeFor(instrument, minPitch, maxPitch); return (pitch >= minPitch && pitch <= maxPitch); } void EM::normaliseColumn(V &column) { double sum = 0.0; for (int i = 0; i < (int)column.size(); ++i) { sum += column[i]; } for (int i = 0; i < (int)column.size(); ++i) { column[i] /= sum; } } void EM::normaliseGrid(Grid &grid) { V denominators(grid[0].size()); for (int i = 0; i < (int)grid.size(); ++i) { for (int j = 0; j < (int)grid[i].size(); ++j) { denominators[j] += grid[i][j]; } } for (int i = 0; i < (int)grid.size(); ++i) { for (int j = 0; j < (int)grid[i].size(); ++j) { grid[i][j] /= denominators[j]; } } } void EM::iterate(V column) { normaliseColumn(column); expectation(column); maximisation(column); } const float * EM::templateFor(int instrument, int note, int shift) { if (m_useShifts) { return silvet_templates[instrument].data[note] + shift; } else { return silvet_templates[instrument].data[note] + SILVET_TEMPLATE_MAX_SHIFT; } } void EM::expectation(const V &column) { // cerr << "."; for (int i = 0; i < m_binCount; ++i) { m_estimate[i] = epsilon; } for (int i = 0; i < m_instrumentCount; ++i) { for (int n = 0; n < m_noteCount; ++n) { for (int f = 0; f < m_shiftCount; ++f) { const float *w = templateFor(i, n, f); double pitch = m_pitches[n]; double source = m_sources[i][n]; double shift = m_shifts[f][n]; for (int j = 0; j < m_binCount; ++j) { m_estimate[j] += w[j] * pitch * source * shift; } } } } for (int i = 0; i < m_binCount; ++i) { m_q[i] = column[i] / m_estimate[i]; } } void EM::maximisation(const V &column) { V newPitches = m_pitches; for (int n = 0; n < m_noteCount; ++n) { newPitches[n] = epsilon; if (n >= m_lowestPitch && n <= m_highestPitch) { for (int i = 0; i < m_instrumentCount; ++i) { for (int f = 0; f < m_shiftCount; ++f) { const float *w = templateFor(i, n, f); double pitch = m_pitches[n]; double source = m_sources[i][n]; double shift = m_shifts[f][n]; for (int j = 0; j < m_binCount; ++j) { newPitches[n] += w[j] * m_q[j] * pitch * source * shift; } } } } if (m_pitchSparsity != 1.0) { newPitches[n] = pow(newPitches[n], m_pitchSparsity); } } normaliseColumn(newPitches); Grid newShifts = m_shifts; if (m_useShifts) { for (int f = 0; f < m_shiftCount; ++f) { for (int n = 0; n < m_noteCount; ++n) { newShifts[f][n] = epsilon; for (int i = 0; i < m_instrumentCount; ++i) { const float *w = templateFor(i, n, f); double pitch = m_pitches[n]; double source = m_sources[i][n]; double shift = m_shifts[f][n]; for (int j = 0; j < m_binCount; ++j) { newShifts[f][n] += w[j] * m_q[j] * pitch * source * shift; } } } } normaliseGrid(newShifts); } Grid newSources = m_sources; for (int i = 0; i < m_instrumentCount; ++i) { for (int n = 0; n < m_noteCount; ++n) { newSources[i][n] = epsilon; if (inRange(i, n)) { for (int f = 0; f < m_shiftCount; ++f) { const float *w = templateFor(i, n, f); double pitch = m_pitches[n]; double source = m_sources[i][n]; double shift = m_shifts[f][n]; for (int j = 0; j < m_binCount; ++j) { newSources[i][n] += w[j] * m_q[j] * pitch * source * shift; } } } if (m_sourceSparsity != 1.0) { newSources[i][n] = pow(newSources[i][n], m_sourceSparsity); } } } normaliseGrid(newSources); m_pitches = newPitches; m_shifts = newShifts; m_sources = newSources; }