Mercurial > hg > auditok
diff tests/test_AudioReader.py @ 400:323d59b404a2
Use pytest instead of genty
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sat, 25 May 2024 21:54:13 +0200 |
parents | 8220dfaa03c6 |
children | 996948ada980 |
line wrap: on
line diff
--- a/tests/test_AudioReader.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_AudioReader.py Sat May 25 21:54:13 2024 +0200 @@ -1,14 +1,7 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 - -""" - -import unittest +import pytest from functools import partial import sys import wave -from genty import genty, genty_dataset from auditok import ( dataset, ADSFactory, @@ -21,78 +14,63 @@ ) -class TestADSFactoryFileAudioSource(unittest.TestCase): - def setUp(self): +class TestADSFactoryFileAudioSource: + def setup_method(self): self.audio_source = WaveAudioSource( filename=dataset.one_to_six_arabic_16000_mono_bc_noise ) def test_ADS_type(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - err_msg = "wrong type for ads object, expected: 'AudioDataSource', " - err_msg += "found: {0}" - self.assertIsInstance( - ads, AudioDataSource, err_msg.format(type(ads)), + err_msg = ( + "wrong type for ads object, expected: 'AudioDataSource', found: {0}" ) + assert isinstance(ads, AudioDataSource), err_msg.format(type(ads)) def test_default_block_size(self): ads = ADSFactory.ads(audio_source=self.audio_source) size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong default block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong default block_size, expected: 160, found: {0}".format(size) def test_block_size(self): ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) size = ads.block_size - self.assertEqual( - size, - 512, - "Wrong block_size, expected: 512, found: {0}".format(size), - ) + assert ( + size == 512 + ), "Wrong block_size, expected: 512, found: {0}".format(size) # with alias keyword ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong block_size, expected: 160, found: {0}".format(size) def test_block_duration(self): - ads = ADSFactory.ads( audio_source=self.audio_source, block_dur=0.01 ) # 10 ms size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong block_size, expected: 160, found: {0}".format(size) # with alias keyword ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms size = ads.block_size - self.assertEqual( - size, - 400, - "Wrong block_size, expected: 400, found: {0}".format(size), - ) + assert ( + size == 400 + ), "Wrong block_size, expected: 400, found: {0}".format(size) def test_hop_duration(self): - ads = ADSFactory.ads( audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 ) # 10 ms size = ads.hop_size - self.assertEqual( - size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) + assert size == 160, "Wrong hop_size, expected: 160, found: {0}".format( + size ) # with alias keyword @@ -100,47 +78,33 @@ audio_source=self.audio_source, bd=0.025, hop_dur=0.015 ) # 15 ms size = ads.hop_size - self.assertEqual( - size, - 240, - "Wrong block_size, expected: 240, found: {0}".format(size), - ) + assert ( + size == 240 + ), "Wrong block_size, expected: 240, found: {0}".format(size) def test_sampling_rate(self): ads = ADSFactory.ads(audio_source=self.audio_source) - srate = ads.sampling_rate - self.assertEqual( - srate, - 16000, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16000 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_sample_width(self): ads = ADSFactory.ads(audio_source=self.audio_source) - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_channels(self): ads = ADSFactory.ads(audio_source=self.audio_source) - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_read(self): ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) - ads.open() ads_data = ads.read() ads.close() @@ -152,14 +116,11 @@ audio_source_data = audio_source.read(256) audio_source.close() - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from ads" - ) + assert ads_data == audio_source_data, "Unexpected data read from ads" def test_Limiter_Deco_read(self): # read a maximum of 0.75 seconds from audio source ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) - ads_data = [] ads.open() while True: @@ -177,9 +138,9 @@ audio_source_data = audio_source.read(int(16000 * 0.75)) audio_source.close() - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from LimiterADS" - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from LimiterADS" def test_Limiter_Deco_read_limit(self): # read a maximum of 1.191 seconds from audio source @@ -189,9 +150,7 @@ total_samples_with_overlap = ( nb_full_blocks * ads.block_size + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels total_read = 0 ads.open() @@ -204,19 +163,17 @@ total_read += len(block) ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) def test_Recorder_Deco_read(self): ads = ADSFactory.ads( audio_source=self.audio_source, record=True, block_size=500 ) - ads_data = [] ads.open() for i in range(10): @@ -234,24 +191,18 @@ audio_source_data = audio_source.read(500 * 10) audio_source.close() - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from RecorderADS" def test_Recorder_Deco_is_rewindable(self): ads = ADSFactory.ads(audio_source=self.audio_source, record=True) - - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) + assert ads.rewindable, "RecorderADS.is_rewindable should return True" def test_Recorder_Deco_rewind_and_read(self): ads = ADSFactory.ads( audio_source=self.audio_source, record=True, block_size=320 ) - ads.open() for i in range(10): ads.read() @@ -275,14 +226,11 @@ audio_source_data = audio_source.read(320 * 10) audio_source.close() - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from RecorderADS" def test_Overlap_Deco_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1714 hop_size = 313 @@ -312,24 +260,17 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for i, block in enumerate(ads_data): - tmp = audio_source.read(block_size) - - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - + assert ( + block == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (i + 1) * hop_size audio_source.close() def test_Limiter_Overlap_Deco_read(self): - block_size = 256 hop_size = 200 @@ -359,21 +300,17 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for i, block in enumerate(ads_data): tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) - self.assertEqual( - len(block), - len(tmp), - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert len(block) == len( + tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (i + 1) * hop_size audio_source.close() def test_Limiter_Overlap_Deco_read_limit(self): - block_size = 313 hop_size = 207 ads = ADSFactory.ads( @@ -392,9 +329,7 @@ total_samples_with_overlap = ( first_read_size + next_read_size * nb_next_blocks + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels cache_size = (block_size - hop_size) * ads.sample_width * ads.channels total_read = cache_size @@ -409,12 +344,11 @@ total_read += len(block) - cache_size ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) def test_Recorder_Overlap_Deco_is_rewindable(self): @@ -424,12 +358,9 @@ hop_size=160, record=True, ) - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) + assert ads.rewindable, "RecorderADS.is_rewindable should return True" def test_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1600 hop_size = 400 @@ -461,24 +392,18 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + ads.read() == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1600 hop_size = 400 @@ -511,24 +436,18 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + ads.read() == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): - # Use arbitrary valid block_size and hop_size block_size = 1000 hop_size = 200 @@ -549,9 +468,7 @@ total_samples_with_overlap = ( first_read_size + next_read_size * nb_next_blocks + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels cache_size = (block_size - hop_size) * ads.sample_width * ads.channels total_read = cache_size @@ -566,17 +483,16 @@ total_read += len(block) - cache_size ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) -class TestADSFactoryBufferAudioSource(unittest.TestCase): - def setUp(self): +class TestADSFactoryBufferAudioSource: + def setup_method(self): self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.ads = ADSFactory.ads( data_buffer=self.signal, @@ -588,32 +504,23 @@ def test_ADS_BAS_sampling_rate(self): srate = self.ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_ADS_BAS_sample_width(self): swidth = self.ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_ADS_BAS_channels(self): channels = self.ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 5 hop_size = 4 @@ -646,20 +553,14 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - block = ads.read() - - self.assertEqual( - block, - tmp, - "Unexpected block '{}' (N={}) read from OverlapADS".format( - block, i - ), + assert ( + block == tmp + ), "Unexpected block '{}' (N={}) read from OverlapADS".format( + block, i ) audio_source.position = (j + 1) * hop_size @@ -667,8 +568,8 @@ audio_source.close() -class TestADSFactoryAlias(unittest.TestCase): - def setUp(self): +class TestADSFactoryAlias: + def setup_method(self): self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" def test_sampling_rate_alias(self): @@ -680,11 +581,9 @@ block_dur=0.5, ) srate = ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_sampling_rate_duplicate(self): func = partial( @@ -695,7 +594,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_sample_width_alias(self): ads = ADSFactory.ads( @@ -706,11 +606,9 @@ block_dur=0.5, ) swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_sample_width_duplicate(self): func = partial( @@ -721,7 +619,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_channels_alias(self): ads = ADSFactory.ads( @@ -732,13 +631,9 @@ block_dur=4, ) channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_channels_duplicate(self): func = partial( @@ -749,7 +644,8 @@ ch=1, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_size_alias(self): ads = ADSFactory.ads( @@ -760,12 +656,10 @@ bs=8, ) size = ads.block_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), + assert ( + size == 8 + ), "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size ) def test_block_size_duplicate(self): @@ -778,7 +672,8 @@ bs=4, block_size=4, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_duration_alias(self): ads = ADSFactory.ads( @@ -788,13 +683,9 @@ channels=1, bd=0.75, ) - # 0.75 ms = 0.75 * 16 = 12 size = ads.block_size - err_msg = "Wrong block_size set with a block_dur alias 'bd', " - err_msg += "expected: 8, found: {0}" - self.assertEqual( - size, 12, err_msg.format(size), - ) + err_msg = "Wrong block_size set with a block_dur alias 'bd', expected: 8, found: {0}" + assert size == 12, err_msg.format(size) def test_block_duration_duplicate(self): func = partial( @@ -806,7 +697,8 @@ bd=4, block_dur=4, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_size_duration_duplicate(self): func = partial( @@ -818,10 +710,10 @@ bd=4, bs=12, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_duration_alias(self): - ads = ADSFactory.ads( data_buffer=self.signal, sampling_rate=16, @@ -831,16 +723,13 @@ hd=0.5, ) size = ads.hop_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), + assert ( + size == 8 + ), "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size ) def test_hop_duration_duplicate(self): - func = partial( ADSFactory.ads, data_buffer=self.signal, @@ -851,7 +740,8 @@ hd=0.5, hop_dur=0.5, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_size_duration_duplicate(self): func = partial( @@ -864,7 +754,8 @@ hs=4, hd=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_size_greater_than_block_size(self): func = partial( @@ -876,16 +767,17 @@ bs=4, hs=8, ) - self.assertRaises(ValueError, func) + with pytest.raises(ValueError): + func() def test_filename_duplicate(self): - func = partial( ADSFactory.ads, fn=dataset.one_to_six_arabic_16000_mono_bc_noise, filename=dataset.one_to_six_arabic_16000_mono_bc_noise, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_data_buffer_duplicate(self): func = partial( @@ -896,7 +788,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_max_time_alias(self): ads = ADSFactory.ads( @@ -907,12 +800,10 @@ mt=10, block_dur=0.5, ) - self.assertEqual( - ads.max_read, - 10, - "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( - ads.max_read - ), + assert ( + ads.max_read == 10 + ), "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( + ads.max_read ) def test_max_time_duplicate(self): @@ -925,8 +816,8 @@ mt=True, max_time=True, ) - - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_record_alias(self): ads = ADSFactory.ads( @@ -937,9 +828,7 @@ rec=True, block_dur=0.5, ) - self.assertTrue( - ads.rewindable, "AudioDataSource.rewindable expected to be True" - ) + assert ads.rewindable, "AudioDataSource.rewindable expected to be True" def test_record_duplicate(self): func = partial( @@ -951,10 +840,10 @@ rec=True, record=True, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): - # Use arbitrary valid block_size and hop_size block_size = 5 hop_size = 4 @@ -987,16 +876,13 @@ ) audio_source.open() - # Compare all blocks read from AudioDataSource to those read - # from an audio source with manual position definition + # Compare all blocks read from AudioDataSource to those read from an audio source with manual position definition for j in range(i): tmp = audio_source.read(block_size) block = ads.read() - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + block == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() @@ -1012,68 +898,78 @@ return b"".join(blocks) -@genty -class TestAudioReader(unittest.TestCase): +@pytest.mark.parametrize( + "file_id, max_read, size", + [ + ("mono_400", 0.5, 16000), # mono + ("3channel_400-800-1600", 0.5, 16000 * 3), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Limiter(file_id, max_read, size): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read(size) - # TODO move all tests here when backward compatibility - # with ADSFactory is dropped + reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) + reader.open() + data = _read_all_data(reader) + reader.close() + assert data == expected - @genty_dataset( - mono=("mono_400", 0.5, 16000), - multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), - ) - def test_Limiter(self, file_id, max_read, size): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read(size) - reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) - reader.open() +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = AudioReader(input_wav, block_dur=0.1, record=True) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - reader.close() - self.assertEqual(data, expected) + assert data == expected + assert data == reader.data + reader.close() - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - reader = AudioReader(input_wav, block_dur=0.1, record=True) - reader.open() +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder_alias(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = Recorder(input_wav, block_dur=0.1) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder_alias(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - - reader = Recorder(input_wav, block_dur=0.1) - reader.open() - data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - -if __name__ == "__main__": - unittest.main() + assert data == expected + assert data == reader.data + reader.close()