Mercurial > hg > gpsynth
view src/file_comparer.cpp @ 0:add35537fdbb tip
Initial import
author | irh <ian.r.hobson@gmail.com> |
---|---|
date | Thu, 25 Aug 2011 11:05:55 +0100 |
parents | |
children |
line wrap: on
line source
// Copyright 2011, Ian Hobson. // // This file is part of gpsynth. // // gpsynth is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // gpsynth is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with gpsynth in the file COPYING. // If not, see http://www.gnu.org/licenses/. #include "file_comparer.hpp" #include "boost_ex.hpp" #include "statistics.hpp" #include "std_ex.hpp" #include "boost/bind.hpp" #include "boost/algorithm/string/classification.hpp" #include "boost/algorithm/string/split.hpp" #include "boost/math/special_functions/fpclassify.hpp" #include <algorithm> #include <cmath> #include <cstddef> #include <iostream> #include <iterator> // maps the frequency range 20-20480Hz to the 0-1 range on a log scale // with 0.1 = 1 octave namespace { dsp::Value ConvertFrequency(dsp::Value frequency) { frequency = stdx::Clamp(frequency, 20.0, 20480.0); return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0; } void ConvertFrequencyList(dsp::ValueList& frequencies) { std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(), &ConvertFrequency); } } // namespace namespace dsp { // The FeatureComparer class abstracts common code required by feature set // comparisons. template<typename T> class FeatureComparer : public FeatureComparerInterface { public: typedef void (*ConversionFunction)(T&); typedef const T& (FeatureExtractor::*FeatureFunction)(); typedef typename T::value_type ValueT; typedef typename T::const_iterator IteratorT; private: // FeatureExtractor* extractor_; // A member function of FeatureExtractor that retrieves an audio feature FeatureFunction feature_function_; // Stores the target feature for comparison T target_buffer_; // Temporary use stored as member to prevent reallocations T comparison_buffer_; // Used if a feature data set needs to have a conversion applied to it ConversionFunction converter_; // the feature's ID int id_; // true if the feature should be scaled by the frame's energy bool scale_by_energy_; public: FeatureComparer(int id, FeatureExtractor* extractor, FeatureFunction feature_function, ConversionFunction converter = NULL, bool scale_by_energy = false) : id_(id), extractor_(extractor), feature_function_(feature_function), converter_(converter), scale_by_energy_(scale_by_energy) {} // copy constructor FeatureComparer(const FeatureComparer& other) : id_(other.id_), extractor_(other.extractor_), feature_function_(other.feature_function_), converter_(other.converter_), scale_by_energy_(other.scale_by_energy_), target_buffer_(other.target_buffer_) // copy the target value buffer // no need to copy comparison buffer {} virtual void AnalyzeTarget() { // retrieve the target feature target_buffer_ = (extractor_->*feature_function_)(); if (converter_ != NULL) { converter_(target_buffer_); } if (scale_by_energy_) { std::transform(target_buffer_.begin(), target_buffer_.end(), extractor_->Energy().begin(), target_buffer_.begin(), std::multiplies<ValueT>()); } } virtual Value Compare() { // get the values to compare const ValueList& feature = (extractor_->*feature_function_)(); ValueList* feature_buffer; if (converter_ != NULL || scale_by_energy_) { comparison_buffer_ = feature; if (converter_ != NULL) { converter_(comparison_buffer_); } if (scale_by_energy_) { std::transform(comparison_buffer_.begin(), comparison_buffer_.end(), extractor_->Energy().begin(), comparison_buffer_.begin(), std::multiplies<ValueT>()); } feature_buffer = &comparison_buffer_; } else { feature_buffer = const_cast<ValueList*>(&feature); } // call conversion function // find range of buffer to compare against target IteratorT buffer_start = feature_buffer->begin(); IteratorT buffer_end; if (feature_buffer->size() > target_buffer_.size()) { buffer_end = buffer_start + target_buffer_.size(); } else { buffer_end = feature_buffer->end(); } // return root mean square error of buffer differences return std::sqrt(stats::MeanSquaredError(buffer_start, buffer_end, target_buffer_.begin())); } virtual FeatureComparerInterface* Clone() const { return new FeatureComparer<T>(*this); } virtual void SetExtractor(FeatureExtractor* extractor) { extractor_ = extractor; } virtual int ID() const { return id_; } }; // specializations for std::vector<ValueList> template<> void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() { target_buffer_ = (extractor_->*feature_function_)(); if (converter_ != NULL) { converter_(target_buffer_); } // for vector features we can scale by energy on a frame by frame basis } template<> Value FeatureComparer<std::vector<ValueList> >::Compare() { // get the values to compare const std::vector<ValueList>& feature = (extractor_->*feature_function_)(); std::vector<ValueList>* feature_buffer; // call conversion function if (converter_ != NULL) { comparison_buffer_ = feature; converter_(comparison_buffer_); feature_buffer = &comparison_buffer_; } else { // this isn't pretty, but it's the only way to allow avoiding a copy // when no conversion is taking place.. feature_buffer = const_cast<std::vector<ValueList>*>(&feature); } // find how many frames to compare std::size_t frames_to_compare; if (feature_buffer->size() > target_buffer_.size()) { frames_to_compare = target_buffer_.size(); } else { frames_to_compare = feature_buffer->size(); } // take average of RMSE over all frames Value error = 0; const ValueList& energy = extractor_->Energy(); if (scale_by_energy_) { for (std::size_t i = 0; i < frames_to_compare; ++i) { Value frame_error; frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), target_buffer_[i])); error += frame_error * energy[i]; } } else { for (std::size_t i = 0; i < frames_to_compare; ++i) { error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), target_buffer_[i])); } } return error / frames_to_compare; } FileComparer::FileComparer(int window_size /* = 1024 */, int hop_size /* = 256 */) : extractor_("", window_size, hop_size), target_duration_(0) { EnableFeature(Feature::LogMagnitude); EnableFeature(Feature::Pitch); } FileComparer::FileComparer(const std::string& feature_list, int window_size /* = 1024 */, int hop_size /* = 256 */) : extractor_("", window_size, hop_size), target_duration_(0) { feature_names_["pitch"] = Feature::Pitch; feature_names_["energy"] = Feature::Energy; feature_names_["mfccs"] = Feature::MFCCs; feature_names_["dmfccs"] = Feature::DeltaMFCCs; feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs; feature_names_["mag"] = Feature::Magnitude; feature_names_["logmag"] = Feature::LogMagnitude; feature_names_["centroid"] = Feature::SpectralCentroid; feature_names_["spread"] = Feature::SpectralSpread; feature_names_["flux"] = Feature::SpectralFlux; EnableFeatures(feature_list); } FileComparer::FileComparer(const FileComparer& other) : target_file_(other.target_file_), target_duration_(other.target_duration_), extractor_("", other.extractor_.WindowSize(), other.extractor_.HopSize()) { foreach (const FeatureComparerPtr& pointer, other.features_) { // clone the feature comparer features_.push_back(FeatureComparerPtr(pointer->Clone())); features_.back()->SetExtractor(&extractor_); } } void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) { bool reload_target = false; if (extractor_.WindowSize() != window_size) { extractor_.SetWindowSize(window_size); reload_target = true; } if (extractor_.HopSize() != hop_size) { extractor_.SetHopSize(hop_size); reload_target = true; } if (reload_target && !target_file_.empty()) { SetTargetFile(target_file_); } } void FileComparer::SetTargetFile(const std::string& target_file) { if (target_file_ != target_file) { target_file_ = target_file; extractor_.LoadFile(target_file); target_duration_ = extractor_.Duration(); std::for_each(features_.begin(), features_.end(), boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1)); } } Value FileComparer::CompareFile(const std::string& file_path) { extractor_.LoadFile(file_path); Value error; foreach (FeatureComparerPtr& feature, features_) { error += feature->Compare(); } return error / features_.size(); } namespace { // Helper function for creating FeatureComparers template<typename T> FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id, FeatureExtractor& extractor, const T& (FeatureExtractor::*feature)(), void (*converter)(T&) = NULL, bool scale_by_energy = false) { return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id), &extractor, feature, converter, scale_by_energy)); } } // namespace void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) { // check if the requested feature is already enabled for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(), end = features_.end(); feature != end; ++feature) { if ((*feature)->ID() == feature_id) { // feature enabled, if enable == false then remove the feature if (!enable) { features_.erase(feature); } return; } } // make the feature comparer FeatureComparerPtr feature; switch (feature_id) { case Feature::Pitch: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::Pitch, ConvertFrequencyList, true); break; case Feature::Energy: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::Energy); break; case Feature::MFCCs: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::MFCCs); break; case Feature::DeltaMFCCs: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::DeltaMFCCs); break; case Feature::DoubleDeltaMFCCs: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::DoubleDeltaMFCCs); break; case Feature::Magnitude: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::MagnitudeSpectrum); case Feature::LogMagnitude: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::LogMagnitudeSpectrum); break; case Feature::SpectralCentroid: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::SpectralCentroid, ConvertFrequencyList, true); break; case Feature::SpectralSpread: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::SpectralSpread, ConvertFrequencyList, true); break; case Feature::SpectralFlux: feature = MakeFeatureComparer(feature_id, extractor_, &FeatureExtractor::SpectralFlux); break; default: throw std::runtime_error("FileComparer::EnableFeature - Invalid ID"); } // analyze the feature if we already have a target file loaded if (!target_file_.empty()) { feature->AnalyzeTarget(); } // store the feature features_.push_back(feature); } void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features, bool enable /* = true */) { foreach (Feature::ID id, features) { EnableFeature(id, enable); } } void FileComparer::EnableFeatures(std::string feature_name_list, bool enable /* = true */) { std::vector<std::string> feature_names; // split the comma separated list of names boost::algorithm::split(feature_names, feature_name_list, boost::algorithm::is_any_of(", "), boost::algorithm::token_compress_on); // enable each feature foreach (const std::string& feature_name, feature_names) { // check if the requested feature name is valid if (feature_names_.find(feature_name) == feature_names_.end()) { std::stringstream message; message << "FileComparer::EnableFeatures - Unknown feature '" << feature_name << "'"; throw std::runtime_error(message.str()); } EnableFeature(feature_names_[feature_name], enable); } } void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) { features_.clear(); EnableFeatures(features); } } // dsp namespace