Mercurial > hg > svapp
diff align/TransformDTWAligner.cpp @ 771:1d6cca5a5621 pitch-align
Allow use of proper sparse models (i.e. retaining event time info) in alignment; use this to switch to note alignment, which is what we have most recently been doing in the external program. Not currently producing correct results, though
author | Chris Cannam |
---|---|
date | Fri, 29 May 2020 17:39:02 +0100 |
parents | a316cb6fed81 |
children | 699b5b130ea2 |
line wrap: on
line diff
--- a/align/TransformDTWAligner.cpp Thu May 28 17:52:19 2020 +0100 +++ b/align/TransformDTWAligner.cpp Fri May 29 17:39:02 2020 +0100 @@ -16,6 +16,7 @@ #include "DTW.h" #include "data/model/SparseTimeValueModel.h" +#include "data/model/NoteModel.h" #include "data/model/RangeSummarisableTimeValueModel.h" #include "data/model/AlignmentModel.h" #include "data/model/AggregateWaveModel.h" @@ -31,6 +32,27 @@ using std::vector; +static +TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor = + [](double x) { + return x; + }; + +static +TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor = + [](double prev, double curr) { + if (prev == curr) { + return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 }); + } else if (curr > prev) { + return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev }); + } else { + return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr }); + } + }; + +QMutex +TransformDTWAligner::m_dtwMutex; + TransformDTWAligner::TransformDTWAligner(Document *doc, ModelId reference, ModelId toAlign, @@ -42,7 +64,8 @@ m_transform(transform), m_dtwType(dtwType), m_incomplete(true), - m_outputPreprocessor([](double x) { return x; }) + m_magnitudePreprocessor(identityMagnitudePreprocessor), + m_riseFallPreprocessor(identityRiseFallPreprocessor) { } @@ -50,16 +73,31 @@ ModelId reference, ModelId toAlign, Transform transform, - DTWType dtwType, - std::function<double(double)> - outputPreprocessor) : + MagnitudePreprocessor outputPreprocessor) : m_document(doc), m_reference(reference), m_toAlign(toAlign), m_transform(transform), - m_dtwType(dtwType), + m_dtwType(Magnitude), m_incomplete(true), - m_outputPreprocessor(outputPreprocessor) + m_magnitudePreprocessor(outputPreprocessor), + m_riseFallPreprocessor(identityRiseFallPreprocessor) +{ +} + +TransformDTWAligner::TransformDTWAligner(Document *doc, + ModelId reference, + ModelId toAlign, + Transform transform, + RiseFallPreprocessor outputPreprocessor) : + m_document(doc), + m_reference(reference), + m_toAlign(toAlign), + m_transform(transform), + m_dtwType(RiseFall), + m_incomplete(true), + m_magnitudePreprocessor(identityMagnitudePreprocessor), + m_riseFallPreprocessor(outputPreprocessor) { } @@ -157,10 +195,10 @@ if (!m_incomplete) { return; } - +/* SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " << "model " << id << endl; - +*/ auto referenceOutputModel = ModelById::get(m_referenceOutputModel); auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel); auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); @@ -176,7 +214,7 @@ if (referenceReady && toAlignReady) { SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " - << "ready, calling performAlignment" << endl; + << "both models ready, calling performAlignment" << endl; alignmentModel->setCompletion(95); @@ -187,11 +225,11 @@ } } else { - +/* SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: " << "not ready yet: reference completion " << referenceCompletion << ", toAlign completion " << toAlignCompletion << endl; - +*/ int completion = std::min(referenceCompletion, toAlignCompletion); completion = (completion * 94) / 100; @@ -210,76 +248,126 @@ } bool -TransformDTWAligner::performAlignmentMagnitude() +TransformDTWAligner::getValuesFrom(ModelId modelId, + vector<sv_frame_t> &frames, + vector<double> &values, + sv_frame_t &resolution) { - auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> - (m_referenceOutputModel); - auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> - (m_toAlignOutputModel); - auto alignmentModel = ModelById::getAs<AlignmentModel> - (m_alignmentModel); + EventVector events; - if (!referenceOutputSTVM || !toAlignOutputSTVM) { - //!!! what? + if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) { + resolution = model->getResolution(); + events = model->getAllEvents(); + } else if (auto model = ModelById::getAs<NoteModel>(modelId)) { + resolution = model->getResolution(); + events = model->getAllEvents(); + } else { + SVCERR << "TransformDTWAligner::getValuesFrom: Type of model " + << modelId << " is not supported" << endl; return false; } + frames.clear(); + values.clear(); + + for (auto e: events) { + frames.push_back(e.getFrame()); + values.push_back(e.getValue()); + } + + return true; +} + +Path +TransformDTWAligner::makePath(const vector<size_t> &alignment, + const vector<sv_frame_t> &refFrames, + const vector<sv_frame_t> &otherFrames, + sv_samplerate_t sampleRate, + sv_frame_t resolution) +{ + Path path(sampleRate, resolution); + + for (int i = 0; in_range_for(alignment, i); ++i) { + + // DTW returns "the index into s2 for each element in s1" + sv_frame_t refFrame = refFrames[i]; + + if (!in_range_for(otherFrames, alignment[i])) { + SVCERR << "TransformDTWAligner::makePath: Internal error: " + << "DTW maps index " << i << " in reference frame vector " + << "(size " << refFrames.size() << ") onto index " + << alignment[i] << " in other frame vector " + << "(only size " << otherFrames.size() << ")" << endl; + continue; + } + + sv_frame_t alignedFrame = otherFrames[alignment[i]]; + path.add(PathPoint(alignedFrame, refFrame)); + } + + return path; +} + +bool +TransformDTWAligner::performAlignmentMagnitude() +{ + auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); if (!alignmentModel) { return false; } + + vector<sv_frame_t> refFrames, otherFrames; + vector<double> refValues, otherValues; + sv_frame_t resolution = 0; + + if (!getValuesFrom(m_referenceOutputModel, + refFrames, refValues, resolution)) { + return false; + } + + if (!getValuesFrom(m_toAlignOutputModel, + otherFrames, otherValues, resolution)) { + return false; + } vector<double> s1, s2; - - { - auto events = referenceOutputSTVM->getAllEvents(); - for (auto e: events) { - s1.push_back(m_outputPreprocessor(e.getValue())); - } - events = toAlignOutputSTVM->getAllEvents(); - for (auto e: events) { - s2.push_back(m_outputPreprocessor(e.getValue())); - } + for (double v: refValues) { + s1.push_back(m_magnitudePreprocessor(v)); + } + for (double v: otherValues) { + s2.push_back(m_magnitudePreprocessor(v)); } SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " << "Have " << s1.size() << " events from reference, " << s2.size() << " from toAlign" << endl; - + MagnitudeDTW dtw; vector<size_t> alignment; { SVCERR << "TransformDTWAligner[" << this << "]: serialising DTW to avoid over-allocation" << endl; - static QMutex mutex; - QMutexLocker locker(&mutex); - + QMutexLocker locker(&m_dtwMutex); alignment = dtw.alignSeries(s1, s2); } SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: " << "DTW produced " << alignment.size() << " points:" << endl; - for (int i = 0; i < alignment.size() && i < 100; ++i) { + for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) { SVCERR << alignment[i] << " "; } SVCERR << endl; + alignmentModel->setPath(makePath(alignment, + refFrames, + otherFrames, + alignmentModel->getSampleRate(), + resolution)); alignmentModel->setCompletion(100); - sv_frame_t resolution = referenceOutputSTVM->getResolution(); - sv_frame_t sourceFrame = 0; - - Path path(referenceOutputSTVM->getSampleRate(), resolution); - - for (size_t m: alignment) { - path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); - sourceFrame += resolution; - } - - alignmentModel->setPath(path); - - SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: Done" - << endl; + SVCERR << "TransformDTWAligner[" << this + << "]: performAlignmentMagnitude: Done" << endl; m_incomplete = false; return true; @@ -288,59 +376,62 @@ bool TransformDTWAligner::performAlignmentRiseFall() { - auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel> - (m_referenceOutputModel); - auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel> - (m_toAlignOutputModel); - auto alignmentModel = ModelById::getAs<AlignmentModel> - (m_alignmentModel); - - if (!referenceOutputSTVM || !toAlignOutputSTVM) { - //!!! what? - return false; - } - + auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel); if (!alignmentModel) { return false; } - auto convertEvents = - [this](const EventVector &ee) { + vector<sv_frame_t> refFrames, otherFrames; + vector<double> refValues, otherValues; + sv_frame_t resolution = 0; + + if (!getValuesFrom(m_referenceOutputModel, + refFrames, refValues, resolution)) { + return false; + } + + if (!getValuesFrom(m_toAlignOutputModel, + otherFrames, otherValues, resolution)) { + return false; + } + + auto preprocess = + [this](const std::vector<double> &vv) { vector<RiseFallDTW::Value> s; double prev = 0.0; - for (auto e: ee) { - double v = m_outputPreprocessor(e.getValue()); - if (v == prev || s.empty()) { - s.push_back({ RiseFallDTW::Direction::None, 0.0 }); - } else if (v > prev) { - s.push_back({ RiseFallDTW::Direction::Up, v - prev }); - } else { - s.push_back({ RiseFallDTW::Direction::Down, prev - v }); - } + for (auto curr: vv) { + s.push_back(m_riseFallPreprocessor(prev, curr)); + prev = curr; } return s; - }; + }; - vector<RiseFallDTW::Value> s1 = - convertEvents(referenceOutputSTVM->getAllEvents()); - - vector<RiseFallDTW::Value> s2 = - convertEvents(toAlignOutputSTVM->getAllEvents()); + vector<RiseFallDTW::Value> s1 = preprocess(refValues); + vector<RiseFallDTW::Value> s2 = preprocess(otherValues); SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: " << "Have " << s1.size() << " events from reference, " << s2.size() << " from toAlign" << endl; + SVCERR << "Reference:" << endl; + for (int i = 0; in_range_for(s1, i) && i < 100; ++i) { + SVCERR << s1[i] << " "; + } + SVCERR << endl; + + SVCERR << "toAlign:" << endl; + for (int i = 0; in_range_for(s2, i) && i < 100; ++i) { + SVCERR << s2[i] << " "; + } + SVCERR << endl; + RiseFallDTW dtw; - vector<size_t> alignment; { SVCERR << "TransformDTWAligner[" << this << "]: serialising DTW to avoid over-allocation" << endl; - static QMutex mutex; - QMutexLocker locker(&mutex); - + QMutexLocker locker(&m_dtwMutex); alignment = dtw.alignSeries(s1, s2); } @@ -351,20 +442,14 @@ } SVCERR << endl; + alignmentModel->setPath(makePath(alignment, + refFrames, + otherFrames, + alignmentModel->getSampleRate(), + resolution)); + alignmentModel->setCompletion(100); - sv_frame_t resolution = referenceOutputSTVM->getResolution(); - sv_frame_t sourceFrame = 0; - - Path path(referenceOutputSTVM->getSampleRate(), resolution); - - for (size_t m: alignment) { - path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution)); - sourceFrame += resolution; - } - - alignmentModel->setPath(path); - SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done" << endl;