changeset 293:755bb58f3db0

Fix bug in AudioDataSource with max_read for multichannel audio
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 06 Oct 2019 19:15:33 +0100
parents 9907db0843cb
children 76b473409a46
files auditok/util.py tests/test_AudioDataSource.py
diffstat 2 files changed, 33 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/util.py	Sun Oct 06 18:46:04 2019 +0200
+++ b/auditok/util.py	Sun Oct 06 19:15:33 2019 +0100
@@ -721,12 +721,13 @@
         super(_Limiter, self).__init__(audio_source)
         self._max_read = max_read
         self._max_samples = round(max_read * self.sr)
+        self._bytes_per_sample = self.sw * self.ch
         self._read_samples = 0
 
     @property
     def data(self):
         data = self._audio_source.data
-        max_read_bytes = self._max_samples * self.sw * self.ch
+        max_read_bytes = self._max_samples * self._bytes_per_sample
         return data[:max_read_bytes]
 
     @property
@@ -737,12 +738,10 @@
         size = min(self._max_samples - self._read_samples, size)
         if size <= 0:
             return None
-
         block = self._audio_source.read(size)
         if block is None:
             return None
-
-        self._read_samples += len(block) // self._audio_source.sw
+        self._read_samples += len(block) // self._bytes_per_sample
         return block
 
     def rewind(self):
--- a/tests/test_AudioDataSource.py	Sun Oct 06 18:46:04 2019 +0200
+++ b/tests/test_AudioDataSource.py	Sun Oct 06 19:15:33 2019 +0100
@@ -7,6 +7,8 @@
 import unittest
 from functools import partial
 import sys
+import wave
+from genty import genty, genty_dataset
 from auditok import (
     dataset,
     ADSFactory,
@@ -15,14 +17,6 @@
     WaveAudioSource,
     DuplicateArgument,
 )
-import wave
-
-
-try:
-    from builtins import range
-except ImportError:
-    if sys.version_info < (3, 0):
-        range = xrange
 
 
 class TestADSFactoryFileAudioSource(unittest.TestCase):
@@ -1014,5 +1008,33 @@
         audio_source.close()
 
 
+@genty
+class TestAudioReader(unittest.TestCase):
+
+    # TODO move all tests here when backward compatibility
+    # with ADSFactory is dropped
+
+    @genty_dataset(
+        mono=("mono_400", 0.5, 16000),
+        multichannel=("3channel_400-800-1600", 0.5, 16000 * 3),
+    )
+    def test_Limiter(self, file_id, max_read, size):
+        input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
+        input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
+        with open(input_raw, "rb") as fp:
+            expected = fp.read(size)
+
+        reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read)
+        reader.open()
+        blocks = []
+        while True:
+            data = reader.read()
+            if data is None:
+                break
+            blocks.append(data)
+        data = b"".join(blocks)
+        self.assertEqual(data, expected)
+
+
 if __name__ == "__main__":
     unittest.main()