comparison tests/test_AudioDataSourceFactory.py @ 2:edee860b9f61

First release on Github
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 17 Sep 2015 22:01:30 +0200
parents
children 252d698ae642
comparison
equal deleted inserted replaced
1:78ba0ead5f9f 2:edee860b9f61
1 '''
2 @author: Amine Sehili <amine.sehili@gmail.com>
3 September 2015
4
5 '''
6
7 import unittest
8 from auditok import dataset, ADSFactory, BufferAudioSource, WaveAudioSource
9 import wave
10 from Crypto.Cipher.AES import block_size
11
12
13 class TestADSFactoryFileAudioSource(unittest.TestCase):
14
15 def setUp(self):
16 self.audio_source = WaveAudioSource(filename=dataset.one_to_six_arabic_16000_mono_bc_noise)
17
18
19 def test_ADS_type(self):
20
21 ads = ADSFactory.ads(audio_source=self.audio_source)
22
23 self.assertIsInstance(ads, ADSFactory.AudioDataSource,
24 msg="wrong type for ads object, expected: 'ADSFactory.AudioDataSource', found: {0}".format(type(ads)))
25
26
27 def test_default_block_size(self):
28 ads = ADSFactory.ads(audio_source=self.audio_source)
29
30 size = ads.get_block_size()
31 self.assertEqual(size, 160, "Wrong default block_size, expected: 160, found: {0}".format(size))
32
33
34 def test_block_size(self):
35 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512)
36
37 size = ads.get_block_size()
38 self.assertEqual(size, 512, "Wrong block_size, expected: 512, found: {0}".format(size))
39
40 def test_sampling_rate(self):
41 ads = ADSFactory.ads(audio_source=self.audio_source)
42
43 srate = ads.get_sampling_rate()
44 self.assertEqual(srate, 16000, "Wrong sampling rate, expected: 16000, found: {0}".format(srate))
45
46 def test_sample_width(self):
47 ads = ADSFactory.ads(audio_source=self.audio_source)
48
49 swidth = ads.get_sample_width()
50 self.assertEqual(swidth, 2, "Wrong sample width, expected: 2, found: {0}".format(swidth))
51
52 def test_channels(self):
53 ads = ADSFactory.ads(audio_source=self.audio_source)
54
55 channels = ads.get_channels()
56 self.assertEqual(channels, 1, "Wrong number of channels, expected: 1, found: {0}".format(channels))
57
58 def test_read(self):
59 ads = ADSFactory.ads(audio_source=self.audio_source, block_size = 256)
60
61 ads.open()
62 ads_data = ads.read()
63 ads.close()
64
65 audio_source = WaveAudioSource(filename=dataset.one_to_six_arabic_16000_mono_bc_noise)
66 audio_source.open()
67 audio_source_data = audio_source.read(256)
68 audio_source.close()
69
70 self.assertEqual(ads_data, audio_source_data, "Unexpected data read from ads")
71
72 def test_Limiter_Deco_type(self):
73 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1)
74
75 self.assertIsInstance(ads, ADSFactory.LimiterADS,
76 msg="wrong type for ads object, expected: 'ADSFactory.LimiterADS', found: {0}".format(type(ads)))
77
78
79 def test_Limiter_Deco_read(self):
80 # read a maximum of 0.75 seconds from audio source
81 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75)
82
83 ads_data = []
84 ads.open()
85 while True:
86 block = ads.read()
87 if block is None:
88 break
89 ads_data.append(block)
90 ads.close()
91 ads_data = ''.join(ads_data)
92
93 audio_source = WaveAudioSource(filename=dataset.one_to_six_arabic_16000_mono_bc_noise)
94 audio_source.open()
95 audio_source_data = audio_source.read(int(16000 * 0.75))
96 audio_source.close()
97
98 self.assertEqual(ads_data, audio_source_data, "Unexpected data read from LimiterADS")
99
100
101 def test_Limiter_Deco_read_limit(self):
102 # read a maximum of 1.25 seconds from audio source
103 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191)
104
105 # desired duration into bytes is obtained by:
106 # max_time * sampling_rate * sample_width * nb_channels
107 # Limiter deco tries to a total quantity of data as
108 # possible to the desired duration in bytes.
109 # It reads N block of size block_size where:
110 # (N - 1) * block_size < desired duration, AND
111 # N * block_size >= desired duration
112
113 # theoretical size to reach
114 expected_size = int(ads.get_sampling_rate() * 1.191) * \
115 ads.get_sample_width() * ads.get_channels()
116
117
118 # how much data are required to get N blocks of size block_size
119 block_size_bytes = ads.get_block_size() * ads.get_sample_width() * ads.get_channels()
120 r = expected_size % block_size_bytes
121 if r > 0:
122 expected_size += block_size_bytes - r
123
124 total_read = 0
125 ads.open()
126 i = 0
127 while True:
128 block = ads.read()
129 if block is None:
130 break
131 i += 1
132 total_read += len(block)
133
134 ads.close()
135
136 self.assertEqual(total_read, expected_size, "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format(expected_size, total_read))
137
138
139
140 def test_Recorder_Deco_type(self):
141 ads = ADSFactory.ads(audio_source=self.audio_source, record=True)
142
143 self.assertIsInstance(ads, ADSFactory.RecorderADS,
144 msg="wrong type for ads object, expected: 'ADSFactory.RecorderADS', found: {0}".format(type(ads)))
145
146
147 def test_Recorder_Deco_read(self):
148 ads = ADSFactory.ads(audio_source=self.audio_source, record=True, block_size=500)
149
150 ads_data = []
151 ads.open()
152 for i in xrange(10):
153 block = ads.read()
154 if block is None:
155 break
156 ads_data.append(block)
157 ads.close()
158 ads_data = ''.join(ads_data)
159
160 audio_source = WaveAudioSource(filename=dataset.one_to_six_arabic_16000_mono_bc_noise)
161 audio_source.open()
162 audio_source_data = audio_source.read(500 * 10)
163 audio_source.close()
164
165 self.assertEqual(ads_data, audio_source_data, "Unexpected data read from RecorderADS")
166
167 def test_Recorder_Deco_is_rewindable(self):
168 ads = ADSFactory.ads(audio_source=self.audio_source, record=True)
169
170 self.assertTrue(ads.is_rewindable(), "RecorderADS.is_rewindable should return True")
171
172
173 def test_Recorder_Deco_rewind(self):
174 ads = ADSFactory.ads(audio_source=self.audio_source, record=True, block_size = 320)
175
176 ads.open()
177 ads.read()
178 ads.rewind()
179
180
181 self.assertIsInstance(ads.get_audio_source(),
182 BufferAudioSource, "After rewind RecorderADS.get_audio_source should \
183 be an instance of BufferAudioSource")
184 ads.close()
185
186
187 def test_Recorder_Deco_rewind_and_read(self):
188 ads = ADSFactory.ads(audio_source=self.audio_source, record=True, block_size = 320)
189
190 ads.open()
191 for i in xrange(10):
192 ads.read()
193
194 ads.rewind()
195
196 # read all available data after rewind
197 ads_data = []
198 while True:
199 block = ads.read()
200 if block is None:
201 break
202 ads_data.append(block)
203 ads.close()
204 ads_data = ''.join(ads_data)
205
206 audio_source = WaveAudioSource(filename=dataset.one_to_six_arabic_16000_mono_bc_noise)
207 audio_source.open()
208 audio_source_data = audio_source.read(320 * 10)
209 audio_source.close()
210
211 self.assertEqual(ads_data, audio_source_data, "Unexpected data read from RecorderADS")
212
213 def test_Overlap_Deco_type(self):
214 # an OverlapADS is obtained if a valid hop_size is given
215 ads = ADSFactory.ads(audio_source=self.audio_source, block_size = 256, hop_size = 128)
216
217 self.assertIsInstance(ads, ADSFactory.OverlapADS,
218 msg="wrong type for ads object, expected: 'ADSFactory.OverlapADS', found: {0}".format(type(ads)))
219
220
221
222
223 def test_Overlap_Deco_read(self):
224
225 # Use arbitrary valid block_size and hop_size
226 block_size = 1714
227 hop_size = 313
228
229 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=block_size, hop_size=hop_size)
230
231 # Read all available data overlapping blocks
232 ads.open()
233 ads_data = []
234 while True:
235 block = ads.read()
236 if block is None:
237 break
238 ads_data.append(block)
239 ads.close()
240
241 # Read all data from file and build a BufferAudioSource
242 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r")
243 wave_data = fp.readframes(fp.getnframes())
244 fp.close()
245 audio_source = BufferAudioSource(wave_data, ads.get_sampling_rate(),
246 ads.get_sample_width(), ads.get_channels())
247 audio_source.open()
248
249 # Compare all blocks read from OverlapADS to those read
250 # from an audio source with a manual set_position
251 for i,block in enumerate(ads_data):
252
253 tmp = audio_source.read(block_size)
254
255 self.assertEqual(block, tmp, "Unexpected block (N={0}) read from OverlapADS".format(i))
256
257 audio_source.set_position((i+1) * hop_size)
258
259 audio_source.close()
260
261
262
263
264 def test_Limiter_Overlap_Deco_type(self):
265 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1, block_size = 256, hop_size = 128)
266
267 self.assertIsInstance(ads, ADSFactory.OverlapADS,
268 msg="wrong type for ads object, expected: 'ADSFactory.OverlapADS', found: {0}".format(type(ads)))
269
270
271 self.assertIsInstance(ads.ads, ADSFactory.LimiterADS,
272 msg="wrong type for ads object, expected: 'ADSFactory.LimiterADS', found: {0}".format(type(ads)))
273
274
275
276 def test_Limiter_Overlap_Deco_read(self):
277
278 block_size = 256
279 hop_size = 200
280
281 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.50, block_size=block_size, hop_size=hop_size)
282
283 # Read all available data overlapping blocks
284 ads.open()
285 ads_data = []
286 while True:
287 block = ads.read()
288 if block is None:
289 break
290 ads_data.append(block)
291 ads.close()
292
293 # Read all data from file and build a BufferAudioSource
294 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r")
295 wave_data = fp.readframes(fp.getnframes())
296 fp.close()
297 audio_source = BufferAudioSource(wave_data, ads.get_sampling_rate(),
298 ads.get_sample_width(), ads.get_channels())
299 audio_source.open()
300
301 # Compare all blocks read from OverlapADS to those read
302 # from an audio source with a manual set_position
303 for i,block in enumerate(ads_data):
304 tmp = audio_source.read(block_size)
305
306 self.assertEqual(block, tmp, "Unexpected block (N={0}) read from OverlapADS".format(i))
307
308 audio_source.set_position((i+1) * hop_size)
309
310 audio_source.close()
311
312
313
314 def test_Limiter_Overlap_Deco_read_limit(self):
315
316 block_size = 313
317 hop_size = 207
318 ads = ADSFactory.ads(audio_source=self.audio_source,
319 max_time=1.932, block_size=block_size,
320 hop_size=hop_size)
321
322 # Limiter + Overlap decos => read N block of actual data
323 # one block of size block_size
324 # N - 1 blocks of size hop_size
325 # the total size of read data might be a slightly greater
326 # than the required size calculated from max_time
327
328 # theoretical size to reach
329 expected_size = int(ads.get_sampling_rate() * 1.932) * \
330 ads.get_sample_width() * ads.get_channels()
331
332 # minus block_size
333 expected_size -= (block_size * ads.get_sample_width() * ads.get_channels())
334
335 # how much data are required to get N - 1 blocks of size hop_size
336 hop_size_bytes = hop_size * ads.get_sample_width() * ads.get_channels()
337 r = expected_size % hop_size_bytes
338 if r > 0:
339 expected_size += hop_size_bytes - r
340
341 expected_size += block_size * ads.get_sample_width() * ads.get_channels()
342
343 cache_size = (block_size - hop_size) * ads.get_sample_width() * ads.get_channels()
344 total_read = cache_size
345
346 ads.open()
347 i = 0
348 while True:
349 block = ads.read()
350 if block is None:
351 break
352 i += 1
353 total_read += len(block) - cache_size
354
355 ads.close()
356 self.assertEqual(total_read, expected_size, "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format(expected_size, total_read))
357
358
359
360 def test_Recorder_Overlap_Deco_type(self):
361 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256, hop_size=128, record=True)
362
363 self.assertIsInstance(ads, ADSFactory.OverlapADS,
364 msg="wrong type for ads object, expected: 'ADSFactory.OverlapADS', found: {0}".format(type(ads)))
365
366
367 self.assertIsInstance(ads.ads, ADSFactory.RecorderADS,
368 msg="wrong type for ads object, expected: 'ADSFactory.RecorderADS', found: {0}".format(type(ads)))
369
370
371
372 def test_Recorder_Overlap_Deco_is_rewindable(self):
373 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=320, hop_size=160, record=True)
374 self.assertTrue(ads.is_rewindable(), "RecorderADS.is_rewindable should return True")
375
376
377 def test_Recorder_Overlap_Deco_rewind_and_read(self):
378
379 # Use arbitrary valid block_size and hop_size
380 block_size = 1600
381 hop_size = 400
382
383 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=block_size, hop_size=hop_size, record=True)
384
385 # Read all available data overlapping blocks
386 ads.open()
387 i = 0
388 while True:
389 block = ads.read()
390 if block is None:
391 break
392 i += 1
393
394 ads.rewind()
395
396 # Read all data from file and build a BufferAudioSource
397 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r")
398 wave_data = fp.readframes(fp.getnframes())
399 fp.close()
400 audio_source = BufferAudioSource(wave_data, ads.get_sampling_rate(),
401 ads.get_sample_width(), ads.get_channels())
402 audio_source.open()
403
404 # Compare all blocks read from OverlapADS to those read
405 # from an audio source with a manual set_position
406 for j in xrange(i):
407
408 tmp = audio_source.read(block_size)
409
410 self.assertEqual(ads.read(), tmp, "Unexpected block (N={0}) read from OverlapADS".format(i))
411 audio_source.set_position((j+1) * hop_size)
412
413 ads.close()
414 audio_source.close()
415
416
417 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self):
418
419 # Use arbitrary valid block_size and hop_size
420 block_size = 1600
421 hop_size = 400
422
423 ads = ADSFactory.ads(audio_source=self.audio_source, max_time = 1.50, block_size=block_size, hop_size=hop_size, record=True)
424
425 # Read all available data overlapping blocks
426 ads.open()
427 i = 0
428 while True:
429 block = ads.read()
430 if block is None:
431 break
432 i += 1
433
434 ads.rewind()
435
436 # Read all data from file and build a BufferAudioSource
437 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r")
438 wave_data = fp.readframes(fp.getnframes())
439 fp.close()
440 audio_source = BufferAudioSource(wave_data, ads.get_sampling_rate(),
441 ads.get_sample_width(), ads.get_channels())
442 audio_source.open()
443
444 # Compare all blocks read from OverlapADS to those read
445 # from an audio source with a manual set_position
446 for j in xrange(i):
447
448 tmp = audio_source.read(block_size)
449
450 self.assertEqual(ads.read(), tmp, "Unexpected block (N={0}) read from OverlapADS".format(i))
451 audio_source.set_position((j+1) * hop_size)
452
453 ads.close()
454 audio_source.close()
455
456
457 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self):
458
459 # Use arbitrary valid block_size and hop_size
460 block_size = 1000
461 hop_size = 200
462
463 ads = ADSFactory.ads(audio_source=self.audio_source, max_time = 1.317, block_size=block_size, hop_size=hop_size, record=True)
464
465 # Limiter + Overlap decos => read N block of actual data
466 # one block of size block_size
467 # N - 1 blocks of size hop_size
468 # the total size of read data might be a slightly greater
469 # than the required size calculated from max_time
470
471 # theoretical size to reach
472 expected_size = int(ads.get_sampling_rate() * 1.317) * \
473 ads.get_sample_width() * ads.get_channels()
474
475 # minus block_size
476 expected_size -= (block_size * ads.get_sample_width() * ads.get_channels())
477
478 # how much data are required to get N - 1 blocks of size hop_size
479 hop_size_bytes = hop_size * ads.get_sample_width() * ads.get_channels()
480 r = expected_size % hop_size_bytes
481 if r > 0:
482 expected_size += hop_size_bytes - r
483
484 expected_size += block_size * ads.get_sample_width() * ads.get_channels()
485
486 cache_size = (block_size - hop_size) * ads.get_sample_width() * ads.get_channels()
487 total_read = cache_size
488
489 ads.open()
490 i = 0
491 while True:
492 block = ads.read()
493 if block is None:
494 break
495 i += 1
496 total_read += len(block) - cache_size
497
498 ads.close()
499 self.assertEqual(total_read, expected_size, "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format(expected_size, total_read))
500
501 class TestADSFactoryBufferAudioSource(unittest.TestCase):
502
503 def setUp(self):
504 self.signal = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
505 self.ads = ADSFactory.ads(data_buffer=self.signal, sampling_rate=16,
506 sample_width=2, channels=1)
507
508 def test_ADS_BAS_type(self):
509 self.assertIsInstance(self.ads.get_audio_source(),
510 BufferAudioSource, "ads should \
511 be an instance of BufferAudioSource")
512
513 def test_ADS_BAS_sampling_rate(self):
514 srate = self.ads.get_sampling_rate()
515 self.assertEqual(srate, 16, "Wrong sampling rate, expected: 16000, found: {0}".format(srate))
516
517
518 def test_ADS_BAS_get_sample_width(self):
519 swidth = self.ads.get_sample_width()
520 self.assertEqual(swidth, 2, "Wrong sample width, expected: 2, found: {0}".format(swidth))
521
522 def test_ADS_BAS_get_channels(self):
523 channels = self.ads.get_channels()
524 self.assertEqual(channels, 1, "Wrong number of channels, expected: 1, found: {0}".format(channels))
525
526
527 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self):
528
529 # Use arbitrary valid block_size and hop_size
530 block_size = 5
531 hop_size = 4
532
533 ads = ADSFactory.ads(data_buffer=self.signal, sampling_rate=16,
534 sample_width=2, channels=1, max_time = 0.80,
535 block_size=block_size, hop_size=hop_size,
536 record=True)
537
538 # Read all available data overlapping blocks
539 ads.open()
540 i = 0
541 while True:
542 block = ads.read()
543 if block is None:
544 break
545 i += 1
546
547 ads.rewind()
548
549 # Build a BufferAudioSource
550 audio_source = BufferAudioSource(self.signal, ads.get_sampling_rate(),
551 ads.get_sample_width(), ads.get_channels())
552 audio_source.open()
553
554 # Compare all blocks read from OverlapADS to those read
555 # from an audio source with a manual set_position
556 for j in xrange(i):
557
558 tmp = audio_source.read(block_size)
559
560 block = ads.read()
561
562 self.assertEqual(block, tmp, "Unexpected block (N={0}) read from OverlapADS".format(i))
563 audio_source.set_position((j+1) * hop_size)
564
565 ads.close()
566 audio_source.close()
567
568
569 if __name__ == "__main__":
570 #import sys;sys.argv = ['', 'Test.testName']
571 unittest.main()