changeset 241:79b668c48fce

Make sure split keeps original number of channels
author Amine Sehili <amine.sehili@gmail.com>
date Fri, 26 Jul 2019 20:46:53 +0100
parents 173ffca58d23
children 90445f084929
files auditok/core.py auditok/util.py tests/test_core.py
diffstat 3 files changed, 185 insertions(+), 191 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Thu Jul 25 20:50:52 2019 +0100
+++ b/auditok/core.py	Fri Jul 26 20:46:53 2019 +0100
@@ -74,12 +74,10 @@
         nuumber of channels of audio data. Only needed for raw audio files.
     use_channel, uc: int, str
         which channel to use if input has multichannel audio data. Can be an
-        int (0 being the first channel), or one of the following special str
-        values:
-        - 'left': first channel (equivalent to 0)
-        - 'right': second channel (equivalent to 1)
-        - 'mix': compute average channel
-        Default: 0, use the first channel.
+        int (0 being the first channel), or one of the following values:
+            - None, "any": a valid frame from one any given channel makes
+              parallel frames from all other channels automatically valid.
+            - 'mix': compute average channel (i.e. mix down all channels)
     max_read, mr: float
         maximum data to read in seconds. Default: `None`, read until there is
         no more data to read.
@@ -132,8 +130,10 @@
         energy_threshold = kwargs.get(
             "energy_threshold", kwargs.get("eth", DEFAULT_ENERGY_THRESHOLD)
         )
-        validator = AudioEnergyValidator(source.sw, energy_threshold)
-
+        use_channel = kwargs.get("use_channel", kwargs.get("uc"))
+        validator = AudioEnergyValidator(
+            energy_threshold, source.sw, source.ch, use_channel=use_channel
+        )
     mode = (
         StreamTokenizer.DROP_TRAILING_SILENCE if drop_trailing_silence else 0
     )
--- a/auditok/util.py	Thu Jul 25 20:50:52 2019 +0100
+++ b/auditok/util.py	Fri Jul 26 20:46:53 2019 +0100
@@ -14,9 +14,9 @@
         ADSFactory.RecorderADS
         DataValidator
         AudioEnergyValidator
-
 """
 from __future__ import division
+import sys
 from abc import ABCMeta, abstractmethod
 import math
 from array import array
@@ -29,14 +29,16 @@
     get_audio_source,
 )
 from .exceptions import DuplicateArgument, TooSamllBlockDuration
-import sys
 
 try:
     import numpy
 
+    np = numpy
     _WITH_NUMPY = True
+    _FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
 except ImportError as e:
     _WITH_NUMPY = False
+    _FORMAT = {1: "b", 2: "h", 4: "i"}
 
 try:
     from builtins import str
@@ -56,6 +58,119 @@
 ]
 
 
+def make_channel_selector(sample_width, channels, selected=None):
+    fmt = _FORMAT.get(sample_width)
+    if fmt is None:
+        err_msg = "'sample_width' must be 1, 2 or 4, given: {}"
+        raise ValueError(err_msg.format(sample_width))
+
+    if channels == 1:
+        if _WITH_NUMPY:
+
+            def _as_array(data):
+                return np.frombuffer(data, dtype=fmt).astype(np.float64)
+
+        else:
+
+            def _as_array(data):
+                return array(fmt, data)
+
+        return _as_array
+
+    if isinstance(selected, int):
+        if selected < 0:
+            selected += channels
+        if selected < 0 or selected >= channels:
+            err_msg = "Selected channel must be >= -channels and < 'channels'"
+            err_msg += ", given: {}"
+            raise ValueError(err_msg.format(selected))
+        if _WITH_NUMPY:
+
+            def _extract_single_channel(data):
+                samples = np.frombuffer(data, dtype=fmt)
+                return samples[selected::channels].astype(np.float64)
+
+        else:
+
+            def _extract_single_channel(data):
+                samples = array(fmt, data)
+                return samples[selected::channels]
+
+        return _extract_single_channel
+
+    if selected in ("mix", "avg", "average"):
+        if _WITH_NUMPY:
+
+            def _average_channels(data):
+                array = np.frombuffer(data, dtype=fmt).astype(np.float64)
+                return array.reshape(-1, channels).mean(axis=1)
+
+        else:
+
+            def _average_channels(data):
+                all_channels = array(fmt, data)
+                mono_channels = [
+                    array(fmt, all_channels[ch::channels])
+                    for ch in range(channels)
+                ]
+                avg_arr = array(
+                    fmt,
+                    (
+                        sum(samples) // channels
+                        for samples in zip(*mono_channels)
+                    ),
+                )
+                return avg_arr
+
+        return _average_channels
+
+    if selected is None:
+        if _WITH_NUMPY:
+
+            def _split_channels(data):
+                array = np.frombuffer(data, dtype=fmt).astype(np.float64)
+                return array.reshape(-1, channels).T
+
+        else:
+
+            def _split_channels(data):
+                all_channels = array(fmt, data)
+                mono_channels = [
+                    array(fmt, all_channels[ch::channels])
+                    for ch in range(channels)
+                ]
+                return mono_channels
+
+        return _split_channels
+
+
+if _WITH_NUMPY:
+
+    def _calculate_energy_single_channel(x):
+        return 10 * np.log10(np.dot(x, x).clip(min=1e-20) / x.size)
+
+
+else:
+
+    def _calculate_energy_single_channel(x):
+        energy = max(sum(i ** 2 for i in x) / len(x), 1e-20)
+        return 10 * math.log10(energy)
+
+
+if _WITH_NUMPY:
+
+    def _calculate_energy_multichannel(x, aggregation_fn=np.max):
+        energy = 10 * np.log10((x * x).mean(axis=1).clip(min=1e-20))
+        return aggregation_fn(energy)
+
+
+else:
+
+    def _calculate_energy_multichannel(x, aggregation_fn=max):
+        energies = (_calculate_energy_single_channel(xi) for xi in x)
+        return aggregation_fn(energies)
+
+
 class DataSource:
     """
     Base class for objects passed to :func:`auditok.core.StreamTokenizer.tokenize`.
@@ -88,6 +203,23 @@
         """
 
 
+class AudioEnergyValidator(DataValidator):
+    def __init__(
+        self, energy_threshold, sample_width, channels, use_channel=None
+    ):
+        self._selector = make_channel_selector(
+            sample_width, channels, use_channel
+        )
+        if channels == 1 or use_channel is not None:
+            self._energy_fn = _calculate_energy_single_channel
+        else:
+            self._energy_fn = _calculate_energy_multichannel
+        self._energy_threshold = energy_threshold
+
+    def is_valid(self, data):
+        return self._energy_fn(self._selector(data)) > self._energy_threshold
+
+
 class StringDataSource(DataSource):
     """
     A class that represent a :class:`DataSource` as a string buffer.
@@ -881,107 +1013,3 @@
             raise AttributeError(
                 "'AudioDataSource' has no attribute '{}'".format(name)
             )
-
-
-class AudioEnergyValidator(DataValidator):
-    """
-    The most basic auditok audio frame validator.
-    This validator computes the log energy of an input audio frame
-    and return True if the result is >= a given threshold, False
-    otherwise.
-
-    :Parameters:
-
-    `sample_width` : *(int)*
-        Number of bytes of one audio sample. This is used to convert data from `basestring` or `Bytes` to
-        an array of floats.
-
-    `energy_threshold` : *(float)*
-        A threshold used to check whether an input data buffer is valid.
-    """
-
-    if _WITH_NUMPY:
-        _formats = {1: numpy.int8, 2: numpy.int16, 4: numpy.int32}
-
-        @staticmethod
-        def _convert(signal, sample_width):
-            return numpy.array(
-                numpy.frombuffer(
-                    signal, dtype=AudioEnergyValidator._formats[sample_width]
-                ),
-                dtype=numpy.float64,
-            )
-
-        @staticmethod
-        def _signal_energy(signal):
-            return float(numpy.dot(signal, signal)) / len(signal)
-
-        @staticmethod
-        def _signal_log_energy(signal):
-            energy = AudioEnergyValidator._signal_energy(signal)
-            if energy <= 0:
-                return -200
-            return 10.0 * numpy.log10(energy)
-
-    else:
-        _formats = {1: "b", 2: "h", 4: "i"}
-
-        @staticmethod
-        def _convert(signal, sample_width):
-            return array(
-                "d", array(AudioEnergyValidator._formats[sample_width], signal)
-            )
-
-        @staticmethod
-        def _signal_energy(signal):
-            energy = 0.0
-            for a in signal:
-                energy += a * a
-            return energy / len(signal)
-
-        @staticmethod
-        def _signal_log_energy(signal):
-            energy = AudioEnergyValidator._signal_energy(signal)
-            if energy <= 0:
-                return -200
-            return 10.0 * math.log10(energy)
-
-    def __init__(self, sample_width, energy_threshold=45):
-        self.sample_width = sample_width
-        self._energy_threshold = energy_threshold
-
-    def is_valid(self, data):
-        """
-        Check if data is valid. Audio data will be converted into an array (of
-        signed values) of which the log energy is computed. Log energy is computed
-        as follows:
-
-        .. code:: python
-
-            arr = AudioEnergyValidator._convert(signal, sample_width)
-            energy = float(numpy.dot(arr, arr)) / len(arr)
-            log_energy = 10. * numpy.log10(energy)
-
-
-        :Parameters:
-
-        `data` : either a *string* or a *Bytes* buffer
-            `data` is converted into a numerical array using the `sample_width`
-            given in the constructor.
-
-        :Returns:
-
-        True if `log_energy` >= `energy_threshold`, False otherwise.
-        """
-
-        signal = AudioEnergyValidator._convert(data, self.sample_width)
-        return (
-            AudioEnergyValidator._signal_log_energy(signal)
-            >= self._energy_threshold
-        )
-
-    def get_energy_threshold(self):
-        return self._energy_threshold
-
-    def set_energy_threshold(self, threshold):
-        self._energy_threshold = threshold
--- a/tests/test_core.py	Thu Jul 25 20:50:52 2019 +0100
+++ b/tests/test_core.py	Fri Jul 26 20:46:53 2019 +0100
@@ -182,25 +182,15 @@
             self.assertEqual(bytes(reg), exp_data)
 
     @genty_dataset(
-        stereo_all_default=(2, {}, [(2, 16), (17, 31), (34, 76)]),
+        stereo_all_default=(2, {}, [(2, 32), (34, 76)]),
         mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]),
         mono_max_read_short_name=(1, {"mr": 5}, [(2, 16), (17, 31), (34, 50)]),
         mono_use_channel_1=(
             1,
-            {"eth": 50, "use_channel": 1},
+            {"eth": 50, "use_channel": 0},
             [(2, 16), (17, 31), (34, 76)],
         ),
         mono_uc_1=(1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]),
-        mono_use_channel_left=(
-            1,
-            {"eth": 50, "use_channel": "left"},
-            [(2, 16), (17, 31), (34, 76)],
-        ),
-        mono_uc_left=(
-            1,
-            {"eth": 50, "uc": "left"},
-            [(2, 16), (17, 31), (34, 76)],
-        ),
         mono_use_channel_None=(
             1,
             {"eth": 50, "use_channel": None},
@@ -208,30 +198,20 @@
         ),
         stereo_use_channel_1=(
             2,
-            {"eth": 50, "use_channel": 1},
-            [(2, 16), (17, 31), (34, 76)],
-        ),
-        stereo_use_channel_left=(
-            2,
-            {"eth": 50, "use_channel": "left"},
+            {"eth": 50, "use_channel": 0},
             [(2, 16), (17, 31), (34, 76)],
         ),
         stereo_use_channel_no_use_channel_given=(
             2,
             {"eth": 50},
-            [(2, 16), (17, 31), (34, 76)],
+            [(2, 32), (34, 76)],
         ),
         stereo_use_channel_minus_2=(
             2,
             {"eth": 50, "use_channel": -2},
             [(2, 16), (17, 31), (34, 76)],
         ),
-        stereo_uc_2=(2, {"eth": 50, "uc": 2}, [(10, 32), (36, 76)]),
-        stereo_use_channel_right=(
-            2,
-            {"eth": 50, "use_channel": "right"},
-            [(10, 32), (36, 76)],
-        ),
+        stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]),
         stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]),
         mono_uc_mix=(
             1,
@@ -272,23 +252,16 @@
         )
         regions = list(regions)
         sample_width = 2
-        import numpy as np
-
-        use_channel = kwargs.get("use_channel", kwargs.get("uc"))
-        # extrat channel of interest
-        if channels != 1:
-            use_channel = kwargs.get("use_channel", kwargs.get("uc"))
-            use_channel = _normalize_use_channel(use_channel)
-            data = _extract_selected_channel(
-                data, channels, sample_width, use_channel=use_channel
-            )
         err_msg = "Wrong number of regions after split, expected: "
-        err_msg += "{}, found: {}".format(expected, regions)
+        err_msg += "{}, found: {}".format(len(expected), len(regions))
         self.assertEqual(len(regions), len(expected), err_msg)
+        sample_size_bytes = sample_width * channels
         for reg, exp in zip(regions, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_width : offset * sample_width]
-            self.assertEqual(bytes(reg), exp_data)
+            exp_data = data[
+                onset * sample_size_bytes : offset * sample_size_bytes
+            ]
+            self.assertEqual(len(bytes(reg)), len(exp_data))
 
     @genty_dataset(
         mono_aw_0_2_max_silence_0_2=(
@@ -296,7 +269,7 @@
             5,
             0.2,
             1,
-            {"uc": 1, "aw": 0.2},
+            {"aw": 0.2},
             [(2, 30), (34, 76)],
         ),
         mono_aw_0_2_max_silence_0_3=(
@@ -304,7 +277,7 @@
             5,
             0.3,
             1,
-            {"uc": 1, "aw": 0.2},
+            {"aw": 0.2},
             [(2, 30), (34, 76)],
         ),
         mono_aw_0_2_max_silence_0_4=(
@@ -312,7 +285,7 @@
             5,
             0.4,
             1,
-            {"uc": 1, "aw": 0.2},
+            {"aw": 0.2},
             [(2, 32), (34, 76)],
         ),
         mono_aw_0_2_max_silence_0=(
@@ -320,23 +293,16 @@
             5,
             0,
             1,
-            {"uc": 1, "aw": 0.2},
+            {"aw": 0.2},
             [(2, 14), (16, 24), (26, 28), (34, 76)],
         ),
-        mono_aw_0_2=(
-            0.2,
-            5,
-            0.2,
-            1,
-            {"uc": 1, "aw": 0.2},
-            [(2, 30), (34, 76)],
-        ),
+        mono_aw_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]),
         mono_aw_0_3_max_silence_0=(
             0.3,
             5,
             0,
             1,
-            {"uc": 1, "aw": 0.3},
+            {"aw": 0.3},
             [(3, 12), (15, 24), (36, 76)],
         ),
         mono_aw_0_3_max_silence_0_3=(
@@ -344,7 +310,7 @@
             5,
             0.3,
             1,
-            {"uc": 1, "aw": 0.3},
+            {"aw": 0.3},
             [(3, 27), (36, 76)],
         ),
         mono_aw_0_3_max_silence_0_5=(
@@ -352,7 +318,7 @@
             5,
             0.5,
             1,
-            {"uc": 1, "aw": 0.3},
+            {"aw": 0.3},
             [(3, 27), (36, 76)],
         ),
         mono_aw_0_3_max_silence_0_6=(
@@ -360,7 +326,7 @@
             5,
             0.6,
             1,
-            {"uc": 1, "aw": 0.3},
+            {"aw": 0.3},
             [(3, 30), (36, 76)],
         ),
         mono_aw_0_4_max_silence_0=(
@@ -368,7 +334,7 @@
             5,
             0,
             1,
-            {"uc": 1, "aw": 0.4},
+            {"aw": 0.4},
             [(4, 12), (16, 24), (36, 76)],
         ),
         mono_aw_0_4_max_silence_0_3=(
@@ -376,7 +342,7 @@
             5,
             0.3,
             1,
-            {"uc": 1, "aw": 0.4},
+            {"aw": 0.4},
             [(4, 12), (16, 24), (36, 76)],
         ),
         mono_aw_0_4_max_silence_0_4=(
@@ -384,23 +350,23 @@
             5,
             0.4,
             1,
-            {"uc": 1, "aw": 0.4},
+            {"aw": 0.4},
             [(4, 28), (36, 76)],
         ),
+        stereo_uc_0_analysis_window_0_2=(
+            0.2,
+            5,
+            0.2,
+            2,
+            {"uc": 0, "analysis_window": 0.2},
+            [(2, 30), (34, 76)],
+        ),
         stereo_uc_1_analysis_window_0_2=(
             0.2,
             5,
             0.2,
             2,
             {"uc": 1, "analysis_window": 0.2},
-            [(2, 30), (34, 76)],
-        ),
-        stereo_uc_2_analysis_window_0_2=(
-            0.2,
-            5,
-            0.2,
-            2,
-            {"uc": 2, "analysis_window": 0.2},
             [(10, 32), (36, 76)],
         ),
         stereo_uc_mix_aw_0_1_max_silence_0=(
@@ -597,20 +563,18 @@
         sample_width = 2
         import numpy as np
 
-        use_channel = kwargs.get("use_channel", kwargs.get("uc"))
-        # extrat channel of interest
-        if channels != 1:
-            use_channel = kwargs.get("use_channel", kwargs.get("uc"))
-            use_channel = _normalize_use_channel(use_channel)
-            data = _extract_selected_channel(
-                data, channels, sample_width, use_channel=use_channel
-            )
         err_msg = "Wrong number of regions after split, expected: "
         err_msg += "{}, found: {}".format(expected, regions)
         self.assertEqual(len(regions), len(expected), err_msg)
         for reg, exp in zip(regions, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_width : offset * sample_width]
+            exp_data = data[
+                onset
+                * sample_width
+                * channels : offset
+                * sample_width
+                * channels
+            ]
             self.assertEqual(bytes(reg), exp_data)
 
     @genty_dataset(
@@ -663,7 +627,7 @@
     )
     def test_split_input_type(self, input, kwargs):
 
-        with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp:
+        with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp:
             data = fp.read()
 
         regions = split(
@@ -677,14 +641,16 @@
             **kwargs
         )
         regions = list(regions)
-        expected = [(2, 16), (17, 31), (34, 76)]
+        expected = [(2, 32), (34, 76)]
         sample_width = 2
         err_msg = "Wrong number of regions after split, expected: "
         err_msg += "{}, found: {}".format(expected, regions)
         self.assertEqual(len(regions), len(expected), err_msg)
         for reg, exp in zip(regions, expected):
             onset, offset = exp
-            exp_data = data[onset * sample_width : offset * sample_width]
+            exp_data = data[
+                onset * sample_width * 2 : offset * sample_width * 2
+            ]
             self.assertEqual(bytes(reg), exp_data)
 
     @genty_dataset(