Mercurial > hg > auditok
changeset 212:de60431f343b
Add tests for split with different input types
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sun, 30 Jun 2019 15:37:12 +0100 |
parents | ed6b3cecb407 |
children | 6ff5411ef661 |
files | auditok/core.py tests/test_core.py |
diffstat | 2 files changed, 76 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/core.py Wed Jun 26 20:20:55 2019 +0100 +++ b/auditok/core.py Sun Jun 30 15:37:12 2019 +0100 @@ -97,6 +97,7 @@ params = kwargs.copy() params["max_read"] = params.get("max_read", params.get("mr")) + params["audio_format"] = params.get("audio_format", params.get("fmt")) if isinstance(input, AudioRegion): params["sampling_rate"] = input.sr params["sample_width"] = input.sw
--- a/tests/test_core.py Wed Jun 26 20:20:55 2019 +0100 +++ b/tests/test_core.py Sun Jun 30 15:37:12 2019 +0100 @@ -4,6 +4,7 @@ from tempfile import TemporaryDirectory from genty import genty, genty_dataset from auditok import split, AudioRegion, AudioParameterError +from auditok.util import AudioDataSource from auditok.io import ( _normalize_use_channel, _extract_selected_channel, @@ -203,7 +204,7 @@ ch=channels, **kwargs ) - + regions = list(regions) sample_width = 2 import numpy as np @@ -215,12 +216,83 @@ data = _extract_selected_channel( data, channels, sample_width, use_channel=use_channel ) - - regions = list(regions) err_msg = "Wrong number of regions after split, expected: " err_msg += "{}, found: {}".format(expected, regions) self.assertEqual(len(regions), len(expected), err_msg) + for reg, exp in zip(regions, expected): + onset, offset = exp + exp_data = data[onset * sample_width : offset * sample_width] + self.assertEqual(bytes(reg), exp_data) + @genty_dataset( + filename_audio_format=( + "tests/data/test_split_10HZ_stereo.raw", + {"audio_format": "raw", "sr": 10, "sw": 2, "ch": 2}, + ), + filename_audio_format_short_name=( + "tests/data/test_split_10HZ_stereo.raw", + {"fmt": "raw", "sr": 10, "sw": 2, "ch": 2}, + ), + filename_no_audio_format=( + "tests/data/test_split_10HZ_stereo.raw", + {"sr": 10, "sw": 2, "ch": 2}, + ), + filename_no_long_audio_params=( + "tests/data/test_split_10HZ_stereo.raw", + {"sampling_rate": 10, "sample_width": 2, "channels": 2}, + ), + bytes_=( + open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), + {"sr": 10, "sw": 2, "ch": 2}, + ), + audio_reader=( + AudioDataSource( + "tests/data/test_split_10HZ_stereo.raw", + sr=10, + sw=2, + ch=2, + block_dur=0.1, + ), + {}, + ), + audio_region=( + AudioRegion( + open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), + 0, + 10, + 2, + 2, + ), + {}, + ), + audio_source=( + get_audio_source( + "tests/data/test_split_10HZ_stereo.raw", sr=10, sw=2, ch=2 + ), + {}, + ), + ) + def test_split_input_type(self, input, kwargs): + + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + + regions = split( + input, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + **kwargs + ) + regions = list(regions) + expected = [(2, 16), (17, 31), (34, 76)] + sample_width = 2 + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(expected, regions) + self.assertEqual(len(regions), len(expected), err_msg) for reg, exp in zip(regions, expected): onset, offset = exp exp_data = data[onset * sample_width : offset * sample_width]