changeset 371:8d3e2b492c6f

Add load function
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 11 Jan 2021 21:13:08 +0100
parents 4d9edd170403
children d653e3f58f3c
files auditok/core.py tests/test_core.py
diffstat 2 files changed, 51 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Sun Jan 10 22:54:47 2021 +0100
+++ b/auditok/core.py	Mon Jan 11 21:13:08 2021 +0100
@@ -2,6 +2,7 @@
 .. autosummary::
     :toctree: generated/
 
+    load
     split
     AudioRegion
     StreamTokenizer
@@ -17,7 +18,7 @@
 except ImportError:
     from . import signal
 
-__all__ = ["split", "AudioRegion", "StreamTokenizer"]
+__all__ = ["load", "split", "AudioRegion", "StreamTokenizer"]
 
 
 DEFAULT_ANALYSIS_WINDOW = 0.05
@@ -25,6 +26,13 @@
 _EPSILON = 1e-10
 
 
+def load(input, skip=0, max_read=None, **kwargs):
+    """Load audio data from a source and return it as an :class:`AudioRegion`.
+    For more information about the parameters see :meth:`AudioRegion.load`
+    """
+    return AudioRegion.load(input, skip, max_read, **kwargs)
+
+
 def split(
     input,
     min_dur=0.2,
@@ -562,17 +570,20 @@
         Parameters
         ----------
         input : None, str, bytes, AudioSource
-            source to load data from. If None, load data from microphone. If
-            bytes, create region from raw data. If str, load data from file.
-            Input can also an AudioSource object.
+            source to read audio data from. If `str`, it should be a path to a
+            valid audio file. If `bytes`, it is used as raw audio data. If it is
+            "-", raw data will be read from stdin. If None, read audio data from
+            the microphone using PyAudio. If of type `bytes` or is a path to a
+            raw audio file then `sampling_rate`, `sample_width` and `channels`
+            parameters (or their alias) are required. If it's an
+            :class:`AudioSource` object it's used directly to read data.
         skip : float, default: 0
             amount, in seconds, of audio data to skip from source. If read from
-            microphone, `skip` must be 0, otherwise a `ValueError` is raised.
+            a microphone, `skip` must be 0, otherwise a `ValueError` is raised.
         max_read : float, default: None
             amount, in seconds, of audio data to read from source. If read from
             microphone, `max_read` should not be None, otherwise a ValueError is
             raised.
-
         audio_format, fmt : str
             type of audio data (e.g., wav, ogg, flac, raw, etc.). This will only
             be used if `input` is a string path to an audio file. If not given,
--- a/tests/test_core.py	Sun Jan 10 22:54:47 2021 +0100
+++ b/tests/test_core.py	Mon Jan 11 21:13:08 2021 +0100
@@ -7,7 +7,7 @@
 from unittest import TestCase, mock
 from unittest.mock import patch
 from genty import genty, genty_dataset
-from auditok import split, AudioRegion, AudioParameterError
+from auditok import load, split, AudioRegion, AudioParameterError
 from auditok.core import (
     _duration_to_nb_windows,
     _make_audio_region,
@@ -35,6 +35,39 @@
 @genty
 class TestFunctions(TestCase):
     @genty_dataset(
+        no_skip_read_all=(0, -1),
+        no_skip_read_all_stereo=(0, -1, 2),
+        skip_2_read_all=(2, -1),
+        skip_2_read_all_None=(2, None),
+        skip_2_read_3=(2, 3),
+        skip_2_read_3_5_stereo=(2, 3.5, 2),
+        skip_2_4_read_3_5_stereo=(2.4, 3.5, 2),
+    )
+    def test_load(self, skip, max_read, channels=1):
+        sampling_rate = 10
+        sample_width = 2
+        filename = "tests/data/test_split_10HZ_{}.raw"
+        filename = filename.format("mono" if channels == 1 else "stereo")
+        region = load(
+            filename,
+            skip=skip,
+            max_read=max_read,
+            sr=sampling_rate,
+            sw=sample_width,
+            ch=channels,
+        )
+        with open(filename, "rb") as fp:
+            fp.read(round(skip * sampling_rate * sample_width * channels))
+            if max_read is None or max_read < 0:
+                to_read = -1
+            else:
+                to_read = round(
+                    max_read * sampling_rate * sample_width * channels
+                )
+            expected = fp.read(to_read)
+        self.assertEqual(bytes(region), expected)
+
+    @genty_dataset(
         zero_duration=(0, 1, None, 0),
         multiple=(0.3, 0.1, round, 3),
         not_multiple_ceil=(0.35, 0.1, math.ceil, 4),