comparison align/TransformDTWAligner.cpp @ 778:83a7b10b7415

Merge from branch pitch-align
author Chris Cannam
date Fri, 26 Jun 2020 13:48:52 +0100
parents 699b5b130ea2
children 8fa98f89eda8
comparison
equal deleted inserted replaced
774:7bded7599874 778:83a7b10b7415
1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
2
3 /*
4 Sonic Visualiser
5 An audio file viewer and annotation editor.
6 Centre for Digital Music, Queen Mary, University of London.
7
8 This program is free software; you can redistribute it and/or
9 modify it under the terms of the GNU General Public License as
10 published by the Free Software Foundation; either version 2 of the
11 License, or (at your option) any later version. See the file
12 COPYING included with this distribution for more information.
13 */
14
15 #include "TransformDTWAligner.h"
16 #include "DTW.h"
17
18 #include "data/model/SparseTimeValueModel.h"
19 #include "data/model/NoteModel.h"
20 #include "data/model/RangeSummarisableTimeValueModel.h"
21 #include "data/model/AlignmentModel.h"
22 #include "data/model/AggregateWaveModel.h"
23
24 #include "framework/Document.h"
25
26 #include "transform/ModelTransformerFactory.h"
27 #include "transform/FeatureExtractionModelTransformer.h"
28
29 #include <QSettings>
30 #include <QMutex>
31 #include <QMutexLocker>
32
33 using std::vector;
34
35 static
36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
37 [](double x) {
38 return x;
39 };
40
41 static
42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
43 [](double prev, double curr) {
44 if (prev == curr) {
45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
46 } else if (curr > prev) {
47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
48 } else {
49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
50 }
51 };
52
53 QMutex
54 TransformDTWAligner::m_dtwMutex;
55
56 TransformDTWAligner::TransformDTWAligner(Document *doc,
57 ModelId reference,
58 ModelId toAlign,
59 Transform transform,
60 DTWType dtwType) :
61 m_document(doc),
62 m_reference(reference),
63 m_toAlign(toAlign),
64 m_transform(transform),
65 m_dtwType(dtwType),
66 m_incomplete(true),
67 m_magnitudePreprocessor(identityMagnitudePreprocessor),
68 m_riseFallPreprocessor(identityRiseFallPreprocessor)
69 {
70 }
71
72 TransformDTWAligner::TransformDTWAligner(Document *doc,
73 ModelId reference,
74 ModelId toAlign,
75 Transform transform,
76 MagnitudePreprocessor outputPreprocessor) :
77 m_document(doc),
78 m_reference(reference),
79 m_toAlign(toAlign),
80 m_transform(transform),
81 m_dtwType(Magnitude),
82 m_incomplete(true),
83 m_magnitudePreprocessor(outputPreprocessor),
84 m_riseFallPreprocessor(identityRiseFallPreprocessor)
85 {
86 }
87
88 TransformDTWAligner::TransformDTWAligner(Document *doc,
89 ModelId reference,
90 ModelId toAlign,
91 Transform transform,
92 RiseFallPreprocessor outputPreprocessor) :
93 m_document(doc),
94 m_reference(reference),
95 m_toAlign(toAlign),
96 m_transform(transform),
97 m_dtwType(RiseFall),
98 m_incomplete(true),
99 m_magnitudePreprocessor(identityMagnitudePreprocessor),
100 m_riseFallPreprocessor(outputPreprocessor)
101 {
102 }
103
104 TransformDTWAligner::~TransformDTWAligner()
105 {
106 if (m_incomplete) {
107 if (auto toAlign = ModelById::get(m_toAlign)) {
108 toAlign->setAlignment({});
109 }
110 }
111
112 ModelById::release(m_referenceOutputModel);
113 ModelById::release(m_toAlignOutputModel);
114 }
115
116 bool
117 TransformDTWAligner::isAvailable()
118 {
119 //!!! needs to be isAvailable(QString transformId)?
120 return true;
121 }
122
123 void
124 TransformDTWAligner::begin()
125 {
126 auto reference =
127 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference);
128 auto toAlign =
129 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign);
130
131 if (!reference || !toAlign) return;
132
133 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning "
134 << m_toAlign << " against reference " << m_reference << endl;
135
136 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance();
137
138 QString message;
139
140 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message);
141 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
142 if (!referenceOutputModel) {
143 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl;
144 emit failed(m_toAlign, message);
145 return;
146 }
147
148 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
149 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
150 << m_transform.getIdentifier()
151 << " is running on reference model" << endl;
152 #endif
153
154 message = "";
155
156 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message);
157 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
158 if (!toAlignOutputModel) {
159 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl;
160 emit failed(m_toAlign, message);
161 return;
162 }
163
164 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
165 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
166 << m_transform.getIdentifier()
167 << " is running on toAlign model" << endl;
168 #endif
169
170 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)),
171 this, SLOT(completionChanged(ModelId)));
172 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
173 this, SLOT(completionChanged(ModelId)));
174
175 auto alignmentModel = std::make_shared<AlignmentModel>
176 (m_reference, m_toAlign, ModelId());
177 m_alignmentModel = ModelById::add(alignmentModel);
178
179 toAlign->setAlignment(m_alignmentModel);
180 m_document->addNonDerivedModel(m_alignmentModel);
181
182 // we wouldn't normally expect these to be true here, but...
183 int completion = 0;
184 if (referenceOutputModel->isReady(&completion) &&
185 toAlignOutputModel->isReady(&completion)) {
186 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models "
187 << "are ready already! calling performAlignment" << endl;
188 if (performAlignment()) {
189 emit complete(m_alignmentModel);
190 } else {
191 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW"));
192 }
193 }
194 }
195
196 void
197 TransformDTWAligner::completionChanged(ModelId id)
198 {
199 if (!m_incomplete) {
200 return;
201 }
202 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
203 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
204 << "model " << id << endl;
205 #endif
206
207 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
208 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
209 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
210
211 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
212 return;
213 }
214
215 int referenceCompletion = 0, toAlignCompletion = 0;
216 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
217 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
218
219 if (referenceReady && toAlignReady) {
220
221 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
222 << "both models ready, calling performAlignment" << endl;
223
224 alignmentModel->setCompletion(95);
225
226 if (performAlignment()) {
227 emit complete(m_alignmentModel);
228 } else {
229 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
230 }
231
232 } else {
233 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
234 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
235 << "not ready yet: reference completion " << referenceCompletion
236 << ", toAlign completion " << toAlignCompletion << endl;
237 #endif
238
239 int completion = std::min(referenceCompletion,
240 toAlignCompletion);
241 completion = (completion * 94) / 100;
242 alignmentModel->setCompletion(completion);
243 }
244 }
245
246 bool
247 TransformDTWAligner::performAlignment()
248 {
249 if (m_dtwType == Magnitude) {
250 return performAlignmentMagnitude();
251 } else {
252 return performAlignmentRiseFall();
253 }
254 }
255
256 bool
257 TransformDTWAligner::getValuesFrom(ModelId modelId,
258 vector<sv_frame_t> &frames,
259 vector<double> &values,
260 sv_frame_t &resolution)
261 {
262 EventVector events;
263
264 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
265 resolution = model->getResolution();
266 events = model->getAllEvents();
267 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
268 resolution = model->getResolution();
269 events = model->getAllEvents();
270 } else {
271 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
272 << modelId << " is not supported" << endl;
273 return false;
274 }
275
276 frames.clear();
277 values.clear();
278
279 for (auto e: events) {
280 frames.push_back(e.getFrame());
281 values.push_back(e.getValue());
282 }
283
284 return true;
285 }
286
287 Path
288 TransformDTWAligner::makePath(const vector<size_t> &alignment,
289 const vector<sv_frame_t> &refFrames,
290 const vector<sv_frame_t> &otherFrames,
291 sv_samplerate_t sampleRate,
292 sv_frame_t resolution)
293 {
294 Path path(sampleRate, resolution);
295
296 path.add(PathPoint(0, 0));
297
298 for (int i = 0; in_range_for(alignment, i); ++i) {
299
300 // DTW returns "the index into s2 for each element in s1"
301 sv_frame_t refFrame = refFrames[i];
302
303 if (!in_range_for(otherFrames, alignment[i])) {
304 SVCERR << "TransformDTWAligner::makePath: Internal error: "
305 << "DTW maps index " << i << " in reference frame vector "
306 << "(size " << refFrames.size() << ") onto index "
307 << alignment[i] << " in other frame vector "
308 << "(only size " << otherFrames.size() << ")" << endl;
309 continue;
310 }
311
312 sv_frame_t alignedFrame = otherFrames[alignment[i]];
313 path.add(PathPoint(alignedFrame, refFrame));
314 }
315
316 return path;
317 }
318
319 bool
320 TransformDTWAligner::performAlignmentMagnitude()
321 {
322 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
323 if (!alignmentModel) {
324 return false;
325 }
326
327 vector<sv_frame_t> refFrames, otherFrames;
328 vector<double> refValues, otherValues;
329 sv_frame_t resolution = 0;
330
331 if (!getValuesFrom(m_referenceOutputModel,
332 refFrames, refValues, resolution)) {
333 return false;
334 }
335
336 if (!getValuesFrom(m_toAlignOutputModel,
337 otherFrames, otherValues, resolution)) {
338 return false;
339 }
340
341 vector<double> s1, s2;
342 for (double v: refValues) {
343 s1.push_back(m_magnitudePreprocessor(v));
344 }
345 for (double v: otherValues) {
346 s2.push_back(m_magnitudePreprocessor(v));
347 }
348
349 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
350 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
351 << "Have " << s1.size() << " events from reference, "
352 << s2.size() << " from toAlign" << endl;
353 #endif
354
355 MagnitudeDTW dtw;
356 vector<size_t> alignment;
357
358 {
359 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
360 SVCERR << "TransformDTWAligner[" << this
361 << "]: serialising DTW to avoid over-allocation" << endl;
362 #endif
363 QMutexLocker locker(&m_dtwMutex);
364 alignment = dtw.alignSeries(s1, s2);
365 }
366
367 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
368 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
369 << "DTW produced " << alignment.size() << " points:" << endl;
370 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
371 SVCERR << alignment[i] << " ";
372 }
373 SVCERR << endl;
374 #endif
375
376 alignmentModel->setPath(makePath(alignment,
377 refFrames,
378 otherFrames,
379 alignmentModel->getSampleRate(),
380 resolution));
381 alignmentModel->setCompletion(100);
382
383 SVCERR << "TransformDTWAligner[" << this
384 << "]: performAlignmentMagnitude: Done" << endl;
385
386 m_incomplete = false;
387 return true;
388 }
389
390 bool
391 TransformDTWAligner::performAlignmentRiseFall()
392 {
393 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
394 if (!alignmentModel) {
395 return false;
396 }
397
398 vector<sv_frame_t> refFrames, otherFrames;
399 vector<double> refValues, otherValues;
400 sv_frame_t resolution = 0;
401
402 if (!getValuesFrom(m_referenceOutputModel,
403 refFrames, refValues, resolution)) {
404 return false;
405 }
406
407 if (!getValuesFrom(m_toAlignOutputModel,
408 otherFrames, otherValues, resolution)) {
409 return false;
410 }
411
412 auto preprocess =
413 [this](const std::vector<double> &vv) {
414 vector<RiseFallDTW::Value> s;
415 double prev = 0.0;
416 for (auto curr: vv) {
417 s.push_back(m_riseFallPreprocessor(prev, curr));
418 prev = curr;
419 }
420 return s;
421 };
422
423 vector<RiseFallDTW::Value> s1 = preprocess(refValues);
424 vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
425
426 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
427 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
428 << "Have " << s1.size() << " events from reference, "
429 << s2.size() << " from toAlign" << endl;
430
431 SVCERR << "Reference:" << endl;
432 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
433 SVCERR << s1[i] << " ";
434 }
435 SVCERR << endl;
436
437 SVCERR << "toAlign:" << endl;
438 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
439 SVCERR << s2[i] << " ";
440 }
441 SVCERR << endl;
442 #endif
443
444 RiseFallDTW dtw;
445 vector<size_t> alignment;
446
447 {
448 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
449 SVCERR << "TransformDTWAligner[" << this
450 << "]: serialising DTW to avoid over-allocation" << endl;
451 #endif
452 QMutexLocker locker(&m_dtwMutex);
453 alignment = dtw.alignSeries(s1, s2);
454 }
455
456 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
457 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
458 << "DTW produced " << alignment.size() << " points:" << endl;
459 for (int i = 0; i < alignment.size() && i < 100; ++i) {
460 SVCERR << alignment[i] << " ";
461 }
462 SVCERR << endl;
463 #endif
464
465 alignmentModel->setPath(makePath(alignment,
466 refFrames,
467 otherFrames,
468 alignmentModel->getSampleRate(),
469 resolution));
470
471 alignmentModel->setCompletion(100);
472
473 SVCERR << "TransformDTWAligner[" << this
474 << "]: performAlignmentRiseFall: Done" << endl;
475
476 m_incomplete = false;
477 return true;
478 }