diff align/TransformDTWAligner.cpp @ 768:1b1960009be6 pitch-align

Provide callback for output preprocessing before DTW, use it for freq-pitch conversion; use direct setting of completion on alignment models instead of creating fake outputs for completion only
author Chris Cannam
date Fri, 22 May 2020 17:17:44 +0100
parents dd742e566e60
children a316cb6fed81
line wrap: on
line diff
--- a/align/TransformDTWAligner.cpp	Thu May 21 16:21:57 2020 +0100
+++ b/align/TransformDTWAligner.cpp	Fri May 22 17:17:44 2020 +0100
@@ -39,11 +39,27 @@
     m_document(doc),
     m_reference(reference),
     m_toAlign(toAlign),
-    m_referenceTransformComplete(false),
-    m_toAlignTransformComplete(false),
     m_transform(transform),
     m_dtwType(dtwType),
-    m_incomplete(true)
+    m_incomplete(true),
+    m_outputPreprocessor([](double x) { return x; })
+{
+}
+
+TransformDTWAligner::TransformDTWAligner(Document *doc,
+                                         ModelId reference,
+                                         ModelId toAlign,
+                                         Transform transform,
+                                         DTWType dtwType,
+                                         std::function<double(double)>
+                                         outputPreprocessor) :
+    m_document(doc),
+    m_reference(reference),
+    m_toAlign(toAlign),
+    m_transform(transform),
+    m_dtwType(dtwType),
+    m_incomplete(true),
+    m_outputPreprocessor(outputPreprocessor)
 {
 }
 
@@ -57,7 +73,6 @@
     
     ModelById::release(m_referenceOutputModel);
     ModelById::release(m_toAlignOutputModel);
-    ModelById::release(m_alignmentProgressModel);
 }
 
 bool
@@ -114,14 +129,9 @@
             this, SLOT(completionChanged(ModelId)));
     connect(toAlignOutputModel.get(), SIGNAL(completionChanged(ModelId)),
             this, SLOT(completionChanged(ModelId)));
-
-    auto alignmentProgressModel = std::make_shared<SparseTimeValueModel>
-        (reference->getSampleRate(), m_transform.getStepSize(), false);
-    alignmentProgressModel->setCompletion(0);
-    m_alignmentProgressModel = ModelById::add(alignmentProgressModel);
     
     auto alignmentModel = std::make_shared<AlignmentModel>
-        (m_reference, m_toAlign, m_alignmentProgressModel);
+        (m_reference, m_toAlign, ModelId());
     m_alignmentModel = ModelById::add(alignmentModel);
     
     toAlign->setAlignment(m_alignmentModel);
@@ -153,8 +163,9 @@
 
     auto referenceOutputModel = ModelById::get(m_referenceOutputModel);
     auto toAlignOutputModel = ModelById::get(m_toAlignOutputModel);
+    auto alignmentModel = ModelById::getAs<AlignmentModel>(m_alignmentModel);
 
-    if (!referenceOutputModel || !toAlignOutputModel) {
+    if (!referenceOutputModel || !toAlignOutputModel || !alignmentModel) {
         return;
     }
 
@@ -162,17 +173,12 @@
     bool referenceReady = referenceOutputModel->isReady(&referenceCompletion);
     bool toAlignReady = toAlignOutputModel->isReady(&toAlignCompletion);
 
-    auto alignmentProgressModel =
-        ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
-    
     if (referenceReady && toAlignReady) {
 
         SVCERR << "TransformDTWAligner[" << this << "]: completionChanged: "
                << "ready, calling performAlignment" << endl;
 
-        if (alignmentProgressModel) {
-            alignmentProgressModel->setCompletion(95);
-        }
+        alignmentModel->setCompletion(95);
         
         if (performAlignment()) {
             emit complete(m_alignmentModel);
@@ -186,12 +192,10 @@
                << "not ready yet: reference completion " << referenceCompletion
                << ", toAlign completion " << toAlignCompletion << endl;
 
-        if (alignmentProgressModel) {
-            int completion = std::min(referenceCompletion,
-                                      toAlignCompletion);
-            completion = (completion * 94) / 100;
-            alignmentProgressModel->setCompletion(completion);
-        }
+        int completion = std::min(referenceCompletion,
+                                  toAlignCompletion);
+        completion = (completion * 94) / 100;
+        alignmentModel->setCompletion(completion);
     }
 }
 
@@ -229,11 +233,11 @@
     {
         auto events = referenceOutputSTVM->getAllEvents();
         for (auto e: events) {
-            s1.push_back(e.getValue());
+            s1.push_back(m_outputPreprocessor(e.getValue()));
         }
         events = toAlignOutputSTVM->getAllEvents();
         for (auto e: events) {
-            s2.push_back(e.getValue());
+            s2.push_back(m_outputPreprocessor(e.getValue()));
         }
     }
 
@@ -260,14 +264,7 @@
     }
     SVCERR << endl;
 
-    auto alignmentProgressModel =
-        ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
-    if (alignmentProgressModel) {
-        alignmentProgressModel->setCompletion(100);
-    }
-    
-    // clear the alignment progress model
-    alignmentModel->setPathFrom(ModelId());
+    alignmentModel->setCompletion(100);
 
     sv_frame_t resolution = referenceOutputSTVM->getResolution();
     sv_frame_t sourceFrame = 0;
@@ -306,36 +303,29 @@
     if (!alignmentModel) {
         return false;
     }
+
+    auto convertEvents =
+        [this](const EventVector &ee) {
+            vector<RiseFallDTW::Value> s;
+            double prev = 0.0;
+            for (auto e: ee) {
+                double v = m_outputPreprocessor(e.getValue());
+                if (v == prev || s.empty()) {
+                    s.push_back({ RiseFallDTW::Direction::None, 0.0 });
+                } else if (v > prev) {
+                    s.push_back({ RiseFallDTW::Direction::Up, v - prev });
+                } else {
+                    s.push_back({ RiseFallDTW::Direction::Down, prev - v });
+                }
+            }
+            return s;
+        };
     
-    vector<RiseFallDTW::Value> s1, s2;
-    double prev1 = 0.0, prev2 = 0.0;
+    vector<RiseFallDTW::Value> s1 =
+        convertEvents(referenceOutputSTVM->getAllEvents());
 
-    {
-        auto events = referenceOutputSTVM->getAllEvents();
-        for (auto e: events) {
-            double v = e.getValue();
-            //!!! the original does this using MIDI pitch for the
-            //!!! pYin transform... rework with a lambda passed in
-            //!!! for modification maybe? + factor out s1/s2 of course
-            if (v > prev1) {
-                s1.push_back({ RiseFallDTW::Direction::Up, v - prev1 });
-            } else {
-                s1.push_back({ RiseFallDTW::Direction::Down, prev1 - v });
-            }
-            prev1 = v;
-        }
-        events = toAlignOutputSTVM->getAllEvents();
-        for (auto e: events) {
-            double v = e.getValue();
-            //!!! as above
-            if (v > prev2) {
-                s2.push_back({ RiseFallDTW::Direction::Up, v - prev2 });
-            } else {
-                s2.push_back({ RiseFallDTW::Direction::Down, prev2 - v });
-            }
-            prev2 = v;
-        }
-    }
+    vector<RiseFallDTW::Value> s2 =
+        convertEvents(toAlignOutputSTVM->getAllEvents());
 
     SVCERR << "TransformDTWAligner[" << this << "]: performAlignment: "
            << "Have " << s1.size() << " events from reference, "
@@ -361,14 +351,7 @@
     }
     SVCERR << endl;
 
-    auto alignmentProgressModel =
-        ModelById::getAs<SparseTimeValueModel>(m_alignmentProgressModel);
-    if (alignmentProgressModel) {
-        alignmentProgressModel->setCompletion(100);
-    }
-    
-    // clear the alignment progress model
-    alignmentModel->setPathFrom(ModelId());
+    alignmentModel->setCompletion(100);
 
     sv_frame_t resolution = referenceOutputSTVM->getResolution();
     sv_frame_t sourceFrame = 0;