annotate align/TransformDTWAligner.cpp @ 781:b651dc5ff555

Add subsequence option all over the place
author Chris Cannam
date Thu, 16 Jul 2020 18:01:50 +0100
parents 8fa98f89eda8
children 700fc9e4852d
rev   line source
Chris@767 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@767 2
Chris@767 3 /*
Chris@767 4 Sonic Visualiser
Chris@767 5 An audio file viewer and annotation editor.
Chris@767 6 Centre for Digital Music, Queen Mary, University of London.
Chris@767 7
Chris@767 8 This program is free software; you can redistribute it and/or
Chris@767 9 modify it under the terms of the GNU General Public License as
Chris@767 10 published by the Free Software Foundation; either version 2 of the
Chris@767 11 License, or (at your option) any later version. See the file
Chris@767 12 COPYING included with this distribution for more information.
Chris@767 13 */
Chris@767 14
Chris@767 15 #include "TransformDTWAligner.h"
Chris@767 16 #include "DTW.h"
Chris@767 17
Chris@767 18 #include "data/model/SparseTimeValueModel.h"
Chris@771 19 #include "data/model/NoteModel.h"
Chris@767 20 #include "data/model/RangeSummarisableTimeValueModel.h"
Chris@767 21 #include "data/model/AlignmentModel.h"
Chris@767 22 #include "data/model/AggregateWaveModel.h"
Chris@767 23
Chris@767 24 #include "framework/Document.h"
Chris@767 25
Chris@767 26 #include "transform/ModelTransformerFactory.h"
Chris@767 27 #include "transform/FeatureExtractionModelTransformer.h"
Chris@767 28
Chris@767 29 #include <QSettings>
Chris@767 30 #include <QMutex>
Chris@767 31 #include <QMutexLocker>
Chris@767 32
Chris@767 33 using std::vector;
Chris@767 34
Chris@771 35 static
Chris@771 36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
Chris@771 37 [](double x) {
Chris@771 38 return x;
Chris@771 39 };
Chris@771 40
Chris@771 41 static
Chris@771 42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
Chris@771 43 [](double prev, double curr) {
Chris@771 44 if (prev == curr) {
Chris@771 45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
Chris@771 46 } else if (curr > prev) {
Chris@771 47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
Chris@771 48 } else {
Chris@771 49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
Chris@771 50 }
Chris@771 51 };
Chris@771 52
Chris@771 53 QMutex
Chris@771 54 TransformDTWAligner::m_dtwMutex;
Chris@771 55
Chris@767 56 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@767 57 ModelId reference,
Chris@767 58 ModelId toAlign,
Chris@781 59 bool subsequence,
Chris@767 60 Transform transform,
Chris@767 61 DTWType dtwType) :
Chris@767 62 m_document(doc),
Chris@767 63 m_reference(reference),
Chris@767 64 m_toAlign(toAlign),
Chris@767 65 m_transform(transform),
Chris@767 66 m_dtwType(dtwType),
Chris@781 67 m_subsequence(subsequence),
Chris@768 68 m_incomplete(true),
Chris@771 69 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 70 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@768 71 {
Chris@768 72 }
Chris@768 73
Chris@768 74 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@768 75 ModelId reference,
Chris@768 76 ModelId toAlign,
Chris@781 77 bool subsequence,
Chris@768 78 Transform transform,
Chris@771 79 MagnitudePreprocessor outputPreprocessor) :
Chris@768 80 m_document(doc),
Chris@768 81 m_reference(reference),
Chris@768 82 m_toAlign(toAlign),
Chris@768 83 m_transform(transform),
Chris@771 84 m_dtwType(Magnitude),
Chris@781 85 m_subsequence(subsequence),
Chris@768 86 m_incomplete(true),
Chris@771 87 m_magnitudePreprocessor(outputPreprocessor),
Chris@771 88 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@771 89 {
Chris@771 90 }
Chris@771 91
Chris@771 92 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@771 93 ModelId reference,
Chris@771 94 ModelId toAlign,
Chris@781 95 bool subsequence,
Chris@771 96 Transform transform,
Chris@771 97 RiseFallPreprocessor outputPreprocessor) :
Chris@771 98 m_document(doc),
Chris@771 99 m_reference(reference),
Chris@771 100 m_toAlign(toAlign),
Chris@771 101 m_transform(transform),
Chris@771 102 m_dtwType(RiseFall),
Chris@781 103 m_subsequence(subsequence),
Chris@771 104 m_incomplete(true),
Chris@771 105 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 106 m_riseFallPreprocessor(outputPreprocessor)
Chris@767 107 {
Chris@767 108 }
Chris@767 109
Chris@767 110 TransformDTWAligner::~TransformDTWAligner()
Chris@767 111 {
Chris@767 112 if (m_incomplete) {
Chris@767 113 if (auto toAlign = ModelById::get(m_toAlign)) {
Chris@767 114 toAlign->setAlignment({});
Chris@767 115 }
Chris@767 116 }
Chris@767 117
Chris@767 118 ModelById::release(m_referenceOutputModel);
Chris@767 119 ModelById::release(m_toAlignOutputModel);
Chris@767 120 }
Chris@767 121
Chris@767 122 bool
Chris@767 123 TransformDTWAligner::isAvailable()
Chris@767 124 {
Chris@767 125 //!!! needs to be isAvailable(QString transformId)?
Chris@767 126 return true;
Chris@767 127 }
Chris@767 128
Chris@767 129 void
Chris@767 130 TransformDTWAligner::begin()
Chris@767 131 {
Chris@767 132 auto reference =
Chris@767 133 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference);
Chris@767 134 auto toAlign =
Chris@767 135 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign);
Chris@767 136
Chris@767 137 if (!reference || !toAlign) return;
Chris@767 138
Chris@767 139 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning "
Chris@767 140 << m_toAlign << " against reference " << m_reference << endl;
Chris@767 141
Chris@767 142 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance();
Chris@767 143
Chris@767 144 QString message;
Chris@767 145
Chris@767 146 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message);
Chris@767 147 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 148 if (!referenceOutputModel) {
Chris@767 149 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl;
Chris@767 150 emit failed(m_toAlign, message);
Chris@767 151 return;
Chris@767 152 }
Chris@767 153
Chris@773 154 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 155 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 156 << m_transform.getIdentifier()
Chris@767 157 << " is running on reference model" << endl;
Chris@773 158 #endif
Chris@767 159
Chris@767 160 message = "";
Chris@767 161
Chris@767 162 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message);
Chris@767 163 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@767 164 if (!toAlignOutputModel) {
Chris@767 165 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl;
Chris@767 166 emit failed(m_toAlign, message);
Chris@767 167 return;
Chris@767 168 }
Chris@767 169
Chris@773 170 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 171 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 172 << m_transform.getIdentifier()
Chris@767 173 << " is running on toAlign model" << endl;
Chris@773 174 #endif
Chris@767 175
Chris@767 176 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 177 this, SLOT(completionChanged(ModelId)));
Chris@767 178 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 179 this, SLOT(completionChanged(ModelId)));
Chris@767 180
Chris@767 181 auto alignmentModel = std::make_shared<AlignmentModel>
Chris@768 182 (m_reference, m_toAlign, ModelId());
Chris@767 183 m_alignmentModel = ModelById::add(alignmentModel);
Chris@767 184
Chris@767 185 toAlign->setAlignment(m_alignmentModel);
Chris@767 186 m_document->addNonDerivedModel(m_alignmentModel);
Chris@767 187
Chris@767 188 // we wouldn't normally expect these to be true here, but...
Chris@767 189 int completion = 0;
Chris@767 190 if (referenceOutputModel->isReady(&completion) &&
Chris@767 191 toAlignOutputModel->isReady(&completion)) {
Chris@767 192 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models "
Chris@767 193 << "are ready already! calling performAlignment" << endl;
Chris@767 194 if (performAlignment()) {
Chris@767 195 emit complete(m_alignmentModel);
Chris@767 196 } else {
Chris@767 197 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW"));
Chris@767 198 }
Chris@767 199 }
Chris@767 200 }
Chris@767 201
Chris@767 202 void
Chris@767 203 TransformDTWAligner::completionChanged(ModelId id)
Chris@767 204 {
Chris@767 205 if (!m_incomplete) {
Chris@767 206 return;
Chris@767 207 }
Chris@773 208 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 209 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 210 << "model " << id << endl;
Chris@773 211 #endif
Chris@773 212
Chris@767 213 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 214 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@768 215 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 216
Chris@768 217 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
Chris@767 218 return;
Chris@767 219 }
Chris@767 220
Chris@767 221 int referenceCompletion = 0, toAlignCompletion = 0;
Chris@767 222 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
Chris@767 223 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
Chris@767 224
Chris@767 225 if (referenceReady && toAlignReady) {
Chris@767 226
Chris@767 227 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@771 228 << "both models ready, calling performAlignment" << endl;
Chris@767 229
Chris@768 230 alignmentModel->setCompletion(95);
Chris@767 231
Chris@767 232 if (performAlignment()) {
Chris@767 233 emit complete(m_alignmentModel);
Chris@767 234 } else {
Chris@767 235 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
Chris@767 236 }
Chris@767 237
Chris@767 238 } else {
Chris@773 239 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 240 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 241 << "not ready yet: reference completion " << referenceCompletion
Chris@767 242 << ", toAlign completion " << toAlignCompletion << endl;
Chris@773 243 #endif
Chris@773 244
Chris@768 245 int completion = std::min(referenceCompletion,
Chris@768 246 toAlignCompletion);
Chris@768 247 completion = (completion * 94) / 100;
Chris@768 248 alignmentModel->setCompletion(completion);
Chris@767 249 }
Chris@767 250 }
Chris@767 251
Chris@767 252 bool
Chris@767 253 TransformDTWAligner::performAlignment()
Chris@767 254 {
Chris@767 255 if (m_dtwType == Magnitude) {
Chris@767 256 return performAlignmentMagnitude();
Chris@767 257 } else {
Chris@767 258 return performAlignmentRiseFall();
Chris@767 259 }
Chris@767 260 }
Chris@767 261
Chris@767 262 bool
Chris@771 263 TransformDTWAligner::getValuesFrom(ModelId modelId,
Chris@771 264 vector<sv_frame_t> &frames,
Chris@771 265 vector<double> &values,
Chris@771 266 sv_frame_t &resolution)
Chris@767 267 {
Chris@771 268 EventVector events;
Chris@767 269
Chris@771 270 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
Chris@771 271 resolution = model->getResolution();
Chris@771 272 events = model->getAllEvents();
Chris@771 273 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
Chris@771 274 resolution = model->getResolution();
Chris@771 275 events = model->getAllEvents();
Chris@771 276 } else {
Chris@771 277 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
Chris@771 278 << modelId << " is not supported" << endl;
Chris@767 279 return false;
Chris@767 280 }
Chris@767 281
Chris@771 282 frames.clear();
Chris@771 283 values.clear();
Chris@771 284
Chris@771 285 for (auto e: events) {
Chris@771 286 frames.push_back(e.getFrame());
Chris@771 287 values.push_back(e.getValue());
Chris@771 288 }
Chris@771 289
Chris@771 290 return true;
Chris@771 291 }
Chris@771 292
Chris@771 293 Path
Chris@771 294 TransformDTWAligner::makePath(const vector<size_t> &alignment,
Chris@771 295 const vector<sv_frame_t> &refFrames,
Chris@771 296 const vector<sv_frame_t> &otherFrames,
Chris@771 297 sv_samplerate_t sampleRate,
Chris@771 298 sv_frame_t resolution)
Chris@771 299 {
Chris@771 300 Path path(sampleRate, resolution);
Chris@771 301
Chris@773 302 path.add(PathPoint(0, 0));
Chris@773 303
Chris@771 304 for (int i = 0; in_range_for(alignment, i); ++i) {
Chris@771 305
Chris@780 306 // DTW returns "the index into s1 for each element in s2"
Chris@780 307 sv_frame_t alignedFrame = otherFrames[i];
Chris@780 308
Chris@780 309 if (!in_range_for(refFrames, alignment[i])) {
Chris@771 310 SVCERR << "TransformDTWAligner::makePath: Internal error: "
Chris@780 311 << "DTW maps index " << i << " in other frame vector "
Chris@780 312 << "(size " << otherFrames.size() << ") onto index "
Chris@780 313 << alignment[i] << " in ref frame vector "
Chris@780 314 << "(only size " << refFrames.size() << ")" << endl;
Chris@771 315 continue;
Chris@771 316 }
Chris@771 317
Chris@780 318 sv_frame_t refFrame = refFrames[alignment[i]];
Chris@771 319 path.add(PathPoint(alignedFrame, refFrame));
Chris@771 320 }
Chris@771 321
Chris@771 322 return path;
Chris@771 323 }
Chris@771 324
Chris@771 325 bool
Chris@771 326 TransformDTWAligner::performAlignmentMagnitude()
Chris@771 327 {
Chris@771 328 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 329 if (!alignmentModel) {
Chris@767 330 return false;
Chris@767 331 }
Chris@771 332
Chris@771 333 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 334 vector<double> refValues, otherValues;
Chris@771 335 sv_frame_t resolution = 0;
Chris@771 336
Chris@771 337 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 338 refFrames, refValues, resolution)) {
Chris@771 339 return false;
Chris@771 340 }
Chris@771 341
Chris@771 342 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 343 otherFrames, otherValues, resolution)) {
Chris@771 344 return false;
Chris@771 345 }
Chris@767 346
Chris@767 347 vector<double> s1, s2;
Chris@771 348 for (double v: refValues) {
Chris@771 349 s1.push_back(m_magnitudePreprocessor(v));
Chris@771 350 }
Chris@771 351 for (double v: otherValues) {
Chris@771 352 s2.push_back(m_magnitudePreprocessor(v));
Chris@767 353 }
Chris@767 354
Chris@773 355 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 356 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 357 << "Have " << s1.size() << " events from reference, "
Chris@767 358 << s2.size() << " from toAlign" << endl;
Chris@773 359 #endif
Chris@771 360
Chris@767 361 MagnitudeDTW dtw;
Chris@767 362 vector<size_t> alignment;
Chris@767 363
Chris@767 364 {
Chris@773 365 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 366 SVCERR << "TransformDTWAligner[" << this
Chris@767 367 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 368 #endif
Chris@771 369 QMutexLocker locker(&m_dtwMutex);
Chris@781 370 if (m_subsequence) {
Chris@781 371 alignment = dtw.alignSubsequence(s1, s2);
Chris@781 372 } else {
Chris@781 373 alignment = dtw.alignSequences(s1, s2);
Chris@781 374 }
Chris@767 375 }
Chris@767 376
Chris@773 377 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 378 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 379 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@771 380 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
Chris@767 381 SVCERR << alignment[i] << " ";
Chris@767 382 }
Chris@767 383 SVCERR << endl;
Chris@773 384 #endif
Chris@767 385
Chris@771 386 alignmentModel->setPath(makePath(alignment,
Chris@771 387 refFrames,
Chris@771 388 otherFrames,
Chris@771 389 alignmentModel->getSampleRate(),
Chris@771 390 resolution));
Chris@768 391 alignmentModel->setCompletion(100);
Chris@767 392
Chris@771 393 SVCERR << "TransformDTWAligner[" << this
Chris@771 394 << "]: performAlignmentMagnitude: Done" << endl;
Chris@767 395
Chris@767 396 m_incomplete = false;
Chris@767 397 return true;
Chris@767 398 }
Chris@767 399
Chris@767 400 bool
Chris@767 401 TransformDTWAligner::performAlignmentRiseFall()
Chris@767 402 {
Chris@771 403 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 404 if (!alignmentModel) {
Chris@767 405 return false;
Chris@767 406 }
Chris@768 407
Chris@771 408 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 409 vector<double> refValues, otherValues;
Chris@771 410 sv_frame_t resolution = 0;
Chris@771 411
Chris@771 412 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 413 refFrames, refValues, resolution)) {
Chris@771 414 return false;
Chris@771 415 }
Chris@771 416
Chris@771 417 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 418 otherFrames, otherValues, resolution)) {
Chris@771 419 return false;
Chris@771 420 }
Chris@771 421
Chris@771 422 auto preprocess =
Chris@771 423 [this](const std::vector<double> &vv) {
Chris@768 424 vector<RiseFallDTW::Value> s;
Chris@768 425 double prev = 0.0;
Chris@771 426 for (auto curr: vv) {
Chris@771 427 s.push_back(m_riseFallPreprocessor(prev, curr));
Chris@771 428 prev = curr;
Chris@768 429 }
Chris@768 430 return s;
Chris@771 431 };
Chris@767 432
Chris@771 433 vector<RiseFallDTW::Value> s1 = preprocess(refValues);
Chris@771 434 vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
Chris@767 435
Chris@773 436 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 437 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 438 << "Have " << s1.size() << " events from reference, "
Chris@767 439 << s2.size() << " from toAlign" << endl;
Chris@767 440
Chris@771 441 SVCERR << "Reference:" << endl;
Chris@771 442 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
Chris@771 443 SVCERR << s1[i] << " ";
Chris@771 444 }
Chris@771 445 SVCERR << endl;
Chris@771 446
Chris@771 447 SVCERR << "toAlign:" << endl;
Chris@771 448 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
Chris@771 449 SVCERR << s2[i] << " ";
Chris@771 450 }
Chris@771 451 SVCERR << endl;
Chris@773 452 #endif
Chris@773 453
Chris@767 454 RiseFallDTW dtw;
Chris@767 455 vector<size_t> alignment;
Chris@767 456
Chris@767 457 {
Chris@773 458 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 459 SVCERR << "TransformDTWAligner[" << this
Chris@767 460 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 461 #endif
Chris@771 462 QMutexLocker locker(&m_dtwMutex);
Chris@781 463 if (m_subsequence) {
Chris@781 464 alignment = dtw.alignSubsequence(s1, s2);
Chris@781 465 } else {
Chris@781 466 alignment = dtw.alignSequences(s1, s2);
Chris@781 467 }
Chris@767 468 }
Chris@767 469
Chris@773 470 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 471 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 472 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@767 473 for (int i = 0; i < alignment.size() && i < 100; ++i) {
Chris@767 474 SVCERR << alignment[i] << " ";
Chris@767 475 }
Chris@767 476 SVCERR << endl;
Chris@773 477 #endif
Chris@767 478
Chris@771 479 alignmentModel->setPath(makePath(alignment,
Chris@771 480 refFrames,
Chris@771 481 otherFrames,
Chris@771 482 alignmentModel->getSampleRate(),
Chris@771 483 resolution));
Chris@771 484
Chris@768 485 alignmentModel->setCompletion(100);
Chris@767 486
Chris@773 487 SVCERR << "TransformDTWAligner[" << this
Chris@773 488 << "]: performAlignmentRiseFall: Done" << endl;
Chris@767 489
Chris@767 490 m_incomplete = false;
Chris@767 491 return true;
Chris@767 492 }