Mercurial > hg > svapp
changeset 757:f32df46d0c84 pitch-align
Flesh out DTW
author | Chris Cannam |
---|---|
date | Mon, 27 Apr 2020 14:59:56 +0100 |
parents | 39808338e771 |
children | 6429a164b7e1 |
files | align/DTW.h |
diffstat | 1 files changed, 125 insertions(+), 23 deletions(-) [+] |
line wrap: on
line diff
--- a/align/DTW.h Mon Apr 27 14:59:50 2020 +0100 +++ b/align/DTW.h Mon Apr 27 14:59:56 2020 +0100 @@ -16,10 +16,58 @@ #define SV_DTW_H #include <vector> +#include <functional> +template <typename Value> class DTW { public: + DTW(std::function<double(const Value &, const Value &)> distanceMetric) : + m_metric(distanceMetric) { } + + std::vector<size_t> alignSeries(std::vector<Value> s1, + std::vector<Value> s2) { + + // Return the index into s2 for each element in s1 + + std::vector<size_t> alignment(s1.size(), 0); + + if (s1.empty() || s2.empty()) { + return alignment; + } + + auto costs = costSeries(s1, s2); + + size_t j = s1.size() - 1; + size_t i = s2.size() - 1; + + while (j > 0 && i > 0) { + + alignment[j] = i; + + cost_t a = costs[j-1][i]; + cost_t b = costs[j][i-1]; + cost_t both = costs[j-1][i-1]; + + if (a < b) { + --j; + if (both < a) { + --i; + } + } else { + --i; + if (both < b) { + --j; + } + } + } + + return alignment; + } + +private: + std::function<double(const Value &, const Value &)> m_metric; + typedef double cost_t; struct CostOption { @@ -27,18 +75,7 @@ cost_t cost; }; - enum class Direction { - None, - Up, - Down - }; - - struct Value { - Direction direction; - cost_t cost; - }; - - static cost_t choose(CostOption x, CostOption y, CostOption d) { + cost_t choose(CostOption x, CostOption y, CostOption d) { if (x.present && y.present) { if (!d.present) { throw std::logic_error("if x & y both exist, so must diagonal"); @@ -53,18 +90,89 @@ } } - static cost_t calculateCost(Value a, Value b) { - auto together = [](cost_t c1, cost_t c2) { + std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1, + std::vector<Value> s2) { + + std::vector<std::vector<cost_t>> costs + (s1.size(), std::vector<cost_t>(s2.size(), 0.0)); + + for (size_t j = 0; j < s1.size(); ++j) { + for (size_t i = 0; i < s2.size(); ++i) { + cost_t c = m_metric(s1[j], s2[i]); + costs[j][i] = choose + ( + { j > 0, + j > 0 ? c + costs[j-1][i] : 0.0 + }, + { i > 0, + i > 0 ? c + costs[j][i-1] : 0.0 + }, + { j > 0 && i > 0, + j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 + }); + } + } + + return costs; + } +}; + +class MagnitudeDTW +{ +public: + MagnitudeDTW() : m_dtw(metric) { } + + std::vector<size_t> alignSeries(std::vector<double> s1, + std::vector<double> s2) { + return m_dtw.alignSeries(s1, s2); + } + +private: + DTW<double> m_dtw; + + static double metric(const double &a, const double &b) { + return std::abs(b - a); + } +}; + +class RiseFallDTW +{ +public: + enum class Direction { + None, + Up, + Down + }; + + struct Value { + Direction direction; + double distance; + }; + + RiseFallDTW() : m_dtw(metric) { } + + std::vector<size_t> alignSeries(std::vector<Value> s1, + std::vector<Value> s2) { + return m_dtw.alignSeries(s1, s2); + } + +private: + DTW<Value> m_dtw; + + static double metric(const Value &a, const Value &b) { + + auto together = [](double c1, double c2) { auto diff = std::abs(c1 - c2); return (diff < 1.0 ? -1.0 : diff > 3.0 ? 1.0 : 0.0); }; - auto opposing = [](cost_t c1, cost_t c2) { + auto opposing = [](double c1, double c2) { auto diff = c1 + c2; return (diff < 2.0 ? 1.0 : 2.0); }; + if (a.direction == Direction::None || b.direction == Direction::None) { if (a.direction == b.direction) { return 0.0; @@ -73,18 +181,12 @@ } } else { if (a.direction == b.direction) { - return together (a.cost, b.cost); + return together (a.distance, b.distance); } else { - return opposing (a.cost, b.cost); + return opposing (a.distance, b.distance); } } } - - static std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1, - std::vector<Value> s2) { - - } - }; #endif