diff src/Silvet.cpp @ 184:9b9cdfccbd14 noteagent

Wire up note agent code -- results are not very good, so far
author Chris Cannam
date Wed, 28 May 2014 14:54:01 +0100
parents 825193ef09d2
children 78212f764251
line wrap: on
line diff
--- a/src/Silvet.cpp	Fri May 23 17:03:27 2014 +0100
+++ b/src/Silvet.cpp	Wed May 28 14:54:01 2014 +0100
@@ -19,6 +19,9 @@
 #include <cq/CQSpectrogram.h>
 
 #include "MedianFilter.h"
+#include "AgentFeederPoly.h"
+#include "NoteHypothesis.h"
+
 #include "constant-q-cpp/src/dsp/Resampler.h"
 
 #include <vector>
@@ -42,7 +45,8 @@
     m_hqMode(true),
     m_fineTuning(false),
     m_instrument(0),
-    m_colsPerSec(50)
+    m_colsPerSec(50),
+    m_agentFeeder(0)
 {
 }
 
@@ -53,6 +57,7 @@
     for (int i = 0; i < (int)m_postFilter.size(); ++i) {
         delete m_postFilter[i];
     }
+    delete m_agentFeeder;
 }
 
 string
@@ -353,6 +358,7 @@
 {
     delete m_resampler;
     delete m_cq;
+    delete m_agentFeeder;
 
     if (m_inputSampleRate != processingSampleRate) {
 	m_resampler = new Resampler(m_inputSampleRate, processingSampleRate);
@@ -393,15 +399,18 @@
     for (int i = 0; i < m_instruments[0].templateNoteCount; ++i) {
         m_postFilter.push_back(new MedianFilter<double>(3));
     }
-    m_pianoRoll.clear();
-    m_columnCount = 0;
+
+    m_columnCountIn = 0;
+    m_columnCountOut = 0;
     m_startTime = RealTime::zeroTime;
+
+    m_agentFeeder = new AgentFeederPoly<NoteHypothesis>();
 }
 
 Silvet::FeatureSet
 Silvet::process(const float *const *inputBuffers, Vamp::RealTime timestamp)
 {
-    if (m_columnCount == 0) {
+    if (m_columnCountIn == 0) {
         m_startTime = timestamp;
     }
     
@@ -423,7 +432,17 @@
 Silvet::getRemainingFeatures()
 {
     Grid cqout = m_cq->getRemainingOutput();
+
     FeatureSet fs = transcribe(cqout);
+
+    m_agentFeeder->finish();
+
+    FeatureList noteFeatures = obtainNotes();
+    for (FeatureList::const_iterator fi = noteFeatures.begin();
+         fi != noteFeatures.end(); ++fi) {
+        fs[m_notesOutputNo].push_back(*fi);
+    }
+
     return fs;
 }
 
@@ -453,7 +472,7 @@
     //!!! pitches or notes? [terminology]
     Grid localPitches(width, vector<double>(pack.templateNoteCount, 0.0));
 
-    bool wantShifts = m_hqMode && m_fineTuning;
+    bool wantShifts = m_hqMode;
     int shiftCount = 1;
     if (wantShifts) {
         shiftCount = pack.templateMaxShift * 2 + 1;
@@ -513,16 +532,13 @@
             for (int j = 0; j < pack.templateNoteCount; ++j) {
                 m_postFilter[j]->push(0.0);
             }
-            m_pianoRoll.push_back(map<int, double>());
-            if (wantShifts) {
-                m_pianoRollShifts.push_back(map<int, int>());
-            }
             continue;
         }
 
-        postProcess(localPitches[i], localBestShifts[i], wantShifts);
+        postProcess(localPitches[i], localBestShifts[i], 
+                    wantShifts, shiftCount);
         
-        FeatureList noteFeatures = noteTrack(shiftCount);
+        FeatureList noteFeatures = obtainNotes();
 
         for (FeatureList::const_iterator fi = noteFeatures.begin();
              fi != noteFeatures.end(); ++fi) {
@@ -556,13 +572,13 @@
 
     for (int i = 0; i < width; ++i) {
 
-        if (m_columnCount < latentColumns) {
-            ++m_columnCount;
+        if (m_columnCountIn < latentColumns) {
+            ++m_columnCountIn;
             continue;
         }
 
-        int prevSampleNo = (m_columnCount - 1) * m_cq->getColumnHop();
-        int sampleNo = m_columnCount * m_cq->getColumnHop();
+        int prevSampleNo = (m_columnCountIn - 1) * m_cq->getColumnHop();
+        int sampleNo = m_columnCountIn * m_cq->getColumnHop();
 
         bool select = (sampleNo / spacing != prevSampleNo / spacing);
 
@@ -611,7 +627,7 @@
             out.push_back(outCol);
         }
 
-        ++m_columnCount;
+        ++m_columnCountIn;
     }
 
     return out;
@@ -620,7 +636,8 @@
 void
 Silvet::postProcess(const vector<double> &pitches,
                     const vector<int> &bestShifts,
-                    bool wantShifts)
+                    bool wantShifts,
+                    int shiftCount)
 {
     const InstrumentPack &pack = m_instruments[m_instrument];
 
@@ -631,178 +648,82 @@
         filtered.push_back(m_postFilter[j]->get());
     }
 
-    // Threshold for level and reduce number of candidate pitches
+    double threshold = 1;
 
-    int polyphony = 5;
-
-    //!!! make this a parameter (was 4.8, try adjusting, compare levels against matlab code)
-    double threshold = 6;
-//    double threshold = 4.8;
-
-    typedef std::multimap<double, int> ValueIndexMap;
-
-    ValueIndexMap strengths;
+    double columnDuration = 1.0 / m_colsPerSec;
+    int postFilterLatency = int(m_postFilter[0]->getSize() / 2);
+    RealTime t = RealTime::fromSeconds
+        (columnDuration * (m_columnCountOut - postFilterLatency) + 0.02);
 
     for (int j = 0; j < pack.templateNoteCount; ++j) {
+
         double strength = filtered[j];
-        if (strength < threshold) continue;
-        strengths.insert(ValueIndexMap::value_type(strength, j));
+        if (strength < threshold) {
+            continue;
+        }
+
+        double freq;
+        if (wantShifts) {
+            freq = noteFrequency(j, bestShifts[j], shiftCount);
+        } else {
+            freq = noteFrequency(j, 0, shiftCount);
+        }
+
+        double confidence = strength / 50.0; //!!!???
+        if (confidence > 1.0) confidence = 1.0;
+
+        AgentHypothesis::Observation obs(freq, t, confidence);
+        m_agentFeeder->feed(obs);
     }
 
-    ValueIndexMap::const_iterator si = strengths.end();
-
-    map<int, double> active;
-    map<int, int> activeShifts;
-
-    while (int(active.size()) < polyphony && si != strengths.begin()) {
-
-        --si;
-
-        double strength = si->first;
-        int j = si->second;
-
-        active[j] = strength;
-
-        if (wantShifts) {
-            activeShifts[j] = bestShifts[j];
-        }
-    }
-
-    m_pianoRoll.push_back(active);
-
-    if (wantShifts) {
-        m_pianoRollShifts.push_back(activeShifts);
-    }
+    m_columnCountOut ++;
 }
 
 Vamp::Plugin::FeatureList
-Silvet::noteTrack(int shiftCount)
+Silvet::obtainNotes()
 {        
-    // Minimum duration pruning, and conversion to notes. We can only
-    // report notes that have just ended (i.e. that are absent in the
-    // latest active set but present in the prior set in the piano
-    // roll) -- any notes that ended earlier will have been reported
-    // already, and if they haven't ended, we don't know their
-    // duration.
-
-    int width = m_pianoRoll.size() - 1;
-
-    const map<int, double> &active = m_pianoRoll[width];
-
-    double columnDuration = 1.0 / m_colsPerSec;
-
-    // only keep notes >= 100ms or thereabouts
-    int durationThreshold = floor(0.1 / columnDuration); // columns
-    if (durationThreshold < 1) durationThreshold = 1;
-
     FeatureList noteFeatures;
 
-    if (width < durationThreshold + 1) {
+    typedef AgentFeederPoly<NoteHypothesis> NoteFeeder;
+
+    NoteFeeder *feeder = dynamic_cast<NoteFeeder *>(m_agentFeeder);
+
+    if (!feeder) {
+        cerr << "INTERNAL ERROR: Feeder is not a poly-note-hypothesis-feeder!"
+             << endl;
         return noteFeatures;
     }
-    
-    //!!! try: repeated note detection? (look for change in first derivative of the pitch matrix)
 
-    for (map<int, double>::const_iterator ni = m_pianoRoll[width-1].begin();
-         ni != m_pianoRoll[width-1].end(); ++ni) {
+    std::set<NoteHypothesis> hh = feeder->getAcceptedHypotheses();
 
-        int note = ni->first;
-        
-        if (active.find(note) != active.end()) {
-            // the note is still playing
-            continue;
+    //!!! inefficient
+    for (std::set<NoteHypothesis>::const_iterator hi = hh.begin();
+         hi != hh.end(); ++hi) { 
+
+        NoteHypothesis h(*hi);
+
+        if (m_emitted.find(h) != m_emitted.end()) {
+            continue; // already returned this one
         }
 
-        // the note was playing but just ended
-        int end = width;
-        int start = end-1;
+        m_emitted.insert(h);
 
-        while (m_pianoRoll[start].find(note) != m_pianoRoll[start].end()) {
-            --start;
-        }
-        ++start;
+        NoteHypothesis::Note n = h.getAveragedNote();
 
-        if ((end - start) < durationThreshold) {
-            continue;
-        }
+        int velocity = n.confidence * 127;
+        if (velocity > 127) velocity = 127;
 
-        emitNote(start, end, note, shiftCount, noteFeatures);
+        Feature f;
+        f.hasTimestamp = true;
+        f.hasDuration = true;
+        f.timestamp = n.time;
+        f.duration = n.duration;
+        f.values.clear();
+        f.values.push_back(n.freq);
+        f.values.push_back(velocity);
+//        f.label = noteName(note, partShift, shiftCount);
+        noteFeatures.push_back(f);
     }
 
-//    cerr << "returning " << noteFeatures.size() << " complete note(s) " << endl;
-
     return noteFeatures;
 }
-
-void
-Silvet::emitNote(int start, int end, int note, int shiftCount,
-                 FeatureList &noteFeatures)
-{
-    int partStart = start;
-    int partShift = 0;
-    int partVelocity = 0;
-
-    Feature f;
-    f.hasTimestamp = true;
-    f.hasDuration = true;
-
-    double columnDuration = 1.0 / m_colsPerSec;
-    int postFilterLatency = int(m_postFilter[0]->getSize() / 2);
-    int partThreshold = floor(0.05 / columnDuration);
-
-    for (int i = start; i != end; ++i) {
-        
-        double strength = m_pianoRoll[i][note];
-
-        int shift = 0;
-
-        if (shiftCount > 1) {
-
-            shift = m_pianoRollShifts[i][note];
-
-            if (i == partStart) {
-                partShift = shift;
-            }
-
-            if (i > partStart + partThreshold && shift != partShift) {
-                
-//                cerr << "i = " << i << ", partStart = " << partStart << ", shift = " << shift << ", partShift = " << partShift << endl;
-
-                // pitch has changed, emit an intermediate note
-                f.timestamp = RealTime::fromSeconds
-                    (columnDuration * (partStart - postFilterLatency) + 0.02);
-                f.duration = RealTime::fromSeconds
-                    (columnDuration * (i - partStart));
-                f.values.clear();
-                f.values.push_back
-                    (noteFrequency(note, partShift, shiftCount));
-                f.values.push_back(partVelocity);
-                f.label = noteName(note, partShift, shiftCount);
-                noteFeatures.push_back(f);
-                partStart = i;
-                partShift = shift;
-                partVelocity = 0;
-            }
-        }
-
-        int v = strength * 2;
-        if (v > 127) v = 127;
-
-        if (v > partVelocity) {
-            partVelocity = v;
-        }
-    }
-
-    if (end >= partStart + partThreshold) {
-        f.timestamp = RealTime::fromSeconds
-            (columnDuration * (partStart - postFilterLatency) + 0.02);
-        f.duration = RealTime::fromSeconds
-            (columnDuration * (end - partStart));
-        f.values.clear();
-        f.values.push_back
-            (noteFrequency(note, partShift, shiftCount));
-        f.values.push_back(partVelocity);
-        f.label = noteName(note, partShift, shiftCount);
-        noteFeatures.push_back(f);
-    }
-}