Mercurial > hg > svapp
view align/TransformDTWAligner.cpp @ 776:32e66fcc4cb7 pitch-align
Make querying and setting the external alignment program or transform separate from selecting the alignment type - we need it to work that way for a clearer UI
author | Chris Cannam |
---|---|
date | Thu, 25 Jun 2020 09:32:01 +0100 |
parents | 699b5b130ea2 |
children | 8fa98f89eda8 |
line wrap: on
line source
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ /* Sonic Visualiser An audio file viewer and annotation editor. Centre for Digital Music, Queen Mary, University of London. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. See the file COPYING included with this distribution for more information. */ #include "TransformDTWAligner.h" #include "DTW.h" #include "data/model/SparseTimeValueModel.h" #include "data/model/NoteModel.h" #include "data/model/RangeSummarisableTimeValueModel.h" #include "data/model/AlignmentModel.h" #include "data/model/AggregateWaveModel.h" #include "framework/Document.h" #include "transform/ModelTransformerFactory.h" #include "transform/FeatureExtractionModelTransformer.h" #include <QSettings> #include <QMutex> #include <QMutexLocker> using std::vector; static TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor = [](double x) { return x; }; static TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor = [](double prev, double curr) { if (prev == curr) { return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 }); } else if (curr > prev) { return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev }); } else { return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr }); } }; QMutex TransformDTWAligner::m_dtwMutex; TransformDTWAligner::TransformDTWAligner(Document *doc, ModelId reference, ModelId toAlign, Transform transform, DTWType dtwType) : m_document(doc), m_reference(reference), m_toAlign(toAlign), m_transform(transform), m_dtwType(dtwType), m_incomplete(true), m_magnitudePreprocessor(identityMagnitudePreprocessor), m_riseFallPreprocessor(identityRiseFallPreprocessor) { } TransformDTWAligner::TransformDTWAligner(Document *doc, ModelId reference, ModelId toAlign, Transform transform, MagnitudePreprocessor outputPreprocessor) : m_document(doc), m_reference(reference), m_toAlign(toAlign), m_transform(transform), m_dtwType(Magnitude), m_incomplete(true), m_magnitudePreprocessor(outputPreprocessor), m_riseFallPreprocessor(identityRiseFallPreprocessor) { } TransformDTWAligner::TransformDTWAligner(Document *doc, ModelId reference, ModelId toAlign, Transform transform, RiseFallPreprocessor outputPreprocessor) : m_document(doc), m_reference(reference), m_toAlign(toAlign), m_transform(transform), m_dtwType(RiseFall), m_incomplete(true), m_magnitudePreprocessor(identityMagnitudePreprocessor), m_riseFallPreprocessor(outputPreprocessor) { } TransformDTWAligner::~TransformDTWAligner() { if (m_incomplete) { if (auto toAlign = ModelById::get(m_toAlign)) { toAlign->setAlignment({}); } } ModelById::release(m_referenceOutputModel); ModelById::release(m_toAlignOutputModel); } bool TransformDTWAligner::isAvailable() { //!!! needs to be isAvailable(QString transformId)? return true; } void TransformDTWAligner::begin() { auto reference = ModelById::getAs<RangeSummarisableTimeValueModel>(m_reference); auto toAlign = ModelById::getAs<RangeSummarisableTimeValueModel>(m_toAlign); if (!reference || !toAlign) return; SVCERR << "TransformDTWAligner[" << this << "]: begin(): aligning " << m_toAlign << " against reference " << m_reference << endl; ModelTransformerFactory *mtf = ModelTransformerFactory::getInstance(); QString message; m_referenceOutputModel = mtf->transform(m_transform, m_reference, message); auto referenceOutputModel = ModelById::get(m_referenceOutputModel); if (!referenceOutputModel) { SVCERR << "Align::alignModel: ERROR: Failed to create reference output model (no plugin?)" << endl; emit failed(m_toAlign, message); return; } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " << m_transform.getIdentifier() << " is running on reference model" << endl; #endif message = ""; m_toAlignOutputModel = mtf->transform(m_transform, m_toAlign, message); auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); if (!toAlignOutputModel) { SVCERR << "Align::alignModel: ERROR: Failed to create toAlign output model (no plugin?)" << endl; emit failed(m_toAlign, message); return; } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: begin(): transform id " << m_transform.getIdentifier() << " is running on toAlign model" << endl; #endif connect(referenceOutputModel.get(), SIGNAL(completionChanged(ModelId)), this, SLOT(completionChanged(ModelId))); connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)), this, SLOT(completionChanged(ModelId))); auto alignmentModel = std::make_shared<AlignmentModel> (m_reference, m_toAlign, ModelId()); m_alignmentModel = ModelById::add(alignmentModel); toAlign->setAlignment(m_alignmentModel); m_document->addNonDerivedModel(m_alignmentModel); // we wouldn't normally expect these to be true here, but... int completion = 0; if (referenceOutputModel->isReady(&completion) && toAlignOutputModel->isReady(&completion)) { SVCERR << "TransformDTWAligner[" << this << "]: begin(): output models " << "are ready already! calling performAlignment" << endl; if (performAlignment()) { emit complete(m_alignmentModel); } else { emit failed(m_toAlign, tr("Failed to calculate alignment using DTW")); } } } void TransformDTWAligner::completionChanged(ModelId id) { if (!m_incomplete) { return; } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " << "model " << id << endl; #endif auto referenceOutputModel = ModelById::get(m_referenceOutputModel); auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) { return; } int referenceCompletion = 0, toAlignCompletion = 0; bool referenceReady = referenceOutputModel->isReady(&referenceCompletion); bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion); if (referenceReady && toAlignReady) { SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " << "both models ready, calling performAlignment" << endl; alignmentModel->setCompletion(95); if (performAlignment()) { emit complete(m_alignmentModel); } else { emit failed(m_toAlign, tr("Alignment of transform outputs failed")); } } else { #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " << "not ready yet: reference completion " << referenceCompletion << ", toAlign completion " << toAlignCompletion << endl; #endif int completion = std::min(referenceCompletion, toAlignCompletion); completion = (completion * 94) / 100; alignmentModel->setCompletion(completion); } } bool TransformDTWAligner::performAlignment() { if (m_dtwType == Magnitude) { return performAlignmentMagnitude(); } else { return performAlignmentRiseFall(); } } bool TransformDTWAligner::getValuesFrom(ModelId modelId, vector<sv_frame_t> &frames, vector<double> &values, sv_frame_t &resolution) { EventVector events; if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) { resolution = model->getResolution(); events = model->getAllEvents(); } else if (auto model = ModelById::getAs<NoteModel>(modelId)) { resolution = model->getResolution(); events = model->getAllEvents(); } else { SVCERR << "TransformDTWAligner::getValuesFrom: Type of model " << modelId << " is not supported" << endl; return false; } frames.clear(); values.clear(); for (auto e: events) { frames.push_back(e.getFrame()); values.push_back(e.getValue()); } return true; } Path TransformDTWAligner::makePath(const vector<size_t> &alignment, const vector<sv_frame_t> &refFrames, const vector<sv_frame_t> &otherFrames, sv_samplerate_t sampleRate, sv_frame_t resolution) { Path path(sampleRate, resolution); path.add(PathPoint(0, 0)); for (int i = 0; in_range_for(alignment, i); ++i) { // DTW returns "the index into s2 for each element in s1" sv_frame_t refFrame = refFrames[i]; if (!in_range_for(otherFrames, alignment[i])) { SVCERR << "TransformDTWAligner::makePath: Internal error: " << "DTW maps index " << i << " in reference frame vector " << "(size " << refFrames.size() << ") onto index " << alignment[i] << " in other frame vector " << "(only size " << otherFrames.size() << ")" << endl; continue; } sv_frame_t alignedFrame = otherFrames[alignment[i]]; path.add(PathPoint(alignedFrame, refFrame)); } return path; } bool TransformDTWAligner::performAlignmentMagnitude() { auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); if (!alignmentModel) { return false; } vector<sv_frame_t> refFrames, otherFrames; vector<double> refValues, otherValues; sv_frame_t resolution = 0; if (!getValuesFrom(m_referenceOutputModel, refFrames, refValues, resolution)) { return false; } if (!getValuesFrom(m_toAlignOutputModel, otherFrames, otherValues, resolution)) { return false; } vector<double> s1, s2; for (double v: refValues) { s1.push_back(m_magnitudePreprocessor(v)); } for (double v: otherValues) { s2.push_back(m_magnitudePreprocessor(v)); } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " << "Have " << s1.size() << " events from reference, " << s2.size() << " from toAlign" << endl; #endif MagnitudeDTW dtw; vector<size_t> alignment; { #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: serialising DTW to avoid over-allocation" << endl; #endif QMutexLocker locker(&m_dtwMutex); alignment = dtw.alignSeries(s1, s2); } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " << "DTW produced " << alignment.size() << " points:" << endl; for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) { SVCERR << alignment[i] << " "; } SVCERR << endl; #endif alignmentModel->setPath(makePath(alignment, refFrames, otherFrames, alignmentModel->getSampleRate(), resolution)); alignmentModel->setCompletion(100); SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: Done" << endl; m_incomplete = false; return true; } bool TransformDTWAligner::performAlignmentRiseFall() { auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); if (!alignmentModel) { return false; } vector<sv_frame_t> refFrames, otherFrames; vector<double> refValues, otherValues; sv_frame_t resolution = 0; if (!getValuesFrom(m_referenceOutputModel, refFrames, refValues, resolution)) { return false; } if (!getValuesFrom(m_toAlignOutputModel, otherFrames, otherValues, resolution)) { return false; } auto preprocess = [this](const std::vector<double> &vv) { vector<RiseFallDTW::Value> s; double prev = 0.0; for (auto curr: vv) { s.push_back(m_riseFallPreprocessor(prev, curr)); prev = curr; } return s; }; vector<RiseFallDTW::Value> s1 = preprocess(refValues); vector<RiseFallDTW::Value> s2 = preprocess(otherValues); #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " << "Have " << s1.size() << " events from reference, " << s2.size() << " from toAlign" << endl; SVCERR << "Reference:" << endl; for (int i = 0; in_range_for(s1, i) && i < 100; ++i) { SVCERR << s1[i] << " "; } SVCERR << endl; SVCERR << "toAlign:" << endl; for (int i = 0; in_range_for(s2, i) && i < 100; ++i) { SVCERR << s2[i] << " "; } SVCERR << endl; #endif RiseFallDTW dtw; vector<size_t> alignment; { #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: serialising DTW to avoid over-allocation" << endl; #endif QMutexLocker locker(&m_dtwMutex); alignment = dtw.alignSeries(s1, s2); } #ifdef DEBUG_TRANSFORM_DTW_ALIGNER SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " << "DTW produced " << alignment.size() << " points:" << endl; for (int i = 0; i < alignment.size() && i < 100; ++i) { SVCERR << alignment[i] << " "; } SVCERR << endl; #endif alignmentModel->setPath(makePath(alignment, refFrames, otherFrames, alignmentModel->getSampleRate(), resolution)); alignmentModel->setCompletion(100); SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done" << endl; m_incomplete = false; return true; }