Mercurial > hg > auditok
changeset 316:b6c5125be036
Fix bugs in AudioEnergyValidator and signal_numpy and add tests
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Thu, 17 Oct 2019 21:21:29 +0100 |
parents | 5f1859160fd7 |
children | 18a9f0dcdaae |
files | auditok/signal_numpy.py auditok/util.py tests/test_core.py |
diffstat | 3 files changed, 119 insertions(+), 34 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/signal_numpy.py Wed Oct 16 21:58:54 2019 +0100 +++ b/auditok/signal_numpy.py Thu Oct 17 21:21:29 2019 +0100 @@ -1,8 +1,13 @@ import numpy as np -from .signal import average_channels_stereo, calculate_energy_single_channel, calculate_energy_multichannel +from .signal import ( + average_channels_stereo, + calculate_energy_single_channel, + calculate_energy_multichannel, +) FORMAT = {1: np.int8, 2: np.int16, 4: np.int32} + def to_array(data, sample_width, channels): fmt = FORMAT[sample_width] if channels == 1: @@ -22,5 +27,4 @@ def separate_channels(data, fmt, channels): array = np.frombuffer(data, dtype=fmt) - return array.reshape(-1, channels).T - + return np.asanyarray(array.reshape(-1, channels).T, order="C")
--- a/auditok/util.py Wed Oct 16 21:58:54 2019 +0100 +++ b/auditok/util.py Thu Oct 17 21:21:29 2019 +0100 @@ -56,9 +56,8 @@ if fmt is None: err_msg = "'sample_width' must be 1, 2 or 4, given: {}" raise ValueError(err_msg.format(sample_width)) - if channels == 1: - return lambda x : x + return lambda x: x if isinstance(selected, int): if selected < 0: @@ -68,20 +67,27 @@ err_msg += ", given: {}" raise ValueError(err_msg.format(selected)) return partial( - signal.extract_single_channel, fmt=fmt, channels=channels, selected=selected + signal.extract_single_channel, + fmt=fmt, + channels=channels, + selected=selected, ) if selected in ("mix", "avg", "average"): if channels == 2: # when data is stereo, using audioop when possible is much faster - return partial(signal.average_channels_stereo, sample_width=sample_width) - + return partial( + signal.average_channels_stereo, sample_width=sample_width + ) + return partial(signal.average_channels, fmt=fmt, channels=channels) if selected in (None, "any"): return partial(signal.separate_channels, fmt=fmt, channels=channels) - - raise ValueError("Selected channel must be an integer, None (alias 'any') or 'average' (alias 'avg' or 'mix')") + + raise ValueError( + "Selected channel must be an integer, None (alias 'any') or 'average' (alias 'avg' or 'mix')" + ) class DataSource: @@ -106,6 +112,7 @@ if read data is valid. Subclasses should implement :func:`is_valid` method. """ + @abstractmethod def is_valid(self, data): """ @@ -114,10 +121,14 @@ class AudioEnergyValidator(DataValidator): - def __init__(self, energy_threshold, sample_width, channels, use_channel=None): + def __init__( + self, energy_threshold, sample_width, channels, use_channel=None + ): self._sample_width = sample_width - self._selector = make_channel_selector(sample_width, channels, use_channel) - if channels == 1 or use_channel is not None: + self._selector = make_channel_selector( + sample_width, channels, use_channel + ) + if channels == 1 or use_channel not in (None, "any"): self._energy_fn = signal.calculate_energy_single_channel else: self._energy_fn = signal.calculate_energy_multichannel @@ -299,9 +310,13 @@ kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None) kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None) kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None) - kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop("asrc", None) + kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop( + "asrc", None + ) kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None) - kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None) + kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop( + "db", None + ) record = kwargs.pop("record", False) if not record: @@ -318,17 +333,19 @@ ) or kwargs.pop("fpb", None) if "sampling_rate" in kwargs or "sr" in kwargs: - kwargs["sampling_rate"] = kwargs.pop("sampling_rate", None) or kwargs.pop( - "sr", None - ) + kwargs["sampling_rate"] = kwargs.pop( + "sampling_rate", None + ) or kwargs.pop("sr", None) if "sample_width" in kwargs or "sw" in kwargs: - kwargs["sample_width"] = kwargs.pop("sample_width", None) or kwargs.pop( - "sw", None - ) + kwargs["sample_width"] = kwargs.pop( + "sample_width", None + ) or kwargs.pop("sw", None) if "channels" in kwargs or "ch" in kwargs: - kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop("ch", None) + kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop( + "ch", None + ) @staticmethod def ads(**kwargs): @@ -736,7 +753,9 @@ super(_FixedSizeAudioReader, self).__init__(audio_source) if block_dur <= 0: - raise ValueError("block_dur must be > 0, given: {}".format(block_dur)) + raise ValueError( + "block_dur must be > 0, given: {}".format(block_dur) + ) self._block_size = int(block_dur * self.sr) if self._block_size == 0: @@ -785,7 +804,9 @@ if block is None: yield None - _hop_size_bytes = self._hop_size * self._audio_source.sw * self._audio_source.ch + _hop_size_bytes = ( + self._hop_size * self._audio_source.sw * self._audio_source.ch + ) cache = block[_hop_size_bytes:] yield block @@ -830,7 +851,13 @@ """ def __init__( - self, input, block_dur=0.01, hop_dur=None, record=False, max_read=None, **kwargs + self, + input, + block_dur=0.01, + hop_dur=None, + record=False, + max_read=None, + **kwargs ): if not isinstance(input, AudioSource): input = get_audio_source(input, **kwargs) @@ -898,11 +925,15 @@ def __getattr__(self, name): if name in ("data", "rewind") and not self.rewindable: - raise AttributeError("'AudioReader' has no attribute '{}'".format(name)) + raise AttributeError( + "'AudioReader' has no attribute '{}'".format(name) + ) try: return getattr(self._audio_source, name) except AttributeError: - raise AttributeError("'AudioReader' has no attribute '{}'".format(name)) + raise AttributeError( + "'AudioReader' has no attribute '{}'".format(name) + ) # Keep AudioDataSource for compatibility @@ -911,7 +942,9 @@ class Recorder(AudioReader): - def __init__(self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs): + def __init__( + self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs + ): super().__init__( input, block_dur=block_dur,
--- a/tests/test_core.py Wed Oct 16 21:58:54 2019 +0100 +++ b/tests/test_core.py Thu Oct 17 21:21:29 2019 +0100 @@ -389,6 +389,54 @@ {"aw": 0.4}, [(4, 28), (36, 76)], ), + stereo_uc_None_analysis_window_0_2=( + 0.2, + 5, + 0.2, + 2, + {"analysis_window": 0.2}, + [(2, 32), (34, 76)], + ), + stereo_uc_any_analysis_window_0_2=( + 0.2, + 5, + 0.2, + 2, + {"uc": None, "analysis_window": 0.2}, + [(2, 32), (34, 76)], + ), + stereo_use_channel_None_aw_0_3_max_silence_0_2=( + 0.2, + 5, + 0.2, + 2, + {"use_channel": None, "analysis_window": 0.3}, + [(3, 30), (36, 76)], + ), + stereo_use_channel_any_aw_0_3_max_silence_0_3=( + 0.2, + 5, + 0.3, + 2, + {"use_channel": "any", "analysis_window": 0.3}, + [(3, 33), (36, 76)], + ), + stereo_use_channel_None_aw_0_4_max_silence_0_2=( + 0.2, + 5, + 0.2, + 2, + {"use_channel": None, "analysis_window": 0.4}, + [(4, 28), (36, 76)], + ), + stereo_use_channel_any_aw_0_3_max_silence_0_4=( + 0.2, + 5, + 0.4, + 2, + {"use_channel": "any", "analysis_window": 0.4}, + [(4, 32), (36, 76)], + ), stereo_uc_0_analysis_window_0_2=( 0.2, 5, @@ -437,20 +485,20 @@ {"uc": "mix", "analysis_window": 0.1}, [(10, 32), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_min_dur_0_3=( + stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3=( 0.3, 5, 0, 2, - {"uc": "mix", "analysis_window": 0.2}, + {"uc": "avg", "analysis_window": 0.2}, [(10, 14), (16, 24), (36, 76)], ), - stereo_uc_mix_aw_0_2_max_silence_0_min_dur_0_41=( + stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41=( 0.41, 5, 0, 2, - {"uc": "mix", "analysis_window": 0.2}, + {"uc": "average", "analysis_window": 0.2}, [(16, 24), (36, 76)], ), stereo_uc_mix_aw_0_2_max_silence_0_1=( @@ -593,7 +641,7 @@ sr=10, sw=2, ch=channels, - eth= 49.99, + eth=49.99, **kwargs ) @@ -604,7 +652,7 @@ max_silence=max_silence, drop_trailing_silence=False, strict_min_dur=False, - eth= 49.99, + eth=49.99, **kwargs )