Mercurial > hg > svapp
view align/DTW.h @ 786:1089d65c585d tip
Divert some debug output away from stderr
author | Chris Cannam |
---|---|
date | Fri, 14 Aug 2020 10:46:44 +0100 |
parents | 4d10365aa6a9 |
children |
line wrap: on
line source
/* -*- c-basic-offset: 4 indent-tabs-mode: nil -*- vi:set ts=8 sts=4 sw=4: */ /* Sonic Visualiser An audio file viewer and annotation editor. Centre for Digital Music, Queen Mary, University of London. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. See the file COPYING included with this distribution for more information. */ #ifndef SV_DTW_H #define SV_DTW_H #include <vector> #include <functional> //#define DEBUG_DTW 1 template <typename Value> class DTW { public: DTW(std::function<double(const Value &, const Value &)> distanceMetric) : m_metric(distanceMetric) { } /** * Align the sequence s2 against the whole of the sequence s1, * returning the index into s1 for each element in s2. */ std::vector<size_t> alignSequences(std::vector<Value> s1, std::vector<Value> s2) { return align(s1, s2, false); } /** * Align the sequence sub against the best-matching subsequence of * s, returning the index into s for each element in sub. */ std::vector<size_t> alignSubsequence(std::vector<Value> s, std::vector<Value> sub) { return align(s, sub, true); } private: std::function<double(const Value &, const Value &)> m_metric; typedef double cost_t; struct CostOption { bool present; cost_t cost; }; cost_t choose(CostOption x, CostOption y, CostOption d) { if (x.present && y.present) { if (!d.present) { throw std::logic_error("if x & y both exist, so must diagonal"); } return std::min(std::min(x.cost, y.cost), d.cost); } else if (x.present) { return x.cost; } else if (y.present) { return y.cost; } else { return 0.0; } } std::vector<std::vector<cost_t>> costSequences(std::vector<Value> s1, std::vector<Value> s2, bool subsequence) { std::vector<std::vector<cost_t>> costs (s1.size(), std::vector<cost_t>(s2.size(), 0.0)); for (size_t j = 0; j < s1.size(); ++j) { for (size_t i = 0; i < s2.size(); ++i) { cost_t c = m_metric(s1[j], s2[i]); if (i == 0 && subsequence) { costs[j][i] = c; } else { costs[j][i] = choose ( { j > 0, j > 0 ? c + costs[j-1][i] : 0.0 }, { i > 0, i > 0 ? c + costs[j][i-1] : 0.0 }, { j > 0 && i > 0, j > 0 && i > 0 ? c + costs[j-1][i-1] : 0.0 }); } } } return costs; } std::vector<size_t> align(const std::vector<Value> &s1, const std::vector<Value> &s2, bool subsequence) { // Return the index into s1 for each element in s2 std::vector<size_t> alignment(s2.size(), 0); if (s1.empty() || s2.empty()) { return alignment; } auto costs = costSequences(s1, s2, subsequence); #ifdef DEBUG_DTW SVCERR << "Cost matrix:" << endl; for (auto v: cost) { for (auto x: v) { SVCERR << x << " "; } SVCERR << "\n"; } #endif size_t j = s1.size() - 1; size_t i = s2.size() - 1; if (subsequence) { cost_t min = 0.0; size_t minidx = 0; for (size_t j = 0; j < s1.size(); ++j) { if (j == 0 || costs[j][i] < min) { min = costs[j][i]; minidx = j; } } j = minidx; #ifdef DEBUG_DTW SVCERR << "Lowest cost at end of subsequence = " << min << " at index " << j << ", tracking back from there" << endl; #endif } while (i > 0 || j > 0) { alignment[i] = j; if (i == 0) { if (subsequence) { break; } else { --j; continue; } } if (j == 0) { --i; continue; } cost_t a = costs[j-1][i]; cost_t b = costs[j][i-1]; cost_t both = costs[j-1][i-1]; if (a < b) { --j; if (both <= a) { --i; } } else { --i; if (both <= b) { --j; } } } if (subsequence) { alignment[0] = j; } return alignment; } }; class MagnitudeDTW { public: MagnitudeDTW() : m_dtw(metric) { } std::vector<size_t> alignSequences(std::vector<double> s1, std::vector<double> s2) { return m_dtw.alignSequences(s1, s2); } std::vector<size_t> alignSubsequence(std::vector<double> s, std::vector<double> sub) { return m_dtw.alignSubsequence(s, sub); } private: DTW<double> m_dtw; static double metric(const double &a, const double &b) { return std::abs(b - a); } }; class RiseFallDTW { public: enum class Direction { None, Up, Down }; struct Value { Direction direction; double distance; }; RiseFallDTW() : m_dtw(metric) { } std::vector<size_t> alignSequences(std::vector<Value> s1, std::vector<Value> s2) { return m_dtw.alignSequences(s1, s2); } std::vector<size_t> alignSubsequence(std::vector<Value> s, std::vector<Value> sub) { return m_dtw.alignSubsequence(s, sub); } private: DTW<Value> m_dtw; static double metric(const Value &a, const Value &b) { auto together = [](double c1, double c2) { auto diff = std::abs(c1 - c2); return (diff < 1.0 ? -1.0 : diff > 3.0 ? 1.0 : 0.0); }; auto opposing = [](double c1, double c2) { auto diff = c1 + c2; return (diff < 2.0 ? 1.0 : 2.0); }; if (a.direction == Direction::None || b.direction == Direction::None) { if (a.direction == b.direction) { return 0.0; } else { return 1.0; } } else { if (a.direction == b.direction) { return together (a.distance, b.distance); } else { return opposing (a.distance, b.distance); } } } }; inline std::ostream &operator<<(std::ostream &s, const RiseFallDTW::Value v) { return (s << (v.direction == RiseFallDTW::Direction::None ? "=" : v.direction == RiseFallDTW::Direction::Up ? "+" : "-") << v.distance); } #endif