comparison align/DTW.h @ 757:f32df46d0c84 pitch-align

Flesh out DTW
author Chris Cannam
date Mon, 27 Apr 2020 14:59:56 +0100
parents b27815194752
children 1d6cca5a5621
comparison
equal deleted inserted replaced
756:39808338e771 757:f32df46d0c84
14 14
15 #ifndef SV_DTW_H 15 #ifndef SV_DTW_H
16 #define SV_DTW_H 16 #define SV_DTW_H
17 17
18 #include <vector> 18 #include <vector>
19 #include <functional>
19 20
21 template <typename Value>
20 class DTW 22 class DTW
21 { 23 {
22 public: 24 public:
25 DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
26 m_metric(distanceMetric) { }
27
28 std::vector<size_t> alignSeries(std::vector<Value> s1,
29 std::vector<Value> s2) {
30
31 // Return the index into s2 for each element in s1
32
33 std::vector<size_t> alignment(s1.size(), 0);
34
35 if (s1.empty() || s2.empty()) {
36 return alignment;
37 }
38
39 auto costs = costSeries(s1, s2);
40
41 size_t j = s1.size() - 1;
42 size_t i = s2.size() - 1;
43
44 while (j > 0 && i > 0) {
45
46 alignment[j] = i;
47
48 cost_t a = costs[j-1][i];
49 cost_t b = costs[j][i-1];
50 cost_t both = costs[j-1][i-1];
51
52 if (a < b) {
53 --j;
54 if (both < a) {
55 --i;
56 }
57 } else {
58 --i;
59 if (both < b) {
60 --j;
61 }
62 }
63 }
64
65 return alignment;
66 }
67
68 private:
69 std::function<double(const Value &, const Value &)> m_metric;
70
23 typedef double cost_t; 71 typedef double cost_t;
24 72
25 struct CostOption { 73 struct CostOption {
26 bool present; 74 bool present;
27 cost_t cost; 75 cost_t cost;
28 }; 76 };
29 77
30 enum class Direction { 78 cost_t choose(CostOption x, CostOption y, CostOption d) {
31 None,
32 Up,
33 Down
34 };
35
36 struct Value {
37 Direction direction;
38 cost_t cost;
39 };
40
41 static cost_t choose(CostOption x, CostOption y, CostOption d) {
42 if (x.present && y.present) { 79 if (x.present && y.present) {
43 if (!d.present) { 80 if (!d.present) {
44 throw std::logic_error("if x & y both exist, so must diagonal"); 81 throw std::logic_error("if x & y both exist, so must diagonal");
45 } 82 }
46 return std::min(std::min(x.cost, y.cost), d.cost); 83 return std::min(std::min(x.cost, y.cost), d.cost);
51 } else { 88 } else {
52 return 0.0; 89 return 0.0;
53 } 90 }
54 } 91 }
55 92
56 static cost_t calculateCost(Value a, Value b) { 93 std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
57 auto together = [](cost_t c1, cost_t c2) { 94 std::vector<Value> s2) {
95
96 std::vector<std::vector<cost_t>> costs
97 (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
98
99 for (size_t j = 0; j < s1.size(); ++j) {
100 for (size_t i = 0; i < s2.size(); ++i) {
101 cost_t c = m_metric(s1[j], s2[i]);
102 costs[j][i] = choose
103 (
104 { j > 0,
105 j > 0 ? c + costs[j-1][i] : 0.0
106 },
107 { i > 0,
108 i > 0 ? c + costs[j][i-1] : 0.0
109 },
110 { j > 0 && i > 0,
111 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
112 });
113 }
114 }
115
116 return costs;
117 }
118 };
119
120 class MagnitudeDTW
121 {
122 public:
123 MagnitudeDTW() : m_dtw(metric) { }
124
125 std::vector<size_t> alignSeries(std::vector<double> s1,
126 std::vector<double> s2) {
127 return m_dtw.alignSeries(s1, s2);
128 }
129
130 private:
131 DTW<double> m_dtw;
132
133 static double metric(const double &a, const double &b) {
134 return std::abs(b - a);
135 }
136 };
137
138 class RiseFallDTW
139 {
140 public:
141 enum class Direction {
142 None,
143 Up,
144 Down
145 };
146
147 struct Value {
148 Direction direction;
149 double distance;
150 };
151
152 RiseFallDTW() : m_dtw(metric) { }
153
154 std::vector<size_t> alignSeries(std::vector<Value> s1,
155 std::vector<Value> s2) {
156 return m_dtw.alignSeries(s1, s2);
157 }
158
159 private:
160 DTW<Value> m_dtw;
161
162 static double metric(const Value &a, const Value &b) {
163
164 auto together = [](double c1, double c2) {
58 auto diff = std::abs(c1 - c2); 165 auto diff = std::abs(c1 - c2);
59 return (diff < 1.0 ? -1.0 : 166 return (diff < 1.0 ? -1.0 :
60 diff > 3.0 ? 1.0 : 167 diff > 3.0 ? 1.0 :
61 0.0); 168 0.0);
62 }; 169 };
63 auto opposing = [](cost_t c1, cost_t c2) { 170 auto opposing = [](double c1, double c2) {
64 auto diff = c1 + c2; 171 auto diff = c1 + c2;
65 return (diff < 2.0 ? 1.0 : 172 return (diff < 2.0 ? 1.0 :
66 2.0); 173 2.0);
67 }; 174 };
175
68 if (a.direction == Direction::None || b.direction == Direction::None) { 176 if (a.direction == Direction::None || b.direction == Direction::None) {
69 if (a.direction == b.direction) { 177 if (a.direction == b.direction) {
70 return 0.0; 178 return 0.0;
71 } else { 179 } else {
72 return 1.0; 180 return 1.0;
73 } 181 }
74 } else { 182 } else {
75 if (a.direction == b.direction) { 183 if (a.direction == b.direction) {
76 return together (a.cost, b.cost); 184 return together (a.distance, b.distance);
77 } else { 185 } else {
78 return opposing (a.cost, b.cost); 186 return opposing (a.distance, b.distance);
79 } 187 }
80 } 188 }
81 } 189 }
82
83 static std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
84 std::vector<Value> s2) {
85
86 }
87
88 }; 190 };
89 191
90 #endif 192 #endif