changeset 308:bc963239143f

Refactor AudioRegion.load and add tests
author Amine Sehili <amine.sehili@gmail.com>
date Sat, 12 Oct 2019 16:23:17 +0100
parents 334c8760e80f
children 0ea9521c80d8
files auditok/core.py tests/test_core.py
diffstat 2 files changed, 119 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Fri Oct 11 20:58:49 2019 +0100
+++ b/auditok/core.py	Sat Oct 12 16:23:17 2019 +0100
@@ -273,6 +273,49 @@
     return AudioRegion(data, sampling_rate, sample_width, channels, meta)
 
 
+def _read_chunks_online(max_read, **kwargs):
+    reader = AudioReader(None, block_dur=0.5, max_read=max_read, **kwargs)
+    reader.open()
+    data = []
+    try:
+        while True:
+            frame = reader.read()
+            if frame is None:
+                break
+            data.append(frame)
+    except KeyboardInterrupt:
+        # Stop data acquisition from microphone when pressing
+        # Ctrl+C on a [i]python session or a notebook
+        pass
+    reader.close()
+    return (
+        b"".join(data),
+        reader.sampling_rate,
+        reader.sample_width,
+        reader.channels,
+    )
+
+
+def _read_offline(input, skip=0, max_read=None, **kwargs):
+    audio_source = get_audio_source(input, **kwargs)
+    audio_source.open()
+    if skip is not None and skip > 0:
+        skip_samples = round(skip * audio_source.sampling_rate)
+        audio_source.read(skip_samples)
+    if max_read is not None:
+        if max_read < 0:
+            max_read = None
+        else:
+            max_read = round(max_read * audio_source.sampling_rate)
+    data = audio_source.read(max_read)
+    return (
+        data,
+        audio_source.sampling_rate,
+        audio_source.sample_width,
+        audio_source.channels,
+    )
+
+
 def _check_convert_index(index, types, err_msg):
     if not isinstance(index, slice) or index.step is not None:
         raise TypeError(err_msg)
@@ -394,27 +437,24 @@
 
     @classmethod
     def load(cls, input, skip=0, max_read=None, **kwargs):
-        if input is None and max_read is None:
-            raise ValueError(
-                "'max_read' should not be None when reading from microphone"
+        if input is None:
+            if max_read is None or max_read < 0:
+                raise ValueError(
+                    "'max_read' should not be None when reading from microphone"
+                )
+            if skip > 0:
+                raise ValueError(
+                    "'skip' should be 0 when reading from microphone"
+                )
+            data, sampling_rate, sample_width, channels = _read_chunks_online(
+                max_read, **kwargs
             )
-        audio_source = get_audio_source(input, **kwargs)
-        audio_source.open()
-        if skip is not None and skip > 0:
-            skip_samples = int(skip * audio_source.sampling_rate)
-            audio_source.read(skip_samples)
-        if max_read is None or max_read < 0:
-            max_read_samples = None
         else:
-            max_read_samples = round(max_read * audio_source.sampling_rate)
-        data = audio_source.read(max_read_samples)
-        audio_source.close()
-        return cls(
-            data,
-            audio_source.sampling_rate,
-            audio_source.sample_width,
-            audio_source.channels,
-        )
+            data, sampling_rate, sample_width, channels = _read_offline(
+                input, skip=skip, max_read=max_read, **kwargs
+            )
+
+        return cls(data, sampling_rate, sample_width, channels)
 
     @property
     def sec(self):
--- a/tests/test_core.py	Fri Oct 11 20:58:49 2019 +0100
+++ b/tests/test_core.py	Sat Oct 12 16:23:17 2019 +0100
@@ -7,7 +7,7 @@
 from mock import patch
 from genty import genty, genty_dataset
 from auditok import split, AudioRegion, AudioParameterError
-from auditok.core import _duration_to_nb_windows
+from auditok.core import _duration_to_nb_windows, _read_chunks_online
 from auditok.util import AudioDataSource
 from auditok.io import (
     _normalize_use_channel,
@@ -1079,14 +1079,71 @@
             str(audio_param_err.exception),
         )
 
-    def test_load_exception(self):
+    @genty_dataset(
+        no_skip_read_all=(0, -1),
+        no_skip_read_all_stereo=(0, -1, 2),
+        skip_2_read_all=(2, -1),
+        skip_2_read_all_None=(2, None),
+        skip_2_read_3=(2, 3),
+        skip_2_read_3_5_stereo=(2, 3.5, 2),
+        skip_2_4_read_3_5_stereo=(2.4, 3.5, 2),
+    )
+    def test_load(self, skip, max_read, channels=1):
+        sampling_rate = 10
+        sample_width = 2
+        filename = "tests/data/test_split_10HZ_{}.raw"
+        filename = filename.format("mono" if channels == 1 else "stereo")
+        region = AudioRegion.load(
+            filename,
+            skip=skip,
+            max_read=max_read,
+            sr=sampling_rate,
+            sw=sample_width,
+            ch=channels,
+        )
+        with open(filename, "rb") as fp:
+            fp.read(round(skip * sampling_rate * sample_width * channels))
+            if max_read is None or max_read < 0:
+                to_read = -1
+            else:
+                to_read = round(
+                    max_read * sampling_rate * sample_width * channels
+                )
+            expected = fp.read(to_read)
+        self.assertEqual(bytes(region), expected)
+
+    def test_load_from_microphone(self):
+        with patch("auditok.io.PyAudioSource") as patch_pyaudio_source:
+            with patch("auditok.core.AudioReader.read") as patch_reader:
+                patch_reader.return_value = None
+                with patch(
+                    "auditok.core.AudioRegion.__init__"
+                ) as patch_AudioRegion:
+                    patch_AudioRegion.return_value = None
+                    AudioRegion.load(
+                        None, skip=0, max_read=5, sr=16000, sw=2, ch=1
+                    )
+        self.assertTrue(patch_pyaudio_source.called)
+        self.assertTrue(patch_reader.called)
+        self.assertTrue(patch_AudioRegion.called)
+
+    @genty_dataset(none=(None,), negative=(-1,))
+    def test_load_from_microphone_without_max_read_exception(self, max_read):
         with self.assertRaises(ValueError) as val_err:
-            AudioRegion.load(None, sr=16000, sw=2, ch=1)
+            AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1)
         self.assertEqual(
             "'max_read' should not be None when reading from microphone",
             str(val_err.exception),
         )
 
+    def test_load_from_microphone_with_nonzero_skip_exception(self):
+        with self.assertRaises(ValueError) as val_err:
+            AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1)
+        self.assertEqual(
+            "'skip' should be 0 when reading from microphone",
+            str(val_err.exception),
+        )
+
     @genty_dataset(
         simple=("output.wav", 1.230, "output.wav"),
         start=("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"),