annotate align/TransformDTWAligner.cpp @ 778:83a7b10b7415

Merge from branch pitch-align
author Chris Cannam
date Fri, 26 Jun 2020 13:48:52 +0100
parents 699b5b130ea2
children 8fa98f89eda8
rev   line source
Chris@767 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@767 2
Chris@767 3 /*
Chris@767 4 Sonic Visualiser
Chris@767 5 An audio file viewer and annotation editor.
Chris@767 6 Centre for Digital Music, Queen Mary, University of London.
Chris@767 7
Chris@767 8 This program is free software; you can redistribute it and/or
Chris@767 9 modify it under the terms of the GNU General Public License as
Chris@767 10 published by the Free Software Foundation; either version 2 of the
Chris@767 11 License, or (at your option) any later version. See the file
Chris@767 12 COPYING included with this distribution for more information.
Chris@767 13 */
Chris@767 14
Chris@767 15 #include "TransformDTWAligner.h"
Chris@767 16 #include "DTW.h"
Chris@767 17
Chris@767 18 #include "data/model/SparseTimeValueModel.h"
Chris@771 19 #include "data/model/NoteModel.h"
Chris@767 20 #include "data/model/RangeSummarisableTimeValueModel.h"
Chris@767 21 #include "data/model/AlignmentModel.h"
Chris@767 22 #include "data/model/AggregateWaveModel.h"
Chris@767 23
Chris@767 24 #include "framework/Document.h"
Chris@767 25
Chris@767 26 #include "transform/ModelTransformerFactory.h"
Chris@767 27 #include "transform/FeatureExtractionModelTransformer.h"
Chris@767 28
Chris@767 29 #include <QSettings>
Chris@767 30 #include <QMutex>
Chris@767 31 #include <QMutexLocker>
Chris@767 32
Chris@767 33 using std::vector;
Chris@767 34
Chris@771 35 static
Chris@771 36 TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
Chris@771 37 [](double x) {
Chris@771 38 return x;
Chris@771 39 };
Chris@771 40
Chris@771 41 static
Chris@771 42 TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
Chris@771 43 [](double prev, double curr) {
Chris@771 44 if (prev == curr) {
Chris@771 45 return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
Chris@771 46 } else if (curr > prev) {
Chris@771 47 return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
Chris@771 48 } else {
Chris@771 49 return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
Chris@771 50 }
Chris@771 51 };
Chris@771 52
Chris@771 53 QMutex
Chris@771 54 TransformDTWAligner::m_dtwMutex;
Chris@771 55
Chris@767 56 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@767 57 ModelId reference,
Chris@767 58 ModelId toAlign,
Chris@767 59 Transform transform,
Chris@767 60 DTWType dtwType) :
Chris@767 61 m_document(doc),
Chris@767 62 m_reference(reference),
Chris@767 63 m_toAlign(toAlign),
Chris@767 64 m_transform(transform),
Chris@767 65 m_dtwType(dtwType),
Chris@768 66 m_incomplete(true),
Chris@771 67 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 68 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@768 69 {
Chris@768 70 }
Chris@768 71
Chris@768 72 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@768 73 ModelId reference,
Chris@768 74 ModelId toAlign,
Chris@768 75 Transform transform,
Chris@771 76 MagnitudePreprocessor outputPreprocessor) :
Chris@768 77 m_document(doc),
Chris@768 78 m_reference(reference),
Chris@768 79 m_toAlign(toAlign),
Chris@768 80 m_transform(transform),
Chris@771 81 m_dtwType(Magnitude),
Chris@768 82 m_incomplete(true),
Chris@771 83 m_magnitudePreprocessor(outputPreprocessor),
Chris@771 84 m_riseFallPreprocessor(identityRiseFallPreprocessor)
Chris@771 85 {
Chris@771 86 }
Chris@771 87
Chris@771 88 TransformDTWAligner::TransformDTWAligner(Document *doc,
Chris@771 89 ModelId reference,
Chris@771 90 ModelId toAlign,
Chris@771 91 Transform transform,
Chris@771 92 RiseFallPreprocessor outputPreprocessor) :
Chris@771 93 m_document(doc),
Chris@771 94 m_reference(reference),
Chris@771 95 m_toAlign(toAlign),
Chris@771 96 m_transform(transform),
Chris@771 97 m_dtwType(RiseFall),
Chris@771 98 m_incomplete(true),
Chris@771 99 m_magnitudePreprocessor(identityMagnitudePreprocessor),
Chris@771 100 m_riseFallPreprocessor(outputPreprocessor)
Chris@767 101 {
Chris@767 102 }
Chris@767 103
Chris@767 104 TransformDTWAligner::~TransformDTWAligner()
Chris@767 105 {
Chris@767 106 if (m_incomplete) {
Chris@767 107 if (auto toAlign = ModelById::get(m_toAlign)) {
Chris@767 108 toAlign->setAlignment({});
Chris@767 109 }
Chris@767 110 }
Chris@767 111
Chris@767 112 ModelById::release(m_referenceOutputModel);
Chris@767 113 ModelById::release(m_toAlignOutputModel);
Chris@767 114 }
Chris@767 115
Chris@767 116 bool
Chris@767 117 TransformDTWAligner::isAvailable()
Chris@767 118 {
Chris@767 119 //!!! needs to be isAvailable(QString transformId)?
Chris@767 120 return true;
Chris@767 121 }
Chris@767 122
Chris@767 123 void
Chris@767 124 TransformDTWAligner::begin()
Chris@767 125 {
Chris@767 126 auto reference =
Chris@767 127 ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference);
Chris@767 128 auto toAlign =
Chris@767 129 ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign);
Chris@767 130
Chris@767 131 if (!reference || !toAlign) return;
Chris@767 132
Chris@767 133 SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning "
Chris@767 134 << m_toAlign << " against reference " << m_reference << endl;
Chris@767 135
Chris@767 136 ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance();
Chris@767 137
Chris@767 138 QString message;
Chris@767 139
Chris@767 140 m_referenceOutputModel = mtf->transform(m_transform, m_reference, message);
Chris@767 141 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 142 if (!referenceOutputModel) {
Chris@767 143 SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl;
Chris@767 144 emit failed(m_toAlign, message);
Chris@767 145 return;
Chris@767 146 }
Chris@767 147
Chris@773 148 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 149 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 150 << m_transform.getIdentifier()
Chris@767 151 << " is running on reference model" << endl;
Chris@773 152 #endif
Chris@767 153
Chris@767 154 message = "";
Chris@767 155
Chris@767 156 m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message);
Chris@767 157 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@767 158 if (!toAlignOutputModel) {
Chris@767 159 SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl;
Chris@767 160 emit failed(m_toAlign, message);
Chris@767 161 return;
Chris@767 162 }
Chris@767 163
Chris@773 164 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 165 SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id "
Chris@767 166 << m_transform.getIdentifier()
Chris@767 167 << " is running on toAlign model" << endl;
Chris@773 168 #endif
Chris@767 169
Chris@767 170 connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 171 this, SLOT(completionChanged(ModelId)));
Chris@767 172 connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
Chris@767 173 this, SLOT(completionChanged(ModelId)));
Chris@767 174
Chris@767 175 auto alignmentModel = std::make_shared<AlignmentModel>
Chris@768 176 (m_reference, m_toAlign, ModelId());
Chris@767 177 m_alignmentModel = ModelById::add(alignmentModel);
Chris@767 178
Chris@767 179 toAlign->setAlignment(m_alignmentModel);
Chris@767 180 m_document->addNonDerivedModel(m_alignmentModel);
Chris@767 181
Chris@767 182 // we wouldn't normally expect these to be true here, but...
Chris@767 183 int completion = 0;
Chris@767 184 if (referenceOutputModel->isReady(&completion) &&
Chris@767 185 toAlignOutputModel->isReady(&completion)) {
Chris@767 186 SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models "
Chris@767 187 << "are ready already! calling performAlignment" << endl;
Chris@767 188 if (performAlignment()) {
Chris@767 189 emit complete(m_alignmentModel);
Chris@767 190 } else {
Chris@767 191 emit failed(m_toAlign, tr("Failed to calculate alignment using DTW"));
Chris@767 192 }
Chris@767 193 }
Chris@767 194 }
Chris@767 195
Chris@767 196 void
Chris@767 197 TransformDTWAligner::completionChanged(ModelId id)
Chris@767 198 {
Chris@767 199 if (!m_incomplete) {
Chris@767 200 return;
Chris@767 201 }
Chris@773 202 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 203 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 204 << "model " << id << endl;
Chris@773 205 #endif
Chris@773 206
Chris@767 207 auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
Chris@767 208 auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
Chris@768 209 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 210
Chris@768 211 if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
Chris@767 212 return;
Chris@767 213 }
Chris@767 214
Chris@767 215 int referenceCompletion = 0, toAlignCompletion = 0;
Chris@767 216 bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
Chris@767 217 bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
Chris@767 218
Chris@767 219 if (referenceReady && toAlignReady) {
Chris@767 220
Chris@767 221 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@771 222 << "both models ready, calling performAlignment" << endl;
Chris@767 223
Chris@768 224 alignmentModel->setCompletion(95);
Chris@767 225
Chris@767 226 if (performAlignment()) {
Chris@767 227 emit complete(m_alignmentModel);
Chris@767 228 } else {
Chris@767 229 emit failed(m_toAlign, tr("Alignment of transform outputs failed"));
Chris@767 230 }
Chris@767 231
Chris@767 232 } else {
Chris@773 233 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 234 SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
Chris@767 235 << "not ready yet: reference completion " << referenceCompletion
Chris@767 236 << ", toAlign completion " << toAlignCompletion << endl;
Chris@773 237 #endif
Chris@773 238
Chris@768 239 int completion = std::min(referenceCompletion,
Chris@768 240 toAlignCompletion);
Chris@768 241 completion = (completion * 94) / 100;
Chris@768 242 alignmentModel->setCompletion(completion);
Chris@767 243 }
Chris@767 244 }
Chris@767 245
Chris@767 246 bool
Chris@767 247 TransformDTWAligner::performAlignment()
Chris@767 248 {
Chris@767 249 if (m_dtwType == Magnitude) {
Chris@767 250 return performAlignmentMagnitude();
Chris@767 251 } else {
Chris@767 252 return performAlignmentRiseFall();
Chris@767 253 }
Chris@767 254 }
Chris@767 255
Chris@767 256 bool
Chris@771 257 TransformDTWAligner::getValuesFrom(ModelId modelId,
Chris@771 258 vector<sv_frame_t> &frames,
Chris@771 259 vector<double> &values,
Chris@771 260 sv_frame_t &resolution)
Chris@767 261 {
Chris@771 262 EventVector events;
Chris@767 263
Chris@771 264 if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
Chris@771 265 resolution = model->getResolution();
Chris@771 266 events = model->getAllEvents();
Chris@771 267 } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
Chris@771 268 resolution = model->getResolution();
Chris@771 269 events = model->getAllEvents();
Chris@771 270 } else {
Chris@771 271 SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
Chris@771 272 << modelId << " is not supported" << endl;
Chris@767 273 return false;
Chris@767 274 }
Chris@767 275
Chris@771 276 frames.clear();
Chris@771 277 values.clear();
Chris@771 278
Chris@771 279 for (auto e: events) {
Chris@771 280 frames.push_back(e.getFrame());
Chris@771 281 values.push_back(e.getValue());
Chris@771 282 }
Chris@771 283
Chris@771 284 return true;
Chris@771 285 }
Chris@771 286
Chris@771 287 Path
Chris@771 288 TransformDTWAligner::makePath(const vector<size_t> &alignment,
Chris@771 289 const vector<sv_frame_t> &refFrames,
Chris@771 290 const vector<sv_frame_t> &otherFrames,
Chris@771 291 sv_samplerate_t sampleRate,
Chris@771 292 sv_frame_t resolution)
Chris@771 293 {
Chris@771 294 Path path(sampleRate, resolution);
Chris@771 295
Chris@773 296 path.add(PathPoint(0, 0));
Chris@773 297
Chris@771 298 for (int i = 0; in_range_for(alignment, i); ++i) {
Chris@771 299
Chris@771 300 // DTW returns "the index into s2 for each element in s1"
Chris@771 301 sv_frame_t refFrame = refFrames[i];
Chris@771 302
Chris@771 303 if (!in_range_for(otherFrames, alignment[i])) {
Chris@771 304 SVCERR << "TransformDTWAligner::makePath: Internal error: "
Chris@771 305 << "DTW maps index " << i << " in reference frame vector "
Chris@771 306 << "(size " << refFrames.size() << ") onto index "
Chris@771 307 << alignment[i] << " in other frame vector "
Chris@771 308 << "(only size " << otherFrames.size() << ")" << endl;
Chris@771 309 continue;
Chris@771 310 }
Chris@771 311
Chris@771 312 sv_frame_t alignedFrame = otherFrames[alignment[i]];
Chris@771 313 path.add(PathPoint(alignedFrame, refFrame));
Chris@771 314 }
Chris@771 315
Chris@771 316 return path;
Chris@771 317 }
Chris@771 318
Chris@771 319 bool
Chris@771 320 TransformDTWAligner::performAlignmentMagnitude()
Chris@771 321 {
Chris@771 322 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 323 if (!alignmentModel) {
Chris@767 324 return false;
Chris@767 325 }
Chris@771 326
Chris@771 327 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 328 vector<double> refValues, otherValues;
Chris@771 329 sv_frame_t resolution = 0;
Chris@771 330
Chris@771 331 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 332 refFrames, refValues, resolution)) {
Chris@771 333 return false;
Chris@771 334 }
Chris@771 335
Chris@771 336 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 337 otherFrames, otherValues, resolution)) {
Chris@771 338 return false;
Chris@771 339 }
Chris@767 340
Chris@767 341 vector<double> s1, s2;
Chris@771 342 for (double v: refValues) {
Chris@771 343 s1.push_back(m_magnitudePreprocessor(v));
Chris@771 344 }
Chris@771 345 for (double v: otherValues) {
Chris@771 346 s2.push_back(m_magnitudePreprocessor(v));
Chris@767 347 }
Chris@767 348
Chris@773 349 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 350 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 351 << "Have " << s1.size() << " events from reference, "
Chris@767 352 << s2.size() << " from toAlign" << endl;
Chris@773 353 #endif
Chris@771 354
Chris@767 355 MagnitudeDTW dtw;
Chris@767 356 vector<size_t> alignment;
Chris@767 357
Chris@767 358 {
Chris@773 359 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 360 SVCERR << "TransformDTWAligner[" << this
Chris@767 361 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 362 #endif
Chris@771 363 QMutexLocker locker(&m_dtwMutex);
Chris@767 364 alignment = dtw.alignSeries(s1, s2);
Chris@767 365 }
Chris@767 366
Chris@773 367 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 368 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
Chris@767 369 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@771 370 for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
Chris@767 371 SVCERR << alignment[i] << " ";
Chris@767 372 }
Chris@767 373 SVCERR << endl;
Chris@773 374 #endif
Chris@767 375
Chris@771 376 alignmentModel->setPath(makePath(alignment,
Chris@771 377 refFrames,
Chris@771 378 otherFrames,
Chris@771 379 alignmentModel->getSampleRate(),
Chris@771 380 resolution));
Chris@768 381 alignmentModel->setCompletion(100);
Chris@767 382
Chris@771 383 SVCERR << "TransformDTWAligner[" << this
Chris@771 384 << "]: performAlignmentMagnitude: Done" << endl;
Chris@767 385
Chris@767 386 m_incomplete = false;
Chris@767 387 return true;
Chris@767 388 }
Chris@767 389
Chris@767 390 bool
Chris@767 391 TransformDTWAligner::performAlignmentRiseFall()
Chris@767 392 {
Chris@771 393 auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
Chris@767 394 if (!alignmentModel) {
Chris@767 395 return false;
Chris@767 396 }
Chris@768 397
Chris@771 398 vector<sv_frame_t> refFrames, otherFrames;
Chris@771 399 vector<double> refValues, otherValues;
Chris@771 400 sv_frame_t resolution = 0;
Chris@771 401
Chris@771 402 if (!getValuesFrom(m_referenceOutputModel,
Chris@771 403 refFrames, refValues, resolution)) {
Chris@771 404 return false;
Chris@771 405 }
Chris@771 406
Chris@771 407 if (!getValuesFrom(m_toAlignOutputModel,
Chris@771 408 otherFrames, otherValues, resolution)) {
Chris@771 409 return false;
Chris@771 410 }
Chris@771 411
Chris@771 412 auto preprocess =
Chris@771 413 [this](const std::vector<double> &vv) {
Chris@768 414 vector<RiseFallDTW::Value> s;
Chris@768 415 double prev = 0.0;
Chris@771 416 for (auto curr: vv) {
Chris@771 417 s.push_back(m_riseFallPreprocessor(prev, curr));
Chris@771 418 prev = curr;
Chris@768 419 }
Chris@768 420 return s;
Chris@771 421 };
Chris@767 422
Chris@771 423 vector<RiseFallDTW::Value> s1 = preprocess(refValues);
Chris@771 424 vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
Chris@767 425
Chris@773 426 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 427 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 428 << "Have " << s1.size() << " events from reference, "
Chris@767 429 << s2.size() << " from toAlign" << endl;
Chris@767 430
Chris@771 431 SVCERR << "Reference:" << endl;
Chris@771 432 for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
Chris@771 433 SVCERR << s1[i] << " ";
Chris@771 434 }
Chris@771 435 SVCERR << endl;
Chris@771 436
Chris@771 437 SVCERR << "toAlign:" << endl;
Chris@771 438 for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
Chris@771 439 SVCERR << s2[i] << " ";
Chris@771 440 }
Chris@771 441 SVCERR << endl;
Chris@773 442 #endif
Chris@773 443
Chris@767 444 RiseFallDTW dtw;
Chris@767 445 vector<size_t> alignment;
Chris@767 446
Chris@767 447 {
Chris@773 448 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@767 449 SVCERR << "TransformDTWAligner[" << this
Chris@767 450 << "]: serialising DTW to avoid over-allocation" << endl;
Chris@773 451 #endif
Chris@771 452 QMutexLocker locker(&m_dtwMutex);
Chris@767 453 alignment = dtw.alignSeries(s1, s2);
Chris@767 454 }
Chris@767 455
Chris@773 456 #ifdef DEBUG_TRANSFORM_DTW_ALIGNER
Chris@769 457 SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
Chris@767 458 << "DTW produced " << alignment.size() << " points:" << endl;
Chris@767 459 for (int i = 0; i < alignment.size() && i < 100; ++i) {
Chris@767 460 SVCERR << alignment[i] << " ";
Chris@767 461 }
Chris@767 462 SVCERR << endl;
Chris@773 463 #endif
Chris@767 464
Chris@771 465 alignmentModel->setPath(makePath(alignment,
Chris@771 466 refFrames,
Chris@771 467 otherFrames,
Chris@771 468 alignmentModel->getSampleRate(),
Chris@771 469 resolution));
Chris@771 470
Chris@768 471 alignmentModel->setCompletion(100);
Chris@767 472
Chris@773 473 SVCERR << "TransformDTWAligner[" << this
Chris@773 474 << "]: performAlignmentRiseFall: Done" << endl;
Chris@767 475
Chris@767 476 m_incomplete = false;
Chris@767 477 return true;
Chris@767 478 }