changeset 294:76b473409a46

Fix bug in AudioDataSource with recorder=True
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 06 Oct 2019 21:12:02 +0200
parents 755bb58f3db0
children 49082909193c
files auditok/util.py tests/test_AudioDataSource.py
diffstat 2 files changed, 39 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/util.py	Sun Oct 06 19:15:33 2019 +0100
+++ b/auditok/util.py	Sun Oct 06 21:12:02 2019 +0200
@@ -665,6 +665,7 @@
         super(_Recorder, self).__init__(audio_source)
         self._cache = []
         self._read_block = self._read_and_cache
+        self._read_from_cache = False
         self._data = None
 
     def read(self, size):
@@ -682,14 +683,17 @@
         return True
 
     def rewind(self):
-        if self._cache:
-            self._data = self._concatenate(self._cache)
+        if self._read_from_cache:
+            self._audio_source.rewind()
+        else:
+            self._data = b"".join(self._cache)
             self._cache = None
             self._audio_source = BufferAudioSource(
                 self._data, self.sr, self.sw, self.ch
             )
             self._read_block = self._audio_source.read
             self.open()
+            self._read_from_cache = True
 
     def _read_and_cache(self, size):
         # Read and save read data
@@ -698,17 +702,6 @@
             self._cache.append(block)
         return block
 
-    def _concatenate(self, data):
-        try:
-            # should always work for python 2
-            # work for python 3 ONLY if data is a list (or an iterator)
-            # whose each element is a 'bytes' objects
-            data = b"".join(data)
-            return data
-        except TypeError:
-            # work for 'str' in python 2 and python 3
-            return "".join(data)
-
 
 class _Limiter(_AudioSourceProxy):
     """
--- a/tests/test_AudioDataSource.py	Sun Oct 06 19:15:33 2019 +0100
+++ b/tests/test_AudioDataSource.py	Sun Oct 06 21:12:02 2019 +0200
@@ -670,7 +670,7 @@
 
 class TestADSFactoryAlias(unittest.TestCase):
     def setUp(self):
-        self.signal = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
+        self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
 
     def test_sampling_rate_alias(self):
         ads = ADSFactory.ads(
@@ -1008,6 +1008,16 @@
         audio_source.close()
 
 
+def _read_all_data(reader):
+    blocks = []
+    while True:
+        data = reader.read()
+        if data is None:
+            break
+        blocks.append(data)
+    return b"".join(blocks)
+
+
 @genty
 class TestAudioReader(unittest.TestCase):
 
@@ -1026,15 +1036,30 @@
 
         reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read)
         reader.open()
-        blocks = []
-        while True:
-            data = reader.read()
-            if data is None:
-                break
-            blocks.append(data)
-        data = b"".join(blocks)
+        data = _read_all_data(reader)
+        reader.close()
         self.assertEqual(data, expected)
 
+    @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",))
+    def test_Recorder(self, file_id):
+        input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
+        input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
+        with open(input_raw, "rb") as fp:
+            expected = fp.read()
+
+        reader = AudioDataSource(input_wav, block_dur=0.1, record=True)
+        reader.open()
+        data = _read_all_data(reader)
+        self.assertEqual(data, expected)
+
+        # rewind many times
+        for _ in range(3):
+            reader.rewind()
+            data = _read_all_data(reader)
+            self.assertEqual(data, expected)
+            self.assertEqual(data, reader.data)
+        reader.close()
+
 
 if __name__ == "__main__":
     unittest.main()