Mercurial > hg > svapp
comparison align/DTW.h @ 757:f32df46d0c84 pitch-align
Flesh out DTW
author | Chris Cannam |
---|---|
date | Mon, 27 Apr 2020 14:59:56 +0100 |
parents | b27815194752 |
children | 1d6cca5a5621 |
comparison
equal
deleted
inserted
replaced
756:39808338e771 | 757:f32df46d0c84 |
---|---|
14 | 14 |
15 #ifndef SV_DTW_H | 15 #ifndef SV_DTW_H |
16 #define SV_DTW_H | 16 #define SV_DTW_H |
17 | 17 |
18 #include <vector> | 18 #include <vector> |
19 #include <functional> | |
19 | 20 |
21 template <typename Value> | |
20 class DTW | 22 class DTW |
21 { | 23 { |
22 public: | 24 public: |
25 DTW(std::function<double(const Value &, const Value &)> distanceMetric) : | |
26 m_metric(distanceMetric) { } | |
27 | |
28 std::vector<size_t> alignSeries(std::vector<Value> s1, | |
29 std::vector<Value> s2) { | |
30 | |
31 // Return the index into s2 for each element in s1 | |
32 | |
33 std::vector<size_t> alignment(s1.size(), 0); | |
34 | |
35 if (s1.empty() || s2.empty()) { | |
36 return alignment; | |
37 } | |
38 | |
39 auto costs = costSeries(s1, s2); | |
40 | |
41 size_t j = s1.size() - 1; | |
42 size_t i = s2.size() - 1; | |
43 | |
44 while (j > 0 && i > 0) { | |
45 | |
46 alignment[j] = i; | |
47 | |
48 cost_t a = costs[j-1][i]; | |
49 cost_t b = costs[j][i-1]; | |
50 cost_t both = costs[j-1][i-1]; | |
51 | |
52 if (a < b) { | |
53 --j; | |
54 if (both < a) { | |
55 --i; | |
56 } | |
57 } else { | |
58 --i; | |
59 if (both < b) { | |
60 --j; | |
61 } | |
62 } | |
63 } | |
64 | |
65 return alignment; | |
66 } | |
67 | |
68 private: | |
69 std::function<double(const Value &, const Value &)> m_metric; | |
70 | |
23 typedef double cost_t; | 71 typedef double cost_t; |
24 | 72 |
25 struct CostOption { | 73 struct CostOption { |
26 bool present; | 74 bool present; |
27 cost_t cost; | 75 cost_t cost; |
28 }; | 76 }; |
29 | 77 |
30 enum class Direction { | 78 cost_t choose(CostOption x, CostOption y, CostOption d) { |
31 None, | |
32 Up, | |
33 Down | |
34 }; | |
35 | |
36 struct Value { | |
37 Direction direction; | |
38 cost_t cost; | |
39 }; | |
40 | |
41 static cost_t choose(CostOption x, CostOption y, CostOption d) { | |
42 if (x.present && y.present) { | 79 if (x.present && y.present) { |
43 if (!d.present) { | 80 if (!d.present) { |
44 throw std::logic_error("if x & y both exist, so must diagonal"); | 81 throw std::logic_error("if x & y both exist, so must diagonal"); |
45 } | 82 } |
46 return std::min(std::min(x.cost, y.cost), d.cost); | 83 return std::min(std::min(x.cost, y.cost), d.cost); |
51 } else { | 88 } else { |
52 return 0.0; | 89 return 0.0; |
53 } | 90 } |
54 } | 91 } |
55 | 92 |
56 static cost_t calculateCost(Value a, Value b) { | 93 std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1, |
57 auto together = [](cost_t c1, cost_t c2) { | 94 std::vector<Value> s2) { |
95 | |
96 std::vector<std::vector<cost_t>> costs | |
97 (s1.size(), std::vector<cost_t>(s2.size(), 0.0)); | |
98 | |
99 for (size_t j = 0; j < s1.size(); ++j) { | |
100 for (size_t i = 0; i < s2.size(); ++i) { | |
101 cost_t c = m_metric(s1[j], s2[i]); | |
102 costs[j][i] = choose | |
103 ( | |
104 { j > 0, | |
105 j > 0 ? c + costs[j-1][i] : 0.0 | |
106 }, | |
107 { i > 0, | |
108 i > 0 ? c + costs[j][i-1] : 0.0 | |
109 }, | |
110 { j > 0 && i > 0, | |
111 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 | |
112 }); | |
113 } | |
114 } | |
115 | |
116 return costs; | |
117 } | |
118 }; | |
119 | |
120 class MagnitudeDTW | |
121 { | |
122 public: | |
123 MagnitudeDTW() : m_dtw(metric) { } | |
124 | |
125 std::vector<size_t> alignSeries(std::vector<double> s1, | |
126 std::vector<double> s2) { | |
127 return m_dtw.alignSeries(s1, s2); | |
128 } | |
129 | |
130 private: | |
131 DTW<double> m_dtw; | |
132 | |
133 static double metric(const double &a, const double &b) { | |
134 return std::abs(b - a); | |
135 } | |
136 }; | |
137 | |
138 class RiseFallDTW | |
139 { | |
140 public: | |
141 enum class Direction { | |
142 None, | |
143 Up, | |
144 Down | |
145 }; | |
146 | |
147 struct Value { | |
148 Direction direction; | |
149 double distance; | |
150 }; | |
151 | |
152 RiseFallDTW() : m_dtw(metric) { } | |
153 | |
154 std::vector<size_t> alignSeries(std::vector<Value> s1, | |
155 std::vector<Value> s2) { | |
156 return m_dtw.alignSeries(s1, s2); | |
157 } | |
158 | |
159 private: | |
160 DTW<Value> m_dtw; | |
161 | |
162 static double metric(const Value &a, const Value &b) { | |
163 | |
164 auto together = [](double c1, double c2) { | |
58 auto diff = std::abs(c1 - c2); | 165 auto diff = std::abs(c1 - c2); |
59 return (diff < 1.0 ? -1.0 : | 166 return (diff < 1.0 ? -1.0 : |
60 diff > 3.0 ? 1.0 : | 167 diff > 3.0 ? 1.0 : |
61 0.0); | 168 0.0); |
62 }; | 169 }; |
63 auto opposing = [](cost_t c1, cost_t c2) { | 170 auto opposing = [](double c1, double c2) { |
64 auto diff = c1 + c2; | 171 auto diff = c1 + c2; |
65 return (diff < 2.0 ? 1.0 : | 172 return (diff < 2.0 ? 1.0 : |
66 2.0); | 173 2.0); |
67 }; | 174 }; |
175 | |
68 if (a.direction == Direction::None || b.direction == Direction::None) { | 176 if (a.direction == Direction::None || b.direction == Direction::None) { |
69 if (a.direction == b.direction) { | 177 if (a.direction == b.direction) { |
70 return 0.0; | 178 return 0.0; |
71 } else { | 179 } else { |
72 return 1.0; | 180 return 1.0; |
73 } | 181 } |
74 } else { | 182 } else { |
75 if (a.direction == b.direction) { | 183 if (a.direction == b.direction) { |
76 return together (a.cost, b.cost); | 184 return together (a.distance, b.distance); |
77 } else { | 185 } else { |
78 return opposing (a.cost, b.cost); | 186 return opposing (a.distance, b.distance); |
79 } | 187 } |
80 } | 188 } |
81 } | 189 } |
82 | |
83 static std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1, | |
84 std::vector<Value> s2) { | |
85 | |
86 } | |
87 | |
88 }; | 190 }; |
89 | 191 |
90 #endif | 192 #endif |