changeset 243:f55007434c6b

Move signal processing functions to a separate module
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 29 Jul 2019 20:37:31 +0100
parents 90445f084929
children ee6d2294cdd5
files auditok/signal.py auditok/signal_numpy.py auditok/util.py
diffstat 3 files changed, 92 insertions(+), 104 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/auditok/signal.py	Mon Jul 29 20:37:31 2019 +0100
@@ -0,0 +1,43 @@
+from array import array
+import math
+
+FORMAT = {1: "b", 2: "h", 4: "i"}
+_EPSILON = 1e-20
+
+
+def to_array(data, fmt):
+    return array(fmt, data)
+
+
+def extract_single_channel(data, fmt, channels, selected):
+    samples = array(fmt, data)
+    return samples[selected::channels]
+
+
+def average_channels(data, fmt, channels):
+    all_channels = array(fmt, data)
+    mono_channels = [
+        array(fmt, all_channels[ch::channels]) for ch in range(channels)
+    ]
+    avg_arr = array(
+        fmt, (sum(samples) // channels for samples in zip(*mono_channels))
+    )
+    return avg_arr
+
+
+def separate_channels(data, fmt, channels):
+    all_channels = array(fmt, data)
+    mono_channels = [
+        array(fmt, all_channels[ch::channels]) for ch in range(channels)
+    ]
+    return mono_channels
+
+
+def calculate_energy_single_channel(x):
+    energy = max(sum(i ** 2 for i in x) / len(x), 1e-20)
+    return 10 * math.log10(energy)
+
+
+def calculate_energy_multichannel(x, aggregation_fn=max):
+    energies = (calculate_energy_single_channel(xi) for xi in x)
+    return aggregation_fn(energies)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/auditok/signal_numpy.py	Mon Jul 29 20:37:31 2019 +0100
@@ -0,0 +1,32 @@
+import numpy as np
+
+FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
+_EPSILON = 1e-20
+
+
+def to_array(data, fmt):
+    return np.frombuffer(data, dtype=fmt).astype(np.float64)
+
+
+def extract_single_channel(data, fmt, channels, selected):
+    samples = np.frombuffer(data, dtype=fmt)
+    return samples[selected::channels].astype(np.float64)
+
+
+def average_channels(data, fmt, channels):
+    array = np.frombuffer(data, dtype=fmt).astype(np.float64)
+    return array.reshape(-1, channels).mean(axis=1)
+
+
+def separate_channels(data, fmt, channels):
+    array = np.frombuffer(data, dtype=fmt).astype(np.float64)
+    return array.reshape(-1, channels).T
+
+
+def calculate_energy_single_channel(x):
+    return 10 * np.log10(np.dot(x, x).clip(min=_EPSILON) / x.size)
+
+
+def calculate_energy_multichannel(x, aggregation_fn=np.max):
+    energy = 10 * np.log10((x * x).mean(axis=1).clip(min=_EPSILON))
+    return aggregation_fn(energy)
--- a/auditok/util.py	Sat Jul 27 22:49:45 2019 +0100
+++ b/auditok/util.py	Mon Jul 29 20:37:31 2019 +0100
@@ -20,6 +20,7 @@
 from abc import ABCMeta, abstractmethod
 import math
 from array import array
+from functools import partial
 from .io import (
     AudioIOError,
     AudioSource,
@@ -31,14 +32,9 @@
 from .exceptions import DuplicateArgument, TooSamllBlockDuration
 
 try:
-    import numpy
-
-    np = numpy
-    _WITH_NUMPY = True
-    _FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
+    import signal_numpy as signal
 except ImportError as e:
-    _WITH_NUMPY = False
-    _FORMAT = {1: "b", 2: "h", 4: "i"}
+    from . import signal
 
 try:
     from builtins import str
@@ -59,23 +55,13 @@
 
 
 def make_channel_selector(sample_width, channels, selected=None):
-    fmt = _FORMAT.get(sample_width)
+    fmt = signal.FORMAT.get(sample_width)
     if fmt is None:
         err_msg = "'sample_width' must be 1, 2 or 4, given: {}"
         raise ValueError(err_msg.format(sample_width))
 
     if channels == 1:
-        if _WITH_NUMPY:
-
-            def _as_array(data):
-                return np.frombuffer(data, dtype=fmt).astype(np.float64)
-
-        else:
-
-            def _as_array(data):
-                return array(fmt, data)
-
-        return _as_array
+        return partial(signal.to_array, fmt=fmt)
 
     if isinstance(selected, int):
         if selected < 0:
@@ -84,91 +70,18 @@
             err_msg = "Selected channel must be >= -channels and < 'channels'"
             err_msg += ", given: {}"
             raise ValueError(err_msg.format(selected))
-        if _WITH_NUMPY:
-
-            def _extract_single_channel(data):
-                samples = np.frombuffer(data, dtype=fmt)
-                return samples[selected::channels].astype(np.float64)
-
-        else:
-
-            def _extract_single_channel(data):
-                samples = array(fmt, data)
-                return samples[selected::channels]
-
-        return _extract_single_channel
+        return partial(
+            signal.extract_single_channel,
+            fmt=fmt,
+            channels=channels,
+            selected=selected,
+        )
 
     if selected in ("mix", "avg", "average"):
-        if _WITH_NUMPY:
+        return partial(signal.average_channels, fmt=fmt, channels=channels)
 
-            def _average_channels(data):
-                array = np.frombuffer(data, dtype=fmt).astype(np.float64)
-                return array.reshape(-1, channels).mean(axis=1)
-
-        else:
-
-            def _average_channels(data):
-                all_channels = array(fmt, data)
-                mono_channels = [
-                    array(fmt, all_channels[ch::channels])
-                    for ch in range(channels)
-                ]
-                avg_arr = array(
-                    fmt,
-                    (
-                        sum(samples) // channels
-                        for samples in zip(*mono_channels)
-                    ),
-                )
-                return avg_arr
-
-        return _average_channels
-
-    if selected is None:
-        if _WITH_NUMPY:
-
-            def _split_channels(data):
-                array = np.frombuffer(data, dtype=fmt).astype(np.float64)
-                return array.reshape(-1, channels).T
-
-        else:
-
-            def _split_channels(data):
-                all_channels = array(fmt, data)
-                mono_channels = [
-                    array(fmt, all_channels[ch::channels])
-                    for ch in range(channels)
-                ]
-                return mono_channels
-
-        return _split_channels
-
-
-if _WITH_NUMPY:
-
-    def _calculate_energy_single_channel(x):
-        return 10 * np.log10(np.dot(x, x).clip(min=1e-20) / x.size)
-
-
-else:
-
-    def _calculate_energy_single_channel(x):
-        energy = max(sum(i ** 2 for i in x) / len(x), 1e-20)
-        return 10 * math.log10(energy)
-
-
-if _WITH_NUMPY:
-
-    def _calculate_energy_multichannel(x, aggregation_fn=np.max):
-        energy = 10 * np.log10((x * x).mean(axis=1).clip(min=1e-20))
-        return aggregation_fn(energy)
-
-
-else:
-
-    def _calculate_energy_multichannel(x, aggregation_fn=max):
-        energies = (_calculate_energy_single_channel(xi) for xi in x)
-        return aggregation_fn(energies)
+    if selected in (None, "any"):
+        return partial(signal.separate_channels, fmt=fmt, channels=channels)
 
 
 class DataSource:
@@ -211,13 +124,13 @@
             sample_width, channels, use_channel
         )
         if channels == 1 or use_channel is not None:
-            self._energy_fn = _calculate_energy_single_channel
+            self._energy_fn = signal.calculate_energy_single_channel
         else:
-            self._energy_fn = _calculate_energy_multichannel
+            self._energy_fn = signal.calculate_energy_multichannel
         self._energy_threshold = energy_threshold
 
     def is_valid(self, data):
-        return self._energy_fn(self._selector(data)) > self._energy_threshold
+        return self._energy_fn(self._selector(data)) >= self._energy_threshold
 
 
 class StringDataSource(DataSource):