diff align/DTW.h @ 780:8fa98f89eda8

Add subsequence DTW (not yet in use)
author Chris Cannam
date Wed, 01 Jul 2020 15:34:46 +0100
parents 8280f7a363d1
children 4d10365aa6a9
line wrap: on
line diff
--- a/align/DTW.h	Wed Jul 01 11:41:07 2020 +0100
+++ b/align/DTW.h	Wed Jul 01 15:34:46 2020 +0100
@@ -27,54 +27,22 @@
     DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
         m_metric(distanceMetric) { }
 
-    std::vector<size_t> alignSeries(std::vector<Value> s1,
-                                    std::vector<Value> s2) {
+    /**
+     * Align the sequence s2 against the whole of the sequence s1,
+     * returning the index into s1 for each element in s2.
+     */
+    std::vector<size_t> alignSequences(std::vector<Value> s1,
+                                       std::vector<Value> s2) {
+        return align(s1, s2, false);
+    }
 
-        // Return the index into s2 for each element in s1
-        
-        std::vector<size_t> alignment(s1.size(), 0);
-
-        if (s1.empty() || s2.empty()) {
-            return alignment;
-        }
-
-        auto costs = costSeries(s1, s2);
-
-#ifdef DEBUG_DTW
-        SVCERR << "Cost matrix:" << endl;
-        for (auto v: costs) {
-            for (auto x: v) {
-                SVCERR << x << " ";
-            }
-            SVCERR << "\n";
-        }
-#endif
-        
-        size_t j = s1.size() - 1;
-        size_t i = s2.size() - 1;
-
-        while (j > 0 && i > 0) {
-
-            alignment[j] = i;
-            
-            cost_t a = costs[j-1][i];
-            cost_t b = costs[j][i-1];
-            cost_t both = costs[j-1][i-1];
-
-            if (a < b) {
-                --j;
-                if (both <= a) {
-                    --i;
-                }
-            } else {
-                --i;
-                if (both <= b) {
-                    --j;
-                }
-            }
-        }
-
-        return alignment;
+    /**
+     * Align the sequence sub against the best-matching subsequence of
+     * s, returning the index into s for each element in sub.
+     */
+    std::vector<size_t> alignSubsequence(std::vector<Value> s,
+                                         std::vector<Value> sub) {
+        return align(s, sub, true);
     }
 
 private:
@@ -102,8 +70,9 @@
         }
     }
 
-    std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
-                                                std::vector<Value> s2) {
+    std::vector<std::vector<cost_t>> costSequences(std::vector<Value> s1,
+                                                   std::vector<Value> s2,
+                                                   bool subsequence) {
 
         std::vector<std::vector<cost_t>> costs
             (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
@@ -111,22 +80,97 @@
         for (size_t j = 0; j < s1.size(); ++j) {
             for (size_t i = 0; i < s2.size(); ++i) {
                 cost_t c = m_metric(s1[j], s2[i]);
-                costs[j][i] = choose
-                    (
-                        { j > 0,
-                          j > 0 ? c + costs[j-1][i] : 0.0
-                        },
-                        { i > 0,
-                          i > 0 ? c + costs[j][i-1] : 0.0
-                        },
-                        { j > 0 && i > 0,
-                          j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
-                        });
+                if (i == 0 && subsequence) {
+                    costs[j][i] = c;
+                } else {
+                    costs[j][i] = choose
+                        (
+                            { j > 0,
+                              j > 0 ? c + costs[j-1][i] : 0.0
+                            },
+                            { i > 0,
+                              i > 0 ? c + costs[j][i-1] : 0.0
+                            },
+                            { j > 0 && i > 0,
+                              j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
+                            });
+                }
             }
         }
 
         return costs;
     }
+
+    std::vector<size_t> align(const std::vector<Value> &s1,
+                              const std::vector<Value> &s2,
+                              bool subsequence) {
+
+        // Return the index into s1 for each element in s2
+        
+        std::vector<size_t> alignment(s2.size(), 0);
+
+        if (s1.empty() || s2.empty()) {
+            return alignment;
+        }
+
+        auto costs = costSequences(s1, s2, subsequence);
+
+#ifdef DEBUG_DTW
+        SVCERR << "Cost matrix:" << endl;
+        for (auto v: cost) {
+            for (auto x: v) {
+                SVCERR << x << " ";
+            }
+            SVCERR << "\n";
+        }
+#endif
+
+        size_t j = s1.size() - 1;
+        size_t i = s2.size() - 1;
+
+        if (subsequence) {
+            cost_t min = 0.0;
+            size_t minidx = 0;
+            for (size_t j = 0; j < s1.size(); ++j) {
+                if (j == 0 || costs[j][i] < min) {
+                    min = costs[j][i];
+                    minidx = j;
+                }
+            }
+            j = minidx;
+#ifdef DEBUG_DTW
+            SVCERR << "Lowest cost at end of subsequence = " << min
+                   << " at index " << j << ", tracking back from there" << endl;
+#endif
+        }
+        
+        while (i > 0 && (j > 0 || subsequence)) {
+
+            alignment[i] = j;
+            
+            cost_t a = costs[j-1][i];
+            cost_t b = costs[j][i-1];
+            cost_t both = costs[j-1][i-1];
+
+            if (a < b) {
+                --j;
+                if (both <= a) {
+                    --i;
+                }
+            } else {
+                --i;
+                if (both <= b) {
+                    --j;
+                }
+            }
+        }
+
+        if (subsequence) {
+            alignment[0] = j;
+        }
+        
+        return alignment;
+    }
 };
 
 class MagnitudeDTW
@@ -134,9 +178,14 @@
 public:
     MagnitudeDTW() : m_dtw(metric) { }
 
-    std::vector<size_t> alignSeries(std::vector<double> s1,
-                                    std::vector<double> s2) {
-        return m_dtw.alignSeries(s1, s2);
+    std::vector<size_t> alignSequences(std::vector<double> s1,
+                                       std::vector<double> s2) {
+        return m_dtw.alignSequences(s1, s2);
+    }
+
+    std::vector<size_t> alignSubsequence(std::vector<double> s,
+                                         std::vector<double> sub) {
+        return m_dtw.alignSubsequence(s, sub);
     }
 
 private:
@@ -163,9 +212,14 @@
 
     RiseFallDTW() : m_dtw(metric) { }
 
-    std::vector<size_t> alignSeries(std::vector<Value> s1,
-                                    std::vector<Value> s2) {
-        return m_dtw.alignSeries(s1, s2);
+    std::vector<size_t> alignSequences(std::vector<Value> s1,
+                                       std::vector<Value> s2) {
+        return m_dtw.alignSequences(s1, s2);
+    }
+
+    std::vector<size_t> alignSubsequence(std::vector<Value> s,
+                                         std::vector<Value> sub) {
+        return m_dtw.alignSubsequence(s, sub);
     }
 
 private: