comparison src/file_comparer.cpp @ 0:add35537fdbb tip

Initial import
author irh <ian.r.hobson@gmail.com>
date Thu, 25 Aug 2011 11:05:55 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:add35537fdbb
1 // Copyright 2011, Ian Hobson.
2 //
3 // This file is part of gpsynth.
4 //
5 // gpsynth is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // gpsynth is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with gpsynth in the file COPYING.
17 // If not, see http://www.gnu.org/licenses/.
18
19 #include "file_comparer.hpp"
20
21 #include "boost_ex.hpp"
22 #include "statistics.hpp"
23 #include "std_ex.hpp"
24
25 #include "boost/bind.hpp"
26 #include "boost/algorithm/string/classification.hpp"
27 #include "boost/algorithm/string/split.hpp"
28 #include "boost/math/special_functions/fpclassify.hpp"
29
30 #include <algorithm>
31 #include <cmath>
32 #include <cstddef>
33 #include <iostream>
34 #include <iterator>
35
36 // maps the frequency range 20-20480Hz to the 0-1 range on a log scale
37 // with 0.1 = 1 octave
38 namespace {
39
40 dsp::Value ConvertFrequency(dsp::Value frequency) {
41 frequency = stdx::Clamp(frequency, 20.0, 20480.0);
42 return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0;
43 }
44
45 void ConvertFrequencyList(dsp::ValueList& frequencies) {
46 std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(),
47 &ConvertFrequency);
48 }
49
50 } // namespace
51
52 namespace dsp {
53
54 // The FeatureComparer class abstracts common code required by feature set
55 // comparisons.
56 template<typename T>
57 class FeatureComparer : public FeatureComparerInterface {
58 public:
59 typedef void (*ConversionFunction)(T&);
60 typedef const T& (FeatureExtractor::*FeatureFunction)();
61 typedef typename T::value_type ValueT;
62 typedef typename T::const_iterator IteratorT;
63
64 private:
65 //
66 FeatureExtractor* extractor_;
67 // A member function of FeatureExtractor that retrieves an audio feature
68 FeatureFunction feature_function_;
69 // Stores the target feature for comparison
70 T target_buffer_;
71 // Temporary use stored as member to prevent reallocations
72 T comparison_buffer_;
73 // Used if a feature data set needs to have a conversion applied to it
74 ConversionFunction converter_;
75 // the feature's ID
76 int id_;
77 // true if the feature should be scaled by the frame's energy
78 bool scale_by_energy_;
79
80 public:
81 FeatureComparer(int id,
82 FeatureExtractor* extractor,
83 FeatureFunction feature_function,
84 ConversionFunction converter = NULL,
85 bool scale_by_energy = false)
86 : id_(id),
87 extractor_(extractor),
88 feature_function_(feature_function),
89 converter_(converter),
90 scale_by_energy_(scale_by_energy)
91 {}
92
93 // copy constructor
94 FeatureComparer(const FeatureComparer& other)
95 : id_(other.id_),
96 extractor_(other.extractor_),
97 feature_function_(other.feature_function_),
98 converter_(other.converter_),
99 scale_by_energy_(other.scale_by_energy_),
100 target_buffer_(other.target_buffer_) // copy the target value buffer
101 // no need to copy comparison buffer
102 {}
103
104 virtual void AnalyzeTarget() {
105 // retrieve the target feature
106 target_buffer_ = (extractor_->*feature_function_)();
107 if (converter_ != NULL) {
108 converter_(target_buffer_);
109 }
110 if (scale_by_energy_) {
111 std::transform(target_buffer_.begin(), target_buffer_.end(),
112 extractor_->Energy().begin(),
113 target_buffer_.begin(),
114 std::multiplies<ValueT>());
115 }
116 }
117
118 virtual Value Compare() {
119 // get the values to compare
120 const ValueList& feature = (extractor_->*feature_function_)();
121 ValueList* feature_buffer;
122 if (converter_ != NULL || scale_by_energy_) {
123 comparison_buffer_ = feature;
124 if (converter_ != NULL) {
125 converter_(comparison_buffer_);
126 }
127 if (scale_by_energy_) {
128 std::transform(comparison_buffer_.begin(), comparison_buffer_.end(),
129 extractor_->Energy().begin(),
130 comparison_buffer_.begin(),
131 std::multiplies<ValueT>());
132 }
133 feature_buffer = &comparison_buffer_;
134 } else {
135 feature_buffer = const_cast<ValueList*>(&feature);
136 }
137 // call conversion function
138
139 // find range of buffer to compare against target
140 IteratorT buffer_start = feature_buffer->begin();
141 IteratorT buffer_end;
142 if (feature_buffer->size() > target_buffer_.size()) {
143 buffer_end = buffer_start + target_buffer_.size();
144 } else {
145 buffer_end = feature_buffer->end();
146 }
147 // return root mean square error of buffer differences
148 return std::sqrt(stats::MeanSquaredError(buffer_start,
149 buffer_end,
150 target_buffer_.begin()));
151 }
152
153 virtual FeatureComparerInterface* Clone() const {
154 return new FeatureComparer<T>(*this);
155 }
156
157 virtual void SetExtractor(FeatureExtractor* extractor) {
158 extractor_ = extractor;
159 }
160
161 virtual int ID() const { return id_; }
162 };
163
164 // specializations for std::vector<ValueList>
165 template<>
166 void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() {
167 target_buffer_ = (extractor_->*feature_function_)();
168 if (converter_ != NULL) {
169 converter_(target_buffer_);
170 }
171 // for vector features we can scale by energy on a frame by frame basis
172 }
173
174 template<>
175 Value FeatureComparer<std::vector<ValueList> >::Compare() {
176 // get the values to compare
177 const std::vector<ValueList>& feature = (extractor_->*feature_function_)();
178 std::vector<ValueList>* feature_buffer;
179 // call conversion function
180 if (converter_ != NULL) {
181 comparison_buffer_ = feature;
182 converter_(comparison_buffer_);
183 feature_buffer = &comparison_buffer_;
184 } else {
185 // this isn't pretty, but it's the only way to allow avoiding a copy
186 // when no conversion is taking place..
187 feature_buffer = const_cast<std::vector<ValueList>*>(&feature);
188 }
189 // find how many frames to compare
190 std::size_t frames_to_compare;
191 if (feature_buffer->size() > target_buffer_.size()) {
192 frames_to_compare = target_buffer_.size();
193 } else {
194 frames_to_compare = feature_buffer->size();
195 }
196 // take average of RMSE over all frames
197 Value error = 0;
198 const ValueList& energy = extractor_->Energy();
199 if (scale_by_energy_) {
200 for (std::size_t i = 0; i < frames_to_compare; ++i) {
201 Value frame_error;
202 frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
203 target_buffer_[i]));
204 error += frame_error * energy[i];
205 }
206 } else {
207 for (std::size_t i = 0; i < frames_to_compare; ++i) {
208 error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i),
209 target_buffer_[i]));
210 }
211 }
212 return error / frames_to_compare;
213 }
214
215
216 FileComparer::FileComparer(int window_size /* = 1024 */,
217 int hop_size /* = 256 */)
218 : extractor_("", window_size, hop_size),
219 target_duration_(0)
220 {
221 EnableFeature(Feature::LogMagnitude);
222 EnableFeature(Feature::Pitch);
223 }
224
225 FileComparer::FileComparer(const std::string& feature_list,
226 int window_size /* = 1024 */,
227 int hop_size /* = 256 */)
228 : extractor_("", window_size, hop_size),
229 target_duration_(0)
230 {
231 feature_names_["pitch"] = Feature::Pitch;
232 feature_names_["energy"] = Feature::Energy;
233 feature_names_["mfccs"] = Feature::MFCCs;
234 feature_names_["dmfccs"] = Feature::DeltaMFCCs;
235 feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs;
236 feature_names_["mag"] = Feature::Magnitude;
237 feature_names_["logmag"] = Feature::LogMagnitude;
238 feature_names_["centroid"] = Feature::SpectralCentroid;
239 feature_names_["spread"] = Feature::SpectralSpread;
240 feature_names_["flux"] = Feature::SpectralFlux;
241 EnableFeatures(feature_list);
242 }
243
244 FileComparer::FileComparer(const FileComparer& other)
245 : target_file_(other.target_file_),
246 target_duration_(other.target_duration_),
247 extractor_("",
248 other.extractor_.WindowSize(),
249 other.extractor_.HopSize())
250 {
251 foreach (const FeatureComparerPtr& pointer, other.features_) {
252 // clone the feature comparer
253 features_.push_back(FeatureComparerPtr(pointer->Clone()));
254 features_.back()->SetExtractor(&extractor_);
255 }
256 }
257
258 void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) {
259 bool reload_target = false;
260 if (extractor_.WindowSize() != window_size) {
261 extractor_.SetWindowSize(window_size);
262 reload_target = true;
263 }
264 if (extractor_.HopSize() != hop_size) {
265 extractor_.SetHopSize(hop_size);
266 reload_target = true;
267 }
268 if (reload_target && !target_file_.empty()) {
269 SetTargetFile(target_file_);
270 }
271 }
272
273 void FileComparer::SetTargetFile(const std::string& target_file) {
274 if (target_file_ != target_file) {
275 target_file_ = target_file;
276 extractor_.LoadFile(target_file);
277 target_duration_ = extractor_.Duration();
278 std::for_each(features_.begin(), features_.end(),
279 boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1));
280 }
281 }
282
283
284 Value FileComparer::CompareFile(const std::string& file_path) {
285 extractor_.LoadFile(file_path);
286 Value error;
287 foreach (FeatureComparerPtr& feature, features_) {
288 error += feature->Compare();
289 }
290 return error / features_.size();
291 }
292
293 namespace {
294
295 // Helper function for creating FeatureComparers
296 template<typename T>
297 FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id,
298 FeatureExtractor& extractor,
299 const T& (FeatureExtractor::*feature)(),
300 void (*converter)(T&) = NULL,
301 bool scale_by_energy = false) {
302 return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id),
303 &extractor,
304 feature,
305 converter,
306 scale_by_energy));
307 }
308
309 } // namespace
310
311 void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) {
312 // check if the requested feature is already enabled
313 for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(),
314 end = features_.end();
315 feature != end;
316 ++feature) {
317 if ((*feature)->ID() == feature_id) {
318 // feature enabled, if enable == false then remove the feature
319 if (!enable) {
320 features_.erase(feature);
321 }
322 return;
323 }
324 }
325 // make the feature comparer
326 FeatureComparerPtr feature;
327 switch (feature_id) {
328 case Feature::Pitch:
329 feature = MakeFeatureComparer(feature_id,
330 extractor_,
331 &FeatureExtractor::Pitch,
332 ConvertFrequencyList,
333 true);
334 break;
335 case Feature::Energy:
336 feature = MakeFeatureComparer(feature_id,
337 extractor_,
338 &FeatureExtractor::Energy);
339 break;
340 case Feature::MFCCs:
341 feature = MakeFeatureComparer(feature_id,
342 extractor_,
343 &FeatureExtractor::MFCCs);
344 break;
345 case Feature::DeltaMFCCs:
346 feature = MakeFeatureComparer(feature_id,
347 extractor_,
348 &FeatureExtractor::DeltaMFCCs);
349 break;
350 case Feature::DoubleDeltaMFCCs:
351 feature = MakeFeatureComparer(feature_id,
352 extractor_,
353 &FeatureExtractor::DoubleDeltaMFCCs);
354 break;
355 case Feature::Magnitude:
356 feature = MakeFeatureComparer(feature_id,
357 extractor_,
358 &FeatureExtractor::MagnitudeSpectrum);
359 case Feature::LogMagnitude:
360 feature = MakeFeatureComparer(feature_id,
361 extractor_,
362 &FeatureExtractor::LogMagnitudeSpectrum);
363 break;
364 case Feature::SpectralCentroid:
365 feature = MakeFeatureComparer(feature_id,
366 extractor_,
367 &FeatureExtractor::SpectralCentroid,
368 ConvertFrequencyList,
369 true);
370 break;
371 case Feature::SpectralSpread:
372 feature = MakeFeatureComparer(feature_id,
373 extractor_,
374 &FeatureExtractor::SpectralSpread,
375 ConvertFrequencyList,
376 true);
377 break;
378 case Feature::SpectralFlux:
379 feature = MakeFeatureComparer(feature_id,
380 extractor_,
381 &FeatureExtractor::SpectralFlux);
382 break;
383 default:
384 throw std::runtime_error("FileComparer::EnableFeature - Invalid ID");
385 }
386 // analyze the feature if we already have a target file loaded
387 if (!target_file_.empty()) {
388 feature->AnalyzeTarget();
389 }
390 // store the feature
391 features_.push_back(feature);
392 }
393
394 void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features,
395 bool enable /* = true */) {
396 foreach (Feature::ID id, features) {
397 EnableFeature(id, enable);
398 }
399 }
400
401 void FileComparer::EnableFeatures(std::string feature_name_list,
402 bool enable /* = true */) {
403 std::vector<std::string> feature_names;
404 // split the comma separated list of names
405 boost::algorithm::split(feature_names,
406 feature_name_list,
407 boost::algorithm::is_any_of(", "),
408 boost::algorithm::token_compress_on);
409 // enable each feature
410 foreach (const std::string& feature_name, feature_names) {
411 // check if the requested feature name is valid
412 if (feature_names_.find(feature_name) == feature_names_.end()) {
413 std::stringstream message;
414 message << "FileComparer::EnableFeatures - Unknown feature '"
415 << feature_name << "'";
416 throw std::runtime_error(message.str());
417 }
418 EnableFeature(feature_names_[feature_name], enable);
419 }
420 }
421
422 void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) {
423 features_.clear();
424 EnableFeatures(features);
425 }
426
427 } // dsp namespace