diff align/TransformDTWAligner.cpp @ 771:1d6cca5a5621 pitch-align

Allow use of proper sparse models (i.e. retaining event time info) in alignment; use this to switch to note alignment, which is what we have most recently been doing in the external program. Not currently producing correct results, though
author Chris Cannam
date Fri, 29 May 2020 17:39:02 +0100
parents a316cb6fed81
children 699b5b130ea2
line wrap: on
line diff
--- a/align/TransformDTWAligner.cpp	Thu May 28 17:52:19 2020 +0100
+++ b/align/TransformDTWAligner.cpp	Fri May 29 17:39:02 2020 +0100
@@ -16,6 +16,7 @@
 #include "DTW.h"
 
 #include "data/model/SparseTimeValueModel.h"
+#include "data/model/NoteModel.h"
 #include "data/model/RangeSummarisableTimeValueModel.h"
 #include "data/model/AlignmentModel.h"
 #include "data/model/AggregateWaveModel.h"
@@ -31,6 +32,27 @@
 
 using std::vector;
 
+static
+TransformDTWAligner::MagnitudePreprocessor identityMagnitudePreprocessor =
+    [](double x) {
+        return x;
+    };
+
+static
+TransformDTWAligner::RiseFallPreprocessor identityRiseFallPreprocessor =
+    [](double prev, double curr) {
+        if (prev == curr) {
+            return RiseFallDTW::Value({ RiseFallDTW::Direction::None, 0.0 });
+        } else if (curr > prev) {
+            return RiseFallDTW::Value({ RiseFallDTW::Direction::Up, curr - prev });
+        } else {
+            return RiseFallDTW::Value({ RiseFallDTW::Direction::Down, prev - curr });
+        }
+    };
+
+QMutex
+TransformDTWAligner::m_dtwMutex;
+
 TransformDTWAligner::TransformDTWAligner(Document *doc,
                                          ModelId reference,
                                          ModelId toAlign,
@@ -42,7 +64,8 @@
     m_transform(transform),
     m_dtwType(dtwType),
     m_incomplete(true),
-    m_outputPreprocessor([](double x) { return x; })
+    m_magnitudePreprocessor(identityMagnitudePreprocessor),
+    m_riseFallPreprocessor(identityRiseFallPreprocessor)
 {
 }
 
@@ -50,16 +73,31 @@
                                          ModelId reference,
                                          ModelId toAlign,
                                          Transform transform,
-                                         DTWType dtwType,
-                                         std::function<double(double)>
-                                         outputPreprocessor) :
+                                         MagnitudePreprocessor outputPreprocessor) :
     m_document(doc),
     m_reference(reference),
     m_toAlign(toAlign),
     m_transform(transform),
-    m_dtwType(dtwType),
+    m_dtwType(Magnitude),
     m_incomplete(true),
-    m_outputPreprocessor(outputPreprocessor)
+    m_magnitudePreprocessor(outputPreprocessor),
+    m_riseFallPreprocessor(identityRiseFallPreprocessor)
+{
+}
+
+TransformDTWAligner::TransformDTWAligner(Document *doc,
+                                         ModelId reference,
+                                         ModelId toAlign,
+                                         Transform transform,
+                                         RiseFallPreprocessor outputPreprocessor) :
+    m_document(doc),
+    m_reference(reference),
+    m_toAlign(toAlign),
+    m_transform(transform),
+    m_dtwType(RiseFall),
+    m_incomplete(true),
+    m_magnitudePreprocessor(identityMagnitudePreprocessor),
+    m_riseFallPreprocessor(outputPreprocessor)
 {
 }
 
@@ -157,10 +195,10 @@
     if (!m_incomplete) {
         return;
     }
-
+/*
     SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
            << "model " << id << endl;
-
+*/
     auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
     auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
     auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
@@ -176,7 +214,7 @@
     if (referenceReady && toAlignReady) {
 
         SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
-               << "ready, calling performAlignment" << endl;
+               << "both models ready, calling performAlignment" << endl;
 
         alignmentModel->setCompletion(95);
         
@@ -187,11 +225,11 @@
         }
 
     } else {
-
+/*
         SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
                << "not ready yet: reference completion " << referenceCompletion
                << ", toAlign completion " << toAlignCompletion << endl;
-
+*/
         int completion = std::min(referenceCompletion,
                                   toAlignCompletion);
         completion = (completion * 94) / 100;
@@ -210,76 +248,126 @@
 }
 
 bool
-TransformDTWAligner::performAlignmentMagnitude()
+TransformDTWAligner::getValuesFrom(ModelId modelId,
+                                   vector<sv_frame_t> &frames,
+                                   vector<double> &values,
+                                   sv_frame_t &resolution)
 {
-    auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel>
-        (m_referenceOutputModel);
-    auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
-        (m_toAlignOutputModel);
-    auto alignmentModel = ModelById::getAs<AlignmentModel>
-        (m_alignmentModel);
+    EventVector events;
 
-    if (!referenceOutputSTVM || !toAlignOutputSTVM) {
-        //!!! what?
+    if (auto model = ModelById::getAs<SparseTimeValueModel>(modelId)) {
+        resolution = model->getResolution();
+        events = model->getAllEvents();
+    } else if (auto model = ModelById::getAs<NoteModel>(modelId)) {
+        resolution = model->getResolution();
+        events = model->getAllEvents();
+    } else {
+        SVCERR << "TransformDTWAligner::getValuesFrom: Type of model "
+               << modelId << " is not supported" << endl;
         return false;
     }
 
+    frames.clear();
+    values.clear();
+
+    for (auto e: events) {
+        frames.push_back(e.getFrame());
+        values.push_back(e.getValue());
+    }
+
+    return true;
+}
+
+Path
+TransformDTWAligner::makePath(const vector<size_t> &alignment,
+                              const vector<sv_frame_t> &refFrames,
+                              const vector<sv_frame_t> &otherFrames,
+                              sv_samplerate_t sampleRate,
+                              sv_frame_t resolution)
+{
+    Path path(sampleRate, resolution);
+
+    for (int i = 0; in_range_for(alignment, i); ++i) {
+
+        // DTW returns "the index into s2 for each element in s1"
+        sv_frame_t refFrame = refFrames[i];
+
+        if (!in_range_for(otherFrames, alignment[i])) {
+            SVCERR << "TransformDTWAligner::makePath: Internal error: "
+                   << "DTW maps index " << i << " in reference frame vector "
+                   << "(size " << refFrames.size() << ") onto index "
+                   << alignment[i] << " in other frame vector "
+                   << "(only size " << otherFrames.size() << ")" << endl;
+            continue;
+        }
+            
+        sv_frame_t alignedFrame = otherFrames[alignment[i]];
+        path.add(PathPoint(alignedFrame, refFrame));
+    }
+
+    return path;
+}
+
+bool
+TransformDTWAligner::performAlignmentMagnitude()
+{
+    auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
     if (!alignmentModel) {
         return false;
     }
+
+    vector<sv_frame_t> refFrames, otherFrames;
+    vector<double> refValues, otherValues;
+    sv_frame_t resolution = 0;
+
+    if (!getValuesFrom(m_referenceOutputModel,
+                       refFrames, refValues, resolution)) {
+        return false;
+    }
+
+    if (!getValuesFrom(m_toAlignOutputModel,
+                       otherFrames, otherValues, resolution)) {
+        return false;
+    }
     
     vector<double> s1, s2;
-
-    {
-        auto events = referenceOutputSTVM->getAllEvents();
-        for (auto e: events) {
-            s1.push_back(m_outputPreprocessor(e.getValue()));
-        }
-        events = toAlignOutputSTVM->getAllEvents();
-        for (auto e: events) {
-            s2.push_back(m_outputPreprocessor(e.getValue()));
-        }
+    for (double v: refValues) {
+        s1.push_back(m_magnitudePreprocessor(v));
+    }
+    for (double v: otherValues) {
+        s2.push_back(m_magnitudePreprocessor(v));
     }
 
     SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
            << "Have " << s1.size() << " events from reference, "
            << s2.size() << " from toAlign" << endl;
-
+    
     MagnitudeDTW dtw;
     vector<size_t> alignment;
 
     {
         SVCERR << "TransformDTWAligner[" << this
                << "]: serialising DTW to avoid over-allocation" << endl;
-        static QMutex mutex;
-        QMutexLocker locker(&mutex);
-
+        QMutexLocker locker(&m_dtwMutex);
         alignment = dtw.alignSeries(s1, s2);
     }
 
     SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: "
            << "DTW produced " << alignment.size() << " points:" << endl;
-    for (int i = 0; i < alignment.size() && i < 100; ++i) {
+    for (int i = 0; in_range_for(alignment, i) && i < 100; ++i) {
         SVCERR << alignment[i] << " ";
     }
     SVCERR << endl;
 
+    alignmentModel->setPath(makePath(alignment,
+                                     refFrames,
+                                     otherFrames,
+                                     alignmentModel->getSampleRate(),
+                                     resolution));
     alignmentModel->setCompletion(100);
 
-    sv_frame_t resolution = referenceOutputSTVM->getResolution();
-    sv_frame_t sourceFrame = 0;
-    
-    Path path(referenceOutputSTVM->getSampleRate(), resolution);
-    
-    for (size_t m: alignment) {
-        path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
-        sourceFrame += resolution;
-    }
-
-    alignmentModel->setPath(path);
-
-    SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentMagnitude: Done"
-           << endl;
+    SVCERR << "TransformDTWAligner[" << this
+           << "]: performAlignmentMagnitude: Done" << endl;
 
     m_incomplete = false;
     return true;
@@ -288,59 +376,62 @@
 bool
 TransformDTWAligner::performAlignmentRiseFall()
 {
-    auto referenceOutputSTVM = ModelById::getAs<SparseTimeValueModel>
-        (m_referenceOutputModel);
-    auto toAlignOutputSTVM = ModelById::getAs<SparseTimeValueModel>
-        (m_toAlignOutputModel);
-    auto alignmentModel = ModelById::getAs<AlignmentModel>
-        (m_alignmentModel);
-
-    if (!referenceOutputSTVM || !toAlignOutputSTVM) {
-        //!!! what?
-        return false;
-    }
-
+    auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
     if (!alignmentModel) {
         return false;
     }
 
-    auto convertEvents =
-        [this](const EventVector &ee) {
+    vector<sv_frame_t> refFrames, otherFrames;
+    vector<double> refValues, otherValues;
+    sv_frame_t resolution = 0;
+
+    if (!getValuesFrom(m_referenceOutputModel,
+                       refFrames, refValues, resolution)) {
+        return false;
+    }
+
+    if (!getValuesFrom(m_toAlignOutputModel,
+                       otherFrames, otherValues, resolution)) {
+        return false;
+    }
+    
+    auto preprocess =
+        [this](const std::vector<double> &vv) {
             vector<RiseFallDTW::Value> s;
             double prev = 0.0;
-            for (auto e: ee) {
-                double v = m_outputPreprocessor(e.getValue());
-                if (v == prev || s.empty()) {
-                    s.push_back({ RiseFallDTW::Direction::None, 0.0 });
-                } else if (v > prev) {
-                    s.push_back({ RiseFallDTW::Direction::Up, v - prev });
-                } else {
-                    s.push_back({ RiseFallDTW::Direction::Down, prev - v });
-                }
+            for (auto curr: vv) {
+                s.push_back(m_riseFallPreprocessor(prev, curr));
+                prev = curr;
             }
             return s;
-        };
+        }; 
     
-    vector<RiseFallDTW::Value> s1 =
-        convertEvents(referenceOutputSTVM->getAllEvents());
-
-    vector<RiseFallDTW::Value> s2 =
-        convertEvents(toAlignOutputSTVM->getAllEvents());
+    vector<RiseFallDTW::Value> s1 = preprocess(refValues);
+    vector<RiseFallDTW::Value> s2 = preprocess(otherValues);
 
     SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: "
            << "Have " << s1.size() << " events from reference, "
            << s2.size() << " from toAlign" << endl;
 
+    SVCERR << "Reference:" << endl;
+    for (int i = 0; in_range_for(s1, i) && i < 100; ++i) {
+        SVCERR << s1[i] << " ";
+    }
+    SVCERR << endl;
+
+    SVCERR << "toAlign:" << endl;
+    for (int i = 0; in_range_for(s2, i) && i < 100; ++i) {
+        SVCERR << s2[i] << " ";
+    }
+    SVCERR << endl;
+
     RiseFallDTW dtw;
-    
     vector<size_t> alignment;
 
     {
         SVCERR << "TransformDTWAligner[" << this
                << "]: serialising DTW to avoid over-allocation" << endl;
-        static QMutex mutex;
-        QMutexLocker locker(&mutex);
-
+        QMutexLocker locker(&m_dtwMutex);
         alignment = dtw.alignSeries(s1, s2);
     }
 
@@ -351,20 +442,14 @@
     }
     SVCERR << endl;
 
+    alignmentModel->setPath(makePath(alignment,
+                                     refFrames,
+                                     otherFrames,
+                                     alignmentModel->getSampleRate(),
+                                     resolution));
+
     alignmentModel->setCompletion(100);
 
-    sv_frame_t resolution = referenceOutputSTVM->getResolution();
-    sv_frame_t sourceFrame = 0;
-    
-    Path path(referenceOutputSTVM->getSampleRate(), resolution);
-    
-    for (size_t m: alignment) {
-        path.add(PathPoint(sourceFrame, sv_frame_t(m) * resolution));
-        sourceFrame += resolution;
-    }
-
-    alignmentModel->setPath(path);
-
     SVCERR << "TransformDTWAligner[" << this << "]: performAlignmentRiseFall: Done"
            << endl;