Mercurial > hg > auditok
changeset 330:9665dc53c394
Rename test_AudioDataSource.py as test_AudioReader.py and refactor it
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Wed, 23 Oct 2019 21:24:33 +0100 |
parents | e0c7ae720cc6 |
children | 9741b52f194a |
files | tests/test_AudioDataSource.py tests/test_AudioReader.py tests/test_AudioSource.py |
diffstat | 3 files changed, 1086 insertions(+), 1095 deletions(-) [+] |
line wrap: on
line diff
--- a/tests/test_AudioDataSource.py Wed Oct 23 20:50:01 2019 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,1087 +0,0 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 - -""" - -import unittest -from functools import partial -import sys -import wave -from genty import genty, genty_dataset -from auditok import ( - dataset, - ADSFactory, - AudioDataSource, - AudioReader, - Recorder, - BufferAudioSource, - WaveAudioSource, - DuplicateArgument, -) - - -class TestADSFactoryFileAudioSource(unittest.TestCase): - def setUp(self): - self.audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - - def test_ADS_type(self): - - ads = ADSFactory.ads(audio_source=self.audio_source) - - self.assertIsInstance( - ads, - AudioDataSource, - msg="wrong type for ads object, expected: 'AudioDataSource', found: {0}".format( - type(ads) - ), - ) - - def test_default_block_size(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong default block_size, expected: 160, found: {0}".format(size), - ) - - def test_block_size(self): - ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) - size = ads.block_size - self.assertEqual( - size, - 512, - "Wrong block_size, expected: 512, found: {0}".format(size), - ) - - # with alias keyword - ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) - - def test_block_duration(self): - - ads = ADSFactory.ads( - audio_source=self.audio_source, block_dur=0.01 - ) # 10 ms - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) - - # with alias keyword - ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms - size = ads.block_size - self.assertEqual( - size, - 400, - "Wrong block_size, expected: 400, found: {0}".format(size), - ) - - def test_hop_duration(self): - - ads = ADSFactory.ads( - audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 - ) # 10 ms - size = ads.hop_size - self.assertEqual( - size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) - ) - - # with alias keyword - ads = ADSFactory.ads( - audio_source=self.audio_source, bd=0.025, hop_dur=0.015 - ) # 15 ms - size = ads.hop_size - self.assertEqual( - size, - 240, - "Wrong block_size, expected: 240, found: {0}".format(size), - ) - - def test_sampling_rate(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - srate = ads.sampling_rate - self.assertEqual( - srate, - 16000, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_sample_width(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_channels(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_read(self): - ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) - - ads.open() - ads_data = ads.read() - ads.close() - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(256) - audio_source.close() - - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from ads" - ) - - def test_Limiter_Deco_read(self): - # read a maximum of 0.75 seconds from audio source - ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) - - ads_data = [] - ads.open() - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(int(16000 * 0.75)) - audio_source.close() - - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from LimiterADS" - ) - - def test_Limiter_Deco_read_limit(self): - # read a maximum of 1.191 seconds from audio source - ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191) - total_samples = round(ads.sampling_rate * 1.191) - nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size) - total_samples_with_overlap = ( - nb_full_blocks * ads.block_size + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - total_read = 0 - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - - ads.close() - - self.assertEqual( - total_read, - expected_read_bytes, - "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format( - expected_read_bytes, total_read - ), - ) - - def test_Recorder_Deco_read(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, record=True, block_size=500 - ) - - ads_data = [] - ads.open() - for i in range(10): - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(500 * 10) - audio_source.close() - - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) - - def test_Recorder_Deco_is_rewindable(self): - ads = ADSFactory.ads(audio_source=self.audio_source, record=True) - - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) - - def test_Recorder_Deco_rewind_and_read(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, record=True, block_size=320 - ) - - ads.open() - for i in range(10): - ads.read() - - ads.rewind() - - # read all available data after rewind - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(320 * 10) - audio_source.close() - - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) - - def test_Overlap_Deco_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1714 - hop_size = 313 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=block_size, - hop_size=hop_size, - ) - - # Read all available data overlapping blocks - ads.open() - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for i, block in enumerate(ads_data): - - tmp = audio_source.read(block_size) - - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - - audio_source.position = (i + 1) * hop_size - - audio_source.close() - - def test_Limiter_Overlap_Deco_read(self): - - block_size = 256 - hop_size = 200 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=0.50, - block_size=block_size, - hop_size=hop_size, - ) - - # Read all available data overlapping blocks - ads.open() - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for i, block in enumerate(ads_data): - tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) - self.assertEqual( - len(block), - len(tmp), - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (i + 1) * hop_size - - audio_source.close() - - def test_Limiter_Overlap_Deco_read_limit(self): - - block_size = 313 - hop_size = 207 - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.932, - block_size=block_size, - hop_size=hop_size, - ) - - total_samples = round(ads.sampling_rate * 1.932) - first_read_size = block_size - next_read_size = block_size - hop_size - nb_next_blocks, last_block_size = divmod( - (total_samples - first_read_size), next_read_size - ) - total_samples_with_overlap = ( - first_read_size + next_read_size * nb_next_blocks + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - cache_size = (block_size - hop_size) * ads.sample_width * ads.channels - total_read = cache_size - - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - cache_size - - ads.close() - self.assertEqual( - total_read, - expected_read_bytes, - "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format( - expected_read_bytes, total_read - ), - ) - - def test_Recorder_Overlap_Deco_is_rewindable(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=320, - hop_size=160, - record=True, - ) - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) - - def test_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1600 - hop_size = 400 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1600 - hop_size = 400 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.50, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1000 - hop_size = 200 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.317, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - total_samples = round(ads.sampling_rate * 1.317) - first_read_size = block_size - next_read_size = block_size - hop_size - nb_next_blocks, last_block_size = divmod( - (total_samples - first_read_size), next_read_size - ) - total_samples_with_overlap = ( - first_read_size + next_read_size * nb_next_blocks + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - cache_size = (block_size - hop_size) * ads.sample_width * ads.channels - total_read = cache_size - - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - cache_size - - ads.close() - self.assertEqual( - total_read, - expected_read_bytes, - "Wrong data length read from LimiterADS, expected: {0}, found: {1}".format( - expected_read_bytes, total_read - ), - ) - - -class TestADSFactoryBufferAudioSource(unittest.TestCase): - def setUp(self): - self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" - self.ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - block_size=4, - ) - - def test_ADS_BAS_sampling_rate(self): - srate = self.ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_ADS_BAS_get_sample_width(self): - swidth = self.ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_ADS_BAS_get_channels(self): - channels = self.ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 5 - hop_size = 4 - - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - max_time=0.80, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Build a BufferAudioSource - audio_source = BufferAudioSource( - self.signal, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - block = ads.read() - - self.assertEqual( - block, - tmp, - "Unexpected block '{}' (N={}) read from OverlapADS".format( - block, i - ), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - -class TestADSFactoryAlias(unittest.TestCase): - def setUp(self): - self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" - - def test_sampling_rate_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sr=16, - sample_width=2, - channels=1, - block_dur=0.5, - ) - srate = ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_sampling_rate_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sr=16, - sampling_rate=16, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_sample_width_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sw=2, - channels=1, - block_dur=0.5, - ) - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_sample_width_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sw=2, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_channels_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - ch=1, - block_dur=4, - ) - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_channels_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - ch=1, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_size_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=8, - ) - size = ads.block_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), - ) - - def test_block_size_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=4, - block_size=4, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_duration_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - ) - # 0.75 ms = 0.75 * 16 = 12 - size = ads.block_size - self.assertEqual( - size, - 12, - "Wrong block_size set with a block_dur alias 'bd', expected: 8, found: {0}".format( - size - ), - ) - - def test_block_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=4, - block_dur=4, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_size_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=4, - bs=12, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_duration_alias(self): - - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - hd=0.5, - ) - size = ads.hop_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), - ) - - def test_hop_duration_duplicate(self): - - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - hd=0.5, - hop_dur=0.5, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_size_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=8, - hs=4, - hd=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_size_greater_than_block_size(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=4, - hs=8, - ) - self.assertRaises(ValueError, func) - - def test_filename_alias(self): - ads = ADSFactory.ads(fn=dataset.one_to_six_arabic_16000_mono_bc_noise) - - def test_filename_duplicate(self): - - func = partial( - ADSFactory.ads, - fn=dataset.one_to_six_arabic_16000_mono_bc_noise, - filename=dataset.one_to_six_arabic_16000_mono_bc_noise, - ) - self.assertRaises(DuplicateArgument, func) - - def test_data_buffer_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - db=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_max_time_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - mt=10, - block_dur=0.5, - ) - self.assertEqual( - ads.max_read, - 10, - "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( - ads.max_read - ), - ) - - def test_max_time_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - mt=True, - max_time=True, - ) - - self.assertRaises(DuplicateArgument, func) - - def test_record_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - rec=True, - block_dur=0.5, - ) - self.assertTrue( - ads.rewindable, "AudioDataSource.rewindable expected to be True" - ) - - def test_record_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - rec=True, - record=True, - ) - self.assertRaises(DuplicateArgument, func) - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): - - # Use arbitrary valid block_size and hop_size - block_size = 5 - hop_size = 4 - - ads = ADSFactory.ads( - db=self.signal, - sr=16, - sw=2, - ch=1, - mt=0.80, - bs=block_size, - hs=hop_size, - rec=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Build a BufferAudioSource - audio_source = BufferAudioSource( - self.signal, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from AudioDataSource to those read - # from an audio source with manual position definition - for j in range(i): - tmp = audio_source.read(block_size) - block = ads.read() - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - ads.close() - audio_source.close() - - -def _read_all_data(reader): - blocks = [] - while True: - data = reader.read() - if data is None: - break - blocks.append(data) - return b"".join(blocks) - - -@genty -class TestAudioReader(unittest.TestCase): - - # TODO move all tests here when backward compatibility - # with ADSFactory is dropped - - @genty_dataset( - mono=("mono_400", 0.5, 16000), - multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), - ) - def test_Limiter(self, file_id, max_read, size): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read(size) - - reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) - reader.open() - data = _read_all_data(reader) - reader.close() - self.assertEqual(data, expected) - - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - - reader = AudioReader(input_wav, block_dur=0.1, record=True) - reader.open() - data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder_alias(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - - reader = Recorder(input_wav, block_dur=0.1) - reader.open() - data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - -if __name__ == "__main__": - unittest.main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/test_AudioReader.py Wed Oct 23 21:24:33 2019 +0100 @@ -0,0 +1,1079 @@ +""" +@author: Amine Sehili <amine.sehili@gmail.com> +September 2015 + +""" + +import unittest +from functools import partial +import sys +import wave +from genty import genty, genty_dataset +from auditok import ( + dataset, + ADSFactory, + AudioDataSource, + AudioReader, + Recorder, + BufferAudioSource, + WaveAudioSource, + DuplicateArgument, +) + + +class TestADSFactoryFileAudioSource(unittest.TestCase): + def setUp(self): + self.audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + + def test_ADS_type(self): + + ads = ADSFactory.ads(audio_source=self.audio_source) + + err_msg = "wrong type for ads object, expected: 'AudioDataSource', " + err_msg += "found: {0}" + self.assertIsInstance( + ads, AudioDataSource, err_msg.format(type(ads)), + ) + + def test_default_block_size(self): + ads = ADSFactory.ads(audio_source=self.audio_source) + size = ads.block_size + self.assertEqual( + size, + 160, + "Wrong default block_size, expected: 160, found: {0}".format(size), + ) + + def test_block_size(self): + ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) + size = ads.block_size + self.assertEqual( + size, + 512, + "Wrong block_size, expected: 512, found: {0}".format(size), + ) + + # with alias keyword + ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) + size = ads.block_size + self.assertEqual( + size, + 160, + "Wrong block_size, expected: 160, found: {0}".format(size), + ) + + def test_block_duration(self): + + ads = ADSFactory.ads( + audio_source=self.audio_source, block_dur=0.01 + ) # 10 ms + size = ads.block_size + self.assertEqual( + size, + 160, + "Wrong block_size, expected: 160, found: {0}".format(size), + ) + + # with alias keyword + ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms + size = ads.block_size + self.assertEqual( + size, + 400, + "Wrong block_size, expected: 400, found: {0}".format(size), + ) + + def test_hop_duration(self): + + ads = ADSFactory.ads( + audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 + ) # 10 ms + size = ads.hop_size + self.assertEqual( + size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) + ) + + # with alias keyword + ads = ADSFactory.ads( + audio_source=self.audio_source, bd=0.025, hop_dur=0.015 + ) # 15 ms + size = ads.hop_size + self.assertEqual( + size, + 240, + "Wrong block_size, expected: 240, found: {0}".format(size), + ) + + def test_sampling_rate(self): + ads = ADSFactory.ads(audio_source=self.audio_source) + + srate = ads.sampling_rate + self.assertEqual( + srate, + 16000, + "Wrong sampling rate, expected: 16000, found: {0}".format(srate), + ) + + def test_sample_width(self): + ads = ADSFactory.ads(audio_source=self.audio_source) + + swidth = ads.sample_width + self.assertEqual( + swidth, + 2, + "Wrong sample width, expected: 2, found: {0}".format(swidth), + ) + + def test_channels(self): + ads = ADSFactory.ads(audio_source=self.audio_source) + + channels = ads.channels + self.assertEqual( + channels, + 1, + "Wrong number of channels, expected: 1, found: {0}".format( + channels + ), + ) + + def test_read(self): + ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) + + ads.open() + ads_data = ads.read() + ads.close() + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(256) + audio_source.close() + + self.assertEqual( + ads_data, audio_source_data, "Unexpected data read from ads" + ) + + def test_Limiter_Deco_read(self): + # read a maximum of 0.75 seconds from audio source + ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) + + ads_data = [] + ads.open() + while True: + block = ads.read() + if block is None: + break + ads_data.append(block) + ads.close() + ads_data = b"".join(ads_data) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(int(16000 * 0.75)) + audio_source.close() + + self.assertEqual( + ads_data, audio_source_data, "Unexpected data read from LimiterADS" + ) + + def test_Limiter_Deco_read_limit(self): + # read a maximum of 1.191 seconds from audio source + ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191) + total_samples = round(ads.sampling_rate * 1.191) + nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size) + total_samples_with_overlap = ( + nb_full_blocks * ads.block_size + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * ads.sw * ads.channels + ) + + total_read = 0 + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + total_read += len(block) + + ads.close() + err_msg = "Wrong data length read from LimiterADS, expected: {0}, " + err_msg += "found: {1}" + self.assertEqual( + total_read, + expected_read_bytes, + err_msg.format(expected_read_bytes, total_read), + ) + + def test_Recorder_Deco_read(self): + ads = ADSFactory.ads( + audio_source=self.audio_source, record=True, block_size=500 + ) + + ads_data = [] + ads.open() + for i in range(10): + block = ads.read() + if block is None: + break + ads_data.append(block) + ads.close() + ads_data = b"".join(ads_data) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(500 * 10) + audio_source.close() + + self.assertEqual( + ads_data, + audio_source_data, + "Unexpected data read from RecorderADS", + ) + + def test_Recorder_Deco_is_rewindable(self): + ads = ADSFactory.ads(audio_source=self.audio_source, record=True) + + self.assertTrue( + ads.rewindable, "RecorderADS.is_rewindable should return True" + ) + + def test_Recorder_Deco_rewind_and_read(self): + ads = ADSFactory.ads( + audio_source=self.audio_source, record=True, block_size=320 + ) + + ads.open() + for i in range(10): + ads.read() + + ads.rewind() + + # read all available data after rewind + ads_data = [] + while True: + block = ads.read() + if block is None: + break + ads_data.append(block) + ads.close() + ads_data = b"".join(ads_data) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(320 * 10) + audio_source.close() + + self.assertEqual( + ads_data, + audio_source_data, + "Unexpected data read from RecorderADS", + ) + + def test_Overlap_Deco_read(self): + + # Use arbitrary valid block_size and hop_size + block_size = 1714 + hop_size = 313 + + ads = ADSFactory.ads( + audio_source=self.audio_source, + block_size=block_size, + hop_size=hop_size, + ) + + # Read all available data overlapping blocks + ads.open() + ads_data = [] + while True: + block = ads.read() + if block is None: + break + ads_data.append(block) + ads.close() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read + # from an audio source with a manual position setting + for i, block in enumerate(ads_data): + + tmp = audio_source.read(block_size) + + self.assertEqual( + block, + tmp, + "Unexpected block (N={0}) read from OverlapADS".format(i), + ) + + audio_source.position = (i + 1) * hop_size + + audio_source.close() + + def test_Limiter_Overlap_Deco_read(self): + + block_size = 256 + hop_size = 200 + + ads = ADSFactory.ads( + audio_source=self.audio_source, + max_time=0.50, + block_size=block_size, + hop_size=hop_size, + ) + + # Read all available data overlapping blocks + ads.open() + ads_data = [] + while True: + block = ads.read() + if block is None: + break + ads_data.append(block) + ads.close() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read + # from an audio source with a manual position setting + for i, block in enumerate(ads_data): + tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) + self.assertEqual( + len(block), + len(tmp), + "Unexpected block (N={0}) read from OverlapADS".format(i), + ) + audio_source.position = (i + 1) * hop_size + + audio_source.close() + + def test_Limiter_Overlap_Deco_read_limit(self): + + block_size = 313 + hop_size = 207 + ads = ADSFactory.ads( + audio_source=self.audio_source, + max_time=1.932, + block_size=block_size, + hop_size=hop_size, + ) + + total_samples = round(ads.sampling_rate * 1.932) + first_read_size = block_size + next_read_size = block_size - hop_size + nb_next_blocks, last_block_size = divmod( + (total_samples - first_read_size), next_read_size + ) + total_samples_with_overlap = ( + first_read_size + next_read_size * nb_next_blocks + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * ads.sw * ads.channels + ) + + cache_size = (block_size - hop_size) * ads.sample_width * ads.channels + total_read = cache_size + + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + total_read += len(block) - cache_size + + ads.close() + err_msg = "Wrong data length read from LimiterADS, expected: {0}, " + err_msg += "found: {1}" + self.assertEqual( + total_read, + expected_read_bytes, + err_msg.format(expected_read_bytes, total_read), + ) + + def test_Recorder_Overlap_Deco_is_rewindable(self): + ads = ADSFactory.ads( + audio_source=self.audio_source, + block_size=320, + hop_size=160, + record=True, + ) + self.assertTrue( + ads.rewindable, "RecorderADS.is_rewindable should return True" + ) + + def test_Recorder_Overlap_Deco_rewind_and_read(self): + + # Use arbitrary valid block_size and hop_size + block_size = 1600 + hop_size = 400 + + ads = ADSFactory.ads( + audio_source=self.audio_source, + block_size=block_size, + hop_size=hop_size, + record=True, + ) + + # Read all available data overlapping blocks + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + + ads.rewind() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read + # from an audio source with a manual position setting + for j in range(i): + + tmp = audio_source.read(block_size) + + self.assertEqual( + ads.read(), + tmp, + "Unexpected block (N={0}) read from OverlapADS".format(i), + ) + audio_source.position = (j + 1) * hop_size + + ads.close() + audio_source.close() + + def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): + + # Use arbitrary valid block_size and hop_size + block_size = 1600 + hop_size = 400 + + ads = ADSFactory.ads( + audio_source=self.audio_source, + max_time=1.50, + block_size=block_size, + hop_size=hop_size, + record=True, + ) + + # Read all available data overlapping blocks + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + + ads.rewind() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read + # from an audio source with a manual position setting + for j in range(i): + + tmp = audio_source.read(block_size) + + self.assertEqual( + ads.read(), + tmp, + "Unexpected block (N={0}) read from OverlapADS".format(i), + ) + audio_source.position = (j + 1) * hop_size + + ads.close() + audio_source.close() + + def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): + + # Use arbitrary valid block_size and hop_size + block_size = 1000 + hop_size = 200 + + ads = ADSFactory.ads( + audio_source=self.audio_source, + max_time=1.317, + block_size=block_size, + hop_size=hop_size, + record=True, + ) + total_samples = round(ads.sampling_rate * 1.317) + first_read_size = block_size + next_read_size = block_size - hop_size + nb_next_blocks, last_block_size = divmod( + (total_samples - first_read_size), next_read_size + ) + total_samples_with_overlap = ( + first_read_size + next_read_size * nb_next_blocks + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * ads.sw * ads.channels + ) + + cache_size = (block_size - hop_size) * ads.sample_width * ads.channels + total_read = cache_size + + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + total_read += len(block) - cache_size + + ads.close() + err_msg = "Wrong data length read from LimiterADS, expected: {0}, " + err_msg += "found: {1}" + self.assertEqual( + total_read, + expected_read_bytes, + err_msg.format(expected_read_bytes, total_read), + ) + + +class TestADSFactoryBufferAudioSource(unittest.TestCase): + def setUp(self): + self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" + self.ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + block_size=4, + ) + + def test_ADS_BAS_sampling_rate(self): + srate = self.ads.sampling_rate + self.assertEqual( + srate, + 16, + "Wrong sampling rate, expected: 16000, found: {0}".format(srate), + ) + + def test_ADS_BAS_get_sample_width(self): + swidth = self.ads.sample_width + self.assertEqual( + swidth, + 2, + "Wrong sample width, expected: 2, found: {0}".format(swidth), + ) + + def test_ADS_BAS_get_channels(self): + channels = self.ads.channels + self.assertEqual( + channels, + 1, + "Wrong number of channels, expected: 1, found: {0}".format( + channels + ), + ) + + def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): + + # Use arbitrary valid block_size and hop_size + block_size = 5 + hop_size = 4 + + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + max_time=0.80, + block_size=block_size, + hop_size=hop_size, + record=True, + ) + + # Read all available data overlapping blocks + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + + ads.rewind() + + # Build a BufferAudioSource + audio_source = BufferAudioSource( + self.signal, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read + # from an audio source with a manual position setting + for j in range(i): + + tmp = audio_source.read(block_size) + + block = ads.read() + + self.assertEqual( + block, + tmp, + "Unexpected block '{}' (N={}) read from OverlapADS".format( + block, i + ), + ) + audio_source.position = (j + 1) * hop_size + + ads.close() + audio_source.close() + + +class TestADSFactoryAlias(unittest.TestCase): + def setUp(self): + self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" + + def test_sampling_rate_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sr=16, + sample_width=2, + channels=1, + block_dur=0.5, + ) + srate = ads.sampling_rate + self.assertEqual( + srate, + 16, + "Wrong sampling rate, expected: 16000, found: {0}".format(srate), + ) + + def test_sampling_rate_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sr=16, + sampling_rate=16, + sample_width=2, + channels=1, + ) + self.assertRaises(DuplicateArgument, func) + + def test_sample_width_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sw=2, + channels=1, + block_dur=0.5, + ) + swidth = ads.sample_width + self.assertEqual( + swidth, + 2, + "Wrong sample width, expected: 2, found: {0}".format(swidth), + ) + + def test_sample_width_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sw=2, + sample_width=2, + channels=1, + ) + self.assertRaises(DuplicateArgument, func) + + def test_channels_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + ch=1, + block_dur=4, + ) + channels = ads.channels + self.assertEqual( + channels, + 1, + "Wrong number of channels, expected: 1, found: {0}".format( + channels + ), + ) + + def test_channels_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + ch=1, + channels=1, + ) + self.assertRaises(DuplicateArgument, func) + + def test_block_size_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bs=8, + ) + size = ads.block_size + self.assertEqual( + size, + 8, + "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size + ), + ) + + def test_block_size_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bs=4, + block_size=4, + ) + self.assertRaises(DuplicateArgument, func) + + def test_block_duration_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bd=0.75, + ) + # 0.75 ms = 0.75 * 16 = 12 + size = ads.block_size + err_msg = "Wrong block_size set with a block_dur alias 'bd', " + err_msg += "expected: 8, found: {0}" + self.assertEqual( + size, 12, err_msg.format(size), + ) + + def test_block_duration_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bd=4, + block_dur=4, + ) + self.assertRaises(DuplicateArgument, func) + + def test_block_size_duration_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bd=4, + bs=12, + ) + self.assertRaises(DuplicateArgument, func) + + def test_hop_duration_alias(self): + + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bd=0.75, + hd=0.5, + ) + size = ads.hop_size + self.assertEqual( + size, + 8, + "Wrong block_size using bs alias, expected: 8, found: {0}".format( + size + ), + ) + + def test_hop_duration_duplicate(self): + + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bd=0.75, + hd=0.5, + hop_dur=0.5, + ) + self.assertRaises(DuplicateArgument, func) + + def test_hop_size_duration_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bs=8, + hs=4, + hd=1, + ) + self.assertRaises(DuplicateArgument, func) + + def test_hop_size_greater_than_block_size(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + bs=4, + hs=8, + ) + self.assertRaises(ValueError, func) + + def test_filename_duplicate(self): + + func = partial( + ADSFactory.ads, + fn=dataset.one_to_six_arabic_16000_mono_bc_noise, + filename=dataset.one_to_six_arabic_16000_mono_bc_noise, + ) + self.assertRaises(DuplicateArgument, func) + + def test_data_buffer_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + db=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + ) + self.assertRaises(DuplicateArgument, func) + + def test_max_time_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + mt=10, + block_dur=0.5, + ) + self.assertEqual( + ads.max_read, + 10, + "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( + ads.max_read + ), + ) + + def test_max_time_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + mt=True, + max_time=True, + ) + + self.assertRaises(DuplicateArgument, func) + + def test_record_alias(self): + ads = ADSFactory.ads( + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + rec=True, + block_dur=0.5, + ) + self.assertTrue( + ads.rewindable, "AudioDataSource.rewindable expected to be True" + ) + + def test_record_duplicate(self): + func = partial( + ADSFactory.ads, + data_buffer=self.signal, + sampling_rate=16, + sample_width=2, + channels=1, + rec=True, + record=True, + ) + self.assertRaises(DuplicateArgument, func) + + def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): + + # Use arbitrary valid block_size and hop_size + block_size = 5 + hop_size = 4 + + ads = ADSFactory.ads( + db=self.signal, + sr=16, + sw=2, + ch=1, + mt=0.80, + bs=block_size, + hs=hop_size, + rec=True, + ) + + # Read all available data overlapping blocks + ads.open() + i = 0 + while True: + block = ads.read() + if block is None: + break + i += 1 + + ads.rewind() + + # Build a BufferAudioSource + audio_source = BufferAudioSource( + self.signal, ads.sampling_rate, ads.sample_width, ads.channels + ) + audio_source.open() + + # Compare all blocks read from AudioDataSource to those read + # from an audio source with manual position definition + for j in range(i): + tmp = audio_source.read(block_size) + block = ads.read() + self.assertEqual( + block, + tmp, + "Unexpected block (N={0}) read from OverlapADS".format(i), + ) + audio_source.position = (j + 1) * hop_size + ads.close() + audio_source.close() + + +def _read_all_data(reader): + blocks = [] + while True: + data = reader.read() + if data is None: + break + blocks.append(data) + return b"".join(blocks) + + +@genty +class TestAudioReader(unittest.TestCase): + + # TODO move all tests here when backward compatibility + # with ADSFactory is dropped + + @genty_dataset( + mono=("mono_400", 0.5, 16000), + multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), + ) + def test_Limiter(self, file_id, max_read, size): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read(size) + + reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) + reader.open() + data = _read_all_data(reader) + reader.close() + self.assertEqual(data, expected) + + @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) + def test_Recorder(self, file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = AudioReader(input_wav, block_dur=0.1, record=True) + reader.open() + data = _read_all_data(reader) + self.assertEqual(data, expected) + + # rewind many times + for _ in range(3): + reader.rewind() + data = _read_all_data(reader) + self.assertEqual(data, expected) + self.assertEqual(data, reader.data) + reader.close() + + @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) + def test_Recorder_alias(self, file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = Recorder(input_wav, block_dur=0.1) + reader.open() + data = _read_all_data(reader) + self.assertEqual(data, expected) + + # rewind many times + for _ in range(3): + reader.rewind() + data = _read_all_data(reader) + self.assertEqual(data, expected) + self.assertEqual(data, reader.data) + reader.close() + + +if __name__ == "__main__": + unittest.main()
--- a/tests/test_AudioSource.py Wed Oct 23 20:50:01 2019 +0200 +++ b/tests/test_AudioSource.py Wed Oct 23 21:24:33 2019 +0100 @@ -385,7 +385,7 @@ with self.assertRaises(IndexError): self.audio_source.position_ms = position_ms - def test_sr10_sw1_ch1_position_s_0(self): + def test_sr10_sw1_ch1_initial_position_s_0(self): tp = self.audio_source.position_s self.assertEqual( tp, @@ -393,8 +393,8 @@ msg="wrong time position, expected: 0.0, found: {0} ".format(tp), ) - def test_sr10_sw1_ch1_position_s_1(self): - srate = self.audio_source.get_sampling_rate() + def test_sr10_sw1_ch1_position_s_1_after_read(self): + srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s @@ -834,7 +834,7 @@ ), ) - def test_sr11_sw4_ch1_position_0(self): + def test_sr11_sw4_ch1_intial_position_0(self): pos = self.audio_source.position self.assertEqual( pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) @@ -872,7 +872,7 @@ msg="wrong position, expected: 10, found: {0} ".format(pos), ) - def test_sr11_sw4_ch1_position_s_0(self): + def test_sr11_sw4_ch1_initial_position_s_0(self): tp = self.audio_source.position_s self.assertEqual( tp, @@ -880,8 +880,8 @@ msg="wrong time position, expected: 0.0, found: {0} ".format(tp), ) - def test_sr11_sw4_ch1_position_s_1(self): - srate = self.audio_source.get_sampling_rate() + def test_sr11_sw4_ch1_position_s_1_after_read(self): + srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s @@ -913,7 +913,6 @@ ) def test_sr11_sw4_ch1_position_s_1(self): - self.audio_source.position_s = 1 tp = self.audio_source.position_s self.assertEqual(