comparison align/TransformDTWAligner.cpp @ 767:dd742e566e60 pitch-align

Make a start on further alignment methods
author Chris Cannam
date Thu, 21 May 2020 16:21:57 +0100
parents
children 1b1960009be6
comparison
equal deleted inserted replaced
761:6429a164b7e1 767:dd742e566e60
1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
2
3 /*
4 Sonic Visualiser
5 An audio file viewer and annotation editor.
6 Centre for Digital Music, Queen Mary, University of London.
7
8 This program is free software; you can redistribute it and/or
9 modify it under the terms of the GNU General Public License as
10 published by the Free Software Foundation; either version 2 of the
11 License, or (at your option) any later version. See the file
12 COPYING included with this distribution for more information.
13 */
14
15 #include "TransformDTWAligner.h"
16 #include "DTW.h"
17
18 #include "data/model/SparseTimeValueModel.h"
19 #include "data/model/RangeSummarisableTimeValueModel.h"
20 #include "data/model/AlignmentModel.h"
21 #include "data/model/AggregateWaveModel.h"
22
23 #include "framework/Document.h"
24
25 #include "transform/ModelTransformerFactory.h"
26 #include "transform/FeatureExtractionModelTransformer.h"
27
28 #include <QSettings>
29 #include <QMutex>
30 #include <QMutexLocker>
31
32 using std::vector;
33
34 TransformDTWAligner::TransformDTWAligner(Document *doc,
35 ModelId reference,
36 ModelId toAlign,
37 Transform transform,
38 DTWType dtwType) :
39 m_document(doc),
40 m_reference(reference),
41 m_toAlign(toAlign),
42 m_referenceTransformComplete(false),
43 m_toAlignTransformComplete(false),
44 m_transform(transform),
45 m_dtwType(dtwType),
46 m_incomplete(true)
47 {
48 }
49
50 TransformDTWAligner::~TransformDTWAligner()
51 {
52 if (m_incomplete) {
53 if (auto toAlign = ModelById::get(m_toAlign)) {
54 toAlign->setAlignment({});
55 }
56 }
57
58 ModelById::release(m_referenceOutputModel);
59 ModelById::release(m_toAlignOutputModel);
60 ModelById::release(m_alignmentProgressModel);
61 }
62
63 bool
64 TransformDTWAligner::isAvailable()
65 {
66 //!!! needs to be isAvailable(QString transformId)?
67 return true;
68 }
69
70 void
71 TransformDTWAligner::begin()
72 {
73 auto reference =
74 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference);
75 auto toAlign =
76 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign);
77
78 if (!reference || !toAlign) return;
79
80 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning "
81 << m_toAlign << " against reference " << m_reference << endl;
82
83 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance();
84
85 QString message;
86
87 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message);
88 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
89 if (!referenceOutputModel) {
90 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl;
91 emit failed(m_toAlign, message);
92 return;
93 }
94
95 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
96 << m_transform.getIdentifier()
97 << " is running on reference model" << endl;
98
99 message = "";
100
101 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message);
102 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
103 if (!toAlignOutputModel) {
104 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl;
105 emit failed(m_toAlign, message);
106 return;
107 }
108
109 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
110 << m_transform.getIdentifier()
111 << " is running on toAlign model" << endl;
112
113 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)),
114 this, SLOT(completionChanged(ModelId)));
115 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
116 this, SLOT(completionChanged(ModelId)));
117
118 auto alignmentProgressModel = std::make_shared<SparseTimeValueModel>
119 (reference->getSampleRate(), m_transform.getStepSize(), false);
120 alignmentProgressModel->setCompletion(0);
121 m_alignmentProgressModel = ModelById::add(alignmentProgressModel);
122
123 auto alignmentModel = std::make_shared<AlignmentModel>
124 (m_reference, m_toAlign, m_alignmentProgressModel);
125 m_alignmentModel = ModelById::add(alignmentModel);
126
127 toAlign->setAlignment(m_alignmentModel);
128 m_document->addNonDerivedModel(m_alignmentModel);
129
130 // we wouldn't normally expect these to be true here, but...
131 int completion = 0;
132 if (referenceOutputModel->isReady(&completion) &&
133 toAlignOutputModel->isReady(&completion)) {
134 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models "
135 << "are ready already! calling performAlignment" << endl;
136 if (performAlignment()) {
137 emit complete(m_alignmentModel);
138 } else {
139 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW"));
140 }
141 }
142 }
143
144 void
145 TransformDTWAligner::completionChanged(ModelId id)
146 {
147 if (!m_incomplete) {
148 return;
149 }
150
151 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
152 << "model " << id << endl;
153
154 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
155 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
156
157 if (!referenceOutputModel || !toAlignOutputModel) {
158 return;
159 }
160
161 int referenceCompletion = 0, toAlignCompletion = 0;
162 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
163 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
164
165 auto alignmentProgressModel =
166 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
167
168 if (referenceReady && toAlignReady) {
169
170 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
171 << "ready, calling performAlignment" << endl;
172
173 if (alignmentProgressModel) {
174 alignmentProgressModel->setCompletion(95);
175 }
176
177 if (performAlignment()) {
178 emit complete(m_alignmentModel);
179 } else {
180 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
181 }
182
183 } else {
184
185 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
186 << "not ready yet: reference completion " << referenceCompletion
187 << ", toAlign completion " << toAlignCompletion << endl;
188
189 if (alignmentProgressModel) {
190 int completion = std::min(referenceCompletion,
191 toAlignCompletion);
192 completion = (completion * 94) / 100;
193 alignmentProgressModel->setCompletion(completion);
194 }
195 }
196 }
197
198 bool
199 TransformDTWAligner::performAlignment()
200 {
201 if (m_dtwType == Magnitude) {
202 return performAlignmentMagnitude();
203 } else {
204 return performAlignmentRiseFall();
205 }
206 }
207
208 bool
209 TransformDTWAligner::performAlignmentMagnitude()
210 {
211 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel>
212 (m_referenceOutputModel);
213 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
214 (m_toAlignOutputModel);
215 auto alignmentModel = ModelById::getAs<AlignmentModel>
216 (m_alignmentModel);
217
218 if (!referenceOutputSTVM || !toAlignOutputSTVM) {
219 //!!! what?
220 return false;
221 }
222
223 if (!alignmentModel) {
224 return false;
225 }
226
227 vector<double> s1, s2;
228
229 {
230 auto events = referenceOutputSTVM->getAllEvents();
231 for (auto e: events) {
232 s1.push_back(e.getValue());
233 }
234 events = toAlignOutputSTVM->getAllEvents();
235 for (auto e: events) {
236 s2.push_back(e.getValue());
237 }
238 }
239
240 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: "
241 << "Have " << s1.size() << " events from reference, "
242 << s2.size() << " from toAlign" << endl;
243
244 MagnitudeDTW dtw;
245 vector<size_t> alignment;
246
247 {
248 SVCERR << "TransformDTWAligner[" << this
249 << "]: serialising DTW to avoid over-allocation" << endl;
250 static QMutex mutex;
251 QMutexLocker locker(&mutex);
252
253 alignment = dtw.alignSeries(s1, s2);
254 }
255
256 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: "
257 << "DTW produced " << alignment.size() << " points:" << endl;
258 for (int i = 0; i < alignment.size() && i < 100; ++i) {
259 SVCERR << alignment[i] << " ";
260 }
261 SVCERR << endl;
262
263 auto alignmentProgressModel =
264 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
265 if (alignmentProgressModel) {
266 alignmentProgressModel->setCompletion(100);
267 }
268
269 // clear the alignment progress model
270 alignmentModel->setPathFrom(ModelId());
271
272 sv_frame_t resolution = referenceOutputSTVM->getResolution();
273 sv_frame_t sourceFrame = 0;
274
275 Path path(referenceOutputSTVM->getSampleRate(), resolution);
276
277 for (size_t m: alignment) {
278 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
279 sourceFrame += resolution;
280 }
281
282 alignmentModel->setPath(path);
283
284 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done"
285 << endl;
286
287 m_incomplete = false;
288 return true;
289 }
290
291 bool
292 TransformDTWAligner::performAlignmentRiseFall()
293 {
294 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel>
295 (m_referenceOutputModel);
296 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
297 (m_toAlignOutputModel);
298 auto alignmentModel = ModelById::getAs<AlignmentModel>
299 (m_alignmentModel);
300
301 if (!referenceOutputSTVM || !toAlignOutputSTVM) {
302 //!!! what?
303 return false;
304 }
305
306 if (!alignmentModel) {
307 return false;
308 }
309
310 vector<RiseFallDTW::Value> s1, s2;
311 double prev1 = 0.0, prev2 = 0.0;
312
313 {
314 auto events = referenceOutputSTVM->getAllEvents();
315 for (auto e: events) {
316 double v = e.getValue();
317 //!!! the original does this using MIDI pitch for the
318 //!!! pYin transform... rework with a lambda passed in
319 //!!! for modification maybe? + factor out s1/s2 of course
320 if (v > prev1) {
321 s1.push_back({ RiseFallDTW::Direction::Up, v - prev1 });
322 } else {
323 s1.push_back({ RiseFallDTW::Direction::Down, prev1 - v });
324 }
325 prev1 = v;
326 }
327 events = toAlignOutputSTVM->getAllEvents();
328 for (auto e: events) {
329 double v = e.getValue();
330 //!!! as above
331 if (v > prev2) {
332 s2.push_back({ RiseFallDTW::Direction::Up, v - prev2 });
333 } else {
334 s2.push_back({ RiseFallDTW::Direction::Down, prev2 - v });
335 }
336 prev2 = v;
337 }
338 }
339
340 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: "
341 << "Have " << s1.size() << " events from reference, "
342 << s2.size() << " from toAlign" << endl;
343
344 RiseFallDTW dtw;
345
346 vector<size_t> alignment;
347
348 {
349 SVCERR << "TransformDTWAligner[" << this
350 << "]: serialising DTW to avoid over-allocation" << endl;
351 static QMutex mutex;
352 QMutexLocker locker(&mutex);
353
354 alignment = dtw.alignSeries(s1, s2);
355 }
356
357 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: "
358 << "DTW produced " << alignment.size() << " points:" << endl;
359 for (int i = 0; i < alignment.size() && i < 100; ++i) {
360 SVCERR << alignment[i] << " ";
361 }
362 SVCERR << endl;
363
364 auto alignmentProgressModel =
365 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
366 if (alignmentProgressModel) {
367 alignmentProgressModel->setCompletion(100);
368 }
369
370 // clear the alignment progress model
371 alignmentModel->setPathFrom(ModelId());
372
373 sv_frame_t resolution = referenceOutputSTVM->getResolution();
374 sv_frame_t sourceFrame = 0;
375
376 Path path(referenceOutputSTVM->getSampleRate(), resolution);
377
378 for (size_t m: alignment) {
379 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
380 sourceFrame += resolution;
381 }
382
383 alignmentModel->setPath(path);
384
385 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done"
386 << endl;
387
388 m_incomplete = false;
389 return true;
390 }