comparison src/EM.cpp @ 113:c4eae816bdb3 bqvec-openmp

Simplify slightly, make HQ mode the default
author Chris Cannam
date Wed, 07 May 2014 09:08:52 +0100
parents 2169e7a448c5
children b2f0967cb8d1 6890dea115c3
comparison
equal deleted inserted replaced
112:2169e7a448c5 113:c4eae816bdb3
32 using namespace breakfastquay; 32 using namespace breakfastquay;
33 33
34 static double epsilon = 1e-16; 34 static double epsilon = 1e-16;
35 35
36 EM::EM(bool useShifts) : 36 EM::EM(bool useShifts) :
37 m_useShifts(useShifts),
38 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), 37 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
39 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), 38 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1),
40 m_binCount(SILVET_TEMPLATE_HEIGHT), 39 m_binCount(SILVET_TEMPLATE_HEIGHT),
41 m_sourceCount(SILVET_TEMPLATE_COUNT), 40 m_sourceCount(SILVET_TEMPLATE_COUNT),
42 m_pitchSparsity(1.1), 41 m_pitchSparsity(1.1),
48 m_updatePitches = allocate<double>(m_noteCount); 47 m_updatePitches = allocate<double>(m_noteCount);
49 for (int n = 0; n < m_noteCount; ++n) { 48 for (int n = 0; n < m_noteCount; ++n) {
50 m_pitches[n] = drand48(); 49 m_pitches[n] = drand48();
51 } 50 }
52 51
53 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount); 52 if (useShifts) {
54 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount); 53 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount);
55 for (int f = 0; f < m_shiftCount; ++f) { 54 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount);
56 for (int n = 0; n < m_noteCount; ++n) { 55 for (int f = 0; f < m_shiftCount; ++f) {
57 if (m_useShifts) { 56 for (int n = 0; n < m_noteCount; ++n) {
58 m_shifts[f][n] = drand48(); 57 m_shifts[f][n] = drand48();
59 } else { 58 }
60 m_shifts[f][n] = 1.0; 59 }
61 } 60 } else {
62 } 61 m_shifts = 0;
62 m_updateShifts = 0;
63 } 63 }
64 64
65 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount); 65 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount);
66 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount); 66 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount);
67 for (int i = 0; i < m_sourceCount; ++i) { 67 for (int i = 0; i < m_sourceCount; ++i) {
138 } 138 }
139 139
140 const double * 140 const double *
141 EM::templateFor(int instrument, int note, int shift) 141 EM::templateFor(int instrument, int note, int shift)
142 { 142 {
143 if (m_useShifts) { 143 if (m_shifts) {
144 return silvet_templates[instrument].data[note] + shift; 144 return silvet_templates[instrument].data[note] + shift;
145 } else { 145 } else {
146 return silvet_templates[instrument].data[note] + 146 return silvet_templates[instrument].data[note] +
147 SILVET_TEMPLATE_MAX_SHIFT; 147 SILVET_TEMPLATE_MAX_SHIFT;
148 } 148 }
159 for (int n = 0; n < m_noteCount; ++n) { 159 for (int n = 0; n < m_noteCount; ++n) {
160 const double pitch = m_pitches[n]; 160 const double pitch = m_pitches[n];
161 const double source = m_sources[i][n]; 161 const double source = m_sources[i][n];
162 for (int f = 0; f < m_shiftCount; ++f) { 162 for (int f = 0; f < m_shiftCount; ++f) {
163 const double *w = templateFor(i, n, f); 163 const double *w = templateFor(i, n, f);
164 const double shift = m_shifts[f][n]; 164 const double shift = m_shifts ? m_shifts[f][n] : 1.0;
165 const double factor = pitch * source * shift; 165 const double factor = pitch * source * shift;
166 v_add_with_gain(m_estimate, w, factor, m_binCount); 166 v_add_with_gain(m_estimate, w, factor, m_binCount);
167 } 167 }
168 } 168 }
169 } 169 }
175 175
176 void 176 void
177 EM::maximisation(const double *column) 177 EM::maximisation(const double *column)
178 { 178 {
179 v_set(m_updatePitches, epsilon, m_noteCount); 179 v_set(m_updatePitches, epsilon, m_noteCount);
180 for (int i = 0; i < m_shiftCount; ++i) { 180
181 v_set(m_updateShifts[i], epsilon, m_noteCount);
182 }
183 for (int i = 0; i < m_sourceCount; ++i) { 181 for (int i = 0; i < m_sourceCount; ++i) {
184 v_set(m_updateSources[i], epsilon, m_noteCount); 182 v_set(m_updateSources[i], epsilon, m_noteCount);
185 } 183 }
186 184
185 if (m_shifts) {
186 for (int i = 0; i < m_shiftCount; ++i) {
187 v_set(m_updateShifts[i], epsilon, m_noteCount);
188 }
189 }
190
187 double *contributions = allocate<double>(m_binCount); 191 double *contributions = allocate<double>(m_binCount);
188 192
189 for (int n = 0; n < m_noteCount; ++n) { 193 for (int n = 0; n < m_noteCount; ++n) {
190 194
191 const double pitch = m_pitches[n]; 195 const double pitch = m_pitches[n];
192 196
193 for (int f = 0; f < m_shiftCount; ++f) { 197 for (int f = 0; f < m_shiftCount; ++f) {
194 198
195 const double shift = m_shifts[f][n]; 199 const double shift = m_shifts ? m_shifts[f][n] : 1.0;
196 200
197 for (int i = 0; i < m_sourceCount; ++i) { 201 for (int i = 0; i < m_sourceCount; ++i) {
198 202
199 const double source = m_sources[i][n]; 203 const double source = m_sources[i][n];
200 const double factor = pitch * source * shift; 204 const double factor = pitch * source * shift;
213 if (inRange(i, n)) { 217 if (inRange(i, n)) {
214 m_updateSources[i][n] += total; 218 m_updateSources[i][n] += total;
215 } 219 }
216 } 220 }
217 221
218 m_updateShifts[f][n] += total; 222 if (m_shifts) {
223 m_updateShifts[f][n] += total;
224 }
219 } 225 }
220 } 226 }
221 } 227 }
222 228
223 if (m_pitchSparsity != 1.0) { 229 if (m_pitchSparsity != 1.0) {
237 } 243 }
238 244
239 normaliseColumn(m_updatePitches, m_noteCount); 245 normaliseColumn(m_updatePitches, m_noteCount);
240 std::swap(m_pitches, m_updatePitches); 246 std::swap(m_pitches, m_updatePitches);
241 247
242 if (m_useShifts) { 248 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount);
249 std::swap(m_sources, m_updateSources);
250
251 if (m_shifts) {
243 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount); 252 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount);
244 std::swap(m_shifts, m_updateShifts); 253 std::swap(m_shifts, m_updateShifts);
245 } 254 }
246 255 }
247 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount); 256
248 std::swap(m_sources, m_updateSources); 257
249 }
250
251