changeset 411:0e938065a2db

AudioRegion as a dataclass
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 20 Jun 2024 21:45:08 +0200
parents 9c9112e23c1c
children 5a6685f1e42d
files auditok/core.py tests/test_core.py
diffstat 2 files changed, 117 insertions(+), 89 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Wed Jun 19 23:30:18 2024 +0200
+++ b/auditok/core.py	Thu Jun 20 21:45:08 2024 +0200
@@ -10,6 +10,9 @@
 
 import math
 import os
+import warnings
+from dataclasses import dataclass, field
+from pathlib import Path
 
 from .exceptions import TooSmallBlockDuration
 from .io import check_audio_data, get_audio_source, player_for, to_file
@@ -378,9 +381,7 @@
     """
     start = start_frame * frame_duration
     data = b"".join(data_frames)
-    duration = len(data) / (sampling_rate * sample_width * channels)
-    meta = {"start": start, "end": start + duration}
-    return AudioRegion(data, sampling_rate, sample_width, channels, meta)
+    return AudioRegion(data, sampling_rate, sample_width, channels, start)
 
 
 def _read_chunks_online(max_read, **kwargs):
@@ -537,6 +538,13 @@
     """A class to store `AudioRegion`'s metadata."""
 
     def __getattr__(self, name):
+        warnings.warn(
+            "`AudioRegion.meta` is deprecated and will be removed in future "
+            "versions. For the 'start' and 'end' fields, please use "
+            "`AudioRegion.start` and `AudioRegion.end`.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
         if name in self:
             return self[name]
         else:
@@ -553,6 +561,7 @@
         return str(self)
 
 
+@dataclass(frozen=True)
 class AudioRegion(object):
     """
     AudioRegion encapsulates raw audio data and provides an interface to
@@ -581,15 +590,20 @@
     See also
     --------
     AudioRegion.load
-
     """
 
-    def __init__(self, data, sampling_rate, sample_width, channels, meta=None):
+    data: bytes
+    sampling_rate: int
+    sample_width: int
+    channels: int
+    start: float = field(default=None, repr=None)
+
+    def init__(self, data, sampling_rate, sample_width, channels, meta=None):
         check_audio_data(data, sample_width, channels)
-        self._data = data
-        self._sampling_rate = sampling_rate
-        self._sample_width = sample_width
-        self._channels = channels
+        self.data = data
+        self.sampling_rate = sampling_rate
+        self.sample_width = sample_width
+        self.channels = channels
         self._samples = None
         self.splitp = self.split_and_plot
 
@@ -605,14 +619,36 @@
         self._millis_view = _MillisView(self)
         self.ms = self.millis
 
-    @property
-    def meta(self):
-        return self._meta
+    def __post_init__(self):
 
-    @meta.setter
-    def meta(self, new_meta):
-        """Meta data of audio region."""
-        self._meta = _AudioRegionMetadata(new_meta)
+        check_audio_data(self.data, self.sample_width, self.channels)
+
+        object.__setattr__(self, "splitp", self.split_and_plot)
+        object.__setattr__(self, "_samples", None)
+
+        duration = len(self.data) / (
+            self.sampling_rate * self.sample_width * self.channels
+        )
+        object.__setattr__(self, "duration", duration)
+
+        if self.start is not None:
+            object.__setattr__(self, "end", self.start + self.duration)
+            object.__setattr__(
+                self,
+                "meta",
+                _AudioRegionMetadata({"start": self.start, "end": self.end}),
+            )
+        else:
+            object.__setattr__(self, "end", None)
+            object.__setattr__(self, "meta", None)
+
+        # `seconds` and `millis` are defined below as @property with docstring
+        object.__setattr__(self, "_seconds_view", _SecondsView(self))
+        object.__setattr__(self, "_millis_view", _MillisView(self))
+
+        object.__setattr__(self, "sec", self.seconds)
+        object.__setattr__(self, "s", self.seconds)
+        object.__setattr__(self, "ms", self.millis)
 
     @classmethod
     def load(cls, input, skip=0, max_read=None, **kwargs):
@@ -664,43 +700,19 @@
         return self._millis_view
 
     @property
-    def duration(self):
-        """
-        Returns region duration in seconds.
-        """
-        return len(self._data) / (
-            self.sampling_rate * self.sample_width * self.channels
-        )
-
-    @property
-    def sampling_rate(self):
-        """Sampling rate of audio data."""
-        return self._sampling_rate
-
-    @property
     def sr(self):
         """Sampling rate of audio data, alias for `sampling_rate`."""
-        return self._sampling_rate
-
-    @property
-    def sample_width(self):
-        """Number of bytes per sample, one channel considered."""
-        return self._sample_width
+        return self.sampling_rate
 
     @property
     def sw(self):
-        """Number of bytes per sample, alias for `sampling_rate`."""
-        return self._sample_width
-
-    @property
-    def channels(self):
-        """Number of channels of audio data."""
-        return self._channels
+        """Number of bytes per sample, alias for `sample_width`."""
+        return self.sample_width
 
     @property
     def ch(self):
         """Number of channels of audio data, alias for `channels`."""
-        return self._channels
+        return self.channels
 
     def play(self, progress_bar=False, player=None, **progress_bar_kwargs):
         """
@@ -721,24 +733,22 @@
         """
         if player is None:
             player = player_for(self)
-        player.play(
-            self._data, progress_bar=progress_bar, **progress_bar_kwargs
-        )
+        player.play(self.data, progress_bar=progress_bar, **progress_bar_kwargs)
 
-    def save(self, file, audio_format=None, exists_ok=True, **audio_parameters):
+    def save(
+        self, filename, audio_format=None, exists_ok=True, **audio_parameters
+    ):
         """
         Save audio region to file.
 
         Parameters
         ----------
-        file : str
-            path to output audio file. May contain `{duration}` placeholder
-            as well as any place holder that this region's metadata might
-            contain (e.g., regions returned by `split` contain metadata with
-            `start` and `end` attributes that can be used to build output file
-            name as `{meta.start}` and `{meta.end}`. See examples using
-            placeholders with formatting.
-
+        filename : str, Path
+            path to output audio file. If of type `str`, it may contain a
+            `{start}`, `{end}` and a `{duration}` placeholders.
+            Regions returned by `split` contain a `start` and and `end`
+            attributes that can be used to build output file name as in the
+            example.
         audio_format : str, default: None
             format used to save audio data. If None (default), format is guessed
             from file name's extension. If file name has no extension, audio
@@ -752,38 +762,55 @@
         Returns
         -------
         file: str
-            name of output file with replaced placehoders.
+            name of output file with filled placehoders.
         Raises
-            IOError if `file` exists and `exists_ok` is False.
+            IOError if `filename` exists and `exists_ok` is False.
 
 
         Examples
         --------
-        >>> region = AudioRegion(b'\\0' * 2 * 24000,
+        Create and AudioRegion, explicitly passing a value for `start`. `end`
+        will be computed based on `start` and the region's duration.
+
+        >>> region = AudioRegion(b'\0' * 2 * 24000,
         >>>                      sampling_rate=16000,
         >>>                      sample_width=2,
-        >>>                      channels=1)
-        >>> region.meta.start = 2.25
-        >>> region.meta.end = 2.25 + region.duration
-        >>> region.save('audio_{meta.start}-{meta.end}.wav')
-        >>> audio_2.25-3.75.wav
-        >>> region.save('region_{meta.start:.3f}_{duration:.3f}.wav')
-        audio_2.250_1.500.wav
+        >>>                      channels=1,
+        >>>                      start=2.25)
+        >>> region
+        <AudioRegion(duration=1.500, sampling_rate=16000, sample_width=2, channels=1)>
+
+        >>> assert region.end == 3.75
+        >>> assert region.save('audio_{start}-{end}.wav') == "audio_2.25-3.75.wav"
+        >>> filename = region.save('audio_{start:.3f}-{end:.3f}_{duration:.3f}.wav')
+        >>> assert filename == "audio_2.250-3.750_1.500.wav"
         """
-        if isinstance(file, str):
-            file = file.format(duration=self.duration, meta=self.meta)
-            if not exists_ok and os.path.exists(file):
-                raise FileExistsError("file '{file}' exists".format(file=file))
+        if isinstance(filename, Path):
+            if not exists_ok and filename.exists():
+                raise FileExistsError(
+                    "file '{filename}' exists".format(filename=str(filename))
+                )
+        if isinstance(filename, str):
+            filename = filename.format(
+                duration=self.duration,
+                meta=self.meta,
+                start=self.start,
+                end=self.end,
+            )
+            if not exists_ok and os.path.exists(filename):
+                raise FileExistsError(
+                    "file '{filename}' exists".format(filename=filename)
+                )
         to_file(
-            self._data,
-            file,
+            self.data,
+            filename,
             audio_format,
             sr=self.sr,
             sw=self.sw,
             ch=self.ch,
             audio_parameters=audio_parameters,
         )
-        return file
+        return filename
 
     def split(
         self,
@@ -899,11 +926,11 @@
     @property
     def samples(self):
         """Audio region as arrays of samples, one array per channel."""
-        if self._samples is None:
-            self._samples = signal.to_array(
-                self._data, self.sample_width, self.channels
+        if self._samples is None:  # TODO fixit
+            _samples = signal.to_array(
+                self.data, self.sample_width, self.channels
             )
-        return self._samples
+        return _samples
 
     def __array__(self):
         return self.samples
@@ -915,7 +942,7 @@
         """
         Return region length in number of samples.
         """
-        return len(self._data) // (self.sample_width * self.channels)
+        return len(self.data) // (self.sample_width * self.channels)
 
     @property
     def len(self):
@@ -925,7 +952,7 @@
         return len(self)
 
     def __bytes__(self):
-        return self._data
+        return self.data
 
     def __str__(self):
         return (
@@ -964,7 +991,7 @@
                 "Can only concatenate AudioRegions of the same "
                 "number of channels ({} != {})".format(self.ch, other.ch)
             )
-        data = self._data + other._data
+        data = self.data + other.data
         return AudioRegion(data, self.sr, self.sw, self.ch)
 
     def __radd__(self, other):
@@ -982,7 +1009,7 @@
         if not isinstance(n, int):
             err_msg = "Can't multiply AudioRegion by a non-int of type '{}'"
             raise TypeError(err_msg.format(type(n)))
-        data = self._data * n
+        data = self.data * n
         return AudioRegion(data, self.sr, self.sw, self.ch)
 
     def __rmul__(self, n):
@@ -1010,7 +1037,7 @@
         if not isinstance(other, AudioRegion):
             return False
         return (
-            (self._data == other._data)
+            (self.data == other.data)
             and (self.sr == other.sr)
             and (self.sw == other.sw)
             and (self.ch == other.ch)
@@ -1022,7 +1049,7 @@
         start_sample, stop_sample = _check_convert_index(index, (int), err_msg)
 
         bytes_per_sample = self.sample_width * self.channels
-        len_samples = len(self._data) // bytes_per_sample
+        len_samples = len(self.data) // bytes_per_sample
 
         if start_sample < 0:
             start_sample = max(start_sample + len_samples, 0)
@@ -1035,7 +1062,7 @@
         else:
             offset = None
 
-        data = self._data[onset:offset]
+        data = self.data[onset:offset]
         return AudioRegion(data, self.sr, self.sw, self.ch)
 
 
--- a/tests/test_core.py	Wed Jun 19 23:30:18 2024 +0200
+++ b/tests/test_core.py	Thu Jun 20 21:45:08 2024 +0200
@@ -1,5 +1,6 @@
 import math
 import os
+from pathlib import Path
 from random import random
 from tempfile import TemporaryDirectory
 from unittest.mock import Mock, patch
@@ -1362,8 +1363,7 @@
     expected_duration_s,
     expected_duration_ms,
 ):
-    meta = {"start": start, "end": expected_end}
-    region = AudioRegion(data, sampling_rate, sample_width, channels, meta)
+    region = AudioRegion(data, sampling_rate, sample_width, channels, start)
     assert region.sampling_rate == sampling_rate
     assert region.sr == sampling_rate
     assert region.sample_width == sample_width
@@ -1518,9 +1518,7 @@
 )
 def test_save(format, start, expected):
     with TemporaryDirectory() as tmpdir:
-        region = AudioRegion(b"0" * 160, 160, 1, 1)
-        meta = {"start": start, "end": start + region.duration}
-        region.meta = meta
+        region = AudioRegion(b"0" * 160, 160, 1, 1, start)
         format = os.path.join(tmpdir, format)
         filename = region.save(format)[len(tmpdir) + 1 :]
         assert filename == expected
@@ -1534,6 +1532,9 @@
         with pytest.raises(FileExistsError):
             region.save(filename, exists_ok=False)
 
+        with pytest.raises(FileExistsError):
+            region.save(Path(filename), exists_ok=False)
+
 
 @pytest.mark.parametrize(
     "region, slice_, expected_data",