view src/file_comparer.cpp @ 0:add35537fdbb tip

Initial import
author irh <ian.r.hobson@gmail.com>
date Thu, 25 Aug 2011 11:05:55 +0100
parents
children
line wrap: on
line source
//  Copyright 2011, Ian Hobson.
//
//  This file is part of gpsynth.
//
//  gpsynth is free software: you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation, either version 3 of the License, or
//  (at your option) any later version.
//
//  gpsynth is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with gpsynth in the file COPYING. 
//  If not, see http://www.gnu.org/licenses/.

#include "file_comparer.hpp"

#include "boost_ex.hpp"
#include "statistics.hpp"
#include "std_ex.hpp"

#include "boost/bind.hpp"
#include "boost/algorithm/string/classification.hpp"
#include "boost/algorithm/string/split.hpp"
#include "boost/math/special_functions/fpclassify.hpp"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <iostream>
#include <iterator>

// maps the frequency range 20-20480Hz to the 0-1 range on a log scale
// with 0.1 = 1 octave
namespace {
 
dsp::Value ConvertFrequency(dsp::Value frequency) {
  frequency = stdx::Clamp(frequency, 20.0, 20480.0);
  return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0;
}
  
void ConvertFrequencyList(dsp::ValueList& frequencies) {
  std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(),
                 &ConvertFrequency);
}
  
} // namespace

namespace dsp {

// The FeatureComparer class abstracts common code required by feature set
// comparisons. 
template<typename T>
class FeatureComparer : public FeatureComparerInterface {  
public:
  typedef void (*ConversionFunction)(T&);
  typedef const T& (FeatureExtractor::*FeatureFunction)();
  typedef typename T::value_type ValueT;
  typedef typename T::const_iterator IteratorT;

private:
  // 
  FeatureExtractor* extractor_;
  // A member function of FeatureExtractor that retrieves an audio feature
  FeatureFunction feature_function_;
  // Stores the target feature for comparison
  T target_buffer_;
  // Temporary use stored as member to prevent reallocations
  T comparison_buffer_;
  // Used if a feature data set needs to have a conversion applied to it
  ConversionFunction converter_;
  // the feature's ID
  int id_;
  // true if the feature should be scaled by the frame's energy
  bool scale_by_energy_;
  
public:
  FeatureComparer(int id,
                  FeatureExtractor* extractor,
                  FeatureFunction feature_function,
                  ConversionFunction converter = NULL,
                  bool scale_by_energy = false) 
  : id_(id),
    extractor_(extractor),
    feature_function_(feature_function),
    converter_(converter),
    scale_by_energy_(scale_by_energy)
  {}
   
  // copy constructor
  FeatureComparer(const FeatureComparer& other)
  : id_(other.id_),
    extractor_(other.extractor_),
    feature_function_(other.feature_function_),
    converter_(other.converter_),
    scale_by_energy_(other.scale_by_energy_),
    target_buffer_(other.target_buffer_) // copy the target value buffer
    // no need to copy comparison buffer
  {}
  
  virtual void AnalyzeTarget() {
    // retrieve the target feature
    target_buffer_ = (extractor_->*feature_function_)();
    if (converter_ != NULL) {
      converter_(target_buffer_);
    }
    if (scale_by_energy_) {
      std::transform(target_buffer_.begin(), target_buffer_.end(),
                     extractor_->Energy().begin(),
                     target_buffer_.begin(),
                     std::multiplies<ValueT>());
    }
  }
  
  virtual Value Compare() {
    // get the values to compare
    const ValueList& feature = (extractor_->*feature_function_)();
    ValueList* feature_buffer;
    if (converter_ != NULL || scale_by_energy_) {
      comparison_buffer_ = feature;
      if (converter_ != NULL) {
        converter_(comparison_buffer_);
      }
      if (scale_by_energy_) {
        std::transform(comparison_buffer_.begin(), comparison_buffer_.end(),
                       extractor_->Energy().begin(),
                       comparison_buffer_.begin(),
                       std::multiplies<ValueT>());
      }
      feature_buffer = &comparison_buffer_;
    } else {
      feature_buffer = const_cast<ValueList*>(&feature);
    }
    // call conversion function

    // find range of buffer to compare against target
    IteratorT buffer_start = feature_buffer->begin();
    IteratorT buffer_end;
    if (feature_buffer->size() > target_buffer_.size()) {
      buffer_end = buffer_start + target_buffer_.size();
    } else {
      buffer_end = feature_buffer->end();
    }
    // return root mean square error of buffer differences
    return std::sqrt(stats::MeanSquaredError(buffer_start,
                                             buffer_end,
                                             target_buffer_.begin()));
  }
  
  virtual FeatureComparerInterface* Clone() const {
    return new FeatureComparer<T>(*this);
  }

  virtual void SetExtractor(FeatureExtractor* extractor) {
    extractor_ = extractor;
  }
  
  virtual int ID() const { return id_; }
};

// specializations for std::vector<ValueList>
template<>
void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() {
  target_buffer_ = (extractor_->*feature_function_)();
  if (converter_ != NULL) {
    converter_(target_buffer_);
  }
  // for vector features we can scale by energy on a frame by frame basis 
}

template<>
Value FeatureComparer<std::vector<ValueList> >::Compare() {
  // get the values to compare
  const std::vector<ValueList>& feature = (extractor_->*feature_function_)();
  std::vector<ValueList>* feature_buffer;
  // call conversion function
  if (converter_ != NULL) {
    comparison_buffer_ = feature;
    converter_(comparison_buffer_);
    feature_buffer = &comparison_buffer_;
  } else {
    // this isn't pretty, but it's the only way to allow avoiding a copy
    // when no conversion is taking place..
    feature_buffer = const_cast<std::vector<ValueList>*>(&feature);
  }
  // find how many frames to compare
  std::size_t frames_to_compare;
  if (feature_buffer->size() > target_buffer_.size()) {
    frames_to_compare = target_buffer_.size();
  } else {
    frames_to_compare = feature_buffer->size();
  }
  // take average of RMSE over all frames
  Value error = 0;
  const ValueList& energy = extractor_->Energy();
  if (scale_by_energy_) {
    for (std::size_t i = 0; i < frames_to_compare; ++i) {
      Value frame_error;
      frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
                                                      target_buffer_[i]));
      error += frame_error * energy[i];
    }
  } else {
    for (std::size_t i = 0; i < frames_to_compare; ++i) {
      error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
                                                 target_buffer_[i]));
    }
  }
  return error / frames_to_compare;
}

  
FileComparer::FileComparer(int window_size /* = 1024 */,
                           int hop_size /* = 256 */)
: extractor_("", window_size, hop_size),
  target_duration_(0)
{
  EnableFeature(Feature::LogMagnitude);
  EnableFeature(Feature::Pitch);
}
  
FileComparer::FileComparer(const std::string& feature_list,
                           int window_size /* = 1024 */,
                           int hop_size /* = 256 */)
: extractor_("", window_size, hop_size),
  target_duration_(0)
{
  feature_names_["pitch"] = Feature::Pitch;
  feature_names_["energy"] = Feature::Energy;
  feature_names_["mfccs"] = Feature::MFCCs;
  feature_names_["dmfccs"] = Feature::DeltaMFCCs;
  feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs;
  feature_names_["mag"] = Feature::Magnitude;
  feature_names_["logmag"] = Feature::LogMagnitude;
  feature_names_["centroid"] = Feature::SpectralCentroid;
  feature_names_["spread"] = Feature::SpectralSpread;
  feature_names_["flux"] = Feature::SpectralFlux;
  EnableFeatures(feature_list);
}
  
FileComparer::FileComparer(const FileComparer& other)
: target_file_(other.target_file_),
  target_duration_(other.target_duration_),
  extractor_("",
             other.extractor_.WindowSize(),
             other.extractor_.HopSize())
{
  foreach (const FeatureComparerPtr& pointer, other.features_) {
    // clone the feature comparer
    features_.push_back(FeatureComparerPtr(pointer->Clone()));
    features_.back()->SetExtractor(&extractor_);
  }
}
  
void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) {
  bool reload_target = false;
  if (extractor_.WindowSize() != window_size) {
    extractor_.SetWindowSize(window_size);
    reload_target = true;
  }
  if (extractor_.HopSize() != hop_size) {
    extractor_.SetHopSize(hop_size);
    reload_target = true;
  }
  if (reload_target && !target_file_.empty()) {
    SetTargetFile(target_file_);
  }
}

void FileComparer::SetTargetFile(const std::string& target_file) {
  if (target_file_ != target_file) {
    target_file_ = target_file;
    extractor_.LoadFile(target_file);
    target_duration_ = extractor_.Duration();
    std::for_each(features_.begin(), features_.end(),
                  boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1));
  }
}


Value FileComparer::CompareFile(const std::string& file_path) {
  extractor_.LoadFile(file_path);
  Value error;
  foreach (FeatureComparerPtr& feature, features_) {
    error += feature->Compare();
  }
  return error / features_.size();
}
  
namespace {

// Helper function for creating FeatureComparers
template<typename T>
FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id,
                                       FeatureExtractor& extractor,
                                       const T& (FeatureExtractor::*feature)(),
                                       void (*converter)(T&) = NULL,
                                       bool scale_by_energy = false) {
  return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id),
                                                   &extractor,
                                                   feature,
                                                   converter,
                                                   scale_by_energy));
}

} // namespace
  
void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) {
  // check if the requested feature is already enabled
  for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(),
       end = features_.end();
       feature != end;
       ++feature) {
    if ((*feature)->ID() == feature_id) {
      // feature enabled, if enable == false then remove the feature
      if (!enable) {
        features_.erase(feature);
      }
      return;
    }
  }
  // make the feature comparer
  FeatureComparerPtr feature;
  switch (feature_id) {
    case Feature::Pitch:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::Pitch,
                                    ConvertFrequencyList,
                                    true);
      break;
    case Feature::Energy:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::Energy);
      break;
    case Feature::MFCCs:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::MFCCs);
      break;
    case Feature::DeltaMFCCs:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::DeltaMFCCs);
      break;      
    case Feature::DoubleDeltaMFCCs:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::DoubleDeltaMFCCs);
      break;
    case Feature::Magnitude:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::MagnitudeSpectrum);
    case Feature::LogMagnitude:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::LogMagnitudeSpectrum);
      break;
    case Feature::SpectralCentroid:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::SpectralCentroid,
                                    ConvertFrequencyList,
                                    true);
      break;
    case Feature::SpectralSpread:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::SpectralSpread,
                                    ConvertFrequencyList,
                                    true);
      break;
    case Feature::SpectralFlux:
      feature = MakeFeatureComparer(feature_id,
                                    extractor_,
                                    &FeatureExtractor::SpectralFlux);
      break;
    default:
      throw std::runtime_error("FileComparer::EnableFeature - Invalid ID");
  }
  // analyze the feature if we already have a target file loaded
  if (!target_file_.empty()) {
    feature->AnalyzeTarget();
  }
  // store the feature
  features_.push_back(feature);
}  

void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features,
                                  bool enable /* = true */) {
  foreach (Feature::ID id, features) {
    EnableFeature(id, enable);
  }
}
  
void FileComparer::EnableFeatures(std::string feature_name_list,
                                  bool enable /* = true */) {
  std::vector<std::string> feature_names;
  // split the comma separated list of names
  boost::algorithm::split(feature_names,
                          feature_name_list,
                          boost::algorithm::is_any_of(", "),
                          boost::algorithm::token_compress_on);
  // enable each feature
  foreach (const std::string& feature_name, feature_names) {
    // check if the requested feature name is valid
    if (feature_names_.find(feature_name) == feature_names_.end()) {
      std::stringstream message;
      message << "FileComparer::EnableFeatures - Unknown feature '"
              << feature_name << "'";
      throw std::runtime_error(message.str());
    }
    EnableFeature(feature_names_[feature_name], enable);
  }
}
  
void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) {
  features_.clear();
  EnableFeatures(features);
}

} // dsp namespace