annotate align/DTW.h @ 765:eae885290abc

Profiling points and comment
author Chris Cannam
date Thu, 14 May 2020 16:38:37 +0100
parents f32df46d0c84
children 1d6cca5a5621
rev   line source
Chris@755 1 /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */
Chris@755 2
Chris@755 3 /*
Chris@755 4 Sonic Visualiser
Chris@755 5 An audio file viewer and annotation editor.
Chris@755 6 Centre for Digital Music, Queen Mary, University of London.
Chris@755 7
Chris@755 8 This program is free software; you can redistribute it and/or
Chris@755 9 modify it under the terms of the GNU General Public License as
Chris@755 10 published by the Free Software Foundation; either version 2 of the
Chris@755 11 License, or (at your option) any later version. See the file
Chris@755 12 COPYING included with this distribution for more information.
Chris@755 13 */
Chris@755 14
Chris@755 15 #ifndef SV_DTW_H
Chris@755 16 #define SV_DTW_H
Chris@755 17
Chris@755 18 #include <vector>
Chris@757 19 #include <functional>
Chris@755 20
Chris@757 21 template <typename Value>
Chris@755 22 class DTW
Chris@755 23 {
Chris@755 24 public:
Chris@757 25 DTW(std::function<double(const Value &, const Value &)> distanceMetric) :
Chris@757 26 m_metric(distanceMetric) { }
Chris@757 27
Chris@757 28 std::vector<size_t> alignSeries(std::vector<Value> s1,
Chris@757 29 std::vector<Value> s2) {
Chris@757 30
Chris@757 31 // Return the index into s2 for each element in s1
Chris@757 32
Chris@757 33 std::vector<size_t> alignment(s1.size(), 0);
Chris@757 34
Chris@757 35 if (s1.empty() || s2.empty()) {
Chris@757 36 return alignment;
Chris@757 37 }
Chris@757 38
Chris@757 39 auto costs = costSeries(s1, s2);
Chris@757 40
Chris@757 41 size_t j = s1.size() - 1;
Chris@757 42 size_t i = s2.size() - 1;
Chris@757 43
Chris@757 44 while (j > 0 && i > 0) {
Chris@757 45
Chris@757 46 alignment[j] = i;
Chris@757 47
Chris@757 48 cost_t a = costs[j-1][i];
Chris@757 49 cost_t b = costs[j][i-1];
Chris@757 50 cost_t both = costs[j-1][i-1];
Chris@757 51
Chris@757 52 if (a < b) {
Chris@757 53 --j;
Chris@757 54 if (both < a) {
Chris@757 55 --i;
Chris@757 56 }
Chris@757 57 } else {
Chris@757 58 --i;
Chris@757 59 if (both < b) {
Chris@757 60 --j;
Chris@757 61 }
Chris@757 62 }
Chris@757 63 }
Chris@757 64
Chris@757 65 return alignment;
Chris@757 66 }
Chris@757 67
Chris@757 68 private:
Chris@757 69 std::function<double(const Value &, const Value &)> m_metric;
Chris@757 70
Chris@755 71 typedef double cost_t;
Chris@755 72
Chris@755 73 struct CostOption {
Chris@755 74 bool present;
Chris@755 75 cost_t cost;
Chris@755 76 };
Chris@755 77
Chris@757 78 cost_t choose(CostOption x, CostOption y, CostOption d) {
Chris@755 79 if (x.present && y.present) {
Chris@755 80 if (!d.present) {
Chris@755 81 throw std::logic_error("if x & y both exist, so must diagonal");
Chris@755 82 }
Chris@755 83 return std::min(std::min(x.cost, y.cost), d.cost);
Chris@755 84 } else if (x.present) {
Chris@755 85 return x.cost;
Chris@755 86 } else if (y.present) {
Chris@755 87 return y.cost;
Chris@755 88 } else {
Chris@755 89 return 0.0;
Chris@755 90 }
Chris@755 91 }
Chris@755 92
Chris@757 93 std::vector<std::vector<cost_t>> costSeries(std::vector<Value> s1,
Chris@757 94 std::vector<Value> s2) {
Chris@757 95
Chris@757 96 std::vector<std::vector<cost_t>> costs
Chris@757 97 (s1.size(), std::vector<cost_t>(s2.size(), 0.0));
Chris@757 98
Chris@757 99 for (size_t j = 0; j < s1.size(); ++j) {
Chris@757 100 for (size_t i = 0; i < s2.size(); ++i) {
Chris@757 101 cost_t c = m_metric(s1[j], s2[i]);
Chris@757 102 costs[j][i] = choose
Chris@757 103 (
Chris@757 104 { j > 0,
Chris@757 105 j > 0 ? c + costs[j-1][i] : 0.0
Chris@757 106 },
Chris@757 107 { i > 0,
Chris@757 108 i > 0 ? c + costs[j][i-1] : 0.0
Chris@757 109 },
Chris@757 110 { j > 0 && i > 0,
Chris@757 111 j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0
Chris@757 112 });
Chris@757 113 }
Chris@757 114 }
Chris@757 115
Chris@757 116 return costs;
Chris@757 117 }
Chris@757 118 };
Chris@757 119
Chris@757 120 class MagnitudeDTW
Chris@757 121 {
Chris@757 122 public:
Chris@757 123 MagnitudeDTW() : m_dtw(metric) { }
Chris@757 124
Chris@757 125 std::vector<size_t> alignSeries(std::vector<double> s1,
Chris@757 126 std::vector<double> s2) {
Chris@757 127 return m_dtw.alignSeries(s1, s2);
Chris@757 128 }
Chris@757 129
Chris@757 130 private:
Chris@757 131 DTW<double> m_dtw;
Chris@757 132
Chris@757 133 static double metric(const double &a, const double &b) {
Chris@757 134 return std::abs(b - a);
Chris@757 135 }
Chris@757 136 };
Chris@757 137
Chris@757 138 class RiseFallDTW
Chris@757 139 {
Chris@757 140 public:
Chris@757 141 enum class Direction {
Chris@757 142 None,
Chris@757 143 Up,
Chris@757 144 Down
Chris@757 145 };
Chris@757 146
Chris@757 147 struct Value {
Chris@757 148 Direction direction;
Chris@757 149 double distance;
Chris@757 150 };
Chris@757 151
Chris@757 152 RiseFallDTW() : m_dtw(metric) { }
Chris@757 153
Chris@757 154 std::vector<size_t> alignSeries(std::vector<Value> s1,
Chris@757 155 std::vector<Value> s2) {
Chris@757 156 return m_dtw.alignSeries(s1, s2);
Chris@757 157 }
Chris@757 158
Chris@757 159 private:
Chris@757 160 DTW<Value> m_dtw;
Chris@757 161
Chris@757 162 static double metric(const Value &a, const Value &b) {
Chris@757 163
Chris@757 164 auto together = [](double c1, double c2) {
Chris@755 165 auto diff = std::abs(c1 - c2);
Chris@755 166 return (diff < 1.0 ? -1.0 :
Chris@755 167 diff > 3.0 ? 1.0 :
Chris@755 168 0.0);
Chris@755 169 };
Chris@757 170 auto opposing = [](double c1, double c2) {
Chris@755 171 auto diff = c1 + c2;
Chris@755 172 return (diff < 2.0 ? 1.0 :
Chris@755 173 2.0);
Chris@755 174 };
Chris@757 175
Chris@755 176 if (a.direction == Direction::None || b.direction == Direction::None) {
Chris@755 177 if (a.direction == b.direction) {
Chris@755 178 return 0.0;
Chris@755 179 } else {
Chris@755 180 return 1.0;
Chris@755 181 }
Chris@755 182 } else {
Chris@755 183 if (a.direction == b.direction) {
Chris@757 184 return together (a.distance, b.distance);
Chris@755 185 } else {
Chris@757 186 return opposing (a.distance, b.distance);
Chris@755 187 }
Chris@755 188 }
Chris@755 189 }
Chris@755 190 };
Chris@755 191
Chris@755 192 #endif