comparison src/Silvet.cpp @ 184:9b9cdfccbd14 noteagent

Wire up note agent code -- results are not very good, so far
author Chris Cannam
date Wed, 28 May 2014 14:54:01 +0100
parents 825193ef09d2
children 78212f764251
comparison
equal deleted inserted replaced
182:e1718e64a921 184:9b9cdfccbd14
17 #include "EM.h" 17 #include "EM.h"
18 18
19 #include <cq/CQSpectrogram.h> 19 #include <cq/CQSpectrogram.h>
20 20
21 #include "MedianFilter.h" 21 #include "MedianFilter.h"
22 #include "AgentFeederPoly.h"
23 #include "NoteHypothesis.h"
24
22 #include "constant-q-cpp/src/dsp/Resampler.h" 25 #include "constant-q-cpp/src/dsp/Resampler.h"
23 26
24 #include <vector> 27 #include <vector>
25 28
26 #include <cstdio> 29 #include <cstdio>
40 m_resampler(0), 43 m_resampler(0),
41 m_cq(0), 44 m_cq(0),
42 m_hqMode(true), 45 m_hqMode(true),
43 m_fineTuning(false), 46 m_fineTuning(false),
44 m_instrument(0), 47 m_instrument(0),
45 m_colsPerSec(50) 48 m_colsPerSec(50),
49 m_agentFeeder(0)
46 { 50 {
47 } 51 }
48 52
49 Silvet::~Silvet() 53 Silvet::~Silvet()
50 { 54 {
51 delete m_resampler; 55 delete m_resampler;
52 delete m_cq; 56 delete m_cq;
53 for (int i = 0; i < (int)m_postFilter.size(); ++i) { 57 for (int i = 0; i < (int)m_postFilter.size(); ++i) {
54 delete m_postFilter[i]; 58 delete m_postFilter[i];
55 } 59 }
60 delete m_agentFeeder;
56 } 61 }
57 62
58 string 63 string
59 Silvet::getIdentifier() const 64 Silvet::getIdentifier() const
60 { 65 {
351 void 356 void
352 Silvet::reset() 357 Silvet::reset()
353 { 358 {
354 delete m_resampler; 359 delete m_resampler;
355 delete m_cq; 360 delete m_cq;
361 delete m_agentFeeder;
356 362
357 if (m_inputSampleRate != processingSampleRate) { 363 if (m_inputSampleRate != processingSampleRate) {
358 m_resampler = new Resampler(m_inputSampleRate, processingSampleRate); 364 m_resampler = new Resampler(m_inputSampleRate, processingSampleRate);
359 } else { 365 } else {
360 m_resampler = 0; 366 m_resampler = 0;
391 } 397 }
392 m_postFilter.clear(); 398 m_postFilter.clear();
393 for (int i = 0; i < m_instruments[0].templateNoteCount; ++i) { 399 for (int i = 0; i < m_instruments[0].templateNoteCount; ++i) {
394 m_postFilter.push_back(new MedianFilter<double>(3)); 400 m_postFilter.push_back(new MedianFilter<double>(3));
395 } 401 }
396 m_pianoRoll.clear(); 402
397 m_columnCount = 0; 403 m_columnCountIn = 0;
404 m_columnCountOut = 0;
398 m_startTime = RealTime::zeroTime; 405 m_startTime = RealTime::zeroTime;
406
407 m_agentFeeder = new AgentFeederPoly<NoteHypothesis>();
399 } 408 }
400 409
401 Silvet::FeatureSet 410 Silvet::FeatureSet
402 Silvet::process(const float *const *inputBuffers, Vamp::RealTime timestamp) 411 Silvet::process(const float *const *inputBuffers, Vamp::RealTime timestamp)
403 { 412 {
404 if (m_columnCount == 0) { 413 if (m_columnCountIn == 0) {
405 m_startTime = timestamp; 414 m_startTime = timestamp;
406 } 415 }
407 416
408 vector<double> data; 417 vector<double> data;
409 for (int i = 0; i < m_blockSize; ++i) { 418 for (int i = 0; i < m_blockSize; ++i) {
421 430
422 Silvet::FeatureSet 431 Silvet::FeatureSet
423 Silvet::getRemainingFeatures() 432 Silvet::getRemainingFeatures()
424 { 433 {
425 Grid cqout = m_cq->getRemainingOutput(); 434 Grid cqout = m_cq->getRemainingOutput();
435
426 FeatureSet fs = transcribe(cqout); 436 FeatureSet fs = transcribe(cqout);
437
438 m_agentFeeder->finish();
439
440 FeatureList noteFeatures = obtainNotes();
441 for (FeatureList::const_iterator fi = noteFeatures.begin();
442 fi != noteFeatures.end(); ++fi) {
443 fs[m_notesOutputNo].push_back(*fi);
444 }
445
427 return fs; 446 return fs;
428 } 447 }
429 448
430 Silvet::FeatureSet 449 Silvet::FeatureSet
431 Silvet::transcribe(const Grid &cqout) 450 Silvet::transcribe(const Grid &cqout)
451 int iterations = m_hqMode ? 20 : 10; 470 int iterations = m_hqMode ? 20 : 10;
452 471
453 //!!! pitches or notes? [terminology] 472 //!!! pitches or notes? [terminology]
454 Grid localPitches(width, vector<double>(pack.templateNoteCount, 0.0)); 473 Grid localPitches(width, vector<double>(pack.templateNoteCount, 0.0));
455 474
456 bool wantShifts = m_hqMode && m_fineTuning; 475 bool wantShifts = m_hqMode;
457 int shiftCount = 1; 476 int shiftCount = 1;
458 if (wantShifts) { 477 if (wantShifts) {
459 shiftCount = pack.templateMaxShift * 2 + 1; 478 shiftCount = pack.templateMaxShift * 2 + 1;
460 } 479 }
461 480
511 if (!present[i]) { 530 if (!present[i]) {
512 // silent column 531 // silent column
513 for (int j = 0; j < pack.templateNoteCount; ++j) { 532 for (int j = 0; j < pack.templateNoteCount; ++j) {
514 m_postFilter[j]->push(0.0); 533 m_postFilter[j]->push(0.0);
515 } 534 }
516 m_pianoRoll.push_back(map<int, double>());
517 if (wantShifts) {
518 m_pianoRollShifts.push_back(map<int, int>());
519 }
520 continue; 535 continue;
521 } 536 }
522 537
523 postProcess(localPitches[i], localBestShifts[i], wantShifts); 538 postProcess(localPitches[i], localBestShifts[i],
539 wantShifts, shiftCount);
524 540
525 FeatureList noteFeatures = noteTrack(shiftCount); 541 FeatureList noteFeatures = obtainNotes();
526 542
527 for (FeatureList::const_iterator fi = noteFeatures.begin(); 543 for (FeatureList::const_iterator fi = noteFeatures.begin();
528 fi != noteFeatures.end(); ++fi) { 544 fi != noteFeatures.end(); ++fi) {
529 fs[m_notesOutputNo].push_back(*fi); 545 fs[m_notesOutputNo].push_back(*fi);
530 } 546 }
554 570
555 const InstrumentPack &pack = m_instruments[m_instrument]; 571 const InstrumentPack &pack = m_instruments[m_instrument];
556 572
557 for (int i = 0; i < width; ++i) { 573 for (int i = 0; i < width; ++i) {
558 574
559 if (m_columnCount < latentColumns) { 575 if (m_columnCountIn < latentColumns) {
560 ++m_columnCount; 576 ++m_columnCountIn;
561 continue; 577 continue;
562 } 578 }
563 579
564 int prevSampleNo = (m_columnCount - 1) * m_cq->getColumnHop(); 580 int prevSampleNo = (m_columnCountIn - 1) * m_cq->getColumnHop();
565 int sampleNo = m_columnCount * m_cq->getColumnHop(); 581 int sampleNo = m_columnCountIn * m_cq->getColumnHop();
566 582
567 bool select = (sampleNo / spacing != prevSampleNo / spacing); 583 bool select = (sampleNo / spacing != prevSampleNo / spacing);
568 584
569 if (select) { 585 if (select) {
570 vector<double> inCol = in[i]; 586 vector<double> inCol = in[i];
609 } 625 }
610 626
611 out.push_back(outCol); 627 out.push_back(outCol);
612 } 628 }
613 629
614 ++m_columnCount; 630 ++m_columnCountIn;
615 } 631 }
616 632
617 return out; 633 return out;
618 } 634 }
619 635
620 void 636 void
621 Silvet::postProcess(const vector<double> &pitches, 637 Silvet::postProcess(const vector<double> &pitches,
622 const vector<int> &bestShifts, 638 const vector<int> &bestShifts,
623 bool wantShifts) 639 bool wantShifts,
640 int shiftCount)
624 { 641 {
625 const InstrumentPack &pack = m_instruments[m_instrument]; 642 const InstrumentPack &pack = m_instruments[m_instrument];
626 643
627 vector<double> filtered; 644 vector<double> filtered;
628 645
629 for (int j = 0; j < pack.templateNoteCount; ++j) { 646 for (int j = 0; j < pack.templateNoteCount; ++j) {
630 m_postFilter[j]->push(pitches[j]); 647 m_postFilter[j]->push(pitches[j]);
631 filtered.push_back(m_postFilter[j]->get()); 648 filtered.push_back(m_postFilter[j]->get());
632 } 649 }
633 650
634 // Threshold for level and reduce number of candidate pitches 651 double threshold = 1;
635
636 int polyphony = 5;
637
638 //!!! make this a parameter (was 4.8, try adjusting, compare levels against matlab code)
639 double threshold = 6;
640 // double threshold = 4.8;
641
642 typedef std::multimap<double, int> ValueIndexMap;
643
644 ValueIndexMap strengths;
645
646 for (int j = 0; j < pack.templateNoteCount; ++j) {
647 double strength = filtered[j];
648 if (strength < threshold) continue;
649 strengths.insert(ValueIndexMap::value_type(strength, j));
650 }
651
652 ValueIndexMap::const_iterator si = strengths.end();
653
654 map<int, double> active;
655 map<int, int> activeShifts;
656
657 while (int(active.size()) < polyphony && si != strengths.begin()) {
658
659 --si;
660
661 double strength = si->first;
662 int j = si->second;
663
664 active[j] = strength;
665
666 if (wantShifts) {
667 activeShifts[j] = bestShifts[j];
668 }
669 }
670
671 m_pianoRoll.push_back(active);
672
673 if (wantShifts) {
674 m_pianoRollShifts.push_back(activeShifts);
675 }
676 }
677
678 Vamp::Plugin::FeatureList
679 Silvet::noteTrack(int shiftCount)
680 {
681 // Minimum duration pruning, and conversion to notes. We can only
682 // report notes that have just ended (i.e. that are absent in the
683 // latest active set but present in the prior set in the piano
684 // roll) -- any notes that ended earlier will have been reported
685 // already, and if they haven't ended, we don't know their
686 // duration.
687
688 int width = m_pianoRoll.size() - 1;
689
690 const map<int, double> &active = m_pianoRoll[width];
691
692 double columnDuration = 1.0 / m_colsPerSec;
693
694 // only keep notes >= 100ms or thereabouts
695 int durationThreshold = floor(0.1 / columnDuration); // columns
696 if (durationThreshold < 1) durationThreshold = 1;
697
698 FeatureList noteFeatures;
699
700 if (width < durationThreshold + 1) {
701 return noteFeatures;
702 }
703
704 //!!! try: repeated note detection? (look for change in first derivative of the pitch matrix)
705
706 for (map<int, double>::const_iterator ni = m_pianoRoll[width-1].begin();
707 ni != m_pianoRoll[width-1].end(); ++ni) {
708
709 int note = ni->first;
710
711 if (active.find(note) != active.end()) {
712 // the note is still playing
713 continue;
714 }
715
716 // the note was playing but just ended
717 int end = width;
718 int start = end-1;
719
720 while (m_pianoRoll[start].find(note) != m_pianoRoll[start].end()) {
721 --start;
722 }
723 ++start;
724
725 if ((end - start) < durationThreshold) {
726 continue;
727 }
728
729 emitNote(start, end, note, shiftCount, noteFeatures);
730 }
731
732 // cerr << "returning " << noteFeatures.size() << " complete note(s) " << endl;
733
734 return noteFeatures;
735 }
736
737 void
738 Silvet::emitNote(int start, int end, int note, int shiftCount,
739 FeatureList &noteFeatures)
740 {
741 int partStart = start;
742 int partShift = 0;
743 int partVelocity = 0;
744
745 Feature f;
746 f.hasTimestamp = true;
747 f.hasDuration = true;
748 652
749 double columnDuration = 1.0 / m_colsPerSec; 653 double columnDuration = 1.0 / m_colsPerSec;
750 int postFilterLatency = int(m_postFilter[0]->getSize() / 2); 654 int postFilterLatency = int(m_postFilter[0]->getSize() / 2);
751 int partThreshold = floor(0.05 / columnDuration); 655 RealTime t = RealTime::fromSeconds
752 656 (columnDuration * (m_columnCountOut - postFilterLatency) + 0.02);
753 for (int i = start; i != end; ++i) { 657
754 658 for (int j = 0; j < pack.templateNoteCount; ++j) {
755 double strength = m_pianoRoll[i][note]; 659
756 660 double strength = filtered[j];
757 int shift = 0; 661 if (strength < threshold) {
758 662 continue;
759 if (shiftCount > 1) { 663 }
760 664
761 shift = m_pianoRollShifts[i][note]; 665 double freq;
762 666 if (wantShifts) {
763 if (i == partStart) { 667 freq = noteFrequency(j, bestShifts[j], shiftCount);
764 partShift = shift; 668 } else {
765 } 669 freq = noteFrequency(j, 0, shiftCount);
766 670 }
767 if (i > partStart + partThreshold && shift != partShift) { 671
768 672 double confidence = strength / 50.0; //!!!???
769 // cerr << "i = " << i << ", partStart = " << partStart << ", shift = " << shift << ", partShift = " << partShift << endl; 673 if (confidence > 1.0) confidence = 1.0;
770 674
771 // pitch has changed, emit an intermediate note 675 AgentHypothesis::Observation obs(freq, t, confidence);
772 f.timestamp = RealTime::fromSeconds 676 m_agentFeeder->feed(obs);
773 (columnDuration * (partStart - postFilterLatency) + 0.02); 677 }
774 f.duration = RealTime::fromSeconds 678
775 (columnDuration * (i - partStart)); 679 m_columnCountOut ++;
776 f.values.clear(); 680 }
777 f.values.push_back 681
778 (noteFrequency(note, partShift, shiftCount)); 682 Vamp::Plugin::FeatureList
779 f.values.push_back(partVelocity); 683 Silvet::obtainNotes()
780 f.label = noteName(note, partShift, shiftCount); 684 {
781 noteFeatures.push_back(f); 685 FeatureList noteFeatures;
782 partStart = i; 686
783 partShift = shift; 687 typedef AgentFeederPoly<NoteHypothesis> NoteFeeder;
784 partVelocity = 0; 688
785 } 689 NoteFeeder *feeder = dynamic_cast<NoteFeeder *>(m_agentFeeder);
786 } 690
787 691 if (!feeder) {
788 int v = strength * 2; 692 cerr << "INTERNAL ERROR: Feeder is not a poly-note-hypothesis-feeder!"
789 if (v > 127) v = 127; 693 << endl;
790 694 return noteFeatures;
791 if (v > partVelocity) { 695 }
792 partVelocity = v; 696
793 } 697 std::set<NoteHypothesis> hh = feeder->getAcceptedHypotheses();
794 } 698
795 699 //!!! inefficient
796 if (end >= partStart + partThreshold) { 700 for (std::set<NoteHypothesis>::const_iterator hi = hh.begin();
797 f.timestamp = RealTime::fromSeconds 701 hi != hh.end(); ++hi) {
798 (columnDuration * (partStart - postFilterLatency) + 0.02); 702
799 f.duration = RealTime::fromSeconds 703 NoteHypothesis h(*hi);
800 (columnDuration * (end - partStart)); 704
705 if (m_emitted.find(h) != m_emitted.end()) {
706 continue; // already returned this one
707 }
708
709 m_emitted.insert(h);
710
711 NoteHypothesis::Note n = h.getAveragedNote();
712
713 int velocity = n.confidence * 127;
714 if (velocity > 127) velocity = 127;
715
716 Feature f;
717 f.hasTimestamp = true;
718 f.hasDuration = true;
719 f.timestamp = n.time;
720 f.duration = n.duration;
801 f.values.clear(); 721 f.values.clear();
802 f.values.push_back 722 f.values.push_back(n.freq);
803 (noteFrequency(note, partShift, shiftCount)); 723 f.values.push_back(velocity);
804 f.values.push_back(partVelocity); 724 // f.label = noteName(note, partShift, shiftCount);
805 f.label = noteName(note, partShift, shiftCount);
806 noteFeatures.push_back(f); 725 noteFeatures.push_back(f);
807 } 726 }
808 } 727
728 return noteFeatures;
729 }