Mercurial > hg > auditok
changeset 400:323d59b404a2
Use pytest instead of genty
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sat, 25 May 2024 21:54:13 +0200 |
parents | 08f893725d23 |
children | b11c51a0eade |
files | tests/test_AudioReader.py tests/test_AudioSource.py tests/test_StreamTokenizer.py tests/test_cmdline_util.py tests/test_core.py tests/test_io.py tests/test_plotting.py tests/test_signal.py tests/test_util.py tests/test_workers.py |
diffstat | 10 files changed, 3804 insertions(+), 4263 deletions(-) [+] |
line wrap: on
line diff
--- a/tests/test_AudioReader.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_AudioReader.py Sat May 25 21:54:13 2024 +0200 @@ -1,14 +1,7 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 - -""" - -import unittest +import pytest from functools import partial import sys import wave -from genty import genty, genty_dataset from auditok import ( dataset, ADSFactory, @@ -21,78 +14,63 @@ ) -class TestADSFactoryFileAudioSource(unittest.TestCase): - def setUp(self): +class TestADSFactoryFileAudioSource: + def setup_method(self): self.audio_source = WaveAudioSource( filename=dataset.one_to_six_arabic_16000_mono_bc_noise ) def test_ADS_type(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - err_msg = "wrong type for ads object, expected: 'AudioDataSource', " - err_msg += "found: {0}" - self.assertIsInstance( - ads, AudioDataSource, err_msg.format(type(ads)), + err_msg = ( + "wrong type for ads object, expected: 'AudioDataSource', found: {0}" ) + assert isinstance(ads, AudioDataSource), err_msg.format(type(ads)) def test_default_block_size(self): ads = ADSFactory.ads(audio_source=self.audio_source) size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong default block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong default block_size, expected: 160, found: {0}".format(size) def test_block_size(self): ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) size = ads.block_size - self.assertEqual( - size, - 512, - "Wrong block_size, expected: 512, found: {0}".format(size), - ) + assert ( + size == 512 + ), "Wrong block_size, expected: 512, found: {0}".format(size) # with alias keyword ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong block_size, expected: 160, found: {0}".format(size) def test_block_duration(self): - ads = ADSFactory.ads( audio_source=self.audio_source, block_dur=0.01 ) # 10 ms size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) + assert ( + size == 160 + ), "Wrong block_size, expected: 160, found: {0}".format(size) # with alias keyword ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms size = ads.block_size - self.assertEqual( - size, - 400, - "Wrong block_size, expected: 400, found: {0}".format(size), - ) + assert ( + size == 400 + ), "Wrong block_size, expected: 400, found: {0}".format(size) def test_hop_duration(self): - ads = ADSFactory.ads( audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 ) # 10 ms size = ads.hop_size - self.assertEqual( - size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) + assert size == 160, "Wrong hop_size, expected: 160, found: {0}".format( + size ) # with alias keyword @@ -100,47 +78,33 @@ audio_source=self.audio_source, bd=0.025, hop_dur=0.015 ) # 15 ms size = ads.hop_size - self.assertEqual( - size, - 240, - "Wrong block_size, expected: 240, found: {0}".format(size), - ) + assert ( + size == 240 + ), "Wrong block_size, expected: 240, found: {0}".format(size) def test_sampling_rate(self): ads = ADSFactory.ads(audio_source=self.audio_source) - srate = ads.sampling_rate - self.assertEqual( - srate, - 16000, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16000 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_sample_width(self): ads = ADSFactory.ads(audio_source=self.audio_source) - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_channels(self): ads = ADSFactory.ads(audio_source=self.audio_source) - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_read(self): ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) - ads.open() ads_data = ads.read() ads.close() @@ -152,14 +116,11 @@ audio_source_data = audio_source.read(256) audio_source.close() - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from ads" - ) + assert ads_data == audio_source_data, "Unexpected data read from ads" def test_Limiter_Deco_read(self): # read a maximum of 0.75 seconds from audio source ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) - ads_data = [] ads.open() while True: @@ -177,9 +138,9 @@ audio_source_data = audio_source.read(int(16000 * 0.75)) audio_source.close() - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from LimiterADS" - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from LimiterADS" def test_Limiter_Deco_read_limit(self): # read a maximum of 1.191 seconds from audio source @@ -189,9 +150,7 @@ total_samples_with_overlap = ( nb_full_blocks * ads.block_size + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels total_read = 0 ads.open() @@ -204,19 +163,17 @@ total_read += len(block) ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) def test_Recorder_Deco_read(self): ads = ADSFactory.ads( audio_source=self.audio_source, record=True, block_size=500 ) - ads_data = [] ads.open() for i in range(10): @@ -234,24 +191,18 @@ audio_source_data = audio_source.read(500 * 10) audio_source.close() - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from RecorderADS" def test_Recorder_Deco_is_rewindable(self): ads = ADSFactory.ads(audio_source=self.audio_source, record=True) - - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) + assert ads.rewindable, "RecorderADS.is_rewindable should return True" def test_Recorder_Deco_rewind_and_read(self): ads = ADSFactory.ads( audio_source=self.audio_source, record=True, block_size=320 ) - ads.open() for i in range(10): ads.read() @@ -275,14 +226,11 @@ audio_source_data = audio_source.read(320 * 10) audio_source.close() - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) + assert ( + ads_data == audio_source_data + ), "Unexpected data read from RecorderADS" def test_Overlap_Deco_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1714 hop_size = 313 @@ -312,24 +260,17 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for i, block in enumerate(ads_data): - tmp = audio_source.read(block_size) - - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - + assert ( + block == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (i + 1) * hop_size audio_source.close() def test_Limiter_Overlap_Deco_read(self): - block_size = 256 hop_size = 200 @@ -359,21 +300,17 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for i, block in enumerate(ads_data): tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) - self.assertEqual( - len(block), - len(tmp), - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert len(block) == len( + tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (i + 1) * hop_size audio_source.close() def test_Limiter_Overlap_Deco_read_limit(self): - block_size = 313 hop_size = 207 ads = ADSFactory.ads( @@ -392,9 +329,7 @@ total_samples_with_overlap = ( first_read_size + next_read_size * nb_next_blocks + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels cache_size = (block_size - hop_size) * ads.sample_width * ads.channels total_read = cache_size @@ -409,12 +344,11 @@ total_read += len(block) - cache_size ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) def test_Recorder_Overlap_Deco_is_rewindable(self): @@ -424,12 +358,9 @@ hop_size=160, record=True, ) - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) + assert ads.rewindable, "RecorderADS.is_rewindable should return True" def test_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1600 hop_size = 400 @@ -461,24 +392,18 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + ads.read() == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 1600 hop_size = 400 @@ -511,24 +436,18 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + ads.read() == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): - # Use arbitrary valid block_size and hop_size block_size = 1000 hop_size = 200 @@ -549,9 +468,7 @@ total_samples_with_overlap = ( first_read_size + next_read_size * nb_next_blocks + last_block_size ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) + expected_read_bytes = total_samples_with_overlap * ads.sw * ads.channels cache_size = (block_size - hop_size) * ads.sample_width * ads.channels total_read = cache_size @@ -566,17 +483,16 @@ total_read += len(block) - cache_size ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read ) -class TestADSFactoryBufferAudioSource(unittest.TestCase): - def setUp(self): +class TestADSFactoryBufferAudioSource: + def setup_method(self): self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.ads = ADSFactory.ads( data_buffer=self.signal, @@ -588,32 +504,23 @@ def test_ADS_BAS_sampling_rate(self): srate = self.ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_ADS_BAS_sample_width(self): swidth = self.ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_ADS_BAS_channels(self): channels = self.ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - # Use arbitrary valid block_size and hop_size block_size = 5 hop_size = 4 @@ -646,20 +553,14 @@ ) audio_source.open() - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting + # Compare all blocks read from OverlapADS to those read from an audio source with a manual position setting for j in range(i): - tmp = audio_source.read(block_size) - block = ads.read() - - self.assertEqual( - block, - tmp, - "Unexpected block '{}' (N={}) read from OverlapADS".format( - block, i - ), + assert ( + block == tmp + ), "Unexpected block '{}' (N={}) read from OverlapADS".format( + block, i ) audio_source.position = (j + 1) * hop_size @@ -667,8 +568,8 @@ audio_source.close() -class TestADSFactoryAlias(unittest.TestCase): - def setUp(self): +class TestADSFactoryAlias: + def setup_method(self): self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" def test_sampling_rate_alias(self): @@ -680,11 +581,9 @@ block_dur=0.5, ) srate = ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) + assert ( + srate == 16 + ), "Wrong sampling rate, expected: 16000, found: {0}".format(srate) def test_sampling_rate_duplicate(self): func = partial( @@ -695,7 +594,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_sample_width_alias(self): ads = ADSFactory.ads( @@ -706,11 +606,9 @@ block_dur=0.5, ) swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) + assert ( + swidth == 2 + ), "Wrong sample width, expected: 2, found: {0}".format(swidth) def test_sample_width_duplicate(self): func = partial( @@ -721,7 +619,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_channels_alias(self): ads = ADSFactory.ads( @@ -732,13 +631,9 @@ block_dur=4, ) channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) + assert ( + channels == 1 + ), "Wrong number of channels, expected: 1, found: {0}".format(channels) def test_channels_duplicate(self): func = partial( @@ -749,7 +644,8 @@ ch=1, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_size_alias(self): ads = ADSFactory.ads( @@ -760,12 +656,10 @@ bs=8, ) size = ads.block_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), + assert ( + size == 8 + ), "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size ) def test_block_size_duplicate(self): @@ -778,7 +672,8 @@ bs=4, block_size=4, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_duration_alias(self): ads = ADSFactory.ads( @@ -788,13 +683,9 @@ channels=1, bd=0.75, ) - # 0.75 ms = 0.75 * 16 = 12 size = ads.block_size - err_msg = "Wrong block_size set with a block_dur alias 'bd', " - err_msg += "expected: 8, found: {0}" - self.assertEqual( - size, 12, err_msg.format(size), - ) + err_msg = "Wrong block_size set with a block_dur alias 'bd', expected: 8, found: {0}" + assert size == 12, err_msg.format(size) def test_block_duration_duplicate(self): func = partial( @@ -806,7 +697,8 @@ bd=4, block_dur=4, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_block_size_duration_duplicate(self): func = partial( @@ -818,10 +710,10 @@ bd=4, bs=12, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_duration_alias(self): - ads = ADSFactory.ads( data_buffer=self.signal, sampling_rate=16, @@ -831,16 +723,13 @@ hd=0.5, ) size = ads.hop_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), + assert ( + size == 8 + ), "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size ) def test_hop_duration_duplicate(self): - func = partial( ADSFactory.ads, data_buffer=self.signal, @@ -851,7 +740,8 @@ hd=0.5, hop_dur=0.5, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_size_duration_duplicate(self): func = partial( @@ -864,7 +754,8 @@ hs=4, hd=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_hop_size_greater_than_block_size(self): func = partial( @@ -876,16 +767,17 @@ bs=4, hs=8, ) - self.assertRaises(ValueError, func) + with pytest.raises(ValueError): + func() def test_filename_duplicate(self): - func = partial( ADSFactory.ads, fn=dataset.one_to_six_arabic_16000_mono_bc_noise, filename=dataset.one_to_six_arabic_16000_mono_bc_noise, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_data_buffer_duplicate(self): func = partial( @@ -896,7 +788,8 @@ sample_width=2, channels=1, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_max_time_alias(self): ads = ADSFactory.ads( @@ -907,12 +800,10 @@ mt=10, block_dur=0.5, ) - self.assertEqual( - ads.max_read, - 10, - "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( - ads.max_read - ), + assert ( + ads.max_read == 10 + ), "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( + ads.max_read ) def test_max_time_duplicate(self): @@ -925,8 +816,8 @@ mt=True, max_time=True, ) - - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_record_alias(self): ads = ADSFactory.ads( @@ -937,9 +828,7 @@ rec=True, block_dur=0.5, ) - self.assertTrue( - ads.rewindable, "AudioDataSource.rewindable expected to be True" - ) + assert ads.rewindable, "AudioDataSource.rewindable expected to be True" def test_record_duplicate(self): func = partial( @@ -951,10 +840,10 @@ rec=True, record=True, ) - self.assertRaises(DuplicateArgument, func) + with pytest.raises(DuplicateArgument): + func() def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): - # Use arbitrary valid block_size and hop_size block_size = 5 hop_size = 4 @@ -987,16 +876,13 @@ ) audio_source.open() - # Compare all blocks read from AudioDataSource to those read - # from an audio source with manual position definition + # Compare all blocks read from AudioDataSource to those read from an audio source with manual position definition for j in range(i): tmp = audio_source.read(block_size) block = ads.read() - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) + assert ( + block == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) audio_source.position = (j + 1) * hop_size ads.close() audio_source.close() @@ -1012,68 +898,78 @@ return b"".join(blocks) -@genty -class TestAudioReader(unittest.TestCase): +@pytest.mark.parametrize( + "file_id, max_read, size", + [ + ("mono_400", 0.5, 16000), # mono + ("3channel_400-800-1600", 0.5, 16000 * 3), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Limiter(file_id, max_read, size): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read(size) - # TODO move all tests here when backward compatibility - # with ADSFactory is dropped + reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) + reader.open() + data = _read_all_data(reader) + reader.close() + assert data == expected - @genty_dataset( - mono=("mono_400", 0.5, 16000), - multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), - ) - def test_Limiter(self, file_id, max_read, size): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read(size) - reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) - reader.open() +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = AudioReader(input_wav, block_dur=0.1, record=True) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - reader.close() - self.assertEqual(data, expected) + assert data == expected + assert data == reader.data + reader.close() - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - reader = AudioReader(input_wav, block_dur=0.1, record=True) - reader.open() +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder_alias(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = Recorder(input_wav, block_dur=0.1) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder_alias(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - - reader = Recorder(input_wav, block_dur=0.1) - reader.open() - data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - -if __name__ == "__main__": - unittest.main() + assert data == expected + assert data == reader.data + reader.close()
--- a/tests/test_AudioSource.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_AudioSource.py Sat May 25 21:54:13 2024 +0200 @@ -1,9 +1,9 @@ """ @author: Amine Sehili <amine.sehili@gmail.com> """ + from array import array -import unittest -from genty import genty, genty_dataset +import pytest from auditok.io import ( AudioParameterError, BufferAudioSource, @@ -24,202 +24,166 @@ yield data -@genty -class TestAudioSource(unittest.TestCase): +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_BufferAudioSource_read_all(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) + with open(file, "rb") as fp: + expected = fp.read() + channels = len(frequencies) + audio_source = BufferAudioSource(expected, 16000, 2, channels) + audio_source.open() + data = audio_source.read(None) + assert data == expected + audio_source.rewind() + data = audio_source.read(-10) + assert data == expected + audio_source.close() - # TODO when use_channel is None, return samples from all channels - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_BufferAudioSource_read_all(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) - with open(file, "rb") as fp: - expected = fp.read() - channels = len(frequencies) - audio_source = BufferAudioSource(expected, 16000, 2, channels) - audio_source.open() - data = audio_source.read(None) - self.assertEqual(data, expected) - audio_source.rewind() - data = audio_source.read(-10) - self.assertEqual(data, expected) - audio_source.close() +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_RawAudioSource(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) + channels = len(frequencies) + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = b"".join(audio_source_read_all_gen(audio_source)) + audio_source.close() + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + fmt = FORMAT[audio_source.sample_width] + expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_RawAudioSource(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) - channels = len(frequencies) - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = b"".join(audio_source_read_all_gen(audio_source)) - audio_source.close() - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() + assert data_read_all == expected - self.assertEqual(data_read_all, expected) + # assert read all data with None + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + assert data_read_all == expected - # assert read all data with None - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = audio_source.read(None) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with a negative size + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + assert data_read_all == expected - # assert read all data with a negative size - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = audio_source.read(-10) - audio_source.close() - self.assertEqual(data_read_all, expected) - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_WaveAudioSource(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.wav".format(file_suffix) - audio_source = WaveAudioSource(file) - audio_source.open() - data = b"".join(audio_source_read_all_gen(audio_source)) - audio_source.close() - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_WaveAudioSource(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.wav".format(file_suffix) + audio_source = WaveAudioSource(file) + audio_source.open() + data = b"".join(audio_source_read_all_gen(audio_source)) + audio_source.close() + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + fmt = FORMAT[audio_source.sample_width] + expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - self.assertEqual(data, expected) + assert data == expected - # assert read all data with None - audio_source = WaveAudioSource(file) - audio_source.open() - data_read_all = audio_source.read(None) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with None + audio_source = WaveAudioSource(file) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + assert data_read_all == expected - # assert read all data with a negative size - audio_source = WaveAudioSource(file) - audio_source.open() - data_read_all = audio_source.read(-10) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with a negative size + audio_source = WaveAudioSource(file) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + assert data_read_all == expected -@genty -class TestBufferAudioSource_SR10_SW1_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR10_SW1_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=10, sample_width=1, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr10_sw1_ch1_read_1(self): block = self.audio_source.read(1) exp = b"A" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEF" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"A" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"BCDEFG" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(13) exp = b"HIJKLMNOPQRST" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"UVWXYZ012345" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {}, found: {} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {}, found: {} ".format(None, block), - ) + assert block is None def test_sr10_sw1_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 10, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 10 def test_sr10_sw1_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 1, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 1 def test_sr10_sw1_ch1_channels(self): channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) + assert channels == 1 - @genty_dataset( - empty=([], 0, 0, 0), - zero=([0], 0, 0, 0), - five=([5], 5, 0.5, 500), - multiple=([5, 20], 25, 2.5, 2500), + @pytest.mark.parametrize( + "block_sizes, expected_sample, expected_second, expected_ms", + [ + ([], 0, 0, 0), # empty + ([0], 0, 0, 0), # zero + ([5], 5, 0.5, 500), # five + ([5, 20], 25, 2.5, 2500), # multiple + ], + ids=["empty", "zero", "five", "multiple"], ) def test_position( self, block_sizes, expected_sample, expected_second, expected_ms @@ -227,38 +191,24 @@ for block_size in block_sizes: self.audio_source.read(block_size) position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(1, 1, 0.1, 100), - ten=(10, 10, 1, 1000), - negative_1=(-1, 31, 3.1, 3100), - negative_2=(-7, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (1, 1, 0.1, 100), # one + (10, 10, 1, 1000), # ten + (-1, 31, 3.1, 3100), # negative_1 + (-7, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_setter( self, position, expected_sample, expected_second, expected_ms @@ -266,38 +216,24 @@ self.audio_source.position = position position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(0.1, 1, 0.1, 100), - ten=(1, 10, 1, 1000), - negative_1=(-0.1, 31, 3.1, 3100), - negative_2=(-0.7, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position_s, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (0.1, 1, 0.1, 100), # one + (1, 10, 1, 1000), # ten + (-0.1, 31, 3.1, 3100), # negative_1 + (-0.7, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_s_setter( self, position_s, expected_sample, expected_second, expected_ms @@ -305,38 +241,24 @@ self.audio_source.position_s = position_s position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(100, 1, 0.1, 100), - ten=(1000, 10, 1, 1000), - negative_1=(-100, 31, 3.1, 3100), - negative_2=(-700, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position_ms, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (100, 1, 0.1, 100), # one + (1000, 10, 1, 1000), # ten + (-100, 31, 3.1, 3100), # negative_1 + (-700, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_ms_setter( self, position_ms, expected_sample, expected_second, expected_ms @@ -344,222 +266,157 @@ self.audio_source.position_ms = position_ms position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset(positive=((100,)), negative=(-100,)) + @pytest.mark.parametrize( + "position", + [ + 100, # positive + -100, # negative + ], + ids=["positive", "negative"], + ) def test_position_setter_out_of_range(self, position): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position = position - @genty_dataset(positive=((100,)), negative=(-100,)) + @pytest.mark.parametrize( + "position_s", + [ + 100, # positive + -100, # negative + ], + ids=["positive", "negative"], + ) def test_position_s_setter_out_of_range(self, position_s): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position_s = position_s - @genty_dataset(positive=((10000,)), negative=(-10000,)) + @pytest.mark.parametrize( + "position_ms", + [ + 10000, # positive + -10000, # negative + ], + ids=["positive", "negative"], + ) def test_position_ms_setter_out_of_range(self, position_ms): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position_ms = position_ms def test_sr10_sw1_ch1_initial_position_s_0(self): tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr10_sw1_ch1_position_s_1_after_read(self): srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr10_sw1_ch1_position_s_2_5(self): # read 2.5 seconds self.audio_source.read(25) tp = self.audio_source.position_s - self.assertEqual( - tp, - 2.5, - msg="wrong time position, expected: 2.5, found: {0} ".format(tp), - ) + assert tp == 2.5 def test_sr10_sw1_ch1_position_s_0(self): self.audio_source.read(10) self.audio_source.position_s = 0 tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr10_sw1_ch1_position_s_1(self): self.audio_source.position_s = 1 tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr10_sw1_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 def test_sr10_sw1_ch1_read_closed(self): self.audio_source.close() - with self.assertRaises(Exception): + with pytest.raises(Exception): self.audio_source.read(1) -@genty -class TestBufferAudioSource_SR16_SW2_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR16_SW2_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=16, sample_width=2, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr16_sw2_ch1_read_1(self): block = self.audio_source.read(1) exp = b"AB" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEFGHIJKL" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"AB" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"CDEFGHIJKLMN" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(5) exp = b"OPQRSTUVWX" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"YZ012345" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {0}, found: {1} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {0}, found: {1} ".format(None, block), - ) + assert block is None def test_sr16_sw2_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 16, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 16 def test_sr16_sw2_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 2, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 2 def test_sr16_sw2_ch1_channels(self): + channels = self.audio_source.channels + assert channels == 1 - channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) - - @genty_dataset( - empty=([], 0, 0, 0), - zero=([0], 0, 0, 0), - two=([2], 2, 2 / 16, int(2000 / 16)), - eleven=([11], 11, 11 / 16, int(11 * 1000 / 16)), - multiple=([4, 8], 12, 0.75, 750), + @pytest.mark.parametrize( + "block_sizes, expected_sample, expected_second, expected_ms", + [ + ([], 0, 0, 0), # empty + ([0], 0, 0, 0), # zero + ([2], 2, 2 / 16, int(2000 / 16)), # two + ([11], 11, 11 / 16, int(11 * 1000 / 16)), # eleven + ([4, 8], 12, 0.75, 750), # multiple + ], + ids=["empty", "zero", "two", "eleven", "multiple"], ) def test_position( self, block_sizes, expected_sample, expected_second, expected_ms @@ -567,46 +424,30 @@ for block_size in block_sizes: self.audio_source.read(block_size) position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms def test_sr16_sw2_ch1_read_position_0(self): self.audio_source.read(10) self.audio_source.position = 0 pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 - @genty_dataset( - zero=(0, 0, 0, 0), - one=(1, 1, 1 / 16, int(1000 / 16)), - ten=(10, 10, 10 / 16, int(10000 / 16)), - negative_1=(-1, 15, 15 / 16, int(15000 / 16)), - negative_2=(-7, 9, 9 / 16, int(9000 / 16)), + @pytest.mark.parametrize( + "position, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (1, 1, 1 / 16, int(1000 / 16)), # one + (10, 10, 10 / 16, int(10000 / 16)), # ten + (-1, 15, 15 / 16, int(15000 / 16)), # negative_1 + (-7, 9, 9 / 16, int(9000 / 16)), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_setter( self, position, expected_sample, expected_second, expected_ms @@ -614,39 +455,25 @@ self.audio_source.position = position position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(0.1, 1, 1 / 16, int(1000 / 16)), - two=(1 / 8, 2, 1 / 8, int(1 / 8 * 1000)), - twelve=(0.75, 12, 0.75, 750), - negative_1=(-0.1, 15, 15 / 16, int(15000 / 16)), - negative_2=(-0.7, 5, 5 / 16, int(5000 / 16)), + @pytest.mark.parametrize( + "position_s, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (0.1, 1, 1 / 16, int(1000 / 16)), # one + (1 / 8, 2, 1 / 8, int(1 / 8 * 1000)), # two + (0.75, 12, 0.75, 750), # twelve + (-0.1, 15, 15 / 16, int(15000 / 16)), # negative_1 + (-0.7, 5, 5 / 16, int(5000 / 16)), # negative_2 + ], + ids=["zero", "one", "two", "twelve", "negative_1", "negative_2"], ) def test_position_s_setter( self, position_s, expected_sample, expected_second, expected_ms @@ -654,39 +481,25 @@ self.audio_source.position_s = position_s position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(100, 1, 1 / 16, int(1000 / 16)), - ten=(1000, 16, 1, 1000), - negative_1=(-100, 15, 15 / 16, int(15 * 1000 / 16)), - negative_2=(-500, 8, 0.5, 500), - negative_3=(-700, 5, 5 / 16, int(5 * 1000 / 16)), + @pytest.mark.parametrize( + "position_ms, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (100, 1, 1 / 16, int(1000 / 16)), # one + (1000, 16, 1, 1000), # ten + (-100, 15, 15 / 16, int(15 * 1000 / 16)), # negative_1 + (-500, 8, 0.5, 500), # negative_2 + (-700, 5, 5 / 16, int(5 * 1000 / 16)), # negative_3 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2", "negative_3"], ) def test_position_ms_setter( self, position_ms, expected_sample, expected_second, expected_ms @@ -694,266 +507,162 @@ self.audio_source.position_ms = position_ms position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms def test_sr16_sw2_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 -class TestBufferAudioSource_SR11_SW4_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR11_SW4_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefgh" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=11, sample_width=4, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr11_sw4_ch1_read_1(self): block = self.audio_source.read(1) exp = b"ABCD" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEFGHIJKLMNOPQRSTUVWX" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"ABCD" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"EFGHIJKLMNOPQRSTUVWXYZ01" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(3) exp = b"23456789abcd" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"efgh" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {0}, found: {1} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {0}, found: {1} ".format(None, block), - ) + assert block is None def test_sr11_sw4_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 11, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 11 def test_sr11_sw4_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 4, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 4 def test_sr11_sw4_ch1_channels(self): channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) + assert channels == 1 def test_sr11_sw4_ch1_intial_position_0(self): pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 def test_sr11_sw4_ch1_position_5(self): self.audio_source.read(5) pos = self.audio_source.position - self.assertEqual( - pos, 5, msg="wrong position, expected: 5, found: {0} ".format(pos) - ) + assert pos == 5 def test_sr11_sw4_ch1_position_9(self): self.audio_source.read(5) self.audio_source.read(4) pos = self.audio_source.position - self.assertEqual( - pos, 9, msg="wrong position, expected: 5, found: {0} ".format(pos) - ) + assert pos == 9 def test_sr11_sw4_ch1_position_0(self): self.audio_source.read(10) self.audio_source.position = 0 pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 def test_sr11_sw4_ch1_position_10(self): self.audio_source.position = 10 pos = self.audio_source.position - self.assertEqual( - pos, - 10, - msg="wrong position, expected: 10, found: {0} ".format(pos), - ) + assert pos == 10 def test_sr11_sw4_ch1_initial_position_s_0(self): tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr11_sw4_ch1_position_s_1_after_read(self): srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr11_sw4_ch1_position_s_0_63(self): # read 2.5 seconds self.audio_source.read(7) tp = self.audio_source.position_s - self.assertAlmostEqual( - tp, - 0.636363636364, - msg="wrong time position, expected: 0.636363636364, " - "found: {0} ".format(tp), - ) + assert tp, pytest.approx(0.636363636364) def test_sr11_sw4_ch1_position_s_0(self): self.audio_source.read(10) self.audio_source.position_s = 0 tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr11_sw4_ch1_position_s_1(self): self.audio_source.position_s = 1 tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr11_sw4_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 -class TestBufferAudioSourceCreationException(unittest.TestCase): +class TestBufferAudioSourceCreationException: def test_wrong_sample_width_value(self): - with self.assertRaises(AudioParameterError) as audio_param_err: + with pytest.raises(AudioParameterError) as audio_param_err: _ = BufferAudioSource( data=b"ABCDEFGHI", sampling_rate=9, sample_width=3, channels=1 ) - self.assertEqual( - "Sample width must be one of: 1, 2 or 4 (bytes)", - str(audio_param_err.exception), + assert ( + str(audio_param_err.value) + == "Sample width must be one of: 1, 2 or 4 (bytes)" ) def test_wrong_data_buffer_size(self): - with self.assertRaises(AudioParameterError) as audio_param_err: + with pytest.raises(AudioParameterError) as audio_param_err: _ = BufferAudioSource( data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 ) - self.assertEqual( - "The length of audio data must be an integer " - "multiple of `sample_width * channels`", - str(audio_param_err.exception), + assert ( + str(audio_param_err.value) + == "The length of audio data must be an integer multiple of `sample_width * channels`" ) -class TestAudioSourceProperties(unittest.TestCase): +class TestAudioSourceProperties: def test_read_properties(self): data = b"" sampling_rate = 8000 @@ -963,9 +672,9 @@ data, sampling_rate, sample_width, channels ) - self.assertEqual(a_source.sampling_rate, sampling_rate) - self.assertEqual(a_source.sample_width, sample_width) - self.assertEqual(a_source.channels, channels) + assert a_source.sampling_rate == sampling_rate + assert a_source.sample_width == sample_width + assert a_source.channels == channels def test_set_readonly_properties_exception(self): data = b"" @@ -976,13 +685,13 @@ data, sampling_rate, sample_width, channels ) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): a_source.sampling_rate = 16000 a_source.sample_width = 1 a_source.channels = 2 -class TestAudioSourceShortProperties(unittest.TestCase): +class TestAudioSourceShortProperties: def test_read_short_properties(self): data = b"" sampling_rate = 8000 @@ -992,9 +701,9 @@ data, sampling_rate, sample_width, channels ) - self.assertEqual(a_source.sr, sampling_rate) - self.assertEqual(a_source.sw, sample_width) - self.assertEqual(a_source.ch, channels) + assert a_source.sr == sampling_rate + assert a_source.sw == sample_width + assert a_source.ch == channels def test_set_readonly_short_properties_exception(self): data = b"" @@ -1005,11 +714,7 @@ data, sampling_rate, sample_width, channels ) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): a_source.sr = 16000 a_source.sw = 1 a_source.ch = 2 - - -if __name__ == "__main__": - unittest.main()
--- a/tests/test_StreamTokenizer.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_StreamTokenizer.py Sat May 25 21:54:13 2024 +0200 @@ -1,10 +1,4 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 - -""" - -import unittest +import pytest from auditok import StreamTokenizer, StringDataSource, DataValidator @@ -13,1017 +7,612 @@ return frame == "A" -class TestStreamTokenizerInitParams(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() +@pytest.fixture +def validator(): + return AValidator() - # Completely deactivate init_min and init_max_silence - # The tokenizer will only rely on the other parameters - # Note that if init_min = 0, the value of init_max_silence - # will have no effect - def test_init_min_0_init_max_silence_0(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=0, - init_max_silence=0, - mode=0, - ) +def test_init_min_0_init_max_silence_0(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=0, + init_max_silence=0, + mode=0, + ) - data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaaAAAAAAAA") - # ^ ^ ^ ^ - # 2 16 20 27 - tokens = tokenizer.tokenize(data_source) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaaAAAAAAAA") + # ^ ^ ^ ^ + # 2 16 20 27 + tokens = tokenizer.tokenize(data_source) - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + assert ( + len(tokens) == 2 + ), f"wrong number of tokens, expected: 2, found: {len(tokens)}" + tok1, tok2 = tokens[0], tokens[1] - # tok1[0]: data - # tok1[1]: start frame (included) - # tok1[2]: end frame (included) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaaaa" + ), f"wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', found: {data}" + assert ( + start == 1 + ), f"wrong start frame for token 1, expected: 1, found: {start}" + assert end == 16, f"wrong end frame for token 1, expected: 16, found: {end}" - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', " - "found: {0} " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAA" + ), f"wrong data for token 2, expected: 'AAAAAAAA', found: {data}" + assert ( + start == 20 + ), f"wrong start frame for token 2, expected: 20, found: {start}" + assert end == 27, f"wrong end frame for token 2, expected: 27, found: {end}" - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: {0} " - ).format(data), - ) - self.assertEqual( - start, - 20, - msg=( - "wrong start frame for token 2, expected: 20, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 27, - msg=( - "wrong end frame for token 2, expected: 27, found: {0} " - ).format(end), - ) - # A valid token is considered as so iff the tokenizer encounters - # at least valid frames (init_min = 3) between witch there - # are at most 0 consecutive non valid frames (init_max_silence = 0) - # The tokenizer will only rely on the other parameters - # In other words, a valid token must start with 3 valid frames - def test_init_min_3_init_max_silence_0(self): +def test_init_min_3_init_max_silence_0(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=0, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=0, - mode=0, - ) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaaAAAAA") + # ^ ^ ^ ^ + # 18 30 33 37 - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaaAAAAA" - ) - # ^ ^ ^ ^ - # 18 30 33 37 + tokens = tokenizer.tokenize(data_source) - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 2 + ), f"wrong number of tokens, expected: 2, found: {len(tokens)}" + tok1, tok2 = tokens[0], tokens[1] - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAAAaaaa" + ), f"wrong data for token 1, expected: 'AAAAAAAAAaaaa', found: '{data}'" + assert ( + start == 18 + ), f"wrong start frame for token 1, expected: 18, found: {start}" + assert end == 30, f"wrong end frame for token 1, expected: 30, found: {end}" - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAAAaaaa", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 1, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 30, - msg=( - "wrong end frame for token 1, expected: 30, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 2, expected: 'AAAAA', found: '{data}'" + assert ( + start == 33 + ), f"wrong start frame for token 2, expected: 33, found: {start}" + assert end == 37, f"wrong end frame for token 2, expected: 37, found: {end}" - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 33, - msg=( - "wrong start frame for token 2, expected: 33, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 37, - msg=( - "wrong end frame for token 2, expected: 37, found: {0} " - ).format(end), - ) - # A valid token is considered iff the tokenizer encounters - # at least valid frames (init_min = 3) between witch there - # are at most 2 consecutive non valid frames (init_max_silence = 2) - def test_init_min_3_init_max_silence_2(self): +def test_init_min_3_init_max_silence_2(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=2, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=2, - mode=0, - ) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaAAAAAAAAAaaaaaaaAAAAA") + # ^ ^ ^ ^ ^ ^ + # 5 16 19 31 35 39 + tokens = tokenizer.tokenize(data_source) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaaAAAAAAAAAaaaaaaaAAAAA" - ) - # ^ ^ ^ ^ ^ ^ - # 5 16 19 31 35 39 - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 3 + ), f"wrong number of tokens, expected: 3, found: {len(tokens)}" + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaAaaAaAaaaa" + ), f"wrong data for token 1, expected: 'AaAaaAaA', found: '{data}'" + assert ( + start == 5 + ), f"wrong start frame for token 1, expected: 5, found: {start}" + assert end == 16, f"wrong end frame for token 1, expected: 16, found: {end}" - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaAaaAaA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 5, - msg=( - "wrong start frame for token 1, expected: 5, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAAAaaaa" + ), f"wrong data for token 2, expected: 'AAAAAAAAAaaaa', found: '{data}'" + assert ( + start == 19 + ), f"wrong start frame for token 2, expected: 19, found: {start}" + assert end == 31, f"wrong end frame for token 2, expected: 31, found: {end}" - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAAAaaaa", - msg=( - "wrong data for token 2, expected: 'AAAAAAAAAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 19, - msg=( - "wrong start frame for token 2, expected: 19, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 31, - msg=( - "wrong end frame for token 2, expected: 31, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 3, expected: 'AAAAA', found: '{data}'" + assert ( + start == 35 + ), f"wrong start frame for token 3, expected: 35, found: {start}" + assert end == 39, f"wrong end frame for token 3, expected: 39, found: {end}" - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 3, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 35, - msg=( - "wrong start frame for token 2, expected: 35, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 39, - msg=( - "wrong end frame for token 2, expected: 39, found: {0} " - ).format(end), - ) +@pytest.fixture +def tokenizer_min_max_length(validator): + return StreamTokenizer( + validator, + min_length=6, + max_length=20, + max_continuous_silence=2, + init_min=3, + init_max_silence=3, + mode=0, + ) -class TestStreamTokenizerMinMaxLength(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_min_length_6_init_max_length_20(self): +def test_min_length_6_init_max_length_20(tokenizer_min_max_length): + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") + # ^ ^ ^ ^ + # 1 14 18 28 - tokenizer = StreamTokenizer( - self.A_validator, - min_length=6, - max_length=20, - max_continuous_silence=2, - init_min=3, - init_max_silence=3, - mode=0, - ) + tokens = tokenizer_min_max_length.tokenize(data_source) - data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") - # ^ ^ ^ ^ - # 1 14 18 28 + assert ( + len(tokens) == 2 + ), f"wrong number of tokens, expected: 2, found: {len(tokens)}" + tok1, tok2 = tokens[0], tokens[1] - tokens = tokenizer.tokenize(data_source) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaa" + ), f"wrong data for token 1, expected: 'AaaaAaAaaAaAaa', found: '{data}'" + assert ( + start == 1 + ), f"wrong start frame for token 1, expected: 1, found: {start}" + assert end == 14, f"wrong end frame for token 1, expected: 14, found: {end}" - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAAAaa" + ), f"wrong data for token 2, expected: 'AAAAAAAAAaa', found: '{data}'" + assert ( + start == 18 + ), f"wrong start frame for token 2, expected: 18, found: {start}" + assert end == 28, f"wrong end frame for token 2, expected: 28, found: {end}" - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 14, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAAAaa", - msg=( - "wrong data for token 2, expected: 'AAAAAAAAAaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 2, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 28, - msg=( - "wrong end frame for token 2, expected: 28, found: {0} " - ).format(end), - ) +@pytest.fixture +def tokenizer_min_max_length_1_1(validator): + return StreamTokenizer( + validator, + min_length=1, + max_length=1, + max_continuous_silence=0, + init_min=0, + init_max_silence=0, + mode=0, + ) - def test_min_length_1_init_max_length_1(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=1, - max_length=1, - max_continuous_silence=0, - init_min=0, - init_max_silence=0, - mode=0, - ) +def test_min_length_1_init_max_length_1(tokenizer_min_max_length_1_1): + data_source = StringDataSource("AAaaaAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") - data_source = StringDataSource( - "AAaaaAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA" - ) + tokens = tokenizer_min_max_length_1_1.tokenize(data_source) - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 21 + ), f"wrong number of tokens, expected: 21, found: {len(tokens)}" - self.assertEqual( - len(tokens), - 21, - msg="wrong number of tokens, expected: 21, found: {0} ".format( - len(tokens) - ), - ) - def test_min_length_10_init_max_length_20(self): +@pytest.fixture +def tokenizer_min_max_length_10_20(validator): + return StreamTokenizer( + validator, + min_length=10, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=3, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=10, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=3, - mode=0, - ) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaaAAAAAaaaaaaAAAAAaaAAaaAAA" - ) - # ^ ^ ^ ^ - # 1 16 30 45 +def test_min_length_10_init_max_length_20(tokenizer_min_max_length_10_20): + data_source = StringDataSource( + "aAaaaAaAaaAaAaaaaaaAAAAAaaaaaaAAAAAaaAAaaAAA" + ) + # ^ ^ ^ ^ + # 1 16 30 45 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_min_max_length_10_20.tokenize(data_source) - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + assert ( + len(tokens) == 2 + ), f"wrong number of tokens, expected: 2, found: {len(tokens)}" + tok1, tok2 = tokens[0], tokens[1] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaaaa" + ), f"wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', found: '{data}'" + assert ( + start == 1 + ), f"wrong start frame for token 1, expected: 1, found: {start}" + assert end == 16, f"wrong end frame for token 1, expected: 16, found: {end}" - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAaaAAaaAAA", - msg=( - "wrong data for token 2, expected: 'AAAAAaaAAaaAAA', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 30, - msg=( - "wrong start frame for token 2, expected: 30, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 43, - msg=( - "wrong end frame for token 2, expected: 43, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAaaAAaaAAA" + ), f"wrong data for token 2, expected: 'AAAAAaaAAaaAAA', found: '{data}'" + assert ( + start == 30 + ), f"wrong start frame for token 2, expected: 30, found: {start}" + assert end == 43, f"wrong end frame for token 2, expected: 43, found: {end}" - def test_min_length_4_init_max_length_5(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=4, - max_length=5, - max_continuous_silence=4, - init_min=3, - init_max_silence=3, - mode=0, - ) +@pytest.fixture +def tokenizer_min_max_length_4_5(validator): + return StreamTokenizer( + validator, + min_length=4, + max_length=5, + max_continuous_silence=4, + init_min=3, + init_max_silence=3, + mode=0, + ) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaAAAAAAAAaaaaaaAAAAAaaaaaAAaaAaa" - ) - # ^ ^^ ^ ^ ^ ^ ^ - # 18 2223 27 32 36 42 46 - tokens = tokenizer.tokenize(data_source) +def test_min_length_4_init_max_length_5(tokenizer_min_max_length_4_5): + data_source = StringDataSource( + "aAaaaAaAaaAaAaaaaaAAAAAAAAaaaaaaAAAAAaaaaaAAaaAaa" + ) + # ^ ^^ ^ ^ ^ ^ ^ + # 18 2223 27 32 36 42 46 - self.assertEqual( - len(tokens), - 4, - msg="wrong number of tokens, expected: 4, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3, tok4 = tokens[0], tokens[1], tokens[2], tokens[3] + tokens = tokenizer_min_max_length_4_5.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 1, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 22, - msg=( - "wrong end frame for token 1, expected: 22, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 4 + ), f"wrong number of tokens, expected: 4, found: {len(tokens)}" + tok1, tok2, tok3, tok4 = tokens[0], tokens[1], tokens[2], tokens[3] - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAaa", - msg=( - "wrong data for token 1, expected: 'AAAaa', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 23, - msg=( - "wrong start frame for token 1, expected: 23, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 27, - msg=( - "wrong end frame for token 1, expected: 27, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 1, expected: 'AAAAA', found: '{data}'" + assert ( + start == 18 + ), f"wrong start frame for token 1, expected: 18, found: {start}" + assert end == 22, f"wrong end frame for token 1, expected: 22, found: {end}" - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 32, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 36, - msg=( - "wrong end frame for token 1, expected: 7, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAaa" + ), f"wrong data for token 2, expected: 'AAAaa', found: '{data}'" + assert ( + start == 23 + ), f"wrong start frame for token 2, expected: 23, found: {start}" + assert end == 27, f"wrong end frame for token 2, expected: 27, found: {end}" - data = "".join(tok4[0]) - start = tok4[1] - end = tok4[2] - self.assertEqual( - data, - "AAaaA", - msg=( - "wrong data for token 2, expected: 'AAaaA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 42, - msg=( - "wrong start frame for token 2, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 46, - msg=( - "wrong end frame for token 2, expected: 22, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 3, expected: 'AAAAA', found: '{data}'" + assert ( + start == 32 + ), f"wrong start frame for token 3, expected: 32, found: {start}" + assert end == 36, f"wrong end frame for token 3, expected: 36, found: {end}" + data = "".join(tok4[0]) + start = tok4[1] + end = tok4[2] + assert ( + data == "AAaaA" + ), f"wrong data for token 4, expected: 'AAaaA', found: '{data}'" + assert ( + start == 42 + ), f"wrong start frame for token 4, expected: 42, found: {start}" + assert end == 46, f"wrong end frame for token 4, expected: 46, found: {end}" -class TestStreamTokenizerMaxContinuousSilence(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_min_5_max_10_max_continuous_silence_0(self): +@pytest.fixture +def tokenizer_max_continuous_silence_0(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=0, + init_min=3, + init_max_silence=3, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=0, - init_min=3, - init_max_silence=3, - mode=0, - ) - data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") - # ^ ^ ^ ^ ^ ^ - # 3 7 9 14 17 25 +def test_min_5_max_10_max_continuous_silence_0( + tokenizer_max_continuous_silence_0, +): + data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") + # ^ ^ ^ ^ ^ ^ + # 3 7 9 14 17 25 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_max_continuous_silence_0.tokenize(data_source) - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + assert ( + len(tokens) == 3 + ), f"wrong number of tokens, expected: 3, found: {len(tokens)}" + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 3, - msg=( - "wrong start frame for token 1, expected: 3, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 7, - msg=( - "wrong end frame for token 1, expected: 7, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 1, expected: 'AAAAA', found: '{data}'" + assert ( + start == 3 + ), f"wrong start frame for token 1, expected: 3, found: {start}" + assert end == 7, f"wrong end frame for token 1, expected: 7, found: {end}" - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 9, - msg=( - "wrong start frame for token 1, expected: 9, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 14, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAA" + ), f"wrong data for token 2, expected: 'AAAAAA', found: '{data}'" + assert ( + start == 9 + ), f"wrong start frame for token 2, expected: 9, found: {start}" + assert end == 14, f"wrong end frame for token 2, expected: 14, found: {end}" - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 17, - msg=( - "wrong start frame for token 1, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 25, - msg=( - "wrong end frame for token 1, expected: 25, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAAAAAA" + ), f"wrong data for token 3, expected: 'AAAAAAAAA', found: '{data}'" + assert ( + start == 17 + ), f"wrong start frame for token 3, expected: 17, found: {start}" + assert end == 25, f"wrong end frame for token 3, expected: 25, found: {end}" - def test_min_5_max_10_max_continuous_silence_1(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=1, - init_min=3, - init_max_silence=3, - mode=0, - ) +@pytest.fixture +def tokenizer_max_continuous_silence_1(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=1, + init_min=3, + init_max_silence=3, + mode=0, + ) - data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") - # ^ ^^ ^ ^ ^ - # 3 12131517 26 - # (12 13 15 17) - tokens = tokenizer.tokenize(data_source) +def test_min_5_max_10_max_continuous_silence_1( + tokenizer_max_continuous_silence_1, +): + data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") + # ^ ^^ ^ ^ ^ + # 3 12131517 26 + # (12 13 15 17) - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + tokens = tokenizer_max_continuous_silence_1.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAaAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAaAAAA', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 3, - msg=( - "wrong start frame for token 1, expected: 3, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 12, - msg=( - "wrong end frame for token 1, expected: 10, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 3 + ), f"wrong number of tokens, expected: 3, found: {len(tokens)}" + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAa", - msg=( - "wrong data for token 1, expected: 'AAa', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 13, - msg=( - "wrong start frame for token 1, expected: 9, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 15, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAaAAAA" + ), f"wrong data for token 1, expected: 'AAAAAaAAAA', found: '{data}'" + assert ( + start == 3 + ), f"wrong start frame for token 1, expected: 3, found: {start}" + assert end == 12, f"wrong end frame for token 1, expected: 12, found: {end}" - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAAAAAAa", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAAa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 17, - msg=( - "wrong start frame for token 1, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 26, - msg=( - "wrong end frame for token 1, expected: 26, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAa" + ), f"wrong data for token 2, expected: 'AAa', found: '{data}'" + assert ( + start == 13 + ), f"wrong start frame for token 2, expected: 13, found: {start}" + assert end == 15, f"wrong end frame for token 2, expected: 15, found: {end}" + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAAAAAAa" + ), f"wrong data for token 3, expected: 'AAAAAAAAAa', found: '{data}'" + assert ( + start == 17 + ), f"wrong start frame for token 3, expected: 17, found: {start}" + assert end == 26, f"wrong end frame for token 3, expected: 26, found: {end}" -class TestStreamTokenizerModes(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_STRICT_MIN_LENGTH(self): +@pytest.fixture +def tokenizer_strict_min_length(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.STRICT_MIN_LENGTH, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.STRICT_MIN_LENGTH, - ) - data_source = StringDataSource("aaAAAAAAAAAAAA") - # ^ ^ - # 2 9 +def test_STRICT_MIN_LENGTH(tokenizer_strict_min_length): + data_source = StringDataSource("aaAAAAAAAAAAAA") + # ^ ^ + # 2 9 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_strict_min_length.tokenize(data_source) - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] + assert ( + len(tokens) == 1 + ), f"wrong number of tokens, expected: 1, found: {len(tokens)}" + tok1 = tokens[0] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 9, - msg=( - "wrong end frame for token 1, expected: 9, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAA" + ), f"wrong data for token 1, expected: 'AAAAAAAA', found: '{data}'" + assert ( + start == 2 + ), f"wrong start frame for token 1, expected: 2, found: {start}" + assert end == 9, f"wrong end frame for token 1, expected: 9, found: {end}" - def test_DROP_TAILING_SILENCE(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=2, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.DROP_TRAILING_SILENCE, - ) +@pytest.fixture +def tokenizer_drop_trailing_silence(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=2, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.DROP_TRAILING_SILENCE, + ) - data_source = StringDataSource("aaAAAAAaaaaa") - # ^ ^ - # 2 6 - tokens = tokenizer.tokenize(data_source) +def test_DROP_TAILING_SILENCE(tokenizer_drop_trailing_silence): + data_source = StringDataSource("aaAAAAAaaaaa") + # ^ ^ + # 2 6 - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] + tokens = tokenizer_drop_trailing_silence.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 6, - msg=( - "wrong end frame for token 1, expected: 6, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 1 + ), f"wrong number of tokens, expected: 1, found: {len(tokens)}" + tok1 = tokens[0] - def test_STRICT_MIN_LENGTH_and_DROP_TAILING_SILENCE(self): + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), f"wrong data for token 1, expected: 'AAAAA', found: '{data}'" + assert ( + start == 2 + ), f"wrong start frame for token 1, expected: 2, found: {start}" + assert end == 6, f"wrong end frame for token 1, expected: 6, found: {end}" - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.STRICT_MIN_LENGTH - | StreamTokenizer.DROP_TRAILING_SILENCE, - ) - data_source = StringDataSource("aaAAAAAAAAAAAAaa") - # ^ ^ - # 2 8 +@pytest.fixture +def tokenizer_strict_min_and_drop_trailing_silence(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.STRICT_MIN_LENGTH + | StreamTokenizer.DROP_TRAILING_SILENCE, + ) - tokens = tokenizer.tokenize(data_source) - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] +def test_STRICT_MIN_LENGTH_and_DROP_TAILING_SILENCE( + tokenizer_strict_min_and_drop_trailing_silence, +): + data_source = StringDataSource("aaAAAAAAAAAAAAaa") + # ^ ^ + # 2 8 - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 9, - msg=( - "wrong end frame for token 1, expected: 9, found: {0} " - ).format(end), - ) + tokens = tokenizer_strict_min_and_drop_trailing_silence.tokenize( + data_source + ) + assert ( + len(tokens) == 1 + ), f"wrong number of tokens, expected: 1, found: {len(tokens)}" + tok1 = tokens[0] -class TestStreamTokenizerCallback(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAA" + ), f"wrong data for token 1, expected: 'AAAAAAAA', found: '{data}'" + assert ( + start == 2 + ), f"wrong start frame for token 1, expected: 2, found: {start}" + assert end == 9, f"wrong end frame for token 1, expected: 9, found: {end}" - def test_callback(self): - tokens = [] +@pytest.fixture +def tokenizer_callback(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=0, + ) - def callback(data, start, end): - tokens.append((data, start, end)) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=0, - ) +def test_callback(tokenizer_callback): + tokens = [] - data_source = StringDataSource("aaAAAAAAAAAAAAa") - # ^ ^^ ^ - # 2 910 14 + def callback(data, start, end): + tokens.append((data, start, end)) - tokenizer.tokenize(data_source, callback=callback) + data_source = StringDataSource("aaAAAAAAAAAAAAa") + # ^ ^^ ^ + # 2 910 14 - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) + tokenizer_callback.tokenize(data_source, callback=callback) - -if __name__ == "__main__": - unittest.main() + assert ( + len(tokens) == 2 + ), f"wrong number of tokens, expected: 2, found: {len(tokens)}"
--- a/tests/test_cmdline_util.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_cmdline_util.py Sat May 25 21:54:13 2024 +0200 @@ -1,10 +1,8 @@ import os -import unittest -from unittest import TestCase -from unittest.mock import patch +import pytest from tempfile import TemporaryDirectory from collections import namedtuple -from genty import genty, genty_dataset +from unittest.mock import patch from auditok.cmdline_util import ( _AUDITOK_LOGGER, @@ -21,8 +19,8 @@ PrintWorker, ) -_ArgsNamespece = namedtuple( - "_ArgsNamespece", +_ArgsNamespace = namedtuple( + "_ArgsNamespace", [ "input", "max_read", @@ -57,285 +55,290 @@ ) -@genty -class TestCmdLineUtil(TestCase): - @genty_dataset( - no_record=("stream.ogg", None, False, None, "mix", "mix", False), - no_record_plot=("stream.ogg", None, True, None, None, None, False), - no_record_save_image=( - "stream.ogg", - None, - True, - "image.png", - None, - None, - False, - ), - record_plot=(None, None, True, None, None, None, True), - record_save_image=(None, None, False, "image.png", None, None, True), - int_use_channel=("stream.ogg", None, False, None, "1", 1, False), - save_detections_as=( - "stream.ogg", - "{id}.wav", - False, - None, - None, - None, - False, - ), - ) - def test_make_kwargs( - self, +@pytest.mark.parametrize( + "save_stream, save_detections_as, plot, save_image, use_channel, exp_use_channel, exp_record", + [ + # no_record + ("stream.ogg", None, False, None, "mix", "mix", False), + # no_record_plot + ("stream.ogg", None, True, None, None, None, False), + # no_record_save_image + ("stream.ogg", None, True, "image.png", None, None, False), + # record_plot + (None, None, True, None, None, None, True), + # record_save_image + (None, None, False, "image.png", None, None, True), + # int_use_channel + ("stream.ogg", None, False, None, "1", 1, False), + # save_detections_as + ("stream.ogg", "{id}.wav", False, None, None, None, False), + ], + ids=[ + "no_record", + "no_record_plot", + "no_record_save_image", + "record_plot", + "record_save_image", + "int_use_channel", + "save_detections_as", + ], +) +def test_make_kwargs( + save_stream, + save_detections_as, + plot, + save_image, + use_channel, + exp_use_channel, + exp_record, +): + args = ( + "file", + 30, + 0.01, + 16000, + 2, + 2, + use_channel, + "raw", + "ogg", + True, + None, + 1, save_stream, save_detections_as, plot, save_image, - use_channel, - exp_use_channel, - exp_record, - ): + 0.2, + 10, + 0.3, + False, + False, + 55, + ) + misc = ( + False, + False, + None, + True, + None, + "TIME_FORMAT", + "TIMESTAMP_FORMAT", + ) + args_ns = _ArgsNamespace(*(args + misc)) - args = ( - "file", - 30, - 0.01, - 16000, - 2, - 2, - use_channel, - "raw", - "ogg", - True, - None, - 1, - save_stream, - save_detections_as, - plot, - save_image, - 0.2, - 10, - 0.3, - False, - False, - 55, - ) - misc = ( - False, - False, - None, - True, - None, - "TIME_FORMAT", - "TIMESTAMP_FORMAT", - ) - args_ns = _ArgsNamespece(*(args + misc)) + io_kwargs = { + "input": "file", + "max_read": 30, + "block_dur": 0.01, + "sampling_rate": 16000, + "sample_width": 2, + "channels": 2, + "use_channel": exp_use_channel, + "save_stream": save_stream, + "save_detections_as": save_detections_as, + "audio_format": "raw", + "export_format": "ogg", + "large_file": True, + "frames_per_buffer": None, + "input_device_index": 1, + "record": exp_record, + } - io_kwargs = { - "input": "file", - "max_read": 30, - "block_dur": 0.01, - "sampling_rate": 16000, - "sample_width": 2, - "channels": 2, - "use_channel": exp_use_channel, - "save_stream": save_stream, - "save_detections_as": save_detections_as, - "audio_format": "raw", - "export_format": "ogg", - "large_file": True, - "frames_per_buffer": None, - "input_device_index": 1, - "record": exp_record, - } + split_kwargs = { + "min_dur": 0.2, + "max_dur": 10, + "max_silence": 0.3, + "drop_trailing_silence": False, + "strict_min_dur": False, + "energy_threshold": 55, + } - split_kwargs = { - "min_dur": 0.2, - "max_dur": 10, - "max_silence": 0.3, - "drop_trailing_silence": False, - "strict_min_dur": False, - "energy_threshold": 55, - } + miscellaneous = { + "echo": False, + "command": None, + "progress_bar": False, + "quiet": True, + "printf": None, + "time_format": "TIME_FORMAT", + "timestamp_format": "TIMESTAMP_FORMAT", + } - miscellaneous = { - "echo": False, - "command": None, - "progress_bar": False, - "quiet": True, - "printf": None, - "time_format": "TIME_FORMAT", - "timestamp_format": "TIMESTAMP_FORMAT", - } + expected = KeywordArguments(io_kwargs, split_kwargs, miscellaneous) + kwargs = make_kwargs(args_ns) + assert kwargs == expected - expected = KeywordArguments(io_kwargs, split_kwargs, miscellaneous) - kwargs = make_kwargs(args_ns) - self.assertEqual(kwargs, expected) - def test_make_logger_stderr_and_file(self): +def test_make_logger_stderr_and_file(capsys): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(stderr=True, file=file) + assert logger.name == _AUDITOK_LOGGER + assert len(logger.handlers) == 2 + assert logger.handlers[1].stream.name == file + logger.info("This is a debug message") + captured = capsys.readouterr() + assert "This is a debug message" in captured.err + + +def test_make_logger_None(): + logger = make_logger(stderr=False, file=None) + assert logger is None + + +def test_initialize_workers_all(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(stderr=True, file=file) - self.assertEqual(logger.name, _AUDITOK_LOGGER) - self.assertEqual(len(logger.handlers), 2) - self.assertEqual(logger.handlers[0].stream.name, "<stderr>") - self.assertEqual(logger.handlers[1].stream.name, file) - - def test_make_logger_None(self): - logger = make_logger(stderr=False, file=None) - self.assertIsNone(logger) - - def test_initialize_workers_all(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [ - RegionSaverWorker, - PlayerWorker, - CommandLineWorker, - PrintWorker, - ], - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_RegionSaverWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as=None, - echo=True, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, [PlayerWorker, CommandLineWorker, PrintWorker] - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_PlayerWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=False, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertFalse(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [RegionSaverWorker, CommandLineWorker, PrintWorker], - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_CommandLineWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command=None, - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, [RegionSaverWorker, PlayerWorker, PrintWorker] - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_PrintWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command="some command", - quiet=True, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [RegionSaverWorker, PlayerWorker, CommandLineWorker], - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_observers(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: + export_filename = os.path.join(tmpdir, "output.wav") reader, observers = initialize_workers( input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=None, + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + echo=True, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + observers, + [ + RegionSaverWorker, + PlayerWorker, + CommandLineWorker, + PrintWorker, + ], + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_RegionSaverWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, observers = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, export_format="wave", save_detections_as=None, echo=True, progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + observers, [PlayerWorker, CommandLineWorker, PrintWorker] + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_PlayerWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, observers = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + echo=False, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert not patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + observers, + [RegionSaverWorker, CommandLineWorker, PrintWorker], + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_CommandLineWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, observers = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + echo=True, + progress_bar=False, command=None, + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + observers, [RegionSaverWorker, PlayerWorker, PrintWorker] + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_PrintWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, observers = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + echo=True, + progress_bar=False, + command="some command", quiet=True, printf="abcd", time_format="%S", timestamp_format="%h:%M:%S", ) - self.assertTrue(patched_player_for.called) - self.assertFalse(isinstance(reader, StreamSaverWorker)) - self.assertTrue(len(observers), 0) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + observers, + [RegionSaverWorker, PlayerWorker, CommandLineWorker], + ): + assert isinstance(obs, cls) -if __name__ == "__main__": - unittest.main() +def test_initialize_workers_no_observers(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + reader, observers = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=None, + export_format="wave", + save_detections_as=None, + echo=True, + progress_bar=False, + command=None, + quiet=True, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + assert patched_player_for.called + assert not isinstance(reader, StreamSaverWorker) + assert len(observers) == 1
--- a/tests/test_core.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_core.py Sat May 25 21:54:13 2024 +0200 @@ -3,10 +3,8 @@ from random import random from tempfile import TemporaryDirectory from array import array as array_ -import unittest -from unittest import TestCase, mock -from unittest.mock import patch -from genty import genty, genty_dataset +import pytest +from unittest.mock import patch, Mock from auditok import load, split, AudioRegion, AudioParameterError from auditok.core import ( _duration_to_nb_windows, @@ -17,8 +15,6 @@ from auditok.util import AudioDataSource from auditok.io import get_audio_source -mock._magics.add("__round__") - def _make_random_length_regions( byte_seq, sampling_rate, sample_width, channels @@ -32,183 +28,178 @@ return regions -@genty -class TestFunctions(TestCase): - @genty_dataset( - no_skip_read_all=(0, -1), - no_skip_read_all_stereo=(0, -1, 2), - skip_2_read_all=(2, -1), - skip_2_read_all_None=(2, None), - skip_2_read_3=(2, 3), - skip_2_read_3_5_stereo=(2, 3.5, 2), - skip_2_4_read_3_5_stereo=(2.4, 3.5, 2), +@pytest.mark.parametrize( + "skip, max_read, channels", + [ + (0, -1, 1), + (0, -1, 2), + (2, -1, 1), + (2, None, 1), + (2, 3, 1), + (2, 3.5, 2), + (2.4, 3.5, 2), + ], + ids=[ + "no_skip_read_all", + "no_skip_read_all_stereo", + "skip_2_read_all", + "skip_2_read_all_None", + "skip_2_read_3", + "skip_2_read_3_5_stereo", + "skip_2_4_read_3_5_stereo", + ], +) +def test_load(skip, max_read, channels): + sampling_rate = 10 + sample_width = 2 + filename = "tests/data/test_split_10HZ_{}.raw" + filename = filename.format("mono" if channels == 1 else "stereo") + region = load( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_load(self, skip, max_read, channels=1): - sampling_rate = 10 - sample_width = 2 - filename = "tests/data/test_split_10HZ_{}.raw" - filename = filename.format("mono" if channels == 1 else "stereo") - region = load( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, + with open(filename, "rb") as fp: + fp.read(round(skip * sampling_rate * sample_width * channels)) + if max_read is None or max_read < 0: + to_read = -1 + else: + to_read = round(max_read * sampling_rate * sample_width * channels) + expected = fp.read(to_read) + assert bytes(region) == expected + + +@pytest.mark.parametrize( + "duration, analysis_window, round_fn, expected, kwargs", + [ + (0, 1, None, 0, None), + (0.3, 0.1, round, 3, None), + (0.35, 0.1, math.ceil, 4, None), + (0.35, 0.1, math.floor, 3, None), + (0.05, 0.1, round, 0, None), + (0.05, 0.1, math.ceil, 1, None), + (0.3, 0.1, math.floor, 3, {"epsilon": 1e-6}), + (-0.5, 0.1, math.ceil, ValueError, None), + (0.5, -0.1, math.ceil, ValueError, None), + ], + ids=[ + "zero_duration", + "multiple", + "not_multiple_ceil", + "not_multiple_floor", + "small_duration", + "small_duration_ceil", + "with_round_error", + "negative_duration", + "negative_analysis_window", + ], +) +def test_duration_to_nb_windows( + duration, analysis_window, round_fn, expected, kwargs +): + if expected == ValueError: + with pytest.raises(ValueError): + _duration_to_nb_windows(duration, analysis_window, round_fn) + else: + if kwargs is None: + kwargs = {} + result = _duration_to_nb_windows( + duration, analysis_window, round_fn, **kwargs ) - with open(filename, "rb") as fp: - fp.read(round(skip * sampling_rate * sample_width * channels)) - if max_read is None or max_read < 0: - to_read = -1 - else: - to_read = round( - max_read * sampling_rate * sample_width * channels - ) - expected = fp.read(to_read) - self.assertEqual(bytes(region), expected) + assert result == expected - @genty_dataset( - zero_duration=(0, 1, None, 0), - multiple=(0.3, 0.1, round, 3), - not_multiple_ceil=(0.35, 0.1, math.ceil, 4), - not_multiple_floor=(0.35, 0.1, math.floor, 3), - small_duration=(0.05, 0.1, round, 0), - small_duration_ceil=(0.05, 0.1, math.ceil, 1), - with_round_error=(0.3, 0.1, math.floor, 3, {"epsilon": 1e-6}), - negative_duration=(-0.5, 0.1, math.ceil, ValueError), - negative_analysis_window=(0.5, -0.1, math.ceil, ValueError), + +@pytest.mark.parametrize( + "channels, skip, max_read", + [ + (1, 0, None), + (1, 3, None), + (1, 2, -1), + (1, 2, 3), + (2, 0, None), + (2, 3, None), + (2, 2, -1), + (2, 2, 3), + ], + ids=[ + "mono_skip_0_max_read_None", + "mono_skip_3_max_read_None", + "mono_skip_2_max_read_negative", + "mono_skip_2_max_read_3", + "stereo_skip_0_max_read_None", + "stereo_skip_3_max_read_None", + "stereo_skip_2_max_read_negative", + "stereo_skip_2_max_read_3", + ], +) +def test_read_offline(channels, skip, max_read): + sampling_rate = 10 + sample_width = 2 + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + onset = round(skip * sampling_rate * sample_width * channels) + if max_read in (-1, None): + offset = len(data) + 1 + else: + offset = onset + round( + max_read * sampling_rate * sample_width * channels + ) + expected_data = data[onset:offset] + read_data, *audio_params = _read_offline( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_duration_to_nb_windows( - self, duration, analysis_window, round_fn, expected, kwargs=None - ): - if expected == ValueError: - with self.assertRaises(expected): - _duration_to_nb_windows(duration, analysis_window, round_fn) - else: - if kwargs is None: - kwargs = {} - result = _duration_to_nb_windows( - duration, analysis_window, round_fn, **kwargs - ) - self.assertEqual(result, expected) + assert read_data == expected_data + assert tuple(audio_params) == (sampling_rate, sample_width, channels) - @genty_dataset( - mono_skip_0_max_read_None=(1, 0, None), - mono_skip_3_max_read_None=(1, 3, None), - mono_skip_2_max_read_negative=(1, 2, -1), - mono_skip_2_max_read_3=(1, 2, 3), - stereo_skip_0_max_read_None=(2, 0, None), - stereo_skip_3_max_read_None=(2, 3, None), - stereo_skip_2_max_read_negative=(2, 2, -1), - stereo_skip_2_max_read_3=(2, 2, 3), - ) - def test_read_offline(self, channels, skip, max_read=None): - sampling_rate = 10 - sample_width = 2 - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() - onset = round(skip * sampling_rate * sample_width * channels) - if max_read in (-1, None): - offset = len(data) + 1 - else: - offset = onset + round( - max_read * sampling_rate * sample_width * channels - ) - expected_data = data[onset:offset] - read_data, *audio_params = _read_offline( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, - ) - self.assertEqual(read_data, expected_data) - self.assertEqual( - tuple(audio_params), (sampling_rate, sample_width, channels) - ) - -@genty -class TestSplit(TestCase): - @genty_dataset( - simple=( +@pytest.mark.parametrize( + "min_dur, max_dur, max_silence, drop_trailing_silence, strict_min_dur, kwargs, expected", + [ + (0.2, 5, 0.2, False, False, {"eth": 50}, [(2, 16), (17, 31), (34, 76)]), + ( + 0.3, + 2, + 0.2, + False, + False, + {"eth": 50}, + [(2, 16), (17, 31), (34, 54), (54, 74), (74, 76)], + ), + (3, 5, 0.2, False, False, {"eth": 50}, [(34, 76)]), + (0.2, 80, 10, False, False, {"eth": 50}, [(2, 76)]), + ( + 0.2, + 5, + 0.0, + False, + False, + {"eth": 50}, + [(2, 14), (17, 24), (26, 29), (34, 76)], + ), + ( 0.2, 5, 0.2, False, False, - {"eth": 50}, - [(2, 16), (17, 31), (34, 76)], - ), - short_max_dur=( - 0.3, - 2, - 0.2, - False, - False, - {"eth": 50}, - [(2, 16), (17, 31), (34, 54), (54, 74), (74, 76)], - ), - long_min_dur=(3, 5, 0.2, False, False, {"eth": 50}, [(34, 76)]), - long_max_silence=(0.2, 80, 10, False, False, {"eth": 50}, [(2, 76)]), - zero_max_silence=( - 0.2, - 5, - 0.0, - False, - False, - {"eth": 50}, - [(2, 14), (17, 24), (26, 29), (34, 76)], - ), - low_energy_threshold=( - 0.2, - 5, - 0.2, - False, - False, {"energy_threshold": 40}, [(0, 50), (50, 76)], ), - high_energy_threshold=( - 0.2, - 5, - 0.2, - False, - False, - {"energy_threshold": 60}, - [], - ), - trim_leading_and_trailing_silence=( - 0.2, - 10, # use long max_dur - 0.5, # and a max_silence longer than any inter-region silence - True, - False, - {"eth": 50}, - [(2, 76)], - ), - drop_trailing_silence=( - 0.2, - 5, - 0.2, - True, - False, - {"eth": 50}, - [(2, 14), (17, 29), (34, 76)], - ), - drop_trailing_silence_2=( - 1.5, - 5, - 0.2, - True, - False, - {"eth": 50}, - [(34, 76)], - ), - strict_min_dur=( + (0.2, 5, 0.2, False, False, {"energy_threshold": 60}, []), + (0.2, 10, 0.5, True, False, {"eth": 50}, [(2, 76)]), + (0.2, 5, 0.2, True, False, {"eth": 50}, [(2, 14), (17, 29), (34, 76)]), + (1.5, 5, 0.2, True, False, {"eth": 50}, [(34, 76)]), + ( 0.3, 2, 0.2, @@ -217,266 +208,185 @@ {"eth": 50}, [(2, 16), (17, 31), (34, 54), (54, 74)], ), - ) - def test_split_params( - self, + ], + ids=[ + "simple", + "short_max_dur", + "long_min_dur", + "long_max_silence", + "zero_max_silence", + "low_energy_threshold", + "high_energy_threshold", + "trim_leading_and_trailing_silence", + "drop_trailing_silence", + "drop_trailing_silence_2", + "strict_min_dur", + ], +) +def test_split_params( + min_dur, + max_dur, + max_silence, + drop_trailing_silence, + strict_min_dur, + kwargs, + expected, +): + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + + regions = split( + data, min_dur, max_dur, max_silence, drop_trailing_silence, strict_min_dur, - kwargs, - expected, - ): - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() + analysis_window=0.1, + sr=10, + sw=2, + ch=1, + **kwargs + ) - regions = split( - data, - min_dur, - max_dur, - max_silence, - drop_trailing_silence, - strict_min_dur, - analysis_window=0.1, - sr=10, - sw=2, - ch=1, - **kwargs - ) + region = AudioRegion(data, 10, 2, 1) + regions_ar = region.split( + min_dur, + max_dur, + max_silence, + drop_trailing_silence, + strict_min_dur, + analysis_window=0.1, + **kwargs + ) - region = AudioRegion(data, 10, 2, 1) - regions_ar = region.split( - min_dur, - max_dur, - max_silence, - drop_trailing_silence, - strict_min_dur, - analysis_window=0.1, - **kwargs - ) + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) + sample_width = 2 + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[onset * sample_width : offset * sample_width] + assert bytes(reg) == exp_data + assert reg == reg_ar - sample_width = 2 - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[onset * sample_width : offset * sample_width] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) - @genty_dataset( - stereo_all_default=(2, {}, [(2, 32), (34, 76)]), - mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), - mono_max_read_short_name=(1, {"mr": 5}, [(2, 16), (17, 31), (34, 50)]), - mono_use_channel_1=( - 1, - {"eth": 50, "use_channel": 0}, - [(2, 16), (17, 31), (34, 76)], - ), - mono_uc_1=(1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]), - mono_use_channel_None=( - 1, - {"eth": 50, "use_channel": None}, - [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_1=( - 2, - {"eth": 50, "use_channel": 0}, - [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_no_use_channel_given=( - 2, - {"eth": 50}, - [(2, 32), (34, 76)], - ), - stereo_use_channel_minus_2=( - 2, - {"eth": 50, "use_channel": -2}, - [(2, 16), (17, 31), (34, 76)], - ), - stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), - stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), - mono_uc_mix=( - 1, - {"eth": 50, "uc": "mix"}, - [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_mix=( - 2, - {"energy_threshold": 53.5, "use_channel": "mix"}, - [(54, 76)], - ), - stereo_uc_mix=(2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]), - stereo_uc_mix_default_eth=( - 2, - {"uc": "mix"}, - [(10, 16), (17, 31), (36, 76)], - ), +@pytest.mark.parametrize( + "channels, kwargs, expected", + [ + (2, {}, [(2, 32), (34, 76)]), + (1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), + (1, {"mr": 5}, [(2, 16), (17, 31), (34, 50)]), + (1, {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)]), + (1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]), + (1, {"eth": 50, "use_channel": None}, [(2, 16), (17, 31), (34, 76)]), + (2, {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)]), + (2, {"eth": 50}, [(2, 32), (34, 76)]), + (2, {"eth": 50, "use_channel": -2}, [(2, 16), (17, 31), (34, 76)]), + (2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), + (2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), + (1, {"eth": 50, "uc": "mix"}, [(2, 16), (17, 31), (34, 76)]), + (2, {"energy_threshold": 53.5, "use_channel": "mix"}, [(54, 76)]), + (2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]), + (2, {"uc": "mix"}, [(10, 16), (17, 31), (36, 76)]), + ], + ids=[ + "stereo_all_default", + "mono_max_read", + "mono_max_read_short_name", + "mono_use_channel_1", + "mono_uc_1", + "mono_use_channel_None", + "stereo_use_channel_1", + "stereo_use_channel_no_use_channel_given", + "stereo_use_channel_minus_2", + "stereo_uc_2", + "stereo_uc_minus_1", + "mono_uc_mix", + "stereo_use_channel_mix", + "stereo_uc_mix", + "stereo_uc_mix_default_eth", + ], +) +def test_split_kwargs(channels, kwargs, expected): + + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + sr=10, + sw=2, + ch=channels, + **kwargs ) - def test_split_kwargs(self, channels, kwargs, expected): - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() + region = AudioRegion(data, 10, 2, channels) + max_read = kwargs.get("max_read", kwargs.get("mr")) + if max_read is not None: + region = region.sec[:max_read] + kwargs.pop("max_read", None) + kwargs.pop("mr", None) - regions = split( - data, - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - sr=10, - sw=2, - ch=channels, - **kwargs - ) + regions_ar = region.split( + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + **kwargs + ) - region = AudioRegion(data, 10, 2, channels) - max_read = kwargs.get("max_read", kwargs.get("mr")) - if max_read is not None: - region = region.sec[:max_read] - kwargs.pop("max_read", None) - kwargs.pop("mr", None) + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg - regions_ar = region.split( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - **kwargs - ) + sample_width = 2 + sample_size_bytes = sample_width * channels + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert len(bytes(reg)) == len(exp_data) + assert reg == reg_ar - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) - sample_width = 2 - sample_size_bytes = sample_width * channels - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(len(bytes(reg)), len(exp_data)) - self.assertEqual(reg, reg_ar) - - @genty_dataset( - mono_aw_0_2_max_silence_0_2=( - 0.2, - 5, - 0.2, - 1, - {"aw": 0.2}, - [(2, 30), (34, 76)], - ), - mono_aw_0_2_max_silence_0_3=( - 0.2, - 5, - 0.3, - 1, - {"aw": 0.2}, - [(2, 30), (34, 76)], - ), - mono_aw_0_2_max_silence_0_4=( - 0.2, - 5, - 0.4, - 1, - {"aw": 0.2}, - [(2, 32), (34, 76)], - ), - mono_aw_0_2_max_silence_0=( - 0.2, - 5, - 0, - 1, - {"aw": 0.2}, - [(2, 14), (16, 24), (26, 28), (34, 76)], - ), - mono_aw_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), - mono_aw_0_3_max_silence_0=( - 0.3, - 5, - 0, - 1, - {"aw": 0.3}, - [(3, 12), (15, 24), (36, 76)], - ), - mono_aw_0_3_max_silence_0_3=( - 0.3, - 5, - 0.3, - 1, - {"aw": 0.3}, - [(3, 27), (36, 76)], - ), - mono_aw_0_3_max_silence_0_5=( - 0.3, - 5, - 0.5, - 1, - {"aw": 0.3}, - [(3, 27), (36, 76)], - ), - mono_aw_0_3_max_silence_0_6=( - 0.3, - 5, - 0.6, - 1, - {"aw": 0.3}, - [(3, 30), (36, 76)], - ), - mono_aw_0_4_max_silence_0=( - 0.2, - 5, - 0, - 1, - {"aw": 0.4}, - [(4, 12), (16, 24), (36, 76)], - ), - mono_aw_0_4_max_silence_0_3=( - 0.2, - 5, - 0.3, - 1, - {"aw": 0.4}, - [(4, 12), (16, 24), (36, 76)], - ), - mono_aw_0_4_max_silence_0_4=( - 0.2, - 5, - 0.4, - 1, - {"aw": 0.4}, - [(4, 28), (36, 76)], - ), - stereo_uc_None_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"analysis_window": 0.2}, - [(2, 32), (34, 76)], - ), - stereo_uc_any_analysis_window_0_2=( +@pytest.mark.parametrize( + "min_dur, max_dur, max_silence, channels, kwargs, expected", + [ + (0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), + (0.2, 5, 0.3, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), + (0.2, 5, 0.4, 1, {"aw": 0.2}, [(2, 32), (34, 76)]), + (0.2, 5, 0, 1, {"aw": 0.2}, [(2, 14), (16, 24), (26, 28), (34, 76)]), + (0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), + (0.3, 5, 0, 1, {"aw": 0.3}, [(3, 12), (15, 24), (36, 76)]), + (0.3, 5, 0.3, 1, {"aw": 0.3}, [(3, 27), (36, 76)]), + (0.3, 5, 0.5, 1, {"aw": 0.3}, [(3, 27), (36, 76)]), + (0.3, 5, 0.6, 1, {"aw": 0.3}, [(3, 30), (36, 76)]), + (0.2, 5, 0, 1, {"aw": 0.4}, [(4, 12), (16, 24), (36, 76)]), + (0.2, 5, 0.3, 1, {"aw": 0.4}, [(4, 12), (16, 24), (36, 76)]), + (0.2, 5, 0.4, 1, {"aw": 0.4}, [(4, 28), (36, 76)]), + (0.2, 5, 0.2, 2, {"analysis_window": 0.2}, [(2, 32), (34, 76)]), + ( 0.2, 5, 0.2, @@ -484,7 +394,7 @@ {"uc": None, "analysis_window": 0.2}, [(2, 32), (34, 76)], ), - stereo_use_channel_None_aw_0_3_max_silence_0_2=( + ( 0.2, 5, 0.2, @@ -492,7 +402,7 @@ {"use_channel": None, "analysis_window": 0.3}, [(3, 30), (36, 76)], ), - stereo_use_channel_any_aw_0_3_max_silence_0_3=( + ( 0.2, 5, 0.3, @@ -500,7 +410,7 @@ {"use_channel": "any", "analysis_window": 0.3}, [(3, 33), (36, 76)], ), - stereo_use_channel_None_aw_0_4_max_silence_0_2=( + ( 0.2, 5, 0.2, @@ -508,7 +418,7 @@ {"use_channel": None, "analysis_window": 0.4}, [(4, 28), (36, 76)], ), - stereo_use_channel_any_aw_0_3_max_silence_0_4=( + ( 0.2, 5, 0.4, @@ -516,7 +426,7 @@ {"use_channel": "any", "analysis_window": 0.4}, [(4, 32), (36, 76)], ), - stereo_uc_0_analysis_window_0_2=( + ( 0.2, 5, 0.2, @@ -524,7 +434,7 @@ {"uc": 0, "analysis_window": 0.2}, [(2, 30), (34, 76)], ), - stereo_uc_1_analysis_window_0_2=( + ( 0.2, 5, 0.2, @@ -532,7 +442,7 @@ {"uc": 1, "analysis_window": 0.2}, [(10, 32), (36, 76)], ), - stereo_uc_mix_aw_0_1_max_silence_0=( + ( 0.2, 5, 0, @@ -540,7 +450,7 @@ {"uc": "mix", "analysis_window": 0.1}, [(10, 14), (17, 24), (26, 29), (36, 76)], ), - stereo_uc_mix_aw_0_1_max_silence_0_1=( + ( 0.2, 5, 0.1, @@ -548,7 +458,7 @@ {"uc": "mix", "analysis_window": 0.1}, [(10, 15), (17, 25), (26, 30), (36, 76)], ), - stereo_uc_mix_aw_0_1_max_silence_0_2=( + ( 0.2, 5, 0.2, @@ -556,7 +466,7 @@ {"uc": "mix", "analysis_window": 0.1}, [(10, 16), (17, 31), (36, 76)], ), - stereo_uc_mix_aw_0_1_max_silence_0_3=( + ( 0.2, 5, 0.3, @@ -564,7 +474,7 @@ {"uc": "mix", "analysis_window": 0.1}, [(10, 32), (36, 76)], ), - stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3=( + ( 0.3, 5, 0, @@ -572,7 +482,7 @@ {"uc": "avg", "analysis_window": 0.2}, [(10, 14), (16, 24), (36, 76)], ), - stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41=( + ( 0.41, 5, 0, @@ -580,7 +490,7 @@ {"uc": "average", "analysis_window": 0.2}, [(16, 24), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_1=( + ( 0.2, 5, 0.1, @@ -588,7 +498,7 @@ {"uc": "mix", "analysis_window": 0.2}, [(10, 14), (16, 24), (26, 28), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_2=( + ( 0.2, 5, 0.2, @@ -596,7 +506,7 @@ {"uc": "mix", "analysis_window": 0.2}, [(10, 30), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_4=( + ( 0.2, 5, 0.4, @@ -604,7 +514,7 @@ {"uc": "mix", "analysis_window": 0.2}, [(10, 32), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_5=( + ( 0.2, 5, 0.5, @@ -612,7 +522,7 @@ {"uc": "mix", "analysis_window": 0.2}, [(10, 32), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_6=( + ( 0.2, 5, 0.6, @@ -620,7 +530,7 @@ {"uc": "mix", "analysis_window": 0.2}, [(10, 34), (36, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0=( + ( 0.2, 5, 0, @@ -628,7 +538,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 24), (27, 30), (36, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_min_dur_0_3=( + ( 0.4, 5, 0, @@ -636,7 +546,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 24), (36, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_6=( + ( 0.2, 5, 0.6, @@ -644,7 +554,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 57), (57, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_1=( + ( 0.2, 5.1, 0.6, @@ -652,7 +562,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 60), (60, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_2=( + ( 0.2, 5.2, 0.6, @@ -660,7 +570,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 60), (60, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_3=( + ( 0.2, 5.3, 0.6, @@ -668,7 +578,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 60), (60, 76)], ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_4=( + ( 0.2, 5.4, 0.6, @@ -676,7 +586,7 @@ {"uc": "mix", "analysis_window": 0.3}, [(9, 63), (63, 76)], ), - stereo_uc_mix_aw_0_4_max_silence_0=( + ( 0.2, 5, 0, @@ -684,7 +594,7 @@ {"uc": "mix", "analysis_window": 0.4}, [(16, 24), (36, 76)], ), - stereo_uc_mix_aw_0_4_max_silence_0_3=( + ( 0.2, 5, 0.3, @@ -692,7 +602,7 @@ {"uc": "mix", "analysis_window": 0.4}, [(16, 24), (36, 76)], ), - stereo_uc_mix_aw_0_4_max_silence_0_4=( + ( 0.2, 5, 0.4, @@ -700,131 +610,172 @@ {"uc": "mix", "analysis_window": 0.4}, [(16, 28), (36, 76)], ), + ], + ids=[ + "mono_aw_0_2_max_silence_0_2", + "mono_aw_0_2_max_silence_0_3", + "mono_aw_0_2_max_silence_0_4", + "mono_aw_0_2_max_silence_0", + "mono_aw_0_2", + "mono_aw_0_3_max_silence_0", + "mono_aw_0_3_max_silence_0_3", + "mono_aw_0_3_max_silence_0_5", + "mono_aw_0_3_max_silence_0_6", + "mono_aw_0_4_max_silence_0", + "mono_aw_0_4_max_silence_0_3", + "mono_aw_0_4_max_silence_0_4", + "stereo_uc_None_analysis_window_0_2", + "stereo_uc_any_analysis_window_0_2", + "stereo_use_channel_None_aw_0_3_max_silence_0_2", + "stereo_use_channel_any_aw_0_3_max_silence_0_3", + "stereo_use_channel_None_aw_0_4_max_silence_0_2", + "stereo_use_channel_any_aw_0_3_max_silence_0_4", + "stereo_uc_0_analysis_window_0_2", + "stereo_uc_1_analysis_window_0_2", + "stereo_uc_mix_aw_0_1_max_silence_0", + "stereo_uc_mix_aw_0_1_max_silence_0_1", + "stereo_uc_mix_aw_0_1_max_silence_0_2", + "stereo_uc_mix_aw_0_1_max_silence_0_3", + "stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3", + "stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41", + "stereo_uc_mix_aw_0_2_max_silence_0_1", + "stereo_uc_mix_aw_0_2_max_silence_0_2", + "stereo_uc_mix_aw_0_2_max_silence_0_4", + "stereo_uc_mix_aw_0_2_max_silence_0_5", + "stereo_uc_mix_aw_0_2_max_silence_0_6", + "stereo_uc_mix_aw_0_3_max_silence_0", + "stereo_uc_mix_aw_0_3_max_silence_0_min_dur_0_3", + "stereo_uc_mix_aw_0_3_max_silence_0_6", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_1", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_2", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_3", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_4", + "stereo_uc_mix_aw_0_4_max_silence_0", + "stereo_uc_mix_aw_0_4_max_silence_0_3", + "stereo_uc_mix_aw_0_4_max_silence_0_4", + ], +) +def test_split_analysis_window( + min_dur, max_dur, max_silence, channels, kwargs, expected +): + + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + drop_trailing_silence=False, + strict_min_dur=False, + sr=10, + sw=2, + ch=channels, + eth=49.99, + **kwargs ) - def test_split_analysis_window( - self, min_dur, max_dur, max_silence, channels, kwargs, expected - ): - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() + region = AudioRegion(data, 10, 2, channels) + regions_ar = region.split( + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + drop_trailing_silence=False, + strict_min_dur=False, + eth=49.99, + **kwargs + ) - regions = split( - data, - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - drop_trailing_silence=False, - strict_min_dur=False, - sr=10, - sw=2, - ch=channels, - eth=49.99, - **kwargs - ) + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg - region = AudioRegion(data, 10, 2, channels) - regions_ar = region.split( - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - drop_trailing_silence=False, - strict_min_dur=False, - eth=49.99, - **kwargs - ) + sample_width = 2 + sample_size_bytes = sample_width * channels + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert bytes(reg) == exp_data + assert reg == reg_ar - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) - sample_width = 2 - sample_size_bytes = sample_width * channels - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) +def test_split_custom_validator(): + filename = "tests/data/test_split_10HZ_mono.raw" + with open(filename, "rb") as fp: + data = fp.read() - def test_split_custom_validator(self): - filename = "tests/data/test_split_10HZ_mono.raw" - with open(filename, "rb") as fp: - data = fp.read() + regions = split( + data, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + sr=10, + sw=2, + ch=1, + analysis_window=0.1, + validator=lambda x: array_("h", x)[0] >= 320, + ) - regions = split( - data, - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - sr=10, - sw=2, - ch=1, - analysis_window=0.1, - validator=lambda x: array_("h", x)[0] >= 320, - ) + region = AudioRegion(data, 10, 2, 1) + regions_ar = region.split( + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + validator=lambda x: array_("h", x)[0] >= 320, + ) - region = AudioRegion(data, 10, 2, 1) - regions_ar = region.split( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - validator=lambda x: array_("h", x)[0] >= 320, - ) + expected = [(2, 16), (17, 31), (34, 76)] + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg - expected = [(2, 16), (17, 31), (34, 76)] - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) + sample_size_bytes = 2 + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert bytes(reg) == exp_data + assert reg == reg_ar - sample_size_bytes = 2 - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) - @genty_dataset( - filename_audio_format=( +@pytest.mark.parametrize( + "input, kwargs", + [ + ( "tests/data/test_split_10HZ_stereo.raw", {"audio_format": "raw", "sr": 10, "sw": 2, "ch": 2}, ), - filename_audio_format_short_name=( + ( "tests/data/test_split_10HZ_stereo.raw", {"fmt": "raw", "sr": 10, "sw": 2, "ch": 2}, ), - filename_no_audio_format=( - "tests/data/test_split_10HZ_stereo.raw", - {"sr": 10, "sw": 2, "ch": 2}, - ), - filename_no_long_audio_params=( + ("tests/data/test_split_10HZ_stereo.raw", {"sr": 10, "sw": 2, "ch": 2}), + ( "tests/data/test_split_10HZ_stereo.raw", {"sampling_rate": 10, "sample_width": 2, "channels": 2}, ), - bytes_=( + ( open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), {"sr": 10, "sw": 2, "ch": 2}, ), - audio_reader=( + ( AudioDataSource( "tests/data/test_split_10HZ_stereo.raw", sr=10, @@ -834,7 +785,7 @@ ), {}, ), - audio_region=( + ( AudioRegion( open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), 10, @@ -843,301 +794,242 @@ ), {}, ), - audio_source=( + ( get_audio_source( "tests/data/test_split_10HZ_stereo.raw", sr=10, sw=2, ch=2 ), {}, ), + ], + ids=[ + "filename_audio_format", + "filename_audio_format_short_name", + "filename_no_audio_format", + "filename_no_long_audio_params", + "bytes_", + "audio_reader", + "audio_region", + "audio_source", + ], +) +def test_split_input_type(input, kwargs): + + with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp: + data = fp.read() + + regions = split( + input, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + **kwargs ) - def test_split_input_type(self, input, kwargs): + regions = list(regions) + expected = [(2, 32), (34, 76)] + sample_width = 2 + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(expected, regions) + assert len(regions) == len(expected), err_msg + for reg, exp in zip(regions, expected): + onset, offset = exp + exp_data = data[onset * sample_width * 2 : offset * sample_width * 2] + assert bytes(reg) == exp_data - with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp: - data = fp.read() - regions = split( - input, +@pytest.mark.parametrize( + "min_dur, max_dur, analysis_window", + [ + (0.5, 0.4, 0.1), + (0.44, 0.49, 0.1), + ], + ids=[ + "min_dur_greater_than_max_dur", + "durations_OK_but_wrong_number_of_analysis_windows", + ], +) +def test_split_wrong_min_max_dur(min_dur, max_dur, analysis_window): + + with pytest.raises(ValueError) as val_err: + split( + b"0" * 16, + min_dur=min_dur, + max_dur=max_dur, + max_silence=0.2, + sr=16000, + sw=1, + ch=1, + analysis_window=analysis_window, + ) + + err_msg = "'min_dur' ({0} sec.) results in {1} analysis " + err_msg += "window(s) ({1} == ceil({0} / {2})) which is " + err_msg += "higher than the number of analysis window(s) for " + err_msg += "'max_dur' ({3} == floor({4} / {2}))" + + err_msg = err_msg.format( + min_dur, + math.ceil(min_dur / analysis_window), + analysis_window, + math.floor(max_dur / analysis_window), + max_dur, + ) + assert err_msg == str(val_err.value) + + +@pytest.mark.parametrize( + "max_silence, max_dur, analysis_window", + [ + (0.5, 0.5, 0.1), + (0.5, 0.4, 0.1), + (0.44, 0.49, 0.1), + ], + ids=[ + "max_silence_equals_max_dur", + "max_silence_greater_than_max_dur", + "durations_OK_but_wrong_number_of_analysis_windows", + ], +) +def test_split_wrong_max_silence_max_dur(max_silence, max_dur, analysis_window): + + with pytest.raises(ValueError) as val_err: + split( + b"0" * 16, + min_dur=0.2, + max_dur=max_dur, + max_silence=max_silence, + sr=16000, + sw=1, + ch=1, + analysis_window=analysis_window, + ) + + err_msg = "'max_silence' ({0} sec.) results in {1} analysis " + err_msg += "window(s) ({1} == floor({0} / {2})) which is " + err_msg += "higher or equal to the number of analysis window(s) for " + err_msg += "'max_dur' ({3} == floor({4} / {2}))" + + err_msg = err_msg.format( + max_silence, + math.floor(max_silence / analysis_window), + analysis_window, + math.floor(max_dur / analysis_window), + max_dur, + ) + assert err_msg == str(val_err.value) + + +@pytest.mark.parametrize( + "wrong_param", + [ + {"min_dur": -1}, + {"min_dur": 0}, + {"max_dur": -1}, + {"max_dur": 0}, + {"max_silence": -1}, + {"analysis_window": 0}, + {"analysis_window": -1}, + ], + ids=[ + "negative_min_dur", + "zero_min_dur", + "negative_max_dur", + "zero_max_dur", + "negative_max_silence", + "zero_analysis_window", + "negative_analysis_window", + ], +) +def test_split_negative_temporal_params(wrong_param): + + params = { + "min_dur": 0.2, + "max_dur": 0.5, + "max_silence": 0.1, + "analysis_window": 0.1, + } + params.update(wrong_param) + with pytest.raises(ValueError) as val_err: + split(None, **params) + + name = set(wrong_param).pop() + value = wrong_param[name] + err_msg = "'{}' ({}) must be >{} 0".format( + name, value, "=" if name == "max_silence" else "" + ) + assert err_msg == str(val_err.value) + + +def test_split_too_small_analysis_window(): + with pytest.raises(ValueError) as val_err: + split(b"", sr=10, sw=1, ch=1, analysis_window=0.09) + err_msg = "Too small 'analysis_windows' (0.09) for sampling rate (10)." + err_msg += " Analysis windows should at least be 1/10 to cover one " + err_msg += "single data sample" + assert err_msg == str(val_err.value) + + +def test_split_and_plot(): + + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + + region = AudioRegion(data, 10, 2, 1) + with patch("auditok.plotting.plot") as patch_fn: + regions = region.split_and_plot( min_dur=0.2, max_dur=5, max_silence=0.2, drop_trailing_silence=False, strict_min_dur=False, analysis_window=0.1, - **kwargs + sr=10, + sw=2, + ch=1, + eth=50, ) - regions = list(regions) - expected = [(2, 32), (34, 76)] - sample_width = 2 - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(expected, regions) - self.assertEqual(len(regions), len(expected), err_msg) - for reg, exp in zip(regions, expected): - onset, offset = exp - exp_data = data[ - onset * sample_width * 2 : offset * sample_width * 2 - ] - self.assertEqual(bytes(reg), exp_data) + assert patch_fn.called + expected = [(2, 16), (17, 31), (34, 76)] + sample_width = 2 + expected_regions = [] + for onset, offset in expected: + onset *= sample_width + offset *= sample_width + expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1)) + assert regions == expected_regions - @genty_dataset( - min_dur_greater_than_max_dur=(0.5, 0.4, 0.1), - durations_OK_but_wrong_number_of_analysis_windows=(0.44, 0.49, 0.1), - ) - def test_split_wrong_min_max_dur(self, min_dur, max_dur, analysis_window): - with self.assertRaises(ValueError) as val_err: - split( - b"0" * 16, - min_dur=min_dur, - max_dur=max_dur, - max_silence=0.2, - sr=16000, - sw=1, - ch=1, - analysis_window=analysis_window, - ) +def test_split_exception(): + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + region = AudioRegion(data, 10, 2, 1) - err_msg = "'min_dur' ({0} sec.) results in {1} analysis " - err_msg += "window(s) ({1} == ceil({0} / {2})) which is " - err_msg += "higher than the number of analysis window(s) for " - err_msg += "'max_dur' ({3} == floor({4} / {2}))" + with pytest.raises(RuntimeWarning): + # max_read is not accepted when calling AudioRegion.split + region.split(max_read=2) - err_msg = err_msg.format( - min_dur, - math.ceil(min_dur / analysis_window), - analysis_window, - math.floor(max_dur / analysis_window), - max_dur, - ) - self.assertEqual(err_msg, str(val_err.exception)) - @genty_dataset( - max_silence_equals_max_dur=(0.5, 0.5, 0.1), - max_silence_greater_than_max_dur=(0.5, 0.4, 0.1), - durations_OK_but_wrong_number_of_analysis_windows=(0.44, 0.49, 0.1), - ) - def test_split_wrong_max_silence_max_dur( - self, max_silence, max_dur, analysis_window - ): - - with self.assertRaises(ValueError) as val_err: - split( - b"0" * 16, - min_dur=0.2, - max_dur=max_dur, - max_silence=max_silence, - sr=16000, - sw=1, - ch=1, - analysis_window=analysis_window, - ) - - err_msg = "'max_silence' ({0} sec.) results in {1} analysis " - err_msg += "window(s) ({1} == floor({0} / {2})) which is " - err_msg += "higher or equal to the number of analysis window(s) for " - err_msg += "'max_dur' ({3} == floor({4} / {2}))" - - err_msg = err_msg.format( - max_silence, - math.floor(max_silence / analysis_window), - analysis_window, - math.floor(max_dur / analysis_window), - max_dur, - ) - self.assertEqual(err_msg, str(val_err.exception)) - - @genty_dataset( - negative_min_dur=({"min_dur": -1},), - zero_min_dur=({"min_dur": 0},), - negative_max_dur=({"max_dur": -1},), - zero_max_dur=({"max_dur": 0},), - negative_max_silence=({"max_silence": -1},), - zero_analysis_window=({"analysis_window": 0},), - negative_analysis_window=({"analysis_window": -1},), - ) - def test_split_negative_temporal_params(self, wrong_param): - - params = { - "min_dur": 0.2, - "max_dur": 0.5, - "max_silence": 0.1, - "analysis_window": 0.1, - } - params.update(wrong_param) - with self.assertRaises(ValueError) as val_err: - split(None, **params) - - name = set(wrong_param).pop() - value = wrong_param[name] - err_msg = "'{}' ({}) must be >{} 0".format( - name, value, "=" if name == "max_silence" else "" - ) - self.assertEqual(err_msg, str(val_err.exception)) - - def test_split_too_small_analysis_window(self): - with self.assertRaises(ValueError) as val_err: - split(b"", sr=10, sw=1, ch=1, analysis_window=0.09) - err_msg = "Too small 'analysis_windows' (0.09) for sampling rate (10)." - err_msg += " Analysis windows should at least be 1/10 to cover one " - err_msg += "single data sample" - self.assertEqual(err_msg, str(val_err.exception)) - - def test_split_and_plot(self): - - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() - - region = AudioRegion(data, 10, 2, 1) - with patch("auditok.plotting.plot") as patch_fn: - regions = region.split_and_plot( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - sr=10, - sw=2, - ch=1, - eth=50, - ) - self.assertTrue(patch_fn.called) - expected = [(2, 16), (17, 31), (34, 76)] - sample_width = 2 - expected_regions = [] - for (onset, offset) in expected: - onset *= sample_width - offset *= sample_width - expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1)) - self.assertEqual(regions, expected_regions) - - def test_split_exception(self): - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() - region = AudioRegion(data, 10, 2, 1) - - with self.assertRaises(RuntimeWarning): - # max_read is not accepted when calling AudioRegion.split - region.split(max_read=2) - - -@genty -class TestAudioRegion(TestCase): - @genty_dataset( - simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000), - one_ms_less_than_1_sec=( - b"\0" * 7992, - 0, - 8000, - 1, - 1, - 0.999, - 0.999, - 999, - ), - tree_quarter_ms_less_than_1_sec=( - b"\0" * 7994, - 0, - 8000, - 1, - 1, - 0.99925, - 0.99925, - 999, - ), - half_ms_less_than_1_sec=( - b"\0" * 7996, - 0, - 8000, - 1, - 1, - 0.9995, - 0.9995, - 1000, - ), - quarter_ms_less_than_1_sec=( - b"\0" * 7998, - 0, - 8000, - 1, - 1, - 0.99975, - 0.99975, - 1000, - ), - simple_sample_width_2=(b"\0" * 8000 * 2, 0, 8000, 2, 1, 1, 1, 1000), - simple_stereo=(b"\0" * 8000 * 2, 0, 8000, 1, 2, 1, 1, 1000), - simple_multichannel=(b"\0" * 8000 * 5, 0, 8000, 1, 5, 1, 1, 1000), - simple_sample_width_2_multichannel=( - b"\0" * 8000 * 2 * 5, - 0, - 8000, - 2, - 5, - 1, - 1, - 1000, - ), - one_ms_less_than_1s_sw_2_multichannel=( - b"\0" * 7992 * 2 * 5, - 0, - 8000, - 2, - 5, - 0.999, - 0.999, - 999, - ), - tree_qrt_ms_lt_1_s_sw_2_multichannel=( - b"\0" * 7994 * 2 * 5, - 0, - 8000, - 2, - 5, - 0.99925, - 0.99925, - 999, - ), - half_ms_lt_1s_sw_2_multichannel=( - b"\0" * 7996 * 2 * 5, - 0, - 8000, - 2, - 5, - 0.9995, - 0.9995, - 1000, - ), - quarter_ms_lt_1s_sw_2_multichannel=( - b"\0" * 7998 * 2 * 5, - 0, - 8000, - 2, - 5, - 0.99975, - 0.99975, - 1000, - ), - arbitrary_length_1=( - b"\0" * int(8000 * 1.33), - 2.7, - 8000, - 1, - 1, - 4.03, - 1.33, - 1330, - ), - arbitrary_length_2=( - b"\0" * int(8000 * 0.476), - 11.568, - 8000, - 1, - 1, - 12.044, - 0.476, - 476, - ), - arbitrary_length_sw_2_multichannel=( +@pytest.mark.parametrize( + "data, start, sampling_rate, sample_width, channels, expected_end, expected_duration_s, expected_duration_ms", + [ + (b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000), + (b"\0" * 7992, 0, 8000, 1, 1, 0.999, 0.999, 999), + (b"\0" * 7994, 0, 8000, 1, 1, 0.99925, 0.99925, 999), + (b"\0" * 7996, 0, 8000, 1, 1, 0.9995, 0.9995, 1000), + (b"\0" * 7998, 0, 8000, 1, 1, 0.99975, 0.99975, 1000), + (b"\0" * 8000 * 2, 0, 8000, 2, 1, 1, 1, 1000), + (b"\0" * 8000 * 2, 0, 8000, 1, 2, 1, 1, 1000), + (b"\0" * 8000 * 5, 0, 8000, 1, 5, 1, 1, 1000), + (b"\0" * 8000 * 2 * 5, 0, 8000, 2, 5, 1, 1, 1000), + (b"\0" * 7992 * 2 * 5, 0, 8000, 2, 5, 0.999, 0.999, 999), + (b"\0" * 7994 * 2 * 5, 0, 8000, 2, 5, 0.99925, 0.99925, 999), + (b"\0" * 7996 * 2 * 5, 0, 8000, 2, 5, 0.9995, 0.9995, 1000), + (b"\0" * 7998 * 2 * 5, 0, 8000, 2, 5, 0.99975, 0.99975, 1000), + (b"\0" * int(8000 * 1.33), 2.7, 8000, 1, 1, 4.03, 1.33, 1330), + (b"\0" * int(8000 * 0.476), 11.568, 8000, 1, 1, 12.044, 0.476, 476), + ( b"\0" * int(8000 * 1.711) * 2 * 3, 9.415, 8000, @@ -1147,7 +1039,7 @@ 1.711, 1711, ), - arbitrary_samplig_rate=( + ( b"\0" * int(3172 * 1.318), 17.236, 3172, @@ -1157,7 +1049,7 @@ int(3172 * 1.318) / 3172, 1318, ), - arbitrary_sr_sw_2_multichannel=( + ( b"\0" * int(11317 * 0.716) * 2 * 3, 18.811, 11317, @@ -1167,533 +1059,641 @@ int(11317 * 0.716) / 11317, 716, ), + ], + ids=[ + "simple", + "one_ms_less_than_1_sec", + "tree_quarter_ms_less_than_1_sec", + "half_ms_less_than_1_sec", + "quarter_ms_less_than_1_sec", + "simple_sample_width_2", + "simple_stereo", + "simple_multichannel", + "simple_sample_width_2_multichannel", + "one_ms_less_than_1s_sw_2_multichannel", + "tree_qrt_ms_lt_1_s_sw_2_multichannel", + "half_ms_lt_1s_sw_2_multichannel", + "quarter_ms_lt_1s_sw_2_multichannel", + "arbitrary_length_1", + "arbitrary_length_2", + "arbitrary_length_sw_2_multichannel", + "arbitrary_samplig_rate", + "arbitrary_sr_sw_2_multichannel", + ], +) +def test_creation( + data, + start, + sampling_rate, + sample_width, + channels, + expected_end, + expected_duration_s, + expected_duration_ms, +): + meta = {"start": start, "end": expected_end} + region = AudioRegion(data, sampling_rate, sample_width, channels, meta) + assert region.sampling_rate == sampling_rate + assert region.sr == sampling_rate + assert region.sample_width == sample_width + assert region.sw == sample_width + assert region.channels == channels + assert region.ch == channels + assert region.meta.start == start + assert region.meta.end == expected_end + assert region.duration == expected_duration_s + assert len(region.ms) == expected_duration_ms + assert bytes(region) == data + + +def test_creation_invalid_data_exception(): + with pytest.raises(AudioParameterError) as audio_param_err: + _ = AudioRegion( + data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 + ) + assert str(audio_param_err.value) == ( + "The length of audio data must be an integer " + "multiple of `sample_width * channels`" ) - def test_creation( - self, - data, - start, - sampling_rate, - sample_width, - channels, - expected_end, - expected_duration_s, - expected_duration_ms, - ): - meta = {"start": start, "end": expected_end} - region = AudioRegion(data, sampling_rate, sample_width, channels, meta) - self.assertEqual(region.sampling_rate, sampling_rate) - self.assertEqual(region.sr, sampling_rate) - self.assertEqual(region.sample_width, sample_width) - self.assertEqual(region.sw, sample_width) - self.assertEqual(region.channels, channels) - self.assertEqual(region.ch, channels) - self.assertEqual(region.meta.start, start) - self.assertEqual(region.meta.end, expected_end) - self.assertEqual(region.duration, expected_duration_s) - self.assertEqual(len(region.ms), expected_duration_ms) - self.assertEqual(bytes(region), data) - def test_creation_invalid_data_exception(self): - with self.assertRaises(AudioParameterError) as audio_param_err: - _ = AudioRegion( - data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 - ) - self.assertEqual( - "The length of audio data must be an integer " - "multiple of `sample_width * channels`", - str(audio_param_err.exception), - ) - @genty_dataset( - no_skip_read_all=(0, -1), - no_skip_read_all_stereo=(0, -1, 2), - skip_2_read_all=(2, -1), - skip_2_read_all_None=(2, None), - skip_2_read_3=(2, 3), - skip_2_read_3_5_stereo=(2, 3.5, 2), - skip_2_4_read_3_5_stereo=(2.4, 3.5, 2), +@pytest.mark.parametrize( + "skip, max_read, channels", + [ + (0, -1, 1), + (0, -1, 2), + (2, -1, 1), + (2, None, 1), + (2, 3, 1), + (2, 3.5, 2), + (2.4, 3.5, 2), + ], + ids=[ + "no_skip_read_all", + "no_skip_read_all_stereo", + "skip_2_read_all", + "skip_2_read_all_None", + "skip_2_read_3", + "skip_2_read_3_5_stereo", + "skip_2_4_read_3_5_stereo", + ], +) +def test_load_AudioRegion(skip, max_read, channels): + sampling_rate = 10 + sample_width = 2 + filename = "tests/data/test_split_10HZ_{}.raw" + filename = filename.format("mono" if channels == 1 else "stereo") + region = AudioRegion.load( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_load(self, skip, max_read, channels=1): - sampling_rate = 10 - sample_width = 2 - filename = "tests/data/test_split_10HZ_{}.raw" - filename = filename.format("mono" if channels == 1 else "stereo") - region = AudioRegion.load( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, - ) - with open(filename, "rb") as fp: - fp.read(round(skip * sampling_rate * sample_width * channels)) - if max_read is None or max_read < 0: - to_read = -1 - else: - to_read = round( - max_read * sampling_rate * sample_width * channels - ) - expected = fp.read(to_read) - self.assertEqual(bytes(region), expected) + with open(filename, "rb") as fp: + fp.read(round(skip * sampling_rate * sample_width * channels)) + if max_read is None or max_read < 0: + to_read = -1 + else: + to_read = round(max_read * sampling_rate * sample_width * channels) + expected = fp.read(to_read) + assert bytes(region) == expected - def test_load_from_microphone(self): - with patch("auditok.io.PyAudioSource") as patch_pyaudio_source: - with patch("auditok.core.AudioReader.read") as patch_reader: - patch_reader.return_value = None - with patch( - "auditok.core.AudioRegion.__init__" - ) as patch_AudioRegion: - patch_AudioRegion.return_value = None - AudioRegion.load( - None, skip=0, max_read=5, sr=16000, sw=2, ch=1 - ) - self.assertTrue(patch_pyaudio_source.called) - self.assertTrue(patch_reader.called) - self.assertTrue(patch_AudioRegion.called) - @genty_dataset(none=(None,), negative=(-1,)) - def test_load_from_microphone_without_max_read_exception(self, max_read): - with self.assertRaises(ValueError) as val_err: - AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1) - self.assertEqual( - "'max_read' should not be None when reading from microphone", - str(val_err.exception), - ) +def test_load_from_microphone(): + with patch("auditok.io.PyAudioSource") as patch_pyaudio_source: + with patch("auditok.core.AudioReader.read") as patch_reader: + patch_reader.return_value = None + with patch( + "auditok.core.AudioRegion.__init__" + ) as patch_AudioRegion: + patch_AudioRegion.return_value = None + AudioRegion.load(None, skip=0, max_read=5, sr=16000, sw=2, ch=1) + assert patch_pyaudio_source.called + assert patch_reader.called + assert patch_AudioRegion.called - def test_load_from_microphone_with_nonzero_skip_exception(self): - with self.assertRaises(ValueError) as val_err: - AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1) - self.assertEqual( - "'skip' should be 0 when reading from microphone", - str(val_err.exception), - ) - @genty_dataset( - simple=("output.wav", 1.230, "output.wav"), - start=("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"), - start_2=("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"), - start_3=("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"), - start_4=("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"), - start_5=( - "output_{meta.start:.8f}.wav", - 1.233712, - "output_1.23371200.wav", - ), - start_end_duration=( +@pytest.mark.parametrize( + "max_read", + [ + None, + -1, + ], + ids=[ + "none", + "negative", + ], +) +def test_load_from_microphone_without_max_read_exception(max_read): + with pytest.raises(ValueError) as val_err: + AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1) + assert str(val_err.value) == ( + "'max_read' should not be None when reading from microphone" + ) + + +def test_load_from_microphone_with_nonzero_skip_exception(): + with pytest.raises(ValueError) as val_err: + AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1) + assert str(val_err.value) == ( + "'skip' should be 0 when reading from microphone" + ) + + +@pytest.mark.parametrize( + "format, start, expected", + [ + ("output.wav", 1.230, "output.wav"), + ("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"), + ("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"), + ("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"), + ("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"), + ("output_{meta.start:.8f}.wav", 1.233712, "output_1.23371200.wav"), + ( "output_{meta.start}_{meta.end}_{duration}.wav", 1.455, "output_1.455_2.455_1.0.wav", ), - start_end_duration_2=( + ( "output_{meta.start}_{meta.end}_{duration}.wav", 1.455321, "output_1.455321_2.455321_1.0.wav", ), - ) - def test_save(self, format, start, expected): - with TemporaryDirectory() as tmpdir: - region = AudioRegion(b"0" * 160, 160, 1, 1) - meta = {"start": start, "end": start + region.duration} - region.meta = meta - format = os.path.join(tmpdir, format) - filename = region.save(format)[len(tmpdir) + 1 :] - self.assertEqual(filename, expected) + ], + ids=[ + "simple", + "start", + "start_2", + "start_3", + "start_4", + "start_5", + "start_end_duration", + "start_end_duration_2", + ], +) +def test_save(format, start, expected): + with TemporaryDirectory() as tmpdir: + region = AudioRegion(b"0" * 160, 160, 1, 1) + meta = {"start": start, "end": start + region.duration} + region.meta = meta + format = os.path.join(tmpdir, format) + filename = region.save(format)[len(tmpdir) + 1 :] + assert filename == expected - def test_save_file_exists_exception(self): - with TemporaryDirectory() as tmpdir: - filename = os.path.join(tmpdir, "output.wav") - open(filename, "w").close() - region = AudioRegion(b"0" * 160, 160, 1, 1) - with self.assertRaises(FileExistsError): - region.save(filename, exists_ok=False) - @genty_dataset( - first_half=( +def test_save_file_exists_exception(): + with TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, "output.wav") + open(filename, "w").close() + region = AudioRegion(b"0" * 160, 160, 1, 1) + with pytest.raises(FileExistsError): + region.save(filename, exists_ok=False) + + +@pytest.mark.parametrize( + "region, slice_, expected_data", + [ + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 500), b"a" * 80, ), - second_half=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(500, None), b"b" * 80, ), - second_half_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-500, None), b"b" * 80, ), - middle=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(200, 750), b"a" * 48 + b"b" * 40, ), - middle_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-800, -250), b"a" * 48 + b"b" * 40, ), - middle_sw2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 2, 1), slice(200, 750), b"a" * 96 + b"b" * 80, ), - middle_ch2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 1, 2), slice(200, 750), b"a" * 96 + b"b" * 80, ), - middle_sw2_ch2=( + ( AudioRegion(b"a" * 320 + b"b" * 320, 160, 2, 2), slice(200, 750), b"a" * 192 + b"b" * 160, ), - but_first_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(1, None), b"a" * (4000 - 8) + b"b" * 4000, ), - but_first_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(-999, None), b"a" * (4000 - 8) + b"b" * 4000, ), - but_last_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, 999), b"a" * 4000 + b"b" * (4000 - 8), ), - but_last_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, -1), b"a" * 4000 + b"b" * (4000 - 8), ), - big_negative_start=( - AudioRegion(b"a" * 160, 160, 1, 1), - slice(-5000, None), - b"a" * 160, - ), - big_negative_stop=( - AudioRegion(b"a" * 160, 160, 1, 1), - slice(None, -1500), - b"", - ), - empty=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(0, 0), - b"", - ), - empty_start_stop_reversed=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(200, 100), - b"", - ), - empty_big_positive_start=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(2000, 3000), - b"", - ), - empty_negative_reversed=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(-100, -200), - b"", - ), - empty_big_negative_stop=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(0, -2000), - b"", - ), - arbitrary_sampling_rate=( + (AudioRegion(b"a" * 160, 160, 1, 1), slice(-5000, None), b"a" * 160), + (AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1500), b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(200, 100), b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(2000, 3000), b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-100, -200), b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, -2000), b""), + ( AudioRegion(b"a" * 124 + b"b" * 376, 1234, 1, 1), slice(100, 200), b"a" + b"b" * 123, ), - ) - def test_region_temporal_slicing(self, region, slice_, expected_data): - sub_region = region.millis[slice_] - self.assertEqual(bytes(sub_region), expected_data) - start_sec = slice_.start / 1000 if slice_.start is not None else None - stop_sec = slice_.stop / 1000 if slice_.stop is not None else None - sub_region = region.sec[start_sec:stop_sec] - self.assertEqual(bytes(sub_region), expected_data) + ], + ids=[ + "first_half", + "second_half", + "second_half_negative", + "middle", + "middle_negative", + "middle_sw2", + "middle_ch2", + "middle_sw2_ch2", + "but_first_sample", + "but_first_sample_negative", + "but_last_sample", + "but_last_sample_negative", + "big_negative_start", + "big_negative_stop", + "empty", + "empty_start_stop_reversed", + "empty_big_positive_start", + "empty_negative_reversed", + "empty_big_negative_stop", + "arbitrary_sampling_rate", + ], +) +def test_region_temporal_slicing(region, slice_, expected_data): + sub_region = region.millis[slice_] + assert bytes(sub_region) == expected_data + start_sec = slice_.start / 1000 if slice_.start is not None else None + stop_sec = slice_.stop / 1000 if slice_.stop is not None else None + sub_region = region.sec[start_sec:stop_sec] + assert bytes(sub_region) == expected_data - @genty_dataset( - first_half=( + +@pytest.mark.parametrize( + "region, slice_, time_shift, expected_data", + [ + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 80), 0, b"a" * 80, ), - second_half=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(80, None), 0.5, b"b" * 80, ), - second_half_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-80, None), 0.5, b"b" * 80, ), - middle=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(160 // 5, 160 // 4 * 3), 0.2, b"a" * 48 + b"b" * 40, ), - middle_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-160 // 5 * 4, -160 // 4), 0.2, b"a" * 48 + b"b" * 40, ), - middle_sw2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 2, 1), slice(160 // 5, 160 // 4 * 3), 0.2, b"a" * 96 + b"b" * 80, ), - middle_ch2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 1, 2), slice(160 // 5, 160 // 4 * 3), 0.2, b"a" * 96 + b"b" * 80, ), - middle_sw2_ch2=( + ( AudioRegion(b"a" * 320 + b"b" * 320, 160, 2, 2), slice(160 // 5, 160 // 4 * 3), 0.2, b"a" * 192 + b"b" * 160, ), - but_first_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(1, None), 1 / 8000, b"a" * (4000 - 1) + b"b" * 4000, ), - but_first_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(-7999, None), 1 / 8000, b"a" * (4000 - 1) + b"b" * 4000, ), - but_last_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, 7999), 0, b"a" * 4000 + b"b" * (4000 - 1), ), - but_last_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, -1), 0, b"a" * 4000 + b"b" * (4000 - 1), ), - big_negative_start=( - AudioRegion(b"a" * 160, 160, 1, 1), - slice(-1600, None), - 0, - b"a" * 160, - ), - big_negative_stop=( - AudioRegion(b"a" * 160, 160, 1, 1), - slice(None, -1600), - 0, - b"", - ), - empty=( - AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), - slice(0, 0), - 0, - b"", - ), - empty_start_stop_reversed=( + (AudioRegion(b"a" * 160, 160, 1, 1), slice(-1600, None), 0, b"a" * 160), + (AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1600), 0, b""), + (AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), 0, b""), + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(80, 40), 0.5, b"", ), - empty_big_positive_start=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(1600, 3000), 10, b"", ), - empty_negative_reversed=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-16, -32), 0.9, b"", ), - empty_big_negative_stop=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, -2000), 0, b"", ), - arbitrary_sampling_rate=( + ( AudioRegion(b"a" * 124 + b"b" * 376, 1235, 1, 1), slice(100, 200), 100 / 1235, b"a" * 24 + b"b" * 76, ), - arbitrary_sampling_rate_middle_sw2_ch2=( + ( AudioRegion(b"a" * 124 + b"b" * 376, 1235, 2, 2), slice(25, 50), 25 / 1235, b"a" * 24 + b"b" * 76, ), + ], + ids=[ + "first_half", + "second_half", + "second_half_negative", + "middle", + "middle_negative", + "middle_sw2", + "middle_ch2", + "middle_sw2_ch2", + "but_first_sample", + "but_first_sample_negative", + "but_last_sample", + "but_last_sample_negative", + "big_negative_start", + "big_negative_stop", + "empty", + "empty_start_stop_reversed", + "empty_big_positive_start", + "empty_negative_reversed", + "empty_big_negative_stop", + "arbitrary_sampling_rate", + "arbitrary_sampling_rate_middle_sw2_ch2", + ], +) +def test_region_sample_slicing(region, slice_, time_shift, expected_data): + sub_region = region[slice_] + assert bytes(sub_region) == expected_data + + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (8000, 1, 1), + (8000, 2, 2), + (5413, 2, 3), + ], + ids=[ + "simple", + "stereo_sw_2", + "arbitrary_sr_multichannel", + ], +) +def test_concatenation(sampling_rate, sample_width, channels): + + region_1, region_2 = _make_random_length_regions( + [b"a", b"b"], sampling_rate, sample_width, channels ) - def test_region_sample_slicing( - self, region, slice_, time_shift, expected_data - ): - sub_region = region[slice_] - self.assertEqual(bytes(sub_region), expected_data) + expected_duration = region_1.duration + region_2.duration + expected_data = bytes(region_1) + bytes(region_2) + concat_region = region_1 + region_2 + assert concat_region.duration == pytest.approx(expected_duration, abs=1e-6) + assert bytes(concat_region) == expected_data - @genty_dataset( - simple=(8000, 1, 1), - stereo_sw_2=(8000, 2, 2), - arbitrary_sr_multichannel=(5413, 2, 3), + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (8000, 1, 1), + (8000, 2, 2), + (5413, 2, 3), + ], + ids=[ + "simple", + "stereo_sw_2", + "arbitrary_sr_multichannel", + ], +) +def test_concatenation_many(sampling_rate, sample_width, channels): + + regions = _make_random_length_regions( + [b"a", b"b", b"c"], sampling_rate, sample_width, channels ) - def test_concatenation(self, sampling_rate, sample_width, channels): + expected_duration = sum(r.duration for r in regions) + expected_data = b"".join(bytes(r) for r in regions) + concat_region = sum(regions) - region_1, region_2 = _make_random_length_regions( - [b"a", b"b"], sampling_rate, sample_width, channels - ) - expected_duration = region_1.duration + region_2.duration - expected_data = bytes(region_1) + bytes(region_2) - concat_region = region_1 + region_2 - self.assertAlmostEqual( - concat_region.duration, expected_duration, places=6 - ) - self.assertEqual(bytes(concat_region), expected_data) + assert concat_region.duration == pytest.approx(expected_duration, abs=1e-6) + assert bytes(concat_region) == expected_data - @genty_dataset( - simple=(8000, 1, 1), - stereo_sw_2=(8000, 2, 2), - arbitrary_sr_multichannel=(5413, 2, 3), + +def test_concatenation_different_sampling_rate_error(): + + region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) + region_2 = AudioRegion(b"b" * 100, 3000, 1, 1) + + with pytest.raises(ValueError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same " + "sampling rate (8000 != 3000)" ) - def test_concatenation_many(self, sampling_rate, sample_width, channels): - regions = _make_random_length_regions( - [b"a", b"b", b"c"], sampling_rate, sample_width, channels - ) - expected_duration = sum(r.duration for r in regions) - expected_data = b"".join(bytes(r) for r in regions) - concat_region = sum(regions) - self.assertAlmostEqual( - concat_region.duration, expected_duration, places=6 - ) - self.assertEqual(bytes(concat_region), expected_data) +def test_concatenation_different_sample_width_error(): - def test_concatenation_different_sampling_rate_error(self): + region_1 = AudioRegion(b"a" * 100, 8000, 2, 1) + region_2 = AudioRegion(b"b" * 100, 8000, 4, 1) - region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) - region_2 = AudioRegion(b"b" * 100, 3000, 1, 1) + with pytest.raises(ValueError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same " "sample width (2 != 4)" + ) - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "sampling rate (8000 != 3000)", - str(val_err.exception), - ) - def test_concatenation_different_sample_width_error(self): +def test_concatenation_different_number_of_channels_error(): - region_1 = AudioRegion(b"a" * 100, 8000, 2, 1) - region_2 = AudioRegion(b"b" * 100, 8000, 4, 1) + region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) + region_2 = AudioRegion(b"b" * 100, 8000, 1, 2) - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "sample width (2 != 4)", - str(val_err.exception), - ) + with pytest.raises(ValueError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same " + "number of channels (1 != 2)" + ) - def test_concatenation_different_number_of_channels_error(self): - region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) - region_2 = AudioRegion(b"b" * 100, 8000, 1, 2) +@pytest.mark.parametrize( + "duration, expected_duration, expected_len, expected_len_ms", + [ + (0.01, 0.03, 240, 30), + (0.00575, 0.01725, 138, 17), + (0.00625, 0.01875, 150, 19), + ], + ids=[ + "simple", + "rounded_len_floor", + "rounded_len_ceil", + ], +) +def test_multiplication( + duration, expected_duration, expected_len, expected_len_ms +): + sw = 2 + data = b"0" * int(duration * 8000 * sw) + region = AudioRegion(data, 8000, sw, 1) + m_region = 1 * region * 3 + assert bytes(m_region) == data * 3 + assert m_region.sr == 8000 + assert m_region.sw == 2 + assert m_region.ch == 1 + assert m_region.duration == expected_duration + assert len(m_region) == expected_len + assert m_region.len == expected_len + assert m_region.s.len == expected_duration + assert len(m_region.ms) == expected_len_ms + assert m_region.ms.len == expected_len_ms - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "number of channels (1 != 2)", - str(val_err.exception), - ) - @genty_dataset( - simple=(0.01, 0.03, 240, 30), - rounded_len_floor=(0.00575, 0.01725, 138, 17), - rounded_len_ceil=(0.00625, 0.01875, 150, 19), - ) - def test_multiplication( - self, duration, expected_duration, expected_len, expected_len_ms - ): - sw = 2 - data = b"0" * int(duration * 8000 * sw) - region = AudioRegion(data, 8000, sw, 1) - m_region = 1 * region * 3 - self.assertEqual(bytes(m_region), data * 3) - self.assertEqual(m_region.sr, 8000) - self.assertEqual(m_region.sw, 2) - self.assertEqual(m_region.ch, 1) - self.assertEqual(m_region.duration, expected_duration) - self.assertEqual(len(m_region), expected_len) - self.assertEqual(m_region.len, expected_len) - self.assertEqual(m_region.s.len, expected_duration) - self.assertEqual(len(m_region.ms), expected_len_ms) - self.assertEqual(m_region.ms.len, expected_len_ms) +@pytest.mark.parametrize( + "factor, _type", + [ + ("x", "str"), + (1.4, "float"), + ], + ids=[ + "_str", + "_float", + ], +) +def test_multiplication_non_int(factor, _type): + with pytest.raises(TypeError) as type_err: + AudioRegion(b"0" * 80, 8000, 1, 1) * factor + err_msg = "Can't multiply AudioRegion by a non-int of type '{}'" + assert err_msg.format(_type) == str(type_err.value) - @genty_dataset(_str=("x", "str"), _float=(1.4, "float")) - def test_multiplication_non_int(self, factor, _type): - with self.assertRaises(TypeError) as type_err: - AudioRegion(b"0" * 80, 8000, 1, 1) * factor - err_msg = "Can't multiply AudioRegion by a non-int of type '{}'" - self.assertEqual(err_msg.format(_type), str(type_err.exception)) - @genty_dataset( - simple=([b"a" * 80, b"b" * 80],), - extra_samples_1=([b"a" * 31, b"b" * 31, b"c" * 30],), - extra_samples_2=([b"a" * 31, b"b" * 30, b"c" * 30],), - extra_samples_3=([b"a" * 11, b"b" * 11, b"c" * 10, b"c" * 10],), - ) - def test_truediv(self, data): +@pytest.mark.parametrize( + "data", + [ + [b"a" * 80, b"b" * 80], + [b"a" * 31, b"b" * 31, b"c" * 30], + [b"a" * 31, b"b" * 30, b"c" * 30], + [b"a" * 11, b"b" * 11, b"c" * 10, b"c" * 10], + ], + ids=[ + "simple", + "extra_samples_1", + "extra_samples_2", + "extra_samples_3", + ], +) +def test_truediv(data): - region = AudioRegion(b"".join(data), 80, 1, 1) + region = AudioRegion(b"".join(data), 80, 1, 1) - sub_regions = region / len(data) - for data_i, region in zip(data, sub_regions): - self.assertEqual(len(data_i), len(bytes(region))) + sub_regions = region / len(data) + for data_i, region in zip(data, sub_regions): + assert len(data_i) == len(bytes(region)) - @genty_dataset( - mono_sw_1=(b"a" * 10, 1, 1, "b", [97] * 10), - mono_sw_2=(b"a" * 10, 2, 1, "h", [24929] * 5), - mono_sw_4=(b"a" * 8, 4, 1, "i", [1633771873] * 2), - stereo_sw_1=(b"ab" * 5, 1, 2, "b", [[97] * 5, [98] * 5]), - ) - def test_samples(self, data, sample_width, channels, fmt, expected): - region = AudioRegion(data, 10, sample_width, channels) - if isinstance(expected[0], list): - expected = [array_(fmt, exp) for exp in expected] - else: - expected = array_(fmt, expected) - samples = region.samples - equal = samples == expected - try: - # for numpy - equal = equal.all() - except AttributeError: - pass - self.assertTrue(equal) +@pytest.mark.parametrize( + "data, sample_width, channels, fmt, expected", + [ + (b"a" * 10, 1, 1, "b", [97] * 10), + (b"a" * 10, 2, 1, "h", [24929] * 5), + (b"a" * 8, 4, 1, "i", [1633771873] * 2), + (b"ab" * 5, 1, 2, "b", [[97] * 5, [98] * 5]), + ], + ids=[ + "mono_sw_1", + "mono_sw_2", + "mono_sw_4", + "stereo_sw_1", + ], +) +def test_samples(data, sample_width, channels, fmt, expected): - -if __name__ == "__main__": - unittest.main() + region = AudioRegion(data, 10, sample_width, channels) + if isinstance(expected[0], list): + expected = [array_(fmt, exp) for exp in expected] + else: + expected = array_(fmt, expected) + samples = region.samples + equal = samples == expected + try: + # for numpy + equal = equal.all() + except AttributeError: + pass + assert equal
--- a/tests/test_io.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_io.py Sat May 25 21:54:13 2024 +0200 @@ -4,10 +4,8 @@ from array import array from tempfile import NamedTemporaryFile, TemporaryDirectory import filecmp -import unittest -from unittest import TestCase +import pytest from unittest.mock import patch, Mock -from genty import genty, genty_dataset from test_util import _sample_generator, _generate_pure_tone, PURE_TONE_DICT from auditok.signal import FORMAT from auditok.io import ( @@ -34,413 +32,604 @@ AUDIO_PARAMS_SHORT = {"sr": 16000, "sw": 2, "ch": 1} -@genty -class TestIO(TestCase): - @genty_dataset( - valid_mono=(b"\0" * 113, 1, 1), - valid_stereo=(b"\0" * 160, 1, 2), - invalid_mono_sw_2=(b"\0" * 113, 2, 1, False), - invalid_stereo_sw_1=(b"\0" * 113, 1, 2, False), - invalid_stereo_sw_2=(b"\0" * 158, 2, 2, False), +@pytest.mark.parametrize( + "data, sample_width, channels, valid", + [ + (b"\0" * 113, 1, 1, True), # valid_mono + (b"\0" * 160, 1, 2, True), # valid_stereo + (b"\0" * 113, 2, 1, False), # invalid_mono_sw_2 + (b"\0" * 113, 1, 2, False), # invalid_stereo_sw_1 + (b"\0" * 158, 2, 2, False), # invalid_stereo_sw_2 + ], + ids=[ + "valid_mono", + "valid_stereo", + "invalid_mono_sw_2", + "invalid_stereo_sw_1", + "invalid_stereo_sw_2", + ], +) +def test_check_audio_data(data, sample_width, channels, valid): + if not valid: + with pytest.raises(AudioParameterError): + check_audio_data(data, sample_width, channels) + else: + assert check_audio_data(data, sample_width, channels) is None + + +@pytest.mark.parametrize( + "fmt, filename, expected", + [ + ("wav", "filename.wav", "wav"), # extention_and_format_same + ("wav", "filename.mp3", "wav"), # extention_and_format_different + (None, "filename.wav", "wav"), # extention_no_format + ("wav", "filename", "wav"), # format_no_extension + (None, "filename", None), # no_format_no_extension + ("wave", "filename", "wav"), # wave_as_wav + (None, "filename.wave", "wav"), # wave_as_wav_extension + ], + ids=[ + "extention_and_format_same", + "extention_and_format_different", + "extention_no_format", + "format_no_extension", + "no_format_no_extension", + "wave_as_wav", + "wave_as_wav_extension", + ], +) +def test_guess_audio_format(fmt, filename, expected): + result = _guess_audio_format(fmt, filename) + assert result == expected + + +def test_get_audio_parameters_short_params(): + expected = (8000, 2, 1) + params = dict(zip(("sr", "sw", "ch"), expected)) + result = _get_audio_parameters(params) + assert result == expected + + +def test_get_audio_parameters_long_params(): + expected = (8000, 2, 1) + params = dict( + zip( + ("sampling_rate", "sample_width", "channels", "use_channel"), + expected, + ) ) - def test_check_audio_data(self, data, sample_width, channels, valid=True): + result = _get_audio_parameters(params) + assert result == expected - if not valid: - with self.assertRaises(AudioParameterError): - check_audio_data(data, sample_width, channels) - else: - self.assertIsNone(check_audio_data(data, sample_width, channels)) - @genty_dataset( - extention_and_format_same=("wav", "filename.wav", "wav"), - extention_and_format_different=("wav", "filename.mp3", "wav"), - extention_no_format=(None, "filename.wav", "wav"), - format_no_extension=("wav", "filename", "wav"), - no_format_no_extension=(None, "filename", None), - wave_as_wav=("wave", "filename", "wav"), - wave_as_wav_extension=(None, "filename.wave", "wav"), - ) - def test_guess_audio_format(self, fmt, filename, expected): - result = _guess_audio_format(fmt, filename) - self.assertEqual(result, expected) +def test_get_audio_parameters_long_params_shadow_short_ones(): + expected = (8000, 2, 1) + params = dict(zip(("sampling_rate", "sample_width", "channels"), expected)) + params.update(dict(zip(("sr", "sw", "ch"), "xxx"))) + result = _get_audio_parameters(params) + assert result == expected - def test_get_audio_parameters_short_params(self): - expected = (8000, 2, 1) - params = dict(zip(("sr", "sw", "ch"), expected)) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) - def test_get_audio_parameters_long_params(self): - expected = (8000, 2, 1) - params = dict( - zip( - ("sampling_rate", "sample_width", "channels", "use_channel"), - expected, - ) - ) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) +@pytest.mark.parametrize( + "values", + [ + ("x", 2, 1), # str_sampling_rate + (-8000, 2, 1), # negative_sampling_rate + (8000, "x", 1), # str_sample_width + (8000, -2, 1), # negative_sample_width + (8000, 2, "x"), # str_channels + (8000, 2, -1), # negative_channels + ], + ids=[ + "str_sampling_rate", + "negative_sampling_rate", + "str_sample_width", + "negative_sample_width", + "str_channels", + "negative_channels", + ], +) +def test_get_audio_parameters_invalid(values): + params = dict(zip(("sampling_rate", "sample_width", "channels"), values)) + with pytest.raises(AudioParameterError): + _get_audio_parameters(params) - def test_get_audio_parameters_long_params_shadow_short_ones(self): - expected = (8000, 2, 1) - params = dict( - zip(("sampling_rate", "sample_width", "channels"), expected) - ) - params.update(dict(zip(("sr", "sw", "ch"), "xxx"))) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) - @genty_dataset( - str_sampling_rate=(("x", 2, 1),), - negative_sampling_rate=((-8000, 2, 1),), - str_sample_width=((8000, "x", 1),), - negative_sample_width=((8000, -2, 1),), - str_channels=((8000, 2, "x"),), - negative_channels=((8000, 2, -1),), - ) - def test_get_audio_parameters_invalid(self, values): - params = dict( - zip(("sampling_rate", "sample_width", "channels"), values) - ) - with self.assertRaises(AudioParameterError): - _get_audio_parameters(params) - - @genty_dataset( - raw_with_audio_format=( +@pytest.mark.parametrize( + "filename, audio_format, funtion_name, kwargs", + [ + ( "audio", "raw", "_load_raw", AUDIO_PARAMS_SHORT, - ), - raw_with_extension=( + ), # raw_with_audio_format + ( "audio.raw", None, "_load_raw", AUDIO_PARAMS_SHORT, - ), - wave_with_audio_format=("audio", "wave", "_load_wave"), - wav_with_audio_format=("audio", "wave", "_load_wave"), - wav_with_extension=("audio.wav", None, "_load_wave"), - format_and_extension_both_given=("audio.dat", "wav", "_load_wave"), - format_and_extension_both_given_b=("audio.raw", "wave", "_load_wave"), - no_format_nor_extension=("audio", None, "_load_with_pydub"), - other_formats_ogg=("audio.ogg", None, "_load_with_pydub"), - other_formats_webm=("audio", "webm", "_load_with_pydub"), + ), # raw_with_extension + ("audio", "wave", "_load_wave", None), # wave_with_audio_format + ("audio", "wave", "_load_wave", None), # wav_with_audio_format + ("audio.wav", None, "_load_wave", None), # wav_with_extension + ( + "audio.dat", + "wav", + "_load_wave", + None, + ), # format_and_extension_both_given_a + ( + "audio.raw", + "wave", + "_load_wave", + None, + ), # format_and_extension_both_given_b + ("audio", None, "_load_with_pydub", None), # no_format_nor_extension + ("audio.ogg", None, "_load_with_pydub", None), # other_formats_ogg + ("audio", "webm", "_load_with_pydub", None), # other_formats_webm + ], + ids=[ + "raw_with_audio_format", + "raw_with_extension", + "wave_with_audio_format", + "wav_with_audio_format", + "wav_with_extension", + "format_and_extension_both_given_a", + "format_and_extension_both_given_b", + "no_format_nor_extension", + "other_formats_ogg", + "other_formats_webm", + ], +) +def test_from_file(filename, audio_format, funtion_name, kwargs): + funtion_name = "auditok.io." + funtion_name + if kwargs is None: + kwargs = {} + with patch(funtion_name) as patch_function: + from_file(filename, audio_format, **kwargs) + assert patch_function.called + + +def test_from_file_large_file_raw(): + filename = "tests/data/test_16KHZ_mono_400Hz.raw" + audio_source = from_file( + filename, + large_file=True, + sampling_rate=16000, + sample_width=2, + channels=1, ) - def test_from_file( - self, filename, audio_format, funtion_name, kwargs=None - ): - funtion_name = "auditok.io." + funtion_name - if kwargs is None: - kwargs = {} - with patch(funtion_name) as patch_function: - from_file(filename, audio_format, **kwargs) - self.assertTrue(patch_function.called) + assert isinstance(audio_source, RawAudioSource) - def test_from_file_large_file_raw(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.raw" - audio_source = from_file( - filename, - large_file=True, - sampling_rate=16000, - sample_width=2, - channels=1, - ) - self.assertIsInstance(audio_source, RawAudioSource) - def test_from_file_large_file_wave(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.wav" - audio_source = from_file(filename, large_file=True) - self.assertIsInstance(audio_source, WaveAudioSource) +def test_from_file_large_file_wave(): + filename = "tests/data/test_16KHZ_mono_400Hz.wav" + audio_source = from_file(filename, large_file=True) + assert isinstance(audio_source, WaveAudioSource) - def test_from_file_large_file_compressed(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.ogg" - with self.assertRaises(AudioIOError): - from_file(filename, large_file=True) - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_from_file_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - from_file("audio", audio_format="raw", **params) +def test_from_file_large_file_compressed(): + filename = "tests/data/test_16KHZ_mono_400Hz.ogg" + with pytest.raises(AudioIOError): + from_file(filename, large_file=True) - def test_from_file_no_pydub(self): - with patch("auditok.io._WITH_PYDUB", False): - with self.assertRaises(AudioIOError): - from_file("audio", "mp3") - @patch("auditok.io._WITH_PYDUB", True) - @patch("auditok.io.BufferAudioSource") - @genty_dataset( - ogg_first_channel=("ogg", "from_ogg"), - ogg_second_channel=("ogg", "from_ogg"), - ogg_mix=("ogg", "from_ogg"), - ogg_default=("ogg", "from_ogg"), - mp3_left_channel=("mp3", "from_mp3"), - mp3_right_channel=("mp3", "from_mp3"), - flac_first_channel=("flac", "from_file"), - flac_second_channel=("flac", "from_file"), - flv_left_channel=("flv", "from_flv"), - webm_right_channel=("webm", "from_file"), - ) - def test_from_file_multichannel_audio_compressed( - self, audio_format, function, *mocks - ): - filename = "audio.{}".format(audio_format) - segment_mock = Mock() - segment_mock.sample_width = 2 - segment_mock.channels = 2 - segment_mock._data = b"abcd" - with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: - open_func.return_value = segment_mock - from_file(filename) - self.assertTrue(open_func.called) +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_from_file_missing_audio_param(missing_param): + with pytest.raises(AudioParameterError): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + from_file("audio", audio_format="raw", **params) - @genty_dataset( - mono=("mono_400", (400,)), - three_channel=("3channel_400-800-1600", (400, 800, 1600)), - mono_large_file=("mono_400", (400,), True), - three_channel_large_file=( + +def test_from_file_no_pydub(): + with patch("auditok.io._WITH_PYDUB", False): + with pytest.raises(AudioIOError): + from_file("audio", "mp3") + + +@pytest.mark.parametrize( + "audio_format, function", + [ + ("ogg", "from_ogg"), # ogg_first_channel + ("ogg", "from_ogg"), # ogg_second_channel + ("ogg", "from_ogg"), # ogg_mix + ("ogg", "from_ogg"), # ogg_default + ("mp3", "from_mp3"), # mp3_left_channel + ("mp3", "from_mp3"), # mp3_right_channel + ("flac", "from_file"), # flac_first_channel + ("flac", "from_file"), # flac_second_channel + ("flv", "from_flv"), # flv_left_channel + ("webm", "from_file"), # webm_right_channel + ], + ids=[ + "ogg_first_channel", + "ogg_second_channel", + "ogg_mix", + "ogg_default", + "mp3_left_channel", + "mp3_right_channel", + "flac_first_channel", + "flac_second_channel", + "flv_left_channel", + "webm_right_channel", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +@patch("auditok.io.BufferAudioSource") +def test_from_file_multichannel_audio_compressed( + mock_buffer_audio_source, audio_format, function +): + filename = "audio.{}".format(audio_format) + segment_mock = Mock() + segment_mock.sample_width = 2 + segment_mock.channels = 2 + segment_mock._data = b"abcd" + with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: + open_func.return_value = segment_mock + from_file(filename) + assert open_func.called + + +@pytest.mark.parametrize( + "file_id, frequencies, large_file", + [ + ("mono_400", (400,), False), # mono + ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel + ("mono_400", (400,), True), # mono_large_file + ( "3channel_400-800-1600", (400, 800, 1600), True, - ), + ), # three_channel_large_file + ], + ids=[ + "mono", + "three_channel", + "mono_large_file", + "three_channel_large_file", + ], +) +def test_load_raw(file_id, frequencies, large_file): + filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + audio_source = _load_raw( + filename, 16000, 2, len(frequencies), large_file=large_file ) - def test_load_raw(self, file_id, frequencies, large_file=False): - filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - audio_source = _load_raw( - filename, 16000, 2, len(frequencies), large_file=large_file - ) - audio_source.open() - data = audio_source.read(-1) - audio_source.close() - expected_class = RawAudioSource if large_file else BufferAudioSource - self.assertIsInstance(audio_source, expected_class) - self.assertEqual(audio_source.sampling_rate, 16000) - self.assertEqual(audio_source.sample_width, 2) - self.assertEqual(audio_source.channels, len(frequencies)) - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - self.assertEqual(data, expected) + audio_source.open() + data = audio_source.read(-1) + audio_source.close() + expected_class = RawAudioSource if large_file else BufferAudioSource + assert isinstance(audio_source, expected_class) + assert audio_source.sampling_rate == 16000 + assert audio_source.sample_width == 2 + assert audio_source.channels == len(frequencies) + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + fmt = FORMAT[audio_source.sample_width] + expected = array(fmt, _sample_generator(*mono_channels)).tobytes() + assert data == expected - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_load_raw_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - srate, swidth, channels, _ = _get_audio_parameters(params) - _load_raw("audio", srate, swidth, channels) - @genty_dataset( - mono=("mono_400", (400,)), - three_channel=("3channel_400-800-1600", (400, 800, 1600)), - mono_large_file=("mono_400", (400,), True), - three_channel_large_file=( +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_load_raw_missing_audio_param(missing_param): + with pytest.raises(AudioParameterError): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + srate, swidth, channels, _ = _get_audio_parameters(params) + _load_raw("audio", srate, swidth, channels) + + +@pytest.mark.parametrize( + "file_id, frequencies, large_file", + [ + ("mono_400", (400,), False), # mono + ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel + ("mono_400", (400,), True), # mono_large_file + ( "3channel_400-800-1600", (400, 800, 1600), True, - ), + ), # three_channel_large_file + ], + ids=[ + "mono", + "three_channel", + "mono_large_file", + "three_channel_large_file", + ], +) +def test_load_wave(file_id, frequencies, large_file): + filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + audio_source = _load_wave(filename, large_file=large_file) + audio_source.open() + data = audio_source.read(-1) + audio_source.close() + expected_class = WaveAudioSource if large_file else BufferAudioSource + assert isinstance(audio_source, expected_class) + assert audio_source.sampling_rate == 16000 + assert audio_source.sample_width == 2 + assert audio_source.channels == len(frequencies) + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + fmt = FORMAT[audio_source.sample_width] + expected = array(fmt, _sample_generator(*mono_channels)).tobytes() + assert data == expected + + +@pytest.mark.parametrize( + "audio_format, channels, function", + [ + ("ogg", 2, "from_ogg"), # ogg_default_first_channel + ("ogg", 1, "from_ogg"), # ogg_first_channel + ("ogg", 2, "from_ogg"), # ogg_second_channel + ("ogg", 3, "from_ogg"), # ogg_mix_channels + ("mp3", 1, "from_mp3"), # mp3_left_channel + ("mp3", 2, "from_mp3"), # mp3_right_channel + ("mp3", 3, "from_mp3"), # mp3_mix_channels + ("flac", 2, "from_file"), # flac_first_channel + ("flac", 2, "from_file"), # flac_second_channel + ("flv", 1, "from_flv"), # flv_left_channel + ("webm", 2, "from_file"), # webm_right_channel + ("webm", 4, "from_file"), # webm_mix_channels + ], + ids=[ + "ogg_default_first_channel", + "ogg_first_channel", + "ogg_second_channel", + "ogg_mix_channels", + "mp3_left_channel", + "mp3_right_channel", + "mp3_mix_channels", + "flac_first_channel", + "flac_second_channel", + "flv_left_channel", + "webm_right_channel", + "webm_mix_channels", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +@patch("auditok.io.BufferAudioSource") +def test_load_with_pydub( + mock_buffer_audio_source, audio_format, channels, function +): + filename = "audio.{}".format(audio_format) + segment_mock = Mock() + segment_mock.sample_width = 2 + segment_mock.channels = channels + segment_mock._data = b"abcdefgh" + with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: + open_func.return_value = segment_mock + _load_with_pydub(filename, audio_format) + assert open_func.called + + +@pytest.mark.parametrize( + "filename, frequencies", + [ + ("mono_400Hz.raw", (400,)), # mono + ("3channel_400-800-1600Hz.raw", (400, 800, 1600)), # three_channel + ], + ids=["mono", "three_channel"], +) +def test_save_raw(filename, frequencies): + filename = "tests/data/test_16KHZ_{}".format(filename) + sample_width = 2 + fmt = FORMAT[sample_width] + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + data = array(fmt, _sample_generator(*mono_channels)).tobytes() + tmpfile = NamedTemporaryFile() + _save_raw(data, tmpfile.name) + assert filecmp.cmp(tmpfile.name, filename, shallow=False) + + +@pytest.mark.parametrize( + "filename, frequencies", + [ + ("mono_400Hz.wav", (400,)), # mono + ("3channel_400-800-1600Hz.wav", (400, 800, 1600)), # three_channel + ], + ids=["mono", "three_channel"], +) +def test_save_wave(filename, frequencies): + filename = "tests/data/test_16KHZ_{}".format(filename) + sampling_rate = 16000 + sample_width = 2 + channels = len(frequencies) + fmt = FORMAT[sample_width] + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + data = array(fmt, _sample_generator(*mono_channels)).tobytes() + tmpfile = NamedTemporaryFile() + _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels) + assert filecmp.cmp(tmpfile.name, filename, shallow=False) + + +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_save_wave_missing_audio_param(missing_param): + with pytest.raises(AudioParameterError): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + srate, swidth, channels, _ = _get_audio_parameters(params) + _save_wave(b"\0\0", "audio", srate, swidth, channels) + + +def test_save_with_pydub(): + with patch("auditok.io.AudioSegment.export") as export: + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, "audio.ogg") + _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1) + assert export.called + tmpdir.cleanup() + + +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio", "raw"), # raw_with_audio_format + ("audio.raw", None), # raw_with_extension + ("audio.mp3", "raw"), # raw_with_audio_format_and_extension + ("audio", None), # raw_no_audio_format_nor_extension + ], + ids=[ + "raw_with_audio_format", + "raw_with_extension", + "raw_with_audio_format_and_extension", + "raw_no_audio_format_nor_extension", + ], +) +def test_to_file_raw(filename, audio_format): + exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw" + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, filename) + data = PURE_TONE_DICT[400].tobytes() + to_file(data, filename, audio_format=audio_format) + assert filecmp.cmp(filename, exp_filename, shallow=False) + tmpdir.cleanup() + + +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio", "wav"), # wav_with_audio_format + ("audio.wav", None), # wav_with_extension + ("audio.mp3", "wav"), # wav_with_audio_format_and_extension + ("audio", "wave"), # wave_with_audio_format + ("audio.wave", None), # wave_with_extension + ("audio.mp3", "wave"), # wave_with_audio_format_and_extension + ], + ids=[ + "wav_with_audio_format", + "wav_with_extension", + "wav_with_audio_format_and_extension", + "wave_with_audio_format", + "wave_with_extension", + "wave_with_audio_format_and_extension", + ], +) +def test_to_file_wave(filename, audio_format): + exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav" + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, filename) + data = PURE_TONE_DICT[400].tobytes() + to_file( + data, + filename, + audio_format=audio_format, + sampling_rate=16000, + sample_width=2, + channels=1, ) - def test_load_wave(self, file_id, frequencies, large_file=False): - filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - audio_source = _load_wave(filename, large_file=large_file) - audio_source.open() - data = audio_source.read(-1) - audio_source.close() - expected_class = WaveAudioSource if large_file else BufferAudioSource - self.assertIsInstance(audio_source, expected_class) - self.assertEqual(audio_source.sampling_rate, 16000) - self.assertEqual(audio_source.sample_width, 2) - self.assertEqual(audio_source.channels, len(frequencies)) - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - self.assertEqual(data, expected) + assert filecmp.cmp(filename, exp_filename, shallow=False) + tmpdir.cleanup() - @patch("auditok.io._WITH_PYDUB", True) - @patch("auditok.io.BufferAudioSource") - @genty_dataset( - ogg_default_first_channel=("ogg", 2, "from_ogg"), - ogg_first_channel=("ogg", 1, "from_ogg"), - ogg_second_channel=("ogg", 2, "from_ogg"), - ogg_mix_channels=("ogg", 3, "from_ogg"), - mp3_left_channel=("mp3", 1, "from_mp3"), - mp3_right_channel=("mp3", 2, "from_mp3"), - mp3_mix_channels=("mp3", 3, "from_mp3"), - flac_first_channel=("flac", 2, "from_file"), - flac_second_channel=("flac", 2, "from_file"), - flv_left_channel=("flv", 1, "from_flv"), - webm_right_channel=("webm", 2, "from_file"), - webm_mix_channels=("webm", 4, "from_file"), - ) - def test_load_with_pydub(self, audio_format, channels, function, *mocks): - filename = "audio.{}".format(audio_format) - segment_mock = Mock() - segment_mock.sample_width = 2 - segment_mock.channels = channels - segment_mock._data = b"abcdefgh" - with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: - open_func.return_value = segment_mock - _load_with_pydub(filename, audio_format) - self.assertTrue(open_func.called) - @genty_dataset( - mono=("mono_400Hz.raw", (400,)), - three_channel=("3channel_400-800-1600Hz.raw", (400, 800, 1600)), - ) - def test_save_raw(self, filename, frequencies): - filename = "tests/data/test_16KHZ_{}".format(filename) - sample_width = 2 - fmt = FORMAT[sample_width] - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - data = array(fmt, _sample_generator(*mono_channels)).tobytes() - tmpfile = NamedTemporaryFile() - _save_raw(data, tmpfile.name) - self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False)) +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_to_file_missing_audio_param(missing_param): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + with pytest.raises(AudioParameterError): + to_file(b"\0\0", "audio", audio_format="wav", **params) + with pytest.raises(AudioParameterError): + to_file(b"\0\0", "audio", audio_format="mp3", **params) - @genty_dataset( - mono=("mono_400Hz.wav", (400,)), - three_channel=("3channel_400-800-1600Hz.wav", (400, 800, 1600)), - ) - def test_save_wave(self, filename, frequencies): - filename = "tests/data/test_16KHZ_{}".format(filename) - sampling_rate = 16000 - sample_width = 2 - channels = len(frequencies) - fmt = FORMAT[sample_width] - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - data = array(fmt, _sample_generator(*mono_channels)).tobytes() - tmpfile = NamedTemporaryFile() - _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels) - self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False)) - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_save_wave_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - srate, swidth, channels, _ = _get_audio_parameters(params) - _save_wave(b"\0\0", "audio", srate, swidth, channels) +def test_to_file_no_pydub(): + with patch("auditok.io._WITH_PYDUB", False): + with pytest.raises(AudioIOError): + to_file("audio", b"", "mp3") - def test_save_with_pydub(self): - with patch("auditok.io.AudioSegment.export") as export: - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, "audio.ogg") - _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1) - self.assertTrue(export.called) - tmpdir.cleanup() - @genty_dataset( - raw_with_audio_format=("audio", "raw"), - raw_with_extension=("audio.raw", None), - raw_with_audio_format_and_extension=("audio.mp3", "raw"), - raw_no_audio_format_nor_extension=("audio", None), - ) - def test_to_file_raw(self, filename, audio_format): - exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw" +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio.ogg", None), # ogg_with_extension + ("audio", "ogg"), # ogg_with_audio_format + ("audio.wav", "ogg"), # ogg_format_with_wrong_extension + ], + ids=[ + "ogg_with_extension", + "ogg_with_audio_format", + "ogg_format_with_wrong_extension", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +def test_to_file_compressed(filename, audio_format): + with patch("auditok.io.AudioSegment.export") as export: tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, filename) - data = PURE_TONE_DICT[400].tobytes() - to_file(data, filename, audio_format=audio_format) - self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False)) + to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT) + assert export.called tmpdir.cleanup() - @genty_dataset( - wav_with_audio_format=("audio", "wav"), - wav_with_extension=("audio.wav", None), - wav_with_audio_format_and_extension=("audio.mp3", "wav"), - wave_with_audio_format=("audio", "wave"), - wave_with_extension=("audio.wave", None), - wave_with_audio_format_and_extension=("audio.mp3", "wave"), - ) - def test_to_file_wave(self, filename, audio_format): - exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav" - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, filename) - data = PURE_TONE_DICT[400].tobytes() - to_file( - data, - filename, - audio_format=audio_format, - sampling_rate=16000, - sample_width=2, - channels=1, - ) - self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False)) - tmpdir.cleanup() - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_to_file_missing_audio_param(self, missing_param): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - with self.assertRaises(AudioParameterError): - to_file(b"\0\0", "audio", audio_format="wav", **params) - with self.assertRaises(AudioParameterError): - to_file(b"\0\0", "audio", audio_format="mp3", **params) - - def test_to_file_no_pydub(self): - with patch("auditok.io._WITH_PYDUB", False): - with self.assertRaises(AudioIOError): - to_file("audio", b"", "mp3") - - @patch("auditok.io._WITH_PYDUB", True) - @genty_dataset( - ogg_with_extension=("audio.ogg", None), - ogg_with_audio_format=("audio", "ogg"), - ogg_format_with_wrong_extension=("audio.wav", "ogg"), - ) - def test_to_file_compressed(self, filename, audio_format, *mocks): - with patch("auditok.io.AudioSegment.export") as export: - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, filename) - to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT) - self.assertTrue(export.called) - tmpdir.cleanup() - - @genty_dataset( - string_wave=( +@pytest.mark.parametrize( + "input, expected_type, extra_args", + [ + ( "tests/data/test_16KHZ_mono_400Hz.wav", BufferAudioSource, - ), - string_wave_large_file=( + None, + ), # string_wave + ( "tests/data/test_16KHZ_mono_400Hz.wav", WaveAudioSource, {"large_file": True}, - ), - stdin=("-", StdinAudioSource), - string_raw=("tests/data/test_16KHZ_mono_400Hz.raw", BufferAudioSource), - string_raw_large_file=( + ), # string_wave_large_file + ("-", StdinAudioSource, None), # stdin + ( + "tests/data/test_16KHZ_mono_400Hz.raw", + BufferAudioSource, + None, + ), # string_raw + ( "tests/data/test_16KHZ_mono_400Hz.raw", RawAudioSource, {"large_file": True}, - ), - bytes_=(b"0" * 8000, BufferAudioSource), - ) - def test_get_audio_source(self, input, expected_type, extra_args=None): - kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} - if extra_args is not None: - kwargs.update(extra_args) - audio_source = get_audio_source(input, **kwargs) - self.assertIsInstance(audio_source, expected_type) - - -if __name__ == "__main__": - unittest.main() + ), # string_raw_large_file + (b"0" * 8000, BufferAudioSource, None), # bytes_ + ], + ids=[ + "string_wave", + "string_wave_large_file", + "stdin", + "string_raw", + "string_raw_large_file", + "bytes_", + ], +) +def test_get_audio_source(input, expected_type, extra_args): + kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} + if extra_args is not None: + kwargs.update(extra_args) + audio_source = get_audio_source(input, **kwargs) + assert isinstance(audio_source, expected_type)
--- a/tests/test_plotting.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_plotting.py Sat May 25 21:54:13 2024 +0200 @@ -1,70 +1,79 @@ import os import sys -import unittest -from unittest import TestCase +import pytest from tempfile import TemporaryDirectory -from genty import genty, genty_dataset import matplotlib -matplotlib.use("AGG") # noqa E402 -import matplotlib.pyplot as plt -from auditok.core import AudioRegion +matplotlib.use("AGG") +import matplotlib.pyplot as plt # noqa E402 +from auditok.core import AudioRegion # noqa E402 if sys.version_info.minor <= 5: PREFIX = "py34_py35/" else: PREFIX = "" +SAVE_NEW_IMAGES = False +if SAVE_NEW_IMAGES: + import shutil # noqa E402 + matplotlib.rcParams["figure.figsize"] = (10, 4) -@genty -class TestPlotting(TestCase): - @genty_dataset(mono=(1,), stereo=(2,)) - def test_region_plot(self, channels): - type_ = "mono" if channels == 1 else "stereo" - audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) - image_filename = "tests/images/{}plot_{}_region.png".format( - PREFIX, type_ +@pytest.mark.parametrize("channels", [1, 2], ids=["mono", "stereo"]) +def test_region_plot(channels): + type_ = "mono" if channels == 1 else "stereo" + audio_filename = f"tests/data/test_split_10HZ_{type_}.raw" + image_filename = f"tests/images/{PREFIX}plot_{type_}_region.png" + expected_image = plt.imread(image_filename) + with TemporaryDirectory() as tmpdir: + output_image_filename = os.path.join(tmpdir, "image.png") + region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) + region.plot(show=False, save_as=output_image_filename) + output_image = plt.imread(output_image_filename) + + if SAVE_NEW_IMAGES: + shutil.copy(output_image_filename, image_filename) + assert (output_image == expected_image).all() # mono, stereo + + +@pytest.mark.parametrize( + "channels, use_channel", + [ + (1, None), # mono + (2, "any"), # stereo_any + (2, 0), # stereo_uc_0 + (2, 1), # stereo_uc_1 + (2, "mix"), # stereo_uc_mix + ], + ids=["mono", "stereo_any", "stereo_uc_0", "stereo_uc_1", "stereo_uc_mix"], +) +def test_region_split_and_plot(channels, use_channel): + type_ = "mono" if channels == 1 else "stereo" + audio_filename = f"tests/data/test_split_10HZ_{type_}.raw" + if type_ == "mono": + fmt = "tests/images/{}split_and_plot_mono_region.png" + else: + fmt = "tests/images/{}split_and_plot_uc_{}_stereo_region.png" + image_filename = fmt.format(PREFIX, use_channel) + + expected_image = plt.imread(image_filename) + with TemporaryDirectory() as tmpdir: + output_image_filename = os.path.join(tmpdir, "image.png") + region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) + region.split_and_plot( + aw=0.1, + uc=use_channel, + max_silence=0, + show=False, + save_as=output_image_filename, ) - expected_image = plt.imread(image_filename) - with TemporaryDirectory() as tmpdir: - output_image_filename = os.path.join(tmpdir, "image.png") - region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) - region.plot(show=False, save_as=output_image_filename) - output_image = plt.imread(output_image_filename) - self.assertTrue((output_image == expected_image).all()) + output_image = plt.imread(output_image_filename) - @genty_dataset( - mono=(1,), - stereo_any=(2, "any"), - stereo_uc_0=(2, 0), - stereo_uc_1=(2, 1), - stereo_uc_mix=(2, "mix"), - ) - def test_region_split_and_plot(self, channels, use_channel=None): - type_ = "mono" if channels == 1 else "stereo" - audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) - if type_ == "mono": - fmt = "tests/images/{}split_and_plot_mono_region.png" - else: - fmt = "tests/images/{}split_and_plot_uc_{}_stereo_region.png" - image_filename = fmt.format(PREFIX, use_channel) - - expected_image = plt.imread(image_filename) - with TemporaryDirectory() as tmpdir: - output_image_filename = os.path.join(tmpdir, "image.png") - region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) - region.split_and_plot( - aw=0.1, - uc=use_channel, - max_silence=0, - show=False, - save_as=output_image_filename, - ) - output_image = plt.imread(output_image_filename) - self.assertTrue((output_image == expected_image).all()) + if SAVE_NEW_IMAGES: + shutil.copy(output_image_filename, image_filename) + assert (output_image == expected_image).all() if __name__ == "__main__": - unittest.main() + pytest.main()
--- a/tests/test_signal.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_signal.py Sat May 25 21:54:13 2024 +0200 @@ -1,188 +1,260 @@ -import unittest -from unittest import TestCase +import pytest from array import array as array_ -from genty import genty, genty_dataset import numpy as np from auditok import signal as signal_ from auditok import signal_numpy -@genty -class TestSignal(TestCase): - def setUp(self): - self.data = b"012345679ABC" - self.numpy_fmt = {"b": np.int8, "h": np.int16, "i": np.int32} +@pytest.fixture +def setup_data(): + return b"012345679ABC" - @genty_dataset( - int8_mono=(1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]]), - int16_mono=(2, [[12592, 13106, 13620, 14134, 16697, 17218]]), - int32_mono=(4, [[858927408, 926299444, 1128415545]]), - int8_stereo=(1, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]]), - int16_stereo=(2, [[12592, 13620, 16697], [13106, 14134, 17218]]), - int32_3channel=(4, [[858927408], [926299444], [1128415545]]), - ) - def test_to_array(self, sample_width, expected): - channels = len(expected) - expected = [ - array_(signal_.FORMAT[sample_width], xi) for xi in expected - ] - result = signal_.to_array(self.data, sample_width, channels) - result_numpy = signal_numpy.to_array(self.data, sample_width, channels) - self.assertEqual(result, expected) - self.assertTrue((result_numpy == np.asarray(expected)).all()) - self.assertEqual(result_numpy.dtype, np.float64) - @genty_dataset( - int8_1channel_select_0=( +@pytest.fixture +def numpy_fmt(): + return {"b": np.int8, "h": np.int16, "i": np.int32} + + +@pytest.mark.parametrize( + "sample_width, expected", + [ + (1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]]), # int8_mono + (2, [[12592, 13106, 13620, 14134, 16697, 17218]]), # int16_mono + (4, [[858927408, 926299444, 1128415545]]), # int32_mono + ( + 1, + [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], + ), # int8_stereo + (2, [[12592, 13620, 16697], [13106, 14134, 17218]]), # int16_stereo + (4, [[858927408], [926299444], [1128415545]]), # int32_3channel + ], + ids=[ + "int8_mono", + "int16_mono", + "int32_mono", + "int8_stereo", + "int16_stereo", + "int32_3channel", + ], +) +def test_to_array(setup_data, sample_width, expected): + data = setup_data + channels = len(expected) + expected = [array_(signal_.FORMAT[sample_width], xi) for xi in expected] + result = signal_.to_array(data, sample_width, channels) + result_numpy = signal_numpy.to_array(data, sample_width, channels) + assert result == expected + assert (result_numpy == np.asarray(expected)).all() + assert result_numpy.dtype == np.float64 + + +@pytest.mark.parametrize( + "fmt, channels, selected, expected", + [ + ( "b", 1, 0, [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], - ), - int8_2channel_select_0=("b", 2, 0, [48, 50, 52, 54, 57, 66]), - int8_3channel_select_0=("b", 3, 0, [48, 51, 54, 65]), - int8_3channel_select_1=("b", 3, 1, [49, 52, 55, 66]), - int8_3channel_select_2=("b", 3, 2, [50, 53, 57, 67]), - int8_4channel_select_0=("b", 4, 0, [48, 52, 57]), - int16_1channel_select_0=( + ), # int8_1channel_select_0 + ("b", 2, 0, [48, 50, 52, 54, 57, 66]), # int8_2channel_select_0 + ("b", 3, 0, [48, 51, 54, 65]), # int8_3channel_select_0 + ("b", 3, 1, [49, 52, 55, 66]), # int8_3channel_select_1 + ("b", 3, 2, [50, 53, 57, 67]), # int8_3channel_select_2 + ("b", 4, 0, [48, 52, 57]), # int8_4channel_select_0 + ( "h", 1, 0, [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel_select_0=("h", 2, 0, [12592, 13620, 16697]), - int16_2channel_select_1=("h", 2, 1, [13106, 14134, 17218]), - int16_3channel_select_0=("h", 3, 0, [12592, 14134]), - int16_3channel_select_1=("h", 3, 1, [13106, 16697]), - int16_3channel_select_2=("h", 3, 2, [13620, 17218]), - int32_1channel_select_0=( + ), # int16_1channel_select_0 + ("h", 2, 0, [12592, 13620, 16697]), # int16_2channel_select_0 + ("h", 2, 1, [13106, 14134, 17218]), # int16_2channel_select_1 + ("h", 3, 0, [12592, 14134]), # int16_3channel_select_0 + ("h", 3, 1, [13106, 16697]), # int16_3channel_select_1 + ("h", 3, 2, [13620, 17218]), # int16_3channel_select_2 + ( "i", 1, 0, [858927408, 926299444, 1128415545], - ), - int32_3channel_select_0=("i", 3, 0, [858927408]), - int32_3channel_select_1=("i", 3, 1, [926299444]), - int32_3channel_select_2=("i", 3, 2, [1128415545]), + ), # int32_1channel_select_0 + ("i", 3, 0, [858927408]), # int32_3channel_select_0 + ("i", 3, 1, [926299444]), # int32_3channel_select_1 + ("i", 3, 2, [1128415545]), # int32_3channel_select_2 + ], + ids=[ + "int8_1channel_select_0", + "int8_2channel_select_0", + "int8_3channel_select_0", + "int8_3channel_select_1", + "int8_3channel_select_2", + "int8_4channel_select_0", + "int16_1channel_select_0", + "int16_2channel_select_0", + "int16_2channel_select_1", + "int16_3channel_select_0", + "int16_3channel_select_1", + "int16_3channel_select_2", + "int32_1channel_select_0", + "int32_3channel_select_0", + "int32_3channel_select_1", + "int32_3channel_select_2", + ], +) +def test_extract_single_channel( + setup_data, numpy_fmt, fmt, channels, selected, expected +): + data = setup_data + result = signal_.extract_single_channel(data, fmt, channels, selected) + expected = array_(fmt, expected) + expected_numpy_fmt = numpy_fmt[fmt] + assert result == expected + result_numpy = signal_numpy.extract_single_channel( + data, numpy_fmt[fmt], channels, selected ) - def test_extract_single_channel(self, fmt, channels, selected, expected): - result = signal_.extract_single_channel( - self.data, fmt, channels, selected - ) - expected = array_(fmt, expected) - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) - result_numpy = signal_numpy.extract_single_channel( - self.data, self.numpy_fmt[fmt], channels, selected - ) - self.assertTrue(all(result_numpy == expected)) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) + assert all(result_numpy == expected) + assert result_numpy.dtype == expected_numpy_fmt - @genty_dataset( - int8_2channel=("b", 2, [48, 50, 52, 54, 61, 66]), - int8_4channel=("b", 4, [50, 54, 64]), - int16_1channel=("h", 1, [12592, 13106, 13620, 14134, 16697, 17218]), - int16_2channel=("h", 2, [12849, 13877, 16958]), - int32_3channel=("i", 3, [971214132]), + +@pytest.mark.parametrize( + "fmt, channels, expected", + [ + ("b", 2, [48, 50, 52, 54, 61, 66]), # int8_2channel + ("b", 4, [50, 54, 64]), # int8_4channel + ("h", 1, [12592, 13106, 13620, 14134, 16697, 17218]), # int16_1channel + ("h", 2, [12849, 13877, 16958]), # int16_2channel + ("i", 3, [971214132]), # int32_3channel + ], + ids=[ + "int8_2channel", + "int8_4channel", + "int16_1channel", + "int16_2channel", + "int32_3channel", + ], +) +def test_compute_average_channel( + setup_data, numpy_fmt, fmt, channels, expected +): + data = setup_data + result = signal_.compute_average_channel(data, fmt, channels) + expected = array_(fmt, expected) + expected_numpy_fmt = numpy_fmt[fmt] + assert result == expected + result_numpy = signal_numpy.compute_average_channel( + data, numpy_fmt[fmt], channels ) - def test_compute_average_channel(self, fmt, channels, expected): - result = signal_.compute_average_channel(self.data, fmt, channels) - expected = array_(fmt, expected) - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) - result_numpy = signal_numpy.compute_average_channel( - self.data, self.numpy_fmt[fmt], channels - ) - self.assertTrue(all(result_numpy == expected)) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) + assert all(result_numpy == expected) + assert result_numpy.dtype == expected_numpy_fmt - @genty_dataset( - int8_2channel=(1, [48, 50, 52, 54, 61, 66]), - int16_2channel=(2, [12849, 13877, 16957]), - ) - def test_compute_average_channel_stereo(self, sample_width, expected): - result = signal_.compute_average_channel_stereo( - self.data, sample_width - ) - fmt = signal_.FORMAT[sample_width] - expected = array_(fmt, expected) - self.assertEqual(result, expected) - @genty_dataset( - int8_1channel=( +@pytest.mark.parametrize( + "sample_width, expected", + [ + (1, [48, 50, 52, 54, 61, 66]), # int8_2channel + (2, [12849, 13877, 16957]), # int16_2channel + ], + ids=["int8_2channel", "int16_2channel"], +) +def test_compute_average_channel_stereo(setup_data, sample_width, expected): + data = setup_data + result = signal_.compute_average_channel_stereo(data, sample_width) + fmt = signal_.FORMAT[sample_width] + expected = array_(fmt, expected) + assert result == expected + + +@pytest.mark.parametrize( + "fmt, channels, expected", + [ + ( "b", 1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], - ), - int8_2channel=( + ), # int8_1channel + ( "b", 2, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], - ), - int8_4channel=( + ), # int8_2channel + ( "b", 4, [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], - ), - int16_2channel=( + ), # int8_4channel + ( "h", 2, [[12592, 13620, 16697], [13106, 14134, 17218]], - ), - int32_3channel=("i", 3, [[858927408], [926299444], [1128415545]]), + ), # int16_2channel + ("i", 3, [[858927408], [926299444], [1128415545]]), # int32_3channel + ], + ids=[ + "int8_1channel", + "int8_2channel", + "int8_4channel", + "int16_2channel", + "int32_3channel", + ], +) +def test_separate_channels(setup_data, numpy_fmt, fmt, channels, expected): + data = setup_data + result = signal_.separate_channels(data, fmt, channels) + expected = [array_(fmt, exp) for exp in expected] + expected_numpy_fmt = numpy_fmt[fmt] + assert result == expected + result_numpy = signal_numpy.separate_channels( + data, numpy_fmt[fmt], channels ) - def test_separate_channels(self, fmt, channels, expected): - result = signal_.separate_channels(self.data, fmt, channels) - expected = [array_(fmt, exp) for exp in expected] - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) + assert (result_numpy == expected).all() + assert result_numpy.dtype == expected_numpy_fmt - result_numpy = signal_numpy.separate_channels( - self.data, self.numpy_fmt[fmt], channels - ) - self.assertTrue((result_numpy == expected).all()) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) - @genty_dataset( - simple=([300, 320, 400, 600], 2, 52.50624901923348), - zero=([0], 2, -200), - zeros=([0, 0, 0], 2, -200), - ) - def test_calculate_energy_single_channel(self, x, sample_width, expected): - x = array_(signal_.FORMAT[sample_width], x) - energy = signal_.calculate_energy_single_channel(x, sample_width) - self.assertEqual(energy, expected) - energy = signal_numpy.calculate_energy_single_channel(x, sample_width) - self.assertEqual(energy, expected) +@pytest.mark.parametrize( + "x, sample_width, expected", + [ + ([300, 320, 400, 600], 2, 52.50624901923348), # simple + ([0], 2, -200), # zero + ([0, 0, 0], 2, -200), # zeros + ], + ids=["simple", "zero", "zeros"], +) +def test_calculate_energy_single_channel(x, sample_width, expected): + x = array_(signal_.FORMAT[sample_width], x) + energy = signal_.calculate_energy_single_channel(x, sample_width) + assert energy == expected + energy = signal_numpy.calculate_energy_single_channel(x, sample_width) + assert energy == expected - @genty_dataset( - min_=( + +@pytest.mark.parametrize( + "x, sample_width, aggregation_fn, expected", + [ + ( [[300, 320, 400, 600], [150, 160, 200, 300]], 2, min, 46.485649105953854, - ), - max_=( + ), # min_ + ( [[300, 320, 400, 600], [150, 160, 200, 300]], 2, max, 52.50624901923348, - ), + ), # max_ + ], + ids=["min_", "max_"], +) +def test_calculate_energy_multichannel( + x, sample_width, aggregation_fn, expected +): + x = [array_(signal_.FORMAT[sample_width], xi) for xi in x] + energy = signal_.calculate_energy_multichannel( + x, sample_width, aggregation_fn ) - def test_calculate_energy_multichannel( - self, x, sample_width, aggregation_fn, expected - ): - x = [array_(signal_.FORMAT[sample_width], xi) for xi in x] - energy = signal_.calculate_energy_multichannel( - x, sample_width, aggregation_fn - ) - self.assertEqual(energy, expected) - - energy = signal_numpy.calculate_energy_multichannel( - x, sample_width, aggregation_fn - ) - self.assertEqual(energy, expected) - - -if __name__ == "__main__": - unittest.main() + assert energy == expected + energy = signal_numpy.calculate_energy_multichannel( + x, sample_width, aggregation_fn + ) + assert energy == expected
--- a/tests/test_util.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_util.py Sat May 25 21:54:13 2024 +0200 @@ -1,9 +1,7 @@ -import unittest -from unittest import TestCase +import pytest from unittest.mock import patch import math from array import array as array_ -from genty import genty, genty_dataset from auditok.util import ( AudioEnergyValidator, make_duration_formatter, @@ -11,7 +9,6 @@ ) from auditok import signal as signal_ from auditok import signal_numpy - from auditok.exceptions import TimeFormatError @@ -44,10 +41,7 @@ two_pi_step = 2 * math.pi * step data = array_( fmt, - ( - int(math.sin(two_pi_step * i) * volume) - for i in range(total_samples) - ), + (int(math.sin(two_pi_step * i) * volume) for i in range(total_samples)), ) return data @@ -63,70 +57,120 @@ ) -@genty -class TestFunctions(TestCase): - def setUp(self): +class TestFunctions: + def setup_method(self): self.data = b"012345679ABC" - @genty_dataset( - only_seconds=("%S", 5400, "5400.000"), - only_millis=("%I", 5400, "5400000"), - full=("%h:%m:%s.%i", 3725.365, "01:02:05.365"), - full_zero_hours=("%h:%m:%s.%i", 1925.075, "00:32:05.075"), - full_zero_minutes=("%h:%m:%s.%i", 3659.075, "01:00:59.075"), - full_zero_seconds=("%h:%m:%s.%i", 3720.075, "01:02:00.075"), - full_zero_millis=("%h:%m:%s.%i", 3725, "01:02:05.000"), - duplicate_directive=( - "%h %h:%m:%s.%i %s", - 3725.365, - "01 01:02:05.365 05", - ), - no_millis=("%h:%m:%s", 3725, "01:02:05"), - no_seconds=("%h:%m", 3725, "01:02"), - no_minutes=("%h", 3725, "01"), - no_hours=("%m:%s.%i", 3725, "02:05.000"), + @pytest.mark.parametrize( + "fmt, duration, expected", + [ + ("%S", 5400, "5400.000"), # only_seconds + ("%I", 5400, "5400000"), # only_millis + ("%h:%m:%s.%i", 3725.365, "01:02:05.365"), # full + ("%h:%m:%s.%i", 1925.075, "00:32:05.075"), # full_zero_hours + ("%h:%m:%s.%i", 3659.075, "01:00:59.075"), # full_zero_minutes + ("%h:%m:%s.%i", 3720.075, "01:02:00.075"), # full_zero_seconds + ("%h:%m:%s.%i", 3725, "01:02:05.000"), # full_zero_millis + ( + "%h %h:%m:%s.%i %s", + 3725.365, + "01 01:02:05.365 05", + ), # duplicate_directive + ("%h:%m:%s", 3725, "01:02:05"), # no_millis + ("%h:%m", 3725, "01:02"), # no_seconds + ("%h", 3725, "01"), # no_minutes + ("%m:%s.%i", 3725, "02:05.000"), # no_hours + ], + ids=[ + "only_seconds", + "only_millis", + "full", + "full_zero_hours", + "full_zero_minutes", + "full_zero_seconds", + "full_zero_millis", + "duplicate_directive", + "no_millis", + "no_seconds", + "no_minutes", + "no_hours", + ], ) def test_make_duration_formatter(self, fmt, duration, expected): formatter = make_duration_formatter(fmt) result = formatter(duration) - self.assertEqual(result, expected) + assert result == expected - @genty_dataset( - duplicate_only_seconds=("%S %S",), - duplicate_only_millis=("%I %I",), - unknown_directive=("%x",), + @pytest.mark.parametrize( + "fmt", + [ + "%S %S", # duplicate_only_seconds + "%I %I", # duplicate_only_millis + "%x", # unknown_directive + ], + ids=[ + "duplicate_only_seconds", + "duplicate_only_millis", + "unknown_directive", + ], ) def test_make_duration_formatter_error(self, fmt): - with self.assertRaises(TimeFormatError): + with pytest.raises(TimeFormatError): make_duration_formatter(fmt) - @genty_dataset( - int8_1channel_select_0=( - 1, - 1, - 0, - [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], - ), - int8_2channel_select_0=(1, 2, 0, [48, 50, 52, 54, 57, 66]), - int8_3channel_select_0=(1, 3, 0, [48, 51, 54, 65]), - int8_3channel_select_1=(1, 3, 1, [49, 52, 55, 66]), - int8_3channel_select_2=(1, 3, 2, [50, 53, 57, 67]), - int8_4channel_select_0=(1, 4, 0, [48, 52, 57]), - int16_1channel_select_0=( - 2, - 1, - 0, - [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel_select_0=(2, 2, 0, [12592, 13620, 16697]), - int16_2channel_select_1=(2, 2, 1, [13106, 14134, 17218]), - int16_3channel_select_0=(2, 3, 0, [12592, 14134]), - int16_3channel_select_1=(2, 3, 1, [13106, 16697]), - int16_3channel_select_2=(2, 3, 2, [13620, 17218]), - int32_1channel_select_0=(4, 1, 0, [858927408, 926299444, 1128415545],), - int32_3channel_select_0=(4, 3, 0, [858927408]), - int32_3channel_select_1=(4, 3, 1, [926299444]), - int32_3channel_select_2=(4, 3, 2, [1128415545]), + @pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + ( + 1, + 1, + 0, + [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], + ), # int8_1channel_select_0 + (1, 2, 0, [48, 50, 52, 54, 57, 66]), # int8_2channel_select_0 + (1, 3, 0, [48, 51, 54, 65]), # int8_3channel_select_0 + (1, 3, 1, [49, 52, 55, 66]), # int8_3channel_select_1 + (1, 3, 2, [50, 53, 57, 67]), # int8_3channel_select_2 + (1, 4, 0, [48, 52, 57]), # int8_4channel_select_0 + ( + 2, + 1, + 0, + [12592, 13106, 13620, 14134, 16697, 17218], + ), # int16_1channel_select_0 + (2, 2, 0, [12592, 13620, 16697]), # int16_2channel_select_0 + (2, 2, 1, [13106, 14134, 17218]), # int16_2channel_select_1 + (2, 3, 0, [12592, 14134]), # int16_3channel_select_0 + (2, 3, 1, [13106, 16697]), # int16_3channel_select_1 + (2, 3, 2, [13620, 17218]), # int16_3channel_select_2 + ( + 4, + 1, + 0, + [858927408, 926299444, 1128415545], + ), # int32_1channel_select_0 + (4, 3, 0, [858927408]), # int32_3channel_select_0 + (4, 3, 1, [926299444]), # int32_3channel_select_1 + (4, 3, 2, [1128415545]), # int32_3channel_select_2 + ], + ids=[ + "int8_1channel_select_0", + "int8_2channel_select_0", + "int8_3channel_select_0", + "int8_3channel_select_1", + "int8_3channel_select_2", + "int8_4channel_select_0", + "int16_1channel_select_0", + "int16_2channel_select_0", + "int16_2channel_select_1", + "int16_3channel_select_0", + "int16_3channel_select_1", + "int16_3channel_select_2", + "int32_1channel_select_0", + "int32_3channel_select_0", + "int32_3channel_select_1", + "int32_3channel_select_2", + ], ) def test_make_channel_selector_one_channel( self, sample_width, channels, selected, expected @@ -141,7 +185,7 @@ expected = array_(fmt, expected) if channels == 1: expected = bytes(expected) - self.assertEqual(result, expected) + assert result == expected # Use signal functions with numpy implementation with patch("auditok.util.signal", signal_numpy): @@ -151,21 +195,31 @@ expected = array_(fmt, expected) if channels == 1: expected = bytes(expected) - self.assertEqual(result_numpy, expected) + assert result_numpy == expected else: - self.assertTrue(all(result_numpy == expected)) + assert all(result_numpy == expected) - @genty_dataset( - int8_2channel=(1, 2, "avg", [48, 50, 52, 54, 61, 66]), - int8_4channel=(1, 4, "average", [50, 54, 64]), - int16_1channel=( - 2, - 1, - "mix", - [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel=(2, 2, "avg", [12849, 13877, 16957]), - int32_3channel=(4, 3, "average", [971214132]), + @pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + (1, 2, "avg", [48, 50, 52, 54, 61, 66]), # int8_2channel + (1, 4, "average", [50, 54, 64]), # int8_4channel + ( + 2, + 1, + "mix", + [12592, 13106, 13620, 14134, 16697, 17218], + ), # int16_1channel + (2, 2, "avg", [12849, 13877, 16957]), # int16_2channel + (4, 3, "average", [971214132]), # int32_3channel + ], + ids=[ + "int8_2channel", + "int8_4channel", + "int16_1channel", + "int16_2channel", + "int32_3channel", + ], ) def test_make_channel_selector_average( self, sample_width, channels, selected, expected @@ -179,7 +233,7 @@ expected = array_(fmt, expected) if channels == 1: expected = bytes(expected) - self.assertEqual(result, expected) + assert result == expected # Use signal functions with numpy implementation with patch("auditok.util.signal", signal_numpy): @@ -187,36 +241,51 @@ result_numpy = selector(self.data) if channels in (1, 2): - self.assertEqual(result_numpy, expected) + assert result_numpy == expected else: - self.assertTrue(all(result_numpy == expected)) + assert all(result_numpy == expected) - @genty_dataset( - int8_1channel=( - 1, - 1, - "any", - [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], - ), - int8_2channel=( - 1, - 2, - None, - [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], - ), - int8_4channel=( - 1, - 4, - "any", - [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], - ), - int16_2channel=( - 2, - 2, - None, - [[12592, 13620, 16697], [13106, 14134, 17218]], - ), - int32_3channel=(4, 3, "any", [[858927408], [926299444], [1128415545]]), + @pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + ( + 1, + 1, + "any", + [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], + ), # int8_1channel + ( + 1, + 2, + None, + [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], + ), # int8_2channel + ( + 1, + 4, + "any", + [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], + ), # int8_4channel + ( + 2, + 2, + None, + [[12592, 13620, 16697], [13106, 14134, 17218]], + ), # int16_2channel + ( + 4, + 3, + "any", + [[858927408], [926299444], [1128415545]], + ), # int32_3channel + ], + ids=[ + "int8_1channel", + "int8_2channel", + "int8_4channel", + "int16_2channel", + "int32_3channel", + ], ) def test_make_channel_selector_any( self, sample_width, channels, selected, expected @@ -231,7 +300,7 @@ expected = [array_(fmt, exp) for exp in expected] if channels == 1: expected = bytes(expected[0]) - self.assertEqual(result, expected) + assert result == expected # Use signal functions with numpy implementation with patch("auditok.util.signal", signal_numpy): @@ -239,39 +308,66 @@ result_numpy = selector(self.data) if channels == 1: - self.assertEqual(result_numpy, expected) + assert result_numpy == expected else: - self.assertTrue((result_numpy == expected).all()) + assert (result_numpy == expected).all() -@genty -class TestAudioEnergyValidator(TestCase): - @genty_dataset( - mono_valid_uc_None=([350, 400], 1, None, True), - mono_valid_uc_any=([350, 400], 1, "any", True), - mono_valid_uc_0=([350, 400], 1, 0, True), - mono_valid_uc_mix=([350, 400], 1, "mix", True), - # previous cases are all the same since we have mono audio - mono_invalid_uc_None=([300, 300], 1, None, False), - stereo_valid_uc_None=([300, 400, 350, 300], 2, None, True), - stereo_valid_uc_any=([300, 400, 350, 300], 2, "any", True), - stereo_valid_uc_mix=([300, 400, 350, 300], 2, "mix", True), - stereo_valid_uc_avg=([300, 400, 350, 300], 2, "avg", True), - stereo_valid_uc_average=([300, 400, 300, 300], 2, "average", True), - stereo_valid_uc_mix_with_null_channel=( - [634, 0, 634, 0], - 2, - "mix", - True, - ), - stereo_valid_uc_0=([320, 100, 320, 100], 2, 0, True), - stereo_valid_uc_1=([100, 320, 100, 320], 2, 1, True), - stereo_invalid_uc_None=([280, 100, 280, 100], 2, None, False), - stereo_invalid_uc_any=([280, 100, 280, 100], 2, "any", False), - stereo_invalid_uc_mix=([400, 200, 400, 200], 2, "mix", False), - stereo_invalid_uc_0=([300, 400, 300, 400], 2, 0, False), - stereo_invalid_uc_1=([400, 300, 400, 300], 2, 1, False), - zeros=([0, 0, 0, 0], 2, None, False), +class TestAudioEnergyValidator: + @pytest.mark.parametrize( + "data, channels, use_channel, expected", + [ + ([350, 400], 1, None, True), # mono_valid_uc_None + ([350, 400], 1, "any", True), # mono_valid_uc_any + ([350, 400], 1, 0, True), # mono_valid_uc_0 + ([350, 400], 1, "mix", True), # mono_valid_uc_mix + ([300, 300], 1, None, False), # mono_invalid_uc_None + ([300, 400, 350, 300], 2, None, True), # stereo_valid_uc_None + ([300, 400, 350, 300], 2, "any", True), # stereo_valid_uc_any + ([300, 400, 350, 300], 2, "mix", True), # stereo_valid_uc_mix + ([300, 400, 350, 300], 2, "avg", True), # stereo_valid_uc_avg + ( + [300, 400, 300, 300], + 2, + "average", + True, + ), # stereo_valid_uc_average + ( + [634, 0, 634, 0], + 2, + "mix", + True, + ), # stereo_valid_uc_mix_with_null_channel + ([320, 100, 320, 100], 2, 0, True), # stereo_valid_uc_0 + ([100, 320, 100, 320], 2, 1, True), # stereo_valid_uc_1 + ([280, 100, 280, 100], 2, None, False), # stereo_invalid_uc_None + ([280, 100, 280, 100], 2, "any", False), # stereo_invalid_uc_any + ([400, 200, 400, 200], 2, "mix", False), # stereo_invalid_uc_mix + ([300, 400, 300, 400], 2, 0, False), # stereo_invalid_uc_0 + ([400, 300, 400, 300], 2, 1, False), # stereo_invalid_uc_1 + ([0, 0, 0, 0], 2, None, False), # zeros + ], + ids=[ + "mono_valid_uc_None", + "mono_valid_uc_any", + "mono_valid_uc_0", + "mono_valid_uc_mix", + "mono_invalid_uc_None", + "stereo_valid_uc_None", + "stereo_valid_uc_any", + "stereo_valid_uc_mix", + "stereo_valid_uc_avg", + "stereo_valid_uc_average", + "stereo_valid_uc_mix_with_null_channel", + "stereo_valid_uc_0", + "stereo_valid_uc_1", + "stereo_invalid_uc_None", + "stereo_invalid_uc_any", + "stereo_invalid_uc_mix", + "stereo_invalid_uc_0", + "stereo_invalid_uc_1", + "zeros", + ], ) def test_audio_energy_validator( self, data, channels, use_channel, expected @@ -285,10 +381,10 @@ ) if expected: - self.assertTrue(validator.is_valid(data)) + assert validator.is_valid(data) else: - self.assertFalse(validator.is_valid(data)) + assert not validator.is_valid(data) if __name__ == "__main__": - unittest.main() + pytest.main()
--- a/tests/test_workers.py Fri May 24 21:30:34 2024 +0200 +++ b/tests/test_workers.py Sat May 25 21:54:13 2024 +0200 @@ -1,9 +1,7 @@ import os -import unittest -from unittest import TestCase from unittest.mock import patch, call, Mock from tempfile import TemporaryDirectory -from genty import genty, genty_dataset +import pytest from auditok import AudioRegion, AudioDataSource from auditok.exceptions import AudioEncodingWarning from auditok.cmdline_util import make_logger @@ -17,215 +15,72 @@ ) -@genty -class TestWorkers(TestCase): - def setUp(self): +@pytest.fixture +def audio_data_source(): + reader = AudioDataSource( + input="tests/data/test_split_10HZ_mono.raw", + block_dur=0.1, + sr=10, + sw=2, + ch=1, + ) + yield reader + reader.close() - self.reader = AudioDataSource( - input="tests/data/test_split_10HZ_mono.raw", - block_dur=0.1, - sr=10, - sw=2, - ch=1, + +@pytest.fixture +def expected_detections(): + return [ + (0.2, 1.6), + (1.7, 3.1), + (3.4, 5.4), + (5.4, 7.4), + (7.4, 7.6), + ] + + +def test_TokenizerWorker(audio_data_source, expected_detections): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_TokenizerWorker") + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, ) - self.expected = [ - (0.2, 1.6), - (1.7, 3.1), - (3.4, 5.4), - (5.4, 7.4), - (7.4, 7.6), - ] + tokenizer.start_all() + tokenizer.join() + with open(file) as fp: + log_lines = fp.readlines() - def tearDown(self): - self.reader.close() + log_fmt = ( + "[DET]: Detection {} (start: {:.3f}, end: {:.3f}, duration: {:.3f})" + ) + assert len(tokenizer.detections) == len(expected_detections) + for i, (det, exp, log_line) in enumerate( + zip(tokenizer.detections, expected_detections, log_lines), 1 + ): + start, end = exp + exp_log_line = log_fmt.format(i, start, end, end - start) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line - def test_TokenizerWorker(self): - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_TokenizerWorker") - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - tokenizer.start_all() - tokenizer.join() - # Get logged text - with open(file) as fp: - log_lines = fp.readlines() - log_fmt = "[DET]: Detection {} (start: {:.3f}, " - log_fmt += "end: {:.3f}, duration: {:.3f})" - self.assertEqual(len(tokenizer.detections), len(self.expected)) - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(i, start, end, end - start) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_PlayerWorker(self): - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_RegionSaverWorker") - player_mock = Mock() - observers = [PlayerWorker(player_mock, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[PLAY]") - ] - self.assertTrue(player_mock.play.called) - - self.assertEqual(len(tokenizer.detections), len(self.expected)) - log_fmt = "[PLAY]: Detection {id} played" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(id=i) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_RegionSaverWorker(self): - filename_format = ( - "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" - ) - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_RegionSaverWorker") - observers = [RegionSaverWorker(filename_format, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - with patch("auditok.core.AudioRegion.save") as patched_save: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[SAVE]") - ] - - # Assert RegionSaverWorker ran as expected - expected_save_calls = [ - call( - filename_format.format( - id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] - ), - None, - ) - for i, exp in enumerate(self.expected, 1) - ] - - # Get calls to 'AudioRegion.save' - mock_calls = [ - c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 - ] - self.assertEqual(mock_calls, expected_save_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - - log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - expected_filename = filename_format.format( - id=i, start=start, end=end, duration=end - start - ) - exp_log_line = log_fmt.format(i, expected_filename) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_CommandLineWorker(self): - command_format = "do nothing with" - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_CommandLineWorker") - observers = [CommandLineWorker(command_format, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - with patch("auditok.workers.os.system") as patched_os_system: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[COMMAND]") - ] - - # Assert CommandLineWorker ran as expected - expected_save_calls = [call(command_format) for _ in self.expected] - self.assertEqual(patched_os_system.mock_calls, expected_save_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - log_fmt = "[COMMAND]: Detection {id} command '{command}'" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(i, command_format) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_PrintWorker(self): - observers = [ - PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") - ] +def test_PlayerWorker(audio_data_source, expected_detections): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_RegionSaverWorker") + player_mock = Mock() + observers = [PlayerWorker(player_mock, logger=logger)] tokenizer = TokenizerWorker( - self.reader, + audio_data_source, + logger=logger, observers=observers, min_dur=0.3, max_dur=2, @@ -234,121 +89,248 @@ strict_min_dur=False, eth=50, ) - with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[PLAY]") + ] + + assert player_mock.play.called + assert len(tokenizer.detections) == len(expected_detections) + log_fmt = "[PLAY]: Detection {id} played" + for i, (det, exp, log_line) in enumerate( + zip(tokenizer.detections, expected_detections, log_lines), 1 + ): + start, end = exp + exp_log_line = log_fmt.format(id=i) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_RegionSaverWorker(audio_data_source, expected_detections): + filename_format = "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_RegionSaverWorker") + observers = [RegionSaverWorker(filename_format, logger=logger)] + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("auditok.core.AudioRegion.save") as patched_save: tokenizer.start_all() tokenizer.join() tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[SAVE]") + ] - # Assert PrintWorker ran as expected - expected_print_calls = [ - call( - "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( - i, exp[0], exp[1], exp[1] - exp[0] - ) - ) - for i, exp in enumerate(self.expected, 1) - ] - self.assertEqual(patched_print.mock_calls, expected_print_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - for det, exp in zip(tokenizer.detections, self.expected): - start, end = exp - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) + expected_save_calls = [ + call( + filename_format.format( + id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] + ), + None, + ) + for i, exp in enumerate(expected_detections, 1) + ] - def test_StreamSaverWorker_wav(self): - with TemporaryDirectory() as tmpdir: - expected_filename = os.path.join(tmpdir, "output.wav") - saver = StreamSaverWorker(self.reader, expected_filename) - saver.start() + mock_calls = [ + c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 + ] + assert mock_calls == expected_save_calls + assert len(tokenizer.detections) == len(expected_detections) - tokenizer = TokenizerWorker(saver) + log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" + for i, (det, exp, log_line) in enumerate( + zip(tokenizer.detections, expected_detections, log_lines), 1 + ): + start, end = exp + expected_filename = filename_format.format( + id=i, start=start, end=end, duration=end - start + ) + exp_log_line = log_fmt.format(id=i, filename=expected_filename) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_CommandLineWorker(audio_data_source, expected_detections): + command_format = "do nothing with" + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_CommandLineWorker") + observers = [CommandLineWorker(command_format, logger=logger)] + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("auditok.workers.os.system") as patched_os_system: tokenizer.start_all() tokenizer.join() - saver.join() + tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[COMMAND]") + ] - output_filename = saver.save_stream() - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + expected_save_calls = [call(command_format) for _ in expected_detections] + assert patched_os_system.mock_calls == expected_save_calls + assert len(tokenizer.detections) == len(expected_detections) + log_fmt = "[COMMAND]: Detection {id} command '{command}'" + for i, (det, exp, log_line) in enumerate( + zip(tokenizer.detections, expected_detections, log_lines), 1 + ): + start, end = exp + exp_log_line = log_fmt.format(id=i, command=command_format) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_PrintWorker(audio_data_source, expected_detections): + observers = [ + PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") + ] + tokenizer = TokenizerWorker( + audio_data_source, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + + expected_print_calls = [ + call( + "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( + i, exp[0], exp[1], exp[1] - exp[0] ) + ) + for i, exp in enumerate(expected_detections, 1) + ] + assert patched_print.mock_calls == expected_print_calls + assert len(tokenizer.detections) == len(expected_detections) + for det, exp in zip(tokenizer.detections, expected_detections): + start, end = exp + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end - expected_region = AudioRegion.load(output_filename) - self.assertEqual(output_filename, expected_filename) - self.assertEqual(region, expected_region) - self.assertEqual(saver.data, bytes(expected_region)) - def test_StreamSaverWorker_raw(self): - with TemporaryDirectory() as tmpdir: - expected_filename = os.path.join(tmpdir, "output") - saver = StreamSaverWorker( - self.reader, expected_filename, export_format="raw" - ) +def test_StreamSaverWorker_wav(audio_data_source): + with TemporaryDirectory() as tmpdir: + expected_filename = os.path.join(tmpdir, "output.wav") + saver = StreamSaverWorker(audio_data_source, expected_filename) + saver.start() + + tokenizer = TokenizerWorker(saver) + tokenizer.start_all() + tokenizer.join() + saver.join() + + output_filename = saver.save_stream() + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + + expected_region = AudioRegion.load(output_filename) + assert output_filename == expected_filename + assert region == expected_region + assert saver.data == bytes(expected_region) + + +def test_StreamSaverWorker_raw(audio_data_source): + with TemporaryDirectory() as tmpdir: + expected_filename = os.path.join(tmpdir, "output") + saver = StreamSaverWorker( + audio_data_source, expected_filename, export_format="raw" + ) + saver.start() + tokenizer = TokenizerWorker(saver) + tokenizer.start_all() + tokenizer.join() + saver.join() + output_filename = saver.save_stream() + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + expected_region = AudioRegion.load( + output_filename, sr=10, sw=2, ch=1, audio_format="raw" + ) + assert output_filename == expected_filename + assert region == expected_region + assert saver.data == bytes(expected_region) + + +def test_StreamSaverWorker_encode_audio(audio_data_source): + with TemporaryDirectory() as tmpdir: + with patch("auditok.workers._run_subprocess") as patch_rsp: + patch_rsp.return_value = (1, None, None) + expected_filename = os.path.join(tmpdir, "output.ogg") + tmp_expected_filename = expected_filename + ".wav" + saver = StreamSaverWorker(audio_data_source, expected_filename) saver.start() tokenizer = TokenizerWorker(saver) tokenizer.start_all() tokenizer.join() saver.join() - output_filename = saver.save_stream() - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 - ) - expected_region = AudioRegion.load( - output_filename, sr=10, sw=2, ch=1, audio_format="raw" - ) - self.assertEqual(output_filename, expected_filename) - self.assertEqual(region, expected_region) - self.assertEqual(saver.data, bytes(expected_region)) - - def test_StreamSaverWorker_encode_audio(self): - with TemporaryDirectory() as tmpdir: - with patch("auditok.workers._run_subprocess") as patch_rsp: - patch_rsp.return_value = (1, None, None) - expected_filename = os.path.join(tmpdir, "output.ogg") - tmp_expected_filename = expected_filename + ".wav" - saver = StreamSaverWorker(self.reader, expected_filename) - saver.start() - tokenizer = TokenizerWorker(saver) - tokenizer.start_all() - tokenizer.join() - saver.join() - with self.assertRaises(AudioEncodingWarning) as rt_warn: - saver.save_stream() - warn_msg = "Couldn't save audio data in the desired format " - warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' " - warn_msg += "is installed or this format is not recognized.\n" - warn_msg += "Audio file was saved as '{}'" - self.assertEqual( - warn_msg.format(tmp_expected_filename), str(rt_warn.exception) - ) - ffmpef_avconv = [ - "-y", - "-f", - "wav", - "-i", - tmp_expected_filename, - "-f", - "ogg", - expected_filename, - ] - expected_calls = [ - call(["ffmpeg"] + ffmpef_avconv), - call(["avconv"] + ffmpef_avconv), - call( - [ - "sox", - "-t", - "wav", - tmp_expected_filename, - expected_filename, - ] - ), - ] - self.assertEqual(patch_rsp.mock_calls, expected_calls) - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 - ) - self.assertTrue(saver._exported) - self.assertEqual(saver.data, bytes(region)) - - -if __name__ == "__main__": - unittest.main() + with pytest.raises(AudioEncodingWarning) as rt_warn: + saver.save_stream() + warn_msg = "Couldn't save audio data in the desired format " + warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' " + warn_msg += "is installed or this format is not recognized.\n" + warn_msg += "Audio file was saved as '{}'" + assert warn_msg.format(tmp_expected_filename) == str(rt_warn.value) + ffmpef_avconv = [ + "-y", + "-f", + "wav", + "-i", + tmp_expected_filename, + "-f", + "ogg", + expected_filename, + ] + expected_calls = [ + call(["ffmpeg"] + ffmpef_avconv), + call(["avconv"] + ffmpef_avconv), + call( + [ + "sox", + "-t", + "wav", + tmp_expected_filename, + expected_filename, + ] + ), + ] + assert patch_rsp.mock_calls == expected_calls + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + assert saver._exported + assert saver.data == bytes(region)