Mercurial > hg > auditok
changeset 238:f16fc2c3d12b
Return all data in AudioSource if read called with None or a negative number
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sun, 21 Jul 2019 16:20:20 +0100 |
parents | c684f90cc3cd |
children | 6c3b56eb8052 |
files | auditok/io.py tests/test_AudioSource.py |
diffstat | 2 files changed, 71 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/io.py Sat Jul 20 12:10:11 2019 +0100 +++ b/auditok/io.py Sun Jul 21 16:20:20 2019 +0100 @@ -456,11 +456,12 @@ def read(self, size): if not self._is_open: raise AudioIOError("Stream is not open") - bytes_to_read = self._sample_size_all_channels * size - data = self._buffer[ - self._current_position_bytes : self._current_position_bytes - + bytes_to_read - ] + if size is None or size < 0: + offset = None + else: + bytes_to_read = self._sample_size_all_channels * size + offset = self._current_position_bytes + bytes_to_read + data = self._buffer[self._current_position_bytes : offset] if data: self._current_position_bytes += len(data) return data @@ -587,7 +588,10 @@ self._audio_stream = open(self._file, "rb") def _read_from_stream(self, size): - bytes_to_read = size * self._sample_size + if size is None or size < 0: + bytes_to_read = None + else: + bytes_to_read = size * self._sample_size data = self._audio_stream.read(bytes_to_read) return data @@ -622,6 +626,8 @@ self._audio_stream = wave.open(self._filename) def _read_from_stream(self, size): + if size is None or size < 0: + size = -1 return self._audio_stream.readframes(size)
--- a/tests/test_AudioSource.py Sat Jul 20 12:10:11 2019 +0100 +++ b/tests/test_AudioSource.py Sun Jul 21 16:20:20 2019 +0100 @@ -44,6 +44,36 @@ 1600, ), ) + def test_BufferAudioSource_read_all( + self, file_suffix, channels, use_channel, frequency + ): + file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) + with open(file, "rb") as fp: + expected = fp.read() + audio_source = BufferAudioSource(expected, 16000, 2, channels) + audio_source.open() + data = audio_source.read(None) + self.assertEqual(data, expected) + audio_source.rewind() + data = audio_source.read(-10) + self.assertEqual(data, expected) + audio_source.close() + + + @genty_dataset( + mono_default=("mono_400Hz", 1, None, 400), + mono_mix=("mono_400Hz", 1, "mix", 400), + mono_channel_selection=("mono_400Hz", 1, 2, 400), + multichannel_default=("3channel_400-800-1600Hz", 3, None, 400), + multichannel_channel_select_1st=("3channel_400-800-1600Hz", 3, 1, 400), + multichannel_channel_select_2nd=("3channel_400-800-1600Hz", 3, 2, 800), + multichannel_channel_select_3rd=( + "3channel_400-800-1600Hz", + 3, + 3, + 1600, + ), + ) def test_RawAudioSource( self, file_suffix, channels, use_channel, frequency ): @@ -55,6 +85,21 @@ expected = _array_to_bytes(PURE_TONE_DICT[frequency]) self.assertEqual(data, expected) + # assert read all data with None + audio_source = RawAudioSource(file, 16000, 2, channels, use_channel) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + self.assertEqual(data_read_all, expected) + + # assert read all data with a negative size + audio_source = RawAudioSource(file, 16000, 2, channels, use_channel) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + self.assertEqual(data_read_all, expected) + + def test_RawAudioSource_mix(self): file = "tests/data/test_16KHZ_3channel_400-800-1600Hz.raw" audio_source = RawAudioSource(file, 16000, 2, 3, use_channel="mix") @@ -95,6 +140,20 @@ expected = _array_to_bytes(PURE_TONE_DICT[frequency]) self.assertEqual(data, expected) + # assert read all data with None + audio_source = WaveAudioSource(file, use_channel) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + self.assertEqual(data_read_all, expected) + + # assert read all data with a negative size + audio_source = WaveAudioSource(file, use_channel) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + self.assertEqual(data_read_all, expected) + def test_WaveAudioSource_mix(self): file = "tests/data/test_16KHZ_3channel_400-800-1600Hz.wav" audio_source = WaveAudioSource(file, use_channel="mix")