# HG changeset patch # User Chris Cannam # Date 1587995996 -3600 # Node ID f32df46d0c8443dffcd093fcb6eb427b69e06165 # Parent 39808338e7714a514f2f3e9d1c8cf11f59ed1736 Flesh out DTW diff -r 39808338e771 -r f32df46d0c84 align/DTW.h --- a/align/DTW.h Mon Apr 27 14:59:50 2020 +0100 +++ b/align/DTW.h Mon Apr 27 14:59:56 2020 +0100 @@ -16,10 +16,58 @@ #define SV_DTW_H #include +#include +template class DTW { public: + DTW(std::function distanceMetric) : + m_metric(distanceMetric) { } + + std::vector alignSeries(std::vector s1, + std::vector s2) { + + // Return the index into s2 for each element in s1 + + std::vector alignment(s1.size(), 0); + + if (s1.empty() || s2.empty()) { + return alignment; + } + + auto costs = costSeries(s1, s2); + + size_t j = s1.size() - 1; + size_t i = s2.size() - 1; + + while (j > 0 && i > 0) { + + alignment[j] = i; + + cost_t a = costs[j-1][i]; + cost_t b = costs[j][i-1]; + cost_t both = costs[j-1][i-1]; + + if (a < b) { + --j; + if (both < a) { + --i; + } + } else { + --i; + if (both < b) { + --j; + } + } + } + + return alignment; + } + +private: + std::function m_metric; + typedef double cost_t; struct CostOption { @@ -27,18 +75,7 @@ cost_t cost; }; - enum class Direction { - None, - Up, - Down - }; - - struct Value { - Direction direction; - cost_t cost; - }; - - static cost_t choose(CostOption x, CostOption y, CostOption d) { + cost_t choose(CostOption x, CostOption y, CostOption d) { if (x.present && y.present) { if (!d.present) { throw std::logic_error("if x & y both exist, so must diagonal"); @@ -53,18 +90,89 @@ } } - static cost_t calculateCost(Value a, Value b) { - auto together = [](cost_t c1, cost_t c2) { + std::vector> costSeries(std::vector s1, + std::vector s2) { + + std::vector> costs + (s1.size(), std::vector(s2.size(), 0.0)); + + for (size_t j = 0; j < s1.size(); ++j) { + for (size_t i = 0; i < s2.size(); ++i) { + cost_t c = m_metric(s1[j], s2[i]); + costs[j][i] = choose + ( + { j > 0, + j > 0 ? c + costs[j-1][i] : 0.0 + }, + { i > 0, + i > 0 ? c + costs[j][i-1] : 0.0 + }, + { j > 0 && i > 0, + j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 + }); + } + } + + return costs; + } +}; + +class MagnitudeDTW +{ +public: + MagnitudeDTW() : m_dtw(metric) { } + + std::vector alignSeries(std::vector s1, + std::vector s2) { + return m_dtw.alignSeries(s1, s2); + } + +private: + DTW m_dtw; + + static double metric(const double &a, const double &b) { + return std::abs(b - a); + } +}; + +class RiseFallDTW +{ +public: + enum class Direction { + None, + Up, + Down + }; + + struct Value { + Direction direction; + double distance; + }; + + RiseFallDTW() : m_dtw(metric) { } + + std::vector alignSeries(std::vector s1, + std::vector s2) { + return m_dtw.alignSeries(s1, s2); + } + +private: + DTW m_dtw; + + static double metric(const Value &a, const Value &b) { + + auto together = [](double c1, double c2) { auto diff = std::abs(c1 - c2); return (diff < 1.0 ? -1.0 : diff > 3.0 ? 1.0 : 0.0); }; - auto opposing = [](cost_t c1, cost_t c2) { + auto opposing = [](double c1, double c2) { auto diff = c1 + c2; return (diff < 2.0 ? 1.0 : 2.0); }; + if (a.direction == Direction::None || b.direction == Direction::None) { if (a.direction == b.direction) { return 0.0; @@ -73,18 +181,12 @@ } } else { if (a.direction == b.direction) { - return together (a.cost, b.cost); + return together (a.distance, b.distance); } else { - return opposing (a.cost, b.cost); + return opposing (a.distance, b.distance); } } } - - static std::vector> costSeries(std::vector s1, - std::vector s2) { - - } - }; #endif