changeset 414:9f83c1ecb03b

implement 'join' for AudioRegion
author Amine Sehili <amine.sehili@gmail.com>
date Tue, 15 Oct 2024 21:56:12 +0200
parents 0a6bc66562d3
children e26dcf224846
files auditok/core.py tests/test_core.py
diffstat 2 files changed, 111 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Fri Jun 21 20:12:53 2024 +0200
+++ b/auditok/core.py	Tue Oct 15 21:56:12 2024 +0200
@@ -14,7 +14,7 @@
 from dataclasses import dataclass, field
 from pathlib import Path
 
-from .exceptions import TooSmallBlockDuration
+from .exceptions import AudioParameterError, TooSmallBlockDuration
 from .io import check_audio_data, get_audio_source, player_for, to_file
 from .plotting import plot
 from .util import AudioEnergyValidator, AudioReader, DataValidator
@@ -897,6 +897,34 @@
         )
         return regions
 
+    def _check_other_parameters(self, other):
+        if other.sr != self.sr:
+            raise AudioParameterError(
+                "Can only concatenate AudioRegions of the same "
+                "sampling rate ({} != {})".format(self.sr, other.sr)
+            )
+        if other.sw != self.sw:
+            raise AudioParameterError(
+                "Can only concatenate AudioRegions of the same "
+                "sample width ({} != {})".format(self.sw, other.sw)
+            )
+        if other.ch != self.ch:
+            raise AudioParameterError(
+                "Can only concatenate AudioRegions of the same "
+                "number of channels ({} != {})".format(self.ch, other.ch)
+            )
+
+    def _check_iter_others(self, others):
+        for other in others:
+            self._check_other_parameters(other)
+            yield other
+
+    def join(self, others):
+        data = self.data.join(
+            other.data for other in self._check_iter_others(others)
+        )
+        return AudioRegion(data, self.sr, self.sw, self.ch)
+
     @property
     def samples(self):
         """Audio region as arrays of samples, one array per channel."""
@@ -950,21 +978,7 @@
                 "Can only concatenate AudioRegion, "
                 'not "{}"'.format(type(other))
             )
-        if other.sr != self.sr:
-            raise ValueError(
-                "Can only concatenate AudioRegions of the same "
-                "sampling rate ({} != {})".format(self.sr, other.sr)
-            )
-        if other.sw != self.sw:
-            raise ValueError(
-                "Can only concatenate AudioRegions of the same "
-                "sample width ({} != {})".format(self.sw, other.sw)
-            )
-        if other.ch != self.ch:
-            raise ValueError(
-                "Can only concatenate AudioRegions of the same "
-                "number of channels ({} != {})".format(self.ch, other.ch)
-            )
+        self._check_other_parameters(other)
         data = self.data + other.data
         return AudioRegion(data, self.sr, self.sw, self.ch)
 
--- a/tests/test_core.py	Fri Jun 21 20:12:53 2024 +0200
+++ b/tests/test_core.py	Tue Oct 15 21:56:12 2024 +0200
@@ -1537,6 +1537,84 @@
 
 
 @pytest.mark.parametrize(
+    "sampling_rate, sample_width, channels",
+    [
+        (16000, 1, 1),  # mono_16K_1byte
+        (16000, 2, 1),  # mono_16K_2byte
+        (44100, 2, 2),  # stereo_44100_2byte
+        (44100, 2, 3),  # 3channel_44100_2byte
+    ],
+    ids=[
+        "mono_16K_1byte",
+        "mono_16K_2byte",
+        "stereo_44100_2byte",
+        "3channel_44100_2byte",
+    ],
+)
+def test_join(sampling_rate, sample_width, channels):
+    duration = 1
+    size = int(duration * sampling_rate * sample_width * channels)
+    glue_data = b"\0" * size
+    regions_data = [
+        b"\1" * int(size * 1.5),
+        b"\2" * int(size * 0.5),
+        b"\3" * int(size * 0.75),
+    ]
+
+    glue_region = AudioRegion(glue_data, sampling_rate, sample_width, channels)
+    regions = [
+        AudioRegion(data, sampling_rate, sample_width, channels)
+        for data in regions_data
+    ]
+    joined = glue_region.join(regions)
+    assert joined.data == glue_data.join(regions_data)
+    assert joined.duration == duration * 2 + 1.5 + 0.5 + 0.75
+
+
+@pytest.mark.parametrize(
+    "sampling_rate, sample_width, channels",
+    [
+        (32000, 1, 1),  # different_sampling_rate
+        (16000, 2, 1),  # different_sample_width
+        (16000, 1, 2),  # different_channels
+    ],
+    ids=[
+        "different_sampling_rate",
+        "different_sample_width",
+        "different_channels",
+    ],
+)
+def test_join_exception(sampling_rate, sample_width, channels):
+
+    glue_sampling_rate = 16000
+    glue_sample_width = 1
+    glue_channels = 1
+
+    duration = 1
+    size = int(
+        duration * glue_sampling_rate * glue_sample_width * glue_channels
+    )
+    glue_data = b"\0" * size
+    glue_region = AudioRegion(
+        glue_data, glue_sampling_rate, glue_sample_width, glue_channels
+    )
+
+    size = int(duration * sampling_rate * sample_width * channels)
+    regions_data = [
+        b"\1" * int(size * 1.5),
+        b"\2" * int(size * 0.5),
+        b"\3" * int(size * 0.75),
+    ]
+    regions = [
+        AudioRegion(data, sampling_rate, sample_width, channels)
+        for data in regions_data
+    ]
+
+    with pytest.raises(AudioParameterError):
+        glue_region.join(regions)
+
+
+@pytest.mark.parametrize(
     "region, slice_, expected_data",
     [
         (
@@ -1886,7 +1964,7 @@
     region_1 = AudioRegion(b"a" * 100, 8000, 1, 1)
     region_2 = AudioRegion(b"b" * 100, 3000, 1, 1)
 
-    with pytest.raises(ValueError) as val_err:
+    with pytest.raises(AudioParameterError) as val_err:
         region_1 + region_2
     assert str(val_err.value) == (
         "Can only concatenate AudioRegions of the same "
@@ -1898,7 +1976,7 @@
     region_1 = AudioRegion(b"a" * 100, 8000, 2, 1)
     region_2 = AudioRegion(b"b" * 100, 8000, 4, 1)
 
-    with pytest.raises(ValueError) as val_err:
+    with pytest.raises(AudioParameterError) as val_err:
         region_1 + region_2
     assert str(val_err.value) == (
         "Can only concatenate AudioRegions of the same sample width (2 != 4)"
@@ -1909,7 +1987,7 @@
     region_1 = AudioRegion(b"a" * 100, 8000, 1, 1)
     region_2 = AudioRegion(b"b" * 100, 8000, 1, 2)
 
-    with pytest.raises(ValueError) as val_err:
+    with pytest.raises(AudioParameterError) as val_err:
         region_1 + region_2
     assert str(val_err.value) == (
         "Can only concatenate AudioRegions of the same "