Mercurial > hg > svapp
comparison align/DTW.h @ 778:83a7b10b7415
Merge from branch pitch-align
author | Chris Cannam |
---|---|
date | Fri, 26 Jun 2020 13:48:52 +0100 |
parents | 8280f7a363d1 |
children | 8fa98f89eda8 |
comparison
equal
deleted
inserted
replaced
774:7bded7599874 | 778:83a7b10b7415 |
---|---|
16 #define SV_DTW_H | 16 #define SV_DTW_H |
17 | 17 |
18 #include <vector> | 18 #include <vector> |
19 #include <functional> | 19 #include <functional> |
20 | 20 |
21 //#define DEBUG_DTW 1 | |
22 | |
21 template <typename Value> | 23 template <typename Value> |
22 class DTW | 24 class DTW |
23 { | 25 { |
24 public: | 26 public: |
25 DTW(std::function<double(const Value &, const Value &)> distanceMetric) : | 27 DTW(std::function<double(const Value &, const Value &)> distanceMetric) : |
36 return alignment; | 38 return alignment; |
37 } | 39 } |
38 | 40 |
39 auto costs = costSeries(s1, s2); | 41 auto costs = costSeries(s1, s2); |
40 | 42 |
43 #ifdef DEBUG_DTW | |
44 SVCERR << "Cost matrix:" << endl; | |
45 for (auto v: costs) { | |
46 for (auto x: v) { | |
47 SVCERR << x << " "; | |
48 } | |
49 SVCERR << "\n"; | |
50 } | |
51 #endif | |
52 | |
41 size_t j = s1.size() - 1; | 53 size_t j = s1.size() - 1; |
42 size_t i = s2.size() - 1; | 54 size_t i = s2.size() - 1; |
43 | 55 |
44 while (j > 0 && i > 0) { | 56 while (j > 0 && i > 0) { |
45 | 57 |
49 cost_t b = costs[j][i-1]; | 61 cost_t b = costs[j][i-1]; |
50 cost_t both = costs[j-1][i-1]; | 62 cost_t both = costs[j-1][i-1]; |
51 | 63 |
52 if (a < b) { | 64 if (a < b) { |
53 --j; | 65 --j; |
54 if (both < a) { | 66 if (both <= a) { |
55 --i; | 67 --i; |
56 } | 68 } |
57 } else { | 69 } else { |
58 --i; | 70 --i; |
59 if (both < b) { | 71 if (both <= b) { |
60 --j; | 72 --j; |
61 } | 73 } |
62 } | 74 } |
63 } | 75 } |
64 | 76 |
187 } | 199 } |
188 } | 200 } |
189 } | 201 } |
190 }; | 202 }; |
191 | 203 |
204 inline std::ostream &operator<<(std::ostream &s, const RiseFallDTW::Value v) { | |
205 return (s << | |
206 (v.direction == RiseFallDTW::Direction::None ? "=" : | |
207 v.direction == RiseFallDTW::Direction::Up ? "+" : "-") | |
208 << v.distance); | |
209 } | |
210 | |
192 #endif | 211 #endif |