changeset 264:ca5449269024

Add test for split_and_plot
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 12 Sep 2019 20:59:12 +0100
parents 58eb13b55163
children 73a01556cd5e
files tests/test_core.py
diffstat 1 files changed, 58 insertions(+), 167 deletions(-) [+]
line wrap: on
line diff
--- a/tests/test_core.py	Wed Sep 11 20:41:33 2019 +0100
+++ b/tests/test_core.py	Thu Sep 12 20:59:12 2019 +0100
@@ -4,6 +4,7 @@
 from tempfile import TemporaryDirectory
 from array import array as array_
 from unittest import TestCase
+from mock import patch
 from genty import genty, genty_dataset
 from auditok import split, AudioRegion, AudioParameterError
 from auditok.core import _duration_to_nb_windows
@@ -15,9 +16,7 @@
 )
 
 
-def _make_random_length_regions(
-    byte_seq, sampling_rate, sample_width, channels
-):
+def _make_random_length_regions(byte_seq, sampling_rate, sample_width, channels):
     regions = []
     for b in byte_seq:
         duration = round(random() * 10, 6)
@@ -58,15 +57,7 @@
 @genty
 class TestSplit(TestCase):
     @genty_dataset(
-        simple=(
-            0.2,
-            5,
-            0.2,
-            False,
-            False,
-            {"eth": 50},
-            [(2, 16), (17, 31), (34, 76)],
-        ),
+        simple=(0.2, 5, 0.2, False, False, {"eth": 50}, [(2, 16), (17, 31), (34, 76)]),
         short_max_dur=(
             0.3,
             2,
@@ -96,15 +87,7 @@
             {"energy_threshold": 40},
             [(0, 50), (50, 76)],
         ),
-        high_energy_threshold=(
-            0.2,
-            5,
-            0.2,
-            False,
-            False,
-            {"energy_threshold": 60},
-            [],
-        ),
+        high_energy_threshold=(0.2, 5, 0.2, False, False, {"energy_threshold": 60}, []),
         trim_leading_and_trailing_silence=(
             0.2,
             10,  # use long max_dur
@@ -123,15 +106,7 @@
             {"eth": 50},
             [(2, 14), (17, 29), (34, 76)],
         ),
-        drop_trailing_silence_2=(
-            1.5,
-            5,
-            0.2,
-            True,
-            False,
-            {"eth": 50},
-            [(34, 76)],
-        ),
+        drop_trailing_silence_2=(1.5, 5, 0.2, True, False, {"eth": 50}, [(34, 76)]),
         strict_min_dur=(
             0.3,
             2,
@@ -195,7 +170,7 @@
             exp_data = data[onset * sample_width : offset * sample_width]
             self.assertEqual(bytes(reg), exp_data)
             self.assertEqual(reg, reg_ar)
-
+  
     @genty_dataset(
         stereo_all_default=(2, {}, [(2, 32), (34, 76)]),
         mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]),
@@ -216,11 +191,7 @@
             {"eth": 50, "use_channel": 0},
             [(2, 16), (17, 31), (34, 76)],
         ),
-        stereo_use_channel_no_use_channel_given=(
-            2,
-            {"eth": 50},
-            [(2, 32), (34, 76)],
-        ),
+        stereo_use_channel_no_use_channel_given=(2, {"eth": 50}, [(2, 32), (34, 76)]),
         stereo_use_channel_minus_2=(
             2,
             {"eth": 50, "use_channel": -2},
@@ -228,22 +199,14 @@
         ),
         stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]),
         stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]),
-        mono_uc_mix=(
-            1,
-            {"eth": 50, "uc": "mix"},
-            [(2, 16), (17, 31), (34, 76)],
-        ),
+        mono_uc_mix=(1, {"eth": 50, "uc": "mix"}, [(2, 16), (17, 31), (34, 76)]),
         stereo_use_channel_mix=(
             2,
             {"energy_threshold": 53.5, "use_channel": "mix"},
             [(54, 76)],
         ),
         stereo_uc_mix=(2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]),
-        stereo_uc_mix_default_eth=(
-            2,
-            {"uc": "mix"},
-            [(10, 16), (17, 31), (36, 76)],
-        ),
+        stereo_uc_mix_default_eth=(2, {"uc": "mix"}, [(10, 16), (17, 31), (36, 76)]),
     )
     def test_split_kwargs(self, channels, kwargs, expected):
 
@@ -290,37 +253,14 @@
         sample_size_bytes = sample_width * channels
         for reg, reg_ar, exp in zip(regions, regions_ar, expected):
             onset, offset = exp
-            exp_data = data[
-                onset * sample_size_bytes : offset * sample_size_bytes
-            ]
+            exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
             self.assertEqual(len(bytes(reg)), len(exp_data))
             self.assertEqual(reg, reg_ar)
 
     @genty_dataset(
-        mono_aw_0_2_max_silence_0_2=(
-            0.2,
-            5,
-            0.2,
-            1,
-            {"aw": 0.2},
-            [(2, 30), (34, 76)],
-        ),
-        mono_aw_0_2_max_silence_0_3=(
-            0.2,
-            5,
-            0.3,
-            1,
-            {"aw": 0.2},
-            [(2, 30), (34, 76)],
-        ),
-        mono_aw_0_2_max_silence_0_4=(
-            0.2,
-            5,
-            0.4,
-            1,
-            {"aw": 0.2},
-            [(2, 32), (34, 76)],
-        ),
+        mono_aw_0_2_max_silence_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]),
+        mono_aw_0_2_max_silence_0_3=(0.2, 5, 0.3, 1, {"aw": 0.2}, [(2, 30), (34, 76)]),
+        mono_aw_0_2_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.2}, [(2, 32), (34, 76)]),
         mono_aw_0_2_max_silence_0=(
             0.2,
             5,
@@ -338,30 +278,9 @@
             {"aw": 0.3},
             [(3, 12), (15, 24), (36, 76)],
         ),
-        mono_aw_0_3_max_silence_0_3=(
-            0.3,
-            5,
-            0.3,
-            1,
-            {"aw": 0.3},
-            [(3, 27), (36, 76)],
-        ),
-        mono_aw_0_3_max_silence_0_5=(
-            0.3,
-            5,
-            0.5,
-            1,
-            {"aw": 0.3},
-            [(3, 27), (36, 76)],
-        ),
-        mono_aw_0_3_max_silence_0_6=(
-            0.3,
-            5,
-            0.6,
-            1,
-            {"aw": 0.3},
-            [(3, 30), (36, 76)],
-        ),
+        mono_aw_0_3_max_silence_0_3=(0.3, 5, 0.3, 1, {"aw": 0.3}, [(3, 27), (36, 76)]),
+        mono_aw_0_3_max_silence_0_5=(0.3, 5, 0.5, 1, {"aw": 0.3}, [(3, 27), (36, 76)]),
+        mono_aw_0_3_max_silence_0_6=(0.3, 5, 0.6, 1, {"aw": 0.3}, [(3, 30), (36, 76)]),
         mono_aw_0_4_max_silence_0=(
             0.2,
             5,
@@ -378,14 +297,7 @@
             {"aw": 0.4},
             [(4, 12), (16, 24), (36, 76)],
         ),
-        mono_aw_0_4_max_silence_0_4=(
-            0.2,
-            5,
-            0.4,
-            1,
-            {"aw": 0.4},
-            [(4, 28), (36, 76)],
-        ),
+        mono_aw_0_4_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.4}, [(4, 28), (36, 76)]),
         stereo_uc_0_analysis_window_0_2=(
             0.2,
             5,
@@ -616,9 +528,7 @@
         sample_size_bytes = sample_width * channels
         for reg, reg_ar, exp in zip(regions, regions_ar, expected):
             onset, offset = exp
-            exp_data = data[
-                onset * sample_size_bytes : offset * sample_size_bytes
-            ]
+            exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
             self.assertEqual(bytes(reg), exp_data)
             self.assertEqual(reg, reg_ar)
 
@@ -655,10 +565,7 @@
         ),
         audio_region=(
             AudioRegion(
-                open("tests/data/test_split_10HZ_stereo.raw", "rb").read(),
-                10,
-                2,
-                2,
+                open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), 10, 2, 2
             ),
             {},
         ),
@@ -692,9 +599,7 @@
         self.assertEqual(len(regions), len(expected), err_msg)
         for reg, exp in zip(regions, expected):
             onset, offset = exp
-            exp_data = data[
-                onset * sample_width * 2 : offset * sample_width * 2
-            ]
+            exp_data = data[onset * sample_width * 2 : offset * sample_width * 2]
             self.assertEqual(bytes(reg), exp_data)
 
     @genty_dataset(
@@ -799,22 +704,41 @@
         err_msg += " Analysis windows should at least be 1/10 to cover one "
         err_msg += "single data sample"
         self.assertEqual(err_msg, str(val_err.exception))
+    
+    def test_split_and_plot(self):
 
+        with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp:
+            data = fp.read()
+
+        region = AudioRegion(data, 10, 2, 1)
+        with patch("auditok.core.plot_detections") as patch_fn:
+            regions = region.split_and_plot(
+                min_dur=0.2,
+                max_dur=5,
+                max_silence=0.2,
+                drop_trailing_silence=False,
+                strict_min_dur=False,
+                analysis_window=0.1,
+                sr=10,
+                sw=2,
+                ch=1,
+                eth=50,
+            )
+        self.assertTrue(patch_fn.called)
+        expected = [(2, 16), (17, 31), (34, 76)]
+        sample_width = 2
+        expected_regions = []
+        for (onset, offset) in expected:
+            onset *= sample_width
+            offset *= sample_width
+            expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1))
+        self.assertEqual(regions, expected_regions)
 
 @genty
 class TestAudioRegion(TestCase):
     @genty_dataset(
         simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000),
-        one_ms_less_than_1_sec=(
-            b"\0" * 7992,
-            0,
-            8000,
-            1,
-            1,
-            0.999,
-            0.999,
-            999,
-        ),
+        one_ms_less_than_1_sec=(b"\0" * 7992, 0, 8000, 1, 1, 0.999, 0.999, 999),
         tree_quarter_ms_less_than_1_sec=(
             b"\0" * 7994,
             0,
@@ -825,16 +749,7 @@
             0.99925,
             999,
         ),
-        half_ms_less_than_1_sec=(
-            b"\0" * 7996,
-            0,
-            8000,
-            1,
-            1,
-            0.9995,
-            0.9995,
-            1000,
-        ),
+        half_ms_less_than_1_sec=(b"\0" * 7996, 0, 8000, 1, 1, 0.9995, 0.9995, 1000),
         quarter_ms_less_than_1_sec=(
             b"\0" * 7998,
             0,
@@ -991,11 +906,7 @@
         start_2=("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"),
         start_3=("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"),
         start_4=("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"),
-        start_5=(
-            "output_{meta.start:.8f}.wav",
-            1.233712,
-            "output_1.23371200.wav",
-        ),
+        start_5=("output_{meta.start:.8f}.wav", 1.233712, "output_1.23371200.wav"),
         start_end_duration=(
             "output_{meta.start}_{meta.end}_{duration}.wav",
             1.455,
@@ -1090,16 +1001,8 @@
             slice(-5000, None),
             b"a" * 160,
         ),
-        big_negative_stop=(
-            AudioRegion(b"a" * 160, 160, 1, 1),
-            slice(None, -1500),
-            b"",
-        ),
-        empty=(
-            AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
-            slice(0, 0),
-            b"",
-        ),
+        big_negative_stop=(AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1500), b""),
+        empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), b""),
         empty_start_stop_reversed=(
             AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
             slice(200, 100),
@@ -1219,12 +1122,7 @@
             0,
             b"",
         ),
-        empty=(
-            AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
-            slice(0, 0),
-            0,
-            b"",
-        ),
+        empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), 0, b""),
         empty_start_stop_reversed=(
             AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1),
             slice(80, 40),
@@ -1262,9 +1160,7 @@
             b"a" * 24 + b"b" * 76,
         ),
     )
-    def test_region_sample_slicing(
-        self, region, slice_, time_shift, expected_data
-    ):
+    def test_region_sample_slicing(self, region, slice_, time_shift, expected_data):
         sub_region = region[slice_]
         self.assertEqual(bytes(sub_region), expected_data)
 
@@ -1281,9 +1177,7 @@
         expected_duration = region_1.duration + region_2.duration
         expected_data = bytes(region_1) + bytes(region_2)
         concat_region = region_1 + region_2
-        self.assertAlmostEqual(
-            concat_region.duration, expected_duration, places=6
-        )
+        self.assertAlmostEqual(concat_region.duration, expected_duration, places=6)
         self.assertEqual(bytes(concat_region), expected_data)
 
     @genty_dataset(
@@ -1300,9 +1194,7 @@
         expected_data = b"".join(bytes(r) for r in regions)
         concat_region = sum(regions)
 
-        self.assertAlmostEqual(
-            concat_region.duration, expected_duration, places=6
-        )
+        self.assertAlmostEqual(concat_region.duration, expected_duration, places=6)
         self.assertEqual(bytes(concat_region), expected_data)
 
     def test_concatenation_different_sampling_rate_error(self):
@@ -1326,8 +1218,7 @@
         with self.assertRaises(ValueError) as val_err:
             region_1 + region_2
         self.assertEqual(
-            "Can only concatenate AudioRegions of the same "
-            "sample width (2 != 4)",
+            "Can only concatenate AudioRegions of the same " "sample width (2 != 4)",
             str(val_err.exception),
         )