Mercurial > hg > gpsynth
comparison src/file_comparer.cpp @ 0:add35537fdbb tip
Initial import
author | irh <ian.r.hobson@gmail.com> |
---|---|
date | Thu, 25 Aug 2011 11:05:55 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:add35537fdbb |
---|---|
1 // Copyright 2011, Ian Hobson. | |
2 // | |
3 // This file is part of gpsynth. | |
4 // | |
5 // gpsynth is free software: you can redistribute it and/or modify | |
6 // it under the terms of the GNU General Public License as published by | |
7 // the Free Software Foundation, either version 3 of the License, or | |
8 // (at your option) any later version. | |
9 // | |
10 // gpsynth is distributed in the hope that it will be useful, | |
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of | |
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
13 // GNU General Public License for more details. | |
14 // | |
15 // You should have received a copy of the GNU General Public License | |
16 // along with gpsynth in the file COPYING. | |
17 // If not, see http://www.gnu.org/licenses/. | |
18 | |
19 #include "file_comparer.hpp" | |
20 | |
21 #include "boost_ex.hpp" | |
22 #include "statistics.hpp" | |
23 #include "std_ex.hpp" | |
24 | |
25 #include "boost/bind.hpp" | |
26 #include "boost/algorithm/string/classification.hpp" | |
27 #include "boost/algorithm/string/split.hpp" | |
28 #include "boost/math/special_functions/fpclassify.hpp" | |
29 | |
30 #include <algorithm> | |
31 #include <cmath> | |
32 #include <cstddef> | |
33 #include <iostream> | |
34 #include <iterator> | |
35 | |
36 // maps the frequency range 20-20480Hz to the 0-1 range on a log scale | |
37 // with 0.1 = 1 octave | |
38 namespace { | |
39 | |
40 dsp::Value ConvertFrequency(dsp::Value frequency) { | |
41 frequency = stdx::Clamp(frequency, 20.0, 20480.0); | |
42 return (::log2(frequency / 440.0) * 10 + 44.5943) / 100.0; | |
43 } | |
44 | |
45 void ConvertFrequencyList(dsp::ValueList& frequencies) { | |
46 std::transform(frequencies.begin(), frequencies.end(), frequencies.begin(), | |
47 &ConvertFrequency); | |
48 } | |
49 | |
50 } // namespace | |
51 | |
52 namespace dsp { | |
53 | |
54 // The FeatureComparer class abstracts common code required by feature set | |
55 // comparisons. | |
56 template<typename T> | |
57 class FeatureComparer : public FeatureComparerInterface { | |
58 public: | |
59 typedef void (*ConversionFunction)(T&); | |
60 typedef const T& (FeatureExtractor::*FeatureFunction)(); | |
61 typedef typename T::value_type ValueT; | |
62 typedef typename T::const_iterator IteratorT; | |
63 | |
64 private: | |
65 // | |
66 FeatureExtractor* extractor_; | |
67 // A member function of FeatureExtractor that retrieves an audio feature | |
68 FeatureFunction feature_function_; | |
69 // Stores the target feature for comparison | |
70 T target_buffer_; | |
71 // Temporary use stored as member to prevent reallocations | |
72 T comparison_buffer_; | |
73 // Used if a feature data set needs to have a conversion applied to it | |
74 ConversionFunction converter_; | |
75 // the feature's ID | |
76 int id_; | |
77 // true if the feature should be scaled by the frame's energy | |
78 bool scale_by_energy_; | |
79 | |
80 public: | |
81 FeatureComparer(int id, | |
82 FeatureExtractor* extractor, | |
83 FeatureFunction feature_function, | |
84 ConversionFunction converter = NULL, | |
85 bool scale_by_energy = false) | |
86 : id_(id), | |
87 extractor_(extractor), | |
88 feature_function_(feature_function), | |
89 converter_(converter), | |
90 scale_by_energy_(scale_by_energy) | |
91 {} | |
92 | |
93 // copy constructor | |
94 FeatureComparer(const FeatureComparer& other) | |
95 : id_(other.id_), | |
96 extractor_(other.extractor_), | |
97 feature_function_(other.feature_function_), | |
98 converter_(other.converter_), | |
99 scale_by_energy_(other.scale_by_energy_), | |
100 target_buffer_(other.target_buffer_) // copy the target value buffer | |
101 // no need to copy comparison buffer | |
102 {} | |
103 | |
104 virtual void AnalyzeTarget() { | |
105 // retrieve the target feature | |
106 target_buffer_ = (extractor_->*feature_function_)(); | |
107 if (converter_ != NULL) { | |
108 converter_(target_buffer_); | |
109 } | |
110 if (scale_by_energy_) { | |
111 std::transform(target_buffer_.begin(), target_buffer_.end(), | |
112 extractor_->Energy().begin(), | |
113 target_buffer_.begin(), | |
114 std::multiplies<ValueT>()); | |
115 } | |
116 } | |
117 | |
118 virtual Value Compare() { | |
119 // get the values to compare | |
120 const ValueList& feature = (extractor_->*feature_function_)(); | |
121 ValueList* feature_buffer; | |
122 if (converter_ != NULL || scale_by_energy_) { | |
123 comparison_buffer_ = feature; | |
124 if (converter_ != NULL) { | |
125 converter_(comparison_buffer_); | |
126 } | |
127 if (scale_by_energy_) { | |
128 std::transform(comparison_buffer_.begin(), comparison_buffer_.end(), | |
129 extractor_->Energy().begin(), | |
130 comparison_buffer_.begin(), | |
131 std::multiplies<ValueT>()); | |
132 } | |
133 feature_buffer = &comparison_buffer_; | |
134 } else { | |
135 feature_buffer = const_cast<ValueList*>(&feature); | |
136 } | |
137 // call conversion function | |
138 | |
139 // find range of buffer to compare against target | |
140 IteratorT buffer_start = feature_buffer->begin(); | |
141 IteratorT buffer_end; | |
142 if (feature_buffer->size() > target_buffer_.size()) { | |
143 buffer_end = buffer_start + target_buffer_.size(); | |
144 } else { | |
145 buffer_end = feature_buffer->end(); | |
146 } | |
147 // return root mean square error of buffer differences | |
148 return std::sqrt(stats::MeanSquaredError(buffer_start, | |
149 buffer_end, | |
150 target_buffer_.begin())); | |
151 } | |
152 | |
153 virtual FeatureComparerInterface* Clone() const { | |
154 return new FeatureComparer<T>(*this); | |
155 } | |
156 | |
157 virtual void SetExtractor(FeatureExtractor* extractor) { | |
158 extractor_ = extractor; | |
159 } | |
160 | |
161 virtual int ID() const { return id_; } | |
162 }; | |
163 | |
164 // specializations for std::vector<ValueList> | |
165 template<> | |
166 void FeatureComparer<std::vector<ValueList> >::AnalyzeTarget() { | |
167 target_buffer_ = (extractor_->*feature_function_)(); | |
168 if (converter_ != NULL) { | |
169 converter_(target_buffer_); | |
170 } | |
171 // for vector features we can scale by energy on a frame by frame basis | |
172 } | |
173 | |
174 template<> | |
175 Value FeatureComparer<std::vector<ValueList> >::Compare() { | |
176 // get the values to compare | |
177 const std::vector<ValueList>& feature = (extractor_->*feature_function_)(); | |
178 std::vector<ValueList>* feature_buffer; | |
179 // call conversion function | |
180 if (converter_ != NULL) { | |
181 comparison_buffer_ = feature; | |
182 converter_(comparison_buffer_); | |
183 feature_buffer = &comparison_buffer_; | |
184 } else { | |
185 // this isn't pretty, but it's the only way to allow avoiding a copy | |
186 // when no conversion is taking place.. | |
187 feature_buffer = const_cast<std::vector<ValueList>*>(&feature); | |
188 } | |
189 // find how many frames to compare | |
190 std::size_t frames_to_compare; | |
191 if (feature_buffer->size() > target_buffer_.size()) { | |
192 frames_to_compare = target_buffer_.size(); | |
193 } else { | |
194 frames_to_compare = feature_buffer->size(); | |
195 } | |
196 // take average of RMSE over all frames | |
197 Value error = 0; | |
198 const ValueList& energy = extractor_->Energy(); | |
199 if (scale_by_energy_) { | |
200 for (std::size_t i = 0; i < frames_to_compare; ++i) { | |
201 Value frame_error; | |
202 frame_error = std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), | |
203 target_buffer_[i])); | |
204 error += frame_error * energy[i]; | |
205 } | |
206 } else { | |
207 for (std::size_t i = 0; i < frames_to_compare; ++i) { | |
208 error += std::sqrt(stats::MeanSquaredError(feature_buffer->at(i), | |
209 target_buffer_[i])); | |
210 } | |
211 } | |
212 return error / frames_to_compare; | |
213 } | |
214 | |
215 | |
216 FileComparer::FileComparer(int window_size /* = 1024 */, | |
217 int hop_size /* = 256 */) | |
218 : extractor_("", window_size, hop_size), | |
219 target_duration_(0) | |
220 { | |
221 EnableFeature(Feature::LogMagnitude); | |
222 EnableFeature(Feature::Pitch); | |
223 } | |
224 | |
225 FileComparer::FileComparer(const std::string& feature_list, | |
226 int window_size /* = 1024 */, | |
227 int hop_size /* = 256 */) | |
228 : extractor_("", window_size, hop_size), | |
229 target_duration_(0) | |
230 { | |
231 feature_names_["pitch"] = Feature::Pitch; | |
232 feature_names_["energy"] = Feature::Energy; | |
233 feature_names_["mfccs"] = Feature::MFCCs; | |
234 feature_names_["dmfccs"] = Feature::DeltaMFCCs; | |
235 feature_names_["ddmfccs"] = Feature::DoubleDeltaMFCCs; | |
236 feature_names_["mag"] = Feature::Magnitude; | |
237 feature_names_["logmag"] = Feature::LogMagnitude; | |
238 feature_names_["centroid"] = Feature::SpectralCentroid; | |
239 feature_names_["spread"] = Feature::SpectralSpread; | |
240 feature_names_["flux"] = Feature::SpectralFlux; | |
241 EnableFeatures(feature_list); | |
242 } | |
243 | |
244 FileComparer::FileComparer(const FileComparer& other) | |
245 : target_file_(other.target_file_), | |
246 target_duration_(other.target_duration_), | |
247 extractor_("", | |
248 other.extractor_.WindowSize(), | |
249 other.extractor_.HopSize()) | |
250 { | |
251 foreach (const FeatureComparerPtr& pointer, other.features_) { | |
252 // clone the feature comparer | |
253 features_.push_back(FeatureComparerPtr(pointer->Clone())); | |
254 features_.back()->SetExtractor(&extractor_); | |
255 } | |
256 } | |
257 | |
258 void FileComparer::SetFeatureExtractorSettings(int window_size, int hop_size) { | |
259 bool reload_target = false; | |
260 if (extractor_.WindowSize() != window_size) { | |
261 extractor_.SetWindowSize(window_size); | |
262 reload_target = true; | |
263 } | |
264 if (extractor_.HopSize() != hop_size) { | |
265 extractor_.SetHopSize(hop_size); | |
266 reload_target = true; | |
267 } | |
268 if (reload_target && !target_file_.empty()) { | |
269 SetTargetFile(target_file_); | |
270 } | |
271 } | |
272 | |
273 void FileComparer::SetTargetFile(const std::string& target_file) { | |
274 if (target_file_ != target_file) { | |
275 target_file_ = target_file; | |
276 extractor_.LoadFile(target_file); | |
277 target_duration_ = extractor_.Duration(); | |
278 std::for_each(features_.begin(), features_.end(), | |
279 boost::bind(&FeatureComparerInterface::AnalyzeTarget, _1)); | |
280 } | |
281 } | |
282 | |
283 | |
284 Value FileComparer::CompareFile(const std::string& file_path) { | |
285 extractor_.LoadFile(file_path); | |
286 Value error; | |
287 foreach (FeatureComparerPtr& feature, features_) { | |
288 error += feature->Compare(); | |
289 } | |
290 return error / features_.size(); | |
291 } | |
292 | |
293 namespace { | |
294 | |
295 // Helper function for creating FeatureComparers | |
296 template<typename T> | |
297 FeatureComparerPtr MakeFeatureComparer(FileComparer::Feature::ID id, | |
298 FeatureExtractor& extractor, | |
299 const T& (FeatureExtractor::*feature)(), | |
300 void (*converter)(T&) = NULL, | |
301 bool scale_by_energy = false) { | |
302 return FeatureComparerPtr(new FeatureComparer<T>(static_cast<int>(id), | |
303 &extractor, | |
304 feature, | |
305 converter, | |
306 scale_by_energy)); | |
307 } | |
308 | |
309 } // namespace | |
310 | |
311 void FileComparer::EnableFeature(Feature::ID feature_id, bool enable) { | |
312 // check if the requested feature is already enabled | |
313 for (std::vector<FeatureComparerPtr>::iterator feature = features_.begin(), | |
314 end = features_.end(); | |
315 feature != end; | |
316 ++feature) { | |
317 if ((*feature)->ID() == feature_id) { | |
318 // feature enabled, if enable == false then remove the feature | |
319 if (!enable) { | |
320 features_.erase(feature); | |
321 } | |
322 return; | |
323 } | |
324 } | |
325 // make the feature comparer | |
326 FeatureComparerPtr feature; | |
327 switch (feature_id) { | |
328 case Feature::Pitch: | |
329 feature = MakeFeatureComparer(feature_id, | |
330 extractor_, | |
331 &FeatureExtractor::Pitch, | |
332 ConvertFrequencyList, | |
333 true); | |
334 break; | |
335 case Feature::Energy: | |
336 feature = MakeFeatureComparer(feature_id, | |
337 extractor_, | |
338 &FeatureExtractor::Energy); | |
339 break; | |
340 case Feature::MFCCs: | |
341 feature = MakeFeatureComparer(feature_id, | |
342 extractor_, | |
343 &FeatureExtractor::MFCCs); | |
344 break; | |
345 case Feature::DeltaMFCCs: | |
346 feature = MakeFeatureComparer(feature_id, | |
347 extractor_, | |
348 &FeatureExtractor::DeltaMFCCs); | |
349 break; | |
350 case Feature::DoubleDeltaMFCCs: | |
351 feature = MakeFeatureComparer(feature_id, | |
352 extractor_, | |
353 &FeatureExtractor::DoubleDeltaMFCCs); | |
354 break; | |
355 case Feature::Magnitude: | |
356 feature = MakeFeatureComparer(feature_id, | |
357 extractor_, | |
358 &FeatureExtractor::MagnitudeSpectrum); | |
359 case Feature::LogMagnitude: | |
360 feature = MakeFeatureComparer(feature_id, | |
361 extractor_, | |
362 &FeatureExtractor::LogMagnitudeSpectrum); | |
363 break; | |
364 case Feature::SpectralCentroid: | |
365 feature = MakeFeatureComparer(feature_id, | |
366 extractor_, | |
367 &FeatureExtractor::SpectralCentroid, | |
368 ConvertFrequencyList, | |
369 true); | |
370 break; | |
371 case Feature::SpectralSpread: | |
372 feature = MakeFeatureComparer(feature_id, | |
373 extractor_, | |
374 &FeatureExtractor::SpectralSpread, | |
375 ConvertFrequencyList, | |
376 true); | |
377 break; | |
378 case Feature::SpectralFlux: | |
379 feature = MakeFeatureComparer(feature_id, | |
380 extractor_, | |
381 &FeatureExtractor::SpectralFlux); | |
382 break; | |
383 default: | |
384 throw std::runtime_error("FileComparer::EnableFeature - Invalid ID"); | |
385 } | |
386 // analyze the feature if we already have a target file loaded | |
387 if (!target_file_.empty()) { | |
388 feature->AnalyzeTarget(); | |
389 } | |
390 // store the feature | |
391 features_.push_back(feature); | |
392 } | |
393 | |
394 void FileComparer::EnableFeatures(const std::vector<Feature::ID>& features, | |
395 bool enable /* = true */) { | |
396 foreach (Feature::ID id, features) { | |
397 EnableFeature(id, enable); | |
398 } | |
399 } | |
400 | |
401 void FileComparer::EnableFeatures(std::string feature_name_list, | |
402 bool enable /* = true */) { | |
403 std::vector<std::string> feature_names; | |
404 // split the comma separated list of names | |
405 boost::algorithm::split(feature_names, | |
406 feature_name_list, | |
407 boost::algorithm::is_any_of(", "), | |
408 boost::algorithm::token_compress_on); | |
409 // enable each feature | |
410 foreach (const std::string& feature_name, feature_names) { | |
411 // check if the requested feature name is valid | |
412 if (feature_names_.find(feature_name) == feature_names_.end()) { | |
413 std::stringstream message; | |
414 message << "FileComparer::EnableFeatures - Unknown feature '" | |
415 << feature_name << "'"; | |
416 throw std::runtime_error(message.str()); | |
417 } | |
418 EnableFeature(feature_names_[feature_name], enable); | |
419 } | |
420 } | |
421 | |
422 void FileComparer::SetFeatures(const std::vector<Feature::ID>& features) { | |
423 features_.clear(); | |
424 EnableFeatures(features); | |
425 } | |
426 | |
427 } // dsp namespace |