changeset 406:79bd3de43a5b

Accept pathlib.Path for io
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 19 Jun 2024 22:48:54 +0200
parents f56b4d8adfb8
children 6c33626d0bff
files auditok/io.py auditok/workers.py tests/test_io.py
diffstat 3 files changed, 158 insertions(+), 68 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Mon Jun 17 19:45:51 2024 +0200
+++ b/auditok/io.py	Wed Jun 19 22:48:54 2024 +0200
@@ -71,9 +71,19 @@
         )
 
 
-def _guess_audio_format(fmt, filename):
+def _guess_audio_format(filename, fmt):
+    """Helper function to guess audio format from file extension, or by
+    normalizing a user provided format.
+
+    Args:
+        filename (str, Path): audio file name.
+        fmt (str): un-normalized user provided format.
+
+    Returns:
+        str, None: guessed audio format or None if no format could be guessed.
+    """
     if fmt is None:
-        extension = os.path.splitext(filename.lower())[1][1:]
+        extension = os.path.splitext(filename)[1][1:].lower()
         if extension:
             fmt = extension
         else:
@@ -411,7 +421,7 @@
 
     Parameters
     ----------
-    filename : str
+    filename : str, Path
         path to a raw audio file.
     sampling_rate : int
         Number of samples per second of audio data.
@@ -421,15 +431,15 @@
         Number of channels of audio data.
     """
 
-    def __init__(self, file, sampling_rate, sample_width, channels):
+    def __init__(self, filename, sampling_rate, sample_width, channels):
         FileAudioSource.__init__(self, sampling_rate, sample_width, channels)
-        self._file = file
+        self._filename = filename
         self._audio_stream = None
         self._sample_size = sample_width * channels
 
     def open(self):
         if self._audio_stream is None:
-            self._audio_stream = open(self._file, "rb")
+            self._audio_stream = open(self._filename, "rb")
 
     def _read_from_stream(self, size):
         if size is None or size < 0:
@@ -449,12 +459,12 @@
 
     Parameters
     ----------
-    filename : str
+    filename : str, Path
         path to a valid wave file.
     """
 
     def __init__(self, filename):
-        self._filename = filename
+        self._filename = str(filename)  # wave requires an str filename
         self._audio_stream = None
         stream = wave.open(self._filename, "rb")
         FileAudioSource.__init__(
@@ -751,7 +761,9 @@
         )
 
 
-def _load_raw(file, sampling_rate, sample_width, channels, large_file=False):
+def _load_raw(
+    filename, sampling_rate, sample_width, channels, large_file=False
+):
     """
     Load a raw audio file with standard Python. If `large_file` is True, return
     a `RawAudioSource` object that reads data lazily from disk, otherwise load
@@ -759,7 +771,7 @@
 
     Parameters
     ----------
-    file : str
+    filename : str, Path
         path to a raw audio data file.
     sampling_rate : int
         sampling rate of audio data.
@@ -776,6 +788,7 @@
     source : RawAudioSource or BufferAudioSource
         an `AudioSource` that reads data from input file.
     """
+
     if None in (sampling_rate, sample_width, channels):
         raise AudioParameterError(
             "All audio parameters are required for raw audio files"
@@ -783,13 +796,13 @@
 
     if large_file:
         return RawAudioSource(
-            file,
+            filename,
             sampling_rate=sampling_rate,
             sample_width=sample_width,
             channels=channels,
         )
 
-    with open(file, "rb") as fp:
+    with open(filename, "rb") as fp:
         data = fp.read()
     return BufferAudioSource(
         data,
@@ -799,7 +812,7 @@
     )
 
 
-def _load_wave(file, large_file=False):
+def _load_wave(filename, large_file=False):
     """
     Load a wave audio file with standard Python. If `large_file` is True, return
     a `WaveAudioSource` object that reads data lazily from disk, otherwise load
@@ -807,7 +820,7 @@
 
     Parameters
     ----------
-    file : str
+    filename : str, Path
         path to a wav audio data file
     large_file : bool
         if True, return a `WaveAudioSource` otherwise a `BufferAudioSource`
@@ -818,9 +831,10 @@
     source : WaveAudioSource or BufferAudioSource
         an `AudioSource` that reads data from input file.
     """
+
     if large_file:
-        return WaveAudioSource(file)
-    with wave.open(file) as fp:
+        return WaveAudioSource(filename)
+    with wave.open(filename) as fp:
         channels = fp.getnchannels()
         srate = fp.getframerate()
         swidth = fp.getsampwidth()
@@ -830,30 +844,31 @@
     )
 
 
-def _load_with_pydub(file, audio_format=None):
+def _load_with_pydub(filename, audio_format=None):
     """
     Open compressed audio or video file using pydub. If a video file
     is passed, its audio track(s) are extracted and loaded.
 
     Parameters
     ----------
-    file : str
+    filename : str, Path
         path to audio file.
     audio_format : str, default: None
-        string, audio/video file format if known (e.g. raw, webm, wav, ogg)
+        audio file format if known (e.g. raw, webm, wav, ogg)
 
     Returns
     -------
     source : BufferAudioSource
         an `AudioSource` that reads data from input file.
     """
+
     func_dict = {
         "mp3": AudioSegment.from_mp3,
         "ogg": AudioSegment.from_ogg,
         "flv": AudioSegment.from_flv,
     }
     open_function = func_dict.get(audio_format, AudioSegment.from_file)
-    segment = open_function(file)
+    segment = open_function(filename)
     return BufferAudioSource(
         data=segment.raw_data,
         sampling_rate=segment.frame_rate,
@@ -890,7 +905,7 @@
 
     Parameters
     ----------
-    filename : str
+    filename : str, Path
         path to input audio or video file.
     audio_format : str
         audio format used to save data  (e.g. raw, webm, wav, ogg).
@@ -919,7 +934,7 @@
         raised if audio data cannot be read in the given
         format or if `format` is `raw` and one or more audio parameters are missing.
     """
-    audio_format = _guess_audio_format(audio_format, filename)
+    audio_format = _guess_audio_format(filename, audio_format)
 
     if audio_format == "raw":
         srate, swidth, channels = _get_audio_parameters(kwargs)
@@ -956,7 +971,7 @@
         raise AudioParameterError(
             "All audio parameters are required to save wave audio files"
         )
-    with wave.open(file, "w") as fp:
+    with wave.open(str(file), "w") as fp:
         fp.setframerate(sampling_rate)
         fp.setsampwidth(sample_width)
         fp.setnchannels(channels)
@@ -980,11 +995,11 @@
         segment.export(fp, format=audio_format)
 
 
-def to_file(data, file, audio_format=None, **kwargs):
+def to_file(data, filename, audio_format=None, **kwargs):
     """
     Writes audio data to file. If `audio_format` is `None`, output
     audio format will be guessed from extension. If `audio_format`
-    is `None` and `file` comes without an extension then audio
+    is `None` and `filename` comes without an extension then audio
     data will be written as a raw audio file.
 
     Parameters
@@ -992,7 +1007,7 @@
     data : bytes-like
         audio data to be written. Can be a `bytes`, `bytearray`,
         `memoryview`, `array` or `numpy.ndarray` object.
-    file : str
+    filename : str, Path
         path to output audio file.
     audio_format : str
         audio format used to save data (e.g. raw, webm, wav, ogg)
@@ -1010,18 +1025,18 @@
     audio parameters are missing. `AudioIOError` if audio data cannot be written
     in the desired format.
     """
-    audio_format = _guess_audio_format(audio_format, file)
+    audio_format = _guess_audio_format(filename, audio_format)
     if audio_format in (None, "raw"):
-        _save_raw(data, file)
+        _save_raw(data, filename)
         return
     sampling_rate, sample_width, channels = _get_audio_parameters(kwargs)
     if audio_format in ("wav", "wave"):
-        _save_wave(data, file, sampling_rate, sample_width, channels)
+        _save_wave(data, filename, sampling_rate, sample_width, channels)
     elif _WITH_PYDUB:
         _save_with_pydub(
-            data, file, audio_format, sampling_rate, sample_width, channels
+            data, filename, audio_format, sampling_rate, sample_width, channels
         )
     else:
         raise AudioIOError(
-            f"cannot write file format {audio_format} (file name: {file})"
+            f"cannot write file format {audio_format} (file name: {filename})"
         )
--- a/auditok/workers.py	Mon Jun 17 19:45:51 2024 +0200
+++ b/auditok/workers.py	Wed Jun 19 22:48:54 2024 +0200
@@ -32,9 +32,9 @@
         ) as proc:
             stdout, stderr = proc.communicate()
             return proc.returncode, stdout, stderr
-    except Exception:
+    except Exception as exc:
         err_msg = "Couldn't export audio using command: '{}'".format(command)
-        raise AudioEncodingError(err_msg)
+        raise AudioEncodingError(err_msg) from exc
 
 
 class Worker(Thread, metaclass=ABCMeta):
@@ -163,7 +163,7 @@
         sample_size_bytes = self._reader.sw * self._reader.ch
         self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes
         self._output_filename = filename
-        self._export_format = _guess_audio_format(export_format, filename)
+        self._export_format = _guess_audio_format(filename, export_format)
         if self._export_format is None:
             self._export_format = "wav"
         self._init_output_stream()
@@ -267,7 +267,7 @@
         except AudioEncodingError:
             try:
                 self._export_with_sox()
-            except AudioEncodingError:
+            except AudioEncodingError as exc:
                 warn_msg = "Couldn't save audio data in the desired format "
                 warn_msg += "'{}'. Either none of 'ffmpeg', 'avconv' or 'sox' "
                 warn_msg += "is installed or this format is not recognized.\n"
@@ -276,7 +276,7 @@
                     warn_msg.format(
                         self._export_format, self._tmp_output_filename
                     )
-                )
+                ) from exc
         finally:
             self._exported = True
         return self._output_filename
--- a/tests/test_io.py	Mon Jun 17 19:45:51 2024 +0200
+++ b/tests/test_io.py	Wed Jun 19 22:48:54 2024 +0200
@@ -2,7 +2,9 @@
 import math
 import os
 import sys
+import wave
 from array import array
+from pathlib import Path
 from tempfile import NamedTemporaryFile, TemporaryDirectory
 from unittest.mock import Mock, patch
 
@@ -66,28 +68,31 @@
 
 
 @pytest.mark.parametrize(
-    "fmt, filename, expected",
+    "filename, audio_format, expected",
     [
-        ("wav", "filename.wav", "wav"),  # extention_and_format_same
-        ("wav", "filename.mp3", "wav"),  # extention_and_format_different
-        (None, "filename.wav", "wav"),  # extention_no_format
-        ("wav", "filename", "wav"),  # format_no_extension
-        (None, "filename", None),  # no_format_no_extension
-        ("wave", "filename", "wav"),  # wave_as_wav
-        (None, "filename.wave", "wav"),  # wave_as_wav_extension
+        ("filename.wav", "wav", "wav"),  # extension_and_format_same
+        ("filename.mp3", "wav", "wav"),  # extension_and_format_different
+        ("filename.wav", None, "wav"),  # extension_no_format
+        ("filename", "wav", "wav"),  # format_no_extension
+        ("filename", None, None),  # no_format_no_extension
+        ("filename", "wave", "wav"),  # wave_as_wav
+        ("filename.wave", None, "wav"),  # wave_as_wav_extension
     ],
     ids=[
-        "extention_and_format_same",
-        "extention_and_format_different",
-        "extention_no_format",
+        "extension_and_format_same",
+        "extension_and_format_different",
+        "extension_no_format",
         "format_no_extension",
         "no_format_no_extension",
         "wave_as_wav",
         "wave_as_wav_extension",
     ],
 )
-def test_guess_audio_format(fmt, filename, expected):
-    result = _guess_audio_format(fmt, filename)
+def test_guess_audio_format(filename, audio_format, expected):
+    result = _guess_audio_format(filename, audio_format)
+    assert result == expected
+
+    result = _guess_audio_format(Path(filename), audio_format)
     assert result == expected
 
 
@@ -133,7 +138,6 @@
     ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"],
 )
 def test_get_audio_parameters_missing_parameter(missing_param):
-
     params = AUDIO_PARAMS.copy()
     del params[missing_param]
     with pytest.raises(AudioParameterError):
@@ -150,7 +154,6 @@
     ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"],
 )
 def test_get_audio_parameters_missing_parameter_short(missing_param):
-
     params = AUDIO_PARAMS_SHORT.copy()
     del params[missing_param]
     with pytest.raises(AudioParameterError):
@@ -240,22 +243,80 @@
     assert patch_function.called
 
 
-def test_from_file_large_file_raw():
-    filename = "tests/data/test_16KHZ_mono_400Hz.raw"
+@pytest.mark.parametrize(
+    "large_file, cls, size, use_pathlib",
+    [
+        (False, BufferAudioSource, -1, False),  # large_file_false_negative_size
+        (False, BufferAudioSource, None, False),  # large_file_false_None_size
+        (True, RawAudioSource, -1, False),  # large_file_true_negative_size
+        (True, RawAudioSource, None, False),  # large_file_true_None_size
+        (True, RawAudioSource, -1, True),  # large_file_true_negative_size_Path
+    ],
+    ids=[
+        "large_file_false_negative_size",
+        "large_file_false_None_size",
+        "large_file_true_negative_size",
+        "large_file_true_None_size",
+        "large_file_true_negative_size_Path",
+    ],
+)
+def test_from_file_raw_read_all(large_file, cls, size, use_pathlib):
+    filename = Path("tests/data/test_16KHZ_mono_400Hz.raw")
+    if use_pathlib:
+        filename = Path(filename)
     audio_source = from_file(
         filename,
-        large_file=True,
+        large_file=large_file,
         sampling_rate=16000,
         sample_width=2,
         channels=1,
     )
-    assert isinstance(audio_source, RawAudioSource)
+    assert isinstance(audio_source, cls)
 
+    with open(filename, "rb") as fp:
+        expected = fp.read()
+    audio_source.open()
+    data = audio_source.read(size)
+    audio_source.close()
+    assert data == expected
 
-def test_from_file_large_file_wave():
+
+@pytest.mark.parametrize(
+    "large_file, cls, size, use_pathlib",
+    [
+        (False, BufferAudioSource, -1, False),  # large_file_false_negative_size
+        (False, BufferAudioSource, None, False),  # large_file_false_None_size
+        (True, WaveAudioSource, -1, False),  # large_file_true_negative_size
+        (True, WaveAudioSource, None, False),  # large_file_true_None_size
+        (True, WaveAudioSource, -1, True),  # large_file_true_negative_size_Path
+    ],
+    ids=[
+        "large_file_false_negative_size",
+        "large_file_false_None_size",
+        "large_file_true_negative_size",
+        "large_file_true_None_size",
+        "large_file_true_negative_size_Path",
+    ],
+)
+def test_from_file_wave_read_all(large_file, cls, size, use_pathlib):
     filename = "tests/data/test_16KHZ_mono_400Hz.wav"
-    audio_source = from_file(filename, large_file=True)
-    assert isinstance(audio_source, WaveAudioSource)
+    if use_pathlib:
+        filename = Path(filename)
+    audio_source = from_file(
+        filename,
+        large_file=large_file,
+        sampling_rate=16000,
+        sample_width=2,
+        channels=1,
+    )
+    assert isinstance(audio_source, cls)
+
+    with wave.open(str(filename)) as fp:
+        expected = fp.readframes(-1)
+    audio_source.open()
+    data = audio_source.read(size)
+    audio_source.close()
+    assert data == expected
 
 
 def test_from_file_large_file_compressed():
@@ -466,15 +527,22 @@
 
 
 @pytest.mark.parametrize(
-    "filename, frequencies",
+    "filename, frequencies, use_pathlib",
     [
-        ("mono_400Hz.raw", (400,)),  # mono
-        ("3channel_400-800-1600Hz.raw", (400, 800, 1600)),  # three_channel
+        ("mono_400Hz.raw", (400,), False),  # mono
+        ("mono_400Hz.raw", (400,), True),  # mono_pathlib
+        (
+            "3channel_400-800-1600Hz.raw",
+            (400, 800, 1600),
+            False,
+        ),  # three_channel
     ],
-    ids=["mono", "three_channel"],
+    ids=["mono", "three_channel", "use_pathlib"],
 )
-def test_save_raw(filename, frequencies):
+def test_save_raw(filename, frequencies, use_pathlib):
     filename = "tests/data/test_16KHZ_{}".format(filename)
+    if use_pathlib:
+        filename = Path(filename)
     sample_width = 2
     dtype = SAMPLE_WIDTH_TO_DTYPE[sample_width]
     mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
@@ -485,15 +553,22 @@
 
 
 @pytest.mark.parametrize(
-    "filename, frequencies",
+    "filename, frequencies, use_pathlib",
     [
-        ("mono_400Hz.wav", (400,)),  # mono
-        ("3channel_400-800-1600Hz.wav", (400, 800, 1600)),  # three_channel
+        ("mono_400Hz.wav", (400,), False),  # mono
+        ("mono_400Hz.wav", (400,), True),  # mono_pathlib
+        (
+            "3channel_400-800-1600Hz.wav",
+            (400, 800, 1600),
+            False,
+        ),  # three_channel
     ],
-    ids=["mono", "three_channel"],
+    ids=["mono", "mono_pathlib", "three_channel"],
 )
-def test_save_wave(filename, frequencies):
+def test_save_wave(filename, frequencies, use_pathlib):
     filename = "tests/data/test_16KHZ_{}".format(filename)
+    if use_pathlib:
+        filename = str(filename)
     sampling_rate = 16000
     sample_width = 2
     channels = len(frequencies)