diff transform/FeatureExtractionModelTransformer.cpp @ 1740:fe3f7f8df3a3 by-id

More work on transformers
author Chris Cannam
date Wed, 26 Jun 2019 17:25:20 +0100
parents 565575463752
children b92bdcd4954b
line wrap: on
line diff
--- a/transform/FeatureExtractionModelTransformer.cpp	Wed Jun 26 14:59:09 2019 +0100
+++ b/transform/FeatureExtractionModelTransformer.cpp	Wed Jun 26 17:25:20 2019 +0100
@@ -650,7 +650,6 @@
 
     auto input = ModelById::getAs<DenseTimeValueModel>(getInputModel());
     if (!input) {
-        SVCERR << "FeatureExtractionModelTransformer::run: Input model not (no longer?) available, abandoning" << endl;
         abandon();
         return;
     }
@@ -794,22 +793,22 @@
 
             if (m_abandoned) break;
 
-            Vamp::Plugin::FeatureSet features = m_plugin->process
+            auto features = m_plugin->process
                 (buffers,
                  RealTime::frame2RealTime(blockFrame, sampleRate)
                  .toVampRealTime());
             
             if (m_abandoned) break;
 
-            for (int j = 0; j < (int)m_outputNos.size(); ++j) {
-                for (int fi = 0; fi < (int)features[m_outputNos[j]].size(); ++fi) {
-                    Vamp::Plugin::Feature feature = features[m_outputNos[j]][fi];
+            for (int j = 0; in_range_for(m_outputNos, j); ++j) {
+                for (int fi = 0; in_range_for(features[m_outputNos[j]], fi); ++fi) {
+                    auto feature = features[m_outputNos[j]][fi];
                     addFeature(j, blockFrame, feature);
                 }
             }
 
             if (blockFrame == contextStart || completion > prevCompletion) {
-                for (int j = 0; j < (int)m_outputNos.size(); ++j) {
+                for (int j = 0; in_range_for(m_outputNos, j); ++j) {
                     setCompletion(j, completion);
                 }
                 prevCompletion = completion;
@@ -820,11 +819,11 @@
         }
 
         if (!m_abandoned) {
-            Vamp::Plugin::FeatureSet features = m_plugin->getRemainingFeatures();
+            auto features = m_plugin->getRemainingFeatures();
 
-            for (int j = 0; j < (int)m_outputNos.size(); ++j) {
-                for (int fi = 0; fi < (int)features[m_outputNos[j]].size(); ++fi) {
-                    Vamp::Plugin::Feature feature = features[m_outputNos[j]][fi];
+            for (int j = 0; in_range_for(m_outputNos, j); ++j) {
+                for (int fi = 0; in_range_for(features[m_outputNos[j]], fi); ++fi) {
+                    auto feature = features[m_outputNos[j]][fi];
                     addFeature(j, blockFrame, feature);
                 }
             }
@@ -990,129 +989,116 @@
     // to create.
 
     ModelId outputId = m_outputs[n];
-    bool found = false;
+
+    if (isOutputType<SparseOneDimensionalModel>(n)) {
+
+        auto model = ModelById::getAs<SparseOneDimensionalModel>(outputId);
+        if (!model) return;
+        model->add(Event(frame, feature.label.c_str()));
+        
+    } else if (isOutputType<SparseTimeValueModel>(n)) {
+
+        auto model = ModelById::getAs<SparseTimeValueModel>(outputId);
+        if (!model) return;
+
+        for (int i = 0; in_range_for(feature.values, i); ++i) {
+
+            float value = feature.values[i];
+
+            QString label = feature.label.c_str();
+            if (feature.values.size() > 1) {
+                label = QString("[%1] %2").arg(i+1).arg(label);
+            }
+
+            auto targetModel = model;
+
+            if (m_needAdditionalModels[n] && i > 0) {
+                targetModel = ModelById::getAs<SparseTimeValueModel>
+                    (getAdditionalModel(n, i));
+                if (!targetModel) targetModel = model;
+            }
+
+            targetModel->add(Event(frame, value, label));
+        }
+
+    } else if (isOutputType<NoteModel>(n) || isOutputType<RegionModel>(n)) {
     
-    if (!found) {
-        auto model = ModelById::getAs<SparseOneDimensionalModel>(outputId);
-        if (model) {
-            found = true;
-            model->add(Event(frame, feature.label.c_str()));
+        int index = 0;
+
+        float value = 0.0;
+        if ((int)feature.values.size() > index) {
+            value = feature.values[index++];
         }
-    }
-    
-    if (!found) {
-        auto model = ModelById::getAs<SparseTimeValueModel>(outputId);
-        if (model) {
-            found = true;
 
-            for (int i = 0; in_range_for(feature.values, i); ++i) {
-
-                float value = feature.values[i];
-
-                QString label = feature.label.c_str();
-                if (feature.values.size() > 1) {
-                    label = QString("[%1] %2").arg(i+1).arg(label);
-                }
-
-                auto targetModel = model;
-
-                if (m_needAdditionalModels[n] && i > 0) {
-                    targetModel = ModelById::getAs<SparseTimeValueModel>
-                        (getAdditionalModel(n, i));
-                    if (!targetModel) targetModel = model;
-                }
-
-                targetModel->add(Event(frame, value, label));
+        sv_frame_t duration = 1;
+        if (feature.hasDuration) {
+            duration = RealTime::realTime2Frame(feature.duration, inputRate);
+        } else {
+            if (in_range_for(feature.values, index)) {
+                duration = lrintf(feature.values[index++]);
             }
         }
-    }
-    
-    if (!found) {
-        if (ModelById::getAs<NoteModel>(outputId) ||
-            ModelById::getAs<RegionModel>(outputId)) {
-            found = true;
 
-            int index = 0;
+        auto noteModel = ModelById::getAs<NoteModel>(outputId);
+        if (noteModel) {
 
-            float value = 0.0;
+            float velocity = 100;
             if ((int)feature.values.size() > index) {
-                value = feature.values[index++];
+                velocity = feature.values[index++];
             }
+            if (velocity < 0) velocity = 127;
+            if (velocity > 127) velocity = 127;
+            
+            noteModel->add(Event(frame, value, // value is pitch
+                                 duration,
+                                 velocity / 127.f,
+                                 feature.label.c_str()));
+        }
 
-            sv_frame_t duration = 1;
-            if (feature.hasDuration) {
-                duration = RealTime::realTime2Frame(feature.duration, inputRate);
-            } else {
-                if (in_range_for(feature.values, index)) {
-                    duration = lrintf(feature.values[index++]);
-                }
-            }
-
-            auto noteModel = ModelById::getAs<NoteModel>(outputId);
-            if (noteModel) {
-
-                float velocity = 100;
-                if ((int)feature.values.size() > index) {
-                    velocity = feature.values[index++];
-                }
-                if (velocity < 0) velocity = 127;
-                if (velocity > 127) velocity = 127;
-
-                noteModel->add(Event(frame, value, // value is pitch
-                                     duration,
-                                     velocity / 127.f,
-                                     feature.label.c_str()));
-            }
-
-            auto regionModel = ModelById::getAs<RegionModel>(outputId);
-            if (regionModel) {
-
-                if (feature.hasDuration && !feature.values.empty()) {
-
-                    for (int i = 0; in_range_for(feature.values, i); ++i) {
-
-                        float value = feature.values[i];
-
-                        QString label = feature.label.c_str();
-                        if (feature.values.size() > 1) {
-                            label = QString("[%1] %2").arg(i+1).arg(label);
-                        }
-
-                        regionModel->add(Event(frame,
-                                               value,
-                                               duration,
-                                               label));
+        auto regionModel = ModelById::getAs<RegionModel>(outputId);
+        if (regionModel) {
+            
+            if (feature.hasDuration && !feature.values.empty()) {
+                
+                for (int i = 0; in_range_for(feature.values, i); ++i) {
+                    
+                    float value = feature.values[i];
+                    
+                    QString label = feature.label.c_str();
+                    if (feature.values.size() > 1) {
+                        label = QString("[%1] %2").arg(i+1).arg(label);
                     }
-                } else {
-            
+                    
                     regionModel->add(Event(frame,
                                            value,
                                            duration,
-                                           feature.label.c_str()));
+                                           label));
                 }
+            } else {
+                
+                regionModel->add(Event(frame,
+                                       value,
+                                       duration,
+                                       feature.label.c_str()));
             }
         }
+
+    } else if (isOutputType<EditableDenseThreeDimensionalModel>(n)) {
+
+        auto model = ModelById::getAs
+            <EditableDenseThreeDimensionalModel>(outputId);
+        if (!model) return;
+        
+        DenseThreeDimensionalModel::Column values = feature.values;
+        
+        if (!feature.hasTimestamp && m_fixedRateFeatureNos[n] >= 0) {
+            model->setColumn(m_fixedRateFeatureNos[n], values);
+        } else {
+            model->setColumn(int(frame / model->getResolution()), values);
+        }
     }
 
-    if (!found) {
-        auto model = ModelById::getAs
-            <EditableDenseThreeDimensionalModel>(outputId);
-        if (model) {
-            found = true;
-        
-            DenseThreeDimensionalModel::Column values = feature.values;
-        
-            if (!feature.hasTimestamp && m_fixedRateFeatureNos[n] >= 0) {
-                model->setColumn(m_fixedRateFeatureNos[n], values);
-            } else {
-                model->setColumn(int(frame / model->getResolution()), values);
-            }
-        }
-    }
-
-    if (!found) {
-        SVDEBUG << "FeatureExtractionModelTransformer::addFeature: Unknown output model type!" << endl;
-    }
+    SVDEBUG << "FeatureExtractionModelTransformer::addFeature: Unknown output model type!" << endl;
 }
 
 void
@@ -1123,47 +1109,11 @@
               << completion << ")" << endl;
 #endif
 
-    ModelId outputId = m_outputs[n];
-    bool found = false;
-    
-    if (!found) {
-        auto model = ModelById::getAs<SparseOneDimensionalModel>(outputId);
-        if (model) {
-            found = true;
-            model->setCompletion(completion, true);
-        }
-    }
-
-    if (!found) {
-        auto model = ModelById::getAs<SparseTimeValueModel>(outputId);
-        if (model) {
-            found = true;
-            model->setCompletion(completion, true);
-        }
-    }
-
-    if (!found) {
-        auto model = ModelById::getAs<NoteModel>(outputId);
-        if (model) {
-            found = true;
-            model->setCompletion(completion, true);
-        }
-    }
-
-    if (!found) {
-        auto model = ModelById::getAs<RegionModel>(outputId);
-        if (model) {
-            found = true;
-            model->setCompletion(completion, true);
-        }
-    }
-
-    if (!found) {
-        auto model = ModelById::getAs<EditableDenseThreeDimensionalModel>(outputId);
-        if (model) {
-            found = true;
-            model->setCompletion(completion, true);
-        }
-    }
+    (void)
+        (setOutputCompletion<SparseOneDimensionalModel>(n, completion) ||
+         setOutputCompletion<SparseTimeValueModel>(n, completion) ||
+         setOutputCompletion<NoteModel>(n, completion) ||
+         setOutputCompletion<RegionModel>(n, completion) ||
+         setOutputCompletion<EditableDenseThreeDimensionalModel>(n, completion));
 }