annotate align/DTW.h @ 778:83a7b10b7415

Merge from branch pitch-align
author Chris Cannam
date Fri, 26 Jun 2020 13:48:52 +0100
parents 8280f7a363d1
children 8fa98f89eda8
rev   line source
Chris@755 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@755 2
Chris@755 3 /*
Chris@755 4 Sonic Visualiser
Chris@755 5 An audio file viewer and annotation editor.
Chris@755 6 Centre for Digital Music, Queen Mary, University of London.
Chris@755 7
Chris@755 8 This program is free software; you can redistribute it and/or
Chris@755 9 modify it under the terms of the GNU General Public License as
Chris@755 10 published by the Free Software Foundation; either version 2 of the
Chris@755 11 License, or (at your option) any later version. See the file
Chris@755 12 COPYING included with this distribution for more information.
Chris@755 13 */
Chris@755 14
Chris@755 15 #ifndef SV_DTW_H
Chris@755 16 #define SV_DTW_H
Chris@755 17
Chris@755 18 #include <vector>
Chris@757 19 #include <functional>
Chris@755 20
Chris@772 21 //#define DEBUG_DTW 1
Chris@772 22
Chris@757 23 template <typename Value>
Chris@755 24 class DTW
Chris@755 25 {
Chris@755 26 public:
Chris@757 27 DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
Chris@757 28 m_metric(distanceMetric) { }
Chris@757 29
Chris@757 30 std::vector<size_t> alignSeries(std::vector<Value> s1,
Chris@757 31 std::vector<Value> s2) {
Chris@757 32
Chris@757 33 // Return the index into s2 for each element in s1
Chris@757 34
Chris@757 35 std::vector<size_t> alignment(s1.size(), 0);
Chris@757 36
Chris@757 37 if (s1.empty() || s2.empty()) {
Chris@757 38 return alignment;
Chris@757 39 }
Chris@757 40
Chris@757 41 auto costs = costSeries(s1, s2);
Chris@757 42
Chris@772 43 #ifdef DEBUG_DTW
Chris@772 44 SVCERR << "Cost matrix:" << endl;
Chris@772 45 for (auto v: costs) {
Chris@772 46 for (auto x: v) {
Chris@772 47 SVCERR << x << " ";
Chris@772 48 }
Chris@772 49 SVCERR << "\n";
Chris@772 50 }
Chris@772 51 #endif
Chris@772 52
Chris@757 53 size_t j = s1.size() - 1;
Chris@757 54 size_t i = s2.size() - 1;
Chris@757 55
Chris@757 56 while (j > 0 && i > 0) {
Chris@757 57
Chris@757 58 alignment[j] = i;
Chris@757 59
Chris@757 60 cost_t a = costs[j-1][i];
Chris@757 61 cost_t b = costs[j][i-1];
Chris@757 62 cost_t both = costs[j-1][i-1];
Chris@757 63
Chris@757 64 if (a < b) {
Chris@757 65 --j;
Chris@772 66 if (both <= a) {
Chris@757 67 --i;
Chris@757 68 }
Chris@757 69 } else {
Chris@757 70 --i;
Chris@772 71 if (both <= b) {
Chris@757 72 --j;
Chris@757 73 }
Chris@757 74 }
Chris@757 75 }
Chris@757 76
Chris@757 77 return alignment;
Chris@757 78 }
Chris@757 79
Chris@757 80 private:
Chris@757 81 std::function<double(const Value &, const Value &)> m_metric;
Chris@757 82
Chris@755 83 typedef double cost_t;
Chris@755 84
Chris@755 85 struct CostOption {
Chris@755 86 bool present;
Chris@755 87 cost_t cost;
Chris@755 88 };
Chris@755 89
Chris@757 90 cost_t choose(CostOption x, CostOption y, CostOption d) {
Chris@755 91 if (x.present && y.present) {
Chris@755 92 if (!d.present) {
Chris@755 93 throw std::logic_error("if x & y both exist, so must diagonal");
Chris@755 94 }
Chris@755 95 return std::min(std::min(x.cost, y.cost), d.cost);
Chris@755 96 } else if (x.present) {
Chris@755 97 return x.cost;
Chris@755 98 } else if (y.present) {
Chris@755 99 return y.cost;
Chris@755 100 } else {
Chris@755 101 return 0.0;
Chris@755 102 }
Chris@755 103 }
Chris@755 104
Chris@757 105 std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
Chris@757 106 std::vector<Value> s2) {
Chris@757 107
Chris@757 108 std::vector<std::vector<cost_t>> costs
Chris@757 109 (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
Chris@757 110
Chris@757 111 for (size_t j = 0; j < s1.size(); ++j) {
Chris@757 112 for (size_t i = 0; i < s2.size(); ++i) {
Chris@757 113 cost_t c = m_metric(s1[j], s2[i]);
Chris@757 114 costs[j][i] = choose
Chris@757 115 (
Chris@757 116 { j > 0,
Chris@757 117 j > 0 ? c + costs[j-1][i] : 0.0
Chris@757 118 },
Chris@757 119 { i > 0,
Chris@757 120 i > 0 ? c + costs[j][i-1] : 0.0
Chris@757 121 },
Chris@757 122 { j > 0 && i > 0,
Chris@757 123 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
Chris@757 124 });
Chris@757 125 }
Chris@757 126 }
Chris@757 127
Chris@757 128 return costs;
Chris@757 129 }
Chris@757 130 };
Chris@757 131
Chris@757 132 class MagnitudeDTW
Chris@757 133 {
Chris@757 134 public:
Chris@757 135 MagnitudeDTW() : m_dtw(metric) { }
Chris@757 136
Chris@757 137 std::vector<size_t> alignSeries(std::vector<double> s1,
Chris@757 138 std::vector<double> s2) {
Chris@757 139 return m_dtw.alignSeries(s1, s2);
Chris@757 140 }
Chris@757 141
Chris@757 142 private:
Chris@757 143 DTW<double> m_dtw;
Chris@757 144
Chris@757 145 static double metric(const double &a, const double &b) {
Chris@757 146 return std::abs(b - a);
Chris@757 147 }
Chris@757 148 };
Chris@757 149
Chris@757 150 class RiseFallDTW
Chris@757 151 {
Chris@757 152 public:
Chris@757 153 enum class Direction {
Chris@757 154 None,
Chris@757 155 Up,
Chris@757 156 Down
Chris@757 157 };
Chris@757 158
Chris@757 159 struct Value {
Chris@757 160 Direction direction;
Chris@757 161 double distance;
Chris@757 162 };
Chris@757 163
Chris@757 164 RiseFallDTW() : m_dtw(metric) { }
Chris@757 165
Chris@757 166 std::vector<size_t> alignSeries(std::vector<Value> s1,
Chris@757 167 std::vector<Value> s2) {
Chris@757 168 return m_dtw.alignSeries(s1, s2);
Chris@757 169 }
Chris@757 170
Chris@757 171 private:
Chris@757 172 DTW<Value> m_dtw;
Chris@757 173
Chris@757 174 static double metric(const Value &a, const Value &b) {
Chris@757 175
Chris@757 176 auto together = [](double c1, double c2) {
Chris@755 177 auto diff = std::abs(c1 - c2);
Chris@755 178 return (diff < 1.0 ? -1.0 :
Chris@755 179 diff > 3.0 ? 1.0 :
Chris@755 180 0.0);
Chris@755 181 };
Chris@757 182 auto opposing = [](double c1, double c2) {
Chris@755 183 auto diff = c1 + c2;
Chris@755 184 return (diff < 2.0 ? 1.0 :
Chris@755 185 2.0);
Chris@755 186 };
Chris@757 187
Chris@755 188 if (a.direction == Direction::None || b.direction == Direction::None) {
Chris@755 189 if (a.direction == b.direction) {
Chris@755 190 return 0.0;
Chris@755 191 } else {
Chris@755 192 return 1.0;
Chris@755 193 }
Chris@755 194 } else {
Chris@755 195 if (a.direction == b.direction) {
Chris@757 196 return together (a.distance, b.distance);
Chris@755 197 } else {
Chris@757 198 return opposing (a.distance, b.distance);
Chris@755 199 }
Chris@755 200 }
Chris@755 201 }
Chris@755 202 };
Chris@755 203
Chris@771 204 inline std::ostream &operator<<(std::ostream &s, const RiseFallDTW::Value v) {
Chris@771 205 return (s <<
Chris@771 206 (v.direction == RiseFallDTW::Direction::None ? "=" :
Chris@771 207 v.direction == RiseFallDTW::Direction::Up ? "+" : "-")
Chris@771 208 << v.distance);
Chris@771 209 }
Chris@771 210
Chris@755 211 #endif