comparison align/TransformDTWAligner.cpp @ 771:1d6cca5a5621 pitch-align

Allow use of proper sparse models (i.e. retaining event time info) in alignment; use this to switch to note alignment, which is what we have most recently been doing in the external program. Not currently producing correct results, though
author Chris Cannam
date Fri, 29 May 2020 17:39:02 +0100
parents a316cb6fed81
children 699b5b130ea2
comparison
equal deleted inserted replaced
770:486add472c3f 771:1d6cca5a5621
14 14
15 #include "TransformDTWAligner.h" 15 #include "TransformDTWAligner.h"
16 #include "DTW.h" 16 #include "DTW.h"
17 17
18 #include "data/model/SparseTimeValueModel.h" 18 #include "data/model/SparseTimeValueModel.h"
19 #include "data/model/NoteModel.h"
19 #include "data/model/RangeSummarisableTimeValueModel.h" 20 #include "data/model/RangeSummarisableTimeValueModel.h"
20 #include "data/model/AlignmentModel.h" 21 #include "data/model/AlignmentModel.h"
21 #include "data/model/AggregateWaveModel.h" 22 #include "data/model/AggregateWaveModel.h"
22 23
23 #include "framework/Document.h" 24 #include "framework/Document.h"
28 #include <QSettings> 29 #include <QSettings>
29 #include <QMutex> 30 #include <QMutex>
30 #include <QMutexLocker> 31 #include <QMutexLocker>
31 32
32 using std::vector; 33 using std::vector;
34
35 static
36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
37 [](double x) {
38 return x;
39 };
40
41 static
42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
43 [](double prev, double curr) {
44 if (prev == curr) {
45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
46 } else if (curr > prev) {
47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
48 } else {
49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
50 }
51 };
52
53 QMutex
54 TransformDTWAligner::m_dtwMutex;
33 55
34 TransformDTWAligner::TransformDTWAligner(Document *doc, 56 TransformDTWAligner::TransformDTWAligner(Document *doc,
35 ModelId reference, 57 ModelId reference,
36 ModelId toAlign, 58 ModelId toAlign,
37 Transform transform, 59 Transform transform,
40 m_reference(reference), 62 m_reference(reference),
41 m_toAlign(toAlign), 63 m_toAlign(toAlign),
42 m_transform(transform), 64 m_transform(transform),
43 m_dtwType(dtwType), 65 m_dtwType(dtwType),
44 m_incomplete(true), 66 m_incomplete(true),
45 m_outputPreprocessor([](double x) { return x; }) 67 m_magnitudePreprocessor(identityMagnitudePreprocessor),
68 m_riseFallPreprocessor(identityRiseFallPreprocessor)
46 { 69 {
47 } 70 }
48 71
49 TransformDTWAligner::TransformDTWAligner(Document *doc, 72 TransformDTWAligner::TransformDTWAligner(Document *doc,
50 ModelId reference, 73 ModelId reference,
51 ModelId toAlign, 74 ModelId toAlign,
52 Transform transform, 75 Transform transform,
53 DTWType dtwType, 76 MagnitudePreprocessor outputPreprocessor) :
54 std::function<double(double)>
55 outputPreprocessor) :
56 m_document(doc), 77 m_document(doc),
57 m_reference(reference), 78 m_reference(reference),
58 m_toAlign(toAlign), 79 m_toAlign(toAlign),
59 m_transform(transform), 80 m_transform(transform),
60 m_dtwType(dtwType), 81 m_dtwType(Magnitude),
61 m_incomplete(true), 82 m_incomplete(true),
62 m_outputPreprocessor(outputPreprocessor) 83 m_magnitudePreprocessor(outputPreprocessor),
84 m_riseFallPreprocessor(identityRiseFallPreprocessor)
85 {
86 }
87
88 TransformDTWAligner::TransformDTWAligner(Document *doc,
89 ModelId reference,
90 ModelId toAlign,
91 Transform transform,
92 RiseFallPreprocessor outputPreprocessor) :
93 m_document(doc),
94 m_reference(reference),
95 m_toAlign(toAlign),
96 m_transform(transform),
97 m_dtwType(RiseFall),
98 m_incomplete(true),
99 m_magnitudePreprocessor(identityMagnitudePreprocessor),
100 m_riseFallPreprocessor(outputPreprocessor)
63 { 101 {
64 } 102 }
65 103
66 TransformDTWAligner::~TransformDTWAligner() 104 TransformDTWAligner::~TransformDTWAligner()
67 { 105 {
155 TransformDTWAligner::completionChanged(ModelId id) 193 TransformDTWAligner::completionChanged(ModelId id)
156 { 194 {
157 if (!m_incomplete) { 195 if (!m_incomplete) {
158 return; 196 return;
159 } 197 }
160 198 /*
161 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " 199 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
162 << "model " << id << endl; 200 << "model " << id << endl;
163 201 */
164 auto referenceOutputModel = ModelById::get(m_referenceOutputModel); 202 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
165 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); 203 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
166 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); 204 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
167 205
168 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) { 206 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
174 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); 212 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
175 213
176 if (referenceReady && toAlignReady) { 214 if (referenceReady && toAlignReady) {
177 215
178 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " 216 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
179 << "ready, calling performAlignment" << endl; 217 << "both models ready, calling performAlignment" << endl;
180 218
181 alignmentModel->setCompletion(95); 219 alignmentModel->setCompletion(95);
182 220
183 if (performAlignment()) { 221 if (performAlignment()) {
184 emit complete(m_alignmentModel); 222 emit complete(m_alignmentModel);
185 } else { 223 } else {
186 emit failed(m_toAlign, tr("Alignment of transform outputs failed")); 224 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
187 } 225 }
188 226
189 } else { 227 } else {
190 228 /*
191 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " 229 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
192 << "not ready yet: reference completion " << referenceCompletion 230 << "not ready yet: reference completion " << referenceCompletion
193 << ", toAlign completion " << toAlignCompletion << endl; 231 << ", toAlign completion " << toAlignCompletion << endl;
194 232 */
195 int completion = std::min(referenceCompletion, 233 int completion = std::min(referenceCompletion,
196 toAlignCompletion); 234 toAlignCompletion);
197 completion = (completion * 94) / 100; 235 completion = (completion * 94) / 100;
198 alignmentModel->setCompletion(completion); 236 alignmentModel->setCompletion(completion);
199 } 237 }
208 return performAlignmentRiseFall(); 246 return performAlignmentRiseFall();
209 } 247 }
210 } 248 }
211 249
212 bool 250 bool
251 TransformDTWAligner::getValuesFrom(ModelId modelId,
252 vector<sv_frame_t> &frames,
253 vector<double> &values,
254 sv_frame_t &resolution)
255 {
256 EventVector events;
257
258 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
259 resolution = model->getResolution();
260 events = model->getAllEvents();
261 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
262 resolution = model->getResolution();
263 events = model->getAllEvents();
264 } else {
265 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
266 << modelId << " is not supported" << endl;
267 return false;
268 }
269
270 frames.clear();
271 values.clear();
272
273 for (auto e: events) {
274 frames.push_back(e.getFrame());
275 values.push_back(e.getValue());
276 }
277
278 return true;
279 }
280
281 Path
282 TransformDTWAligner::makePath(const vector<size_t> &alignment,
283 const vector<sv_frame_t> &refFrames,
284 const vector<sv_frame_t> &otherFrames,
285 sv_samplerate_t sampleRate,
286 sv_frame_t resolution)
287 {
288 Path path(sampleRate, resolution);
289
290 for (int i = 0; in_range_for(alignment, i); ++i) {
291
292 // DTW returns "the index into s2 for each element in s1"
293 sv_frame_t refFrame = refFrames[i];
294
295 if (!in_range_for(otherFrames, alignment[i])) {
296 SVCERR << "TransformDTWAligner::makePath: Internal error: "
297 << "DTW maps index " << i << " in reference frame vector "
298 << "(size " << refFrames.size() << ") onto index "
299 << alignment[i] << " in other frame vector "
300 << "(only size " << otherFrames.size() << ")" << endl;
301 continue;
302 }
303
304 sv_frame_t alignedFrame = otherFrames[alignment[i]];
305 path.add(PathPoint(alignedFrame, refFrame));
306 }
307
308 return path;
309 }
310
311 bool
213 TransformDTWAligner::performAlignmentMagnitude() 312 TransformDTWAligner::performAlignmentMagnitude()
214 { 313 {
215 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> 314 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
216 (m_referenceOutputModel);
217 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
218 (m_toAlignOutputModel);
219 auto alignmentModel = ModelById::getAs<AlignmentModel>
220 (m_alignmentModel);
221
222 if (!referenceOutputSTVM || !toAlignOutputSTVM) {
223 //!!! what?
224 return false;
225 }
226
227 if (!alignmentModel) { 315 if (!alignmentModel) {
228 return false; 316 return false;
229 } 317 }
318
319 vector<sv_frame_t> refFrames, otherFrames;
320 vector<double> refValues, otherValues;
321 sv_frame_t resolution = 0;
322
323 if (!getValuesFrom(m_referenceOutputModel,
324 refFrames, refValues, resolution)) {
325 return false;
326 }
327
328 if (!getValuesFrom(m_toAlignOutputModel,
329 otherFrames, otherValues, resolution)) {
330 return false;
331 }
230 332
231 vector<double> s1, s2; 333 vector<double> s1, s2;
232 334 for (double v: refValues) {
233 { 335 s1.push_back(m_magnitudePreprocessor(v));
234 auto events = referenceOutputSTVM->getAllEvents(); 336 }
235 for (auto e: events) { 337 for (double v: otherValues) {
236 s1.push_back(m_outputPreprocessor(e.getValue())); 338 s2.push_back(m_magnitudePreprocessor(v));
237 }
238 events = toAlignOutputSTVM->getAllEvents();
239 for (auto e: events) {
240 s2.push_back(m_outputPreprocessor(e.getValue()));
241 }
242 } 339 }
243 340
244 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " 341 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
245 << "Have " << s1.size() << " events from reference, " 342 << "Have " << s1.size() << " events from reference, "
246 << s2.size() << " from toAlign" << endl; 343 << s2.size() << " from toAlign" << endl;
247 344
248 MagnitudeDTW dtw; 345 MagnitudeDTW dtw;
249 vector<size_t> alignment; 346 vector<size_t> alignment;
250 347
251 { 348 {
252 SVCERR << "TransformDTWAligner[" << this 349 SVCERR << "TransformDTWAligner[" << this
253 << "]: serialising DTW to avoid over-allocation" << endl; 350 << "]: serialising DTW to avoid over-allocation" << endl;
254 static QMutex mutex; 351 QMutexLocker locker(&m_dtwMutex);
255 QMutexLocker locker(&mutex);
256
257 alignment = dtw.alignSeries(s1, s2); 352 alignment = dtw.alignSeries(s1, s2);
258 } 353 }
259 354
260 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " 355 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
261 << "DTW produced " << alignment.size() << " points:" << endl; 356 << "DTW produced " << alignment.size() << " points:" << endl;
262 for (int i = 0; i < alignment.size() && i < 100; ++i) { 357 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
263 SVCERR << alignment[i] << " "; 358 SVCERR << alignment[i] << " ";
264 } 359 }
265 SVCERR << endl; 360 SVCERR << endl;
266 361
362 alignmentModel->setPath(makePath(alignment,
363 refFrames,
364 otherFrames,
365 alignmentModel->getSampleRate(),
366 resolution));
267 alignmentModel->setCompletion(100); 367 alignmentModel->setCompletion(100);
268 368
269 sv_frame_t resolution = referenceOutputSTVM->getResolution(); 369 SVCERR << "TransformDTWAligner[" << this
270 sv_frame_t sourceFrame = 0; 370 << "]: performAlignmentMagnitude: Done" << endl;
271
272 Path path(referenceOutputSTVM->getSampleRate(), resolution);
273
274 for (size_t m: alignment) {
275 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
276 sourceFrame += resolution;
277 }
278
279 alignmentModel->setPath(path);
280
281 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: Done"
282 << endl;
283 371
284 m_incomplete = false; 372 m_incomplete = false;
285 return true; 373 return true;
286 } 374 }
287 375
288 bool 376 bool
289 TransformDTWAligner::performAlignmentRiseFall() 377 TransformDTWAligner::performAlignmentRiseFall()
290 { 378 {
291 auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> 379 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
292 (m_referenceOutputModel);
293 auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
294 (m_toAlignOutputModel);
295 auto alignmentModel = ModelById::getAs<AlignmentModel>
296 (m_alignmentModel);
297
298 if (!referenceOutputSTVM || !toAlignOutputSTVM) {
299 //!!! what?
300 return false;
301 }
302
303 if (!alignmentModel) { 380 if (!alignmentModel) {
304 return false; 381 return false;
305 } 382 }
306 383
307 auto convertEvents = 384 vector<sv_frame_t> refFrames, otherFrames;
308 [this](const EventVector &ee) { 385 vector<double> refValues, otherValues;
386 sv_frame_t resolution = 0;
387
388 if (!getValuesFrom(m_referenceOutputModel,
389 refFrames, refValues, resolution)) {
390 return false;
391 }
392
393 if (!getValuesFrom(m_toAlignOutputModel,
394 otherFrames, otherValues, resolution)) {
395 return false;
396 }
397
398 auto preprocess =
399 [this](const std::vector<double> &vv) {
309 vector<RiseFallDTW::Value> s; 400 vector<RiseFallDTW::Value> s;
310 double prev = 0.0; 401 double prev = 0.0;
311 for (auto e: ee) { 402 for (auto curr: vv) {
312 double v = m_outputPreprocessor(e.getValue()); 403 s.push_back(m_riseFallPreprocessor(prev, curr));
313 if (v == prev || s.empty()) { 404 prev = curr;
314 s.push_back({ RiseFallDTW::Direction::None, 0.0 });
315 } else if (v > prev) {
316 s.push_back({ RiseFallDTW::Direction::Up, v - prev });
317 } else {
318 s.push_back({ RiseFallDTW::Direction::Down, prev - v });
319 }
320 } 405 }
321 return s; 406 return s;
322 }; 407 };
323 408
324 vector<RiseFallDTW::Value> s1 = 409 vector<RiseFallDTW::Value> s1 = preprocess(refValues);
325 convertEvents(referenceOutputSTVM->getAllEvents()); 410 vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
326
327 vector<RiseFallDTW::Value> s2 =
328 convertEvents(toAlignOutputSTVM->getAllEvents());
329 411
330 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " 412 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
331 << "Have " << s1.size() << " events from reference, " 413 << "Have " << s1.size() << " events from reference, "
332 << s2.size() << " from toAlign" << endl; 414 << s2.size() << " from toAlign" << endl;
333 415
416 SVCERR << "Reference:" << endl;
417 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
418 SVCERR << s1[i] << " ";
419 }
420 SVCERR << endl;
421
422 SVCERR << "toAlign:" << endl;
423 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
424 SVCERR << s2[i] << " ";
425 }
426 SVCERR << endl;
427
334 RiseFallDTW dtw; 428 RiseFallDTW dtw;
335
336 vector<size_t> alignment; 429 vector<size_t> alignment;
337 430
338 { 431 {
339 SVCERR << "TransformDTWAligner[" << this 432 SVCERR << "TransformDTWAligner[" << this
340 << "]: serialising DTW to avoid over-allocation" << endl; 433 << "]: serialising DTW to avoid over-allocation" << endl;
341 static QMutex mutex; 434 QMutexLocker locker(&m_dtwMutex);
342 QMutexLocker locker(&mutex);
343
344 alignment = dtw.alignSeries(s1, s2); 435 alignment = dtw.alignSeries(s1, s2);
345 } 436 }
346 437
347 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " 438 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
348 << "DTW produced " << alignment.size() << " points:" << endl; 439 << "DTW produced " << alignment.size() << " points:" << endl;
349 for (int i = 0; i < alignment.size() && i < 100; ++i) { 440 for (int i = 0; i < alignment.size() && i < 100; ++i) {
350 SVCERR << alignment[i] << " "; 441 SVCERR << alignment[i] << " ";
351 } 442 }
352 SVCERR << endl; 443 SVCERR << endl;
353 444
445 alignmentModel->setPath(makePath(alignment,
446 refFrames,
447 otherFrames,
448 alignmentModel->getSampleRate(),
449 resolution));
450
354 alignmentModel->setCompletion(100); 451 alignmentModel->setCompletion(100);
355
356 sv_frame_t resolution = referenceOutputSTVM->getResolution();
357 sv_frame_t sourceFrame = 0;
358
359 Path path(referenceOutputSTVM->getSampleRate(), resolution);
360
361 for (size_t m: alignment) {
362 path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
363 sourceFrame += resolution;
364 }
365
366 alignmentModel->setPath(path);
367 452
368 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done" 453 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done"
369 << endl; 454 << endl;
370 455
371 m_incomplete = false; 456 m_incomplete = false;