Mercurial > hg > auditok
changeset 294:76b473409a46
Fix bug in AudioDataSource with recorder=True
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sun, 06 Oct 2019 21:12:02 +0200 |
parents | 755bb58f3db0 |
children | 49082909193c |
files | auditok/util.py tests/test_AudioDataSource.py |
diffstat | 2 files changed, 39 insertions(+), 21 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/util.py Sun Oct 06 19:15:33 2019 +0100 +++ b/auditok/util.py Sun Oct 06 21:12:02 2019 +0200 @@ -665,6 +665,7 @@ super(_Recorder, self).__init__(audio_source) self._cache = [] self._read_block = self._read_and_cache + self._read_from_cache = False self._data = None def read(self, size): @@ -682,14 +683,17 @@ return True def rewind(self): - if self._cache: - self._data = self._concatenate(self._cache) + if self._read_from_cache: + self._audio_source.rewind() + else: + self._data = b"".join(self._cache) self._cache = None self._audio_source = BufferAudioSource( self._data, self.sr, self.sw, self.ch ) self._read_block = self._audio_source.read self.open() + self._read_from_cache = True def _read_and_cache(self, size): # Read and save read data @@ -698,17 +702,6 @@ self._cache.append(block) return block - def _concatenate(self, data): - try: - # should always work for python 2 - # work for python 3 ONLY if data is a list (or an iterator) - # whose each element is a 'bytes' objects - data = b"".join(data) - return data - except TypeError: - # work for 'str' in python 2 and python 3 - return "".join(data) - class _Limiter(_AudioSourceProxy): """
--- a/tests/test_AudioDataSource.py Sun Oct 06 19:15:33 2019 +0100 +++ b/tests/test_AudioDataSource.py Sun Oct 06 21:12:02 2019 +0200 @@ -670,7 +670,7 @@ class TestADSFactoryAlias(unittest.TestCase): def setUp(self): - self.signal = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" + self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" def test_sampling_rate_alias(self): ads = ADSFactory.ads( @@ -1008,6 +1008,16 @@ audio_source.close() +def _read_all_data(reader): + blocks = [] + while True: + data = reader.read() + if data is None: + break + blocks.append(data) + return b"".join(blocks) + + @genty class TestAudioReader(unittest.TestCase): @@ -1026,15 +1036,30 @@ reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read) reader.open() - blocks = [] - while True: - data = reader.read() - if data is None: - break - blocks.append(data) - data = b"".join(blocks) + data = _read_all_data(reader) + reader.close() self.assertEqual(data, expected) + @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) + def test_Recorder(self, file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = AudioDataSource(input_wav, block_dur=0.1, record=True) + reader.open() + data = _read_all_data(reader) + self.assertEqual(data, expected) + + # rewind many times + for _ in range(3): + reader.rewind() + data = _read_all_data(reader) + self.assertEqual(data, expected) + self.assertEqual(data, reader.data) + reader.close() + if __name__ == "__main__": unittest.main()