Chris@755: /* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ Chris@755: Chris@755: /* Chris@755: Sonic Visualiser Chris@755: An audio file viewer and annotation editor. Chris@755: Centre for Digital Music, Queen Mary, University of London. Chris@755: Chris@755: This program is free software; you can redistribute it and/or Chris@755: modify it under the terms of the GNU General Public License as Chris@755: published by the Free Software Foundation; either version 2 of the Chris@755: License, or (at your option) any later version. See the file Chris@755: COPYING included with this distribution for more information. Chris@755: */ Chris@755: Chris@755: #ifndef SV_DTW_H Chris@755: #define SV_DTW_H Chris@755: Chris@755: #include Chris@757: #include Chris@755: Chris@772: //#define DEBUG_DTW 1 Chris@772: Chris@757: template Chris@755: class DTW Chris@755: { Chris@755: public: Chris@757: DTW(std::function distanceMetric) : Chris@757: m_metric(distanceMetric) { } Chris@757: Chris@780: /** Chris@780: * Align the sequence s2 against the whole of the sequence s1, Chris@780: * returning the index into s1 for each element in s2. Chris@780: */ Chris@780: std::vector alignSequences(std::vector s1, Chris@780: std::vector s2) { Chris@780: return align(s1, s2, false); Chris@780: } Chris@757: Chris@780: /** Chris@780: * Align the sequence sub against the best-matching subsequence of Chris@780: * s, returning the index into s for each element in sub. Chris@780: */ Chris@780: std::vector alignSubsequence(std::vector s, Chris@780: std::vector sub) { Chris@780: return align(s, sub, true); Chris@757: } Chris@757: Chris@757: private: Chris@757: std::function m_metric; Chris@757: Chris@755: typedef double cost_t; Chris@755: Chris@755: struct CostOption { Chris@755: bool present; Chris@755: cost_t cost; Chris@755: }; Chris@755: Chris@757: cost_t choose(CostOption x, CostOption y, CostOption d) { Chris@755: if (x.present && y.present) { Chris@755: if (!d.present) { Chris@755: throw std::logic_error("if x & y both exist, so must diagonal"); Chris@755: } Chris@755: return std::min(std::min(x.cost, y.cost), d.cost); Chris@755: } else if (x.present) { Chris@755: return x.cost; Chris@755: } else if (y.present) { Chris@755: return y.cost; Chris@755: } else { Chris@755: return 0.0; Chris@755: } Chris@755: } Chris@755: Chris@780: std::vector> costSequences(std::vector s1, Chris@780: std::vector s2, Chris@780: bool subsequence) { Chris@757: Chris@757: std::vector> costs Chris@757: (s1.size(), std::vector(s2.size(), 0.0)); Chris@757: Chris@757: for (size_t j = 0; j < s1.size(); ++j) { Chris@757: for (size_t i = 0; i < s2.size(); ++i) { Chris@757: cost_t c = m_metric(s1[j], s2[i]); Chris@780: if (i == 0 && subsequence) { Chris@780: costs[j][i] = c; Chris@780: } else { Chris@780: costs[j][i] = choose Chris@780: ( Chris@780: { j > 0, Chris@780: j > 0 ? c + costs[j-1][i] : 0.0 Chris@780: }, Chris@780: { i > 0, Chris@780: i > 0 ? c + costs[j][i-1] : 0.0 Chris@780: }, Chris@780: { j > 0 && i > 0, Chris@780: j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 Chris@780: }); Chris@780: } Chris@757: } Chris@757: } Chris@757: Chris@757: return costs; Chris@757: } Chris@780: Chris@780: std::vector align(const std::vector &s1, Chris@780: const std::vector &s2, Chris@780: bool subsequence) { Chris@780: Chris@780: // Return the index into s1 for each element in s2 Chris@780: Chris@780: std::vector alignment(s2.size(), 0); Chris@780: Chris@780: if (s1.empty() || s2.empty()) { Chris@780: return alignment; Chris@780: } Chris@780: Chris@780: auto costs = costSequences(s1, s2, subsequence); Chris@780: Chris@780: #ifdef DEBUG_DTW Chris@780: SVCERR << "Cost matrix:" << endl; Chris@780: for (auto v: cost) { Chris@780: for (auto x: v) { Chris@780: SVCERR << x << " "; Chris@780: } Chris@780: SVCERR << "\n"; Chris@780: } Chris@780: #endif Chris@780: Chris@780: size_t j = s1.size() - 1; Chris@780: size_t i = s2.size() - 1; Chris@780: Chris@780: if (subsequence) { Chris@780: cost_t min = 0.0; Chris@780: size_t minidx = 0; Chris@780: for (size_t j = 0; j < s1.size(); ++j) { Chris@780: if (j == 0 || costs[j][i] < min) { Chris@780: min = costs[j][i]; Chris@780: minidx = j; Chris@780: } Chris@780: } Chris@780: j = minidx; Chris@780: #ifdef DEBUG_DTW Chris@780: SVCERR << "Lowest cost at end of subsequence = " << min Chris@780: << " at index " << j << ", tracking back from there" << endl; Chris@780: #endif Chris@780: } Chris@780: Chris@783: while (i > 0 || j > 0) { Chris@780: Chris@780: alignment[i] = j; Chris@780: Chris@783: if (i == 0) { Chris@783: if (subsequence) { Chris@783: break; Chris@783: } else { Chris@783: --j; Chris@783: continue; Chris@783: } Chris@783: } Chris@783: Chris@783: if (j == 0) { Chris@783: --i; Chris@783: continue; Chris@783: } Chris@783: Chris@780: cost_t a = costs[j-1][i]; Chris@780: cost_t b = costs[j][i-1]; Chris@780: cost_t both = costs[j-1][i-1]; Chris@780: Chris@780: if (a < b) { Chris@780: --j; Chris@780: if (both <= a) { Chris@780: --i; Chris@780: } Chris@780: } else { Chris@780: --i; Chris@780: if (both <= b) { Chris@780: --j; Chris@780: } Chris@780: } Chris@780: } Chris@780: Chris@780: if (subsequence) { Chris@780: alignment[0] = j; Chris@780: } Chris@780: Chris@780: return alignment; Chris@780: } Chris@757: }; Chris@757: Chris@757: class MagnitudeDTW Chris@757: { Chris@757: public: Chris@757: MagnitudeDTW() : m_dtw(metric) { } Chris@757: Chris@780: std::vector alignSequences(std::vector s1, Chris@780: std::vector s2) { Chris@780: return m_dtw.alignSequences(s1, s2); Chris@780: } Chris@780: Chris@780: std::vector alignSubsequence(std::vector s, Chris@780: std::vector sub) { Chris@780: return m_dtw.alignSubsequence(s, sub); Chris@757: } Chris@757: Chris@757: private: Chris@757: DTW m_dtw; Chris@757: Chris@757: static double metric(const double &a, const double &b) { Chris@757: return std::abs(b - a); Chris@757: } Chris@757: }; Chris@757: Chris@757: class RiseFallDTW Chris@757: { Chris@757: public: Chris@757: enum class Direction { Chris@757: None, Chris@757: Up, Chris@757: Down Chris@757: }; Chris@757: Chris@757: struct Value { Chris@757: Direction direction; Chris@757: double distance; Chris@757: }; Chris@757: Chris@757: RiseFallDTW() : m_dtw(metric) { } Chris@757: Chris@780: std::vector alignSequences(std::vector s1, Chris@780: std::vector s2) { Chris@780: return m_dtw.alignSequences(s1, s2); Chris@780: } Chris@780: Chris@780: std::vector alignSubsequence(std::vector s, Chris@780: std::vector sub) { Chris@780: return m_dtw.alignSubsequence(s, sub); Chris@757: } Chris@757: Chris@757: private: Chris@757: DTW m_dtw; Chris@757: Chris@757: static double metric(const Value &a, const Value &b) { Chris@757: Chris@757: auto together = [](double c1, double c2) { Chris@755: auto diff = std::abs(c1 - c2); Chris@755: return (diff < 1.0 ? -1.0 : Chris@755: diff > 3.0 ? 1.0 : Chris@755: 0.0); Chris@755: }; Chris@757: auto opposing = [](double c1, double c2) { Chris@755: auto diff = c1 + c2; Chris@755: return (diff < 2.0 ? 1.0 : Chris@755: 2.0); Chris@755: }; Chris@757: Chris@755: if (a.direction == Direction::None || b.direction == Direction::None) { Chris@755: if (a.direction == b.direction) { Chris@755: return 0.0; Chris@755: } else { Chris@755: return 1.0; Chris@755: } Chris@755: } else { Chris@755: if (a.direction == b.direction) { Chris@757: return together (a.distance, b.distance); Chris@755: } else { Chris@757: return opposing (a.distance, b.distance); Chris@755: } Chris@755: } Chris@755: } Chris@755: }; Chris@755: Chris@771: inline std::ostream &operator<<(std::ostream &s, const RiseFallDTW::Value v) { Chris@771: return (s << Chris@771: (v.direction == RiseFallDTW::Direction::None ? "=" : Chris@771: v.direction == RiseFallDTW::Direction::Up ? "+" : "-") Chris@771: << v.distance); Chris@771: } Chris@771: Chris@755: #endif