annotate align/TransformDTWAligner.cpp @ 785:e136dd3bb5c6

Permit setting the default alignment preference
author Chris Cannam
date Wed, 05 Aug 2020 16:05:51 +0100
parents 700fc9e4852d
children
rev   line source
Chris@767 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@767 2
Chris@767 3 /*
Chris@767 4 Sonic Visualiser
Chris@767 5 An audio file viewer and annotation editor.
Chris@767 6 Centre for Digital Music, Queen Mary, University of London.
Chris@767 7
Chris@767 8 This program is free software; you can redistribute it and/or
Chris@767 9 modify it under the terms of the GNU General Public License as
Chris@767 10 published by the Free Software Foundation; either version 2 of the
Chris@767 11 License, or (at your option) any later version. See the file
Chris@767 12 COPYING included with this distribution for more information.
Chris@767 13 */
Chris@767 14
Chris@767 15 #include "TransformDTWAligner.h"
Chris@767 16 #include "DTW.h"
Chris@767 17
Chris@767 18 #include "data/model/SparseTimeValueModel.h"
Chris@771 19 #include "data/model/NoteModel.h"
Chris@767 20 #include "data/model/RangeSummarisableTimeValueModel.h"
Chris@767 21 #include "data/model/AlignmentModel.h"
Chris@767 22 #include "data/model/AggregateWaveModel.h"
Chris@767 23
Chris@767 24 #include "framework/Document.h"
Chris@767 25
Chris@767 26 #include "transform/ModelTransformerFactory.h"
Chris@767 27 #include "transform/FeatureExtractionModelTransformer.h"
Chris@767 28
Chris@767 29 #include <QSettings>
Chris@767 30 #include <QMutex>
Chris@767 31 #include <QMutexLocker>
Chris@767 32
Chris@767 33 using std::vector;
Chris@767 34
Chris@771 35 static
Chris@771 36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
Chris@771 37 [](double x) {
Chris@771 38 return x;
Chris@771 39 };
Chris@771 40
Chris@771 41 static
Chris@771 42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
Chris@771 43 [](double prev, double curr) {
Chris@771 44 if (prev == curr) {
Chris@771 45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
Chris@771 46 } else if (curr > prev) {
Chris@771 47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
Chris@771 48 } else {
Chris@771 49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
Chris@771 50 }
Chris@771 51 };
Chris@771 52
Chris@771 53 QMutex
Chris@771 54 TransformDTWAligner::m_dtwMutex;
Chris@771 55
Chris@767 56 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@767 57 ModelId reference,
Chris@767 58 ModelId toAlign,
Chris@781 59 bool subsequence,
Chris@767 60 Transform transform,
Chris@767 61 DTWType dtwType) :
Chris@767 62 m_document(doc),
Chris@767 63 m_reference(reference),
Chris@767 64 m_toAlign(toAlign),
Chris@767 65 m_transform(transform),
Chris@767 66 m_dtwType(dtwType),
Chris@781 67 m_subsequence(subsequence),
Chris@768 68 m_incomplete(true),
Chris@771 69 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 70 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@768 71 {
Chris@768 72 }
Chris@768 73
Chris@768 74 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@768 75 ModelId reference,
Chris@768 76 ModelId toAlign,
Chris@781 77 bool subsequence,
Chris@768 78 Transform transform,
Chris@771 79 MagnitudePreprocessor outputPreprocessor) :
Chris@768 80 m_document(doc),
Chris@768 81 m_reference(reference),
Chris@768 82 m_toAlign(toAlign),
Chris@768 83 m_transform(transform),
Chris@771 84 m_dtwType(Magnitude),
Chris@781 85 m_subsequence(subsequence),
Chris@768 86 m_incomplete(true),
Chris@771 87 m_magnitudePreprocessor(outputPreprocessor),
Chris@771 88 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@771 89 {
Chris@771 90 }
Chris@771 91
Chris@771 92 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@771 93 ModelId reference,
Chris@771 94 ModelId toAlign,
Chris@781 95 bool subsequence,
Chris@771 96 Transform transform,
Chris@771 97 RiseFallPreprocessor outputPreprocessor) :
Chris@771 98 m_document(doc),
Chris@771 99 m_reference(reference),
Chris@771 100 m_toAlign(toAlign),
Chris@771 101 m_transform(transform),
Chris@771 102 m_dtwType(RiseFall),
Chris@781 103 m_subsequence(subsequence),
Chris@771 104 m_incomplete(true),
Chris@771 105 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 106 m_riseFallPreprocessor(outputPreprocessor)
Chris@767 107 {
Chris@767 108 }
Chris@767 109
Chris@767 110 TransformDTWAligner::~TransformDTWAligner()
Chris@767 111 {
Chris@767 112 if (m_incomplete) {
Chris@767 113 if (auto toAlign = ModelById::get(m_toAlign)) {
Chris@767 114 toAlign->setAlignment({});
Chris@767 115 }
Chris@767 116 }
Chris@767 117
Chris@767 118 ModelById::release(m_referenceOutputModel);
Chris@767 119 ModelById::release(m_toAlignOutputModel);
Chris@767 120 }
Chris@767 121
Chris@767 122 bool
Chris@767 123 TransformDTWAligner::isAvailable()
Chris@767 124 {
Chris@767 125 //!!! needs to be isAvailable(QString transformId)?
Chris@767 126 return true;
Chris@767 127 }
Chris@767 128
Chris@767 129 void
Chris@767 130 TransformDTWAligner::begin()
Chris@767 131 {
Chris@767 132 auto reference =
Chris@767 133 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference);
Chris@767 134 auto toAlign =
Chris@767 135 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign);
Chris@767 136
Chris@767 137 if (!reference || !toAlign) return;
Chris@767 138
Chris@767 139 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning "
Chris@767 140 << m_toAlign << " against reference " << m_reference << endl;
Chris@767 141
Chris@767 142 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance();
Chris@767 143
Chris@767 144 QString message;
Chris@767 145
Chris@767 146 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message);
Chris@767 147 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 148 if (!referenceOutputModel) {
Chris@767 149 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl;
Chris@767 150 emit failed(m_toAlign, message);
Chris@767 151 return;
Chris@767 152 }
Chris@767 153
Chris@773 154 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 155 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 156 << m_transform.getIdentifier()
Chris@767 157 << " is running on reference model" << endl;
Chris@773 158 #endif
Chris@767 159
Chris@767 160 message = "";
Chris@767 161
Chris@767 162 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message);
Chris@767 163 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@767 164 if (!toAlignOutputModel) {
Chris@767 165 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl;
Chris@767 166 emit failed(m_toAlign, message);
Chris@767 167 return;
Chris@767 168 }
Chris@767 169
Chris@773 170 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 171 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 172 << m_transform.getIdentifier()
Chris@767 173 << " is running on toAlign model" << endl;
Chris@773 174 #endif
Chris@767 175
Chris@767 176 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 177 this, SLOT(completionChanged(ModelId)));
Chris@767 178 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 179 this, SLOT(completionChanged(ModelId)));
Chris@767 180
Chris@767 181 auto alignmentModel = std::make_shared<AlignmentModel>
Chris@768 182 (m_reference, m_toAlign, ModelId());
Chris@767 183 m_alignmentModel = ModelById::add(alignmentModel);
Chris@767 184
Chris@767 185 toAlign->setAlignment(m_alignmentModel);
Chris@767 186 m_document->addNonDerivedModel(m_alignmentModel);
Chris@767 187
Chris@767 188 // we wouldn't normally expect these to be true here, but...
Chris@767 189 int completion = 0;
Chris@767 190 if (referenceOutputModel->isReady(&completion) &&
Chris@767 191 toAlignOutputModel->isReady(&completion)) {
Chris@767 192 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models "
Chris@767 193 << "are ready already! calling performAlignment" << endl;
Chris@767 194 if (performAlignment()) {
Chris@767 195 emit complete(m_alignmentModel);
Chris@767 196 } else {
Chris@767 197 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW"));
Chris@767 198 }
Chris@767 199 }
Chris@767 200 }
Chris@767 201
Chris@767 202 void
Chris@782 203 TransformDTWAligner::completionChanged(ModelId
Chris@782 204 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@782 205 id
Chris@782 206 #endif
Chris@782 207 )
Chris@767 208 {
Chris@767 209 if (!m_incomplete) {
Chris@767 210 return;
Chris@767 211 }
Chris@773 212 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 213 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 214 << "model " << id << endl;
Chris@773 215 #endif
Chris@773 216
Chris@767 217 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 218 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@768 219 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 220
Chris@768 221 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
Chris@767 222 return;
Chris@767 223 }
Chris@767 224
Chris@767 225 int referenceCompletion = 0, toAlignCompletion = 0;
Chris@767 226 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
Chris@767 227 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
Chris@767 228
Chris@767 229 if (referenceReady && toAlignReady) {
Chris@767 230
Chris@767 231 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@771 232 << "both models ready, calling performAlignment" << endl;
Chris@767 233
Chris@768 234 alignmentModel->setCompletion(95);
Chris@767 235
Chris@767 236 if (performAlignment()) {
Chris@767 237 emit complete(m_alignmentModel);
Chris@767 238 } else {
Chris@767 239 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
Chris@767 240 }
Chris@767 241
Chris@767 242 } else {
Chris@773 243 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 244 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 245 << "not ready yet: reference completion " << referenceCompletion
Chris@767 246 << ", toAlign completion " << toAlignCompletion << endl;
Chris@773 247 #endif
Chris@773 248
Chris@768 249 int completion = std::min(referenceCompletion,
Chris@768 250 toAlignCompletion);
Chris@768 251 completion = (completion * 94) / 100;
Chris@768 252 alignmentModel->setCompletion(completion);
Chris@767 253 }
Chris@767 254 }
Chris@767 255
Chris@767 256 bool
Chris@767 257 TransformDTWAligner::performAlignment()
Chris@767 258 {
Chris@767 259 if (m_dtwType == Magnitude) {
Chris@767 260 return performAlignmentMagnitude();
Chris@767 261 } else {
Chris@767 262 return performAlignmentRiseFall();
Chris@767 263 }
Chris@767 264 }
Chris@767 265
Chris@767 266 bool
Chris@771 267 TransformDTWAligner::getValuesFrom(ModelId modelId,
Chris@771 268 vector<sv_frame_t> &frames,
Chris@771 269 vector<double> &values,
Chris@771 270 sv_frame_t &resolution)
Chris@767 271 {
Chris@771 272 EventVector events;
Chris@767 273
Chris@771 274 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
Chris@771 275 resolution = model->getResolution();
Chris@771 276 events = model->getAllEvents();
Chris@771 277 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
Chris@771 278 resolution = model->getResolution();
Chris@771 279 events = model->getAllEvents();
Chris@771 280 } else {
Chris@771 281 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
Chris@771 282 << modelId << " is not supported" << endl;
Chris@767 283 return false;
Chris@767 284 }
Chris@767 285
Chris@771 286 frames.clear();
Chris@771 287 values.clear();
Chris@771 288
Chris@771 289 for (auto e: events) {
Chris@771 290 frames.push_back(e.getFrame());
Chris@771 291 values.push_back(e.getValue());
Chris@771 292 }
Chris@771 293
Chris@771 294 return true;
Chris@771 295 }
Chris@771 296
Chris@771 297 Path
Chris@771 298 TransformDTWAligner::makePath(const vector<size_t> &alignment,
Chris@771 299 const vector<sv_frame_t> &refFrames,
Chris@771 300 const vector<sv_frame_t> &otherFrames,
Chris@771 301 sv_samplerate_t sampleRate,
Chris@771 302 sv_frame_t resolution)
Chris@771 303 {
Chris@782 304 Path path(sampleRate, int(resolution));
Chris@771 305
Chris@773 306 path.add(PathPoint(0, 0));
Chris@773 307
Chris@771 308 for (int i = 0; in_range_for(alignment, i); ++i) {
Chris@771 309
Chris@780 310 // DTW returns "the index into s1 for each element in s2"
Chris@780 311 sv_frame_t alignedFrame = otherFrames[i];
Chris@780 312
Chris@780 313 if (!in_range_for(refFrames, alignment[i])) {
Chris@771 314 SVCERR << "TransformDTWAligner::makePath: Internal error: "
Chris@780 315 << "DTW maps index " << i << " in other frame vector "
Chris@780 316 << "(size " << otherFrames.size() << ") onto index "
Chris@780 317 << alignment[i] << " in ref frame vector "
Chris@780 318 << "(only size " << refFrames.size() << ")" << endl;
Chris@771 319 continue;
Chris@771 320 }
Chris@771 321
Chris@780 322 sv_frame_t refFrame = refFrames[alignment[i]];
Chris@771 323 path.add(PathPoint(alignedFrame, refFrame));
Chris@771 324 }
Chris@771 325
Chris@771 326 return path;
Chris@771 327 }
Chris@771 328
Chris@771 329 bool
Chris@771 330 TransformDTWAligner::performAlignmentMagnitude()
Chris@771 331 {
Chris@771 332 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 333 if (!alignmentModel) {
Chris@767 334 return false;
Chris@767 335 }
Chris@771 336
Chris@771 337 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 338 vector<double> refValues, otherValues;
Chris@771 339 sv_frame_t resolution = 0;
Chris@771 340
Chris@771 341 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 342 refFrames, refValues, resolution)) {
Chris@771 343 return false;
Chris@771 344 }
Chris@771 345
Chris@771 346 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 347 otherFrames, otherValues, resolution)) {
Chris@771 348 return false;
Chris@771 349 }
Chris@767 350
Chris@767 351 vector<double> s1, s2;
Chris@771 352 for (double v: refValues) {
Chris@771 353 s1.push_back(m_magnitudePreprocessor(v));
Chris@771 354 }
Chris@771 355 for (double v: otherValues) {
Chris@771 356 s2.push_back(m_magnitudePreprocessor(v));
Chris@767 357 }
Chris@767 358
Chris@773 359 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 360 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 361 << "Have " << s1.size() << " events from reference, "
Chris@767 362 << s2.size() << " from toAlign" << endl;
Chris@773 363 #endif
Chris@771 364
Chris@767 365 MagnitudeDTW dtw;
Chris@767 366 vector<size_t> alignment;
Chris@767 367
Chris@767 368 {
Chris@773 369 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 370 SVCERR << "TransformDTWAligner[" << this
Chris@767 371 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 372 #endif
Chris@771 373 QMutexLocker locker(&m_dtwMutex);
Chris@781 374 if (m_subsequence) {
Chris@781 375 alignment = dtw.alignSubsequence(s1, s2);
Chris@781 376 } else {
Chris@781 377 alignment = dtw.alignSequences(s1, s2);
Chris@781 378 }
Chris@767 379 }
Chris@767 380
Chris@773 381 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 382 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 383 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@771 384 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
Chris@767 385 SVCERR << alignment[i] << " ";
Chris@767 386 }
Chris@767 387 SVCERR << endl;
Chris@773 388 #endif
Chris@767 389
Chris@771 390 alignmentModel->setPath(makePath(alignment,
Chris@771 391 refFrames,
Chris@771 392 otherFrames,
Chris@771 393 alignmentModel->getSampleRate(),
Chris@771 394 resolution));
Chris@768 395 alignmentModel->setCompletion(100);
Chris@767 396
Chris@771 397 SVCERR << "TransformDTWAligner[" << this
Chris@771 398 << "]: performAlignmentMagnitude: Done" << endl;
Chris@767 399
Chris@767 400 m_incomplete = false;
Chris@767 401 return true;
Chris@767 402 }
Chris@767 403
Chris@767 404 bool
Chris@767 405 TransformDTWAligner::performAlignmentRiseFall()
Chris@767 406 {
Chris@771 407 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 408 if (!alignmentModel) {
Chris@767 409 return false;
Chris@767 410 }
Chris@768 411
Chris@771 412 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 413 vector<double> refValues, otherValues;
Chris@771 414 sv_frame_t resolution = 0;
Chris@771 415
Chris@771 416 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 417 refFrames, refValues, resolution)) {
Chris@771 418 return false;
Chris@771 419 }
Chris@771 420
Chris@771 421 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 422 otherFrames, otherValues, resolution)) {
Chris@771 423 return false;
Chris@771 424 }
Chris@771 425
Chris@771 426 auto preprocess =
Chris@771 427 [this](const std::vector<double> &vv) {
Chris@768 428 vector<RiseFallDTW::Value> s;
Chris@768 429 double prev = 0.0;
Chris@771 430 for (auto curr: vv) {
Chris@771 431 s.push_back(m_riseFallPreprocessor(prev, curr));
Chris@771 432 prev = curr;
Chris@768 433 }
Chris@768 434 return s;
Chris@771 435 };
Chris@767 436
Chris@771 437 vector<RiseFallDTW::Value> s1 = preprocess(refValues);
Chris@771 438 vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
Chris@767 439
Chris@773 440 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 441 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 442 << "Have " << s1.size() << " events from reference, "
Chris@767 443 << s2.size() << " from toAlign" << endl;
Chris@767 444
Chris@771 445 SVCERR << "Reference:" << endl;
Chris@771 446 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
Chris@771 447 SVCERR << s1[i] << " ";
Chris@771 448 }
Chris@771 449 SVCERR << endl;
Chris@771 450
Chris@771 451 SVCERR << "toAlign:" << endl;
Chris@771 452 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
Chris@771 453 SVCERR << s2[i] << " ";
Chris@771 454 }
Chris@771 455 SVCERR << endl;
Chris@773 456 #endif
Chris@773 457
Chris@767 458 RiseFallDTW dtw;
Chris@767 459 vector<size_t> alignment;
Chris@767 460
Chris@767 461 {
Chris@773 462 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 463 SVCERR << "TransformDTWAligner[" << this
Chris@767 464 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 465 #endif
Chris@771 466 QMutexLocker locker(&m_dtwMutex);
Chris@781 467 if (m_subsequence) {
Chris@781 468 alignment = dtw.alignSubsequence(s1, s2);
Chris@781 469 } else {
Chris@781 470 alignment = dtw.alignSequences(s1, s2);
Chris@781 471 }
Chris@767 472 }
Chris@767 473
Chris@773 474 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 475 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 476 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@767 477 for (int i = 0; i < alignment.size() && i < 100; ++i) {
Chris@767 478 SVCERR << alignment[i] << " ";
Chris@767 479 }
Chris@767 480 SVCERR << endl;
Chris@773 481 #endif
Chris@767 482
Chris@771 483 alignmentModel->setPath(makePath(alignment,
Chris@771 484 refFrames,
Chris@771 485 otherFrames,
Chris@771 486 alignmentModel->getSampleRate(),
Chris@771 487 resolution));
Chris@771 488
Chris@768 489 alignmentModel->setCompletion(100);
Chris@767 490
Chris@773 491 SVCERR << "TransformDTWAligner[" << this
Chris@773 492 << "]: performAlignmentRiseFall: Done" << endl;
Chris@767 493
Chris@767 494 m_incomplete = false;
Chris@767 495 return true;
Chris@767 496 }