diff plugin/transform/FeatureExtractionModelTransformer.cpp @ 384:6f6ab834449d spectrogram-cache-rejig

* Merge from trunk
author Chris Cannam
date Wed, 27 Feb 2008 11:59:42 +0000
parents aa8dbac62024
children
line wrap: on
line diff
--- a/plugin/transform/FeatureExtractionModelTransformer.cpp	Thu Nov 15 14:03:56 2007 +0000
+++ b/plugin/transform/FeatureExtractionModelTransformer.cpp	Wed Feb 27 11:59:42 2008 +0000
@@ -29,81 +29,130 @@
 #include "data/model/FFTModel.h"
 #include "data/model/WaveFileModel.h"
 
+#include "TransformFactory.h"
+
 #include <QMessageBox>
 
 #include <iostream>
 
-FeatureExtractionModelTransformer::FeatureExtractionModelTransformer(Model *inputModel,
-								   QString pluginId,
-                                                                   const ExecutionContext &context,
-                                                                   QString configurationXml,
-								   QString outputName) :
-    PluginTransformer(inputModel, context),
+FeatureExtractionModelTransformer::FeatureExtractionModelTransformer(Input in,
+                                                                     const Transform &transform) :
+    ModelTransformer(in, transform),
     m_plugin(0),
     m_descriptor(0),
     m_outputFeatureNo(0)
 {
-//    std::cerr << "FeatureExtractionModelTransformer::FeatureExtractionModelTransformer: plugin " << pluginId.toStdString() << ", outputName " << outputName.toStdString() << std::endl;
+//    std::cerr << "FeatureExtractionModelTransformer::FeatureExtractionModelTransformer: plugin " << pluginId.toStdString() << ", outputName " << m_transform.getOutput().toStdString() << std::endl;
+
+    QString pluginId = transform.getPluginIdentifier();
 
     FeatureExtractionPluginFactory *factory =
 	FeatureExtractionPluginFactory::instanceFor(pluginId);
 
     if (!factory) {
-	std::cerr << "FeatureExtractionModelTransformer: No factory available for plugin id \""
-		  << pluginId.toStdString() << "\"" << std::endl;
+        m_message = tr("No factory available for feature extraction plugin id \"%1\" (unknown plugin type, or internal error?)").arg(pluginId);
 	return;
     }
 
-    m_plugin = factory->instantiatePlugin(pluginId, m_input->getSampleRate());
+    DenseTimeValueModel *input = getConformingInput();
+    if (!input) {
+        m_message = tr("Input model for feature extraction plugin \"%1\" is of wrong type (internal error?)").arg(pluginId);
+        return;
+    }
 
+    m_plugin = factory->instantiatePlugin(pluginId, input->getSampleRate());
     if (!m_plugin) {
-	std::cerr << "FeatureExtractionModelTransformer: Failed to instantiate plugin \""
-		  << pluginId.toStdString() << "\"" << std::endl;
+        m_message = tr("Failed to instantiate plugin \"%1\"").arg(pluginId);
 	return;
     }
 
-    if (configurationXml != "") {
-        PluginXml(m_plugin).setParametersFromXml(configurationXml);
-    }
+    TransformFactory::getInstance()->makeContextConsistentWithPlugin
+        (m_transform, m_plugin);
 
-    DenseTimeValueModel *input = getInput();
-    if (!input) return;
+    TransformFactory::getInstance()->setPluginParameters
+        (m_transform, m_plugin);
 
     size_t channelCount = input->getChannelCount();
     if (m_plugin->getMaxChannelCount() < channelCount) {
 	channelCount = 1;
     }
     if (m_plugin->getMinChannelCount() > channelCount) {
-	std::cerr << "FeatureExtractionModelTransformer:: "
-		  << "Can't provide enough channels to plugin (plugin min "
-		  << m_plugin->getMinChannelCount() << ", max "
-		  << m_plugin->getMaxChannelCount() << ", input model has "
-		  << input->getChannelCount() << ")" << std::endl;
+        m_message = tr("Cannot provide enough channels to feature extraction plugin \"%1\" (plugin min is %2, max %3; input model has %4)")
+            .arg(pluginId)
+            .arg(m_plugin->getMinChannelCount())
+            .arg(m_plugin->getMaxChannelCount())
+            .arg(input->getChannelCount());
 	return;
     }
 
     std::cerr << "Initialising feature extraction plugin with channels = "
-              << channelCount << ", step = " << m_context.stepSize
-              << ", block = " << m_context.blockSize << std::endl;
+              << channelCount << ", step = " << m_transform.getStepSize()
+              << ", block = " << m_transform.getBlockSize() << std::endl;
 
     if (!m_plugin->initialise(channelCount,
-                              m_context.stepSize,
-                              m_context.blockSize)) {
-        std::cerr << "FeatureExtractionModelTransformer: Plugin "
-                  << m_plugin->getIdentifier() << " failed to initialise!" << std::endl;
-        return;
+                              m_transform.getStepSize(),
+                              m_transform.getBlockSize())) {
+
+        size_t pstep = m_transform.getStepSize();
+        size_t pblock = m_transform.getBlockSize();
+
+        m_transform.setStepSize(0);
+        m_transform.setBlockSize(0);
+        TransformFactory::getInstance()->makeContextConsistentWithPlugin
+            (m_transform, m_plugin);
+
+        if (m_transform.getStepSize() != pstep ||
+            m_transform.getBlockSize() != pblock) {
+            
+            if (!m_plugin->initialise(channelCount,
+                                      m_transform.getStepSize(),
+                                      m_transform.getBlockSize())) {
+
+                m_message = tr("Failed to initialise feature extraction plugin \"%1\"").arg(pluginId);
+                return;
+
+            } else {
+
+                m_message = tr("Feature extraction plugin \"%1\" rejected the given step and block sizes (%2 and %3); using plugin defaults (%4 and %5) instead")
+                    .arg(pluginId)
+                    .arg(pstep)
+                    .arg(pblock)
+                    .arg(m_transform.getStepSize())
+                    .arg(m_transform.getBlockSize());
+            }
+
+        } else {
+
+            m_message = tr("Failed to initialise feature extraction plugin \"%1\"").arg(pluginId);
+            return;
+        }
+    }
+
+    if (m_transform.getPluginVersion() != "") {
+        QString pv = QString("%1").arg(m_plugin->getPluginVersion());
+        if (pv != m_transform.getPluginVersion()) {
+            QString vm = tr("Transform was configured for version %1 of plugin \"%2\", but the plugin being used is version %3")
+                .arg(m_transform.getPluginVersion())
+                .arg(pluginId)
+                .arg(pv);
+            if (m_message != "") {
+                m_message = QString("%1; %2").arg(vm).arg(m_message);
+            } else {
+                m_message = vm;
+            }
+        }
     }
 
     Vamp::Plugin::OutputList outputs = m_plugin->getOutputDescriptors();
 
     if (outputs.empty()) {
-	std::cerr << "FeatureExtractionModelTransformer: Plugin \""
-		  << pluginId.toStdString() << "\" has no outputs" << std::endl;
+        m_message = tr("Plugin \"%1\" has no outputs").arg(pluginId);
 	return;
     }
     
     for (size_t i = 0; i < outputs.size(); ++i) {
-	if (outputName == "" || outputs[i].identifier == outputName.toStdString()) {
+	if (m_transform.getOutput() == "" ||
+            outputs[i].identifier == m_transform.getOutput().toStdString()) {
 	    m_outputFeatureNo = i;
 	    m_descriptor = new Vamp::Plugin::OutputDescriptor
 		(outputs[i]);
@@ -112,9 +161,9 @@
     }
 
     if (!m_descriptor) {
-	std::cerr << "FeatureExtractionModelTransformer: Plugin \""
-		  << pluginId.toStdString() << "\" has no output named \""
-		  << outputName.toStdString() << "\"" << std::endl;
+        m_message = tr("Plugin \"%1\" has no output named \"%2\"")
+            .arg(pluginId)
+            .arg(m_transform.getOutput());
 	return;
     }
 
@@ -138,7 +187,7 @@
         haveExtents = true;
     }
 
-    size_t modelRate = m_input->getSampleRate();
+    size_t modelRate = input->getSampleRate();
     size_t modelResolution = 1;
     
     switch (m_descriptor->sampleType) {
@@ -150,7 +199,7 @@
 	break;
 
     case Vamp::Plugin::OutputDescriptor::OneSamplePerStep:
-	modelResolution = m_context.stepSize;
+	modelResolution = m_transform.getStepSize();
 	break;
 
     case Vamp::Plugin::OutputDescriptor::FixedSampleRate:
@@ -217,7 +266,7 @@
         m_output = model;
     }
 
-    if (m_output) m_output->setSourceModel(m_input);
+    if (m_output) m_output->setSourceModel(input);
 }
 
 FeatureExtractionModelTransformer::~FeatureExtractionModelTransformer()
@@ -228,12 +277,12 @@
 }
 
 DenseTimeValueModel *
-FeatureExtractionModelTransformer::getInput()
+FeatureExtractionModelTransformer::getConformingInput()
 {
     DenseTimeValueModel *dtvm =
 	dynamic_cast<DenseTimeValueModel *>(getInputModel());
     if (!dtvm) {
-	std::cerr << "FeatureExtractionModelTransformer::getInput: WARNING: Input model is not conformable to DenseTimeValueModel" << std::endl;
+	std::cerr << "FeatureExtractionModelTransformer::getConformingInput: WARNING: Input model is not conformable to DenseTimeValueModel" << std::endl;
     }
     return dtvm;
 }
@@ -241,7 +290,7 @@
 void
 FeatureExtractionModelTransformer::run()
 {
-    DenseTimeValueModel *input = getInput();
+    DenseTimeValueModel *input = getConformingInput();
     if (!input) return;
 
     if (!m_output) return;
@@ -258,7 +307,7 @@
         sleep(1);
     }
 
-    size_t sampleRate = m_input->getSampleRate();
+    size_t sampleRate = input->getSampleRate();
 
     size_t channelCount = input->getChannelCount();
     if (m_plugin->getMaxChannelCount() < channelCount) {
@@ -267,9 +316,12 @@
 
     float **buffers = new float*[channelCount];
     for (size_t ch = 0; ch < channelCount; ++ch) {
-	buffers[ch] = new float[m_context.blockSize + 2];
+	buffers[ch] = new float[m_transform.getBlockSize() + 2];
     }
 
+    size_t stepSize = m_transform.getStepSize();
+    size_t blockSize = m_transform.getBlockSize();
+
     bool frequencyDomain = (m_plugin->getInputDomain() ==
                             Vamp::Plugin::FrequencyDomain);
     std::vector<FFTModel *> fftModels;
@@ -277,12 +329,12 @@
     if (frequencyDomain) {
         for (size_t ch = 0; ch < channelCount; ++ch) {
             FFTModel *model = new FFTModel
-                                  (getInput(),
-                                   channelCount == 1 ? m_context.channel : ch,
-                                   m_context.windowType,
-                                   m_context.blockSize,
-                                   m_context.stepSize,
-                                   m_context.blockSize,
+                                  (getConformingInput(),
+                                   channelCount == 1 ? m_input.getChannel() : ch,
+                                   m_transform.getWindowType(),
+                                   blockSize,
+                                   stepSize,
+                                   blockSize,
                                    false,
                                    StorageAdviser::PrecisionCritical);
             if (!model->isOK()) {
@@ -299,11 +351,17 @@
         }
     }
 
-    long startFrame = m_input->getStartFrame();
-    long   endFrame = m_input->getEndFrame();
+    long startFrame = m_input.getModel()->getStartFrame();
+    long   endFrame = m_input.getModel()->getEndFrame();
 
-    long contextStart = m_context.startFrame;
-    long contextDuration = m_context.duration;
+    RealTime contextStartRT = m_transform.getStartTime();
+    RealTime contextDurationRT = m_transform.getDuration();
+
+    long contextStart =
+        RealTime::realTime2Frame(contextStartRT, sampleRate);
+
+    long contextDuration =
+        RealTime::realTime2Frame(contextDurationRT, sampleRate);
 
     if (contextStart == 0 || contextStart < startFrame) {
         contextStart = startFrame;
@@ -325,7 +383,7 @@
     while (!m_abandoned) {
 
         if (frequencyDomain) {
-            if (blockFrame - int(m_context.blockSize)/2 >
+            if (blockFrame - int(blockSize)/2 >
                 contextStart + contextDuration) break;
         } else {
             if (blockFrame >= 
@@ -334,25 +392,24 @@
 
 //	std::cerr << "FeatureExtractionModelTransformer::run: blockFrame "
 //		  << blockFrame << ", endFrame " << endFrame << ", blockSize "
-//                  << m_context.blockSize << std::endl;
+//                  << blockSize << std::endl;
 
 	long completion =
-	    (((blockFrame - contextStart) / m_context.stepSize) * 99) /
-	    (contextDuration / m_context.stepSize);
+	    (((blockFrame - contextStart) / stepSize) * 99) /
+	    (contextDuration / stepSize);
 
-	// channelCount is either m_input->channelCount or 1
+	// channelCount is either m_input.getModel()->channelCount or 1
 
-        for (size_t ch = 0; ch < channelCount; ++ch) {
-            if (frequencyDomain) {
-                int column = (blockFrame - startFrame) / m_context.stepSize;
-                for (size_t i = 0; i <= m_context.blockSize/2; ++i) {
+        if (frequencyDomain) {
+            for (size_t ch = 0; ch < channelCount; ++ch) {
+                int column = (blockFrame - startFrame) / stepSize;
+                for (size_t i = 0; i <= blockSize/2; ++i) {
                     fftModels[ch]->getValuesAt
                         (column, i, buffers[ch][i*2], buffers[ch][i*2+1]);
                 }
-            } else {
-                getFrames(ch, channelCount, 
-                          blockFrame, m_context.blockSize, buffers[ch]);
-            }                
+            }
+        } else {
+            getFrames(channelCount, blockFrame, blockSize, buffers);
         }
 
 	Vamp::Plugin::FeatureSet features = m_plugin->process
@@ -369,7 +426,7 @@
 	    prevCompletion = completion;
 	}
 
-	blockFrame += m_context.stepSize;
+	blockFrame += stepSize;
     }
 
     if (m_abandoned) return;
@@ -392,15 +449,17 @@
 }
 
 void
-FeatureExtractionModelTransformer::getFrames(int channel, int channelCount,
-                                            long startFrame, long size,
-                                            float *buffer)
+FeatureExtractionModelTransformer::getFrames(int channelCount,
+                                             long startFrame, long size,
+                                             float **buffers)
 {
     long offset = 0;
 
     if (startFrame < 0) {
-        for (int i = 0; i < size && startFrame + i < 0; ++i) {
-            buffer[i] = 0.0f;
+        for (int c = 0; c < channelCount; ++c) {
+            for (int i = 0; i < size && startFrame + i < 0; ++i) {
+                buffers[c][i] = 0.0f;
+            }
         }
         offset = -startFrame;
         size -= offset;
@@ -408,30 +467,52 @@
         startFrame = 0;
     }
 
-    long got = getInput()->getData
-        ((channelCount == 1 ? m_context.channel : channel),
-         startFrame, size, buffer + offset);
+    DenseTimeValueModel *input = getConformingInput();
+    if (!input) return;
+    
+    long got = 0;
+
+    if (channelCount == 1) {
+
+        got = input->getData(m_input.getChannel(), startFrame, size,
+                             buffers[0] + offset);
+
+        if (m_input.getChannel() == -1 && input->getChannelCount() > 1) {
+            // use mean instead of sum, as plugin input
+            float cc = float(input->getChannelCount());
+            for (long i = 0; i < size; ++i) {
+                buffers[0][i + offset] /= cc;
+            }
+        }
+
+    } else {
+
+        float **writebuf = buffers;
+        if (offset > 0) {
+            writebuf = new float *[channelCount];
+            for (int i = 0; i < channelCount; ++i) {
+                writebuf[i] = buffers[i] + offset;
+            }
+        }
+
+        got = input->getData(0, channelCount-1, startFrame, size, writebuf);
+
+        if (writebuf != buffers) delete[] writebuf;
+    }
 
     while (got < size) {
-        buffer[offset + got] = 0.0;
+        for (int c = 0; c < channelCount; ++c) {
+            buffers[c][got + offset] = 0.0;
+        }
         ++got;
     }
-
-    if (m_context.channel == -1 && channelCount == 1 &&
-        getInput()->getChannelCount() > 1) {
-        // use mean instead of sum, as plugin input
-        int cc = getInput()->getChannelCount();
-        for (long i = 0; i < size; ++i) {
-            buffer[i] /= cc;
-        }
-    }
 }
 
 void
 FeatureExtractionModelTransformer::addFeature(size_t blockFrame,
 					     const Vamp::Plugin::Feature &feature)
 {
-    size_t inputRate = m_input->getSampleRate();
+    size_t inputRate = m_input.getModel()->getSampleRate();
 
 //    std::cerr << "FeatureExtractionModelTransformer::addFeature("
 //	      << blockFrame << ")" << std::endl;
@@ -470,8 +551,10 @@
 	
     if (binCount == 0) {
 
-	SparseOneDimensionalModel *model = getOutput<SparseOneDimensionalModel>();
+	SparseOneDimensionalModel *model =
+            getConformingOutput<SparseOneDimensionalModel>();
 	if (!model) return;
+
 	model->addPoint(SparseOneDimensionalModel::Point(frame, feature.label.c_str()));
 	
     } else if (binCount == 1) {
@@ -479,8 +562,10 @@
 	float value = 0.0;
 	if (feature.values.size() > 0) value = feature.values[0];
 
-	SparseTimeValueModel *model = getOutput<SparseTimeValueModel>();
+	SparseTimeValueModel *model =
+            getConformingOutput<SparseTimeValueModel>();
 	if (!model) return;
+
 	model->addPoint(SparseTimeValueModel::Point(frame, value, feature.label.c_str()));
 //        std::cerr << "SparseTimeValueModel::addPoint(" << frame << ", " << value << "), " << feature.label.c_str() << std::endl;
 
@@ -495,12 +580,15 @@
         
         float velocity = 100;
         if (feature.values.size() > 2) velocity = feature.values[2];
+        if (velocity < 0) velocity = 127;
+        if (velocity > 127) velocity = 127;
 
-        NoteModel *model = getOutput<NoteModel>();
+        NoteModel *model = getConformingOutput<NoteModel>();
         if (!model) return;
 
         model->addPoint(NoteModel::Point(frame, pitch,
                                          lrintf(duration),
+                                         velocity / 127.f,
                                          feature.label.c_str()));
 	
     } else {
@@ -508,7 +596,7 @@
 	DenseThreeDimensionalModel::Column values = feature.values;
 	
 	EditableDenseThreeDimensionalModel *model =
-            getOutput<EditableDenseThreeDimensionalModel>();
+            getConformingOutput<EditableDenseThreeDimensionalModel>();
 	if (!model) return;
 
 	model->setColumn(frame / model->getResolution(), values);
@@ -528,29 +616,32 @@
 
     if (binCount == 0) {
 
-	SparseOneDimensionalModel *model = getOutput<SparseOneDimensionalModel>();
+	SparseOneDimensionalModel *model =
+            getConformingOutput<SparseOneDimensionalModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else if (binCount == 1) {
 
-	SparseTimeValueModel *model = getOutput<SparseTimeValueModel>();
+	SparseTimeValueModel *model =
+            getConformingOutput<SparseTimeValueModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else if (m_descriptor->sampleType ==
 	       Vamp::Plugin::OutputDescriptor::VariableSampleRate) {
 
-	NoteModel *model = getOutput<NoteModel>();
+	NoteModel *model =
+            getConformingOutput<NoteModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
 
     } else {
 
 	EditableDenseThreeDimensionalModel *model =
-            getOutput<EditableDenseThreeDimensionalModel>();
+            getConformingOutput<EditableDenseThreeDimensionalModel>();
 	if (!model) return;
-	model->setCompletion(completion, m_context.updates);
+	model->setCompletion(completion, true); //!!!m_context.updates);
     }
 }