diff src/file_comparer.cpp @ 0:add35537fdbb tip

Initial import
author irh <ian.r.hobson@gmail.com>
date Thu, 25 Aug 2011 11:05:55 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/file_comparer.cpp	Thu Aug 25 11:05:55 2011 +0100
@@ -0,0 +1,427 @@
+//  Copyright 2011, Ian Hobson.
+//
+//  This file is part of gpsynth.
+//
+//  gpsynth is free software: you can redistribute it and/or modify
+//  it under the terms of the GNU General Public License as published by
+//  the Free Software Foundation, either version 3 of the License, or
+//  (at your option) any later version.
+//
+//  gpsynth is distributed in the hope that it will be useful,
+//  but WITHOUT ANY WARRANTY; without even the implied warranty of
+//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+//  GNU General Public License for more details.
+//
+//  You should have received a copy of the GNU General Public License
+//  along with gpsynth in the file COPYING. 
+//  If not, see http://www.gnu.org/licenses/.
+
+#include "file_comparer.hpp"
+
+#include "boost_ex.hpp"
+#include "statistics.hpp"
+#include "std_ex.hpp"
+
+#include "boost/bind.hpp"
+#include "boost/algorithm/string/classification.hpp"
+#include "boost/algorithm/string/split.hpp"
+#include "boost/math/special_functions/fpclassify.hpp"
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <iostream>
+#include <iterator>
+
+// maps the frequency range 20-20480Hz to the 0-1 range on a log scale
+// with 0.1 = 1 octave
+namespace {
+ 
+dsp::Value ConvertFrequency(dsp::Value frequency) {
+  frequency = stdx::Clamp(frequency, 20.0, 20480.0);
+  return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0;
+}
+  
+void ConvertFrequencyList(dsp::ValueList& frequencies) {
+  std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(),
+                 &ConvertFrequency);
+}
+  
+} // namespace
+
+namespace dsp {
+
+// The FeatureComparer class abstracts common code required by feature set
+// comparisons. 
+template<typename T>
+class FeatureComparer : public FeatureComparerInterface {  
+public:
+  typedef void (*ConversionFunction)(T&);
+  typedef const T& (FeatureExtractor::*FeatureFunction)();
+  typedef typename T::value_type ValueT;
+  typedef typename T::const_iterator IteratorT;
+
+private:
+  // 
+  FeatureExtractor* extractor_;
+  // A member function of FeatureExtractor that retrieves an audio feature
+  FeatureFunction feature_function_;
+  // Stores the target feature for comparison
+  T target_buffer_;
+  // Temporary use stored as member to prevent reallocations
+  T comparison_buffer_;
+  // Used if a feature data set needs to have a conversion applied to it
+  ConversionFunction converter_;
+  // the feature's ID
+  int id_;
+  // true if the feature should be scaled by the frame's energy
+  bool scale_by_energy_;
+  
+public:
+  FeatureComparer(int id,
+                  FeatureExtractor* extractor,
+                  FeatureFunction feature_function,
+                  ConversionFunction converter = NULL,
+                  bool scale_by_energy = false) 
+  : id_(id),
+    extractor_(extractor),
+    feature_function_(feature_function),
+    converter_(converter),
+    scale_by_energy_(scale_by_energy)
+  {}
+   
+  // copy constructor
+  FeatureComparer(const FeatureComparer& other)
+  : id_(other.id_),
+    extractor_(other.extractor_),
+    feature_function_(other.feature_function_),
+    converter_(other.converter_),
+    scale_by_energy_(other.scale_by_energy_),
+    target_buffer_(other.target_buffer_) // copy the target value buffer
+    // no need to copy comparison buffer
+  {}
+  
+  virtual void AnalyzeTarget() {
+    // retrieve the target feature
+    target_buffer_ = (extractor_->*feature_function_)();
+    if (converter_ != NULL) {
+      converter_(target_buffer_);
+    }
+    if (scale_by_energy_) {
+      std::transform(target_buffer_.begin(), target_buffer_.end(),
+                     extractor_->Energy().begin(),
+                     target_buffer_.begin(),
+                     std::multiplies<ValueT>());
+    }
+  }
+  
+  virtual Value Compare() {
+    // get the values to compare
+    const ValueList& feature = (extractor_->*feature_function_)();
+    ValueList* feature_buffer;
+    if (converter_ != NULL || scale_by_energy_) {
+      comparison_buffer_ = feature;
+      if (converter_ != NULL) {
+        converter_(comparison_buffer_);
+      }
+      if (scale_by_energy_) {
+        std::transform(comparison_buffer_.begin(), comparison_buffer_.end(),
+                       extractor_->Energy().begin(),
+                       comparison_buffer_.begin(),
+                       std::multiplies<ValueT>());
+      }
+      feature_buffer = &comparison_buffer_;
+    } else {
+      feature_buffer = const_cast<ValueList*>(&feature);
+    }
+    // call conversion function
+
+    // find range of buffer to compare against target
+    IteratorT buffer_start = feature_buffer->begin();
+    IteratorT buffer_end;
+    if (feature_buffer->size() > target_buffer_.size()) {
+      buffer_end = buffer_start + target_buffer_.size();
+    } else {
+      buffer_end = feature_buffer->end();
+    }
+    // return root mean square error of buffer differences
+    return std::sqrt(stats::MeanSquaredError(buffer_start,
+                                             buffer_end,
+                                             target_buffer_.begin()));
+  }
+  
+  virtual FeatureComparerInterface* Clone() const {
+    return new FeatureComparer<T>(*this);
+  }
+
+  virtual void SetExtractor(FeatureExtractor* extractor) {
+    extractor_ = extractor;
+  }
+  
+  virtual int ID() const { return id_; }
+};
+
+// specializations for std::vector<ValueList>
+template<>
+void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() {
+  target_buffer_ = (extractor_->*feature_function_)();
+  if (converter_ != NULL) {
+    converter_(target_buffer_);
+  }
+  // for vector features we can scale by energy on a frame by frame basis 
+}
+
+template<>
+Value FeatureComparer<std::vector<ValueList> >::Compare() {
+  // get the values to compare
+  const std::vector<ValueList>& feature = (extractor_->*feature_function_)();
+  std::vector<ValueList>* feature_buffer;
+  // call conversion function
+  if (converter_ != NULL) {
+    comparison_buffer_ = feature;
+    converter_(comparison_buffer_);
+    feature_buffer = &comparison_buffer_;
+  } else {
+    // this isn't pretty, but it's the only way to allow avoiding a copy
+    // when no conversion is taking place..
+    feature_buffer = const_cast<std::vector<ValueList>*>(&feature);
+  }
+  // find how many frames to compare
+  std::size_t frames_to_compare;
+  if (feature_buffer->size() > target_buffer_.size()) {
+    frames_to_compare = target_buffer_.size();
+  } else {
+    frames_to_compare = feature_buffer->size();
+  }
+  // take average of RMSE over all frames
+  Value error = 0;
+  const ValueList& energy = extractor_->Energy();
+  if (scale_by_energy_) {
+    for (std::size_t i = 0; i < frames_to_compare; ++i) {
+      Value frame_error;
+      frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
+                                                      target_buffer_[i]));
+      error += frame_error * energy[i];
+    }
+  } else {
+    for (std::size_t i = 0; i < frames_to_compare; ++i) {
+      error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
+                                                 target_buffer_[i]));
+    }
+  }
+  return error / frames_to_compare;
+}
+
+  
+FileComparer::FileComparer(int window_size /* = 1024 */,
+                           int hop_size /* = 256 */)
+: extractor_("", window_size, hop_size),
+  target_duration_(0)
+{
+  EnableFeature(Feature::LogMagnitude);
+  EnableFeature(Feature::Pitch);
+}
+  
+FileComparer::FileComparer(const std::string& feature_list,
+                           int window_size /* = 1024 */,
+                           int hop_size /* = 256 */)
+: extractor_("", window_size, hop_size),
+  target_duration_(0)
+{
+  feature_names_["pitch"] = Feature::Pitch;
+  feature_names_["energy"] = Feature::Energy;
+  feature_names_["mfccs"] = Feature::MFCCs;
+  feature_names_["dmfccs"] = Feature::DeltaMFCCs;
+  feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs;
+  feature_names_["mag"] = Feature::Magnitude;
+  feature_names_["logmag"] = Feature::LogMagnitude;
+  feature_names_["centroid"] = Feature::SpectralCentroid;
+  feature_names_["spread"] = Feature::SpectralSpread;
+  feature_names_["flux"] = Feature::SpectralFlux;
+  EnableFeatures(feature_list);
+}
+  
+FileComparer::FileComparer(const FileComparer& other)
+: target_file_(other.target_file_),
+  target_duration_(other.target_duration_),
+  extractor_("",
+             other.extractor_.WindowSize(),
+             other.extractor_.HopSize())
+{
+  foreach (const FeatureComparerPtr& pointer, other.features_) {
+    // clone the feature comparer
+    features_.push_back(FeatureComparerPtr(pointer->Clone()));
+    features_.back()->SetExtractor(&extractor_);
+  }
+}
+  
+void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) {
+  bool reload_target = false;
+  if (extractor_.WindowSize() != window_size) {
+    extractor_.SetWindowSize(window_size);
+    reload_target = true;
+  }
+  if (extractor_.HopSize() != hop_size) {
+    extractor_.SetHopSize(hop_size);
+    reload_target = true;
+  }
+  if (reload_target && !target_file_.empty()) {
+    SetTargetFile(target_file_);
+  }
+}
+
+void FileComparer::SetTargetFile(const std::string& target_file) {
+  if (target_file_ != target_file) {
+    target_file_ = target_file;
+    extractor_.LoadFile(target_file);
+    target_duration_ = extractor_.Duration();
+    std::for_each(features_.begin(), features_.end(),
+                  boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1));
+  }
+}
+
+
+Value FileComparer::CompareFile(const std::string& file_path) {
+  extractor_.LoadFile(file_path);
+  Value error;
+  foreach (FeatureComparerPtr& feature, features_) {
+    error += feature->Compare();
+  }
+  return error / features_.size();
+}
+  
+namespace {
+
+// Helper function for creating FeatureComparers
+template<typename T>
+FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id,
+                                       FeatureExtractor& extractor,
+                                       const T& (FeatureExtractor::*feature)(),
+                                       void (*converter)(T&) = NULL,
+                                       bool scale_by_energy = false) {
+  return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id),
+                                                   &extractor,
+                                                   feature,
+                                                   converter,
+                                                   scale_by_energy));
+}
+
+} // namespace
+  
+void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) {
+  // check if the requested feature is already enabled
+  for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(),
+       end = features_.end();
+       feature != end;
+       ++feature) {
+    if ((*feature)->ID() == feature_id) {
+      // feature enabled, if enable == false then remove the feature
+      if (!enable) {
+        features_.erase(feature);
+      }
+      return;
+    }
+  }
+  // make the feature comparer
+  FeatureComparerPtr feature;
+  switch (feature_id) {
+    case Feature::Pitch:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::Pitch,
+                                    ConvertFrequencyList,
+                                    true);
+      break;
+    case Feature::Energy:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::Energy);
+      break;
+    case Feature::MFCCs:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::MFCCs);
+      break;
+    case Feature::DeltaMFCCs:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::DeltaMFCCs);
+      break;      
+    case Feature::DoubleDeltaMFCCs:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::DoubleDeltaMFCCs);
+      break;
+    case Feature::Magnitude:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::MagnitudeSpectrum);
+    case Feature::LogMagnitude:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::LogMagnitudeSpectrum);
+      break;
+    case Feature::SpectralCentroid:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::SpectralCentroid,
+                                    ConvertFrequencyList,
+                                    true);
+      break;
+    case Feature::SpectralSpread:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::SpectralSpread,
+                                    ConvertFrequencyList,
+                                    true);
+      break;
+    case Feature::SpectralFlux:
+      feature = MakeFeatureComparer(feature_id,
+                                    extractor_,
+                                    &FeatureExtractor::SpectralFlux);
+      break;
+    default:
+      throw std::runtime_error("FileComparer::EnableFeature - Invalid ID");
+  }
+  // analyze the feature if we already have a target file loaded
+  if (!target_file_.empty()) {
+    feature->AnalyzeTarget();
+  }
+  // store the feature
+  features_.push_back(feature);
+}  
+
+void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features,
+                                  bool enable /* = true */) {
+  foreach (Feature::ID id, features) {
+    EnableFeature(id, enable);
+  }
+}
+  
+void FileComparer::EnableFeatures(std::string feature_name_list,
+                                  bool enable /* = true */) {
+  std::vector<std::string> feature_names;
+  // split the comma separated list of names
+  boost::algorithm::split(feature_names,
+                          feature_name_list,
+                          boost::algorithm::is_any_of(", "),
+                          boost::algorithm::token_compress_on);
+  // enable each feature
+  foreach (const std::string& feature_name, feature_names) {
+    // check if the requested feature name is valid
+    if (feature_names_.find(feature_name) == feature_names_.end()) {
+      std::stringstream message;
+      message << "FileComparer::EnableFeatures - Unknown feature '"
+              << feature_name << "'";
+      throw std::runtime_error(message.str());
+    }
+    EnableFeature(feature_names_[feature_name], enable);
+  }
+}
+  
+void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) {
+  features_.clear();
+  EnableFeatures(features);
+}
+
+} // dsp namespace