changeset 160:017994445d87

Add tests for RawAudioSource
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 27 Feb 2019 21:16:44 +0100
parents 3439ba35aba0
children e91d97f7a632
files tests/test_audio_source.py
diffstat 1 files changed, 60 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/tests/test_audio_source.py	Tue Feb 26 20:20:18 2019 +0100
+++ b/tests/test_audio_source.py	Wed Feb 27 21:16:44 2019 +0100
@@ -1,8 +1,67 @@
 """
 @author: Amine Sehili <amine.sehili@gmail.com>
 """
+from array import array
 import unittest
-from auditok import BufferAudioSource, AudioParameterError
+from genty import genty, genty_dataset
+from auditok.io import (
+    AudioParameterError,
+    _array_to_bytes,
+    DATA_FORMAT,
+    BufferAudioSource,
+    RawAudioSource,
+    WaveAudioSource,
+)
+from test_util import PURE_TONE_DICT
+
+
+def audio_source_read_all_gen(audio_source, size=None):
+    if size is None:
+        size = int(audio_source.sr * 0.1)  # 100ms
+    while True:
+        data = audio_source.read(size)
+        if data is None:
+            break
+        yield data
+
+
+@genty
+class TestAudioSource(unittest.TestCase):
+
+    # TODO when use_channel is None, return samples from all channels
+
+    @genty_dataset(
+        mono_default=("mono_400Hz", 1, None, 400),
+        mono_mix=("mono_400Hz", 1, "mix", 400),
+        mono_channel_selection=("mono_400Hz", 1, 2, 400),
+        multichannel_default=("3channel_400-800-1600Hz", 3, None, 400),
+        multichannel_channel_selection=("3channel_400-800-1600Hz", 3, 1, 800),
+    )
+    def test_RawAudioSource(
+        self, file_suffix, channels, use_channel, frequency
+    ):
+        file = "tests/data/test_16KHZ_{}.raw".format(file_suffix)
+        audio_source = RawAudioSource(file, 16000, 2, channels, use_channel)
+        audio_source.open()
+        data = b"".join(audio_source_read_all_gen(audio_source))
+        audio_source.close()
+        expected = _array_to_bytes(PURE_TONE_DICT[frequency])
+        self.assertEqual(data, expected)
+
+    def test_RawAudioSource_mix(self):
+        file = "tests/data/test_16KHZ_3channel_400-800-1600Hz.raw"
+        audio_source = RawAudioSource(file, 16000, 2, 3, use_channel="mix")
+        audio_source.open()
+        data = b"".join(audio_source_read_all_gen(audio_source))
+        audio_source.close()
+
+        mono_channels = [PURE_TONE_DICT[freq] for freq in [400, 800, 1600]]
+        fmt = DATA_FORMAT[2]
+        expected = _array_to_bytes(
+            array(fmt, (sum(samples) // 3 for samples in zip(*mono_channels)))
+        )
+        expected = expected
+        self.assertEqual(data, expected)
 
 
 class TestBufferAudioSource_SR10_SW1_CH1(unittest.TestCase):