changeset 45:69c438d4b9d3

* Pick up default sample rate and channel count from first audio file (formerly they were hardcoded to 44100 and 1...)
author Chris Cannam
date Mon, 18 Oct 2010 14:17:48 +0100
parents aa521baace07
children 4d07f61dba3f
files runner/FeatureExtractionManager.cpp runner/FeatureExtractionManager.h runner/main.cpp
diffstat 3 files changed, 168 insertions(+), 73 deletions(-) [+]
line wrap: on
line diff
--- a/runner/FeatureExtractionManager.cpp	Mon Oct 18 14:16:17 2010 +0100
+++ b/runner/FeatureExtractionManager.cpp	Mon Oct 18 14:17:48 2010 +0100
@@ -59,7 +59,7 @@
     m_blockSize(16384),
     m_defaultSampleRate(0),
     m_sampleRate(0),
-    m_channels(1)
+    m_channels(0)
 {
 }
 
@@ -69,6 +69,9 @@
          pi != m_plugins.end(); ++pi) {
         delete pi->first;
     }
+    foreach (AudioFileReader *r, m_readyReaders) {
+        delete r;
+    }
 }
 
 void FeatureExtractionManager::setChannels(int channels)
@@ -379,29 +382,22 @@
     return addFeatureExtractor(transform, writers);
 }
 
-void FeatureExtractionManager::extractFeatures(QString audioSource)
+void FeatureExtractionManager::addSource(QString audioSource)
 {
-    if (m_plugins.empty()) return;
-
-    testOutputFiles(audioSource);
-
-    ProgressPrinter retrievalProgress("Retrieving audio data...");
-
-    FileSource source(audioSource, &retrievalProgress);
-    if (!source.isAvailable()) {
-        cerr << "ERROR: File or URL \"" << audioSource.toStdString()
-             << "\" could not be located" << endl;
-        throw FileNotFound(audioSource);
-    }
-    
-    source.waitForData();
-    
     if (QFileInfo(audioSource).suffix().toLower() == "m3u") {
+        ProgressPrinter retrievalProgress("Opening playlist file...");
+        FileSource source(audioSource, &retrievalProgress);
+        if (!source.isAvailable()) {
+            cerr << "ERROR: File or URL \"" << audioSource.toStdString()
+                 << "\" could not be located" << endl;
+            throw FileNotFound(audioSource);
+        }
+        source.waitForData();
         PlaylistFileReader reader(source);
         if (reader.isOK()) {
             vector<QString> files = reader.load();
             for (int i = 0; i < (int)files.size(); ++i) {
-                extractFeatures(files[i]);
+                addSource(files[i]);
             }
             return;
         } else {
@@ -411,33 +407,106 @@
         }
     }
 
+    std::cerr << "Have audio source: \"" << audioSource.toStdString() << "\"" << std::endl;
+
+    // We don't actually do anything with it here, unless it's the
+    // first audio source and we need it to establish default channel
+    // count and sample rate
+
+    if (m_channels == 0 || m_defaultSampleRate == 0) {
+
+        ProgressPrinter retrievalProgress("Determining default rate and channel count from first input file...");
+
+        FileSource source(audioSource, &retrievalProgress);
+        if (!source.isAvailable()) {
+            cerr << "ERROR: File or URL \"" << audioSource.toStdString()
+                 << "\" could not be located" << endl;
+            throw FileNotFound(audioSource);
+        }
+    
+        source.waitForData();
+
+        // Open to determine validity, channel count, sample rate only
+        // (then close, and open again later with actual desired rate &c)
+
+        AudioFileReader *reader =
+            AudioFileReaderFactory::createReader(source, 0, &retrievalProgress);
+    
+        if (!reader) {
+            throw FailedToOpenFile(audioSource);
+        }
+
+        retrievalProgress.done();
+
+        cerr << "File or URL \"" << audioSource.toStdString() << "\" opened successfully" << endl;
+
+        if (m_channels == 0) {
+            m_channels = reader->getChannelCount();
+            cerr << "Taking default channel count of "
+                 << reader->getChannelCount() << " from file" << endl;
+        }
+
+        if (m_defaultSampleRate == 0) {
+            m_defaultSampleRate = reader->getNativeRate();
+            cerr << "Taking default sample rate of "
+                 << reader->getNativeRate() << "Hz from file" << endl;
+            cerr << "(Note: Default may be overridden by transforms)" << endl;
+        }
+
+        m_readyReaders[audioSource] = reader;
+    }
+}
+
+void FeatureExtractionManager::extractFeatures(QString audioSource)
+{
+    if (m_plugins.empty()) return;
+
+    testOutputFiles(audioSource);
+
     if (m_sampleRate == 0) {
-        cerr << "ERROR: Internal error in FeatureExtractionManager::extractFeatures: Plugin list is non-empty, but no sample rate set" << endl;
-        exit(1);
+        throw FileOperationFailed
+            (audioSource, "internal error: have sources and plugins, but no sample rate");
+    }
+    if (m_channels == 0) {
+        throw FileOperationFailed
+            (audioSource, "internal error: have sources and plugins, but no channel count");
     }
 
-    AudioFileReader *reader =
-        AudioFileReaderFactory::createReader(source, m_sampleRate, &retrievalProgress);
-    
+    AudioFileReader *reader = 0;
+
+    if (m_readyReaders.contains(audioSource)) {
+        reader = m_readyReaders[audioSource];
+        m_readyReaders.remove(audioSource);
+        if (reader->getChannelCount() != m_channels ||
+            reader->getSampleRate() != m_sampleRate) {
+            // can't use this; open it again
+            delete reader;
+            reader = 0;
+        }
+    }
+    if (!reader) {
+        ProgressPrinter retrievalProgress("Retrieving audio data...");
+        FileSource source(audioSource, &retrievalProgress);
+        source.waitForData();
+        reader = AudioFileReaderFactory::createReader
+            (source, m_sampleRate, &retrievalProgress);
+        retrievalProgress.done();
+    }
+
     if (!reader) {
         throw FailedToOpenFile(audioSource);
     }
 
-    size_t channels = reader->getChannelCount();
+    cerr << "Audio file \"" << audioSource.toStdString() << "\": "
+         << reader->getChannelCount() << "ch at " 
+         << reader->getNativeRate() << "Hz" << endl;
+    if (reader->getChannelCount() != m_channels ||
+        reader->getNativeRate() != m_sampleRate) {
+        cerr << "NOTE: File will be mixed or resampled for processing: "
+             << m_channels << "ch at " 
+             << m_sampleRate << "Hz" << endl;
+    }
 
-    retrievalProgress.done();
-
-    cerr << "Opened " << channels << "-channel file or URL \"" << audioSource.toStdString() << "\"" << endl;
-
-    // reject file if it has too few channels
-    if ((int)channels < m_channels) {
-        delete reader;
-        throw FileOperationFailed
-            (audioSource,
-             QString("read sufficient channels (found %1, require %2)")
-             .arg(channels).arg(m_channels));
-    }
-    
     // allocate audio buffers
     float **data = new float *[m_channels];
     for (int c = 0; c < m_channels; ++c) {
--- a/runner/FeatureExtractionManager.h	Mon Oct 18 14:16:17 2010 +0100
+++ b/runner/FeatureExtractionManager.h	Mon Oct 18 14:17:48 2010 +0100
@@ -20,6 +20,8 @@
 #include <set>
 #include <string>
 
+#include <QMap>
+
 #include <vamp-hostsdk/Plugin.h>
 #include <vamp-hostsdk/PluginSummarisingAdapter.h>
 #include <transform/Transform.h>
@@ -31,6 +33,7 @@
 using std::map;
 
 class FeatureWriter;
+class AudioFileReader;
 
 class FeatureExtractionManager
 {
@@ -54,6 +57,7 @@
     bool addDefaultFeatureExtractor(TransformId transformId,
                                     const vector<FeatureWriter*> &writers);
 
+    void addSource(QString audioSource);
     void extractFeatures(QString audioSource);
 
 private:
@@ -105,6 +109,8 @@
     int m_defaultSampleRate;
     int m_sampleRate;
     int m_channels;
+
+    QMap<QString, AudioFileReader *> m_readyReaders;
     
     void print(Transform transform) const;
 };
--- a/runner/main.cpp	Mon Oct 18 14:16:17 2010 +0100
+++ b/runner/main.cpp	Mon Oct 18 14:17:48 2010 +0100
@@ -23,6 +23,7 @@
 #include <QString>
 #include <QFileInfo>
 #include <QDir>
+#include <QSet>
 
 using std::cout;
 using std::cerr;
@@ -587,17 +588,6 @@
         }
     }
     
-    // the manager dictates the sample rate and number of channels
-    // to work at - files with too few channels are rejected,
-    // too many channels are handled as usual by the Vamp plugin
-
-    //!!! Review this: although we probably do want to fix the channel
-    // count here, we don't necessarily want to fix the rate: it's
-    // specified in the Transform file.
-
-    manager.setDefaultSampleRate(44100);
-    manager.setChannels(1);
-    
     vector<FeatureWriter *> writers;
 
     for (set<string>::const_iterator i = requestedWriterTags.begin();
@@ -682,27 +672,6 @@
         }
     }
 
-    bool haveFeatureExtractor = false;
-    
-    for (set<string>::const_iterator i = requestedTransformFiles.begin();
-         i != requestedTransformFiles.end(); ++i) {
-        if (manager.addFeatureExtractorFromFile(i->c_str(), writers)) {
-            haveFeatureExtractor = true;
-        }
-    }
-
-    for (set<string>::const_iterator i = requestedDefaultTransforms.begin();
-         i != requestedDefaultTransforms.end(); ++i) {
-        if (manager.addDefaultFeatureExtractor(i->c_str(), writers)) {
-            haveFeatureExtractor = true;
-        }
-    }
-
-    if (!haveFeatureExtractor) {
-        cerr << myname.toStdString() << ": no feature extractors added" << endl;
-        exit(2);
-    }
-
     QStringList sources;
     if (!recursive) {
         sources = otherArgs;
@@ -721,17 +690,16 @@
     }
 
     bool good = true;
+    QSet<QString> badSources;
 
     for (QStringList::const_iterator i = sources.begin();
          i != sources.end(); ++i) {
-        std::cerr << "Extracting features for: \"" << i->toStdString() << "\"" << std::endl;
         try {
-            manager.extractFeatures(*i);
+            manager.addSource(*i);
         } catch (const std::exception &e) {
+            badSources.insert(*i);
             cerr << "ERROR: Failed to process file \"" << i->toStdString()
                  << "\": " << e.what() << endl;
-            cerr << "NOTE: If you want to continue with processing any further files after an" << endl
-                 << "error like this, use the --force option" << endl;
             if (force) {
                 // print a note only if we have more files to process
                 QStringList::const_iterator j = i;
@@ -739,11 +707,63 @@
                     cerr << "NOTE: \"--force\" option was provided, continuing (more errors may occur)" << endl;
                 }
             } else {
+                cerr << "NOTE: If you want to continue with processing any further files after an" << endl
+                     << "error like this, use the --force option" << endl;
                 good = false;
                 break;
             }
         }
     }
+
+    if (good) {
+    
+        bool haveFeatureExtractor = false;
+    
+        for (set<string>::const_iterator i = requestedTransformFiles.begin();
+             i != requestedTransformFiles.end(); ++i) {
+            if (manager.addFeatureExtractorFromFile(i->c_str(), writers)) {
+                haveFeatureExtractor = true;
+            }
+        }
+
+        for (set<string>::const_iterator i = requestedDefaultTransforms.begin();
+             i != requestedDefaultTransforms.end(); ++i) {
+            if (manager.addDefaultFeatureExtractor(i->c_str(), writers)) {
+                haveFeatureExtractor = true;
+            }
+        }
+
+        if (!haveFeatureExtractor) {
+            cerr << myname.toStdString() << ": no feature extractors added" << endl;
+            good = false;
+        }
+    }
+
+    if (good) {
+        for (QStringList::const_iterator i = sources.begin();
+             i != sources.end(); ++i) {
+            if (badSources.contains(*i)) continue;
+            std::cerr << "Extracting features for: \"" << i->toStdString() << "\"" << std::endl;
+            try {
+                manager.extractFeatures(*i);
+            } catch (const std::exception &e) {
+                cerr << "ERROR: Feature extraction failed for \"" << i->toStdString()
+                     << "\": " << e.what() << endl;
+                if (force) {
+                    // print a note only if we have more files to process
+                    QStringList::const_iterator j = i;
+                    if (++j != sources.end()) {
+                        cerr << "NOTE: \"--force\" option was provided, continuing (more errors may occur)" << endl;
+                    }
+                } else {
+                    cerr << "NOTE: If you want to continue with processing any further files after an" << endl
+                         << "error like this, use the --force option" << endl;
+                    good = false;
+                    break;
+                }
+            }
+        }
+    }
     
     for (int i = 0; i < writers.size(); ++i) delete writers[i];