changeset 189:d665d51309f8

Refactor get_audio_source
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 15 Apr 2019 21:12:05 +0100
parents 0914e845c21b
children 4d60f490bb5d
files auditok/io.py
diffstat 1 files changed, 54 insertions(+), 67 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Wed Apr 10 21:05:12 2019 +0100
+++ b/auditok/io.py	Mon Apr 15 21:12:05 2019 +0100
@@ -57,9 +57,10 @@
     "player_for",
 ]
 
-DEFAULT_SAMPLE_RATE = 16000
+DEFAULT_SAMPLING_RATE = 16000
 DEFAULT_SAMPLE_WIDTH = 2
 DEFAULT_NB_CHANNELS = 1
+DEFAULT_USE_CHANNEL = 0
 DATA_FORMAT = {1: "b", 2: "h", 4: "i"}
 
 
@@ -139,9 +140,7 @@
                                          channels,
                                          use_channel)
     """
-    err_message = (
-        "'{ln}' (or '{sn}') must be a positive integer, found: '{val}'"
-    )
+    err_message = "'{ln}' (or '{sn}') must be a positive integer, found: '{val}'"
     parameters = []
     for (long_name, short_name) in (
         ("sampling_rate", "sr"),
@@ -177,12 +176,8 @@
         return audioop.tomono(data, sample_width, 0.5, 0.5)
     fmt = DATA_FORMAT[sample_width]
     buffer = array(fmt, data)
-    mono_channels = [
-        array(fmt, buffer[ch::channels]) for ch in range(channels)
-    ]
-    avg_arr = array(
-        fmt, (sum(samples) // channels for samples in zip(*mono_channels))
-    )
+    mono_channels = [array(fmt, buffer[ch::channels]) for ch in range(channels)]
+    avg_arr = array(fmt, (sum(samples) // channels for samples in zip(*mono_channels)))
     return _array_to_bytes(avg_arr)
 
 
@@ -226,15 +221,13 @@
 
     def __init__(
         self,
-        sampling_rate=DEFAULT_SAMPLE_RATE,
+        sampling_rate=DEFAULT_SAMPLING_RATE,
         sample_width=DEFAULT_SAMPLE_WIDTH,
         channels=DEFAULT_NB_CHANNELS,
     ):
 
         if not sample_width in (1, 2, 4):
-            raise AudioParameterError(
-                "Sample width must be one of: 1, 2 or 4 (bytes)"
-            )
+            raise AudioParameterError("Sample width must be one of: 1, 2 or 4 (bytes)")
 
         self._sampling_rate = sampling_rate
         self._sample_width = sample_width
@@ -364,7 +357,7 @@
         """ Return the total number of already read samples """
         warnings.warn(
             "'get_position' is deprecated, use 'position' property instead",
-            DeprecationWarning
+            DeprecationWarning,
         )
         return self.position
 
@@ -372,7 +365,7 @@
         """ Return the total duration in seconds of already read data """
         warnings.warn(
             "'get_time_position' is deprecated, use 'position_s' or 'position_ms' properties instead",
-            DeprecationWarning
+            DeprecationWarning,
         )
         return self.position_s
 
@@ -386,7 +379,7 @@
         """
         warnings.warn(
             "'set_position' is deprecated, set 'position' property instead",
-            DeprecationWarning
+            DeprecationWarning,
         )
         self.position = position
 
@@ -400,7 +393,7 @@
         """
         warnings.warn(
             "'set_time_position' is deprecated, set 'position_s' or 'position_ms' properties instead",
-            DeprecationWarning
+            DeprecationWarning,
         )
         self.position_s = time_position
 
@@ -414,7 +407,7 @@
     def __init__(
         self,
         data_buffer,
-        sampling_rate=DEFAULT_SAMPLE_RATE,
+        sampling_rate=DEFAULT_SAMPLING_RATE,
         sample_width=DEFAULT_SAMPLE_WIDTH,
         channels=DEFAULT_NB_CHANNELS,
     ):
@@ -440,8 +433,7 @@
             raise AudioIOError("Stream is not open")
         bytes_to_read = self._sample_size_all_channels * size
         data = self._buffer[
-            self._current_position_bytes : self._current_position_bytes
-            + bytes_to_read
+            self._current_position_bytes : self._current_position_bytes + bytes_to_read
         ]
         if data:
             self._current_position_bytes += len(data)
@@ -554,9 +546,7 @@
 
 
 class RawAudioSource(_FileAudioSource, Rewindable):
-    def __init__(
-        self, file, sampling_rate, sample_width, channels, use_channel=0
-    ):
+    def __init__(self, file, sampling_rate, sample_width, channels, use_channel=0):
         _FileAudioSource.__init__(
             self, sampling_rate, sample_width, channels, use_channel
         )
@@ -614,7 +604,7 @@
 
     def __init__(
         self,
-        sampling_rate=DEFAULT_SAMPLE_RATE,
+        sampling_rate=DEFAULT_SAMPLING_RATE,
         sample_width=DEFAULT_SAMPLE_WIDTH,
         channels=DEFAULT_NB_CHANNELS,
         frames_per_buffer=1024,
@@ -673,7 +663,7 @@
 
     def __init__(
         self,
-        sampling_rate=DEFAULT_SAMPLE_RATE,
+        sampling_rate=DEFAULT_SAMPLING_RATE,
         sample_width=DEFAULT_SAMPLE_WIDTH,
         channels=DEFAULT_NB_CHANNELS,
         use_channel=0,
@@ -713,7 +703,7 @@
 
     def __init__(
         self,
-        sampling_rate=DEFAULT_SAMPLE_RATE,
+        sampling_rate=DEFAULT_SAMPLING_RATE,
         sample_width=DEFAULT_SAMPLE_WIDTH,
         channels=DEFAULT_NB_CHANNELS,
     ):
@@ -752,9 +742,7 @@
 
     def _chunk_data(self, data):
         # make audio chunks of 100 ms to allow interruption (like ctrl+c)
-        chunk_size = int(
-            (self.sampling_rate * self.sample_width * self.channels) / 10
-        )
+        chunk_size = int((self.sampling_rate * self.sample_width * self.channels) / 10)
         start = 0
         while start < len(data):
             yield data[start : start + chunk_size]
@@ -782,34 +770,49 @@
         audio_source.get_channels(),
     )
 
+
 def get_audio_source(input=None, **kwargs):
+    """
+    Create and return an AudioSource from input.
 
-    # read data from standard input
+    Parameters:
+
+        ´input´ : str, bytes, "-" or None
+        Source to read audio data from. If str, it should be a path to a valid
+        audio file. If bytes, it is interpreted as raw audio data. if equals to
+        "-", raw data will be read from stdin. If None, read audio data from
+        microphone using PyAudio.
+    """
+
+    sampling_rate = kwargs.get("sampling_rate", kwargs.get("sr", DEFAULT_SAMPLING_RATE))
+    sample_width = kwargs.get("sample_rate", kwargs.get("sw", DEFAULT_SAMPLE_WIDTH))
+    channels = kwargs.get("channels", kwargs.get("ch", DEFAULT_NB_CHANNELS))
+    use_channel = kwargs.get("use_channel", kwargs.get("uc", DEFAULT_USE_CHANNEL))
     if input == "-":
-        return StdinAudioSource(**kwargs)
+        return StdinAudioSource(sampling_rate, sample_width, channels, use_channel)
 
-    # create AudioSource from raw data
     if isinstance(input, bytes):
-        return BufferAudioSource(input, **kwargs)
+        return BufferAudioSource(input, sampling_rate, sample_width, channels)
 
     # read data from a file
     if input is not None:
-        return from_file(filename=input,
-                         audio_format=kwargs.get('audio_format'),
-                         **kwargs)
+        return from_file(filename=input, **kwargs)
 
     # read data from microphone via pyaudio
     else:
-        return PyAudioSource(**kwargs)
+        frames_per_buffer = kwargs.get("frames_per_buffer", 1024)
+        input_device_index = kwargs.get("input_device_index")
+        return PyAudioSource(
+            sampling_rate=sampling_rate,
+            sample_width=sample_width,
+            channels=channels,
+            frames_per_buffer=frames_per_buffer,
+            input_device_index=input_device_index,
+        )
 
 
 def _load_raw(
-    file,
-    sampling_rate,
-    sample_width,
-    channels,
-    use_channel=0,
-    large_file=False,
+    file, sampling_rate, sample_width, channels, use_channel=0, large_file=False
 ):
     """
     Load a raw audio file with standard Python.
@@ -855,14 +858,9 @@
             data = fp.read()
         if channels != 1:
             # TODO check if striding with mmap doesn't load all data to memory
-            data = _extract_selected_channel(
-                data, channels, sample_width, use_channel
-            )
+            data = _extract_selected_channel(data, channels, sample_width, use_channel)
         return BufferAudioSource(
-            data,
-            sampling_rate=sampling_rate,
-            sample_width=sample_width,
-            channels=1,
+            data, sampling_rate=sampling_rate, sample_width=sample_width, channels=1
         )
 
 
@@ -883,9 +881,7 @@
         data = fp.readframes(-1)
     if channels > 1:
         data = _extract_selected_channel(data, channels, swidth, use_channel)
-    return BufferAudioSource(
-        data, sampling_rate=srate, sample_width=swidth, channels=1
-    )
+    return BufferAudioSource(data, sampling_rate=srate, sample_width=swidth, channels=1)
 
 
 def _load_with_pydub(filename, audio_format, use_channel=0):
@@ -979,9 +975,7 @@
 
     if audio_format == "raw":
         srate, swidth, channels, use_channel = _get_audio_parameters(kwargs)
-        return _load_raw(
-            filename, srate, swidth, channels, use_channel, large_file
-        )
+        return _load_raw(filename, srate, swidth, channels, use_channel, large_file)
 
     use_channel = _normalize_use_channel(kwargs.get("use_channel"))
     if audio_format in ["wav", "wave"]:
@@ -993,9 +987,7 @@
             filename, audio_format=audio_format, use_channel=use_channel
         )
     else:
-        raise AudioIOError(
-            "pydub is required for audio formats other than raw or wav"
-        )
+        raise AudioIOError("pydub is required for audio formats other than raw or wav")
 
 
 def _save_raw(data, file):
@@ -1023,18 +1015,13 @@
         fp.writeframes(data)
 
 
-def _save_with_pydub(
-    data, file, audio_format, sampling_rate, sample_width, channels
-):
+def _save_with_pydub(data, file, audio_format, sampling_rate, sample_width, channels):
     """
     Saves audio data with pydub (https://github.com/jiaaro/pydub).
     See also :func:`to_file`.
     """
     segment = AudioSegment(
-        data,
-        frame_rate=sampling_rate,
-        sample_width=sample_width,
-        channels=channels,
+        data, frame_rate=sampling_rate, sample_width=sample_width, channels=channels
     )
     with open(file, "wb") as fp:
         segment.export(fp, format=audio_format)