changeset 207:12e6837c5961

Add tests for split
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 12 Jun 2019 20:27:32 +0100
parents b10480e4453e
children 29472a5a798a
files auditok/core.py tests/data/test_split_10HZ_mono.raw tests/test_core.py
diffstat 3 files changed, 115 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Wed Jun 05 20:58:15 2019 +0100
+++ b/auditok/core.py	Wed Jun 12 20:27:32 2019 +0100
@@ -26,7 +26,6 @@
     max_silence=0.3,
     drop_trailing_silence=False,
     strict_min_dur=False,
-    analysis_window=0.01,
     **kwargs
 ):
     """Splits audio data and returns a generator of `AudioRegion`s
@@ -90,11 +89,13 @@
     """
     if isinstance(input, AudioDataSource):
         source = input
+        analysis_window = source.block_dur
     else:
-        block_dur = kwargs.get("analysis_window", DEFAULT_ANALYSIS_WINDOW)
+        analysis_window = kwargs.get(
+            "analysis_window", DEFAULT_ANALYSIS_WINDOW
+        )
         max_read = kwargs.get("max_read")
         params = kwargs.copy()
-        print(isinstance(input, AudioRegion))
         if isinstance(input, AudioRegion):
             params["sampling_rate"] = input.sr
             params["sample_width"] = input.sw
@@ -102,7 +103,7 @@
             input = bytes(input)
 
         source = AudioDataSource(
-            input, block_dur=block_dur, max_read=max_read, **params
+            input, block_dur=analysis_window, max_read=max_read, **params
         )
 
     validator = kwargs.get("validator")
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/data/test_split_10HZ_mono.raw	Wed Jun 12 20:27:32 2019 +0100
@@ -0,0 +1,1 @@
+,,@@@@@@@@@@@@,,,XXXXXXX,,@@@,,,,,@@@@@@@@@@@@@@@@@@@@XXXXXXXXXXXXXXXXXXXXXX
\ No newline at end of file
--- a/tests/test_core.py	Wed Jun 05 20:58:15 2019 +0100
+++ b/tests/test_core.py	Wed Jun 12 20:27:32 2019 +0100
@@ -1,9 +1,9 @@
 import os
-import unittest
+from unittest import TestCase
 from random import random
 from tempfile import TemporaryDirectory
 from genty import genty, genty_dataset
-from auditok import AudioRegion, AudioParameterError
+from auditok import split, AudioRegion, AudioParameterError
 
 
 def _make_random_length_regions(
@@ -22,7 +22,113 @@
 
 
 @genty
-class TestAudioRegion(unittest.TestCase):
+class TestSplit(TestCase):
+    @genty_dataset(
+        simple=(
+            0.2,
+            5,
+            0.2,
+            False,
+            False,
+            {"eth": 50},
+            [(2, 16), (17, 31), (34, 76)],
+        ),
+        low_energy_threshold=(
+            0.2,
+            5,
+            0.2,
+            False,
+            False,
+            {"energy_threshold": 40},
+            [(0, 50), (50, 76)],
+        ),
+        high_energy_threshold=(
+            0.2,
+            5,
+            0.2,
+            False,
+            False,
+            {"energy_threshold": 60},
+            [],
+        ),
+        trim_leading_and_trailing_silence=(
+            0.2,
+            10,  # use long max_dur
+            0.5,  # and a max_silence longer than any inter-region silence
+            True,
+            False,
+            {"eth": 50},
+            [(2, 76)],
+        ),
+        drop_trailing_silence=(
+            0.2,
+            5,
+            0.2,
+            True,
+            False,
+            {"eth": 50},
+            [(2, 14), (17, 29), (34, 76)],
+        ),
+        drop_trailing_silence_2=(
+            1.5,
+            5,
+            0.2,
+            True,
+            False,
+            {"eth": 50},
+            [(34, 76)],
+        ),
+        strict_min_dur=(
+            0.3,
+            2,
+            0.2,
+            False,
+            True,
+            {"eth": 50},
+            [(2, 16), (17, 31), (34, 54), (54, 74)],
+        ),
+    )
+    def test_split_params(
+        self,
+        min_dur,
+        max_dur,
+        max_silence,
+        drop_trailing_silence,
+        strict_min_dur,
+        kwargs,
+        expected,
+    ):
+        with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp:
+            data = fp.read()
+
+        regions = split(
+            data,
+            min_dur,
+            max_dur,
+            max_silence,
+            drop_trailing_silence,
+            strict_min_dur,
+            analysis_window=0.1,
+            sr=10,
+            sw=2,
+            ch=1,
+            **kwargs
+        )
+        regions = list(regions)
+        print(regions)
+        err_msg = "Wrong number of regions after split, expected: "
+        err_msg += "{}, found: {}".format(len(regions), len(expected))
+        self.assertEqual(len(regions), len(expected), err_msg)
+
+        sample_width = 2
+        for reg, exp in zip(regions, expected):
+            onset, offset = exp
+            exp_data = data[onset * sample_width : offset * sample_width]
+            self.assertEqual(bytes(reg), exp_data)
+
+
+@genty
+class TestAudioRegion(TestCase):
     @genty_dataset(
         simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000),
         one_ms_less_than_1_sec=(