Mercurial > hg > auditok
changeset 299:73989d247f4e
Add test for callable validator in split
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Tue, 08 Oct 2019 20:16:11 +0100 |
parents | d5cbf4fc1416 |
children | 1732213b290a |
files | tests/test_core.py |
diffstat | 1 files changed, 218 insertions(+), 30 deletions(-) [+] |
line wrap: on
line diff
--- a/tests/test_core.py Tue Oct 08 20:01:34 2019 +0100 +++ b/tests/test_core.py Tue Oct 08 20:16:11 2019 +0100 @@ -16,7 +16,9 @@ ) -def _make_random_length_regions(byte_seq, sampling_rate, sample_width, channels): +def _make_random_length_regions( + byte_seq, sampling_rate, sample_width, channels +): regions = [] for b in byte_seq: duration = round(random() * 10, 6) @@ -57,7 +59,15 @@ @genty class TestSplit(TestCase): @genty_dataset( - simple=(0.2, 5, 0.2, False, False, {"eth": 50}, [(2, 16), (17, 31), (34, 76)]), + simple=( + 0.2, + 5, + 0.2, + False, + False, + {"eth": 50}, + [(2, 16), (17, 31), (34, 76)], + ), short_max_dur=( 0.3, 2, @@ -87,7 +97,15 @@ {"energy_threshold": 40}, [(0, 50), (50, 76)], ), - high_energy_threshold=(0.2, 5, 0.2, False, False, {"energy_threshold": 60}, []), + high_energy_threshold=( + 0.2, + 5, + 0.2, + False, + False, + {"energy_threshold": 60}, + [], + ), trim_leading_and_trailing_silence=( 0.2, 10, # use long max_dur @@ -106,7 +124,15 @@ {"eth": 50}, [(2, 14), (17, 29), (34, 76)], ), - drop_trailing_silence_2=(1.5, 5, 0.2, True, False, {"eth": 50}, [(34, 76)]), + drop_trailing_silence_2=( + 1.5, + 5, + 0.2, + True, + False, + {"eth": 50}, + [(34, 76)], + ), strict_min_dur=( 0.3, 2, @@ -170,7 +196,7 @@ exp_data = data[onset * sample_width : offset * sample_width] self.assertEqual(bytes(reg), exp_data) self.assertEqual(reg, reg_ar) - + @genty_dataset( stereo_all_default=(2, {}, [(2, 32), (34, 76)]), mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), @@ -191,7 +217,11 @@ {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)], ), - stereo_use_channel_no_use_channel_given=(2, {"eth": 50}, [(2, 32), (34, 76)]), + stereo_use_channel_no_use_channel_given=( + 2, + {"eth": 50}, + [(2, 32), (34, 76)], + ), stereo_use_channel_minus_2=( 2, {"eth": 50, "use_channel": -2}, @@ -199,14 +229,22 @@ ), stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), - mono_uc_mix=(1, {"eth": 50, "uc": "mix"}, [(2, 16), (17, 31), (34, 76)]), + mono_uc_mix=( + 1, + {"eth": 50, "uc": "mix"}, + [(2, 16), (17, 31), (34, 76)], + ), stereo_use_channel_mix=( 2, {"energy_threshold": 53.5, "use_channel": "mix"}, [(54, 76)], ), stereo_uc_mix=(2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]), - stereo_uc_mix_default_eth=(2, {"uc": "mix"}, [(10, 16), (17, 31), (36, 76)]), + stereo_uc_mix_default_eth=( + 2, + {"uc": "mix"}, + [(10, 16), (17, 31), (36, 76)], + ), ) def test_split_kwargs(self, channels, kwargs, expected): @@ -253,14 +291,37 @@ sample_size_bytes = sample_width * channels for reg, reg_ar, exp in zip(regions, regions_ar, expected): onset, offset = exp - exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + exp_data = data[ + onset * sample_size_bytes : offset * sample_size_bytes + ] self.assertEqual(len(bytes(reg)), len(exp_data)) self.assertEqual(reg, reg_ar) @genty_dataset( - mono_aw_0_2_max_silence_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), - mono_aw_0_2_max_silence_0_3=(0.2, 5, 0.3, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), - mono_aw_0_2_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.2}, [(2, 32), (34, 76)]), + mono_aw_0_2_max_silence_0_2=( + 0.2, + 5, + 0.2, + 1, + {"aw": 0.2}, + [(2, 30), (34, 76)], + ), + mono_aw_0_2_max_silence_0_3=( + 0.2, + 5, + 0.3, + 1, + {"aw": 0.2}, + [(2, 30), (34, 76)], + ), + mono_aw_0_2_max_silence_0_4=( + 0.2, + 5, + 0.4, + 1, + {"aw": 0.2}, + [(2, 32), (34, 76)], + ), mono_aw_0_2_max_silence_0=( 0.2, 5, @@ -278,9 +339,30 @@ {"aw": 0.3}, [(3, 12), (15, 24), (36, 76)], ), - mono_aw_0_3_max_silence_0_3=(0.3, 5, 0.3, 1, {"aw": 0.3}, [(3, 27), (36, 76)]), - mono_aw_0_3_max_silence_0_5=(0.3, 5, 0.5, 1, {"aw": 0.3}, [(3, 27), (36, 76)]), - mono_aw_0_3_max_silence_0_6=(0.3, 5, 0.6, 1, {"aw": 0.3}, [(3, 30), (36, 76)]), + mono_aw_0_3_max_silence_0_3=( + 0.3, + 5, + 0.3, + 1, + {"aw": 0.3}, + [(3, 27), (36, 76)], + ), + mono_aw_0_3_max_silence_0_5=( + 0.3, + 5, + 0.5, + 1, + {"aw": 0.3}, + [(3, 27), (36, 76)], + ), + mono_aw_0_3_max_silence_0_6=( + 0.3, + 5, + 0.6, + 1, + {"aw": 0.3}, + [(3, 30), (36, 76)], + ), mono_aw_0_4_max_silence_0=( 0.2, 5, @@ -297,7 +379,14 @@ {"aw": 0.4}, [(4, 12), (16, 24), (36, 76)], ), - mono_aw_0_4_max_silence_0_4=(0.2, 5, 0.4, 1, {"aw": 0.4}, [(4, 28), (36, 76)]), + mono_aw_0_4_max_silence_0_4=( + 0.2, + 5, + 0.4, + 1, + {"aw": 0.4}, + [(4, 28), (36, 76)], + ), stereo_uc_0_analysis_window_0_2=( 0.2, 5, @@ -528,7 +617,58 @@ sample_size_bytes = sample_width * channels for reg, reg_ar, exp in zip(regions, regions_ar, expected): onset, offset = exp - exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + exp_data = data[ + onset * sample_size_bytes : offset * sample_size_bytes + ] + self.assertEqual(bytes(reg), exp_data) + self.assertEqual(reg, reg_ar) + + def test_split_custom_validator(self): + filename = "tests/data/test_split_10HZ_mono.raw" + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + sr=10, + sw=2, + ch=1, + analysis_window=0.1, + validator=lambda x: array_("h", x)[0] >= 320, + ) + + region = AudioRegion(data, 10, 2, 1) + regions_ar = region.split( + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + validator=lambda x: array_("h", x)[0] >= 320, + ) + + expected = [(2, 16), (17, 31), (34, 76)] + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + self.assertEqual(len(regions), len(expected), err_msg) + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + self.assertEqual(len(regions_ar), len(expected), err_msg) + + sample_size_bytes = 2 + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[ + onset * sample_size_bytes : offset * sample_size_bytes + ] self.assertEqual(bytes(reg), exp_data) self.assertEqual(reg, reg_ar) @@ -565,7 +705,10 @@ ), audio_region=( AudioRegion( - open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), 10, 2, 2 + open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), + 10, + 2, + 2, ), {}, ), @@ -599,7 +742,9 @@ self.assertEqual(len(regions), len(expected), err_msg) for reg, exp in zip(regions, expected): onset, offset = exp - exp_data = data[onset * sample_width * 2 : offset * sample_width * 2] + exp_data = data[ + onset * sample_width * 2 : offset * sample_width * 2 + ] self.assertEqual(bytes(reg), exp_data) @genty_dataset( @@ -704,7 +849,7 @@ err_msg += " Analysis windows should at least be 1/10 to cover one " err_msg += "single data sample" self.assertEqual(err_msg, str(val_err.exception)) - + def test_split_and_plot(self): with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: @@ -734,11 +879,21 @@ expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1)) self.assertEqual(regions, expected_regions) + @genty class TestAudioRegion(TestCase): @genty_dataset( simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000), - one_ms_less_than_1_sec=(b"\0" * 7992, 0, 8000, 1, 1, 0.999, 0.999, 999), + one_ms_less_than_1_sec=( + b"\0" * 7992, + 0, + 8000, + 1, + 1, + 0.999, + 0.999, + 999, + ), tree_quarter_ms_less_than_1_sec=( b"\0" * 7994, 0, @@ -749,7 +904,16 @@ 0.99925, 999, ), - half_ms_less_than_1_sec=(b"\0" * 7996, 0, 8000, 1, 1, 0.9995, 0.9995, 1000), + half_ms_less_than_1_sec=( + b"\0" * 7996, + 0, + 8000, + 1, + 1, + 0.9995, + 0.9995, + 1000, + ), quarter_ms_less_than_1_sec=( b"\0" * 7998, 0, @@ -906,7 +1070,11 @@ start_2=("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"), start_3=("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"), start_4=("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"), - start_5=("output_{meta.start:.8f}.wav", 1.233712, "output_1.23371200.wav"), + start_5=( + "output_{meta.start:.8f}.wav", + 1.233712, + "output_1.23371200.wav", + ), start_end_duration=( "output_{meta.start}_{meta.end}_{duration}.wav", 1.455, @@ -1001,8 +1169,16 @@ slice(-5000, None), b"a" * 160, ), - big_negative_stop=(AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1500), b""), - empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), b""), + big_negative_stop=( + AudioRegion(b"a" * 160, 160, 1, 1), + slice(None, -1500), + b"", + ), + empty=( + AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), + slice(0, 0), + b"", + ), empty_start_stop_reversed=( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(200, 100), @@ -1122,7 +1298,12 @@ 0, b"", ), - empty=(AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), 0, b""), + empty=( + AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), + slice(0, 0), + 0, + b"", + ), empty_start_stop_reversed=( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(80, 40), @@ -1160,7 +1341,9 @@ b"a" * 24 + b"b" * 76, ), ) - def test_region_sample_slicing(self, region, slice_, time_shift, expected_data): + def test_region_sample_slicing( + self, region, slice_, time_shift, expected_data + ): sub_region = region[slice_] self.assertEqual(bytes(sub_region), expected_data) @@ -1177,7 +1360,9 @@ expected_duration = region_1.duration + region_2.duration expected_data = bytes(region_1) + bytes(region_2) concat_region = region_1 + region_2 - self.assertAlmostEqual(concat_region.duration, expected_duration, places=6) + self.assertAlmostEqual( + concat_region.duration, expected_duration, places=6 + ) self.assertEqual(bytes(concat_region), expected_data) @genty_dataset( @@ -1194,7 +1379,9 @@ expected_data = b"".join(bytes(r) for r in regions) concat_region = sum(regions) - self.assertAlmostEqual(concat_region.duration, expected_duration, places=6) + self.assertAlmostEqual( + concat_region.duration, expected_duration, places=6 + ) self.assertEqual(bytes(concat_region), expected_data) def test_concatenation_different_sampling_rate_error(self): @@ -1218,7 +1405,8 @@ with self.assertRaises(ValueError) as val_err: region_1 + region_2 self.assertEqual( - "Can only concatenate AudioRegions of the same " "sample width (2 != 4)", + "Can only concatenate AudioRegions of the same " + "sample width (2 != 4)", str(val_err.exception), )