changeset 299:73989d247f4e

Add test for callable validator in split
author Amine Sehili <amine.sehili@gmail.com>
date Tue, 08 Oct 2019 20:16:11 +0100
parents d5cbf4fc1416
children 1732213b290a
files tests/test_core.py
diffstat 1 files changed, 218 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/tests/test_core.py	Tue Oct 08 20:01:34 2019 +0100
+++ b/tests/test_core.py	Tue Oct 08 20:16:11 2019 +0100
@@ -16,7 +16,9 @@
 )
 
 
-def _make_random_length_regions(byte_seq, sampling_rate, sample_width, channels):
+def _make_random_length_regions(
+    byte_seq, sampling_rate, sample_width, channels
+):
     regions = []
     for b in byte_seq:
         duration = round(random() * 10, 6)
@@ -57,7 +59,15 @@
 @genty
 class TestSplit(TestCase):
     @genty_dataset(
-        simple=(0.2, 5, 0.2, False, False, {"eth": 50}, [(2, 16), (17, 31), (34, 76)]),
+        simple=(
+            0.2,
+            5,
+            0.2,
+            False,
+            False,
+            {"eth": 50},
+            [(2, 16), (17, 31), (34, 76)],
+        ),
         short_max_dur=(
             0.3,
             2,
@@ -87,7 +97,15 @@
             {"energy_threshold": 40},
             [(0, 50), (50, 76)],
         ),
-        high_energy_threshold=(0.2, 5, 0.2, False, False, {"energy_threshold": 60}, []),
+        high_energy_threshold=(
+            0.2,
+            5,
+            0.2,
+            False,
+            False,
+            {"energy_threshold": 60},
+            [],
+        ),
         trim_leading_and_trailing_silence=(
             0.2,
             10,  # use long max_dur
@@ -106,7 +124,15 @@
             {"eth": 50},
             [(2, 14), (17, 29), (34, 76)],
         ),
-        drop_trailing_silence_2=(1.5, 5, 0.2, True, False, {"eth": 50}, [(34, 76)]),
+        drop_trailing_silence_2=(
+            1.5,
+            5,
+            0.2,
+            True,
+            False,
+            {"eth": 50},
+            [(34, 76)],
+        ),
         strict_min_dur=(
             0.3,
             2,
@@ -170,7 +196,7 @@
             exp_data = data[onset * sample_width : offset * sample_width]
             self.assertEqual(bytes(reg), exp_data)
             self.assertEqual(reg, reg_ar)
-  
+
     @genty_dataset(
         stereo_all_default=(2, {}, [(2, 32), (34, 76)]),
         mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]),
@@ -191,7 +217,11 @@
             {"eth": 50, "use_channel": 0},
             [(2, 16), (17, 31), (34, 76)],
         ),
-        stereo_use_channel_no_use_channel_given=(2, {"eth": 50}, [(2, 32), (34, 76)]),
+        stereo_use_channel_no_use_channel_given=(
+            2,
+            {"eth": 50},
+            [(2, 32), (34, 76)],
+        ),
         stereo_use_channel_minus_2=(
             2,
             {"eth": 50, "use_channel": -2},
@@ -199,14 +229,22 @@
         ),
         stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]),
         stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]),
-        mono_uc_mix=(1, {"eth": 50, "uc": "mix"}, [(2, 16), (17, 31), (34, 76)]),
+        mono_uc_mix=(
+            1,
+            {"eth": 50, "uc": "mix"},
+            [(2, 16), (17, 31), (34, 76)],
+        ),
         stereo_use_channel_mix=(
             2,
             {"energy_threshold": 53.5, "use_channel": "mix"},
             [(54, 76)],
         ),
         stereo_uc_mix=(2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]),
-        stereo_uc_mix_default_eth=(2, {"uc": "mix"}, [(10, 16), (17, 31), (36, 76)]),
+        stereo_uc_mix_default_eth=(
+            2,
+            {"uc": "mix"},
+            [(10, 16), (17, 31), (36, 76)],
+        ),
     )
     def test_split_kwargs(self, channels, kwargs, expected):
 
@@ -253,14 +291,37 @@
         sample_size_bytes = sample_width * channels
         for reg, reg_ar, exp in zip(regions, regions_ar, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
+            exp_data = data[
+                onset * sample_size_bytes : offset * sample_size_bytes
+            ]
             self.assertEqual(len(bytes(reg)), len(exp_data))
             self.assertEqual(reg, reg_ar)
 
     @genty_dataset(
-        mono_aw_0_2_max_silence_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]),
-        mono_aw_0_2_max_silence_0_3=(0.2, 5, 0.3, 1, {"aw": 0.2}, [(2, 30), (34, 76)]),
-        mono_aw_0_2_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.2}, [(2, 32), (34, 76)]),
+        mono_aw_0_2_max_silence_0_2=(
+            0.2,
+            5,
+            0.2,
+            1,
+            {"aw": 0.2},
+            [(2, 30), (34, 76)],
+        ),
+        mono_aw_0_2_max_silence_0_3=(
+            0.2,
+            5,
+            0.3,
+            1,
+            {"aw": 0.2},
+            [(2, 30), (34, 76)],
+        ),
+        mono_aw_0_2_max_silence_0_4=(
+            0.2,
+            5,
+            0.4,
+            1,
+            {"aw": 0.2},
+            [(2, 32), (34, 76)],
+        ),
         mono_aw_0_2_max_silence_0=(
             0.2,
             5,
@@ -278,9 +339,30 @@
             {"aw": 0.3},
             [(3, 12), (15, 24), (36, 76)],
         ),
-        mono_aw_0_3_max_silence_0_3=(0.3, 5, 0.3, 1, {"aw": 0.3}, [(3, 27), (36, 76)]),
-        mono_aw_0_3_max_silence_0_5=(0.3, 5, 0.5, 1, {"aw": 0.3}, [(3, 27), (36, 76)]),
-        mono_aw_0_3_max_silence_0_6=(0.3, 5, 0.6, 1, {"aw": 0.3}, [(3, 30), (36, 76)]),
+        mono_aw_0_3_max_silence_0_3=(
+            0.3,
+            5,
+            0.3,
+            1,
+            {"aw": 0.3},
+            [(3, 27), (36, 76)],
+        ),
+        mono_aw_0_3_max_silence_0_5=(
+            0.3,
+            5,
+            0.5,
+            1,
+            {"aw": 0.3},
+            [(3, 27), (36, 76)],
+        ),
+        mono_aw_0_3_max_silence_0_6=(
+            0.3,
+            5,
+            0.6,
+            1,
+            {"aw": 0.3},
+            [(3, 30), (36, 76)],
+        ),
         mono_aw_0_4_max_silence_0=(
             0.2,
             5,
@@ -297,7 +379,14 @@
             {"aw": 0.4},
             [(4, 12), (16, 24), (36, 76)],
         ),
-        mono_aw_0_4_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.4}, [(4, 28), (36, 76)]),
+        mono_aw_0_4_max_silence_0_4=(
+            0.2,
+            5,
+            0.4,
+            1,
+            {"aw": 0.4},
+            [(4, 28), (36, 76)],
+        ),
         stereo_uc_0_analysis_window_0_2=(
             0.2,
             5,
@@ -528,7 +617,58 @@
         sample_size_bytes = sample_width * channels
         for reg, reg_ar, exp in zip(regions, regions_ar, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
+            exp_data = data[
+                onset * sample_size_bytes : offset * sample_size_bytes
+            ]
+            self.assertEqual(bytes(reg), exp_data)
+            self.assertEqual(reg, reg_ar)
+
+    def test_split_custom_validator(self):
+        filename = "tests/data/test_split_10HZ_mono.raw"
+        with open(filename, "rb") as fp:
+            data = fp.read()
+
+        regions = split(
+            data,
+            min_dur=0.2,
+            max_dur=5,
+            max_silence=0.2,
+            drop_trailing_silence=False,
+            strict_min_dur=False,
+            sr=10,
+            sw=2,
+            ch=1,
+            analysis_window=0.1,
+            validator=lambda x: array_("h", x)[0] >= 320,
+        )
+
+        region = AudioRegion(data, 10, 2, 1)
+        regions_ar = region.split(
+            min_dur=0.2,
+            max_dur=5,
+            max_silence=0.2,
+            drop_trailing_silence=False,
+            strict_min_dur=False,
+            analysis_window=0.1,
+            validator=lambda x: array_("h", x)[0] >= 320,
+        )
+
+        expected = [(2, 16), (17, 31), (34, 76)]
+        regions = list(regions)
+        regions_ar = list(regions_ar)
+        err_msg = "Wrong number of regions after split, expected: "
+        err_msg += "{}, found: {}".format(len(expected), len(regions))
+        self.assertEqual(len(regions), len(expected), err_msg)
+        err_msg = "Wrong number of regions after AudioRegion.split, expected: "
+        err_msg += "{}, found: {}".format(len(expected), len(regions_ar))
+        self.assertEqual(len(regions_ar), len(expected), err_msg)
+
+        sample_size_bytes = 2
+        for reg, reg_ar, exp in zip(regions, regions_ar, expected):
+            onset, offset = exp
+            exp_data = data[
+                onset * sample_size_bytes : offset * sample_size_bytes
+            ]
             self.assertEqual(bytes(reg), exp_data)
             self.assertEqual(reg, reg_ar)
 
@@ -565,7 +705,10 @@
         ),
         audio_region=(
             AudioRegion(
-                open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), 10, 2, 2
+                open("tests/data/test_split_10HZ_stereo.raw", "rb").read(),
+                10,
+                2,
+                2,
             ),
             {},
         ),
@@ -599,7 +742,9 @@
         self.assertEqual(len(regions), len(expected), err_msg)
         for reg, exp in zip(regions, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_width * 2 : offset * sample_width * 2]
+            exp_data = data[
+                onset * sample_width * 2 : offset * sample_width * 2
+            ]
             self.assertEqual(bytes(reg), exp_data)
 
     @genty_dataset(
@@ -704,7 +849,7 @@
         err_msg += " Analysis windows should at least be 1/10 to cover one "
         err_msg += "single data sample"
         self.assertEqual(err_msg, str(val_err.exception))
-    
+
     def test_split_and_plot(self):
 
         with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp:
@@ -734,11 +879,21 @@
             expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1))
         self.assertEqual(regions, expected_regions)
 
+
 @genty
 class TestAudioRegion(TestCase):
     @genty_dataset(
         simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000),
-        one_ms_less_than_1_sec=(b"\0" * 7992, 0, 8000, 1, 1, 0.999, 0.999, 999),
+        one_ms_less_than_1_sec=(
+            b"\0" * 7992,
+            0,
+            8000,
+            1,
+            1,
+            0.999,
+            0.999,
+            999,
+        ),
         tree_quarter_ms_less_than_1_sec=(
             b"\0" * 7994,
             0,
@@ -749,7 +904,16 @@
             0.99925,
             999,
         ),
-        half_ms_less_than_1_sec=(b"\0" * 7996, 0, 8000, 1, 1, 0.9995, 0.9995, 1000),
+        half_ms_less_than_1_sec=(
+            b"\0" * 7996,
+            0,
+            8000,
+            1,
+            1,
+            0.9995,
+            0.9995,
+            1000,
+        ),
         quarter_ms_less_than_1_sec=(
             b"\0" * 7998,
             0,
@@ -906,7 +1070,11 @@
         start_2=("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"),
         start_3=("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"),
         start_4=("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"),
-        start_5=("output_{meta.start:.8f}.wav", 1.233712, "output_1.23371200.wav"),
+        start_5=(
+            "output_{meta.start:.8f}.wav",
+            1.233712,
+            "output_1.23371200.wav",
+        ),
         start_end_duration=(
             "output_{meta.start}_{meta.end}_{duration}.wav",
             1.455,
@@ -1001,8 +1169,16 @@
             slice(-5000, None),
             b"a" * 160,
         ),
-        big_negative_stop=(AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1500), b""),
-        empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), b""),
+        big_negative_stop=(
+            AudioRegion(b"a" * 160, 160, 1, 1),
+            slice(None, -1500),
+            b"",
+        ),
+        empty=(
+            AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
+            slice(0, 0),
+            b"",
+        ),
         empty_start_stop_reversed=(
             AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
             slice(200, 100),
@@ -1122,7 +1298,12 @@
             0,
             b"",
         ),
-        empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), 0, b""),
+        empty=(
+            AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
+            slice(0, 0),
+            0,
+            b"",
+        ),
         empty_start_stop_reversed=(
             AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
             slice(80, 40),
@@ -1160,7 +1341,9 @@
             b"a" * 24 + b"b" * 76,
         ),
     )
-    def test_region_sample_slicing(self, region, slice_, time_shift, expected_data):
+    def test_region_sample_slicing(
+        self, region, slice_, time_shift, expected_data
+    ):
         sub_region = region[slice_]
         self.assertEqual(bytes(sub_region), expected_data)
 
@@ -1177,7 +1360,9 @@
         expected_duration = region_1.duration + region_2.duration
         expected_data = bytes(region_1) + bytes(region_2)
         concat_region = region_1 + region_2
-        self.assertAlmostEqual(concat_region.duration, expected_duration, places=6)
+        self.assertAlmostEqual(
+            concat_region.duration, expected_duration, places=6
+        )
         self.assertEqual(bytes(concat_region), expected_data)
 
     @genty_dataset(
@@ -1194,7 +1379,9 @@
         expected_data = b"".join(bytes(r) for r in regions)
         concat_region = sum(regions)
 
-        self.assertAlmostEqual(concat_region.duration, expected_duration, places=6)
+        self.assertAlmostEqual(
+            concat_region.duration, expected_duration, places=6
+        )
         self.assertEqual(bytes(concat_region), expected_data)
 
     def test_concatenation_different_sampling_rate_error(self):
@@ -1218,7 +1405,8 @@
         with self.assertRaises(ValueError) as val_err:
             region_1 + region_2
         self.assertEqual(
-            "Can only concatenate AudioRegions of the same " "sample width (2 != 4)",
+            "Can only concatenate AudioRegions of the same "
+            "sample width (2 != 4)",
             str(val_err.exception),
         )