Mercurial > hg > silvet
comparison src/EM.cpp @ 113:c4eae816bdb3 bqvec-openmp
Simplify slightly, make HQ mode the default
author | Chris Cannam |
---|---|
date | Wed, 07 May 2014 09:08:52 +0100 |
parents | 2169e7a448c5 |
children | b2f0967cb8d1 6890dea115c3 |
comparison
equal
deleted
inserted
replaced
112:2169e7a448c5 | 113:c4eae816bdb3 |
---|---|
32 using namespace breakfastquay; | 32 using namespace breakfastquay; |
33 | 33 |
34 static double epsilon = 1e-16; | 34 static double epsilon = 1e-16; |
35 | 35 |
36 EM::EM(bool useShifts) : | 36 EM::EM(bool useShifts) : |
37 m_useShifts(useShifts), | |
38 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), | 37 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), |
39 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), | 38 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), |
40 m_binCount(SILVET_TEMPLATE_HEIGHT), | 39 m_binCount(SILVET_TEMPLATE_HEIGHT), |
41 m_sourceCount(SILVET_TEMPLATE_COUNT), | 40 m_sourceCount(SILVET_TEMPLATE_COUNT), |
42 m_pitchSparsity(1.1), | 41 m_pitchSparsity(1.1), |
48 m_updatePitches = allocate<double>(m_noteCount); | 47 m_updatePitches = allocate<double>(m_noteCount); |
49 for (int n = 0; n < m_noteCount; ++n) { | 48 for (int n = 0; n < m_noteCount; ++n) { |
50 m_pitches[n] = drand48(); | 49 m_pitches[n] = drand48(); |
51 } | 50 } |
52 | 51 |
53 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount); | 52 if (useShifts) { |
54 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount); | 53 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount); |
55 for (int f = 0; f < m_shiftCount; ++f) { | 54 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount); |
56 for (int n = 0; n < m_noteCount; ++n) { | 55 for (int f = 0; f < m_shiftCount; ++f) { |
57 if (m_useShifts) { | 56 for (int n = 0; n < m_noteCount; ++n) { |
58 m_shifts[f][n] = drand48(); | 57 m_shifts[f][n] = drand48(); |
59 } else { | 58 } |
60 m_shifts[f][n] = 1.0; | 59 } |
61 } | 60 } else { |
62 } | 61 m_shifts = 0; |
62 m_updateShifts = 0; | |
63 } | 63 } |
64 | 64 |
65 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount); | 65 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount); |
66 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount); | 66 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount); |
67 for (int i = 0; i < m_sourceCount; ++i) { | 67 for (int i = 0; i < m_sourceCount; ++i) { |
138 } | 138 } |
139 | 139 |
140 const double * | 140 const double * |
141 EM::templateFor(int instrument, int note, int shift) | 141 EM::templateFor(int instrument, int note, int shift) |
142 { | 142 { |
143 if (m_useShifts) { | 143 if (m_shifts) { |
144 return silvet_templates[instrument].data[note] + shift; | 144 return silvet_templates[instrument].data[note] + shift; |
145 } else { | 145 } else { |
146 return silvet_templates[instrument].data[note] + | 146 return silvet_templates[instrument].data[note] + |
147 SILVET_TEMPLATE_MAX_SHIFT; | 147 SILVET_TEMPLATE_MAX_SHIFT; |
148 } | 148 } |
159 for (int n = 0; n < m_noteCount; ++n) { | 159 for (int n = 0; n < m_noteCount; ++n) { |
160 const double pitch = m_pitches[n]; | 160 const double pitch = m_pitches[n]; |
161 const double source = m_sources[i][n]; | 161 const double source = m_sources[i][n]; |
162 for (int f = 0; f < m_shiftCount; ++f) { | 162 for (int f = 0; f < m_shiftCount; ++f) { |
163 const double *w = templateFor(i, n, f); | 163 const double *w = templateFor(i, n, f); |
164 const double shift = m_shifts[f][n]; | 164 const double shift = m_shifts ? m_shifts[f][n] : 1.0; |
165 const double factor = pitch * source * shift; | 165 const double factor = pitch * source * shift; |
166 v_add_with_gain(m_estimate, w, factor, m_binCount); | 166 v_add_with_gain(m_estimate, w, factor, m_binCount); |
167 } | 167 } |
168 } | 168 } |
169 } | 169 } |
175 | 175 |
176 void | 176 void |
177 EM::maximisation(const double *column) | 177 EM::maximisation(const double *column) |
178 { | 178 { |
179 v_set(m_updatePitches, epsilon, m_noteCount); | 179 v_set(m_updatePitches, epsilon, m_noteCount); |
180 for (int i = 0; i < m_shiftCount; ++i) { | 180 |
181 v_set(m_updateShifts[i], epsilon, m_noteCount); | |
182 } | |
183 for (int i = 0; i < m_sourceCount; ++i) { | 181 for (int i = 0; i < m_sourceCount; ++i) { |
184 v_set(m_updateSources[i], epsilon, m_noteCount); | 182 v_set(m_updateSources[i], epsilon, m_noteCount); |
185 } | 183 } |
186 | 184 |
185 if (m_shifts) { | |
186 for (int i = 0; i < m_shiftCount; ++i) { | |
187 v_set(m_updateShifts[i], epsilon, m_noteCount); | |
188 } | |
189 } | |
190 | |
187 double *contributions = allocate<double>(m_binCount); | 191 double *contributions = allocate<double>(m_binCount); |
188 | 192 |
189 for (int n = 0; n < m_noteCount; ++n) { | 193 for (int n = 0; n < m_noteCount; ++n) { |
190 | 194 |
191 const double pitch = m_pitches[n]; | 195 const double pitch = m_pitches[n]; |
192 | 196 |
193 for (int f = 0; f < m_shiftCount; ++f) { | 197 for (int f = 0; f < m_shiftCount; ++f) { |
194 | 198 |
195 const double shift = m_shifts[f][n]; | 199 const double shift = m_shifts ? m_shifts[f][n] : 1.0; |
196 | 200 |
197 for (int i = 0; i < m_sourceCount; ++i) { | 201 for (int i = 0; i < m_sourceCount; ++i) { |
198 | 202 |
199 const double source = m_sources[i][n]; | 203 const double source = m_sources[i][n]; |
200 const double factor = pitch * source * shift; | 204 const double factor = pitch * source * shift; |
213 if (inRange(i, n)) { | 217 if (inRange(i, n)) { |
214 m_updateSources[i][n] += total; | 218 m_updateSources[i][n] += total; |
215 } | 219 } |
216 } | 220 } |
217 | 221 |
218 m_updateShifts[f][n] += total; | 222 if (m_shifts) { |
223 m_updateShifts[f][n] += total; | |
224 } | |
219 } | 225 } |
220 } | 226 } |
221 } | 227 } |
222 | 228 |
223 if (m_pitchSparsity != 1.0) { | 229 if (m_pitchSparsity != 1.0) { |
237 } | 243 } |
238 | 244 |
239 normaliseColumn(m_updatePitches, m_noteCount); | 245 normaliseColumn(m_updatePitches, m_noteCount); |
240 std::swap(m_pitches, m_updatePitches); | 246 std::swap(m_pitches, m_updatePitches); |
241 | 247 |
242 if (m_useShifts) { | 248 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount); |
249 std::swap(m_sources, m_updateSources); | |
250 | |
251 if (m_shifts) { | |
243 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount); | 252 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount); |
244 std::swap(m_shifts, m_updateShifts); | 253 std::swap(m_shifts, m_updateShifts); |
245 } | 254 } |
246 | 255 } |
247 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount); | 256 |
248 std::swap(m_sources, m_updateSources); | 257 |
249 } | |
250 | |
251 |