Mercurial > hg > svapp
comparison align/DTW.h @ 780:8fa98f89eda8
Add subsequence DTW (not yet in use)
author | Chris Cannam |
---|---|
date | Wed, 01 Jul 2020 15:34:46 +0100 |
parents | 8280f7a363d1 |
children | 4d10365aa6a9 |
comparison
equal
deleted
inserted
replaced
779:5de2b710cfae | 780:8fa98f89eda8 |
---|---|
25 { | 25 { |
26 public: | 26 public: |
27 DTW(std::function<double(const Value &, const Value &)> distanceMetric) : | 27 DTW(std::function<double(const Value &, const Value &)> distanceMetric) : |
28 m_metric(distanceMetric) { } | 28 m_metric(distanceMetric) { } |
29 | 29 |
30 std::vector<size_t> alignSeries(std::vector<Value> s1, | 30 /** |
31 std::vector<Value> s2) { | 31 * Align the sequence s2 against the whole of the sequence s1, |
32 | 32 * returning the index into s1 for each element in s2. |
33 // Return the index into s2 for each element in s1 | 33 */ |
34 | 34 std::vector<size_t> alignSequences(std::vector<Value> s1, |
35 std::vector<size_t> alignment(s1.size(), 0); | 35 std::vector<Value> s2) { |
36 | 36 return align(s1, s2, false); |
37 if (s1.empty() || s2.empty()) { | 37 } |
38 return alignment; | 38 |
39 } | 39 /** |
40 | 40 * Align the sequence sub against the best-matching subsequence of |
41 auto costs = costSeries(s1, s2); | 41 * s, returning the index into s for each element in sub. |
42 | 42 */ |
43 #ifdef DEBUG_DTW | 43 std::vector<size_t> alignSubsequence(std::vector<Value> s, |
44 SVCERR << "Cost matrix:" << endl; | 44 std::vector<Value> sub) { |
45 for (auto v: costs) { | 45 return align(s, sub, true); |
46 for (auto x: v) { | |
47 SVCERR << x << " "; | |
48 } | |
49 SVCERR << "\n"; | |
50 } | |
51 #endif | |
52 | |
53 size_t j = s1.size() - 1; | |
54 size_t i = s2.size() - 1; | |
55 | |
56 while (j > 0 && i > 0) { | |
57 | |
58 alignment[j] = i; | |
59 | |
60 cost_t a = costs[j-1][i]; | |
61 cost_t b = costs[j][i-1]; | |
62 cost_t both = costs[j-1][i-1]; | |
63 | |
64 if (a < b) { | |
65 --j; | |
66 if (both <= a) { | |
67 --i; | |
68 } | |
69 } else { | |
70 --i; | |
71 if (both <= b) { | |
72 --j; | |
73 } | |
74 } | |
75 } | |
76 | |
77 return alignment; | |
78 } | 46 } |
79 | 47 |
80 private: | 48 private: |
81 std::function<double(const Value &, const Value &)> m_metric; | 49 std::function<double(const Value &, const Value &)> m_metric; |
82 | 50 |
100 } else { | 68 } else { |
101 return 0.0; | 69 return 0.0; |
102 } | 70 } |
103 } | 71 } |
104 | 72 |
105 std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1, | 73 std::vector<std::vector<cost_t>> costSequences(std::vector<Value> s1, |
106 std::vector<Value> s2) { | 74 std::vector<Value> s2, |
75 bool subsequence) { | |
107 | 76 |
108 std::vector<std::vector<cost_t>> costs | 77 std::vector<std::vector<cost_t>> costs |
109 (s1.size(), std::vector<cost_t>(s2.size(), 0.0)); | 78 (s1.size(), std::vector<cost_t>(s2.size(), 0.0)); |
110 | 79 |
111 for (size_t j = 0; j < s1.size(); ++j) { | 80 for (size_t j = 0; j < s1.size(); ++j) { |
112 for (size_t i = 0; i < s2.size(); ++i) { | 81 for (size_t i = 0; i < s2.size(); ++i) { |
113 cost_t c = m_metric(s1[j], s2[i]); | 82 cost_t c = m_metric(s1[j], s2[i]); |
114 costs[j][i] = choose | 83 if (i == 0 && subsequence) { |
115 ( | 84 costs[j][i] = c; |
116 { j > 0, | 85 } else { |
117 j > 0 ? c + costs[j-1][i] : 0.0 | 86 costs[j][i] = choose |
118 }, | 87 ( |
119 { i > 0, | 88 { j > 0, |
120 i > 0 ? c + costs[j][i-1] : 0.0 | 89 j > 0 ? c + costs[j-1][i] : 0.0 |
121 }, | 90 }, |
122 { j > 0 && i > 0, | 91 { i > 0, |
123 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 | 92 i > 0 ? c + costs[j][i-1] : 0.0 |
124 }); | 93 }, |
94 { j > 0 && i > 0, | |
95 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 | |
96 }); | |
97 } | |
125 } | 98 } |
126 } | 99 } |
127 | 100 |
128 return costs; | 101 return costs; |
102 } | |
103 | |
104 std::vector<size_t> align(const std::vector<Value> &s1, | |
105 const std::vector<Value> &s2, | |
106 bool subsequence) { | |
107 | |
108 // Return the index into s1 for each element in s2 | |
109 | |
110 std::vector<size_t> alignment(s2.size(), 0); | |
111 | |
112 if (s1.empty() || s2.empty()) { | |
113 return alignment; | |
114 } | |
115 | |
116 auto costs = costSequences(s1, s2, subsequence); | |
117 | |
118 #ifdef DEBUG_DTW | |
119 SVCERR << "Cost matrix:" << endl; | |
120 for (auto v: cost) { | |
121 for (auto x: v) { | |
122 SVCERR << x << " "; | |
123 } | |
124 SVCERR << "\n"; | |
125 } | |
126 #endif | |
127 | |
128 size_t j = s1.size() - 1; | |
129 size_t i = s2.size() - 1; | |
130 | |
131 if (subsequence) { | |
132 cost_t min = 0.0; | |
133 size_t minidx = 0; | |
134 for (size_t j = 0; j < s1.size(); ++j) { | |
135 if (j == 0 || costs[j][i] < min) { | |
136 min = costs[j][i]; | |
137 minidx = j; | |
138 } | |
139 } | |
140 j = minidx; | |
141 #ifdef DEBUG_DTW | |
142 SVCERR << "Lowest cost at end of subsequence = " << min | |
143 << " at index " << j << ", tracking back from there" << endl; | |
144 #endif | |
145 } | |
146 | |
147 while (i > 0 && (j > 0 || subsequence)) { | |
148 | |
149 alignment[i] = j; | |
150 | |
151 cost_t a = costs[j-1][i]; | |
152 cost_t b = costs[j][i-1]; | |
153 cost_t both = costs[j-1][i-1]; | |
154 | |
155 if (a < b) { | |
156 --j; | |
157 if (both <= a) { | |
158 --i; | |
159 } | |
160 } else { | |
161 --i; | |
162 if (both <= b) { | |
163 --j; | |
164 } | |
165 } | |
166 } | |
167 | |
168 if (subsequence) { | |
169 alignment[0] = j; | |
170 } | |
171 | |
172 return alignment; | |
129 } | 173 } |
130 }; | 174 }; |
131 | 175 |
132 class MagnitudeDTW | 176 class MagnitudeDTW |
133 { | 177 { |
134 public: | 178 public: |
135 MagnitudeDTW() : m_dtw(metric) { } | 179 MagnitudeDTW() : m_dtw(metric) { } |
136 | 180 |
137 std::vector<size_t> alignSeries(std::vector<double> s1, | 181 std::vector<size_t> alignSequences(std::vector<double> s1, |
138 std::vector<double> s2) { | 182 std::vector<double> s2) { |
139 return m_dtw.alignSeries(s1, s2); | 183 return m_dtw.alignSequences(s1, s2); |
184 } | |
185 | |
186 std::vector<size_t> alignSubsequence(std::vector<double> s, | |
187 std::vector<double> sub) { | |
188 return m_dtw.alignSubsequence(s, sub); | |
140 } | 189 } |
141 | 190 |
142 private: | 191 private: |
143 DTW<double> m_dtw; | 192 DTW<double> m_dtw; |
144 | 193 |
161 double distance; | 210 double distance; |
162 }; | 211 }; |
163 | 212 |
164 RiseFallDTW() : m_dtw(metric) { } | 213 RiseFallDTW() : m_dtw(metric) { } |
165 | 214 |
166 std::vector<size_t> alignSeries(std::vector<Value> s1, | 215 std::vector<size_t> alignSequences(std::vector<Value> s1, |
167 std::vector<Value> s2) { | 216 std::vector<Value> s2) { |
168 return m_dtw.alignSeries(s1, s2); | 217 return m_dtw.alignSequences(s1, s2); |
218 } | |
219 | |
220 std::vector<size_t> alignSubsequence(std::vector<Value> s, | |
221 std::vector<Value> sub) { | |
222 return m_dtw.alignSubsequence(s, sub); | |
169 } | 223 } |
170 | 224 |
171 private: | 225 private: |
172 DTW<Value> m_dtw; | 226 DTW<Value> m_dtw; |
173 | 227 |