annotate src/EM.cpp @ 116:91bb029a847a timing

Reorder the calculations to match the series of vector operations in the most recent bqvec code, just in case it's the order of vector calculations that is saving the time rather than the avoidance of std::vector
author Chris Cannam
date Wed, 07 May 2014 09:57:19 +0100
parents a6e136aaa202
children
rev   line source
Chris@34 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@34 2
Chris@34 3 /*
Chris@34 4 Silvet
Chris@34 5
Chris@34 6 A Vamp plugin for note transcription.
Chris@34 7 Centre for Digital Music, Queen Mary University of London.
Chris@34 8
Chris@34 9 This program is free software; you can redistribute it and/or
Chris@34 10 modify it under the terms of the GNU General Public License as
Chris@34 11 published by the Free Software Foundation; either version 2 of the
Chris@34 12 License, or (at your option) any later version. See the file
Chris@34 13 COPYING included with this distribution for more information.
Chris@34 14 */
Chris@34 15
Chris@34 16 #include "EM.h"
Chris@34 17
Chris@34 18 #include "data/include/templates.h"
Chris@34 19
Chris@36 20 #include <cstdlib>
Chris@42 21 #include <cmath>
Chris@36 22
Chris@36 23 #include <iostream>
Chris@36 24
Chris@36 25 #include <vector>
Chris@36 26
Chris@36 27 using std::vector;
Chris@36 28 using std::cerr;
Chris@36 29 using std::endl;
Chris@36 30
Chris@35 31 static double epsilon = 1e-16;
Chris@35 32
Chris@35 33 EM::EM() :
Chris@45 34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
Chris@45 35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1),
Chris@45 36 m_binCount(SILVET_TEMPLATE_HEIGHT),
Chris@45 37 m_instrumentCount(SILVET_TEMPLATE_COUNT),
Chris@42 38 m_pitchSparsity(1.1),
Chris@83 39 m_sourceSparsity(1.3),
Chris@83 40 m_lowestPitch(silvet_templates_lowest_note),
Chris@83 41 m_highestPitch(silvet_templates_highest_note)
Chris@35 42 {
Chris@55 43 m_pitches = V(m_noteCount);
Chris@55 44 for (int n = 0; n < m_noteCount; ++n) {
Chris@55 45 m_pitches[n] = drand48();
Chris@55 46 }
Chris@35 47
Chris@55 48 m_shifts = Grid(m_shiftCount);
Chris@55 49 for (int f = 0; f < m_shiftCount; ++f) {
Chris@55 50 m_shifts[f] = V(m_noteCount);
Chris@55 51 for (int n = 0; n < m_noteCount; ++n) {
Chris@55 52 m_shifts[f][n] = drand48();
Chris@55 53 }
Chris@35 54 }
Chris@35 55
Chris@45 56 m_sources = Grid(m_instrumentCount);
Chris@55 57 for (int i = 0; i < m_instrumentCount; ++i) {
Chris@55 58 m_sources[i] = V(m_noteCount);
Chris@55 59 for (int n = 0; n < m_noteCount; ++n) {
Chris@35 60 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
Chris@35 61 }
Chris@35 62 }
Chris@35 63
Chris@45 64 m_estimate = V(m_binCount);
Chris@45 65 m_q = V(m_binCount);
Chris@35 66 }
Chris@35 67
Chris@35 68 EM::~EM()
Chris@35 69 {
Chris@35 70 }
Chris@35 71
Chris@45 72 void
Chris@45 73 EM::rangeFor(int instrument, int &minPitch, int &maxPitch)
Chris@45 74 {
Chris@55 75 minPitch = silvet_templates[instrument].lowest;
Chris@55 76 maxPitch = silvet_templates[instrument].highest;
Chris@45 77 }
Chris@45 78
Chris@35 79 bool
Chris@45 80 EM::inRange(int instrument, int pitch)
Chris@35 81 {
Chris@45 82 int minPitch, maxPitch;
Chris@45 83 rangeFor(instrument, minPitch, maxPitch);
Chris@45 84 return (pitch >= minPitch && pitch <= maxPitch);
Chris@35 85 }
Chris@35 86
Chris@36 87 void
Chris@55 88 EM::normaliseColumn(V &column)
Chris@36 89 {
Chris@36 90 double sum = 0.0;
Chris@36 91 for (int i = 0; i < (int)column.size(); ++i) {
Chris@36 92 sum += column[i];
Chris@36 93 }
Chris@36 94 for (int i = 0; i < (int)column.size(); ++i) {
Chris@36 95 column[i] /= sum;
Chris@36 96 }
Chris@36 97 }
Chris@36 98
Chris@36 99 void
Chris@55 100 EM::normaliseGrid(Grid &grid)
Chris@53 101 {
Chris@55 102 V denominators(grid[0].size());
Chris@53 103
Chris@55 104 for (int i = 0; i < (int)grid.size(); ++i) {
Chris@55 105 for (int j = 0; j < (int)grid[i].size(); ++j) {
Chris@55 106 denominators[j] += grid[i][j];
Chris@53 107 }
Chris@53 108 }
Chris@53 109
Chris@55 110 for (int i = 0; i < (int)grid.size(); ++i) {
Chris@55 111 for (int j = 0; j < (int)grid[i].size(); ++j) {
Chris@55 112 grid[i][j] /= denominators[j];
Chris@53 113 }
Chris@53 114 }
Chris@53 115 }
Chris@53 116
Chris@53 117 void
Chris@36 118 EM::iterate(V column)
Chris@36 119 {
Chris@55 120 normaliseColumn(column);
Chris@36 121 expectation(column);
Chris@36 122 maximisation(column);
Chris@36 123 }
Chris@36 124
Chris@88 125 const double *
Chris@55 126 EM::templateFor(int instrument, int note, int shift)
Chris@45 127 {
Chris@45 128 return silvet_templates[instrument].data[note] + shift;
Chris@45 129 }
Chris@45 130
Chris@36 131 void
Chris@36 132 EM::expectation(const V &column)
Chris@36 133 {
Chris@62 134 // cerr << ".";
Chris@36 135
Chris@45 136 for (int i = 0; i < m_binCount; ++i) {
Chris@36 137 m_estimate[i] = epsilon;
Chris@36 138 }
Chris@36 139
Chris@45 140 for (int i = 0; i < m_instrumentCount; ++i) {
Chris@55 141 for (int n = 0; n < m_noteCount; ++n) {
Chris@83 142 const double pitch = m_pitches[n];
Chris@83 143 const double source = m_sources[i][n];
Chris@55 144 for (int f = 0; f < m_shiftCount; ++f) {
Chris@88 145 const double *w = templateFor(i, n, f);
Chris@83 146 const double shift = m_shifts[f][n];
Chris@83 147 const double factor = pitch * source * shift;
Chris@55 148 for (int j = 0; j < m_binCount; ++j) {
Chris@83 149 m_estimate[j] += w[j] * factor;
Chris@55 150 }
Chris@36 151 }
Chris@36 152 }
Chris@36 153 }
Chris@36 154
Chris@45 155 for (int i = 0; i < m_binCount; ++i) {
Chris@36 156 m_q[i] = column[i] / m_estimate[i];
Chris@36 157 }
Chris@36 158 }
Chris@36 159
Chris@36 160 void
Chris@36 161 EM::maximisation(const V &column)
Chris@36 162 {
Chris@89 163 V newPitches(m_noteCount, epsilon);
Chris@89 164 Grid newShifts(m_shiftCount, V(m_noteCount, epsilon));
Chris@89 165 Grid newSources(m_instrumentCount, V(m_noteCount, epsilon));
Chris@36 166
Chris@116 167 V contributions(m_binCount);
Chris@116 168
Chris@55 169 for (int n = 0; n < m_noteCount; ++n) {
Chris@85 170
Chris@85 171 const double pitch = m_pitches[n];
Chris@85 172
Chris@85 173 for (int f = 0; f < m_shiftCount; ++f) {
Chris@85 174
Chris@85 175 const double shift = m_shifts[f][n];
Chris@85 176
Chris@45 177 for (int i = 0; i < m_instrumentCount; ++i) {
Chris@85 178
Chris@83 179 const double source = m_sources[i][n];
Chris@89 180 const double factor = pitch * source * shift;
Chris@88 181 const double *w = templateFor(i, n, f);
Chris@85 182
Chris@116 183 for (int j = 0; j < m_binCount; ++j) {
Chris@116 184 contributions[j] = w[j];
Chris@116 185 }
Chris@116 186 for (int j = 0; j < m_binCount; ++j) {
Chris@116 187 contributions[j] *= m_q.at(j);
Chris@116 188 }
Chris@116 189 for (int j = 0; j < m_binCount; ++j) {
Chris@116 190 contributions[j] *= factor;
Chris@116 191 }
Chris@116 192
Chris@116 193 double total = 0.0;
Chris@116 194 for (int j = 0; j < m_binCount; ++j) {
Chris@116 195 total += contributions.at(j);
Chris@116 196 }
Chris@116 197
Chris@86 198 if (n >= m_lowestPitch && n <= m_highestPitch) {
Chris@85 199
Chris@116 200 newPitches[n] += total;
Chris@116 201
Chris@85 202 if (inRange(i, n)) {
Chris@116 203 newSources[i][n] += total;
Chris@55 204 }
Chris@36 205 }
Chris@86 206
Chris@116 207 newShifts[f][n] += total;
Chris@36 208 }
Chris@36 209 }
Chris@85 210 }
Chris@85 211
Chris@85 212 for (int n = 0; n < m_noteCount; ++n) {
Chris@42 213 if (m_pitchSparsity != 1.0) {
Chris@42 214 newPitches[n] = pow(newPitches[n], m_pitchSparsity);
Chris@42 215 }
Chris@85 216 if (m_sourceSparsity != 1.0) {
Chris@55 217 for (int i = 0; i < m_instrumentCount; ++i) {
Chris@42 218 newSources[i][n] = pow(newSources[i][n], m_sourceSparsity);
Chris@42 219 }
Chris@36 220 }
Chris@36 221 }
Chris@85 222
Chris@85 223 normaliseColumn(newPitches);
Chris@85 224 normaliseGrid(newShifts);
Chris@55 225 normaliseGrid(newSources);
Chris@36 226
Chris@36 227 m_pitches = newPitches;
Chris@55 228 m_shifts = newShifts;
Chris@36 229 m_sources = newSources;
Chris@36 230 }
Chris@36 231
Chris@36 232