diff plugin/transform/FeatureExtractionModelTransformer.cpp @ 350:d7c41483af8f

* Merge from transforms branch -- switch over to using Transform object properly
author Chris Cannam
date Fri, 07 Dec 2007 16:47:31 +0000
parents 277006c62fea
children 399ea254afd6
line wrap: on
line diff
--- a/plugin/transform/FeatureExtractionModelTransformer.cpp	Fri Nov 30 17:31:09 2007 +0000
+++ b/plugin/transform/FeatureExtractionModelTransformer.cpp	Fri Dec 07 16:47:31 2007 +0000
@@ -29,21 +29,22 @@
 #include "data/model/FFTModel.h"
 #include "data/model/WaveFileModel.h"
 
+#include "TransformFactory.h"
+
 #include <QMessageBox>
 
 #include <iostream>
 
-FeatureExtractionModelTransformer::FeatureExtractionModelTransformer(Model *inputModel,
-								   QString pluginId,
-                                                                   const ExecutionContext &context,
-                                                                   QString configurationXml,
-								   QString outputName) :
-    PluginTransformer(inputModel, context),
+FeatureExtractionModelTransformer::FeatureExtractionModelTransformer(Input in,
+                                                                     const Transform &transform) :
+    ModelTransformer(in, transform),
     m_plugin(0),
     m_descriptor(0),
     m_outputFeatureNo(0)
 {
-//    std::cerr << "FeatureExtractionModelTransformer::FeatureExtractionModelTransformer: plugin " << pluginId.toStdString() << ", outputName " << outputName.toStdString() << std::endl;
+//    std::cerr << "FeatureExtractionModelTransformer::FeatureExtractionModelTransformer: plugin " << pluginId.toStdString() << ", outputName " << m_transform.getOutput().toStdString() << std::endl;
+
+    QString pluginId = transform.getPluginIdentifier();
 
     FeatureExtractionPluginFactory *factory =
 	FeatureExtractionPluginFactory::instanceFor(pluginId);
@@ -54,22 +55,24 @@
 	return;
     }
 
-    m_plugin = factory->instantiatePlugin(pluginId, m_input->getSampleRate());
+    DenseTimeValueModel *input = getConformingInput();
+    if (!input) {
+        std::cerr << "FeatureExtractionModelTransformer: Input model not conformable" << std::endl;
+        return;
+    }
 
+    m_plugin = factory->instantiatePlugin(pluginId, input->getSampleRate());
     if (!m_plugin) {
 	std::cerr << "FeatureExtractionModelTransformer: Failed to instantiate plugin \""
 		  << pluginId.toStdString() << "\"" << std::endl;
 	return;
     }
 
-    m_context.makeConsistentWithPlugin(m_plugin);
+    TransformFactory::getInstance()->makeContextConsistentWithPlugin
+        (m_transform, m_plugin);
 
-    if (configurationXml != "") {
-        PluginXml(m_plugin).setParametersFromXml(configurationXml);
-    }
-
-    DenseTimeValueModel *input = getInput();
-    if (!input) return;
+    TransformFactory::getInstance()->setPluginParameters
+        (m_transform, m_plugin);
 
     size_t channelCount = input->getChannelCount();
     if (m_plugin->getMaxChannelCount() < channelCount) {
@@ -85,14 +88,14 @@
     }
 
     std::cerr << "Initialising feature extraction plugin with channels = "
-              << channelCount << ", step = " << m_context.stepSize
-              << ", block = " << m_context.blockSize << std::endl;
+              << channelCount << ", step = " << m_transform.getStepSize()
+              << ", block = " << m_transform.getBlockSize() << std::endl;
 
     if (!m_plugin->initialise(channelCount,
-                              m_context.stepSize,
-                              m_context.blockSize)) {
+                              m_transform.getStepSize(),
+                              m_transform.getBlockSize())) {
         std::cerr << "FeatureExtractionModelTransformer: Plugin "
-                  << m_plugin->getIdentifier() << " failed to initialise!" << std::endl;
+                  << pluginId.toStdString() << " failed to initialise!" << std::endl;
         return;
     }
 
@@ -105,7 +108,8 @@
     }
     
     for (size_t i = 0; i < outputs.size(); ++i) {
-	if (outputName == "" || outputs[i].identifier == outputName.toStdString()) {
+	if (m_transform.getOutput() == "" ||
+            outputs[i].identifier == m_transform.getOutput().toStdString()) {
 	    m_outputFeatureNo = i;
 	    m_descriptor = new Vamp::Plugin::OutputDescriptor
 		(outputs[i]);
@@ -116,7 +120,7 @@
     if (!m_descriptor) {
 	std::cerr << "FeatureExtractionModelTransformer: Plugin \""
 		  << pluginId.toStdString() << "\" has no output named \""
-		  << outputName.toStdString() << "\"" << std::endl;
+		  << m_transform.getOutput().toStdString() << "\"" << std::endl;
 	return;
     }
 
@@ -140,7 +144,7 @@
         haveExtents = true;
     }
 
-    size_t modelRate = m_input->getSampleRate();
+    size_t modelRate = input->getSampleRate();
     size_t modelResolution = 1;
     
     switch (m_descriptor->sampleType) {
@@ -152,7 +156,7 @@
 	break;
 
     case Vamp::Plugin::OutputDescriptor::OneSamplePerStep:
-	modelResolution = m_context.stepSize;
+	modelResolution = m_transform.getStepSize();
 	break;
 
     case Vamp::Plugin::OutputDescriptor::FixedSampleRate:
@@ -219,7 +223,7 @@
         m_output = model;
     }
 
-    if (m_output) m_output->setSourceModel(m_input);
+    if (m_output) m_output->setSourceModel(input);
 }
 
 FeatureExtractionModelTransformer::~FeatureExtractionModelTransformer()
@@ -230,12 +234,12 @@
 }
 
 DenseTimeValueModel *
-FeatureExtractionModelTransformer::getInput()
+FeatureExtractionModelTransformer::getConformingInput()
 {
     DenseTimeValueModel *dtvm =
 	dynamic_cast<DenseTimeValueModel *>(getInputModel());
     if (!dtvm) {
-	std::cerr << "FeatureExtractionModelTransformer::getInput: WARNING: Input model is not conformable to DenseTimeValueModel" << std::endl;
+	std::cerr << "FeatureExtractionModelTransformer::getConformingInput: WARNING: Input model is not conformable to DenseTimeValueModel" << std::endl;
     }
     return dtvm;
 }
@@ -243,7 +247,7 @@
 void
 FeatureExtractionModelTransformer::run()
 {
-    DenseTimeValueModel *input = getInput();
+    DenseTimeValueModel *input = getConformingInput();
     if (!input) return;
 
     if (!m_output) return;
@@ -260,7 +264,7 @@
         sleep(1);
     }
 
-    size_t sampleRate = m_input->getSampleRate();
+    size_t sampleRate = input->getSampleRate();
 
     size_t channelCount = input->getChannelCount();
     if (m_plugin->getMaxChannelCount() < channelCount) {
@@ -269,9 +273,12 @@
 
     float **buffers = new float*[channelCount];
     for (size_t ch = 0; ch < channelCount; ++ch) {
-	buffers[ch] = new float[m_context.blockSize + 2];
+	buffers[ch] = new float[m_transform.getBlockSize() + 2];
     }
 
+    size_t stepSize = m_transform.getStepSize();
+    size_t blockSize = m_transform.getBlockSize();
+
     bool frequencyDomain = (m_plugin->getInputDomain() ==
                             Vamp::Plugin::FrequencyDomain);
     std::vector<FFTModel *> fftModels;
@@ -279,12 +286,12 @@
     if (frequencyDomain) {
         for (size_t ch = 0; ch < channelCount; ++ch) {
             FFTModel *model = new FFTModel
-                                  (getInput(),
-                                   channelCount == 1 ? m_context.channel : ch,
-                                   m_context.windowType,
-                                   m_context.blockSize,
-                                   m_context.stepSize,
-                                   m_context.blockSize,
+                                  (getConformingInput(),
+                                   channelCount == 1 ? m_input.getChannel() : ch,
+                                   m_transform.getWindowType(),
+                                   blockSize,
+                                   stepSize,
+                                   blockSize,
                                    false,
                                    StorageAdviser::PrecisionCritical);
             if (!model->isOK()) {
@@ -301,11 +308,17 @@
         }
     }
 
-    long startFrame = m_input->getStartFrame();
-    long   endFrame = m_input->getEndFrame();
+    long startFrame = m_input.getModel()->getStartFrame();
+    long   endFrame = m_input.getModel()->getEndFrame();
 
-    long contextStart = m_context.startFrame;
-    long contextDuration = m_context.duration;
+    RealTime contextStartRT = m_transform.getStartTime();
+    RealTime contextDurationRT = m_transform.getDuration();
+
+    long contextStart =
+        RealTime::realTime2Frame(contextStartRT, sampleRate);
+
+    long contextDuration =
+        RealTime::realTime2Frame(contextDurationRT, sampleRate);
 
     if (contextStart == 0 || contextStart < startFrame) {
         contextStart = startFrame;
@@ -327,7 +340,7 @@
     while (!m_abandoned) {
 
         if (frequencyDomain) {
-            if (blockFrame - int(m_context.blockSize)/2 >
+            if (blockFrame - int(blockSize)/2 >
                 contextStart + contextDuration) break;
         } else {
             if (blockFrame >= 
@@ -336,24 +349,24 @@
 
 //	std::cerr << "FeatureExtractionModelTransformer::run: blockFrame "
 //		  << blockFrame << ", endFrame " << endFrame << ", blockSize "
-//                  << m_context.blockSize << std::endl;
+//                  << blockSize << std::endl;
 
 	long completion =
-	    (((blockFrame - contextStart) / m_context.stepSize) * 99) /
-	    (contextDuration / m_context.stepSize);
+	    (((blockFrame - contextStart) / stepSize) * 99) /
+	    (contextDuration / stepSize);
 
-	// channelCount is either m_input->channelCount or 1
+	// channelCount is either m_input.getModel()->channelCount or 1
 
         for (size_t ch = 0; ch < channelCount; ++ch) {
             if (frequencyDomain) {
-                int column = (blockFrame - startFrame) / m_context.stepSize;
-                for (size_t i = 0; i <= m_context.blockSize/2; ++i) {
+                int column = (blockFrame - startFrame) / stepSize;
+                for (size_t i = 0; i <= blockSize/2; ++i) {
                     fftModels[ch]->getValuesAt
                         (column, i, buffers[ch][i*2], buffers[ch][i*2+1]);
                 }
             } else {
                 getFrames(ch, channelCount, 
-                          blockFrame, m_context.blockSize, buffers[ch]);
+                          blockFrame, blockSize, buffers[ch]);
             }                
         }
 
@@ -371,7 +384,7 @@
 	    prevCompletion = completion;
 	}
 
-	blockFrame += m_context.stepSize;
+	blockFrame += stepSize;
     }
 
     if (m_abandoned) return;
@@ -410,8 +423,11 @@
         startFrame = 0;
     }
 
-    long got = getInput()->getData
-        ((channelCount == 1 ? m_context.channel : channel),
+    DenseTimeValueModel *input = getConformingInput();
+    if (!input) return;
+
+    long got = input->getData
+        ((channelCount == 1 ? m_input.getChannel() : channel),
          startFrame, size, buffer + offset);
 
     while (got < size) {
@@ -419,10 +435,10 @@
         ++got;
     }
 
-    if (m_context.channel == -1 && channelCount == 1 &&
-        getInput()->getChannelCount() > 1) {
+    if (m_input.getChannel() == -1 && channelCount == 1 &&
+        input->getChannelCount() > 1) {
         // use mean instead of sum, as plugin input
-        int cc = getInput()->getChannelCount();
+        int cc = input->getChannelCount();
         for (long i = 0; i < size; ++i) {
             buffer[i] /= cc;
         }
@@ -433,7 +449,7 @@
 FeatureExtractionModelTransformer::addFeature(size_t blockFrame,
 					     const Vamp::Plugin::Feature &feature)
 {
-    size_t inputRate = m_input->getSampleRate();
+    size_t inputRate = m_input.getModel()->getSampleRate();
 
 //    std::cerr << "FeatureExtractionModelTransformer::addFeature("
 //	      << blockFrame << ")" << std::endl;
@@ -472,8 +488,10 @@
 	
     if (binCount == 0) {
 
-	SparseOneDimensionalModel *model = getOutput<SparseOneDimensionalModel>();
+	SparseOneDimensionalModel *model =
+            getConformingOutput<SparseOneDimensionalModel>();
 	if (!model) return;
+
 	model->addPoint(SparseOneDimensionalModel::Point(frame, feature.label.c_str()));
 	
     } else if (binCount == 1) {
@@ -481,8 +499,10 @@
 	float value = 0.0;
 	if (feature.values.size() > 0) value = feature.values[0];
 
-	SparseTimeValueModel *model = getOutput<SparseTimeValueModel>();
+	SparseTimeValueModel *model =
+            getConformingOutput<SparseTimeValueModel>();
 	if (!model) return;
+
 	model->addPoint(SparseTimeValueModel::Point(frame, value, feature.label.c_str()));
 //        std::cerr << "SparseTimeValueModel::addPoint(" << frame << ", " << value << "), " << feature.label.c_str() << std::endl;
 
@@ -500,7 +520,7 @@
         if (velocity < 0) velocity = 127;
         if (velocity > 127) velocity = 127;
 
-        NoteModel *model = getOutput<NoteModel>();
+        NoteModel *model = getConformingOutput<NoteModel>();
         if (!model) return;
 
         model->addPoint(NoteModel::Point(frame, pitch,
@@ -513,7 +533,7 @@
 	DenseThreeDimensionalModel::Column values = feature.values;
 	
 	EditableDenseThreeDimensionalModel *model =
-            getOutput<EditableDenseThreeDimensionalModel>();
+            getConformingOutput<EditableDenseThreeDimensionalModel>();
 	if (!model) return;
 
 	model->setColumn(frame / model->getResolution(), values);
@@ -533,29 +553,32 @@
 
     if (binCount == 0) {
 
-	SparseOneDimensionalModel *model = getOutput<SparseOneDimensionalModel>();
+	SparseOneDimensionalModel *model =
+            getConformingOutput<SparseOneDimensionalModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else if (binCount == 1) {
 
-	SparseTimeValueModel *model = getOutput<SparseTimeValueModel>();
+	SparseTimeValueModel *model =
+            getConformingOutput<SparseTimeValueModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else if (m_descriptor->sampleType ==
 	       Vamp::Plugin::OutputDescriptor::VariableSampleRate) {
 
-	NoteModel *model = getOutput<NoteModel>();
+	NoteModel *model =
+            getConformingOutput<NoteModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else {
 
 	EditableDenseThreeDimensionalModel *model =
-            getOutput<EditableDenseThreeDimensionalModel>();
+            getConformingOutput<EditableDenseThreeDimensionalModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
     }
 }