changeset 238:f16fc2c3d12b

Return all data in AudioSource if read called with None or a negative number
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 21 Jul 2019 16:20:20 +0100
parents c684f90cc3cd
children 6c3b56eb8052
files auditok/io.py tests/test_AudioSource.py
diffstat 2 files changed, 71 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Sat Jul 20 12:10:11 2019 +0100
+++ b/auditok/io.py	Sun Jul 21 16:20:20 2019 +0100
@@ -456,11 +456,12 @@
     def read(self, size):
         if not self._is_open:
             raise AudioIOError("Stream is not open")
-        bytes_to_read = self._sample_size_all_channels * size
-        data = self._buffer[
-            self._current_position_bytes : self._current_position_bytes
-            + bytes_to_read
-        ]
+        if size is None or size < 0:
+            offset = None
+        else:
+            bytes_to_read = self._sample_size_all_channels * size
+            offset = self._current_position_bytes + bytes_to_read
+        data = self._buffer[self._current_position_bytes : offset]
         if data:
             self._current_position_bytes += len(data)
             return data
@@ -587,7 +588,10 @@
             self._audio_stream = open(self._file, "rb")
 
     def _read_from_stream(self, size):
-        bytes_to_read = size * self._sample_size
+        if size is None or size < 0:
+            bytes_to_read = None
+        else:
+            bytes_to_read = size * self._sample_size
         data = self._audio_stream.read(bytes_to_read)
         return data
 
@@ -622,6 +626,8 @@
             self._audio_stream = wave.open(self._filename)
 
     def _read_from_stream(self, size):
+        if size is None or size < 0:
+            size = -1
         return self._audio_stream.readframes(size)
 
 
--- a/tests/test_AudioSource.py	Sat Jul 20 12:10:11 2019 +0100
+++ b/tests/test_AudioSource.py	Sun Jul 21 16:20:20 2019 +0100
@@ -44,6 +44,36 @@
             1600,
         ),
     )
+    def test_BufferAudioSource_read_all(
+        self, file_suffix, channels, use_channel, frequency
+    ):
+        file = "tests/data/test_16KHZ_{}.raw".format(file_suffix)
+        with open(file, "rb") as fp:
+            expected = fp.read()
+            audio_source = BufferAudioSource(expected, 16000, 2, channels)
+            audio_source.open()
+            data = audio_source.read(None)
+            self.assertEqual(data, expected)
+            audio_source.rewind()
+            data = audio_source.read(-10)
+            self.assertEqual(data, expected)
+            audio_source.close()
+
+
+    @genty_dataset(
+        mono_default=("mono_400Hz", 1, None, 400),
+        mono_mix=("mono_400Hz", 1, "mix", 400),
+        mono_channel_selection=("mono_400Hz", 1, 2, 400),
+        multichannel_default=("3channel_400-800-1600Hz", 3, None, 400),
+        multichannel_channel_select_1st=("3channel_400-800-1600Hz", 3, 1, 400),
+        multichannel_channel_select_2nd=("3channel_400-800-1600Hz", 3, 2, 800),
+        multichannel_channel_select_3rd=(
+            "3channel_400-800-1600Hz",
+            3,
+            3,
+            1600,
+        ),
+    )
     def test_RawAudioSource(
         self, file_suffix, channels, use_channel, frequency
     ):
@@ -55,6 +85,21 @@
         expected = _array_to_bytes(PURE_TONE_DICT[frequency])
         self.assertEqual(data, expected)
 
+        # assert read all data with None
+        audio_source = RawAudioSource(file, 16000, 2, channels, use_channel)
+        audio_source.open()
+        data_read_all = audio_source.read(None)
+        audio_source.close()
+        self.assertEqual(data_read_all, expected)
+
+        # assert read all data with a negative size
+        audio_source = RawAudioSource(file, 16000, 2, channels, use_channel)
+        audio_source.open()
+        data_read_all = audio_source.read(-10)
+        audio_source.close()
+        self.assertEqual(data_read_all, expected)
+
+
     def test_RawAudioSource_mix(self):
         file = "tests/data/test_16KHZ_3channel_400-800-1600Hz.raw"
         audio_source = RawAudioSource(file, 16000, 2, 3, use_channel="mix")
@@ -95,6 +140,20 @@
         expected = _array_to_bytes(PURE_TONE_DICT[frequency])
         self.assertEqual(data, expected)
 
+        # assert read all data with None
+        audio_source = WaveAudioSource(file, use_channel)
+        audio_source.open()
+        data_read_all = audio_source.read(None)
+        audio_source.close()
+        self.assertEqual(data_read_all, expected)
+
+        # assert read all data with a negative size
+        audio_source = WaveAudioSource(file, use_channel)
+        audio_source.open()
+        data_read_all = audio_source.read(-10)
+        audio_source.close()
+        self.assertEqual(data_read_all, expected)
+
     def test_WaveAudioSource_mix(self):
         file = "tests/data/test_16KHZ_3channel_400-800-1600Hz.wav"
         audio_source = WaveAudioSource(file, use_channel="mix")