annotate src/EM.cpp @ 115:fbf9b824aaf3 bqvec-openmp

Report on last couple of tests
author Chris Cannam
date Wed, 07 May 2014 09:48:56 +0100
parents b2f0967cb8d1
children
rev   line source
Chris@34 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@34 2
Chris@34 3 /*
Chris@34 4 Silvet
Chris@34 5
Chris@34 6 A Vamp plugin for note transcription.
Chris@34 7 Centre for Digital Music, Queen Mary University of London.
Chris@34 8
Chris@34 9 This program is free software; you can redistribute it and/or
Chris@34 10 modify it under the terms of the GNU General Public License as
Chris@34 11 published by the Free Software Foundation; either version 2 of the
Chris@34 12 License, or (at your option) any later version. See the file
Chris@34 13 COPYING included with this distribution for more information.
Chris@34 14 */
Chris@34 15
Chris@34 16 #include "EM.h"
Chris@34 17
Chris@34 18 #include "data/include/templates.h"
Chris@34 19
Chris@36 20 #include <cstdlib>
Chris@42 21 #include <cmath>
Chris@36 22
Chris@36 23 #include <iostream>
Chris@36 24
Chris@91 25 #include "VectorOps.h"
Chris@91 26 #include "Allocators.h"
Chris@36 27
Chris@36 28 using std::vector;
Chris@36 29 using std::cerr;
Chris@36 30 using std::endl;
Chris@36 31
Chris@91 32 using namespace breakfastquay;
Chris@91 33
Chris@35 34 static double epsilon = 1e-16;
Chris@35 35
Chris@114 36 bool EM::m_initialised = false;
Chris@114 37 double ****EM::m_templates = 0;
Chris@114 38
Chris@110 39 EM::EM(bool useShifts) :
Chris@45 40 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
Chris@110 41 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1),
Chris@45 42 m_binCount(SILVET_TEMPLATE_HEIGHT),
Chris@91 43 m_sourceCount(SILVET_TEMPLATE_COUNT),
Chris@42 44 m_pitchSparsity(1.1),
Chris@83 45 m_sourceSparsity(1.3),
Chris@83 46 m_lowestPitch(silvet_templates_lowest_note),
Chris@83 47 m_highestPitch(silvet_templates_highest_note)
Chris@35 48 {
Chris@114 49 if (!m_initialised) {
Chris@114 50 cerr << "ERROR: You must call EM::initialise() before constructing any EM objects" << endl;
Chris@114 51 abort();
Chris@114 52 }
Chris@114 53
Chris@91 54 m_pitches = allocate<double>(m_noteCount);
Chris@100 55 m_updatePitches = allocate<double>(m_noteCount);
Chris@55 56 for (int n = 0; n < m_noteCount; ++n) {
Chris@55 57 m_pitches[n] = drand48();
Chris@55 58 }
Chris@114 59
Chris@114 60 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount);
Chris@114 61 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount);
Chris@114 62 for (int i = 0; i < m_sourceCount; ++i) {
Chris@114 63 for (int n = 0; n < m_noteCount; ++n) {
Chris@114 64 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
Chris@114 65 }
Chris@114 66 }
Chris@35 67
Chris@113 68 if (useShifts) {
Chris@113 69 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount);
Chris@113 70 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount);
Chris@113 71 for (int f = 0; f < m_shiftCount; ++f) {
Chris@113 72 for (int n = 0; n < m_noteCount; ++n) {
Chris@110 73 m_shifts[f][n] = drand48();
Chris@110 74 }
Chris@55 75 }
Chris@113 76 } else {
Chris@113 77 m_shifts = 0;
Chris@113 78 m_updateShifts = 0;
Chris@35 79 }
Chris@35 80
Chris@91 81 m_estimate = allocate<double>(m_binCount);
Chris@91 82 m_q = allocate<double>(m_binCount);
Chris@35 83 }
Chris@35 84
Chris@35 85 EM::~EM()
Chris@35 86 {
Chris@92 87 deallocate(m_q);
Chris@92 88 deallocate(m_estimate);
Chris@92 89 deallocate_channels(m_sources, m_sourceCount);
Chris@100 90 deallocate_channels(m_updateSources, m_sourceCount);
Chris@92 91 deallocate_channels(m_shifts, m_shiftCount);
Chris@100 92 deallocate_channels(m_updateShifts, m_shiftCount);
Chris@92 93 deallocate(m_pitches);
Chris@100 94 deallocate(m_updatePitches);
Chris@35 95 }
Chris@35 96
Chris@45 97 void
Chris@114 98 EM::initialise()
Chris@114 99 {
Chris@114 100 //!!! need mutex
Chris@114 101
Chris@114 102 if (m_initialised) return;
Chris@114 103 m_templates = new double ***[SILVET_TEMPLATE_COUNT];
Chris@114 104 for (int i = 0; i < SILVET_TEMPLATE_COUNT; ++i) {
Chris@114 105 m_templates[i] = new double **[SILVET_TEMPLATE_NOTE_COUNT];
Chris@114 106 for (int n = 0; n < SILVET_TEMPLATE_NOTE_COUNT; ++n) {
Chris@114 107 m_templates[i][n] = new double *[SILVET_TEMPLATE_MAX_SHIFT * 2 + 1];
Chris@114 108 for (int f = 0; f < SILVET_TEMPLATE_MAX_SHIFT * 2 + 1; ++f) {
Chris@114 109 m_templates[i][n][f] = allocate<double>(SILVET_TEMPLATE_HEIGHT);
Chris@114 110 const float *t = silvet_templates[i].data[n] + f;
Chris@114 111 v_convert(m_templates[i][n][f], t, SILVET_TEMPLATE_HEIGHT);
Chris@114 112 }
Chris@114 113 }
Chris@114 114 }
Chris@114 115 m_initialised = true;
Chris@114 116 }
Chris@114 117
Chris@114 118 void
Chris@45 119 EM::rangeFor(int instrument, int &minPitch, int &maxPitch)
Chris@45 120 {
Chris@55 121 minPitch = silvet_templates[instrument].lowest;
Chris@55 122 maxPitch = silvet_templates[instrument].highest;
Chris@45 123 }
Chris@45 124
Chris@35 125 bool
Chris@45 126 EM::inRange(int instrument, int pitch)
Chris@35 127 {
Chris@45 128 int minPitch, maxPitch;
Chris@45 129 rangeFor(instrument, minPitch, maxPitch);
Chris@45 130 return (pitch >= minPitch && pitch <= maxPitch);
Chris@35 131 }
Chris@35 132
Chris@36 133 void
Chris@92 134 EM::normaliseColumn(double *column, int size)
Chris@36 135 {
Chris@92 136 double sum = v_sum(column, size);
Chris@92 137 v_scale(column, 1.0 / sum, size);
Chris@36 138 }
Chris@36 139
Chris@36 140 void
Chris@92 141 EM::normaliseGrid(double **grid, int size1, int size2)
Chris@53 142 {
Chris@92 143 double *denominators = allocate_and_zero<double>(size2);
Chris@53 144
Chris@92 145 for (int i = 0; i < size1; ++i) {
Chris@92 146 for (int j = 0; j < size2; ++j) {
Chris@55 147 denominators[j] += grid[i][j];
Chris@53 148 }
Chris@53 149 }
Chris@53 150
Chris@92 151 for (int i = 0; i < size1; ++i) {
Chris@92 152 v_divide(grid[i], denominators, size2);
Chris@53 153 }
Chris@92 154
Chris@92 155 deallocate(denominators);
Chris@53 156 }
Chris@53 157
Chris@53 158 void
Chris@92 159 EM::iterate(const double *column)
Chris@36 160 {
Chris@92 161 double *norm = allocate<double>(m_binCount);
Chris@92 162 v_copy(norm, column, m_binCount);
Chris@92 163 normaliseColumn(norm, m_binCount);
Chris@92 164 expectation(norm);
Chris@92 165 maximisation(norm);
Chris@95 166 deallocate(norm);
Chris@36 167 }
Chris@36 168
Chris@88 169 const double *
Chris@55 170 EM::templateFor(int instrument, int note, int shift)
Chris@45 171 {
Chris@113 172 if (m_shifts) {
Chris@114 173 return m_templates[instrument][note][shift];
Chris@110 174 } else {
Chris@114 175 return m_templates[instrument][note][SILVET_TEMPLATE_MAX_SHIFT];
Chris@110 176 }
Chris@45 177 }
Chris@45 178
Chris@36 179 void
Chris@92 180 EM::expectation(const double *column)
Chris@36 181 {
Chris@62 182 // cerr << ".";
Chris@36 183
Chris@99 184 v_set(m_estimate, epsilon, m_binCount);
Chris@36 185
Chris@91 186 for (int i = 0; i < m_sourceCount; ++i) {
Chris@55 187 for (int n = 0; n < m_noteCount; ++n) {
Chris@83 188 const double pitch = m_pitches[n];
Chris@83 189 const double source = m_sources[i][n];
Chris@55 190 for (int f = 0; f < m_shiftCount; ++f) {
Chris@88 191 const double *w = templateFor(i, n, f);
Chris@113 192 const double shift = m_shifts ? m_shifts[f][n] : 1.0;
Chris@83 193 const double factor = pitch * source * shift;
Chris@111 194 v_add_with_gain(m_estimate, w, factor, m_binCount);
Chris@36 195 }
Chris@36 196 }
Chris@36 197 }
Chris@36 198
Chris@45 199 for (int i = 0; i < m_binCount; ++i) {
Chris@36 200 m_q[i] = column[i] / m_estimate[i];
Chris@36 201 }
Chris@36 202 }
Chris@36 203
Chris@36 204 void
Chris@92 205 EM::maximisation(const double *column)
Chris@36 206 {
Chris@100 207 v_set(m_updatePitches, epsilon, m_noteCount);
Chris@113 208
Chris@92 209 for (int i = 0; i < m_sourceCount; ++i) {
Chris@100 210 v_set(m_updateSources[i], epsilon, m_noteCount);
Chris@92 211 }
Chris@62 212
Chris@113 213 if (m_shifts) {
Chris@113 214 for (int i = 0; i < m_shiftCount; ++i) {
Chris@113 215 v_set(m_updateShifts[i], epsilon, m_noteCount);
Chris@113 216 }
Chris@113 217 }
Chris@113 218
Chris@94 219 double *contributions = allocate<double>(m_binCount);
Chris@36 220
Chris@55 221 for (int n = 0; n < m_noteCount; ++n) {
Chris@85 222
Chris@85 223 const double pitch = m_pitches[n];
Chris@85 224
Chris@85 225 for (int f = 0; f < m_shiftCount; ++f) {
Chris@85 226
Chris@113 227 const double shift = m_shifts ? m_shifts[f][n] : 1.0;
Chris@85 228
Chris@91 229 for (int i = 0; i < m_sourceCount; ++i) {
Chris@85 230
Chris@83 231 const double source = m_sources[i][n];
Chris@89 232 const double factor = pitch * source * shift;
Chris@88 233 const double *w = templateFor(i, n, f);
Chris@85 234
Chris@94 235 v_copy(contributions, w, m_binCount);
Chris@95 236 v_multiply(contributions, m_q, m_binCount);
Chris@94 237 v_scale(contributions, factor, m_binCount);
Chris@94 238
Chris@94 239 double total = v_sum(contributions, m_binCount);
Chris@94 240
Chris@86 241 if (n >= m_lowestPitch && n <= m_highestPitch) {
Chris@85 242
Chris@100 243 m_updatePitches[n] += total;
Chris@85 244
Chris@85 245 if (inRange(i, n)) {
Chris@100 246 m_updateSources[i][n] += total;
Chris@55 247 }
Chris@36 248 }
Chris@55 249
Chris@113 250 if (m_shifts) {
Chris@113 251 m_updateShifts[f][n] += total;
Chris@113 252 }
Chris@42 253 }
Chris@36 254 }
Chris@36 255 }
Chris@36 256
Chris@103 257 if (m_pitchSparsity != 1.0) {
Chris@103 258 for (int n = 0; n < m_noteCount; ++n) {
Chris@103 259 m_updatePitches[n] =
Chris@103 260 pow(m_updatePitches[n], m_pitchSparsity);
Chris@62 261 }
Chris@103 262 }
Chris@103 263
Chris@103 264 if (m_sourceSparsity != 1.0) {
Chris@103 265 for (int n = 0; n < m_noteCount; ++n) {
Chris@91 266 for (int i = 0; i < m_sourceCount; ++i) {
Chris@103 267 m_updateSources[i][n] =
Chris@103 268 pow(m_updateSources[i][n], m_sourceSparsity);
Chris@62 269 }
Chris@62 270 }
Chris@62 271 }
Chris@85 272
Chris@100 273 normaliseColumn(m_updatePitches, m_noteCount);
Chris@112 274 std::swap(m_pitches, m_updatePitches);
Chris@112 275
Chris@113 276 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount);
Chris@113 277 std::swap(m_sources, m_updateSources);
Chris@113 278
Chris@113 279 if (m_shifts) {
Chris@112 280 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount);
Chris@112 281 std::swap(m_shifts, m_updateShifts);
Chris@112 282 }
Chris@36 283 }
Chris@36 284
Chris@36 285