Mercurial > hg > silvet
comparison src/Silvet.cpp @ 184:9b9cdfccbd14 noteagent
Wire up note agent code -- results are not very good, so far
author | Chris Cannam |
---|---|
date | Wed, 28 May 2014 14:54:01 +0100 |
parents | 825193ef09d2 |
children | 78212f764251 |
comparison
equal
deleted
inserted
replaced
182:e1718e64a921 | 184:9b9cdfccbd14 |
---|---|
17 #include "EM.h" | 17 #include "EM.h" |
18 | 18 |
19 #include <cq/CQSpectrogram.h> | 19 #include <cq/CQSpectrogram.h> |
20 | 20 |
21 #include "MedianFilter.h" | 21 #include "MedianFilter.h" |
22 #include "AgentFeederPoly.h" | |
23 #include "NoteHypothesis.h" | |
24 | |
22 #include "constant-q-cpp/src/dsp/Resampler.h" | 25 #include "constant-q-cpp/src/dsp/Resampler.h" |
23 | 26 |
24 #include <vector> | 27 #include <vector> |
25 | 28 |
26 #include <cstdio> | 29 #include <cstdio> |
40 m_resampler(0), | 43 m_resampler(0), |
41 m_cq(0), | 44 m_cq(0), |
42 m_hqMode(true), | 45 m_hqMode(true), |
43 m_fineTuning(false), | 46 m_fineTuning(false), |
44 m_instrument(0), | 47 m_instrument(0), |
45 m_colsPerSec(50) | 48 m_colsPerSec(50), |
49 m_agentFeeder(0) | |
46 { | 50 { |
47 } | 51 } |
48 | 52 |
49 Silvet::~Silvet() | 53 Silvet::~Silvet() |
50 { | 54 { |
51 delete m_resampler; | 55 delete m_resampler; |
52 delete m_cq; | 56 delete m_cq; |
53 for (int i = 0; i < (int)m_postFilter.size(); ++i) { | 57 for (int i = 0; i < (int)m_postFilter.size(); ++i) { |
54 delete m_postFilter[i]; | 58 delete m_postFilter[i]; |
55 } | 59 } |
60 delete m_agentFeeder; | |
56 } | 61 } |
57 | 62 |
58 string | 63 string |
59 Silvet::getIdentifier() const | 64 Silvet::getIdentifier() const |
60 { | 65 { |
351 void | 356 void |
352 Silvet::reset() | 357 Silvet::reset() |
353 { | 358 { |
354 delete m_resampler; | 359 delete m_resampler; |
355 delete m_cq; | 360 delete m_cq; |
361 delete m_agentFeeder; | |
356 | 362 |
357 if (m_inputSampleRate != processingSampleRate) { | 363 if (m_inputSampleRate != processingSampleRate) { |
358 m_resampler = new Resampler(m_inputSampleRate, processingSampleRate); | 364 m_resampler = new Resampler(m_inputSampleRate, processingSampleRate); |
359 } else { | 365 } else { |
360 m_resampler = 0; | 366 m_resampler = 0; |
391 } | 397 } |
392 m_postFilter.clear(); | 398 m_postFilter.clear(); |
393 for (int i = 0; i < m_instruments[0].templateNoteCount; ++i) { | 399 for (int i = 0; i < m_instruments[0].templateNoteCount; ++i) { |
394 m_postFilter.push_back(new MedianFilter<double>(3)); | 400 m_postFilter.push_back(new MedianFilter<double>(3)); |
395 } | 401 } |
396 m_pianoRoll.clear(); | 402 |
397 m_columnCount = 0; | 403 m_columnCountIn = 0; |
404 m_columnCountOut = 0; | |
398 m_startTime = RealTime::zeroTime; | 405 m_startTime = RealTime::zeroTime; |
406 | |
407 m_agentFeeder = new AgentFeederPoly<NoteHypothesis>(); | |
399 } | 408 } |
400 | 409 |
401 Silvet::FeatureSet | 410 Silvet::FeatureSet |
402 Silvet::process(const float *const *inputBuffers, Vamp::RealTime timestamp) | 411 Silvet::process(const float *const *inputBuffers, Vamp::RealTime timestamp) |
403 { | 412 { |
404 if (m_columnCount == 0) { | 413 if (m_columnCountIn == 0) { |
405 m_startTime = timestamp; | 414 m_startTime = timestamp; |
406 } | 415 } |
407 | 416 |
408 vector<double> data; | 417 vector<double> data; |
409 for (int i = 0; i < m_blockSize; ++i) { | 418 for (int i = 0; i < m_blockSize; ++i) { |
421 | 430 |
422 Silvet::FeatureSet | 431 Silvet::FeatureSet |
423 Silvet::getRemainingFeatures() | 432 Silvet::getRemainingFeatures() |
424 { | 433 { |
425 Grid cqout = m_cq->getRemainingOutput(); | 434 Grid cqout = m_cq->getRemainingOutput(); |
435 | |
426 FeatureSet fs = transcribe(cqout); | 436 FeatureSet fs = transcribe(cqout); |
437 | |
438 m_agentFeeder->finish(); | |
439 | |
440 FeatureList noteFeatures = obtainNotes(); | |
441 for (FeatureList::const_iterator fi = noteFeatures.begin(); | |
442 fi != noteFeatures.end(); ++fi) { | |
443 fs[m_notesOutputNo].push_back(*fi); | |
444 } | |
445 | |
427 return fs; | 446 return fs; |
428 } | 447 } |
429 | 448 |
430 Silvet::FeatureSet | 449 Silvet::FeatureSet |
431 Silvet::transcribe(const Grid &cqout) | 450 Silvet::transcribe(const Grid &cqout) |
451 int iterations = m_hqMode ? 20 : 10; | 470 int iterations = m_hqMode ? 20 : 10; |
452 | 471 |
453 //!!! pitches or notes? [terminology] | 472 //!!! pitches or notes? [terminology] |
454 Grid localPitches(width, vector<double>(pack.templateNoteCount, 0.0)); | 473 Grid localPitches(width, vector<double>(pack.templateNoteCount, 0.0)); |
455 | 474 |
456 bool wantShifts = m_hqMode && m_fineTuning; | 475 bool wantShifts = m_hqMode; |
457 int shiftCount = 1; | 476 int shiftCount = 1; |
458 if (wantShifts) { | 477 if (wantShifts) { |
459 shiftCount = pack.templateMaxShift * 2 + 1; | 478 shiftCount = pack.templateMaxShift * 2 + 1; |
460 } | 479 } |
461 | 480 |
511 if (!present[i]) { | 530 if (!present[i]) { |
512 // silent column | 531 // silent column |
513 for (int j = 0; j < pack.templateNoteCount; ++j) { | 532 for (int j = 0; j < pack.templateNoteCount; ++j) { |
514 m_postFilter[j]->push(0.0); | 533 m_postFilter[j]->push(0.0); |
515 } | 534 } |
516 m_pianoRoll.push_back(map<int, double>()); | |
517 if (wantShifts) { | |
518 m_pianoRollShifts.push_back(map<int, int>()); | |
519 } | |
520 continue; | 535 continue; |
521 } | 536 } |
522 | 537 |
523 postProcess(localPitches[i], localBestShifts[i], wantShifts); | 538 postProcess(localPitches[i], localBestShifts[i], |
539 wantShifts, shiftCount); | |
524 | 540 |
525 FeatureList noteFeatures = noteTrack(shiftCount); | 541 FeatureList noteFeatures = obtainNotes(); |
526 | 542 |
527 for (FeatureList::const_iterator fi = noteFeatures.begin(); | 543 for (FeatureList::const_iterator fi = noteFeatures.begin(); |
528 fi != noteFeatures.end(); ++fi) { | 544 fi != noteFeatures.end(); ++fi) { |
529 fs[m_notesOutputNo].push_back(*fi); | 545 fs[m_notesOutputNo].push_back(*fi); |
530 } | 546 } |
554 | 570 |
555 const InstrumentPack &pack = m_instruments[m_instrument]; | 571 const InstrumentPack &pack = m_instruments[m_instrument]; |
556 | 572 |
557 for (int i = 0; i < width; ++i) { | 573 for (int i = 0; i < width; ++i) { |
558 | 574 |
559 if (m_columnCount < latentColumns) { | 575 if (m_columnCountIn < latentColumns) { |
560 ++m_columnCount; | 576 ++m_columnCountIn; |
561 continue; | 577 continue; |
562 } | 578 } |
563 | 579 |
564 int prevSampleNo = (m_columnCount - 1) * m_cq->getColumnHop(); | 580 int prevSampleNo = (m_columnCountIn - 1) * m_cq->getColumnHop(); |
565 int sampleNo = m_columnCount * m_cq->getColumnHop(); | 581 int sampleNo = m_columnCountIn * m_cq->getColumnHop(); |
566 | 582 |
567 bool select = (sampleNo / spacing != prevSampleNo / spacing); | 583 bool select = (sampleNo / spacing != prevSampleNo / spacing); |
568 | 584 |
569 if (select) { | 585 if (select) { |
570 vector<double> inCol = in[i]; | 586 vector<double> inCol = in[i]; |
609 } | 625 } |
610 | 626 |
611 out.push_back(outCol); | 627 out.push_back(outCol); |
612 } | 628 } |
613 | 629 |
614 ++m_columnCount; | 630 ++m_columnCountIn; |
615 } | 631 } |
616 | 632 |
617 return out; | 633 return out; |
618 } | 634 } |
619 | 635 |
620 void | 636 void |
621 Silvet::postProcess(const vector<double> &pitches, | 637 Silvet::postProcess(const vector<double> &pitches, |
622 const vector<int> &bestShifts, | 638 const vector<int> &bestShifts, |
623 bool wantShifts) | 639 bool wantShifts, |
640 int shiftCount) | |
624 { | 641 { |
625 const InstrumentPack &pack = m_instruments[m_instrument]; | 642 const InstrumentPack &pack = m_instruments[m_instrument]; |
626 | 643 |
627 vector<double> filtered; | 644 vector<double> filtered; |
628 | 645 |
629 for (int j = 0; j < pack.templateNoteCount; ++j) { | 646 for (int j = 0; j < pack.templateNoteCount; ++j) { |
630 m_postFilter[j]->push(pitches[j]); | 647 m_postFilter[j]->push(pitches[j]); |
631 filtered.push_back(m_postFilter[j]->get()); | 648 filtered.push_back(m_postFilter[j]->get()); |
632 } | 649 } |
633 | 650 |
634 // Threshold for level and reduce number of candidate pitches | 651 double threshold = 1; |
635 | |
636 int polyphony = 5; | |
637 | |
638 //!!! make this a parameter (was 4.8, try adjusting, compare levels against matlab code) | |
639 double threshold = 6; | |
640 // double threshold = 4.8; | |
641 | |
642 typedef std::multimap<double, int> ValueIndexMap; | |
643 | |
644 ValueIndexMap strengths; | |
645 | |
646 for (int j = 0; j < pack.templateNoteCount; ++j) { | |
647 double strength = filtered[j]; | |
648 if (strength < threshold) continue; | |
649 strengths.insert(ValueIndexMap::value_type(strength, j)); | |
650 } | |
651 | |
652 ValueIndexMap::const_iterator si = strengths.end(); | |
653 | |
654 map<int, double> active; | |
655 map<int, int> activeShifts; | |
656 | |
657 while (int(active.size()) < polyphony && si != strengths.begin()) { | |
658 | |
659 --si; | |
660 | |
661 double strength = si->first; | |
662 int j = si->second; | |
663 | |
664 active[j] = strength; | |
665 | |
666 if (wantShifts) { | |
667 activeShifts[j] = bestShifts[j]; | |
668 } | |
669 } | |
670 | |
671 m_pianoRoll.push_back(active); | |
672 | |
673 if (wantShifts) { | |
674 m_pianoRollShifts.push_back(activeShifts); | |
675 } | |
676 } | |
677 | |
678 Vamp::Plugin::FeatureList | |
679 Silvet::noteTrack(int shiftCount) | |
680 { | |
681 // Minimum duration pruning, and conversion to notes. We can only | |
682 // report notes that have just ended (i.e. that are absent in the | |
683 // latest active set but present in the prior set in the piano | |
684 // roll) -- any notes that ended earlier will have been reported | |
685 // already, and if they haven't ended, we don't know their | |
686 // duration. | |
687 | |
688 int width = m_pianoRoll.size() - 1; | |
689 | |
690 const map<int, double> &active = m_pianoRoll[width]; | |
691 | |
692 double columnDuration = 1.0 / m_colsPerSec; | |
693 | |
694 // only keep notes >= 100ms or thereabouts | |
695 int durationThreshold = floor(0.1 / columnDuration); // columns | |
696 if (durationThreshold < 1) durationThreshold = 1; | |
697 | |
698 FeatureList noteFeatures; | |
699 | |
700 if (width < durationThreshold + 1) { | |
701 return noteFeatures; | |
702 } | |
703 | |
704 //!!! try: repeated note detection? (look for change in first derivative of the pitch matrix) | |
705 | |
706 for (map<int, double>::const_iterator ni = m_pianoRoll[width-1].begin(); | |
707 ni != m_pianoRoll[width-1].end(); ++ni) { | |
708 | |
709 int note = ni->first; | |
710 | |
711 if (active.find(note) != active.end()) { | |
712 // the note is still playing | |
713 continue; | |
714 } | |
715 | |
716 // the note was playing but just ended | |
717 int end = width; | |
718 int start = end-1; | |
719 | |
720 while (m_pianoRoll[start].find(note) != m_pianoRoll[start].end()) { | |
721 --start; | |
722 } | |
723 ++start; | |
724 | |
725 if ((end - start) < durationThreshold) { | |
726 continue; | |
727 } | |
728 | |
729 emitNote(start, end, note, shiftCount, noteFeatures); | |
730 } | |
731 | |
732 // cerr << "returning " << noteFeatures.size() << " complete note(s) " << endl; | |
733 | |
734 return noteFeatures; | |
735 } | |
736 | |
737 void | |
738 Silvet::emitNote(int start, int end, int note, int shiftCount, | |
739 FeatureList ¬eFeatures) | |
740 { | |
741 int partStart = start; | |
742 int partShift = 0; | |
743 int partVelocity = 0; | |
744 | |
745 Feature f; | |
746 f.hasTimestamp = true; | |
747 f.hasDuration = true; | |
748 | 652 |
749 double columnDuration = 1.0 / m_colsPerSec; | 653 double columnDuration = 1.0 / m_colsPerSec; |
750 int postFilterLatency = int(m_postFilter[0]->getSize() / 2); | 654 int postFilterLatency = int(m_postFilter[0]->getSize() / 2); |
751 int partThreshold = floor(0.05 / columnDuration); | 655 RealTime t = RealTime::fromSeconds |
752 | 656 (columnDuration * (m_columnCountOut - postFilterLatency) + 0.02); |
753 for (int i = start; i != end; ++i) { | 657 |
754 | 658 for (int j = 0; j < pack.templateNoteCount; ++j) { |
755 double strength = m_pianoRoll[i][note]; | 659 |
756 | 660 double strength = filtered[j]; |
757 int shift = 0; | 661 if (strength < threshold) { |
758 | 662 continue; |
759 if (shiftCount > 1) { | 663 } |
760 | 664 |
761 shift = m_pianoRollShifts[i][note]; | 665 double freq; |
762 | 666 if (wantShifts) { |
763 if (i == partStart) { | 667 freq = noteFrequency(j, bestShifts[j], shiftCount); |
764 partShift = shift; | 668 } else { |
765 } | 669 freq = noteFrequency(j, 0, shiftCount); |
766 | 670 } |
767 if (i > partStart + partThreshold && shift != partShift) { | 671 |
768 | 672 double confidence = strength / 50.0; //!!!??? |
769 // cerr << "i = " << i << ", partStart = " << partStart << ", shift = " << shift << ", partShift = " << partShift << endl; | 673 if (confidence > 1.0) confidence = 1.0; |
770 | 674 |
771 // pitch has changed, emit an intermediate note | 675 AgentHypothesis::Observation obs(freq, t, confidence); |
772 f.timestamp = RealTime::fromSeconds | 676 m_agentFeeder->feed(obs); |
773 (columnDuration * (partStart - postFilterLatency) + 0.02); | 677 } |
774 f.duration = RealTime::fromSeconds | 678 |
775 (columnDuration * (i - partStart)); | 679 m_columnCountOut ++; |
776 f.values.clear(); | 680 } |
777 f.values.push_back | 681 |
778 (noteFrequency(note, partShift, shiftCount)); | 682 Vamp::Plugin::FeatureList |
779 f.values.push_back(partVelocity); | 683 Silvet::obtainNotes() |
780 f.label = noteName(note, partShift, shiftCount); | 684 { |
781 noteFeatures.push_back(f); | 685 FeatureList noteFeatures; |
782 partStart = i; | 686 |
783 partShift = shift; | 687 typedef AgentFeederPoly<NoteHypothesis> NoteFeeder; |
784 partVelocity = 0; | 688 |
785 } | 689 NoteFeeder *feeder = dynamic_cast<NoteFeeder *>(m_agentFeeder); |
786 } | 690 |
787 | 691 if (!feeder) { |
788 int v = strength * 2; | 692 cerr << "INTERNAL ERROR: Feeder is not a poly-note-hypothesis-feeder!" |
789 if (v > 127) v = 127; | 693 << endl; |
790 | 694 return noteFeatures; |
791 if (v > partVelocity) { | 695 } |
792 partVelocity = v; | 696 |
793 } | 697 std::set<NoteHypothesis> hh = feeder->getAcceptedHypotheses(); |
794 } | 698 |
795 | 699 //!!! inefficient |
796 if (end >= partStart + partThreshold) { | 700 for (std::set<NoteHypothesis>::const_iterator hi = hh.begin(); |
797 f.timestamp = RealTime::fromSeconds | 701 hi != hh.end(); ++hi) { |
798 (columnDuration * (partStart - postFilterLatency) + 0.02); | 702 |
799 f.duration = RealTime::fromSeconds | 703 NoteHypothesis h(*hi); |
800 (columnDuration * (end - partStart)); | 704 |
705 if (m_emitted.find(h) != m_emitted.end()) { | |
706 continue; // already returned this one | |
707 } | |
708 | |
709 m_emitted.insert(h); | |
710 | |
711 NoteHypothesis::Note n = h.getAveragedNote(); | |
712 | |
713 int velocity = n.confidence * 127; | |
714 if (velocity > 127) velocity = 127; | |
715 | |
716 Feature f; | |
717 f.hasTimestamp = true; | |
718 f.hasDuration = true; | |
719 f.timestamp = n.time; | |
720 f.duration = n.duration; | |
801 f.values.clear(); | 721 f.values.clear(); |
802 f.values.push_back | 722 f.values.push_back(n.freq); |
803 (noteFrequency(note, partShift, shiftCount)); | 723 f.values.push_back(velocity); |
804 f.values.push_back(partVelocity); | 724 // f.label = noteName(note, partShift, shiftCount); |
805 f.label = noteName(note, partShift, shiftCount); | |
806 noteFeatures.push_back(f); | 725 noteFeatures.push_back(f); |
807 } | 726 } |
808 } | 727 |
728 return noteFeatures; | |
729 } |