changeset 296:5af0974b3446

Add Recorder class as an alias of AudioReader with record=True
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 07 Oct 2019 20:58:23 +0100
parents 49082909193c
children 7259b1eb9329
files auditok/util.py tests/test_AudioDataSource.py
diffstat 2 files changed, 58 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/util.py	Mon Oct 07 20:44:57 2019 +0100
+++ b/auditok/util.py	Mon Oct 07 20:58:23 2019 +0100
@@ -51,6 +51,7 @@
     "ADSFactory",
     "AudioDataSource",
     "AudioReader",
+    "Recorder",
     "AudioEnergyValidator",
 ]
 
@@ -72,10 +73,7 @@
             err_msg += ", given: {}"
             raise ValueError(err_msg.format(selected))
         return partial(
-            signal.extract_single_channel,
-            fmt=fmt,
-            channels=channels,
-            selected=selected,
+            signal.extract_single_channel, fmt=fmt, channels=channels, selected=selected
         )
 
     if selected in ("mix", "avg", "average"):
@@ -118,12 +116,8 @@
 
 
 class AudioEnergyValidator(DataValidator):
-    def __init__(
-        self, energy_threshold, sample_width, channels, use_channel=None
-    ):
-        self._selector = make_channel_selector(
-            sample_width, channels, use_channel
-        )
+    def __init__(self, energy_threshold, sample_width, channels, use_channel=None):
+        self._selector = make_channel_selector(sample_width, channels, use_channel)
         if channels == 1 or use_channel is not None:
             self._energy_fn = signal.calculate_energy_single_channel
         else:
@@ -305,13 +299,9 @@
         kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None)
         kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None)
         kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None)
-        kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop(
-            "asrc", None
-        )
+        kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop("asrc", None)
         kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None)
-        kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop(
-            "db", None
-        )
+        kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None)
 
         record = kwargs.pop("record", False)
         if not record:
@@ -328,19 +318,17 @@
             ) or kwargs.pop("fpb", None)
 
         if "sampling_rate" in kwargs or "sr" in kwargs:
-            kwargs["sampling_rate"] = kwargs.pop(
-                "sampling_rate", None
-            ) or kwargs.pop("sr", None)
+            kwargs["sampling_rate"] = kwargs.pop("sampling_rate", None) or kwargs.pop(
+                "sr", None
+            )
 
         if "sample_width" in kwargs or "sw" in kwargs:
-            kwargs["sample_width"] = kwargs.pop(
-                "sample_width", None
-            ) or kwargs.pop("sw", None)
+            kwargs["sample_width"] = kwargs.pop("sample_width", None) or kwargs.pop(
+                "sw", None
+            )
 
         if "channels" in kwargs or "ch" in kwargs:
-            kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop(
-                "ch", None
-            )
+            kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop("ch", None)
 
     @staticmethod
     def ads(**kwargs):
@@ -748,9 +736,7 @@
         super(_FixedSizeAudioReader, self).__init__(audio_source)
 
         if block_dur <= 0:
-            raise ValueError(
-                "block_dur must be > 0, given: {}".format(block_dur)
-            )
+            raise ValueError("block_dur must be > 0, given: {}".format(block_dur))
 
         self._block_size = int(block_dur * self.sr)
         if self._block_size == 0:
@@ -799,9 +785,7 @@
         if block is None:
             yield None
 
-        _hop_size_bytes = (
-            self._hop_size * self._audio_source.sw * self._audio_source.ch
-        )
+        _hop_size_bytes = self._hop_size * self._audio_source.sw * self._audio_source.ch
         cache = block[_hop_size_bytes:]
         yield block
 
@@ -846,13 +830,7 @@
     """
 
     def __init__(
-        self,
-        input,
-        block_dur=0.01,
-        hop_dur=None,
-        record=False,
-        max_read=None,
-        **kwargs
+        self, input, block_dur=0.01, hop_dur=None, record=False, max_read=None, **kwargs
     ):
         if not isinstance(input, AudioSource):
             input = get_audio_source(input, **kwargs)
@@ -877,10 +855,11 @@
         if self.max_read is not None:
             max_read = "{:.3f}".format(self.max_read)
         return (
-            "AudioReader(source, block_dur={block_dur}, "
+            "{cls}(block_dur={block_dur}, "
             "hop_dur={hop_dur}, record={rewindable}, "
             "max_read={max_read})"
         ).format(
+            cls=self.__class__.__name__,
             block_dur=block_dur,
             hop_dur=hop_dur,
             rewindable=self._record,
@@ -919,16 +898,25 @@
 
     def __getattr__(self, name):
         if name in ("data", "rewind") and not self.rewindable:
-            raise AttributeError(
-                "'AudioReader' has no attribute '{}'".format(name)
-            )
+            raise AttributeError("'AudioReader' has no attribute '{}'".format(name))
         try:
             return getattr(self._audio_source, name)
         except AttributeError:
-            raise AttributeError(
-                "'AudioReader' has no attribute '{}'".format(name)
-            )
+            raise AttributeError("'AudioReader' has no attribute '{}'".format(name))
+
 
 # Keep AudioDataSource for compatibility
 # Remove in a future version when ADSFactory is dropped
-AudioDataSource = AudioReader
\ No newline at end of file
+AudioDataSource = AudioReader
+
+
+class Recorder(AudioReader):
+    def __init__(self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs):
+        super().__init__(
+            input,
+            block_dur=block_dur,
+            hop_dur=hop_dur,
+            record=True,
+            max_read=max_read,
+            **kwargs
+        )
--- a/tests/test_AudioDataSource.py	Mon Oct 07 20:44:57 2019 +0100
+++ b/tests/test_AudioDataSource.py	Mon Oct 07 20:58:23 2019 +0100
@@ -13,6 +13,8 @@
     dataset,
     ADSFactory,
     AudioDataSource,
+    AudioReader,
+    Recorder,
     BufferAudioSource,
     WaveAudioSource,
     DuplicateArgument,
@@ -1034,7 +1036,7 @@
         with open(input_raw, "rb") as fp:
             expected = fp.read(size)
 
-        reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read)
+        reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read)
         reader.open()
         data = _read_all_data(reader)
         reader.close()
@@ -1047,7 +1049,7 @@
         with open(input_raw, "rb") as fp:
             expected = fp.read()
 
-        reader = AudioDataSource(input_wav, block_dur=0.1, record=True)
+        reader = AudioReader(input_wav, block_dur=0.1, record=True)
         reader.open()
         data = _read_all_data(reader)
         self.assertEqual(data, expected)
@@ -1060,6 +1062,26 @@
             self.assertEqual(data, reader.data)
         reader.close()
 
+    @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",))
+    def test_Recorder_alias(self, file_id):
+        input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
+        input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
+        with open(input_raw, "rb") as fp:
+            expected = fp.read()
+
+        reader = Recorder(input_wav, block_dur=0.1)
+        reader.open()
+        data = _read_all_data(reader)
+        self.assertEqual(data, expected)
+
+        # rewind many times
+        for _ in range(3):
+            reader.rewind()
+            data = _read_all_data(reader)
+            self.assertEqual(data, expected)
+            self.assertEqual(data, reader.data)
+        reader.close()
+
 
 if __name__ == "__main__":
     unittest.main()