changeset 316:b6c5125be036

Fix bugs in AudioEnergyValidator and signal_numpy and add tests
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 17 Oct 2019 21:21:29 +0100
parents 5f1859160fd7
children 18a9f0dcdaae
files auditok/signal_numpy.py auditok/util.py tests/test_core.py
diffstat 3 files changed, 119 insertions(+), 34 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/signal_numpy.py	Wed Oct 16 21:58:54 2019 +0100
+++ b/auditok/signal_numpy.py	Thu Oct 17 21:21:29 2019 +0100
@@ -1,8 +1,13 @@
 import numpy as np
-from .signal import average_channels_stereo, calculate_energy_single_channel, calculate_energy_multichannel
+from .signal import (
+    average_channels_stereo,
+    calculate_energy_single_channel,
+    calculate_energy_multichannel,
+)
 
 FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
 
+
 def to_array(data, sample_width, channels):
     fmt = FORMAT[sample_width]
     if channels == 1:
@@ -22,5 +27,4 @@
 
 def separate_channels(data, fmt, channels):
     array = np.frombuffer(data, dtype=fmt)
-    return array.reshape(-1, channels).T
-
+    return np.asanyarray(array.reshape(-1, channels).T, order="C")
--- a/auditok/util.py	Wed Oct 16 21:58:54 2019 +0100
+++ b/auditok/util.py	Thu Oct 17 21:21:29 2019 +0100
@@ -56,9 +56,8 @@
     if fmt is None:
         err_msg = "'sample_width' must be 1, 2 or 4, given: {}"
         raise ValueError(err_msg.format(sample_width))
-
     if channels == 1:
-        return lambda x : x
+        return lambda x: x
 
     if isinstance(selected, int):
         if selected < 0:
@@ -68,20 +67,27 @@
             err_msg += ", given: {}"
             raise ValueError(err_msg.format(selected))
         return partial(
-            signal.extract_single_channel, fmt=fmt, channels=channels, selected=selected
+            signal.extract_single_channel,
+            fmt=fmt,
+            channels=channels,
+            selected=selected,
         )
 
     if selected in ("mix", "avg", "average"):
         if channels == 2:
             # when data is stereo, using audioop when possible is much faster
-            return partial(signal.average_channels_stereo, sample_width=sample_width)
-        
+            return partial(
+                signal.average_channels_stereo, sample_width=sample_width
+            )
+
         return partial(signal.average_channels, fmt=fmt, channels=channels)
 
     if selected in (None, "any"):
         return partial(signal.separate_channels, fmt=fmt, channels=channels)
-    
-    raise ValueError("Selected channel must be an integer, None (alias 'any') or 'average' (alias 'avg' or 'mix')")
+
+    raise ValueError(
+        "Selected channel must be an integer, None (alias 'any') or 'average' (alias 'avg' or 'mix')"
+    )
 
 
 class DataSource:
@@ -106,6 +112,7 @@
     if read data is valid.
     Subclasses should implement :func:`is_valid` method.
    """
+
     @abstractmethod
     def is_valid(self, data):
         """
@@ -114,10 +121,14 @@
 
 
 class AudioEnergyValidator(DataValidator):
-    def __init__(self, energy_threshold, sample_width, channels, use_channel=None):
+    def __init__(
+        self, energy_threshold, sample_width, channels, use_channel=None
+    ):
         self._sample_width = sample_width
-        self._selector = make_channel_selector(sample_width, channels, use_channel)
-        if channels == 1 or use_channel is not None:
+        self._selector = make_channel_selector(
+            sample_width, channels, use_channel
+        )
+        if channels == 1 or use_channel not in (None, "any"):
             self._energy_fn = signal.calculate_energy_single_channel
         else:
             self._energy_fn = signal.calculate_energy_multichannel
@@ -299,9 +310,13 @@
         kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None)
         kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None)
         kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None)
-        kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop("asrc", None)
+        kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop(
+            "asrc", None
+        )
         kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None)
-        kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None)
+        kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop(
+            "db", None
+        )
 
         record = kwargs.pop("record", False)
         if not record:
@@ -318,17 +333,19 @@
             ) or kwargs.pop("fpb", None)
 
         if "sampling_rate" in kwargs or "sr" in kwargs:
-            kwargs["sampling_rate"] = kwargs.pop("sampling_rate", None) or kwargs.pop(
-                "sr", None
-            )
+            kwargs["sampling_rate"] = kwargs.pop(
+                "sampling_rate", None
+            ) or kwargs.pop("sr", None)
 
         if "sample_width" in kwargs or "sw" in kwargs:
-            kwargs["sample_width"] = kwargs.pop("sample_width", None) or kwargs.pop(
-                "sw", None
-            )
+            kwargs["sample_width"] = kwargs.pop(
+                "sample_width", None
+            ) or kwargs.pop("sw", None)
 
         if "channels" in kwargs or "ch" in kwargs:
-            kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop("ch", None)
+            kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop(
+                "ch", None
+            )
 
     @staticmethod
     def ads(**kwargs):
@@ -736,7 +753,9 @@
         super(_FixedSizeAudioReader, self).__init__(audio_source)
 
         if block_dur <= 0:
-            raise ValueError("block_dur must be > 0, given: {}".format(block_dur))
+            raise ValueError(
+                "block_dur must be > 0, given: {}".format(block_dur)
+            )
 
         self._block_size = int(block_dur * self.sr)
         if self._block_size == 0:
@@ -785,7 +804,9 @@
         if block is None:
             yield None
 
-        _hop_size_bytes = self._hop_size * self._audio_source.sw * self._audio_source.ch
+        _hop_size_bytes = (
+            self._hop_size * self._audio_source.sw * self._audio_source.ch
+        )
         cache = block[_hop_size_bytes:]
         yield block
 
@@ -830,7 +851,13 @@
     """
 
     def __init__(
-        self, input, block_dur=0.01, hop_dur=None, record=False, max_read=None, **kwargs
+        self,
+        input,
+        block_dur=0.01,
+        hop_dur=None,
+        record=False,
+        max_read=None,
+        **kwargs
     ):
         if not isinstance(input, AudioSource):
             input = get_audio_source(input, **kwargs)
@@ -898,11 +925,15 @@
 
     def __getattr__(self, name):
         if name in ("data", "rewind") and not self.rewindable:
-            raise AttributeError("'AudioReader' has no attribute '{}'".format(name))
+            raise AttributeError(
+                "'AudioReader' has no attribute '{}'".format(name)
+            )
         try:
             return getattr(self._audio_source, name)
         except AttributeError:
-            raise AttributeError("'AudioReader' has no attribute '{}'".format(name))
+            raise AttributeError(
+                "'AudioReader' has no attribute '{}'".format(name)
+            )
 
 
 # Keep AudioDataSource for compatibility
@@ -911,7 +942,9 @@
 
 
 class Recorder(AudioReader):
-    def __init__(self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs):
+    def __init__(
+        self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs
+    ):
         super().__init__(
             input,
             block_dur=block_dur,
--- a/tests/test_core.py	Wed Oct 16 21:58:54 2019 +0100
+++ b/tests/test_core.py	Thu Oct 17 21:21:29 2019 +0100
@@ -389,6 +389,54 @@
             {"aw": 0.4},
             [(4, 28), (36, 76)],
         ),
+        stereo_uc_None_analysis_window_0_2=(
+            0.2,
+            5,
+            0.2,
+            2,
+            {"analysis_window": 0.2},
+            [(2, 32), (34, 76)],
+        ),
+        stereo_uc_any_analysis_window_0_2=(
+            0.2,
+            5,
+            0.2,
+            2,
+            {"uc": None, "analysis_window": 0.2},
+            [(2, 32), (34, 76)],
+        ),
+        stereo_use_channel_None_aw_0_3_max_silence_0_2=(
+            0.2,
+            5,
+            0.2,
+            2,
+            {"use_channel": None, "analysis_window": 0.3},
+            [(3, 30), (36, 76)],
+        ),
+        stereo_use_channel_any_aw_0_3_max_silence_0_3=(
+            0.2,
+            5,
+            0.3,
+            2,
+            {"use_channel": "any", "analysis_window": 0.3},
+            [(3, 33), (36, 76)],
+        ),
+        stereo_use_channel_None_aw_0_4_max_silence_0_2=(
+            0.2,
+            5,
+            0.2,
+            2,
+            {"use_channel": None, "analysis_window": 0.4},
+            [(4, 28), (36, 76)],
+        ),
+        stereo_use_channel_any_aw_0_3_max_silence_0_4=(
+            0.2,
+            5,
+            0.4,
+            2,
+            {"use_channel": "any", "analysis_window": 0.4},
+            [(4, 32), (36, 76)],
+        ),
         stereo_uc_0_analysis_window_0_2=(
             0.2,
             5,
@@ -437,20 +485,20 @@
             {"uc": "mix", "analysis_window": 0.1},
             [(10, 32), (36, 76)],
         ),
-        stereo_uc_mix_aw_0_2_max_silence_0_min_dur_0_3=(
+        stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3=(
             0.3,
             5,
             0,
             2,
-            {"uc": "mix", "analysis_window": 0.2},
+            {"uc": "avg", "analysis_window": 0.2},
             [(10, 14), (16, 24), (36, 76)],
         ),
-        stereo_uc_mix_aw_0_2_max_silence_0_min_dur_0_41=(
+        stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41=(
             0.41,
             5,
             0,
             2,
-            {"uc": "mix", "analysis_window": 0.2},
+            {"uc": "average", "analysis_window": 0.2},
             [(16, 24), (36, 76)],
         ),
         stereo_uc_mix_aw_0_2_max_silence_0_1=(
@@ -593,7 +641,7 @@
             sr=10,
             sw=2,
             ch=channels,
-            eth= 49.99,
+            eth=49.99,
             **kwargs
         )
 
@@ -604,7 +652,7 @@
             max_silence=max_silence,
             drop_trailing_silence=False,
             strict_min_dur=False,
-            eth= 49.99,
+            eth=49.99,
             **kwargs
         )