Mercurial > hg > svapp
comparison align/TransformDTWAligner.cpp @ 778:83a7b10b7415
Merge from branch pitch-align
author | Chris Cannam |
---|---|
date | Fri, 26 Jun 2020 13:48:52 +0100 |
parents | 699b5b130ea2 |
children | 8fa98f89eda8 |
comparison
equal
deleted
inserted
replaced
774:7bded7599874 | 778:83a7b10b7415 |
---|---|
1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ | |
2 | |
3 /* | |
4 Sonic Visualiser | |
5 An audio file viewer and annotation editor. | |
6 Centre for Digital Music, Queen Mary, University of London. | |
7 | |
8 This program is free software; you can redistribute it and/or | |
9 modify it under the terms of the GNU General Public License as | |
10 published by the Free Software Foundation; either version 2 of the | |
11 License, or (at your option) any later version. See the file | |
12 COPYING included with this distribution for more information. | |
13 */ | |
14 | |
15 #include "TransformDTWAligner.h" | |
16 #include "DTW.h" | |
17 | |
18 #include "data/model/SparseTimeValueModel.h" | |
19 #include "data/model/NoteModel.h" | |
20 #include "data/model/RangeSummarisableTimeValueModel.h" | |
21 #include "data/model/AlignmentModel.h" | |
22 #include "data/model/AggregateWaveModel.h" | |
23 | |
24 #include "framework/Document.h" | |
25 | |
26 #include "transform/ModelTransformerFactory.h" | |
27 #include "transform/FeatureExtractionModelTransformer.h" | |
28 | |
29 #include <QSettings> | |
30 #include <QMutex> | |
31 #include <QMutexLocker> | |
32 | |
33 using std::vector; | |
34 | |
35 static | |
36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor = | |
37 [](double x) { | |
38 return x; | |
39 }; | |
40 | |
41 static | |
42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor = | |
43 [](double prev, double curr) { | |
44 if (prev == curr) { | |
45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 }); | |
46 } else if (curr > prev) { | |
47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev }); | |
48 } else { | |
49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr }); | |
50 } | |
51 }; | |
52 | |
53 QMutex | |
54 TransformDTWAligner::m_dtwMutex; | |
55 | |
56 TransformDTWAligner::TransformDTWAligner(Document *doc, | |
57 ModelId reference, | |
58 ModelId toAlign, | |
59 Transform transform, | |
60 DTWType dtwType) : | |
61 m_document(doc), | |
62 m_reference(reference), | |
63 m_toAlign(toAlign), | |
64 m_transform(transform), | |
65 m_dtwType(dtwType), | |
66 m_incomplete(true), | |
67 m_magnitudePreprocessor(identityMagnitudePreprocessor), | |
68 m_riseFallPreprocessor(identityRiseFallPreprocessor) | |
69 { | |
70 } | |
71 | |
72 TransformDTWAligner::TransformDTWAligner(Document *doc, | |
73 ModelId reference, | |
74 ModelId toAlign, | |
75 Transform transform, | |
76 MagnitudePreprocessor outputPreprocessor) : | |
77 m_document(doc), | |
78 m_reference(reference), | |
79 m_toAlign(toAlign), | |
80 m_transform(transform), | |
81 m_dtwType(Magnitude), | |
82 m_incomplete(true), | |
83 m_magnitudePreprocessor(outputPreprocessor), | |
84 m_riseFallPreprocessor(identityRiseFallPreprocessor) | |
85 { | |
86 } | |
87 | |
88 TransformDTWAligner::TransformDTWAligner(Document *doc, | |
89 ModelId reference, | |
90 ModelId toAlign, | |
91 Transform transform, | |
92 RiseFallPreprocessor outputPreprocessor) : | |
93 m_document(doc), | |
94 m_reference(reference), | |
95 m_toAlign(toAlign), | |
96 m_transform(transform), | |
97 m_dtwType(RiseFall), | |
98 m_incomplete(true), | |
99 m_magnitudePreprocessor(identityMagnitudePreprocessor), | |
100 m_riseFallPreprocessor(outputPreprocessor) | |
101 { | |
102 } | |
103 | |
104 TransformDTWAligner::~TransformDTWAligner() | |
105 { | |
106 if (m_incomplete) { | |
107 if (auto toAlign = ModelById::get(m_toAlign)) { | |
108 toAlign->setAlignment({}); | |
109 } | |
110 } | |
111 | |
112 ModelById::release(m_referenceOutputModel); | |
113 ModelById::release(m_toAlignOutputModel); | |
114 } | |
115 | |
116 bool | |
117 TransformDTWAligner::isAvailable() | |
118 { | |
119 //!!! needs to be isAvailable(QString transformId)? | |
120 return true; | |
121 } | |
122 | |
123 void | |
124 TransformDTWAligner::begin() | |
125 { | |
126 auto reference = | |
127 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference); | |
128 auto toAlign = | |
129 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign); | |
130 | |
131 if (!reference || !toAlign) return; | |
132 | |
133 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning " | |
134 << m_toAlign << " against reference " << m_reference << endl; | |
135 | |
136 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance(); | |
137 | |
138 QString message; | |
139 | |
140 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message); | |
141 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); | |
142 if (!referenceOutputModel) { | |
143 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl; | |
144 emit failed(m_toAlign, message); | |
145 return; | |
146 } | |
147 | |
148 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
149 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " | |
150 << m_transform.getIdentifier() | |
151 << " is running on reference model" << endl; | |
152 #endif | |
153 | |
154 message = ""; | |
155 | |
156 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message); | |
157 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); | |
158 if (!toAlignOutputModel) { | |
159 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl; | |
160 emit failed(m_toAlign, message); | |
161 return; | |
162 } | |
163 | |
164 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
165 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " | |
166 << m_transform.getIdentifier() | |
167 << " is running on toAlign model" << endl; | |
168 #endif | |
169 | |
170 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)), | |
171 this, SLOT(completionChanged(ModelId))); | |
172 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)), | |
173 this, SLOT(completionChanged(ModelId))); | |
174 | |
175 auto alignmentModel = std::make_shared<AlignmentModel> | |
176 (m_reference, m_toAlign, ModelId()); | |
177 m_alignmentModel = ModelById::add(alignmentModel); | |
178 | |
179 toAlign->setAlignment(m_alignmentModel); | |
180 m_document->addNonDerivedModel(m_alignmentModel); | |
181 | |
182 // we wouldn't normally expect these to be true here, but... | |
183 int completion = 0; | |
184 if (referenceOutputModel->isReady(&completion) && | |
185 toAlignOutputModel->isReady(&completion)) { | |
186 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models " | |
187 << "are ready already! calling performAlignment" << endl; | |
188 if (performAlignment()) { | |
189 emit complete(m_alignmentModel); | |
190 } else { | |
191 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW")); | |
192 } | |
193 } | |
194 } | |
195 | |
196 void | |
197 TransformDTWAligner::completionChanged(ModelId id) | |
198 { | |
199 if (!m_incomplete) { | |
200 return; | |
201 } | |
202 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
203 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
204 << "model " << id << endl; | |
205 #endif | |
206 | |
207 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); | |
208 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); | |
209 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); | |
210 | |
211 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) { | |
212 return; | |
213 } | |
214 | |
215 int referenceCompletion = 0, toAlignCompletion = 0; | |
216 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion); | |
217 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); | |
218 | |
219 if (referenceReady && toAlignReady) { | |
220 | |
221 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
222 << "both models ready, calling performAlignment" << endl; | |
223 | |
224 alignmentModel->setCompletion(95); | |
225 | |
226 if (performAlignment()) { | |
227 emit complete(m_alignmentModel); | |
228 } else { | |
229 emit failed(m_toAlign, tr("Alignment of transform outputs failed")); | |
230 } | |
231 | |
232 } else { | |
233 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
234 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
235 << "not ready yet: reference completion " << referenceCompletion | |
236 << ", toAlign completion " << toAlignCompletion << endl; | |
237 #endif | |
238 | |
239 int completion = std::min(referenceCompletion, | |
240 toAlignCompletion); | |
241 completion = (completion * 94) / 100; | |
242 alignmentModel->setCompletion(completion); | |
243 } | |
244 } | |
245 | |
246 bool | |
247 TransformDTWAligner::performAlignment() | |
248 { | |
249 if (m_dtwType == Magnitude) { | |
250 return performAlignmentMagnitude(); | |
251 } else { | |
252 return performAlignmentRiseFall(); | |
253 } | |
254 } | |
255 | |
256 bool | |
257 TransformDTWAligner::getValuesFrom(ModelId modelId, | |
258 vector<sv_frame_t> &frames, | |
259 vector<double> &values, | |
260 sv_frame_t &resolution) | |
261 { | |
262 EventVector events; | |
263 | |
264 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) { | |
265 resolution = model->getResolution(); | |
266 events = model->getAllEvents(); | |
267 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) { | |
268 resolution = model->getResolution(); | |
269 events = model->getAllEvents(); | |
270 } else { | |
271 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model " | |
272 << modelId << " is not supported" << endl; | |
273 return false; | |
274 } | |
275 | |
276 frames.clear(); | |
277 values.clear(); | |
278 | |
279 for (auto e: events) { | |
280 frames.push_back(e.getFrame()); | |
281 values.push_back(e.getValue()); | |
282 } | |
283 | |
284 return true; | |
285 } | |
286 | |
287 Path | |
288 TransformDTWAligner::makePath(const vector<size_t> &alignment, | |
289 const vector<sv_frame_t> &refFrames, | |
290 const vector<sv_frame_t> &otherFrames, | |
291 sv_samplerate_t sampleRate, | |
292 sv_frame_t resolution) | |
293 { | |
294 Path path(sampleRate, resolution); | |
295 | |
296 path.add(PathPoint(0, 0)); | |
297 | |
298 for (int i = 0; in_range_for(alignment, i); ++i) { | |
299 | |
300 // DTW returns "the index into s2 for each element in s1" | |
301 sv_frame_t refFrame = refFrames[i]; | |
302 | |
303 if (!in_range_for(otherFrames, alignment[i])) { | |
304 SVCERR << "TransformDTWAligner::makePath: Internal error: " | |
305 << "DTW maps index " << i << " in reference frame vector " | |
306 << "(size " << refFrames.size() << ") onto index " | |
307 << alignment[i] << " in other frame vector " | |
308 << "(only size " << otherFrames.size() << ")" << endl; | |
309 continue; | |
310 } | |
311 | |
312 sv_frame_t alignedFrame = otherFrames[alignment[i]]; | |
313 path.add(PathPoint(alignedFrame, refFrame)); | |
314 } | |
315 | |
316 return path; | |
317 } | |
318 | |
319 bool | |
320 TransformDTWAligner::performAlignmentMagnitude() | |
321 { | |
322 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); | |
323 if (!alignmentModel) { | |
324 return false; | |
325 } | |
326 | |
327 vector<sv_frame_t> refFrames, otherFrames; | |
328 vector<double> refValues, otherValues; | |
329 sv_frame_t resolution = 0; | |
330 | |
331 if (!getValuesFrom(m_referenceOutputModel, | |
332 refFrames, refValues, resolution)) { | |
333 return false; | |
334 } | |
335 | |
336 if (!getValuesFrom(m_toAlignOutputModel, | |
337 otherFrames, otherValues, resolution)) { | |
338 return false; | |
339 } | |
340 | |
341 vector<double> s1, s2; | |
342 for (double v: refValues) { | |
343 s1.push_back(m_magnitudePreprocessor(v)); | |
344 } | |
345 for (double v: otherValues) { | |
346 s2.push_back(m_magnitudePreprocessor(v)); | |
347 } | |
348 | |
349 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
350 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " | |
351 << "Have " << s1.size() << " events from reference, " | |
352 << s2.size() << " from toAlign" << endl; | |
353 #endif | |
354 | |
355 MagnitudeDTW dtw; | |
356 vector<size_t> alignment; | |
357 | |
358 { | |
359 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
360 SVCERR << "TransformDTWAligner[" << this | |
361 << "]: serialising DTW to avoid over-allocation" << endl; | |
362 #endif | |
363 QMutexLocker locker(&m_dtwMutex); | |
364 alignment = dtw.alignSeries(s1, s2); | |
365 } | |
366 | |
367 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
368 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " | |
369 << "DTW produced " << alignment.size() << " points:" << endl; | |
370 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) { | |
371 SVCERR << alignment[i] << " "; | |
372 } | |
373 SVCERR << endl; | |
374 #endif | |
375 | |
376 alignmentModel->setPath(makePath(alignment, | |
377 refFrames, | |
378 otherFrames, | |
379 alignmentModel->getSampleRate(), | |
380 resolution)); | |
381 alignmentModel->setCompletion(100); | |
382 | |
383 SVCERR << "TransformDTWAligner[" << this | |
384 << "]: performAlignmentMagnitude: Done" << endl; | |
385 | |
386 m_incomplete = false; | |
387 return true; | |
388 } | |
389 | |
390 bool | |
391 TransformDTWAligner::performAlignmentRiseFall() | |
392 { | |
393 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); | |
394 if (!alignmentModel) { | |
395 return false; | |
396 } | |
397 | |
398 vector<sv_frame_t> refFrames, otherFrames; | |
399 vector<double> refValues, otherValues; | |
400 sv_frame_t resolution = 0; | |
401 | |
402 if (!getValuesFrom(m_referenceOutputModel, | |
403 refFrames, refValues, resolution)) { | |
404 return false; | |
405 } | |
406 | |
407 if (!getValuesFrom(m_toAlignOutputModel, | |
408 otherFrames, otherValues, resolution)) { | |
409 return false; | |
410 } | |
411 | |
412 auto preprocess = | |
413 [this](const std::vector<double> &vv) { | |
414 vector<RiseFallDTW::Value> s; | |
415 double prev = 0.0; | |
416 for (auto curr: vv) { | |
417 s.push_back(m_riseFallPreprocessor(prev, curr)); | |
418 prev = curr; | |
419 } | |
420 return s; | |
421 }; | |
422 | |
423 vector<RiseFallDTW::Value> s1 = preprocess(refValues); | |
424 vector<RiseFallDTW::Value> s2 = preprocess(otherValues); | |
425 | |
426 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
427 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " | |
428 << "Have " << s1.size() << " events from reference, " | |
429 << s2.size() << " from toAlign" << endl; | |
430 | |
431 SVCERR << "Reference:" << endl; | |
432 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) { | |
433 SVCERR << s1[i] << " "; | |
434 } | |
435 SVCERR << endl; | |
436 | |
437 SVCERR << "toAlign:" << endl; | |
438 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) { | |
439 SVCERR << s2[i] << " "; | |
440 } | |
441 SVCERR << endl; | |
442 #endif | |
443 | |
444 RiseFallDTW dtw; | |
445 vector<size_t> alignment; | |
446 | |
447 { | |
448 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
449 SVCERR << "TransformDTWAligner[" << this | |
450 << "]: serialising DTW to avoid over-allocation" << endl; | |
451 #endif | |
452 QMutexLocker locker(&m_dtwMutex); | |
453 alignment = dtw.alignSeries(s1, s2); | |
454 } | |
455 | |
456 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER | |
457 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " | |
458 << "DTW produced " << alignment.size() << " points:" << endl; | |
459 for (int i = 0; i < alignment.size() && i < 100; ++i) { | |
460 SVCERR << alignment[i] << " "; | |
461 } | |
462 SVCERR << endl; | |
463 #endif | |
464 | |
465 alignmentModel->setPath(makePath(alignment, | |
466 refFrames, | |
467 otherFrames, | |
468 alignmentModel->getSampleRate(), | |
469 resolution)); | |
470 | |
471 alignmentModel->setCompletion(100); | |
472 | |
473 SVCERR << "TransformDTWAligner[" << this | |
474 << "]: performAlignmentRiseFall: Done" << endl; | |
475 | |
476 m_incomplete = false; | |
477 return true; | |
478 } |