Mercurial > hg > svapp
comparison align/TransformDTWAligner.cpp @ 771:1d6cca5a5621 pitch-align
Allow use of proper sparse models (i.e. retaining event time info) in alignment; use this to switch to note alignment, which is what we have most recently been doing in the external program. Not currently producing correct results, though
author | Chris Cannam |
---|---|
date | Fri, 29 May 2020 17:39:02 +0100 |
parents | a316cb6fed81 |
children | 699b5b130ea2 |
comparison
equal
deleted
inserted
replaced
770:486add472c3f | 771:1d6cca5a5621 |
---|---|
14 | 14 |
15 #include "TransformDTWAligner.h" | 15 #include "TransformDTWAligner.h" |
16 #include "DTW.h" | 16 #include "DTW.h" |
17 | 17 |
18 #include "data/model/SparseTimeValueModel.h" | 18 #include "data/model/SparseTimeValueModel.h" |
19 #include "data/model/NoteModel.h" | |
19 #include "data/model/RangeSummarisableTimeValueModel.h" | 20 #include "data/model/RangeSummarisableTimeValueModel.h" |
20 #include "data/model/AlignmentModel.h" | 21 #include "data/model/AlignmentModel.h" |
21 #include "data/model/AggregateWaveModel.h" | 22 #include "data/model/AggregateWaveModel.h" |
22 | 23 |
23 #include "framework/Document.h" | 24 #include "framework/Document.h" |
28 #include <QSettings> | 29 #include <QSettings> |
29 #include <QMutex> | 30 #include <QMutex> |
30 #include <QMutexLocker> | 31 #include <QMutexLocker> |
31 | 32 |
32 using std::vector; | 33 using std::vector; |
34 | |
35 static | |
36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor = | |
37 [](double x) { | |
38 return x; | |
39 }; | |
40 | |
41 static | |
42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor = | |
43 [](double prev, double curr) { | |
44 if (prev == curr) { | |
45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 }); | |
46 } else if (curr > prev) { | |
47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev }); | |
48 } else { | |
49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr }); | |
50 } | |
51 }; | |
52 | |
53 QMutex | |
54 TransformDTWAligner::m_dtwMutex; | |
33 | 55 |
34 TransformDTWAligner::TransformDTWAligner(Document *doc, | 56 TransformDTWAligner::TransformDTWAligner(Document *doc, |
35 ModelId reference, | 57 ModelId reference, |
36 ModelId toAlign, | 58 ModelId toAlign, |
37 Transform transform, | 59 Transform transform, |
40 m_reference(reference), | 62 m_reference(reference), |
41 m_toAlign(toAlign), | 63 m_toAlign(toAlign), |
42 m_transform(transform), | 64 m_transform(transform), |
43 m_dtwType(dtwType), | 65 m_dtwType(dtwType), |
44 m_incomplete(true), | 66 m_incomplete(true), |
45 m_outputPreprocessor([](double x) { return x; }) | 67 m_magnitudePreprocessor(identityMagnitudePreprocessor), |
68 m_riseFallPreprocessor(identityRiseFallPreprocessor) | |
46 { | 69 { |
47 } | 70 } |
48 | 71 |
49 TransformDTWAligner::TransformDTWAligner(Document *doc, | 72 TransformDTWAligner::TransformDTWAligner(Document *doc, |
50 ModelId reference, | 73 ModelId reference, |
51 ModelId toAlign, | 74 ModelId toAlign, |
52 Transform transform, | 75 Transform transform, |
53 DTWType dtwType, | 76 MagnitudePreprocessor outputPreprocessor) : |
54 std::function<double(double)> | |
55 outputPreprocessor) : | |
56 m_document(doc), | 77 m_document(doc), |
57 m_reference(reference), | 78 m_reference(reference), |
58 m_toAlign(toAlign), | 79 m_toAlign(toAlign), |
59 m_transform(transform), | 80 m_transform(transform), |
60 m_dtwType(dtwType), | 81 m_dtwType(Magnitude), |
61 m_incomplete(true), | 82 m_incomplete(true), |
62 m_outputPreprocessor(outputPreprocessor) | 83 m_magnitudePreprocessor(outputPreprocessor), |
84 m_riseFallPreprocessor(identityRiseFallPreprocessor) | |
85 { | |
86 } | |
87 | |
88 TransformDTWAligner::TransformDTWAligner(Document *doc, | |
89 ModelId reference, | |
90 ModelId toAlign, | |
91 Transform transform, | |
92 RiseFallPreprocessor outputPreprocessor) : | |
93 m_document(doc), | |
94 m_reference(reference), | |
95 m_toAlign(toAlign), | |
96 m_transform(transform), | |
97 m_dtwType(RiseFall), | |
98 m_incomplete(true), | |
99 m_magnitudePreprocessor(identityMagnitudePreprocessor), | |
100 m_riseFallPreprocessor(outputPreprocessor) | |
63 { | 101 { |
64 } | 102 } |
65 | 103 |
66 TransformDTWAligner::~TransformDTWAligner() | 104 TransformDTWAligner::~TransformDTWAligner() |
67 { | 105 { |
155 TransformDTWAligner::completionChanged(ModelId id) | 193 TransformDTWAligner::completionChanged(ModelId id) |
156 { | 194 { |
157 if (!m_incomplete) { | 195 if (!m_incomplete) { |
158 return; | 196 return; |
159 } | 197 } |
160 | 198 /* |
161 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | 199 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " |
162 << "model " << id << endl; | 200 << "model " << id << endl; |
163 | 201 */ |
164 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); | 202 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); |
165 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); | 203 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); |
166 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); | 204 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); |
167 | 205 |
168 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) { | 206 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) { |
174 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); | 212 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); |
175 | 213 |
176 if (referenceReady && toAlignReady) { | 214 if (referenceReady && toAlignReady) { |
177 | 215 |
178 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | 216 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " |
179 << "ready, calling performAlignment" << endl; | 217 << "both models ready, calling performAlignment" << endl; |
180 | 218 |
181 alignmentModel->setCompletion(95); | 219 alignmentModel->setCompletion(95); |
182 | 220 |
183 if (performAlignment()) { | 221 if (performAlignment()) { |
184 emit complete(m_alignmentModel); | 222 emit complete(m_alignmentModel); |
185 } else { | 223 } else { |
186 emit failed(m_toAlign, tr("Alignment of transform outputs failed")); | 224 emit failed(m_toAlign, tr("Alignment of transform outputs failed")); |
187 } | 225 } |
188 | 226 |
189 } else { | 227 } else { |
190 | 228 /* |
191 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | 229 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " |
192 << "not ready yet: reference completion " << referenceCompletion | 230 << "not ready yet: reference completion " << referenceCompletion |
193 << ", toAlign completion " << toAlignCompletion << endl; | 231 << ", toAlign completion " << toAlignCompletion << endl; |
194 | 232 */ |
195 int completion = std::min(referenceCompletion, | 233 int completion = std::min(referenceCompletion, |
196 toAlignCompletion); | 234 toAlignCompletion); |
197 completion = (completion * 94) / 100; | 235 completion = (completion * 94) / 100; |
198 alignmentModel->setCompletion(completion); | 236 alignmentModel->setCompletion(completion); |
199 } | 237 } |
208 return performAlignmentRiseFall(); | 246 return performAlignmentRiseFall(); |
209 } | 247 } |
210 } | 248 } |
211 | 249 |
212 bool | 250 bool |
251 TransformDTWAligner::getValuesFrom(ModelId modelId, | |
252 vector<sv_frame_t> &frames, | |
253 vector<double> &values, | |
254 sv_frame_t &resolution) | |
255 { | |
256 EventVector events; | |
257 | |
258 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) { | |
259 resolution = model->getResolution(); | |
260 events = model->getAllEvents(); | |
261 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) { | |
262 resolution = model->getResolution(); | |
263 events = model->getAllEvents(); | |
264 } else { | |
265 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model " | |
266 << modelId << " is not supported" << endl; | |
267 return false; | |
268 } | |
269 | |
270 frames.clear(); | |
271 values.clear(); | |
272 | |
273 for (auto e: events) { | |
274 frames.push_back(e.getFrame()); | |
275 values.push_back(e.getValue()); | |
276 } | |
277 | |
278 return true; | |
279 } | |
280 | |
281 Path | |
282 TransformDTWAligner::makePath(const vector<size_t> &alignment, | |
283 const vector<sv_frame_t> &refFrames, | |
284 const vector<sv_frame_t> &otherFrames, | |
285 sv_samplerate_t sampleRate, | |
286 sv_frame_t resolution) | |
287 { | |
288 Path path(sampleRate, resolution); | |
289 | |
290 for (int i = 0; in_range_for(alignment, i); ++i) { | |
291 | |
292 // DTW returns "the index into s2 for each element in s1" | |
293 sv_frame_t refFrame = refFrames[i]; | |
294 | |
295 if (!in_range_for(otherFrames, alignment[i])) { | |
296 SVCERR << "TransformDTWAligner::makePath: Internal error: " | |
297 << "DTW maps index " << i << " in reference frame vector " | |
298 << "(size " << refFrames.size() << ") onto index " | |
299 << alignment[i] << " in other frame vector " | |
300 << "(only size " << otherFrames.size() << ")" << endl; | |
301 continue; | |
302 } | |
303 | |
304 sv_frame_t alignedFrame = otherFrames[alignment[i]]; | |
305 path.add(PathPoint(alignedFrame, refFrame)); | |
306 } | |
307 | |
308 return path; | |
309 } | |
310 | |
311 bool | |
213 TransformDTWAligner::performAlignmentMagnitude() | 312 TransformDTWAligner::performAlignmentMagnitude() |
214 { | 313 { |
215 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> | 314 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); |
216 (m_referenceOutputModel); | |
217 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
218 (m_toAlignOutputModel); | |
219 auto alignmentModel = ModelById::getAs<AlignmentModel> | |
220 (m_alignmentModel); | |
221 | |
222 if (!referenceOutputSTVM || !toAlignOutputSTVM) { | |
223 //!!! what? | |
224 return false; | |
225 } | |
226 | |
227 if (!alignmentModel) { | 315 if (!alignmentModel) { |
228 return false; | 316 return false; |
229 } | 317 } |
318 | |
319 vector<sv_frame_t> refFrames, otherFrames; | |
320 vector<double> refValues, otherValues; | |
321 sv_frame_t resolution = 0; | |
322 | |
323 if (!getValuesFrom(m_referenceOutputModel, | |
324 refFrames, refValues, resolution)) { | |
325 return false; | |
326 } | |
327 | |
328 if (!getValuesFrom(m_toAlignOutputModel, | |
329 otherFrames, otherValues, resolution)) { | |
330 return false; | |
331 } | |
230 | 332 |
231 vector<double> s1, s2; | 333 vector<double> s1, s2; |
232 | 334 for (double v: refValues) { |
233 { | 335 s1.push_back(m_magnitudePreprocessor(v)); |
234 auto events = referenceOutputSTVM->getAllEvents(); | 336 } |
235 for (auto e: events) { | 337 for (double v: otherValues) { |
236 s1.push_back(m_outputPreprocessor(e.getValue())); | 338 s2.push_back(m_magnitudePreprocessor(v)); |
237 } | |
238 events = toAlignOutputSTVM->getAllEvents(); | |
239 for (auto e: events) { | |
240 s2.push_back(m_outputPreprocessor(e.getValue())); | |
241 } | |
242 } | 339 } |
243 | 340 |
244 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " | 341 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " |
245 << "Have " << s1.size() << " events from reference, " | 342 << "Have " << s1.size() << " events from reference, " |
246 << s2.size() << " from toAlign" << endl; | 343 << s2.size() << " from toAlign" << endl; |
247 | 344 |
248 MagnitudeDTW dtw; | 345 MagnitudeDTW dtw; |
249 vector<size_t> alignment; | 346 vector<size_t> alignment; |
250 | 347 |
251 { | 348 { |
252 SVCERR << "TransformDTWAligner[" << this | 349 SVCERR << "TransformDTWAligner[" << this |
253 << "]: serialising DTW to avoid over-allocation" << endl; | 350 << "]: serialising DTW to avoid over-allocation" << endl; |
254 static QMutex mutex; | 351 QMutexLocker locker(&m_dtwMutex); |
255 QMutexLocker locker(&mutex); | |
256 | |
257 alignment = dtw.alignSeries(s1, s2); | 352 alignment = dtw.alignSeries(s1, s2); |
258 } | 353 } |
259 | 354 |
260 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " | 355 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " |
261 << "DTW produced " << alignment.size() << " points:" << endl; | 356 << "DTW produced " << alignment.size() << " points:" << endl; |
262 for (int i = 0; i < alignment.size() && i < 100; ++i) { | 357 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) { |
263 SVCERR << alignment[i] << " "; | 358 SVCERR << alignment[i] << " "; |
264 } | 359 } |
265 SVCERR << endl; | 360 SVCERR << endl; |
266 | 361 |
362 alignmentModel->setPath(makePath(alignment, | |
363 refFrames, | |
364 otherFrames, | |
365 alignmentModel->getSampleRate(), | |
366 resolution)); | |
267 alignmentModel->setCompletion(100); | 367 alignmentModel->setCompletion(100); |
268 | 368 |
269 sv_frame_t resolution = referenceOutputSTVM->getResolution(); | 369 SVCERR << "TransformDTWAligner[" << this |
270 sv_frame_t sourceFrame = 0; | 370 << "]: performAlignmentMagnitude: Done" << endl; |
271 | |
272 Path path(referenceOutputSTVM->getSampleRate(), resolution); | |
273 | |
274 for (size_t m: alignment) { | |
275 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); | |
276 sourceFrame += resolution; | |
277 } | |
278 | |
279 alignmentModel->setPath(path); | |
280 | |
281 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: Done" | |
282 << endl; | |
283 | 371 |
284 m_incomplete = false; | 372 m_incomplete = false; |
285 return true; | 373 return true; |
286 } | 374 } |
287 | 375 |
288 bool | 376 bool |
289 TransformDTWAligner::performAlignmentRiseFall() | 377 TransformDTWAligner::performAlignmentRiseFall() |
290 { | 378 { |
291 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> | 379 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); |
292 (m_referenceOutputModel); | |
293 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
294 (m_toAlignOutputModel); | |
295 auto alignmentModel = ModelById::getAs<AlignmentModel> | |
296 (m_alignmentModel); | |
297 | |
298 if (!referenceOutputSTVM || !toAlignOutputSTVM) { | |
299 //!!! what? | |
300 return false; | |
301 } | |
302 | |
303 if (!alignmentModel) { | 380 if (!alignmentModel) { |
304 return false; | 381 return false; |
305 } | 382 } |
306 | 383 |
307 auto convertEvents = | 384 vector<sv_frame_t> refFrames, otherFrames; |
308 [this](const EventVector &ee) { | 385 vector<double> refValues, otherValues; |
386 sv_frame_t resolution = 0; | |
387 | |
388 if (!getValuesFrom(m_referenceOutputModel, | |
389 refFrames, refValues, resolution)) { | |
390 return false; | |
391 } | |
392 | |
393 if (!getValuesFrom(m_toAlignOutputModel, | |
394 otherFrames, otherValues, resolution)) { | |
395 return false; | |
396 } | |
397 | |
398 auto preprocess = | |
399 [this](const std::vector<double> &vv) { | |
309 vector<RiseFallDTW::Value> s; | 400 vector<RiseFallDTW::Value> s; |
310 double prev = 0.0; | 401 double prev = 0.0; |
311 for (auto e: ee) { | 402 for (auto curr: vv) { |
312 double v = m_outputPreprocessor(e.getValue()); | 403 s.push_back(m_riseFallPreprocessor(prev, curr)); |
313 if (v == prev || s.empty()) { | 404 prev = curr; |
314 s.push_back({ RiseFallDTW::Direction::None, 0.0 }); | |
315 } else if (v > prev) { | |
316 s.push_back({ RiseFallDTW::Direction::Up, v - prev }); | |
317 } else { | |
318 s.push_back({ RiseFallDTW::Direction::Down, prev - v }); | |
319 } | |
320 } | 405 } |
321 return s; | 406 return s; |
322 }; | 407 }; |
323 | 408 |
324 vector<RiseFallDTW::Value> s1 = | 409 vector<RiseFallDTW::Value> s1 = preprocess(refValues); |
325 convertEvents(referenceOutputSTVM->getAllEvents()); | 410 vector<RiseFallDTW::Value> s2 = preprocess(otherValues); |
326 | |
327 vector<RiseFallDTW::Value> s2 = | |
328 convertEvents(toAlignOutputSTVM->getAllEvents()); | |
329 | 411 |
330 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " | 412 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " |
331 << "Have " << s1.size() << " events from reference, " | 413 << "Have " << s1.size() << " events from reference, " |
332 << s2.size() << " from toAlign" << endl; | 414 << s2.size() << " from toAlign" << endl; |
333 | 415 |
416 SVCERR << "Reference:" << endl; | |
417 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) { | |
418 SVCERR << s1[i] << " "; | |
419 } | |
420 SVCERR << endl; | |
421 | |
422 SVCERR << "toAlign:" << endl; | |
423 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) { | |
424 SVCERR << s2[i] << " "; | |
425 } | |
426 SVCERR << endl; | |
427 | |
334 RiseFallDTW dtw; | 428 RiseFallDTW dtw; |
335 | |
336 vector<size_t> alignment; | 429 vector<size_t> alignment; |
337 | 430 |
338 { | 431 { |
339 SVCERR << "TransformDTWAligner[" << this | 432 SVCERR << "TransformDTWAligner[" << this |
340 << "]: serialising DTW to avoid over-allocation" << endl; | 433 << "]: serialising DTW to avoid over-allocation" << endl; |
341 static QMutex mutex; | 434 QMutexLocker locker(&m_dtwMutex); |
342 QMutexLocker locker(&mutex); | |
343 | |
344 alignment = dtw.alignSeries(s1, s2); | 435 alignment = dtw.alignSeries(s1, s2); |
345 } | 436 } |
346 | 437 |
347 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " | 438 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " |
348 << "DTW produced " << alignment.size() << " points:" << endl; | 439 << "DTW produced " << alignment.size() << " points:" << endl; |
349 for (int i = 0; i < alignment.size() && i < 100; ++i) { | 440 for (int i = 0; i < alignment.size() && i < 100; ++i) { |
350 SVCERR << alignment[i] << " "; | 441 SVCERR << alignment[i] << " "; |
351 } | 442 } |
352 SVCERR << endl; | 443 SVCERR << endl; |
353 | 444 |
445 alignmentModel->setPath(makePath(alignment, | |
446 refFrames, | |
447 otherFrames, | |
448 alignmentModel->getSampleRate(), | |
449 resolution)); | |
450 | |
354 alignmentModel->setCompletion(100); | 451 alignmentModel->setCompletion(100); |
355 | |
356 sv_frame_t resolution = referenceOutputSTVM->getResolution(); | |
357 sv_frame_t sourceFrame = 0; | |
358 | |
359 Path path(referenceOutputSTVM->getSampleRate(), resolution); | |
360 | |
361 for (size_t m: alignment) { | |
362 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); | |
363 sourceFrame += resolution; | |
364 } | |
365 | |
366 alignmentModel->setPath(path); | |
367 | 452 |
368 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done" | 453 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done" |
369 << endl; | 454 << endl; |
370 | 455 |
371 m_incomplete = false; | 456 m_incomplete = false; |