Mercurial > hg > auditok
changeset 293:755bb58f3db0
Fix bug in AudioDataSource with max_read for multichannel audio
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sun, 06 Oct 2019 19:15:33 +0100 |
parents | 9907db0843cb |
children | 76b473409a46 |
files | auditok/util.py tests/test_AudioDataSource.py |
diffstat | 2 files changed, 33 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/util.py Sun Oct 06 18:46:04 2019 +0200 +++ b/auditok/util.py Sun Oct 06 19:15:33 2019 +0100 @@ -721,12 +721,13 @@ super(_Limiter, self).__init__(audio_source) self._max_read = max_read self._max_samples = round(max_read * self.sr) + self._bytes_per_sample = self.sw * self.ch self._read_samples = 0 @property def data(self): data = self._audio_source.data - max_read_bytes = self._max_samples * self.sw * self.ch + max_read_bytes = self._max_samples * self._bytes_per_sample return data[:max_read_bytes] @property @@ -737,12 +738,10 @@ size = min(self._max_samples - self._read_samples, size) if size <= 0: return None - block = self._audio_source.read(size) if block is None: return None - - self._read_samples += len(block) // self._audio_source.sw + self._read_samples += len(block) // self._bytes_per_sample return block def rewind(self):
--- a/tests/test_AudioDataSource.py Sun Oct 06 18:46:04 2019 +0200 +++ b/tests/test_AudioDataSource.py Sun Oct 06 19:15:33 2019 +0100 @@ -7,6 +7,8 @@ import unittest from functools import partial import sys +import wave +from genty import genty, genty_dataset from auditok import ( dataset, ADSFactory, @@ -15,14 +17,6 @@ WaveAudioSource, DuplicateArgument, ) -import wave - - -try: - from builtins import range -except ImportError: - if sys.version_info < (3, 0): - range = xrange class TestADSFactoryFileAudioSource(unittest.TestCase): @@ -1014,5 +1008,33 @@ audio_source.close() +@genty +class TestAudioReader(unittest.TestCase): + + # TODO move all tests here when backward compatibility + # with ADSFactory is dropped + + @genty_dataset( + mono=("mono_400", 0.5, 16000), + multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), + ) + def test_Limiter(self, file_id, max_read, size): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read(size) + + reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read) + reader.open() + blocks = [] + while True: + data = reader.read() + if data is None: + break + blocks.append(data) + data = b"".join(blocks) + self.assertEqual(data, expected) + + if __name__ == "__main__": unittest.main()