Mercurial > hg > svapp
comparison align/TransformDTWAligner.cpp @ 767:dd742e566e60 pitch-align
Make a start on further alignment methods
author | Chris Cannam |
---|---|
date | Thu, 21 May 2020 16:21:57 +0100 |
parents | |
children | 1b1960009be6 |
comparison
equal
deleted
inserted
replaced
761:6429a164b7e1 | 767:dd742e566e60 |
---|---|
1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ | |
2 | |
3 /* | |
4 Sonic Visualiser | |
5 An audio file viewer and annotation editor. | |
6 Centre for Digital Music, Queen Mary, University of London. | |
7 | |
8 This program is free software; you can redistribute it and/or | |
9 modify it under the terms of the GNU General Public License as | |
10 published by the Free Software Foundation; either version 2 of the | |
11 License, or (at your option) any later version. See the file | |
12 COPYING included with this distribution for more information. | |
13 */ | |
14 | |
15 #include "TransformDTWAligner.h" | |
16 #include "DTW.h" | |
17 | |
18 #include "data/model/SparseTimeValueModel.h" | |
19 #include "data/model/RangeSummarisableTimeValueModel.h" | |
20 #include "data/model/AlignmentModel.h" | |
21 #include "data/model/AggregateWaveModel.h" | |
22 | |
23 #include "framework/Document.h" | |
24 | |
25 #include "transform/ModelTransformerFactory.h" | |
26 #include "transform/FeatureExtractionModelTransformer.h" | |
27 | |
28 #include <QSettings> | |
29 #include <QMutex> | |
30 #include <QMutexLocker> | |
31 | |
32 using std::vector; | |
33 | |
34 TransformDTWAligner::TransformDTWAligner(Document *doc, | |
35 ModelId reference, | |
36 ModelId toAlign, | |
37 Transform transform, | |
38 DTWType dtwType) : | |
39 m_document(doc), | |
40 m_reference(reference), | |
41 m_toAlign(toAlign), | |
42 m_referenceTransformComplete(false), | |
43 m_toAlignTransformComplete(false), | |
44 m_transform(transform), | |
45 m_dtwType(dtwType), | |
46 m_incomplete(true) | |
47 { | |
48 } | |
49 | |
50 TransformDTWAligner::~TransformDTWAligner() | |
51 { | |
52 if (m_incomplete) { | |
53 if (auto toAlign = ModelById::get(m_toAlign)) { | |
54 toAlign->setAlignment({}); | |
55 } | |
56 } | |
57 | |
58 ModelById::release(m_referenceOutputModel); | |
59 ModelById::release(m_toAlignOutputModel); | |
60 ModelById::release(m_alignmentProgressModel); | |
61 } | |
62 | |
63 bool | |
64 TransformDTWAligner::isAvailable() | |
65 { | |
66 //!!! needs to be isAvailable(QString transformId)? | |
67 return true; | |
68 } | |
69 | |
70 void | |
71 TransformDTWAligner::begin() | |
72 { | |
73 auto reference = | |
74 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference); | |
75 auto toAlign = | |
76 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign); | |
77 | |
78 if (!reference || !toAlign) return; | |
79 | |
80 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning " | |
81 << m_toAlign << " against reference " << m_reference << endl; | |
82 | |
83 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance(); | |
84 | |
85 QString message; | |
86 | |
87 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message); | |
88 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); | |
89 if (!referenceOutputModel) { | |
90 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl; | |
91 emit failed(m_toAlign, message); | |
92 return; | |
93 } | |
94 | |
95 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " | |
96 << m_transform.getIdentifier() | |
97 << " is running on reference model" << endl; | |
98 | |
99 message = ""; | |
100 | |
101 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message); | |
102 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); | |
103 if (!toAlignOutputModel) { | |
104 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl; | |
105 emit failed(m_toAlign, message); | |
106 return; | |
107 } | |
108 | |
109 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " | |
110 << m_transform.getIdentifier() | |
111 << " is running on toAlign model" << endl; | |
112 | |
113 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)), | |
114 this, SLOT(completionChanged(ModelId))); | |
115 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)), | |
116 this, SLOT(completionChanged(ModelId))); | |
117 | |
118 auto alignmentProgressModel = std::make_shared<SparseTimeValueModel> | |
119 (reference->getSampleRate(), m_transform.getStepSize(), false); | |
120 alignmentProgressModel->setCompletion(0); | |
121 m_alignmentProgressModel = ModelById::add(alignmentProgressModel); | |
122 | |
123 auto alignmentModel = std::make_shared<AlignmentModel> | |
124 (m_reference, m_toAlign, m_alignmentProgressModel); | |
125 m_alignmentModel = ModelById::add(alignmentModel); | |
126 | |
127 toAlign->setAlignment(m_alignmentModel); | |
128 m_document->addNonDerivedModel(m_alignmentModel); | |
129 | |
130 // we wouldn't normally expect these to be true here, but... | |
131 int completion = 0; | |
132 if (referenceOutputModel->isReady(&completion) && | |
133 toAlignOutputModel->isReady(&completion)) { | |
134 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models " | |
135 << "are ready already! calling performAlignment" << endl; | |
136 if (performAlignment()) { | |
137 emit complete(m_alignmentModel); | |
138 } else { | |
139 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW")); | |
140 } | |
141 } | |
142 } | |
143 | |
144 void | |
145 TransformDTWAligner::completionChanged(ModelId id) | |
146 { | |
147 if (!m_incomplete) { | |
148 return; | |
149 } | |
150 | |
151 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
152 << "model " << id << endl; | |
153 | |
154 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); | |
155 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); | |
156 | |
157 if (!referenceOutputModel || !toAlignOutputModel) { | |
158 return; | |
159 } | |
160 | |
161 int referenceCompletion = 0, toAlignCompletion = 0; | |
162 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion); | |
163 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); | |
164 | |
165 auto alignmentProgressModel = | |
166 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); | |
167 | |
168 if (referenceReady && toAlignReady) { | |
169 | |
170 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
171 << "ready, calling performAlignment" << endl; | |
172 | |
173 if (alignmentProgressModel) { | |
174 alignmentProgressModel->setCompletion(95); | |
175 } | |
176 | |
177 if (performAlignment()) { | |
178 emit complete(m_alignmentModel); | |
179 } else { | |
180 emit failed(m_toAlign, tr("Alignment of transform outputs failed")); | |
181 } | |
182 | |
183 } else { | |
184 | |
185 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " | |
186 << "not ready yet: reference completion " << referenceCompletion | |
187 << ", toAlign completion " << toAlignCompletion << endl; | |
188 | |
189 if (alignmentProgressModel) { | |
190 int completion = std::min(referenceCompletion, | |
191 toAlignCompletion); | |
192 completion = (completion * 94) / 100; | |
193 alignmentProgressModel->setCompletion(completion); | |
194 } | |
195 } | |
196 } | |
197 | |
198 bool | |
199 TransformDTWAligner::performAlignment() | |
200 { | |
201 if (m_dtwType == Magnitude) { | |
202 return performAlignmentMagnitude(); | |
203 } else { | |
204 return performAlignmentRiseFall(); | |
205 } | |
206 } | |
207 | |
208 bool | |
209 TransformDTWAligner::performAlignmentMagnitude() | |
210 { | |
211 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
212 (m_referenceOutputModel); | |
213 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
214 (m_toAlignOutputModel); | |
215 auto alignmentModel = ModelById::getAs<AlignmentModel> | |
216 (m_alignmentModel); | |
217 | |
218 if (!referenceOutputSTVM || !toAlignOutputSTVM) { | |
219 //!!! what? | |
220 return false; | |
221 } | |
222 | |
223 if (!alignmentModel) { | |
224 return false; | |
225 } | |
226 | |
227 vector<double> s1, s2; | |
228 | |
229 { | |
230 auto events = referenceOutputSTVM->getAllEvents(); | |
231 for (auto e: events) { | |
232 s1.push_back(e.getValue()); | |
233 } | |
234 events = toAlignOutputSTVM->getAllEvents(); | |
235 for (auto e: events) { | |
236 s2.push_back(e.getValue()); | |
237 } | |
238 } | |
239 | |
240 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " | |
241 << "Have " << s1.size() << " events from reference, " | |
242 << s2.size() << " from toAlign" << endl; | |
243 | |
244 MagnitudeDTW dtw; | |
245 vector<size_t> alignment; | |
246 | |
247 { | |
248 SVCERR << "TransformDTWAligner[" << this | |
249 << "]: serialising DTW to avoid over-allocation" << endl; | |
250 static QMutex mutex; | |
251 QMutexLocker locker(&mutex); | |
252 | |
253 alignment = dtw.alignSeries(s1, s2); | |
254 } | |
255 | |
256 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " | |
257 << "DTW produced " << alignment.size() << " points:" << endl; | |
258 for (int i = 0; i < alignment.size() && i < 100; ++i) { | |
259 SVCERR << alignment[i] << " "; | |
260 } | |
261 SVCERR << endl; | |
262 | |
263 auto alignmentProgressModel = | |
264 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); | |
265 if (alignmentProgressModel) { | |
266 alignmentProgressModel->setCompletion(100); | |
267 } | |
268 | |
269 // clear the alignment progress model | |
270 alignmentModel->setPathFrom(ModelId()); | |
271 | |
272 sv_frame_t resolution = referenceOutputSTVM->getResolution(); | |
273 sv_frame_t sourceFrame = 0; | |
274 | |
275 Path path(referenceOutputSTVM->getSampleRate(), resolution); | |
276 | |
277 for (size_t m: alignment) { | |
278 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); | |
279 sourceFrame += resolution; | |
280 } | |
281 | |
282 alignmentModel->setPath(path); | |
283 | |
284 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done" | |
285 << endl; | |
286 | |
287 m_incomplete = false; | |
288 return true; | |
289 } | |
290 | |
291 bool | |
292 TransformDTWAligner::performAlignmentRiseFall() | |
293 { | |
294 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
295 (m_referenceOutputModel); | |
296 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> | |
297 (m_toAlignOutputModel); | |
298 auto alignmentModel = ModelById::getAs<AlignmentModel> | |
299 (m_alignmentModel); | |
300 | |
301 if (!referenceOutputSTVM || !toAlignOutputSTVM) { | |
302 //!!! what? | |
303 return false; | |
304 } | |
305 | |
306 if (!alignmentModel) { | |
307 return false; | |
308 } | |
309 | |
310 vector<RiseFallDTW::Value> s1, s2; | |
311 double prev1 = 0.0, prev2 = 0.0; | |
312 | |
313 { | |
314 auto events = referenceOutputSTVM->getAllEvents(); | |
315 for (auto e: events) { | |
316 double v = e.getValue(); | |
317 //!!! the original does this using MIDI pitch for the | |
318 //!!! pYin transform... rework with a lambda passed in | |
319 //!!! for modification maybe? + factor out s1/s2 of course | |
320 if (v > prev1) { | |
321 s1.push_back({ RiseFallDTW::Direction::Up, v - prev1 }); | |
322 } else { | |
323 s1.push_back({ RiseFallDTW::Direction::Down, prev1 - v }); | |
324 } | |
325 prev1 = v; | |
326 } | |
327 events = toAlignOutputSTVM->getAllEvents(); | |
328 for (auto e: events) { | |
329 double v = e.getValue(); | |
330 //!!! as above | |
331 if (v > prev2) { | |
332 s2.push_back({ RiseFallDTW::Direction::Up, v - prev2 }); | |
333 } else { | |
334 s2.push_back({ RiseFallDTW::Direction::Down, prev2 - v }); | |
335 } | |
336 prev2 = v; | |
337 } | |
338 } | |
339 | |
340 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " | |
341 << "Have " << s1.size() << " events from reference, " | |
342 << s2.size() << " from toAlign" << endl; | |
343 | |
344 RiseFallDTW dtw; | |
345 | |
346 vector<size_t> alignment; | |
347 | |
348 { | |
349 SVCERR << "TransformDTWAligner[" << this | |
350 << "]: serialising DTW to avoid over-allocation" << endl; | |
351 static QMutex mutex; | |
352 QMutexLocker locker(&mutex); | |
353 | |
354 alignment = dtw.alignSeries(s1, s2); | |
355 } | |
356 | |
357 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " | |
358 << "DTW produced " << alignment.size() << " points:" << endl; | |
359 for (int i = 0; i < alignment.size() && i < 100; ++i) { | |
360 SVCERR << alignment[i] << " "; | |
361 } | |
362 SVCERR << endl; | |
363 | |
364 auto alignmentProgressModel = | |
365 ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); | |
366 if (alignmentProgressModel) { | |
367 alignmentProgressModel->setCompletion(100); | |
368 } | |
369 | |
370 // clear the alignment progress model | |
371 alignmentModel->setPathFrom(ModelId()); | |
372 | |
373 sv_frame_t resolution = referenceOutputSTVM->getResolution(); | |
374 sv_frame_t sourceFrame = 0; | |
375 | |
376 Path path(referenceOutputSTVM->getSampleRate(), resolution); | |
377 | |
378 for (size_t m: alignment) { | |
379 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); | |
380 sourceFrame += resolution; | |
381 } | |
382 | |
383 alignmentModel->setPath(path); | |
384 | |
385 SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done" | |
386 << endl; | |
387 | |
388 m_incomplete = false; | |
389 return true; | |
390 } |