Mercurial > hg > svapp
diff align/TransformDTWAligner.cpp @ 767:dd742e566e60 pitch-align
Make a start on further alignment methods
author | Chris Cannam |
---|---|
date | Thu, 21 May 2020 16:21:57 +0100 |
parents | |
children | 1b1960009be6 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/align/TransformDTWAligner.cpp Thu May 21 16:21:57 2020 +0100 @@ -0,0 +1,390 @@ +/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ + +/* + Sonic Visualiser + An audio file viewer and annotation editor. + Centre for Digital Music, Queen Mary, University of London. + + This program is free software; you can redistribute it and/or + modify it under the terms of the GNU General Public License as + published by the Free Software Foundation; either version 2 of the + License, or (at your option) any later version. See the file + COPYING included with this distribution for more information. +*/ + +#include "TransformDTWAligner.h" +#include "DTW.h" + +#include "data/model/SparseTimeValueModel.h" +#include "data/model/RangeSummarisableTimeValueModel.h" +#include "data/model/AlignmentModel.h" +#include "data/model/AggregateWaveModel.h" + +#include "framework/Document.h" + +#include "transform/ModelTransformerFactory.h" +#include "transform/FeatureExtractionModelTransformer.h" + +#include <QSettings> +#include <QMutex> +#include <QMutexLocker> + +using std::vector; + +TransformDTWAligner::TransformDTWAligner(Document *doc, + ModelId reference, + ModelId toAlign, + Transform transform, + DTWType dtwType) : + m_document(doc), + m_reference(reference), + m_toAlign(toAlign), + m_referenceTransformComplete(false), + m_toAlignTransformComplete(false), + m_transform(transform), + m_dtwType(dtwType), + m_incomplete(true) +{ +} + +TransformDTWAligner::~TransformDTWAligner() +{ + if (m_incomplete) { + if (auto toAlign = ModelById::get(m_toAlign)) { + toAlign->setAlignment({}); + } + } + + ModelById::release(m_referenceOutputModel); + ModelById::release(m_toAlignOutputModel); + ModelById::release(m_alignmentProgressModel); +} + +bool +TransformDTWAligner::isAvailable() +{ + //!!! needs to be isAvailable(QString transformId)? + return true; +} + +void +TransformDTWAligner::begin() +{ + auto reference = + ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference); + auto toAlign = + ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign); + + if (!reference || !toAlign) return; + + SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning " + << m_toAlign << " against reference " << m_reference << endl; + + ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance(); + + QString message; + + m_referenceOutputModel = mtf->transform(m_transform, m_reference, message); + auto referenceOutputModel = ModelById::get(m_referenceOutputModel); + if (!referenceOutputModel) { + SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl; + emit failed(m_toAlign, message); + return; + } + + SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " + << m_transform.getIdentifier() + << " is running on reference model" << endl; + + message = ""; + + m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message); + auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); + if (!toAlignOutputModel) { + SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl; + emit failed(m_toAlign, message); + return; + } + + SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " + << m_transform.getIdentifier() + << " is running on toAlign model" << endl; + + connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)), + this, SLOT(completionChanged(ModelId))); + connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)), + this, SLOT(completionChanged(ModelId))); + + auto alignmentProgressModel = std::make_shared<SparseTimeValueModel> + (reference->getSampleRate(), m_transform.getStepSize(), false); + alignmentProgressModel->setCompletion(0); + m_alignmentProgressModel = ModelById::add(alignmentProgressModel); + + auto alignmentModel = std::make_shared<AlignmentModel> + (m_reference, m_toAlign, m_alignmentProgressModel); + m_alignmentModel = ModelById::add(alignmentModel); + + toAlign->setAlignment(m_alignmentModel); + m_document->addNonDerivedModel(m_alignmentModel); + + // we wouldn't normally expect these to be true here, but... + int completion = 0; + if (referenceOutputModel->isReady(&completion) && + toAlignOutputModel->isReady(&completion)) { + SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models " + << "are ready already! calling performAlignment" << endl; + if (performAlignment()) { + emit complete(m_alignmentModel); + } else { + emit failed(m_toAlign, tr("Failed to calculate alignment using DTW")); + } + } +} + +void +TransformDTWAligner::completionChanged(ModelId id) +{ + if (!m_incomplete) { + return; + } + + SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " + << "model " << id << endl; + + auto referenceOutputModel = ModelById::get(m_referenceOutputModel); + auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); + + if (!referenceOutputModel || !toAlignOutputModel) { + return; + } + + int referenceCompletion = 0, toAlignCompletion = 0; + bool referenceReady = referenceOutputModel->isReady(&referenceCompletion); + bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); + + auto alignmentProgressModel = + ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); + + if (referenceReady && toAlignReady) { + + SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " + << "ready, calling performAlignment" << endl; + + if (alignmentProgressModel) { + alignmentProgressModel->setCompletion(95); + } + + if (performAlignment()) { + emit complete(m_alignmentModel); + } else { + emit failed(m_toAlign, tr("Alignment of transform outputs failed")); + } + + } else { + + SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " + << "not ready yet: reference completion " << referenceCompletion + << ", toAlign completion " << toAlignCompletion << endl; + + if (alignmentProgressModel) { + int completion = std::min(referenceCompletion, + toAlignCompletion); + completion = (completion * 94) / 100; + alignmentProgressModel->setCompletion(completion); + } + } +} + +bool +TransformDTWAligner::performAlignment() +{ + if (m_dtwType == Magnitude) { + return performAlignmentMagnitude(); + } else { + return performAlignmentRiseFall(); + } +} + +bool +TransformDTWAligner::performAlignmentMagnitude() +{ + auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> + (m_referenceOutputModel); + auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> + (m_toAlignOutputModel); + auto alignmentModel = ModelById::getAs<AlignmentModel> + (m_alignmentModel); + + if (!referenceOutputSTVM || !toAlignOutputSTVM) { + //!!! what? + return false; + } + + if (!alignmentModel) { + return false; + } + + vector<double> s1, s2; + + { + auto events = referenceOutputSTVM->getAllEvents(); + for (auto e: events) { + s1.push_back(e.getValue()); + } + events = toAlignOutputSTVM->getAllEvents(); + for (auto e: events) { + s2.push_back(e.getValue()); + } + } + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " + << "Have " << s1.size() << " events from reference, " + << s2.size() << " from toAlign" << endl; + + MagnitudeDTW dtw; + vector<size_t> alignment; + + { + SVCERR << "TransformDTWAligner[" << this + << "]: serialising DTW to avoid over-allocation" << endl; + static QMutex mutex; + QMutexLocker locker(&mutex); + + alignment = dtw.alignSeries(s1, s2); + } + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " + << "DTW produced " << alignment.size() << " points:" << endl; + for (int i = 0; i < alignment.size() && i < 100; ++i) { + SVCERR << alignment[i] << " "; + } + SVCERR << endl; + + auto alignmentProgressModel = + ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); + if (alignmentProgressModel) { + alignmentProgressModel->setCompletion(100); + } + + // clear the alignment progress model + alignmentModel->setPathFrom(ModelId()); + + sv_frame_t resolution = referenceOutputSTVM->getResolution(); + sv_frame_t sourceFrame = 0; + + Path path(referenceOutputSTVM->getSampleRate(), resolution); + + for (size_t m: alignment) { + path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); + sourceFrame += resolution; + } + + alignmentModel->setPath(path); + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done" + << endl; + + m_incomplete = false; + return true; +} + +bool +TransformDTWAligner::performAlignmentRiseFall() +{ + auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> + (m_referenceOutputModel); + auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> + (m_toAlignOutputModel); + auto alignmentModel = ModelById::getAs<AlignmentModel> + (m_alignmentModel); + + if (!referenceOutputSTVM || !toAlignOutputSTVM) { + //!!! what? + return false; + } + + if (!alignmentModel) { + return false; + } + + vector<RiseFallDTW::Value> s1, s2; + double prev1 = 0.0, prev2 = 0.0; + + { + auto events = referenceOutputSTVM->getAllEvents(); + for (auto e: events) { + double v = e.getValue(); + //!!! the original does this using MIDI pitch for the + //!!! pYin transform... rework with a lambda passed in + //!!! for modification maybe? + factor out s1/s2 of course + if (v > prev1) { + s1.push_back({ RiseFallDTW::Direction::Up, v - prev1 }); + } else { + s1.push_back({ RiseFallDTW::Direction::Down, prev1 - v }); + } + prev1 = v; + } + events = toAlignOutputSTVM->getAllEvents(); + for (auto e: events) { + double v = e.getValue(); + //!!! as above + if (v > prev2) { + s2.push_back({ RiseFallDTW::Direction::Up, v - prev2 }); + } else { + s2.push_back({ RiseFallDTW::Direction::Down, prev2 - v }); + } + prev2 = v; + } + } + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " + << "Have " << s1.size() << " events from reference, " + << s2.size() << " from toAlign" << endl; + + RiseFallDTW dtw; + + vector<size_t> alignment; + + { + SVCERR << "TransformDTWAligner[" << this + << "]: serialising DTW to avoid over-allocation" << endl; + static QMutex mutex; + QMutexLocker locker(&mutex); + + alignment = dtw.alignSeries(s1, s2); + } + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: " + << "DTW produced " << alignment.size() << " points:" << endl; + for (int i = 0; i < alignment.size() && i < 100; ++i) { + SVCERR << alignment[i] << " "; + } + SVCERR << endl; + + auto alignmentProgressModel = + ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel); + if (alignmentProgressModel) { + alignmentProgressModel->setCompletion(100); + } + + // clear the alignment progress model + alignmentModel->setPathFrom(ModelId()); + + sv_frame_t resolution = referenceOutputSTVM->getResolution(); + sv_frame_t sourceFrame = 0; + + Path path(referenceOutputSTVM->getSampleRate(), resolution); + + for (size_t m: alignment) { + path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); + sourceFrame += resolution; + } + + alignmentModel->setPath(path); + + SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: Done" + << endl; + + m_incomplete = false; + return true; +}