Mercurial > hg > auditok
comparison tests/test_AudioReader.py @ 330:9665dc53c394
Rename test_AudioDataSource.py as test_AudioReader.py and refactor it
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Wed, 23 Oct 2019 21:24:33 +0100 |
parents | |
children | 8220dfaa03c6 |
comparison
equal
deleted
inserted
replaced
329:e0c7ae720cc6 | 330:9665dc53c394 |
---|---|
1 """ | |
2 @author: Amine Sehili <amine.sehili@gmail.com> | |
3 September 2015 | |
4 | |
5 """ | |
6 | |
7 import unittest | |
8 from functools import partial | |
9 import sys | |
10 import wave | |
11 from genty import genty, genty_dataset | |
12 from auditok import ( | |
13 dataset, | |
14 ADSFactory, | |
15 AudioDataSource, | |
16 AudioReader, | |
17 Recorder, | |
18 BufferAudioSource, | |
19 WaveAudioSource, | |
20 DuplicateArgument, | |
21 ) | |
22 | |
23 | |
24 class TestADSFactoryFileAudioSource(unittest.TestCase): | |
25 def setUp(self): | |
26 self.audio_source = WaveAudioSource( | |
27 filename=dataset.one_to_six_arabic_16000_mono_bc_noise | |
28 ) | |
29 | |
30 def test_ADS_type(self): | |
31 | |
32 ads = ADSFactory.ads(audio_source=self.audio_source) | |
33 | |
34 err_msg = "wrong type for ads object, expected: 'AudioDataSource', " | |
35 err_msg += "found: {0}" | |
36 self.assertIsInstance( | |
37 ads, AudioDataSource, err_msg.format(type(ads)), | |
38 ) | |
39 | |
40 def test_default_block_size(self): | |
41 ads = ADSFactory.ads(audio_source=self.audio_source) | |
42 size = ads.block_size | |
43 self.assertEqual( | |
44 size, | |
45 160, | |
46 "Wrong default block_size, expected: 160, found: {0}".format(size), | |
47 ) | |
48 | |
49 def test_block_size(self): | |
50 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) | |
51 size = ads.block_size | |
52 self.assertEqual( | |
53 size, | |
54 512, | |
55 "Wrong block_size, expected: 512, found: {0}".format(size), | |
56 ) | |
57 | |
58 # with alias keyword | |
59 ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) | |
60 size = ads.block_size | |
61 self.assertEqual( | |
62 size, | |
63 160, | |
64 "Wrong block_size, expected: 160, found: {0}".format(size), | |
65 ) | |
66 | |
67 def test_block_duration(self): | |
68 | |
69 ads = ADSFactory.ads( | |
70 audio_source=self.audio_source, block_dur=0.01 | |
71 ) # 10 ms | |
72 size = ads.block_size | |
73 self.assertEqual( | |
74 size, | |
75 160, | |
76 "Wrong block_size, expected: 160, found: {0}".format(size), | |
77 ) | |
78 | |
79 # with alias keyword | |
80 ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms | |
81 size = ads.block_size | |
82 self.assertEqual( | |
83 size, | |
84 400, | |
85 "Wrong block_size, expected: 400, found: {0}".format(size), | |
86 ) | |
87 | |
88 def test_hop_duration(self): | |
89 | |
90 ads = ADSFactory.ads( | |
91 audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 | |
92 ) # 10 ms | |
93 size = ads.hop_size | |
94 self.assertEqual( | |
95 size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) | |
96 ) | |
97 | |
98 # with alias keyword | |
99 ads = ADSFactory.ads( | |
100 audio_source=self.audio_source, bd=0.025, hop_dur=0.015 | |
101 ) # 15 ms | |
102 size = ads.hop_size | |
103 self.assertEqual( | |
104 size, | |
105 240, | |
106 "Wrong block_size, expected: 240, found: {0}".format(size), | |
107 ) | |
108 | |
109 def test_sampling_rate(self): | |
110 ads = ADSFactory.ads(audio_source=self.audio_source) | |
111 | |
112 srate = ads.sampling_rate | |
113 self.assertEqual( | |
114 srate, | |
115 16000, | |
116 "Wrong sampling rate, expected: 16000, found: {0}".format(srate), | |
117 ) | |
118 | |
119 def test_sample_width(self): | |
120 ads = ADSFactory.ads(audio_source=self.audio_source) | |
121 | |
122 swidth = ads.sample_width | |
123 self.assertEqual( | |
124 swidth, | |
125 2, | |
126 "Wrong sample width, expected: 2, found: {0}".format(swidth), | |
127 ) | |
128 | |
129 def test_channels(self): | |
130 ads = ADSFactory.ads(audio_source=self.audio_source) | |
131 | |
132 channels = ads.channels | |
133 self.assertEqual( | |
134 channels, | |
135 1, | |
136 "Wrong number of channels, expected: 1, found: {0}".format( | |
137 channels | |
138 ), | |
139 ) | |
140 | |
141 def test_read(self): | |
142 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) | |
143 | |
144 ads.open() | |
145 ads_data = ads.read() | |
146 ads.close() | |
147 | |
148 audio_source = WaveAudioSource( | |
149 filename=dataset.one_to_six_arabic_16000_mono_bc_noise | |
150 ) | |
151 audio_source.open() | |
152 audio_source_data = audio_source.read(256) | |
153 audio_source.close() | |
154 | |
155 self.assertEqual( | |
156 ads_data, audio_source_data, "Unexpected data read from ads" | |
157 ) | |
158 | |
159 def test_Limiter_Deco_read(self): | |
160 # read a maximum of 0.75 seconds from audio source | |
161 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) | |
162 | |
163 ads_data = [] | |
164 ads.open() | |
165 while True: | |
166 block = ads.read() | |
167 if block is None: | |
168 break | |
169 ads_data.append(block) | |
170 ads.close() | |
171 ads_data = b"".join(ads_data) | |
172 | |
173 audio_source = WaveAudioSource( | |
174 filename=dataset.one_to_six_arabic_16000_mono_bc_noise | |
175 ) | |
176 audio_source.open() | |
177 audio_source_data = audio_source.read(int(16000 * 0.75)) | |
178 audio_source.close() | |
179 | |
180 self.assertEqual( | |
181 ads_data, audio_source_data, "Unexpected data read from LimiterADS" | |
182 ) | |
183 | |
184 def test_Limiter_Deco_read_limit(self): | |
185 # read a maximum of 1.191 seconds from audio source | |
186 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191) | |
187 total_samples = round(ads.sampling_rate * 1.191) | |
188 nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size) | |
189 total_samples_with_overlap = ( | |
190 nb_full_blocks * ads.block_size + last_block_size | |
191 ) | |
192 expected_read_bytes = ( | |
193 total_samples_with_overlap * ads.sw * ads.channels | |
194 ) | |
195 | |
196 total_read = 0 | |
197 ads.open() | |
198 i = 0 | |
199 while True: | |
200 block = ads.read() | |
201 if block is None: | |
202 break | |
203 i += 1 | |
204 total_read += len(block) | |
205 | |
206 ads.close() | |
207 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " | |
208 err_msg += "found: {1}" | |
209 self.assertEqual( | |
210 total_read, | |
211 expected_read_bytes, | |
212 err_msg.format(expected_read_bytes, total_read), | |
213 ) | |
214 | |
215 def test_Recorder_Deco_read(self): | |
216 ads = ADSFactory.ads( | |
217 audio_source=self.audio_source, record=True, block_size=500 | |
218 ) | |
219 | |
220 ads_data = [] | |
221 ads.open() | |
222 for i in range(10): | |
223 block = ads.read() | |
224 if block is None: | |
225 break | |
226 ads_data.append(block) | |
227 ads.close() | |
228 ads_data = b"".join(ads_data) | |
229 | |
230 audio_source = WaveAudioSource( | |
231 filename=dataset.one_to_six_arabic_16000_mono_bc_noise | |
232 ) | |
233 audio_source.open() | |
234 audio_source_data = audio_source.read(500 * 10) | |
235 audio_source.close() | |
236 | |
237 self.assertEqual( | |
238 ads_data, | |
239 audio_source_data, | |
240 "Unexpected data read from RecorderADS", | |
241 ) | |
242 | |
243 def test_Recorder_Deco_is_rewindable(self): | |
244 ads = ADSFactory.ads(audio_source=self.audio_source, record=True) | |
245 | |
246 self.assertTrue( | |
247 ads.rewindable, "RecorderADS.is_rewindable should return True" | |
248 ) | |
249 | |
250 def test_Recorder_Deco_rewind_and_read(self): | |
251 ads = ADSFactory.ads( | |
252 audio_source=self.audio_source, record=True, block_size=320 | |
253 ) | |
254 | |
255 ads.open() | |
256 for i in range(10): | |
257 ads.read() | |
258 | |
259 ads.rewind() | |
260 | |
261 # read all available data after rewind | |
262 ads_data = [] | |
263 while True: | |
264 block = ads.read() | |
265 if block is None: | |
266 break | |
267 ads_data.append(block) | |
268 ads.close() | |
269 ads_data = b"".join(ads_data) | |
270 | |
271 audio_source = WaveAudioSource( | |
272 filename=dataset.one_to_six_arabic_16000_mono_bc_noise | |
273 ) | |
274 audio_source.open() | |
275 audio_source_data = audio_source.read(320 * 10) | |
276 audio_source.close() | |
277 | |
278 self.assertEqual( | |
279 ads_data, | |
280 audio_source_data, | |
281 "Unexpected data read from RecorderADS", | |
282 ) | |
283 | |
284 def test_Overlap_Deco_read(self): | |
285 | |
286 # Use arbitrary valid block_size and hop_size | |
287 block_size = 1714 | |
288 hop_size = 313 | |
289 | |
290 ads = ADSFactory.ads( | |
291 audio_source=self.audio_source, | |
292 block_size=block_size, | |
293 hop_size=hop_size, | |
294 ) | |
295 | |
296 # Read all available data overlapping blocks | |
297 ads.open() | |
298 ads_data = [] | |
299 while True: | |
300 block = ads.read() | |
301 if block is None: | |
302 break | |
303 ads_data.append(block) | |
304 ads.close() | |
305 | |
306 # Read all data from file and build a BufferAudioSource | |
307 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") | |
308 wave_data = fp.readframes(fp.getnframes()) | |
309 fp.close() | |
310 audio_source = BufferAudioSource( | |
311 wave_data, ads.sampling_rate, ads.sample_width, ads.channels | |
312 ) | |
313 audio_source.open() | |
314 | |
315 # Compare all blocks read from OverlapADS to those read | |
316 # from an audio source with a manual position setting | |
317 for i, block in enumerate(ads_data): | |
318 | |
319 tmp = audio_source.read(block_size) | |
320 | |
321 self.assertEqual( | |
322 block, | |
323 tmp, | |
324 "Unexpected block (N={0}) read from OverlapADS".format(i), | |
325 ) | |
326 | |
327 audio_source.position = (i + 1) * hop_size | |
328 | |
329 audio_source.close() | |
330 | |
331 def test_Limiter_Overlap_Deco_read(self): | |
332 | |
333 block_size = 256 | |
334 hop_size = 200 | |
335 | |
336 ads = ADSFactory.ads( | |
337 audio_source=self.audio_source, | |
338 max_time=0.50, | |
339 block_size=block_size, | |
340 hop_size=hop_size, | |
341 ) | |
342 | |
343 # Read all available data overlapping blocks | |
344 ads.open() | |
345 ads_data = [] | |
346 while True: | |
347 block = ads.read() | |
348 if block is None: | |
349 break | |
350 ads_data.append(block) | |
351 ads.close() | |
352 | |
353 # Read all data from file and build a BufferAudioSource | |
354 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") | |
355 wave_data = fp.readframes(fp.getnframes()) | |
356 fp.close() | |
357 audio_source = BufferAudioSource( | |
358 wave_data, ads.sampling_rate, ads.sample_width, ads.channels | |
359 ) | |
360 audio_source.open() | |
361 | |
362 # Compare all blocks read from OverlapADS to those read | |
363 # from an audio source with a manual position setting | |
364 for i, block in enumerate(ads_data): | |
365 tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) | |
366 self.assertEqual( | |
367 len(block), | |
368 len(tmp), | |
369 "Unexpected block (N={0}) read from OverlapADS".format(i), | |
370 ) | |
371 audio_source.position = (i + 1) * hop_size | |
372 | |
373 audio_source.close() | |
374 | |
375 def test_Limiter_Overlap_Deco_read_limit(self): | |
376 | |
377 block_size = 313 | |
378 hop_size = 207 | |
379 ads = ADSFactory.ads( | |
380 audio_source=self.audio_source, | |
381 max_time=1.932, | |
382 block_size=block_size, | |
383 hop_size=hop_size, | |
384 ) | |
385 | |
386 total_samples = round(ads.sampling_rate * 1.932) | |
387 first_read_size = block_size | |
388 next_read_size = block_size - hop_size | |
389 nb_next_blocks, last_block_size = divmod( | |
390 (total_samples - first_read_size), next_read_size | |
391 ) | |
392 total_samples_with_overlap = ( | |
393 first_read_size + next_read_size * nb_next_blocks + last_block_size | |
394 ) | |
395 expected_read_bytes = ( | |
396 total_samples_with_overlap * ads.sw * ads.channels | |
397 ) | |
398 | |
399 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels | |
400 total_read = cache_size | |
401 | |
402 ads.open() | |
403 i = 0 | |
404 while True: | |
405 block = ads.read() | |
406 if block is None: | |
407 break | |
408 i += 1 | |
409 total_read += len(block) - cache_size | |
410 | |
411 ads.close() | |
412 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " | |
413 err_msg += "found: {1}" | |
414 self.assertEqual( | |
415 total_read, | |
416 expected_read_bytes, | |
417 err_msg.format(expected_read_bytes, total_read), | |
418 ) | |
419 | |
420 def test_Recorder_Overlap_Deco_is_rewindable(self): | |
421 ads = ADSFactory.ads( | |
422 audio_source=self.audio_source, | |
423 block_size=320, | |
424 hop_size=160, | |
425 record=True, | |
426 ) | |
427 self.assertTrue( | |
428 ads.rewindable, "RecorderADS.is_rewindable should return True" | |
429 ) | |
430 | |
431 def test_Recorder_Overlap_Deco_rewind_and_read(self): | |
432 | |
433 # Use arbitrary valid block_size and hop_size | |
434 block_size = 1600 | |
435 hop_size = 400 | |
436 | |
437 ads = ADSFactory.ads( | |
438 audio_source=self.audio_source, | |
439 block_size=block_size, | |
440 hop_size=hop_size, | |
441 record=True, | |
442 ) | |
443 | |
444 # Read all available data overlapping blocks | |
445 ads.open() | |
446 i = 0 | |
447 while True: | |
448 block = ads.read() | |
449 if block is None: | |
450 break | |
451 i += 1 | |
452 | |
453 ads.rewind() | |
454 | |
455 # Read all data from file and build a BufferAudioSource | |
456 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") | |
457 wave_data = fp.readframes(fp.getnframes()) | |
458 fp.close() | |
459 audio_source = BufferAudioSource( | |
460 wave_data, ads.sampling_rate, ads.sample_width, ads.channels | |
461 ) | |
462 audio_source.open() | |
463 | |
464 # Compare all blocks read from OverlapADS to those read | |
465 # from an audio source with a manual position setting | |
466 for j in range(i): | |
467 | |
468 tmp = audio_source.read(block_size) | |
469 | |
470 self.assertEqual( | |
471 ads.read(), | |
472 tmp, | |
473 "Unexpected block (N={0}) read from OverlapADS".format(i), | |
474 ) | |
475 audio_source.position = (j + 1) * hop_size | |
476 | |
477 ads.close() | |
478 audio_source.close() | |
479 | |
480 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): | |
481 | |
482 # Use arbitrary valid block_size and hop_size | |
483 block_size = 1600 | |
484 hop_size = 400 | |
485 | |
486 ads = ADSFactory.ads( | |
487 audio_source=self.audio_source, | |
488 max_time=1.50, | |
489 block_size=block_size, | |
490 hop_size=hop_size, | |
491 record=True, | |
492 ) | |
493 | |
494 # Read all available data overlapping blocks | |
495 ads.open() | |
496 i = 0 | |
497 while True: | |
498 block = ads.read() | |
499 if block is None: | |
500 break | |
501 i += 1 | |
502 | |
503 ads.rewind() | |
504 | |
505 # Read all data from file and build a BufferAudioSource | |
506 fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") | |
507 wave_data = fp.readframes(fp.getnframes()) | |
508 fp.close() | |
509 audio_source = BufferAudioSource( | |
510 wave_data, ads.sampling_rate, ads.sample_width, ads.channels | |
511 ) | |
512 audio_source.open() | |
513 | |
514 # Compare all blocks read from OverlapADS to those read | |
515 # from an audio source with a manual position setting | |
516 for j in range(i): | |
517 | |
518 tmp = audio_source.read(block_size) | |
519 | |
520 self.assertEqual( | |
521 ads.read(), | |
522 tmp, | |
523 "Unexpected block (N={0}) read from OverlapADS".format(i), | |
524 ) | |
525 audio_source.position = (j + 1) * hop_size | |
526 | |
527 ads.close() | |
528 audio_source.close() | |
529 | |
530 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): | |
531 | |
532 # Use arbitrary valid block_size and hop_size | |
533 block_size = 1000 | |
534 hop_size = 200 | |
535 | |
536 ads = ADSFactory.ads( | |
537 audio_source=self.audio_source, | |
538 max_time=1.317, | |
539 block_size=block_size, | |
540 hop_size=hop_size, | |
541 record=True, | |
542 ) | |
543 total_samples = round(ads.sampling_rate * 1.317) | |
544 first_read_size = block_size | |
545 next_read_size = block_size - hop_size | |
546 nb_next_blocks, last_block_size = divmod( | |
547 (total_samples - first_read_size), next_read_size | |
548 ) | |
549 total_samples_with_overlap = ( | |
550 first_read_size + next_read_size * nb_next_blocks + last_block_size | |
551 ) | |
552 expected_read_bytes = ( | |
553 total_samples_with_overlap * ads.sw * ads.channels | |
554 ) | |
555 | |
556 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels | |
557 total_read = cache_size | |
558 | |
559 ads.open() | |
560 i = 0 | |
561 while True: | |
562 block = ads.read() | |
563 if block is None: | |
564 break | |
565 i += 1 | |
566 total_read += len(block) - cache_size | |
567 | |
568 ads.close() | |
569 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " | |
570 err_msg += "found: {1}" | |
571 self.assertEqual( | |
572 total_read, | |
573 expected_read_bytes, | |
574 err_msg.format(expected_read_bytes, total_read), | |
575 ) | |
576 | |
577 | |
578 class TestADSFactoryBufferAudioSource(unittest.TestCase): | |
579 def setUp(self): | |
580 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" | |
581 self.ads = ADSFactory.ads( | |
582 data_buffer=self.signal, | |
583 sampling_rate=16, | |
584 sample_width=2, | |
585 channels=1, | |
586 block_size=4, | |
587 ) | |
588 | |
589 def test_ADS_BAS_sampling_rate(self): | |
590 srate = self.ads.sampling_rate | |
591 self.assertEqual( | |
592 srate, | |
593 16, | |
594 "Wrong sampling rate, expected: 16000, found: {0}".format(srate), | |
595 ) | |
596 | |
597 def test_ADS_BAS_get_sample_width(self): | |
598 swidth = self.ads.sample_width | |
599 self.assertEqual( | |
600 swidth, | |
601 2, | |
602 "Wrong sample width, expected: 2, found: {0}".format(swidth), | |
603 ) | |
604 | |
605 def test_ADS_BAS_get_channels(self): | |
606 channels = self.ads.channels | |
607 self.assertEqual( | |
608 channels, | |
609 1, | |
610 "Wrong number of channels, expected: 1, found: {0}".format( | |
611 channels | |
612 ), | |
613 ) | |
614 | |
615 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): | |
616 | |
617 # Use arbitrary valid block_size and hop_size | |
618 block_size = 5 | |
619 hop_size = 4 | |
620 | |
621 ads = ADSFactory.ads( | |
622 data_buffer=self.signal, | |
623 sampling_rate=16, | |
624 sample_width=2, | |
625 channels=1, | |
626 max_time=0.80, | |
627 block_size=block_size, | |
628 hop_size=hop_size, | |
629 record=True, | |
630 ) | |
631 | |
632 # Read all available data overlapping blocks | |
633 ads.open() | |
634 i = 0 | |
635 while True: | |
636 block = ads.read() | |
637 if block is None: | |
638 break | |
639 i += 1 | |
640 | |
641 ads.rewind() | |
642 | |
643 # Build a BufferAudioSource | |
644 audio_source = BufferAudioSource( | |
645 self.signal, ads.sampling_rate, ads.sample_width, ads.channels | |
646 ) | |
647 audio_source.open() | |
648 | |
649 # Compare all blocks read from OverlapADS to those read | |
650 # from an audio source with a manual position setting | |
651 for j in range(i): | |
652 | |
653 tmp = audio_source.read(block_size) | |
654 | |
655 block = ads.read() | |
656 | |
657 self.assertEqual( | |
658 block, | |
659 tmp, | |
660 "Unexpected block '{}' (N={}) read from OverlapADS".format( | |
661 block, i | |
662 ), | |
663 ) | |
664 audio_source.position = (j + 1) * hop_size | |
665 | |
666 ads.close() | |
667 audio_source.close() | |
668 | |
669 | |
670 class TestADSFactoryAlias(unittest.TestCase): | |
671 def setUp(self): | |
672 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" | |
673 | |
674 def test_sampling_rate_alias(self): | |
675 ads = ADSFactory.ads( | |
676 data_buffer=self.signal, | |
677 sr=16, | |
678 sample_width=2, | |
679 channels=1, | |
680 block_dur=0.5, | |
681 ) | |
682 srate = ads.sampling_rate | |
683 self.assertEqual( | |
684 srate, | |
685 16, | |
686 "Wrong sampling rate, expected: 16000, found: {0}".format(srate), | |
687 ) | |
688 | |
689 def test_sampling_rate_duplicate(self): | |
690 func = partial( | |
691 ADSFactory.ads, | |
692 data_buffer=self.signal, | |
693 sr=16, | |
694 sampling_rate=16, | |
695 sample_width=2, | |
696 channels=1, | |
697 ) | |
698 self.assertRaises(DuplicateArgument, func) | |
699 | |
700 def test_sample_width_alias(self): | |
701 ads = ADSFactory.ads( | |
702 data_buffer=self.signal, | |
703 sampling_rate=16, | |
704 sw=2, | |
705 channels=1, | |
706 block_dur=0.5, | |
707 ) | |
708 swidth = ads.sample_width | |
709 self.assertEqual( | |
710 swidth, | |
711 2, | |
712 "Wrong sample width, expected: 2, found: {0}".format(swidth), | |
713 ) | |
714 | |
715 def test_sample_width_duplicate(self): | |
716 func = partial( | |
717 ADSFactory.ads, | |
718 data_buffer=self.signal, | |
719 sampling_rate=16, | |
720 sw=2, | |
721 sample_width=2, | |
722 channels=1, | |
723 ) | |
724 self.assertRaises(DuplicateArgument, func) | |
725 | |
726 def test_channels_alias(self): | |
727 ads = ADSFactory.ads( | |
728 data_buffer=self.signal, | |
729 sampling_rate=16, | |
730 sample_width=2, | |
731 ch=1, | |
732 block_dur=4, | |
733 ) | |
734 channels = ads.channels | |
735 self.assertEqual( | |
736 channels, | |
737 1, | |
738 "Wrong number of channels, expected: 1, found: {0}".format( | |
739 channels | |
740 ), | |
741 ) | |
742 | |
743 def test_channels_duplicate(self): | |
744 func = partial( | |
745 ADSFactory.ads, | |
746 data_buffer=self.signal, | |
747 sampling_rate=16, | |
748 sample_width=2, | |
749 ch=1, | |
750 channels=1, | |
751 ) | |
752 self.assertRaises(DuplicateArgument, func) | |
753 | |
754 def test_block_size_alias(self): | |
755 ads = ADSFactory.ads( | |
756 data_buffer=self.signal, | |
757 sampling_rate=16, | |
758 sample_width=2, | |
759 channels=1, | |
760 bs=8, | |
761 ) | |
762 size = ads.block_size | |
763 self.assertEqual( | |
764 size, | |
765 8, | |
766 "Wrong block_size using bs alias, expected: 8, found: {0}".format( | |
767 size | |
768 ), | |
769 ) | |
770 | |
771 def test_block_size_duplicate(self): | |
772 func = partial( | |
773 ADSFactory.ads, | |
774 data_buffer=self.signal, | |
775 sampling_rate=16, | |
776 sample_width=2, | |
777 channels=1, | |
778 bs=4, | |
779 block_size=4, | |
780 ) | |
781 self.assertRaises(DuplicateArgument, func) | |
782 | |
783 def test_block_duration_alias(self): | |
784 ads = ADSFactory.ads( | |
785 data_buffer=self.signal, | |
786 sampling_rate=16, | |
787 sample_width=2, | |
788 channels=1, | |
789 bd=0.75, | |
790 ) | |
791 # 0.75 ms = 0.75 * 16 = 12 | |
792 size = ads.block_size | |
793 err_msg = "Wrong block_size set with a block_dur alias 'bd', " | |
794 err_msg += "expected: 8, found: {0}" | |
795 self.assertEqual( | |
796 size, 12, err_msg.format(size), | |
797 ) | |
798 | |
799 def test_block_duration_duplicate(self): | |
800 func = partial( | |
801 ADSFactory.ads, | |
802 data_buffer=self.signal, | |
803 sampling_rate=16, | |
804 sample_width=2, | |
805 channels=1, | |
806 bd=4, | |
807 block_dur=4, | |
808 ) | |
809 self.assertRaises(DuplicateArgument, func) | |
810 | |
811 def test_block_size_duration_duplicate(self): | |
812 func = partial( | |
813 ADSFactory.ads, | |
814 data_buffer=self.signal, | |
815 sampling_rate=16, | |
816 sample_width=2, | |
817 channels=1, | |
818 bd=4, | |
819 bs=12, | |
820 ) | |
821 self.assertRaises(DuplicateArgument, func) | |
822 | |
823 def test_hop_duration_alias(self): | |
824 | |
825 ads = ADSFactory.ads( | |
826 data_buffer=self.signal, | |
827 sampling_rate=16, | |
828 sample_width=2, | |
829 channels=1, | |
830 bd=0.75, | |
831 hd=0.5, | |
832 ) | |
833 size = ads.hop_size | |
834 self.assertEqual( | |
835 size, | |
836 8, | |
837 "Wrong block_size using bs alias, expected: 8, found: {0}".format( | |
838 size | |
839 ), | |
840 ) | |
841 | |
842 def test_hop_duration_duplicate(self): | |
843 | |
844 func = partial( | |
845 ADSFactory.ads, | |
846 data_buffer=self.signal, | |
847 sampling_rate=16, | |
848 sample_width=2, | |
849 channels=1, | |
850 bd=0.75, | |
851 hd=0.5, | |
852 hop_dur=0.5, | |
853 ) | |
854 self.assertRaises(DuplicateArgument, func) | |
855 | |
856 def test_hop_size_duration_duplicate(self): | |
857 func = partial( | |
858 ADSFactory.ads, | |
859 data_buffer=self.signal, | |
860 sampling_rate=16, | |
861 sample_width=2, | |
862 channels=1, | |
863 bs=8, | |
864 hs=4, | |
865 hd=1, | |
866 ) | |
867 self.assertRaises(DuplicateArgument, func) | |
868 | |
869 def test_hop_size_greater_than_block_size(self): | |
870 func = partial( | |
871 ADSFactory.ads, | |
872 data_buffer=self.signal, | |
873 sampling_rate=16, | |
874 sample_width=2, | |
875 channels=1, | |
876 bs=4, | |
877 hs=8, | |
878 ) | |
879 self.assertRaises(ValueError, func) | |
880 | |
881 def test_filename_duplicate(self): | |
882 | |
883 func = partial( | |
884 ADSFactory.ads, | |
885 fn=dataset.one_to_six_arabic_16000_mono_bc_noise, | |
886 filename=dataset.one_to_six_arabic_16000_mono_bc_noise, | |
887 ) | |
888 self.assertRaises(DuplicateArgument, func) | |
889 | |
890 def test_data_buffer_duplicate(self): | |
891 func = partial( | |
892 ADSFactory.ads, | |
893 data_buffer=self.signal, | |
894 db=self.signal, | |
895 sampling_rate=16, | |
896 sample_width=2, | |
897 channels=1, | |
898 ) | |
899 self.assertRaises(DuplicateArgument, func) | |
900 | |
901 def test_max_time_alias(self): | |
902 ads = ADSFactory.ads( | |
903 data_buffer=self.signal, | |
904 sampling_rate=16, | |
905 sample_width=2, | |
906 channels=1, | |
907 mt=10, | |
908 block_dur=0.5, | |
909 ) | |
910 self.assertEqual( | |
911 ads.max_read, | |
912 10, | |
913 "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( | |
914 ads.max_read | |
915 ), | |
916 ) | |
917 | |
918 def test_max_time_duplicate(self): | |
919 func = partial( | |
920 ADSFactory.ads, | |
921 data_buffer=self.signal, | |
922 sampling_rate=16, | |
923 sample_width=2, | |
924 channels=1, | |
925 mt=True, | |
926 max_time=True, | |
927 ) | |
928 | |
929 self.assertRaises(DuplicateArgument, func) | |
930 | |
931 def test_record_alias(self): | |
932 ads = ADSFactory.ads( | |
933 data_buffer=self.signal, | |
934 sampling_rate=16, | |
935 sample_width=2, | |
936 channels=1, | |
937 rec=True, | |
938 block_dur=0.5, | |
939 ) | |
940 self.assertTrue( | |
941 ads.rewindable, "AudioDataSource.rewindable expected to be True" | |
942 ) | |
943 | |
944 def test_record_duplicate(self): | |
945 func = partial( | |
946 ADSFactory.ads, | |
947 data_buffer=self.signal, | |
948 sampling_rate=16, | |
949 sample_width=2, | |
950 channels=1, | |
951 rec=True, | |
952 record=True, | |
953 ) | |
954 self.assertRaises(DuplicateArgument, func) | |
955 | |
956 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): | |
957 | |
958 # Use arbitrary valid block_size and hop_size | |
959 block_size = 5 | |
960 hop_size = 4 | |
961 | |
962 ads = ADSFactory.ads( | |
963 db=self.signal, | |
964 sr=16, | |
965 sw=2, | |
966 ch=1, | |
967 mt=0.80, | |
968 bs=block_size, | |
969 hs=hop_size, | |
970 rec=True, | |
971 ) | |
972 | |
973 # Read all available data overlapping blocks | |
974 ads.open() | |
975 i = 0 | |
976 while True: | |
977 block = ads.read() | |
978 if block is None: | |
979 break | |
980 i += 1 | |
981 | |
982 ads.rewind() | |
983 | |
984 # Build a BufferAudioSource | |
985 audio_source = BufferAudioSource( | |
986 self.signal, ads.sampling_rate, ads.sample_width, ads.channels | |
987 ) | |
988 audio_source.open() | |
989 | |
990 # Compare all blocks read from AudioDataSource to those read | |
991 # from an audio source with manual position definition | |
992 for j in range(i): | |
993 tmp = audio_source.read(block_size) | |
994 block = ads.read() | |
995 self.assertEqual( | |
996 block, | |
997 tmp, | |
998 "Unexpected block (N={0}) read from OverlapADS".format(i), | |
999 ) | |
1000 audio_source.position = (j + 1) * hop_size | |
1001 ads.close() | |
1002 audio_source.close() | |
1003 | |
1004 | |
1005 def _read_all_data(reader): | |
1006 blocks = [] | |
1007 while True: | |
1008 data = reader.read() | |
1009 if data is None: | |
1010 break | |
1011 blocks.append(data) | |
1012 return b"".join(blocks) | |
1013 | |
1014 | |
1015 @genty | |
1016 class TestAudioReader(unittest.TestCase): | |
1017 | |
1018 # TODO move all tests here when backward compatibility | |
1019 # with ADSFactory is dropped | |
1020 | |
1021 @genty_dataset( | |
1022 mono=("mono_400", 0.5, 16000), | |
1023 multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), | |
1024 ) | |
1025 def test_Limiter(self, file_id, max_read, size): | |
1026 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) | |
1027 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) | |
1028 with open(input_raw, "rb") as fp: | |
1029 expected = fp.read(size) | |
1030 | |
1031 reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) | |
1032 reader.open() | |
1033 data = _read_all_data(reader) | |
1034 reader.close() | |
1035 self.assertEqual(data, expected) | |
1036 | |
1037 @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) | |
1038 def test_Recorder(self, file_id): | |
1039 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) | |
1040 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) | |
1041 with open(input_raw, "rb") as fp: | |
1042 expected = fp.read() | |
1043 | |
1044 reader = AudioReader(input_wav, block_dur=0.1, record=True) | |
1045 reader.open() | |
1046 data = _read_all_data(reader) | |
1047 self.assertEqual(data, expected) | |
1048 | |
1049 # rewind many times | |
1050 for _ in range(3): | |
1051 reader.rewind() | |
1052 data = _read_all_data(reader) | |
1053 self.assertEqual(data, expected) | |
1054 self.assertEqual(data, reader.data) | |
1055 reader.close() | |
1056 | |
1057 @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) | |
1058 def test_Recorder_alias(self, file_id): | |
1059 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) | |
1060 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) | |
1061 with open(input_raw, "rb") as fp: | |
1062 expected = fp.read() | |
1063 | |
1064 reader = Recorder(input_wav, block_dur=0.1) | |
1065 reader.open() | |
1066 data = _read_all_data(reader) | |
1067 self.assertEqual(data, expected) | |
1068 | |
1069 # rewind many times | |
1070 for _ in range(3): | |
1071 reader.rewind() | |
1072 data = _read_all_data(reader) | |
1073 self.assertEqual(data, expected) | |
1074 self.assertEqual(data, reader.data) | |
1075 reader.close() | |
1076 | |
1077 | |
1078 if __name__ == "__main__": | |
1079 unittest.main() |