annotate src/EM.cpp @ 167:416b555df3b2 finetune

More on returning fine tuning (but we're treating different shifts of the same pitch as different notes at the moment which is not right)
author Chris Cannam
date Tue, 20 May 2014 17:49:07 +0100
parents 629c9525b815
children 237d41a0f69d
rev   line source
Chris@34 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@34 2
Chris@34 3 /*
Chris@34 4 Silvet
Chris@34 5
Chris@34 6 A Vamp plugin for note transcription.
Chris@34 7 Centre for Digital Music, Queen Mary University of London.
Chris@34 8
Chris@34 9 This program is free software; you can redistribute it and/or
Chris@34 10 modify it under the terms of the GNU General Public License as
Chris@34 11 published by the Free Software Foundation; either version 2 of the
Chris@34 12 License, or (at your option) any later version. See the file
Chris@34 13 COPYING included with this distribution for more information.
Chris@34 14 */
Chris@34 15
Chris@34 16 #include "EM.h"
Chris@34 17
Chris@36 18 #include <cstdlib>
Chris@42 19 #include <cmath>
Chris@36 20
Chris@36 21 #include <iostream>
Chris@36 22
Chris@91 23 #include "VectorOps.h"
Chris@91 24 #include "Allocators.h"
Chris@161 25 #include "Instruments.h"
Chris@36 26
Chris@36 27 using std::vector;
Chris@36 28 using std::cerr;
Chris@36 29 using std::endl;
Chris@36 30
Chris@91 31 using namespace breakfastquay;
Chris@91 32
Chris@151 33 static float epsilon = 1e-10;
Chris@35 34
Chris@161 35 EM::EM(const InstrumentPack *pack, bool useShifts) :
Chris@161 36 m_pack(pack),
Chris@161 37 m_noteCount(pack->templateNoteCount),
Chris@161 38 m_shiftCount(useShifts ? pack->templateMaxShift * 2 + 1 : 1),
Chris@161 39 m_binCount(pack->templateHeight),
Chris@161 40 m_sourceCount(pack->templates.size()),
Chris@42 41 m_pitchSparsity(1.1),
Chris@150 42 //!!! note: slightly less source sparsity might help; also
Chris@150 43 //!!! consider a modest shift sparsity e.g. 1.1
Chris@83 44 m_sourceSparsity(1.3),
Chris@161 45 m_lowestPitch(pack->lowestNote),
Chris@161 46 m_highestPitch(pack->highestNote)
Chris@35 47 {
Chris@151 48 m_pitches = allocate<float>(m_noteCount);
Chris@151 49 m_updatePitches = allocate<float>(m_noteCount);
Chris@55 50 for (int n = 0; n < m_noteCount; ++n) {
Chris@55 51 m_pitches[n] = drand48();
Chris@55 52 }
Chris@35 53
Chris@113 54 if (useShifts) {
Chris@151 55 m_shifts = allocate_channels<float>(m_shiftCount, m_noteCount);
Chris@151 56 m_updateShifts = allocate_channels<float>(m_shiftCount, m_noteCount);
Chris@113 57 for (int f = 0; f < m_shiftCount; ++f) {
Chris@113 58 for (int n = 0; n < m_noteCount; ++n) {
Chris@110 59 m_shifts[f][n] = drand48();
Chris@110 60 }
Chris@55 61 }
Chris@113 62 } else {
Chris@113 63 m_shifts = 0;
Chris@113 64 m_updateShifts = 0;
Chris@35 65 }
Chris@35 66
Chris@151 67 m_sources = allocate_channels<float>(m_sourceCount, m_noteCount);
Chris@151 68 m_updateSources = allocate_channels<float>(m_sourceCount, m_noteCount);
Chris@91 69 for (int i = 0; i < m_sourceCount; ++i) {
Chris@55 70 for (int n = 0; n < m_noteCount; ++n) {
Chris@35 71 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
Chris@35 72 }
Chris@35 73 }
Chris@35 74
Chris@151 75 m_estimate = allocate<float>(m_binCount);
Chris@151 76 m_q = allocate<float>(m_binCount);
Chris@35 77 }
Chris@35 78
Chris@35 79 EM::~EM()
Chris@35 80 {
Chris@92 81 deallocate(m_q);
Chris@92 82 deallocate(m_estimate);
Chris@92 83 deallocate_channels(m_sources, m_sourceCount);
Chris@100 84 deallocate_channels(m_updateSources, m_sourceCount);
Chris@92 85 deallocate_channels(m_shifts, m_shiftCount);
Chris@100 86 deallocate_channels(m_updateShifts, m_shiftCount);
Chris@92 87 deallocate(m_pitches);
Chris@100 88 deallocate(m_updatePitches);
Chris@35 89 }
Chris@35 90
Chris@45 91 void
Chris@45 92 EM::rangeFor(int instrument, int &minPitch, int &maxPitch)
Chris@45 93 {
Chris@161 94 minPitch = m_pack->templates[instrument].lowestNote;
Chris@161 95 maxPitch = m_pack->templates[instrument].highestNote;
Chris@45 96 }
Chris@45 97
Chris@35 98 bool
Chris@45 99 EM::inRange(int instrument, int pitch)
Chris@35 100 {
Chris@45 101 int minPitch, maxPitch;
Chris@45 102 rangeFor(instrument, minPitch, maxPitch);
Chris@45 103 return (pitch >= minPitch && pitch <= maxPitch);
Chris@35 104 }
Chris@35 105
Chris@36 106 void
Chris@151 107 EM::normaliseColumn(float *column, int size)
Chris@36 108 {
Chris@151 109 float sum = v_sum(column, size);
Chris@92 110 v_scale(column, 1.0 / sum, size);
Chris@36 111 }
Chris@36 112
Chris@36 113 void
Chris@151 114 EM::normaliseGrid(float **grid, int size1, int size2)
Chris@53 115 {
Chris@151 116 float *denominators = allocate_and_zero<float>(size2);
Chris@122 117
Chris@92 118 for (int i = 0; i < size1; ++i) {
Chris@122 119 for (int j = 0; j < size2; ++j) {
Chris@122 120 denominators[j] += grid[i][j];
Chris@122 121 }
Chris@53 122 }
Chris@122 123
Chris@122 124 for (int i = 0; i < size1; ++i) {
Chris@122 125 v_divide(grid[i], denominators, size2);
Chris@122 126 }
Chris@122 127
Chris@122 128 deallocate(denominators);
Chris@53 129 }
Chris@53 130
Chris@53 131 void
Chris@92 132 EM::iterate(const double *column)
Chris@36 133 {
Chris@151 134 float *norm = allocate<float>(m_binCount);
Chris@151 135 v_convert(norm, column, m_binCount);
Chris@92 136 normaliseColumn(norm, m_binCount);
Chris@92 137 expectation(norm);
Chris@92 138 maximisation(norm);
Chris@95 139 deallocate(norm);
Chris@36 140 }
Chris@36 141
Chris@151 142 const float *
Chris@55 143 EM::templateFor(int instrument, int note, int shift)
Chris@45 144 {
Chris@161 145 const float *base = m_pack->templates.at(instrument).data.at(note).data();
Chris@113 146 if (m_shifts) {
Chris@161 147 return base + shift;
Chris@110 148 } else {
Chris@161 149 return base + m_pack->templateMaxShift;
Chris@110 150 }
Chris@45 151 }
Chris@45 152
Chris@36 153 void
Chris@151 154 EM::expectation(const float *column)
Chris@36 155 {
Chris@62 156 // cerr << ".";
Chris@36 157
Chris@99 158 v_set(m_estimate, epsilon, m_binCount);
Chris@36 159
Chris@130 160 for (int f = 0; f < m_shiftCount; ++f) {
Chris@130 161
Chris@151 162 const float *shiftIn = m_shifts ? m_shifts[f] : 0;
Chris@130 163
Chris@130 164 for (int i = 0; i < m_sourceCount; ++i) {
Chris@130 165
Chris@151 166 const float *sourceIn = m_sources[i];
Chris@130 167
Chris@130 168 int lowest, highest;
Chris@130 169 rangeFor(i, lowest, highest);
Chris@130 170
Chris@130 171 for (int n = lowest; n <= highest; ++n) {
Chris@130 172
Chris@151 173 const float source = sourceIn[n];
Chris@151 174 const float shift = shiftIn ? shiftIn[n] : 1.0;
Chris@151 175 const float pitch = m_pitches[n];
Chris@130 176
Chris@151 177 const float factor = pitch * source * shift;
Chris@151 178 const float *w = templateFor(i, n, f);
Chris@130 179
Chris@111 180 v_add_with_gain(m_estimate, w, factor, m_binCount);
Chris@36 181 }
Chris@36 182 }
Chris@36 183 }
Chris@36 184
Chris@45 185 for (int i = 0; i < m_binCount; ++i) {
Chris@36 186 m_q[i] = column[i] / m_estimate[i];
Chris@36 187 }
Chris@164 188
Chris@164 189 /*
Chris@164 190 double l2norm = 0.0;
Chris@164 191
Chris@164 192 for (int i = 0; i < m_binCount; ++i) {
Chris@164 193 l2norm += (column[i] - m_estimate[i]) * (column[i] - m_estimate[i]);
Chris@164 194 }
Chris@164 195
Chris@164 196 l2norm = sqrt(l2norm);
Chris@164 197 cerr << "l2norm = " << l2norm << endl;
Chris@164 198 */
Chris@36 199 }
Chris@36 200
Chris@36 201 void
Chris@151 202 EM::maximisation(const float *column)
Chris@36 203 {
Chris@100 204 v_set(m_updatePitches, epsilon, m_noteCount);
Chris@113 205
Chris@92 206 for (int i = 0; i < m_sourceCount; ++i) {
Chris@100 207 v_set(m_updateSources[i], epsilon, m_noteCount);
Chris@92 208 }
Chris@62 209
Chris@113 210 if (m_shifts) {
Chris@113 211 for (int i = 0; i < m_shiftCount; ++i) {
Chris@113 212 v_set(m_updateShifts[i], epsilon, m_noteCount);
Chris@113 213 }
Chris@113 214 }
Chris@113 215
Chris@151 216 float *contributions = allocate<float>(m_binCount);
Chris@36 217
Chris@130 218 for (int f = 0; f < m_shiftCount; ++f) {
Chris@85 219
Chris@151 220 const float *shiftIn = m_shifts ? m_shifts[f] : 0;
Chris@151 221 float *shiftOut = m_shifts ? m_updateShifts[f] : 0;
Chris@85 222
Chris@130 223 for (int i = 0; i < m_sourceCount; ++i) {
Chris@85 224
Chris@151 225 const float *sourceIn = m_sources[i];
Chris@151 226 float *sourceOut = m_updateSources[i];
Chris@85 227
Chris@130 228 int lowest, highest;
Chris@130 229 rangeFor(i, lowest, highest);
Chris@85 230
Chris@130 231 for (int n = lowest; n <= highest; ++n) {
Chris@130 232
Chris@151 233 const float shift = shiftIn ? shiftIn[n] : 1.0;
Chris@151 234 const float source = sourceIn[n];
Chris@151 235 const float pitch = m_pitches[n];
Chris@130 236
Chris@151 237 const float factor = pitch * source * shift;
Chris@151 238 const float *w = templateFor(i, n, f);
Chris@85 239
Chris@94 240 v_copy(contributions, w, m_binCount);
Chris@95 241 v_multiply(contributions, m_q, m_binCount);
Chris@94 242
Chris@151 243 float total = factor * v_sum(contributions, m_binCount);
Chris@94 244
Chris@130 245 m_updatePitches[n] += total;
Chris@130 246 sourceOut[n] += total;
Chris@85 247
Chris@130 248 if (shiftOut) {
Chris@130 249 shiftOut[n] += total;
Chris@113 250 }
Chris@42 251 }
Chris@36 252 }
Chris@36 253 }
Chris@36 254
Chris@103 255 if (m_pitchSparsity != 1.0) {
Chris@103 256 for (int n = 0; n < m_noteCount; ++n) {
Chris@103 257 m_updatePitches[n] =
Chris@103 258 pow(m_updatePitches[n], m_pitchSparsity);
Chris@62 259 }
Chris@103 260 }
Chris@103 261
Chris@103 262 if (m_sourceSparsity != 1.0) {
Chris@130 263 for (int i = 0; i < m_sourceCount; ++i) {
Chris@130 264 for (int n = 0; n < m_noteCount; ++n) {
Chris@103 265 m_updateSources[i][n] =
Chris@103 266 pow(m_updateSources[i][n], m_sourceSparsity);
Chris@62 267 }
Chris@62 268 }
Chris@62 269 }
Chris@85 270
Chris@100 271 normaliseColumn(m_updatePitches, m_noteCount);
Chris@112 272 std::swap(m_pitches, m_updatePitches);
Chris@112 273
Chris@113 274 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount);
Chris@113 275 std::swap(m_sources, m_updateSources);
Chris@113 276
Chris@113 277 if (m_shifts) {
Chris@112 278 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount);
Chris@112 279 std::swap(m_shifts, m_updateShifts);
Chris@112 280 }
Chris@36 281 }
Chris@36 282
Chris@36 283