changeset 212:de60431f343b

Add tests for split with different input types
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 30 Jun 2019 15:37:12 +0100
parents ed6b3cecb407
children 6ff5411ef661
files auditok/core.py tests/test_core.py
diffstat 2 files changed, 76 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Wed Jun 26 20:20:55 2019 +0100
+++ b/auditok/core.py	Sun Jun 30 15:37:12 2019 +0100
@@ -97,6 +97,7 @@
 
         params = kwargs.copy()
         params["max_read"] = params.get("max_read", params.get("mr"))
+        params["audio_format"] = params.get("audio_format", params.get("fmt"))
         if isinstance(input, AudioRegion):
             params["sampling_rate"] = input.sr
             params["sample_width"] = input.sw
--- a/tests/test_core.py	Wed Jun 26 20:20:55 2019 +0100
+++ b/tests/test_core.py	Sun Jun 30 15:37:12 2019 +0100
@@ -4,6 +4,7 @@
 from tempfile import TemporaryDirectory
 from genty import genty, genty_dataset
 from auditok import split, AudioRegion, AudioParameterError
+from auditok.util import AudioDataSource
 from auditok.io import (
     _normalize_use_channel,
     _extract_selected_channel,
@@ -203,7 +204,7 @@
             ch=channels,
             **kwargs
         )
-
+        regions = list(regions)
         sample_width = 2
         import numpy as np
 
@@ -215,12 +216,83 @@
             data = _extract_selected_channel(
                 data, channels, sample_width, use_channel=use_channel
             )
-
-        regions = list(regions)
         err_msg = "Wrong number of regions after split, expected: "
         err_msg += "{}, found: {}".format(expected, regions)
         self.assertEqual(len(regions), len(expected), err_msg)
+        for reg, exp in zip(regions, expected):
+            onset, offset = exp
+            exp_data = data[onset * sample_width : offset * sample_width]
+            self.assertEqual(bytes(reg), exp_data)
 
+    @genty_dataset(
+        filename_audio_format=(
+            "tests/data/test_split_10HZ_stereo.raw",
+            {"audio_format": "raw", "sr": 10, "sw": 2, "ch": 2},
+        ),
+        filename_audio_format_short_name=(
+            "tests/data/test_split_10HZ_stereo.raw",
+            {"fmt": "raw", "sr": 10, "sw": 2, "ch": 2},
+        ),
+        filename_no_audio_format=(
+            "tests/data/test_split_10HZ_stereo.raw",
+            {"sr": 10, "sw": 2, "ch": 2},
+        ),
+        filename_no_long_audio_params=(
+            "tests/data/test_split_10HZ_stereo.raw",
+            {"sampling_rate": 10, "sample_width": 2, "channels": 2},
+        ),
+        bytes_=(
+            open("tests/data/test_split_10HZ_stereo.raw", "rb").read(),
+            {"sr": 10, "sw": 2, "ch": 2},
+        ),
+        audio_reader=(
+            AudioDataSource(
+                "tests/data/test_split_10HZ_stereo.raw",
+                sr=10,
+                sw=2,
+                ch=2,
+                block_dur=0.1,
+            ),
+            {},
+        ),
+        audio_region=(
+            AudioRegion(
+                open("tests/data/test_split_10HZ_stereo.raw", "rb").read(),
+                0,
+                10,
+                2,
+                2,
+            ),
+            {},
+        ),
+        audio_source=(
+            get_audio_source(
+                "tests/data/test_split_10HZ_stereo.raw", sr=10, sw=2, ch=2
+            ),
+            {},
+        ),
+    )
+    def test_split_input_type(self, input, kwargs):
+
+        with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp:
+            data = fp.read()
+
+        regions = split(
+            input,
+            min_dur=0.2,
+            max_dur=5,
+            max_silence=0.2,
+            drop_trailing_silence=False,
+            strict_min_dur=False,
+            analysis_window=0.1,
+            **kwargs
+        )
+        regions = list(regions)
+        expected = [(2, 16), (17, 31), (34, 76)]
+        sample_width = 2
+        err_msg = "Wrong number of regions after split, expected: "
+        err_msg += "{}, found: {}".format(expected, regions)
+        self.assertEqual(len(regions), len(expected), err_msg)
         for reg, exp in zip(regions, expected):
             onset, offset = exp
             exp_data = data[onset * sample_width : offset * sample_width]