comparison tests/test_AudioReader.py @ 400:323d59b404a2

Use pytest instead of genty
author Amine Sehili <amine.sehili@gmail.com>
date Sat, 25 May 2024 21:54:13 +0200
parents 8220dfaa03c6
children 996948ada980
comparison
equal deleted inserted replaced
399:08f893725d23 400:323d59b404a2
1 """ 1 import pytest
2 @author: Amine Sehili <amine.sehili@gmail.com>
3 September 2015
4
5 """
6
7 import unittest
8 from functools import partial 2 from functools import partial
9 import sys 3 import sys
10 import wave 4 import wave
11 from genty import genty, genty_dataset
12 from auditok import ( 5 from auditok import (
13 dataset, 6 dataset,
14 ADSFactory, 7 ADSFactory,
15 AudioDataSource, 8 AudioDataSource,
16 AudioReader, 9 AudioReader,
19 WaveAudioSource, 12 WaveAudioSource,
20 DuplicateArgument, 13 DuplicateArgument,
21 ) 14 )
22 15
23 16
24 class TestADSFactoryFileAudioSource(unittest.TestCase): 17 class TestADSFactoryFileAudioSource:
25 def setUp(self): 18 def setup_method(self):
26 self.audio_source = WaveAudioSource( 19 self.audio_source = WaveAudioSource(
27 filename=dataset.one_to_six_arabic_16000_mono_bc_noise 20 filename=dataset.one_to_six_arabic_16000_mono_bc_noise
28 ) 21 )
29 22
30 def test_ADS_type(self): 23 def test_ADS_type(self):
31
32 ads = ADSFactory.ads(audio_source=self.audio_source) 24 ads = ADSFactory.ads(audio_source=self.audio_source)
33 25 err_msg = (
34 err_msg = "wrong type for ads object, expected: 'AudioDataSource', " 26 "wrong type for ads object, expected: 'AudioDataSource', found: {0}"
35 err_msg += "found: {0}" 27 )
36 self.assertIsInstance( 28 assert isinstance(ads, AudioDataSource), err_msg.format(type(ads))
37 ads, AudioDataSource, err_msg.format(type(ads)),
38 )
39 29
40 def test_default_block_size(self): 30 def test_default_block_size(self):
41 ads = ADSFactory.ads(audio_source=self.audio_source) 31 ads = ADSFactory.ads(audio_source=self.audio_source)
42 size = ads.block_size 32 size = ads.block_size
43 self.assertEqual( 33 assert (
44 size, 34 size == 160
45 160, 35 ), "Wrong default block_size, expected: 160, found: {0}".format(size)
46 "Wrong default block_size, expected: 160, found: {0}".format(size),
47 )
48 36
49 def test_block_size(self): 37 def test_block_size(self):
50 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) 38 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512)
51 size = ads.block_size 39 size = ads.block_size
52 self.assertEqual( 40 assert (
53 size, 41 size == 512
54 512, 42 ), "Wrong block_size, expected: 512, found: {0}".format(size)
55 "Wrong block_size, expected: 512, found: {0}".format(size),
56 )
57 43
58 # with alias keyword 44 # with alias keyword
59 ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) 45 ads = ADSFactory.ads(audio_source=self.audio_source, bs=160)
60 size = ads.block_size 46 size = ads.block_size
61 self.assertEqual( 47 assert (
62 size, 48 size == 160
63 160, 49 ), "Wrong block_size, expected: 160, found: {0}".format(size)
64 "Wrong block_size, expected: 160, found: {0}".format(size),
65 )
66 50
67 def test_block_duration(self): 51 def test_block_duration(self):
68
69 ads = ADSFactory.ads( 52 ads = ADSFactory.ads(
70 audio_source=self.audio_source, block_dur=0.01 53 audio_source=self.audio_source, block_dur=0.01
71 ) # 10 ms 54 ) # 10 ms
72 size = ads.block_size 55 size = ads.block_size
73 self.assertEqual( 56 assert (
74 size, 57 size == 160
75 160, 58 ), "Wrong block_size, expected: 160, found: {0}".format(size)
76 "Wrong block_size, expected: 160, found: {0}".format(size),
77 )
78 59
79 # with alias keyword 60 # with alias keyword
80 ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms 61 ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms
81 size = ads.block_size 62 size = ads.block_size
82 self.assertEqual( 63 assert (
83 size, 64 size == 400
84 400, 65 ), "Wrong block_size, expected: 400, found: {0}".format(size)
85 "Wrong block_size, expected: 400, found: {0}".format(size),
86 )
87 66
88 def test_hop_duration(self): 67 def test_hop_duration(self):
89
90 ads = ADSFactory.ads( 68 ads = ADSFactory.ads(
91 audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 69 audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01
92 ) # 10 ms 70 ) # 10 ms
93 size = ads.hop_size 71 size = ads.hop_size
94 self.assertEqual( 72 assert size == 160, "Wrong hop_size, expected: 160, found: {0}".format(
95 size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) 73 size
96 ) 74 )
97 75
98 # with alias keyword 76 # with alias keyword
99 ads = ADSFactory.ads( 77 ads = ADSFactory.ads(
100 audio_source=self.audio_source, bd=0.025, hop_dur=0.015 78 audio_source=self.audio_source, bd=0.025, hop_dur=0.015
101 ) # 15 ms 79 ) # 15 ms
102 size = ads.hop_size 80 size = ads.hop_size
103 self.assertEqual( 81 assert (
104 size, 82 size == 240
105 240, 83 ), "Wrong block_size, expected: 240, found: {0}".format(size)
106 "Wrong block_size, expected: 240, found: {0}".format(size),
107 )
108 84
109 def test_sampling_rate(self): 85 def test_sampling_rate(self):
110 ads = ADSFactory.ads(audio_source=self.audio_source) 86 ads = ADSFactory.ads(audio_source=self.audio_source)
111
112 srate = ads.sampling_rate 87 srate = ads.sampling_rate
113 self.assertEqual( 88 assert (
114 srate, 89 srate == 16000
115 16000, 90 ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate)
116 "Wrong sampling rate, expected: 16000, found: {0}".format(srate),
117 )
118 91
119 def test_sample_width(self): 92 def test_sample_width(self):
120 ads = ADSFactory.ads(audio_source=self.audio_source) 93 ads = ADSFactory.ads(audio_source=self.audio_source)
121
122 swidth = ads.sample_width 94 swidth = ads.sample_width
123 self.assertEqual( 95 assert (
124 swidth, 96 swidth == 2
125 2, 97 ), "Wrong sample width, expected: 2, found: {0}".format(swidth)
126 "Wrong sample width, expected: 2, found: {0}".format(swidth),
127 )
128 98
129 def test_channels(self): 99 def test_channels(self):
130 ads = ADSFactory.ads(audio_source=self.audio_source) 100 ads = ADSFactory.ads(audio_source=self.audio_source)
131
132 channels = ads.channels 101 channels = ads.channels
133 self.assertEqual( 102 assert (
134 channels, 103 channels == 1
135 1, 104 ), "Wrong number of channels, expected: 1, found: {0}".format(channels)
136 "Wrong number of channels, expected: 1, found: {0}".format(
137 channels
138 ),
139 )
140 105
141 def test_read(self): 106 def test_read(self):
142 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) 107 ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256)
143
144 ads.open() 108 ads.open()
145 ads_data = ads.read() 109 ads_data = ads.read()
146 ads.close() 110 ads.close()
147 111
148 audio_source = WaveAudioSource( 112 audio_source = WaveAudioSource(
150 ) 114 )
151 audio_source.open() 115 audio_source.open()
152 audio_source_data = audio_source.read(256) 116 audio_source_data = audio_source.read(256)
153 audio_source.close() 117 audio_source.close()
154 118
155 self.assertEqual( 119 assert ads_data == audio_source_data, "Unexpected data read from ads"
156 ads_data, audio_source_data, "Unexpected data read from ads"
157 )
158 120
159 def test_Limiter_Deco_read(self): 121 def test_Limiter_Deco_read(self):
160 # read a maximum of 0.75 seconds from audio source 122 # read a maximum of 0.75 seconds from audio source
161 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) 123 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75)
162
163 ads_data = [] 124 ads_data = []
164 ads.open() 125 ads.open()
165 while True: 126 while True:
166 block = ads.read() 127 block = ads.read()
167 if block is None: 128 if block is None:
175 ) 136 )
176 audio_source.open() 137 audio_source.open()
177 audio_source_data = audio_source.read(int(16000 * 0.75)) 138 audio_source_data = audio_source.read(int(16000 * 0.75))
178 audio_source.close() 139 audio_source.close()
179 140
180 self.assertEqual( 141 assert (
181 ads_data, audio_source_data, "Unexpected data read from LimiterADS" 142 ads_data == audio_source_data
182 ) 143 ), "Unexpected data read from LimiterADS"
183 144
184 def test_Limiter_Deco_read_limit(self): 145 def test_Limiter_Deco_read_limit(self):
185 # read a maximum of 1.191 seconds from audio source 146 # read a maximum of 1.191 seconds from audio source
186 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191) 147 ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191)
187 total_samples = round(ads.sampling_rate * 1.191) 148 total_samples = round(ads.sampling_rate * 1.191)
188 nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size) 149 nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size)
189 total_samples_with_overlap = ( 150 total_samples_with_overlap = (
190 nb_full_blocks * ads.block_size + last_block_size 151 nb_full_blocks * ads.block_size + last_block_size
191 ) 152 )
192 expected_read_bytes = ( 153 expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels
193 total_samples_with_overlap * ads.sw * ads.channels
194 )
195 154
196 total_read = 0 155 total_read = 0
197 ads.open() 156 ads.open()
198 i = 0 157 i = 0
199 while True: 158 while True:
202 break 161 break
203 i += 1 162 i += 1
204 total_read += len(block) 163 total_read += len(block)
205 164
206 ads.close() 165 ads.close()
207 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " 166 err_msg = (
208 err_msg += "found: {1}" 167 "Wrong data length read from LimiterADS, expected: {0}, found: {1}"
209 self.assertEqual( 168 )
210 total_read, 169 assert total_read == expected_read_bytes, err_msg.format(
211 expected_read_bytes, 170 expected_read_bytes, total_read
212 err_msg.format(expected_read_bytes, total_read),
213 ) 171 )
214 172
215 def test_Recorder_Deco_read(self): 173 def test_Recorder_Deco_read(self):
216 ads = ADSFactory.ads( 174 ads = ADSFactory.ads(
217 audio_source=self.audio_source, record=True, block_size=500 175 audio_source=self.audio_source, record=True, block_size=500
218 ) 176 )
219
220 ads_data = [] 177 ads_data = []
221 ads.open() 178 ads.open()
222 for i in range(10): 179 for i in range(10):
223 block = ads.read() 180 block = ads.read()
224 if block is None: 181 if block is None:
232 ) 189 )
233 audio_source.open() 190 audio_source.open()
234 audio_source_data = audio_source.read(500 * 10) 191 audio_source_data = audio_source.read(500 * 10)
235 audio_source.close() 192 audio_source.close()
236 193
237 self.assertEqual( 194 assert (
238 ads_data, 195 ads_data == audio_source_data
239 audio_source_data, 196 ), "Unexpected data read from RecorderADS"
240 "Unexpected data read from RecorderADS",
241 )
242 197
243 def test_Recorder_Deco_is_rewindable(self): 198 def test_Recorder_Deco_is_rewindable(self):
244 ads = ADSFactory.ads(audio_source=self.audio_source, record=True) 199 ads = ADSFactory.ads(audio_source=self.audio_source, record=True)
245 200 assert ads.rewindable, "RecorderADS.is_rewindable should return True"
246 self.assertTrue(
247 ads.rewindable, "RecorderADS.is_rewindable should return True"
248 )
249 201
250 def test_Recorder_Deco_rewind_and_read(self): 202 def test_Recorder_Deco_rewind_and_read(self):
251 ads = ADSFactory.ads( 203 ads = ADSFactory.ads(
252 audio_source=self.audio_source, record=True, block_size=320 204 audio_source=self.audio_source, record=True, block_size=320
253 ) 205 )
254
255 ads.open() 206 ads.open()
256 for i in range(10): 207 for i in range(10):
257 ads.read() 208 ads.read()
258 209
259 ads.rewind() 210 ads.rewind()
273 ) 224 )
274 audio_source.open() 225 audio_source.open()
275 audio_source_data = audio_source.read(320 * 10) 226 audio_source_data = audio_source.read(320 * 10)
276 audio_source.close() 227 audio_source.close()
277 228
278 self.assertEqual( 229 assert (
279 ads_data, 230 ads_data == audio_source_data
280 audio_source_data, 231 ), "Unexpected data read from RecorderADS"
281 "Unexpected data read from RecorderADS",
282 )
283 232
284 def test_Overlap_Deco_read(self): 233 def test_Overlap_Deco_read(self):
285
286 # Use arbitrary valid block_size and hop_size 234 # Use arbitrary valid block_size and hop_size
287 block_size = 1714 235 block_size = 1714
288 hop_size = 313 236 hop_size = 313
289 237
290 ads = ADSFactory.ads( 238 ads = ADSFactory.ads(
310 audio_source = BufferAudioSource( 258 audio_source = BufferAudioSource(
311 wave_data, ads.sampling_rate, ads.sample_width, ads.channels 259 wave_data, ads.sampling_rate, ads.sample_width, ads.channels
312 ) 260 )
313 audio_source.open() 261 audio_source.open()
314 262
315 # Compare all blocks read from OverlapADS to those read 263 # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting
316 # from an audio source with a manual position setting
317 for i, block in enumerate(ads_data): 264 for i, block in enumerate(ads_data):
318
319 tmp = audio_source.read(block_size) 265 tmp = audio_source.read(block_size)
320 266 assert (
321 self.assertEqual( 267 block == tmp
322 block, 268 ), "Unexpected block (N={0}) read from OverlapADS".format(i)
323 tmp,
324 "Unexpected block (N={0}) read from OverlapADS".format(i),
325 )
326
327 audio_source.position = (i + 1) * hop_size 269 audio_source.position = (i + 1) * hop_size
328 270
329 audio_source.close() 271 audio_source.close()
330 272
331 def test_Limiter_Overlap_Deco_read(self): 273 def test_Limiter_Overlap_Deco_read(self):
332
333 block_size = 256 274 block_size = 256
334 hop_size = 200 275 hop_size = 200
335 276
336 ads = ADSFactory.ads( 277 ads = ADSFactory.ads(
337 audio_source=self.audio_source, 278 audio_source=self.audio_source,
357 audio_source = BufferAudioSource( 298 audio_source = BufferAudioSource(
358 wave_data, ads.sampling_rate, ads.sample_width, ads.channels 299 wave_data, ads.sampling_rate, ads.sample_width, ads.channels
359 ) 300 )
360 audio_source.open() 301 audio_source.open()
361 302
362 # Compare all blocks read from OverlapADS to those read 303 # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting
363 # from an audio source with a manual position setting
364 for i, block in enumerate(ads_data): 304 for i, block in enumerate(ads_data):
365 tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) 305 tmp = audio_source.read(len(block) // (ads.sw * ads.ch))
366 self.assertEqual( 306 assert len(block) == len(
367 len(block), 307 tmp
368 len(tmp), 308 ), "Unexpected block (N={0}) read from OverlapADS".format(i)
369 "Unexpected block (N={0}) read from OverlapADS".format(i),
370 )
371 audio_source.position = (i + 1) * hop_size 309 audio_source.position = (i + 1) * hop_size
372 310
373 audio_source.close() 311 audio_source.close()
374 312
375 def test_Limiter_Overlap_Deco_read_limit(self): 313 def test_Limiter_Overlap_Deco_read_limit(self):
376
377 block_size = 313 314 block_size = 313
378 hop_size = 207 315 hop_size = 207
379 ads = ADSFactory.ads( 316 ads = ADSFactory.ads(
380 audio_source=self.audio_source, 317 audio_source=self.audio_source,
381 max_time=1.932, 318 max_time=1.932,
390 (total_samples - first_read_size), next_read_size 327 (total_samples - first_read_size), next_read_size
391 ) 328 )
392 total_samples_with_overlap = ( 329 total_samples_with_overlap = (
393 first_read_size + next_read_size * nb_next_blocks + last_block_size 330 first_read_size + next_read_size * nb_next_blocks + last_block_size
394 ) 331 )
395 expected_read_bytes = ( 332 expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels
396 total_samples_with_overlap * ads.sw * ads.channels
397 )
398 333
399 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels 334 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels
400 total_read = cache_size 335 total_read = cache_size
401 336
402 ads.open() 337 ads.open()
407 break 342 break
408 i += 1 343 i += 1
409 total_read += len(block) - cache_size 344 total_read += len(block) - cache_size
410 345
411 ads.close() 346 ads.close()
412 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " 347 err_msg = (
413 err_msg += "found: {1}" 348 "Wrong data length read from LimiterADS, expected: {0}, found: {1}"
414 self.assertEqual( 349 )
415 total_read, 350 assert total_read == expected_read_bytes, err_msg.format(
416 expected_read_bytes, 351 expected_read_bytes, total_read
417 err_msg.format(expected_read_bytes, total_read),
418 ) 352 )
419 353
420 def test_Recorder_Overlap_Deco_is_rewindable(self): 354 def test_Recorder_Overlap_Deco_is_rewindable(self):
421 ads = ADSFactory.ads( 355 ads = ADSFactory.ads(
422 audio_source=self.audio_source, 356 audio_source=self.audio_source,
423 block_size=320, 357 block_size=320,
424 hop_size=160, 358 hop_size=160,
425 record=True, 359 record=True,
426 ) 360 )
427 self.assertTrue( 361 assert ads.rewindable, "RecorderADS.is_rewindable should return True"
428 ads.rewindable, "RecorderADS.is_rewindable should return True"
429 )
430 362
431 def test_Recorder_Overlap_Deco_rewind_and_read(self): 363 def test_Recorder_Overlap_Deco_rewind_and_read(self):
432
433 # Use arbitrary valid block_size and hop_size 364 # Use arbitrary valid block_size and hop_size
434 block_size = 1600 365 block_size = 1600
435 hop_size = 400 366 hop_size = 400
436 367
437 ads = ADSFactory.ads( 368 ads = ADSFactory.ads(
459 audio_source = BufferAudioSource( 390 audio_source = BufferAudioSource(
460 wave_data, ads.sampling_rate, ads.sample_width, ads.channels 391 wave_data, ads.sampling_rate, ads.sample_width, ads.channels
461 ) 392 )
462 audio_source.open() 393 audio_source.open()
463 394
464 # Compare all blocks read from OverlapADS to those read 395 # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting
465 # from an audio source with a manual position setting
466 for j in range(i): 396 for j in range(i):
467
468 tmp = audio_source.read(block_size) 397 tmp = audio_source.read(block_size)
469 398 assert (
470 self.assertEqual( 399 ads.read() == tmp
471 ads.read(), 400 ), "Unexpected block (N={0}) read from OverlapADS".format(i)
472 tmp,
473 "Unexpected block (N={0}) read from OverlapADS".format(i),
474 )
475 audio_source.position = (j + 1) * hop_size 401 audio_source.position = (j + 1) * hop_size
476 402
477 ads.close() 403 ads.close()
478 audio_source.close() 404 audio_source.close()
479 405
480 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): 406 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self):
481
482 # Use arbitrary valid block_size and hop_size 407 # Use arbitrary valid block_size and hop_size
483 block_size = 1600 408 block_size = 1600
484 hop_size = 400 409 hop_size = 400
485 410
486 ads = ADSFactory.ads( 411 ads = ADSFactory.ads(
509 audio_source = BufferAudioSource( 434 audio_source = BufferAudioSource(
510 wave_data, ads.sampling_rate, ads.sample_width, ads.channels 435 wave_data, ads.sampling_rate, ads.sample_width, ads.channels
511 ) 436 )
512 audio_source.open() 437 audio_source.open()
513 438
514 # Compare all blocks read from OverlapADS to those read 439 # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting
515 # from an audio source with a manual position setting
516 for j in range(i): 440 for j in range(i):
517
518 tmp = audio_source.read(block_size) 441 tmp = audio_source.read(block_size)
519 442 assert (
520 self.assertEqual( 443 ads.read() == tmp
521 ads.read(), 444 ), "Unexpected block (N={0}) read from OverlapADS".format(i)
522 tmp,
523 "Unexpected block (N={0}) read from OverlapADS".format(i),
524 )
525 audio_source.position = (j + 1) * hop_size 445 audio_source.position = (j + 1) * hop_size
526 446
527 ads.close() 447 ads.close()
528 audio_source.close() 448 audio_source.close()
529 449
530 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): 450 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self):
531
532 # Use arbitrary valid block_size and hop_size 451 # Use arbitrary valid block_size and hop_size
533 block_size = 1000 452 block_size = 1000
534 hop_size = 200 453 hop_size = 200
535 454
536 ads = ADSFactory.ads( 455 ads = ADSFactory.ads(
547 (total_samples - first_read_size), next_read_size 466 (total_samples - first_read_size), next_read_size
548 ) 467 )
549 total_samples_with_overlap = ( 468 total_samples_with_overlap = (
550 first_read_size + next_read_size * nb_next_blocks + last_block_size 469 first_read_size + next_read_size * nb_next_blocks + last_block_size
551 ) 470 )
552 expected_read_bytes = ( 471 expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels
553 total_samples_with_overlap * ads.sw * ads.channels
554 )
555 472
556 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels 473 cache_size = (block_size - hop_size) * ads.sample_width * ads.channels
557 total_read = cache_size 474 total_read = cache_size
558 475
559 ads.open() 476 ads.open()
564 break 481 break
565 i += 1 482 i += 1
566 total_read += len(block) - cache_size 483 total_read += len(block) - cache_size
567 484
568 ads.close() 485 ads.close()
569 err_msg = "Wrong data length read from LimiterADS, expected: {0}, " 486 err_msg = (
570 err_msg += "found: {1}" 487 "Wrong data length read from LimiterADS, expected: {0}, found: {1}"
571 self.assertEqual( 488 )
572 total_read, 489 assert total_read == expected_read_bytes, err_msg.format(
573 expected_read_bytes, 490 expected_read_bytes, total_read
574 err_msg.format(expected_read_bytes, total_read), 491 )
575 ) 492
576 493
577 494 class TestADSFactoryBufferAudioSource:
578 class TestADSFactoryBufferAudioSource(unittest.TestCase): 495 def setup_method(self):
579 def setUp(self):
580 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" 496 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
581 self.ads = ADSFactory.ads( 497 self.ads = ADSFactory.ads(
582 data_buffer=self.signal, 498 data_buffer=self.signal,
583 sampling_rate=16, 499 sampling_rate=16,
584 sample_width=2, 500 sample_width=2,
586 block_size=4, 502 block_size=4,
587 ) 503 )
588 504
589 def test_ADS_BAS_sampling_rate(self): 505 def test_ADS_BAS_sampling_rate(self):
590 srate = self.ads.sampling_rate 506 srate = self.ads.sampling_rate
591 self.assertEqual( 507 assert (
592 srate, 508 srate == 16
593 16, 509 ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate)
594 "Wrong sampling rate, expected: 16000, found: {0}".format(srate),
595 )
596 510
597 def test_ADS_BAS_sample_width(self): 511 def test_ADS_BAS_sample_width(self):
598 swidth = self.ads.sample_width 512 swidth = self.ads.sample_width
599 self.assertEqual( 513 assert (
600 swidth, 514 swidth == 2
601 2, 515 ), "Wrong sample width, expected: 2, found: {0}".format(swidth)
602 "Wrong sample width, expected: 2, found: {0}".format(swidth),
603 )
604 516
605 def test_ADS_BAS_channels(self): 517 def test_ADS_BAS_channels(self):
606 channels = self.ads.channels 518 channels = self.ads.channels
607 self.assertEqual( 519 assert (
608 channels, 520 channels == 1
609 1, 521 ), "Wrong number of channels, expected: 1, found: {0}".format(channels)
610 "Wrong number of channels, expected: 1, found: {0}".format(
611 channels
612 ),
613 )
614 522
615 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): 523 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self):
616
617 # Use arbitrary valid block_size and hop_size 524 # Use arbitrary valid block_size and hop_size
618 block_size = 5 525 block_size = 5
619 hop_size = 4 526 hop_size = 4
620 527
621 ads = ADSFactory.ads( 528 ads = ADSFactory.ads(
644 audio_source = BufferAudioSource( 551 audio_source = BufferAudioSource(
645 self.signal, ads.sampling_rate, ads.sample_width, ads.channels 552 self.signal, ads.sampling_rate, ads.sample_width, ads.channels
646 ) 553 )
647 audio_source.open() 554 audio_source.open()
648 555
649 # Compare all blocks read from OverlapADS to those read 556 # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting
650 # from an audio source with a manual position setting
651 for j in range(i): 557 for j in range(i):
652
653 tmp = audio_source.read(block_size) 558 tmp = audio_source.read(block_size)
654 559 block = ads.read()
655 block = ads.read() 560 assert (
656 561 block == tmp
657 self.assertEqual( 562 ), "Unexpected block '{}' (N={}) read from OverlapADS".format(
658 block, 563 block, i
659 tmp,
660 "Unexpected block '{}' (N={}) read from OverlapADS".format(
661 block, i
662 ),
663 ) 564 )
664 audio_source.position = (j + 1) * hop_size 565 audio_source.position = (j + 1) * hop_size
665 566
666 ads.close() 567 ads.close()
667 audio_source.close() 568 audio_source.close()
668 569
669 570
670 class TestADSFactoryAlias(unittest.TestCase): 571 class TestADSFactoryAlias:
671 def setUp(self): 572 def setup_method(self):
672 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" 573 self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
673 574
674 def test_sampling_rate_alias(self): 575 def test_sampling_rate_alias(self):
675 ads = ADSFactory.ads( 576 ads = ADSFactory.ads(
676 data_buffer=self.signal, 577 data_buffer=self.signal,
678 sample_width=2, 579 sample_width=2,
679 channels=1, 580 channels=1,
680 block_dur=0.5, 581 block_dur=0.5,
681 ) 582 )
682 srate = ads.sampling_rate 583 srate = ads.sampling_rate
683 self.assertEqual( 584 assert (
684 srate, 585 srate == 16
685 16, 586 ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate)
686 "Wrong sampling rate, expected: 16000, found: {0}".format(srate),
687 )
688 587
689 def test_sampling_rate_duplicate(self): 588 def test_sampling_rate_duplicate(self):
690 func = partial( 589 func = partial(
691 ADSFactory.ads, 590 ADSFactory.ads,
692 data_buffer=self.signal, 591 data_buffer=self.signal,
693 sr=16, 592 sr=16,
694 sampling_rate=16, 593 sampling_rate=16,
695 sample_width=2, 594 sample_width=2,
696 channels=1, 595 channels=1,
697 ) 596 )
698 self.assertRaises(DuplicateArgument, func) 597 with pytest.raises(DuplicateArgument):
598 func()
699 599
700 def test_sample_width_alias(self): 600 def test_sample_width_alias(self):
701 ads = ADSFactory.ads( 601 ads = ADSFactory.ads(
702 data_buffer=self.signal, 602 data_buffer=self.signal,
703 sampling_rate=16, 603 sampling_rate=16,
704 sw=2, 604 sw=2,
705 channels=1, 605 channels=1,
706 block_dur=0.5, 606 block_dur=0.5,
707 ) 607 )
708 swidth = ads.sample_width 608 swidth = ads.sample_width
709 self.assertEqual( 609 assert (
710 swidth, 610 swidth == 2
711 2, 611 ), "Wrong sample width, expected: 2, found: {0}".format(swidth)
712 "Wrong sample width, expected: 2, found: {0}".format(swidth),
713 )
714 612
715 def test_sample_width_duplicate(self): 613 def test_sample_width_duplicate(self):
716 func = partial( 614 func = partial(
717 ADSFactory.ads, 615 ADSFactory.ads,
718 data_buffer=self.signal, 616 data_buffer=self.signal,
719 sampling_rate=16, 617 sampling_rate=16,
720 sw=2, 618 sw=2,
721 sample_width=2, 619 sample_width=2,
722 channels=1, 620 channels=1,
723 ) 621 )
724 self.assertRaises(DuplicateArgument, func) 622 with pytest.raises(DuplicateArgument):
623 func()
725 624
726 def test_channels_alias(self): 625 def test_channels_alias(self):
727 ads = ADSFactory.ads( 626 ads = ADSFactory.ads(
728 data_buffer=self.signal, 627 data_buffer=self.signal,
729 sampling_rate=16, 628 sampling_rate=16,
730 sample_width=2, 629 sample_width=2,
731 ch=1, 630 ch=1,
732 block_dur=4, 631 block_dur=4,
733 ) 632 )
734 channels = ads.channels 633 channels = ads.channels
735 self.assertEqual( 634 assert (
736 channels, 635 channels == 1
737 1, 636 ), "Wrong number of channels, expected: 1, found: {0}".format(channels)
738 "Wrong number of channels, expected: 1, found: {0}".format(
739 channels
740 ),
741 )
742 637
743 def test_channels_duplicate(self): 638 def test_channels_duplicate(self):
744 func = partial( 639 func = partial(
745 ADSFactory.ads, 640 ADSFactory.ads,
746 data_buffer=self.signal, 641 data_buffer=self.signal,
747 sampling_rate=16, 642 sampling_rate=16,
748 sample_width=2, 643 sample_width=2,
749 ch=1, 644 ch=1,
750 channels=1, 645 channels=1,
751 ) 646 )
752 self.assertRaises(DuplicateArgument, func) 647 with pytest.raises(DuplicateArgument):
648 func()
753 649
754 def test_block_size_alias(self): 650 def test_block_size_alias(self):
755 ads = ADSFactory.ads( 651 ads = ADSFactory.ads(
756 data_buffer=self.signal, 652 data_buffer=self.signal,
757 sampling_rate=16, 653 sampling_rate=16,
758 sample_width=2, 654 sample_width=2,
759 channels=1, 655 channels=1,
760 bs=8, 656 bs=8,
761 ) 657 )
762 size = ads.block_size 658 size = ads.block_size
763 self.assertEqual( 659 assert (
764 size, 660 size == 8
765 8, 661 ), "Wrong block_size using bs alias, expected: 8, found: {0}".format(
766 "Wrong block_size using bs alias, expected: 8, found: {0}".format( 662 size
767 size
768 ),
769 ) 663 )
770 664
771 def test_block_size_duplicate(self): 665 def test_block_size_duplicate(self):
772 func = partial( 666 func = partial(
773 ADSFactory.ads, 667 ADSFactory.ads,
776 sample_width=2, 670 sample_width=2,
777 channels=1, 671 channels=1,
778 bs=4, 672 bs=4,
779 block_size=4, 673 block_size=4,
780 ) 674 )
781 self.assertRaises(DuplicateArgument, func) 675 with pytest.raises(DuplicateArgument):
676 func()
782 677
783 def test_block_duration_alias(self): 678 def test_block_duration_alias(self):
784 ads = ADSFactory.ads( 679 ads = ADSFactory.ads(
785 data_buffer=self.signal, 680 data_buffer=self.signal,
786 sampling_rate=16, 681 sampling_rate=16,
787 sample_width=2, 682 sample_width=2,
788 channels=1, 683 channels=1,
789 bd=0.75, 684 bd=0.75,
790 ) 685 )
791 # 0.75 ms = 0.75 * 16 = 12
792 size = ads.block_size 686 size = ads.block_size
793 err_msg = "Wrong block_size set with a block_dur alias 'bd', " 687 err_msg = "Wrong block_size set with a block_dur alias 'bd', expected: 8, found: {0}"
794 err_msg += "expected: 8, found: {0}" 688 assert size == 12, err_msg.format(size)
795 self.assertEqual(
796 size, 12, err_msg.format(size),
797 )
798 689
799 def test_block_duration_duplicate(self): 690 def test_block_duration_duplicate(self):
800 func = partial( 691 func = partial(
801 ADSFactory.ads, 692 ADSFactory.ads,
802 data_buffer=self.signal, 693 data_buffer=self.signal,
804 sample_width=2, 695 sample_width=2,
805 channels=1, 696 channels=1,
806 bd=4, 697 bd=4,
807 block_dur=4, 698 block_dur=4,
808 ) 699 )
809 self.assertRaises(DuplicateArgument, func) 700 with pytest.raises(DuplicateArgument):
701 func()
810 702
811 def test_block_size_duration_duplicate(self): 703 def test_block_size_duration_duplicate(self):
812 func = partial( 704 func = partial(
813 ADSFactory.ads, 705 ADSFactory.ads,
814 data_buffer=self.signal, 706 data_buffer=self.signal,
816 sample_width=2, 708 sample_width=2,
817 channels=1, 709 channels=1,
818 bd=4, 710 bd=4,
819 bs=12, 711 bs=12,
820 ) 712 )
821 self.assertRaises(DuplicateArgument, func) 713 with pytest.raises(DuplicateArgument):
714 func()
822 715
823 def test_hop_duration_alias(self): 716 def test_hop_duration_alias(self):
824
825 ads = ADSFactory.ads( 717 ads = ADSFactory.ads(
826 data_buffer=self.signal, 718 data_buffer=self.signal,
827 sampling_rate=16, 719 sampling_rate=16,
828 sample_width=2, 720 sample_width=2,
829 channels=1, 721 channels=1,
830 bd=0.75, 722 bd=0.75,
831 hd=0.5, 723 hd=0.5,
832 ) 724 )
833 size = ads.hop_size 725 size = ads.hop_size
834 self.assertEqual( 726 assert (
835 size, 727 size == 8
836 8, 728 ), "Wrong block_size using bs alias, expected: 8, found: {0}".format(
837 "Wrong block_size using bs alias, expected: 8, found: {0}".format( 729 size
838 size
839 ),
840 ) 730 )
841 731
842 def test_hop_duration_duplicate(self): 732 def test_hop_duration_duplicate(self):
843
844 func = partial( 733 func = partial(
845 ADSFactory.ads, 734 ADSFactory.ads,
846 data_buffer=self.signal, 735 data_buffer=self.signal,
847 sampling_rate=16, 736 sampling_rate=16,
848 sample_width=2, 737 sample_width=2,
849 channels=1, 738 channels=1,
850 bd=0.75, 739 bd=0.75,
851 hd=0.5, 740 hd=0.5,
852 hop_dur=0.5, 741 hop_dur=0.5,
853 ) 742 )
854 self.assertRaises(DuplicateArgument, func) 743 with pytest.raises(DuplicateArgument):
744 func()
855 745
856 def test_hop_size_duration_duplicate(self): 746 def test_hop_size_duration_duplicate(self):
857 func = partial( 747 func = partial(
858 ADSFactory.ads, 748 ADSFactory.ads,
859 data_buffer=self.signal, 749 data_buffer=self.signal,
862 channels=1, 752 channels=1,
863 bs=8, 753 bs=8,
864 hs=4, 754 hs=4,
865 hd=1, 755 hd=1,
866 ) 756 )
867 self.assertRaises(DuplicateArgument, func) 757 with pytest.raises(DuplicateArgument):
758 func()
868 759
869 def test_hop_size_greater_than_block_size(self): 760 def test_hop_size_greater_than_block_size(self):
870 func = partial( 761 func = partial(
871 ADSFactory.ads, 762 ADSFactory.ads,
872 data_buffer=self.signal, 763 data_buffer=self.signal,
874 sample_width=2, 765 sample_width=2,
875 channels=1, 766 channels=1,
876 bs=4, 767 bs=4,
877 hs=8, 768 hs=8,
878 ) 769 )
879 self.assertRaises(ValueError, func) 770 with pytest.raises(ValueError):
771 func()
880 772
881 def test_filename_duplicate(self): 773 def test_filename_duplicate(self):
882
883 func = partial( 774 func = partial(
884 ADSFactory.ads, 775 ADSFactory.ads,
885 fn=dataset.one_to_six_arabic_16000_mono_bc_noise, 776 fn=dataset.one_to_six_arabic_16000_mono_bc_noise,
886 filename=dataset.one_to_six_arabic_16000_mono_bc_noise, 777 filename=dataset.one_to_six_arabic_16000_mono_bc_noise,
887 ) 778 )
888 self.assertRaises(DuplicateArgument, func) 779 with pytest.raises(DuplicateArgument):
780 func()
889 781
890 def test_data_buffer_duplicate(self): 782 def test_data_buffer_duplicate(self):
891 func = partial( 783 func = partial(
892 ADSFactory.ads, 784 ADSFactory.ads,
893 data_buffer=self.signal, 785 data_buffer=self.signal,
894 db=self.signal, 786 db=self.signal,
895 sampling_rate=16, 787 sampling_rate=16,
896 sample_width=2, 788 sample_width=2,
897 channels=1, 789 channels=1,
898 ) 790 )
899 self.assertRaises(DuplicateArgument, func) 791 with pytest.raises(DuplicateArgument):
792 func()
900 793
901 def test_max_time_alias(self): 794 def test_max_time_alias(self):
902 ads = ADSFactory.ads( 795 ads = ADSFactory.ads(
903 data_buffer=self.signal, 796 data_buffer=self.signal,
904 sampling_rate=16, 797 sampling_rate=16,
905 sample_width=2, 798 sample_width=2,
906 channels=1, 799 channels=1,
907 mt=10, 800 mt=10,
908 block_dur=0.5, 801 block_dur=0.5,
909 ) 802 )
910 self.assertEqual( 803 assert (
911 ads.max_read, 804 ads.max_read == 10
912 10, 805 ), "Wrong AudioDataSource.max_read, expected: 10, found: {}".format(
913 "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( 806 ads.max_read
914 ads.max_read
915 ),
916 ) 807 )
917 808
918 def test_max_time_duplicate(self): 809 def test_max_time_duplicate(self):
919 func = partial( 810 func = partial(
920 ADSFactory.ads, 811 ADSFactory.ads,
923 sample_width=2, 814 sample_width=2,
924 channels=1, 815 channels=1,
925 mt=True, 816 mt=True,
926 max_time=True, 817 max_time=True,
927 ) 818 )
928 819 with pytest.raises(DuplicateArgument):
929 self.assertRaises(DuplicateArgument, func) 820 func()
930 821
931 def test_record_alias(self): 822 def test_record_alias(self):
932 ads = ADSFactory.ads( 823 ads = ADSFactory.ads(
933 data_buffer=self.signal, 824 data_buffer=self.signal,
934 sampling_rate=16, 825 sampling_rate=16,
935 sample_width=2, 826 sample_width=2,
936 channels=1, 827 channels=1,
937 rec=True, 828 rec=True,
938 block_dur=0.5, 829 block_dur=0.5,
939 ) 830 )
940 self.assertTrue( 831 assert ads.rewindable, "AudioDataSource.rewindable expected to be True"
941 ads.rewindable, "AudioDataSource.rewindable expected to be True"
942 )
943 832
944 def test_record_duplicate(self): 833 def test_record_duplicate(self):
945 func = partial( 834 func = partial(
946 ADSFactory.ads, 835 ADSFactory.ads,
947 data_buffer=self.signal, 836 data_buffer=self.signal,
949 sample_width=2, 838 sample_width=2,
950 channels=1, 839 channels=1,
951 rec=True, 840 rec=True,
952 record=True, 841 record=True,
953 ) 842 )
954 self.assertRaises(DuplicateArgument, func) 843 with pytest.raises(DuplicateArgument):
844 func()
955 845
956 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): 846 def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self):
957
958 # Use arbitrary valid block_size and hop_size 847 # Use arbitrary valid block_size and hop_size
959 block_size = 5 848 block_size = 5
960 hop_size = 4 849 hop_size = 4
961 850
962 ads = ADSFactory.ads( 851 ads = ADSFactory.ads(
985 audio_source = BufferAudioSource( 874 audio_source = BufferAudioSource(
986 self.signal, ads.sampling_rate, ads.sample_width, ads.channels 875 self.signal, ads.sampling_rate, ads.sample_width, ads.channels
987 ) 876 )
988 audio_source.open() 877 audio_source.open()
989 878
990 # Compare all blocks read from AudioDataSource to those read 879 # Compare all blocks read from AudioDataSource to those read from an audio source with manual position definition
991 # from an audio source with manual position definition
992 for j in range(i): 880 for j in range(i):
993 tmp = audio_source.read(block_size) 881 tmp = audio_source.read(block_size)
994 block = ads.read() 882 block = ads.read()
995 self.assertEqual( 883 assert (
996 block, 884 block == tmp
997 tmp, 885 ), "Unexpected block (N={0}) read from OverlapADS".format(i)
998 "Unexpected block (N={0}) read from OverlapADS".format(i),
999 )
1000 audio_source.position = (j + 1) * hop_size 886 audio_source.position = (j + 1) * hop_size
1001 ads.close() 887 ads.close()
1002 audio_source.close() 888 audio_source.close()
1003 889
1004 890
1010 break 896 break
1011 blocks.append(data) 897 blocks.append(data)
1012 return b"".join(blocks) 898 return b"".join(blocks)
1013 899
1014 900
1015 @genty 901 @pytest.mark.parametrize(
1016 class TestAudioReader(unittest.TestCase): 902 "file_id, max_read, size",
1017 903 [
1018 # TODO move all tests here when backward compatibility 904 ("mono_400", 0.5, 16000), # mono
1019 # with ADSFactory is dropped 905 ("3channel_400-800-1600", 0.5, 16000 * 3), # multichannel
1020 906 ],
1021 @genty_dataset( 907 ids=["mono", "multichannel"],
1022 mono=("mono_400", 0.5, 16000), 908 )
1023 multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), 909 def test_Limiter(file_id, max_read, size):
1024 ) 910 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
1025 def test_Limiter(self, file_id, max_read, size): 911 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
1026 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) 912 with open(input_raw, "rb") as fp:
1027 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) 913 expected = fp.read(size)
1028 with open(input_raw, "rb") as fp: 914
1029 expected = fp.read(size) 915 reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read)
1030 916 reader.open()
1031 reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) 917 data = _read_all_data(reader)
1032 reader.open() 918 reader.close()
919 assert data == expected
920
921
922 @pytest.mark.parametrize(
923 "file_id",
924 [
925 "mono_400", # mono
926 "3channel_400-800-1600", # multichannel
927 ],
928 ids=["mono", "multichannel"],
929 )
930 def test_Recorder(file_id):
931 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
932 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
933 with open(input_raw, "rb") as fp:
934 expected = fp.read()
935
936 reader = AudioReader(input_wav, block_dur=0.1, record=True)
937 reader.open()
938 data = _read_all_data(reader)
939 assert data == expected
940
941 # rewind many times
942 for _ in range(3):
943 reader.rewind()
1033 data = _read_all_data(reader) 944 data = _read_all_data(reader)
1034 reader.close() 945 assert data == expected
1035 self.assertEqual(data, expected) 946 assert data == reader.data
1036 947 reader.close()
1037 @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) 948
1038 def test_Recorder(self, file_id): 949
1039 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) 950 @pytest.mark.parametrize(
1040 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) 951 "file_id",
1041 with open(input_raw, "rb") as fp: 952 [
1042 expected = fp.read() 953 "mono_400", # mono
1043 954 "3channel_400-800-1600", # multichannel
1044 reader = AudioReader(input_wav, block_dur=0.1, record=True) 955 ],
1045 reader.open() 956 ids=["mono", "multichannel"],
957 )
958 def test_Recorder_alias(file_id):
959 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
960 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
961 with open(input_raw, "rb") as fp:
962 expected = fp.read()
963
964 reader = Recorder(input_wav, block_dur=0.1)
965 reader.open()
966 data = _read_all_data(reader)
967 assert data == expected
968
969 # rewind many times
970 for _ in range(3):
971 reader.rewind()
1046 data = _read_all_data(reader) 972 data = _read_all_data(reader)
1047 self.assertEqual(data, expected) 973 assert data == expected
1048 974 assert data == reader.data
1049 # rewind many times 975 reader.close()
1050 for _ in range(3):
1051 reader.rewind()
1052 data = _read_all_data(reader)
1053 self.assertEqual(data, expected)
1054 self.assertEqual(data, reader.data)
1055 reader.close()
1056
1057 @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",))
1058 def test_Recorder_alias(self, file_id):
1059 input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
1060 input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
1061 with open(input_raw, "rb") as fp:
1062 expected = fp.read()
1063
1064 reader = Recorder(input_wav, block_dur=0.1)
1065 reader.open()
1066 data = _read_all_data(reader)
1067 self.assertEqual(data, expected)
1068
1069 # rewind many times
1070 for _ in range(3):
1071 reader.rewind()
1072 data = _read_all_data(reader)
1073 self.assertEqual(data, expected)
1074 self.assertEqual(data, reader.data)
1075 reader.close()
1076
1077
1078 if __name__ == "__main__":
1079 unittest.main()