changeset 757:f32df46d0c84 pitch-align

Flesh out DTW
author Chris Cannam
date Mon, 27 Apr 2020 14:59:56 +0100
parents 39808338e771
children 6429a164b7e1
files align/DTW.h
diffstat 1 files changed, 125 insertions(+), 23 deletions(-) [+]
line wrap: on
line diff
--- a/align/DTW.h	Mon Apr 27 14:59:50 2020 +0100
+++ b/align/DTW.h	Mon Apr 27 14:59:56 2020 +0100
@@ -16,10 +16,58 @@
 #define SV_DTW_H
 
 #include <vector>
+#include <functional>
 
+template <typename Value>
 class DTW
 {
 public:
+    DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
+        m_metric(distanceMetric) { }
+
+    std::vector<size_t> alignSeries(std::vector<Value> s1,
+                                    std::vector<Value> s2) {
+
+        // Return the index into s2 for each element in s1
+        
+        std::vector<size_t> alignment(s1.size(), 0);
+
+        if (s1.empty() || s2.empty()) {
+            return alignment;
+        }
+
+        auto costs = costSeries(s1, s2);
+
+        size_t j = s1.size() - 1;
+        size_t i = s2.size() - 1;
+
+        while (j > 0 && i > 0) {
+
+            alignment[j] = i;
+            
+            cost_t a = costs[j-1][i];
+            cost_t b = costs[j][i-1];
+            cost_t both = costs[j-1][i-1];
+
+            if (a < b) {
+                --j;
+                if (both < a) {
+                    --i;
+                }
+            } else {
+                --i;
+                if (both < b) {
+                    --j;
+                }
+            }
+        }
+
+        return alignment;
+    }
+
+private:
+    std::function<double(const Value &, const Value &)> m_metric;
+    
     typedef double cost_t;
 
     struct CostOption {
@@ -27,18 +75,7 @@
         cost_t cost;
     };
 
-    enum class Direction {
-        None,
-        Up,
-        Down
-    };
-    
-    struct Value {
-        Direction direction;
-        cost_t cost;
-    };
-
-    static cost_t choose(CostOption x, CostOption y, CostOption d) {
+    cost_t choose(CostOption x, CostOption y, CostOption d) {
         if (x.present && y.present) {
             if (!d.present) {
                 throw std::logic_error("if x & y both exist, so must diagonal");
@@ -53,18 +90,89 @@
         }
     }
 
-    static cost_t calculateCost(Value a, Value b) {
-        auto together = [](cost_t c1, cost_t c2) {
+    std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
+                                                std::vector<Value> s2) {
+
+        std::vector<std::vector<cost_t>> costs
+            (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
+
+        for (size_t j = 0; j < s1.size(); ++j) {
+            for (size_t i = 0; i < s2.size(); ++i) {
+                cost_t c = m_metric(s1[j], s2[i]);
+                costs[j][i] = choose
+                    (
+                        { j > 0,
+                          j > 0 ? c + costs[j-1][i] : 0.0
+                        },
+                        { i > 0,
+                          i > 0 ? c + costs[j][i-1] : 0.0
+                        },
+                        { j > 0 && i > 0,
+                          j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
+                        });
+            }
+        }
+
+        return costs;
+    }
+};
+
+class MagnitudeDTW
+{
+public:
+    MagnitudeDTW() : m_dtw(metric) { }
+
+    std::vector<size_t> alignSeries(std::vector<double> s1,
+                                    std::vector<double> s2) {
+        return m_dtw.alignSeries(s1, s2);
+    }
+
+private:
+    DTW<double> m_dtw;
+
+    static double metric(const double &a, const double &b) {
+        return std::abs(b - a);
+    }
+};
+
+class RiseFallDTW
+{
+public:
+    enum class Direction {
+        None,
+        Up,
+        Down
+    };
+    
+    struct Value {
+        Direction direction;
+        double distance;
+    };
+
+    RiseFallDTW() : m_dtw(metric) { }
+
+    std::vector<size_t> alignSeries(std::vector<Value> s1,
+                                    std::vector<Value> s2) {
+        return m_dtw.alignSeries(s1, s2);
+    }
+
+private:
+    DTW<Value> m_dtw;
+
+    static double metric(const Value &a, const Value &b) {
+        
+        auto together = [](double c1, double c2) {
                             auto diff = std::abs(c1 - c2);
                             return (diff < 1.0 ? -1.0 :
                                     diff > 3.0 ?  1.0 :
                                     0.0);
                         };
-        auto opposing = [](cost_t c1, cost_t c2) {
+        auto opposing = [](double c1, double c2) {
                             auto diff = c1 + c2;
                             return (diff < 2.0 ? 1.0 :
                                     2.0);
                         };
+
         if (a.direction == Direction::None || b.direction == Direction::None) {
             if (a.direction == b.direction) {
                 return 0.0;
@@ -73,18 +181,12 @@
             }
         } else {
             if (a.direction == b.direction) {
-                return together (a.cost, b.cost);
+                return together (a.distance, b.distance);
             } else {
-                return opposing (a.cost, b.cost);
+                return opposing (a.distance, b.distance);
             }
         }
     }
-
-    static std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
-                                                       std::vector<Value> s2) {
-        
-    }
-    
 };
 
 #endif