annotate align/DTW.h @ 786:1089d65c585d tip

Divert some debug output away from stderr
author Chris Cannam
date Fri, 14 Aug 2020 10:46:44 +0100
parents 4d10365aa6a9
children
rev   line source
Chris@755 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@755 2
Chris@755 3 /*
Chris@755 4 Sonic Visualiser
Chris@755 5 An audio file viewer and annotation editor.
Chris@755 6 Centre for Digital Music, Queen Mary, University of London.
Chris@755 7
Chris@755 8 This program is free software; you can redistribute it and/or
Chris@755 9 modify it under the terms of the GNU General Public License as
Chris@755 10 published by the Free Software Foundation; either version 2 of the
Chris@755 11 License, or (at your option) any later version. See the file
Chris@755 12 COPYING included with this distribution for more information.
Chris@755 13 */
Chris@755 14
Chris@755 15 #ifndef SV_DTW_H
Chris@755 16 #define SV_DTW_H
Chris@755 17
Chris@755 18 #include <vector>
Chris@757 19 #include <functional>
Chris@755 20
Chris@772 21 //#define DEBUG_DTW 1
Chris@772 22
Chris@757 23 template <typename Value>
Chris@755 24 class DTW
Chris@755 25 {
Chris@755 26 public:
Chris@757 27 DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
Chris@757 28 m_metric(distanceMetric) { }
Chris@757 29
Chris@780 30 /**
Chris@780 31 * Align the sequence s2 against the whole of the sequence s1,
Chris@780 32 * returning the index into s1 for each element in s2.
Chris@780 33 */
Chris@780 34 std::vector<size_t> alignSequences(std::vector<Value> s1,
Chris@780 35 std::vector<Value> s2) {
Chris@780 36 return align(s1, s2, false);
Chris@780 37 }
Chris@757 38
Chris@780 39 /**
Chris@780 40 * Align the sequence sub against the best-matching subsequence of
Chris@780 41 * s, returning the index into s for each element in sub.
Chris@780 42 */
Chris@780 43 std::vector<size_t> alignSubsequence(std::vector<Value> s,
Chris@780 44 std::vector<Value> sub) {
Chris@780 45 return align(s, sub, true);
Chris@757 46 }
Chris@757 47
Chris@757 48 private:
Chris@757 49 std::function<double(const Value &, const Value &)> m_metric;
Chris@757 50
Chris@755 51 typedef double cost_t;
Chris@755 52
Chris@755 53 struct CostOption {
Chris@755 54 bool present;
Chris@755 55 cost_t cost;
Chris@755 56 };
Chris@755 57
Chris@757 58 cost_t choose(CostOption x, CostOption y, CostOption d) {
Chris@755 59 if (x.present && y.present) {
Chris@755 60 if (!d.present) {
Chris@755 61 throw std::logic_error("if x & y both exist, so must diagonal");
Chris@755 62 }
Chris@755 63 return std::min(std::min(x.cost, y.cost), d.cost);
Chris@755 64 } else if (x.present) {
Chris@755 65 return x.cost;
Chris@755 66 } else if (y.present) {
Chris@755 67 return y.cost;
Chris@755 68 } else {
Chris@755 69 return 0.0;
Chris@755 70 }
Chris@755 71 }
Chris@755 72
Chris@780 73 std::vector<std::vector<cost_t>> costSequences(std::vector<Value> s1,
Chris@780 74 std::vector<Value> s2,
Chris@780 75 bool subsequence) {
Chris@757 76
Chris@757 77 std::vector<std::vector<cost_t>> costs
Chris@757 78 (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
Chris@757 79
Chris@757 80 for (size_t j = 0; j < s1.size(); ++j) {
Chris@757 81 for (size_t i = 0; i < s2.size(); ++i) {
Chris@757 82 cost_t c = m_metric(s1[j], s2[i]);
Chris@780 83 if (i == 0 && subsequence) {
Chris@780 84 costs[j][i] = c;
Chris@780 85 } else {
Chris@780 86 costs[j][i] = choose
Chris@780 87 (
Chris@780 88 { j > 0,
Chris@780 89 j > 0 ? c + costs[j-1][i] : 0.0
Chris@780 90 },
Chris@780 91 { i > 0,
Chris@780 92 i > 0 ? c + costs[j][i-1] : 0.0
Chris@780 93 },
Chris@780 94 { j > 0 && i > 0,
Chris@780 95 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
Chris@780 96 });
Chris@780 97 }
Chris@757 98 }
Chris@757 99 }
Chris@757 100
Chris@757 101 return costs;
Chris@757 102 }
Chris@780 103
Chris@780 104 std::vector<size_t> align(const std::vector<Value> &s1,
Chris@780 105 const std::vector<Value> &s2,
Chris@780 106 bool subsequence) {
Chris@780 107
Chris@780 108 // Return the index into s1 for each element in s2
Chris@780 109
Chris@780 110 std::vector<size_t> alignment(s2.size(), 0);
Chris@780 111
Chris@780 112 if (s1.empty() || s2.empty()) {
Chris@780 113 return alignment;
Chris@780 114 }
Chris@780 115
Chris@780 116 auto costs = costSequences(s1, s2, subsequence);
Chris@780 117
Chris@780 118 #ifdef DEBUG_DTW
Chris@780 119 SVCERR << "Cost matrix:" << endl;
Chris@780 120 for (auto v: cost) {
Chris@780 121 for (auto x: v) {
Chris@780 122 SVCERR << x << " ";
Chris@780 123 }
Chris@780 124 SVCERR << "\n";
Chris@780 125 }
Chris@780 126 #endif
Chris@780 127
Chris@780 128 size_t j = s1.size() - 1;
Chris@780 129 size_t i = s2.size() - 1;
Chris@780 130
Chris@780 131 if (subsequence) {
Chris@780 132 cost_t min = 0.0;
Chris@780 133 size_t minidx = 0;
Chris@780 134 for (size_t j = 0; j < s1.size(); ++j) {
Chris@780 135 if (j == 0 || costs[j][i] < min) {
Chris@780 136 min = costs[j][i];
Chris@780 137 minidx = j;
Chris@780 138 }
Chris@780 139 }
Chris@780 140 j = minidx;
Chris@780 141 #ifdef DEBUG_DTW
Chris@780 142 SVCERR << "Lowest cost at end of subsequence = " << min
Chris@780 143 << " at index " << j << ", tracking back from there" << endl;
Chris@780 144 #endif
Chris@780 145 }
Chris@780 146
Chris@783 147 while (i > 0 || j > 0) {
Chris@780 148
Chris@780 149 alignment[i] = j;
Chris@780 150
Chris@783 151 if (i == 0) {
Chris@783 152 if (subsequence) {
Chris@783 153 break;
Chris@783 154 } else {
Chris@783 155 --j;
Chris@783 156 continue;
Chris@783 157 }
Chris@783 158 }
Chris@783 159
Chris@783 160 if (j == 0) {
Chris@783 161 --i;
Chris@783 162 continue;
Chris@783 163 }
Chris@783 164
Chris@780 165 cost_t a = costs[j-1][i];
Chris@780 166 cost_t b = costs[j][i-1];
Chris@780 167 cost_t both = costs[j-1][i-1];
Chris@780 168
Chris@780 169 if (a < b) {
Chris@780 170 --j;
Chris@780 171 if (both <= a) {
Chris@780 172 --i;
Chris@780 173 }
Chris@780 174 } else {
Chris@780 175 --i;
Chris@780 176 if (both <= b) {
Chris@780 177 --j;
Chris@780 178 }
Chris@780 179 }
Chris@780 180 }
Chris@780 181
Chris@780 182 if (subsequence) {
Chris@780 183 alignment[0] = j;
Chris@780 184 }
Chris@780 185
Chris@780 186 return alignment;
Chris@780 187 }
Chris@757 188 };
Chris@757 189
Chris@757 190 class MagnitudeDTW
Chris@757 191 {
Chris@757 192 public:
Chris@757 193 MagnitudeDTW() : m_dtw(metric) { }
Chris@757 194
Chris@780 195 std::vector<size_t> alignSequences(std::vector<double> s1,
Chris@780 196 std::vector<double> s2) {
Chris@780 197 return m_dtw.alignSequences(s1, s2);
Chris@780 198 }
Chris@780 199
Chris@780 200 std::vector<size_t> alignSubsequence(std::vector<double> s,
Chris@780 201 std::vector<double> sub) {
Chris@780 202 return m_dtw.alignSubsequence(s, sub);
Chris@757 203 }
Chris@757 204
Chris@757 205 private:
Chris@757 206 DTW<double> m_dtw;
Chris@757 207
Chris@757 208 static double metric(const double &a, const double &b) {
Chris@757 209 return std::abs(b - a);
Chris@757 210 }
Chris@757 211 };
Chris@757 212
Chris@757 213 class RiseFallDTW
Chris@757 214 {
Chris@757 215 public:
Chris@757 216 enum class Direction {
Chris@757 217 None,
Chris@757 218 Up,
Chris@757 219 Down
Chris@757 220 };
Chris@757 221
Chris@757 222 struct Value {
Chris@757 223 Direction direction;
Chris@757 224 double distance;
Chris@757 225 };
Chris@757 226
Chris@757 227 RiseFallDTW() : m_dtw(metric) { }
Chris@757 228
Chris@780 229 std::vector<size_t> alignSequences(std::vector<Value> s1,
Chris@780 230 std::vector<Value> s2) {
Chris@780 231 return m_dtw.alignSequences(s1, s2);
Chris@780 232 }
Chris@780 233
Chris@780 234 std::vector<size_t> alignSubsequence(std::vector<Value> s,
Chris@780 235 std::vector<Value> sub) {
Chris@780 236 return m_dtw.alignSubsequence(s, sub);
Chris@757 237 }
Chris@757 238
Chris@757 239 private:
Chris@757 240 DTW<Value> m_dtw;
Chris@757 241
Chris@757 242 static double metric(const Value &a, const Value &b) {
Chris@757 243
Chris@757 244 auto together = [](double c1, double c2) {
Chris@755 245 auto diff = std::abs(c1 - c2);
Chris@755 246 return (diff < 1.0 ? -1.0 :
Chris@755 247 diff > 3.0 ? 1.0 :
Chris@755 248 0.0);
Chris@755 249 };
Chris@757 250 auto opposing = [](double c1, double c2) {
Chris@755 251 auto diff = c1 + c2;
Chris@755 252 return (diff < 2.0 ? 1.0 :
Chris@755 253 2.0);
Chris@755 254 };
Chris@757 255
Chris@755 256 if (a.direction == Direction::None || b.direction == Direction::None) {
Chris@755 257 if (a.direction == b.direction) {
Chris@755 258 return 0.0;
Chris@755 259 } else {
Chris@755 260 return 1.0;
Chris@755 261 }
Chris@755 262 } else {
Chris@755 263 if (a.direction == b.direction) {
Chris@757 264 return together (a.distance, b.distance);
Chris@755 265 } else {
Chris@757 266 return opposing (a.distance, b.distance);
Chris@755 267 }
Chris@755 268 }
Chris@755 269 }
Chris@755 270 };
Chris@755 271
Chris@771 272 inline std::ostream &operator<<(std::ostream &s, const RiseFallDTW::Value v) {
Chris@771 273 return (s <<
Chris@771 274 (v.direction == RiseFallDTW::Direction::None ? "=" :
Chris@771 275 v.direction == RiseFallDTW::Direction::Up ? "+" : "-")
Chris@771 276 << v.distance);
Chris@771 277 }
Chris@771 278
Chris@755 279 #endif