Mercurial > hg > gpsynth
diff src/file_comparer.cpp @ 0:add35537fdbb tip
Initial import
author | irh <ian.r.hobson@gmail.com> |
---|---|
date | Thu, 25 Aug 2011 11:05:55 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/file_comparer.cpp Thu Aug 25 11:05:55 2011 +0100 @@ -0,0 +1,427 @@ +// Copyright 2011, Ian Hobson. +// +// This file is part of gpsynth. +// +// gpsynth is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// gpsynth is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with gpsynth in the file COPYING. +// If not, see http://www.gnu.org/licenses/. + +#include "file_comparer.hpp" + +#include "boost_ex.hpp" +#include "statistics.hpp" +#include "std_ex.hpp" + +#include "boost/bind.hpp" +#include "boost/algorithm/string/classification.hpp" +#include "boost/algorithm/string/split.hpp" +#include "boost/math/special_functions/fpclassify.hpp" + +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <iostream> +#include <iterator> + +// maps the frequency range 20-20480Hz to the 0-1 range on a log scale +// with 0.1 = 1 octave +namespace { + +dsp::Value ConvertFrequency(dsp::Value frequency) { + frequency = stdx::Clamp(frequency, 20.0, 20480.0); + return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0; +} + +void ConvertFrequencyList(dsp::ValueList& frequencies) { + std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(), + &ConvertFrequency); +} + +} // namespace + +namespace dsp { + +// The FeatureComparer class abstracts common code required by feature set +// comparisons. +template<typename T> +class FeatureComparer : public FeatureComparerInterface { +public: + typedef void (*ConversionFunction)(T&); + typedef const T& (FeatureExtractor::*FeatureFunction)(); + typedef typename T::value_type ValueT; + typedef typename T::const_iterator IteratorT; + +private: + // + FeatureExtractor* extractor_; + // A member function of FeatureExtractor that retrieves an audio feature + FeatureFunction feature_function_; + // Stores the target feature for comparison + T target_buffer_; + // Temporary use stored as member to prevent reallocations + T comparison_buffer_; + // Used if a feature data set needs to have a conversion applied to it + ConversionFunction converter_; + // the feature's ID + int id_; + // true if the feature should be scaled by the frame's energy + bool scale_by_energy_; + +public: + FeatureComparer(int id, + FeatureExtractor* extractor, + FeatureFunction feature_function, + ConversionFunction converter = NULL, + bool scale_by_energy = false) + : id_(id), + extractor_(extractor), + feature_function_(feature_function), + converter_(converter), + scale_by_energy_(scale_by_energy) + {} + + // copy constructor + FeatureComparer(const FeatureComparer& other) + : id_(other.id_), + extractor_(other.extractor_), + feature_function_(other.feature_function_), + converter_(other.converter_), + scale_by_energy_(other.scale_by_energy_), + target_buffer_(other.target_buffer_) // copy the target value buffer + // no need to copy comparison buffer + {} + + virtual void AnalyzeTarget() { + // retrieve the target feature + target_buffer_ = (extractor_->*feature_function_)(); + if (converter_ != NULL) { + converter_(target_buffer_); + } + if (scale_by_energy_) { + std::transform(target_buffer_.begin(), target_buffer_.end(), + extractor_->Energy().begin(), + target_buffer_.begin(), + std::multiplies<ValueT>()); + } + } + + virtual Value Compare() { + // get the values to compare + const ValueList& feature = (extractor_->*feature_function_)(); + ValueList* feature_buffer; + if (converter_ != NULL || scale_by_energy_) { + comparison_buffer_ = feature; + if (converter_ != NULL) { + converter_(comparison_buffer_); + } + if (scale_by_energy_) { + std::transform(comparison_buffer_.begin(), comparison_buffer_.end(), + extractor_->Energy().begin(), + comparison_buffer_.begin(), + std::multiplies<ValueT>()); + } + feature_buffer = &comparison_buffer_; + } else { + feature_buffer = const_cast<ValueList*>(&feature); + } + // call conversion function + + // find range of buffer to compare against target + IteratorT buffer_start = feature_buffer->begin(); + IteratorT buffer_end; + if (feature_buffer->size() > target_buffer_.size()) { + buffer_end = buffer_start + target_buffer_.size(); + } else { + buffer_end = feature_buffer->end(); + } + // return root mean square error of buffer differences + return std::sqrt(stats::MeanSquaredError(buffer_start, + buffer_end, + target_buffer_.begin())); + } + + virtual FeatureComparerInterface* Clone() const { + return new FeatureComparer<T>(*this); + } + + virtual void SetExtractor(FeatureExtractor* extractor) { + extractor_ = extractor; + } + + virtual int ID() const { return id_; } +}; + +// specializations for std::vector<ValueList> +template<> +void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() { + target_buffer_ = (extractor_->*feature_function_)(); + if (converter_ != NULL) { + converter_(target_buffer_); + } + // for vector features we can scale by energy on a frame by frame basis +} + +template<> +Value FeatureComparer<std::vector<ValueList> >::Compare() { + // get the values to compare + const std::vector<ValueList>& feature = (extractor_->*feature_function_)(); + std::vector<ValueList>* feature_buffer; + // call conversion function + if (converter_ != NULL) { + comparison_buffer_ = feature; + converter_(comparison_buffer_); + feature_buffer = &comparison_buffer_; + } else { + // this isn't pretty, but it's the only way to allow avoiding a copy + // when no conversion is taking place.. + feature_buffer = const_cast<std::vector<ValueList>*>(&feature); + } + // find how many frames to compare + std::size_t frames_to_compare; + if (feature_buffer->size() > target_buffer_.size()) { + frames_to_compare = target_buffer_.size(); + } else { + frames_to_compare = feature_buffer->size(); + } + // take average of RMSE over all frames + Value error = 0; + const ValueList& energy = extractor_->Energy(); + if (scale_by_energy_) { + for (std::size_t i = 0; i < frames_to_compare; ++i) { + Value frame_error; + frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), + target_buffer_[i])); + error += frame_error * energy[i]; + } + } else { + for (std::size_t i = 0; i < frames_to_compare; ++i) { + error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), + target_buffer_[i])); + } + } + return error / frames_to_compare; +} + + +FileComparer::FileComparer(int window_size /* = 1024 */, + int hop_size /* = 256 */) +: extractor_("", window_size, hop_size), + target_duration_(0) +{ + EnableFeature(Feature::LogMagnitude); + EnableFeature(Feature::Pitch); +} + +FileComparer::FileComparer(const std::string& feature_list, + int window_size /* = 1024 */, + int hop_size /* = 256 */) +: extractor_("", window_size, hop_size), + target_duration_(0) +{ + feature_names_["pitch"] = Feature::Pitch; + feature_names_["energy"] = Feature::Energy; + feature_names_["mfccs"] = Feature::MFCCs; + feature_names_["dmfccs"] = Feature::DeltaMFCCs; + feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs; + feature_names_["mag"] = Feature::Magnitude; + feature_names_["logmag"] = Feature::LogMagnitude; + feature_names_["centroid"] = Feature::SpectralCentroid; + feature_names_["spread"] = Feature::SpectralSpread; + feature_names_["flux"] = Feature::SpectralFlux; + EnableFeatures(feature_list); +} + +FileComparer::FileComparer(const FileComparer& other) +: target_file_(other.target_file_), + target_duration_(other.target_duration_), + extractor_("", + other.extractor_.WindowSize(), + other.extractor_.HopSize()) +{ + foreach (const FeatureComparerPtr& pointer, other.features_) { + // clone the feature comparer + features_.push_back(FeatureComparerPtr(pointer->Clone())); + features_.back()->SetExtractor(&extractor_); + } +} + +void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) { + bool reload_target = false; + if (extractor_.WindowSize() != window_size) { + extractor_.SetWindowSize(window_size); + reload_target = true; + } + if (extractor_.HopSize() != hop_size) { + extractor_.SetHopSize(hop_size); + reload_target = true; + } + if (reload_target && !target_file_.empty()) { + SetTargetFile(target_file_); + } +} + +void FileComparer::SetTargetFile(const std::string& target_file) { + if (target_file_ != target_file) { + target_file_ = target_file; + extractor_.LoadFile(target_file); + target_duration_ = extractor_.Duration(); + std::for_each(features_.begin(), features_.end(), + boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1)); + } +} + + +Value FileComparer::CompareFile(const std::string& file_path) { + extractor_.LoadFile(file_path); + Value error; + foreach (FeatureComparerPtr& feature, features_) { + error += feature->Compare(); + } + return error / features_.size(); +} + +namespace { + +// Helper function for creating FeatureComparers +template<typename T> +FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id, + FeatureExtractor& extractor, + const T& (FeatureExtractor::*feature)(), + void (*converter)(T&) = NULL, + bool scale_by_energy = false) { + return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id), + &extractor, + feature, + converter, + scale_by_energy)); +} + +} // namespace + +void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) { + // check if the requested feature is already enabled + for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(), + end = features_.end(); + feature != end; + ++feature) { + if ((*feature)->ID() == feature_id) { + // feature enabled, if enable == false then remove the feature + if (!enable) { + features_.erase(feature); + } + return; + } + } + // make the feature comparer + FeatureComparerPtr feature; + switch (feature_id) { + case Feature::Pitch: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::Pitch, + ConvertFrequencyList, + true); + break; + case Feature::Energy: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::Energy); + break; + case Feature::MFCCs: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::MFCCs); + break; + case Feature::DeltaMFCCs: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::DeltaMFCCs); + break; + case Feature::DoubleDeltaMFCCs: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::DoubleDeltaMFCCs); + break; + case Feature::Magnitude: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::MagnitudeSpectrum); + case Feature::LogMagnitude: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::LogMagnitudeSpectrum); + break; + case Feature::SpectralCentroid: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::SpectralCentroid, + ConvertFrequencyList, + true); + break; + case Feature::SpectralSpread: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::SpectralSpread, + ConvertFrequencyList, + true); + break; + case Feature::SpectralFlux: + feature = MakeFeatureComparer(feature_id, + extractor_, + &FeatureExtractor::SpectralFlux); + break; + default: + throw std::runtime_error("FileComparer::EnableFeature - Invalid ID"); + } + // analyze the feature if we already have a target file loaded + if (!target_file_.empty()) { + feature->AnalyzeTarget(); + } + // store the feature + features_.push_back(feature); +} + +void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features, + bool enable /* = true */) { + foreach (Feature::ID id, features) { + EnableFeature(id, enable); + } +} + +void FileComparer::EnableFeatures(std::string feature_name_list, + bool enable /* = true */) { + std::vector<std::string> feature_names; + // split the comma separated list of names + boost::algorithm::split(feature_names, + feature_name_list, + boost::algorithm::is_any_of(", "), + boost::algorithm::token_compress_on); + // enable each feature + foreach (const std::string& feature_name, feature_names) { + // check if the requested feature name is valid + if (feature_names_.find(feature_name) == feature_names_.end()) { + std::stringstream message; + message << "FileComparer::EnableFeatures - Unknown feature '" + << feature_name << "'"; + throw std::runtime_error(message.str()); + } + EnableFeature(feature_names_[feature_name], enable); + } +} + +void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) { + features_.clear(); + EnableFeatures(features); +} + +} // dsp namespace