changeset 210:74864841228a

Fix _get_audio_parameters
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 17 Jun 2019 21:03:58 +0100
parents 9047740c5092
children ed6b3cecb407
files auditok/core.py auditok/io.py tests/data/test_split_10HZ_stereo.raw tests/test_core.py tests/test_io.py
diffstat 5 files changed, 64 insertions(+), 53 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/core.py	Fri Jun 14 22:43:21 2019 +0100
+++ b/auditok/core.py	Mon Jun 17 21:03:58 2019 +0100
@@ -36,10 +36,6 @@
     input: str, bytes, AudioSource, AudioRegion, AudioDataSource
         input audio data. If str, it should be a path to an existing audio
         file. If bytes, input is considered as raw audio data.
-    audio_format: str
-        type of audio date (e.g., wav, ogg, raw, etc.). This will only be used
-        if ´input´ is a string path to audio file. If not given, audio type
-        will be guessed from file name extension or from file header.
     min_dur: float
         minimun duration in seconds of a detected audio event. Default: 0.2.
         Using large values, very short audio events (e.g., very short 1-word
@@ -57,11 +53,15 @@
         strict minimum duration. Drop an event if it is shorter than ´min_dur´
         even if it is continguous to the latest valid event. This happens if
         the the latest event had reached ´max_dur´.
-    analysis_window: float
+    analysis_window, aw: float
         duration of analysis window in seconds. Default: 0.05 second (50 ms).
         A value up to 0.1 second (100 ms) should be good for most use-cases.
         You might need a different value, especially if you use a custom
         validator.
+    audio_format, fmt: str
+        type of audio date (e.g., wav, ogg, raw, etc.). This will only be used
+        if ´input´ is a string path to audio file. If not given, audio type
+        will be guessed from file name extension or from file header.
     sampling_rate, sr: int
         sampling rate of audio data. Only needed for raw audio files/data.
     sample_width, sw: int
@@ -77,13 +77,13 @@
         - 'right': second channel (equivalent to 1)
         - 'mix': compute average channel
         Default: 0, use the first channel.
-    max_read: float
+    max_read, mr: float
         maximum data to read in seconds. Default: `None`, read until there is
         no more data to read.
-    validator: DataValidator
+    validator, val: DataValidator
         custom data validator. If ´None´ (default), an `AudioEnergyValidor` is
         used with the given energy threshold.
-    energy_threshold: float
+    energy_threshold, eth: float
         energy threshlod for audio activity detection, default: 50. If a custom
         validator is given, this argumemt will be ignored.
     """
@@ -92,21 +92,20 @@
         analysis_window = source.block_dur
     else:
         analysis_window = kwargs.get(
-            "analysis_window", DEFAULT_ANALYSIS_WINDOW
+            "analysis_window", kwargs.get("aw", DEFAULT_ANALYSIS_WINDOW)
         )
-        max_read = kwargs.get("max_read")
+
         params = kwargs.copy()
+        params["max_read"] = params.get("max_read", params.get("mr"))
         if isinstance(input, AudioRegion):
             params["sampling_rate"] = input.sr
             params["sample_width"] = input.sw
             params["channels"] = input.ch
             input = bytes(input)
 
-        source = AudioDataSource(
-            input, block_dur=analysis_window, max_read=max_read, **params
-        )
+        source = AudioDataSource(input, block_dur=analysis_window, **params)
 
-    validator = kwargs.get("validator")
+    validator = kwargs.get("validator", kwargs.get("val"))
     if validator is None:
         energy_threshold = kwargs.get(
             "energy_threshold", kwargs.get("eth", DEFAULT_ENERGY_THRESHOLD)
--- a/auditok/io.py	Fri Jun 14 22:43:21 2019 +0100
+++ b/auditok/io.py	Mon Jun 17 21:03:58 2019 +0100
@@ -107,17 +107,19 @@
     str 'mix' returns it as is. If it's `left` or `right` returns 0 or 1
     respectively.
     """
+    err_message = "'use_channel' parameter must be a non-zero integer or one of "
+    err_message += "('left', 'right', 'mix'), found: '{}'"
     if use_channel is None:
         return 0
     if use_channel == "mix":
         return "mix"
     if isinstance(use_channel, int):
-        return use_channel - 1
+        if use_channel == 0:
+            raise AudioParameterError(err_message.format(use_channel))
+        return use_channel - 1 if use_channel > 0 else use_channel
     try:
         return ["left", "right"].index(use_channel)
     except ValueError:
-        err_message = "'use_channel' parameter must be an integer or one of "
-        err_message += "('left', 'right', 'mix'), found: '{}'"
         raise AudioParameterError(err_message.format(use_channel))
 
 
@@ -169,8 +171,7 @@
             )
         parameters.append(param)
     sampling_rate, sample_width, channels = parameters
-    use_channel = param_dict.get("use_channel", param_dict.get("uc", 1))
-    use_channel = _normalize_use_channel(use_channel)
+    use_channel = param_dict.get("use_channel", param_dict.get("uc"))
     return sampling_rate, sample_width, channels, use_channel
 
 
@@ -203,7 +204,8 @@
 def _extract_selected_channel(data, channels, sample_width, use_channel):
     if use_channel == "mix":
         return _mix_audio_channels(data, channels, sample_width)
-    elif use_channel >= channels or use_channel < -channels:
+
+    if use_channel >= channels or use_channel < -channels:
         err_message = "use_channel == {} but audio data has only {} channel{}."
         err_message += " Selected channel must be 'mix' or an integer >= "
         err_message += "-channels and < channels"
@@ -569,7 +571,7 @@
 
 class RawAudioSource(_FileAudioSource, Rewindable):
     def __init__(
-        self, file, sampling_rate, sample_width, channels, use_channel=0
+        self, file, sampling_rate, sample_width, channels, use_channel=None
     ):
         _FileAudioSource.__init__(
             self, sampling_rate, sample_width, channels, use_channel
@@ -600,7 +602,7 @@
             path to a valid wave file.
     """
 
-    def __init__(self, filename, use_channel=0):
+    def __init__(self, filename, use_channel=None):
         self._filename = filename
         self._audio_stream = None
         stream = wave.open(self._filename, "rb")
@@ -850,7 +852,7 @@
         data = _extract_selected_channel(
             input, channels, sample_width, use_channel
         )
-        return BufferAudioSource(data, sampling_rate, sample_width, channels)
+        return BufferAudioSource(data, sampling_rate, sample_width, 1)
 
     # read data from a file
     if input is not None:
@@ -874,7 +876,7 @@
     sampling_rate,
     sample_width,
     channels,
-    use_channel=0,
+    use_channel=1,
     large_file=False,
 ):
     """
@@ -920,6 +922,7 @@
         with open(file, "rb") as fp:
             data = fp.read()
         if channels != 1:
+            use_channel = _normalize_use_channel(use_channel) # TODO: should happen in BufferAudioSource
             data = _extract_selected_channel(
                 data, channels, sample_width, use_channel
             )
@@ -931,7 +934,7 @@
         )
 
 
-def _load_wave(filename, large_file=False, use_channel=0):
+def _load_wave(filename, large_file=False, use_channel=1):
     """
     Load a wave audio file with standard Python.
     If `large_file` is True, audio data will be lazily
@@ -947,13 +950,14 @@
         swidth = fp.getsampwidth()
         data = fp.readframes(-1)
     if channels > 1:
+        use_channel = _normalize_use_channel(use_channel) # TODO: should happen in BufferAudioSource
         data = _extract_selected_channel(data, channels, swidth, use_channel)
     return BufferAudioSource(
         data, sampling_rate=srate, sample_width=swidth, channels=1
     )
 
 
-def _load_with_pydub(filename, audio_format, use_channel=0):
+def _load_with_pydub(filename, audio_format, use_channel=1):
     """Open compressed audio file using pydub. If a video file
     is passed, its audio track(s) are extracted and loaded.
     This function should not be called directely, use :func:`from_file`
@@ -975,6 +979,7 @@
     segment = open_function(filename)
     data = segment._data
     if segment.channels > 1:
+        use_channel = _normalize_use_channel(use_channel) # TODO: should happen in BufferAudioSource
         data = _extract_selected_channel(
             data, segment.channels, segment.sample_width, use_channel
         )
@@ -1048,7 +1053,7 @@
             filename, srate, swidth, channels, use_channel, large_file
         )
 
-    use_channel = _normalize_use_channel(kwargs.get("use_channel"))
+    use_channel = kwargs.get("use_channel", kwargs.get("uc"))    
     if audio_format in ["wav", "wave"]:
         return _load_wave(filename, large_file, use_channel)
     if large_file:
@@ -1155,4 +1160,4 @@
         )
     else:
         err_message = "cannot write file format {} (file name: {})"
-        raise AudioIOError(err_message.format(audio_format, file))
+        raise AudioIOError(err_message.format(audio_format, file))
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/data/test_split_10HZ_stereo.raw	Mon Jun 17 21:03:58 2019 +0100
@@ -0,0 +1,1 @@
+,,,,@,@,@,@,@,@,@,@,@@@@@@@@,@,@,@X@X@X@X@X@X@X@,@,@@@@@@@,@,,,,,,,,@,@,@@@@@@@@@@@@@@@@@@XXXXXXXXXXXXXXXXXXXXXX
\ No newline at end of file
--- a/tests/test_core.py	Fri Jun 14 22:43:21 2019 +0100
+++ b/tests/test_core.py	Mon Jun 17 21:03:58 2019 +0100
@@ -4,6 +4,11 @@
 from tempfile import TemporaryDirectory
 from genty import genty, genty_dataset
 from auditok import split, AudioRegion, AudioParameterError
+from auditok.io import (
+    _normalize_use_channel,
+    _extract_selected_channel,
+    get_audio_source,
+)
 
 
 def _make_random_length_regions(
@@ -115,9 +120,8 @@
             **kwargs
         )
         regions = list(regions)
-        print(regions)
         err_msg = "Wrong number of regions after split, expected: "
-        err_msg += "{}, found: {}".format(len(regions), len(expected))
+        err_msg += "{}, found: {}".format(len(expected), len(regions))
         self.assertEqual(len(regions), len(expected), err_msg)
 
         sample_width = 2
--- a/tests/test_io.py	Fri Jun 14 22:43:21 2019 +0100
+++ b/tests/test_io.py	Mon Jun 17 21:03:58 2019 +0100
@@ -84,13 +84,13 @@
         self.assertEqual(result, expected)
 
     @genty_dataset(
-        int_1=((8000, 2, 1, 1), (8000, 2, 1, 0)),
-        int_2=((8000, 2, 1, 2), (8000, 2, 1, 1)),
-        use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, 0)),
-        use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, 1)),
+        int_1=((8000, 2, 1, 1), (8000, 2, 1, 1)),
+        int_2=((8000, 2, 1, 2), (8000, 2, 1, 2)),
+        use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, "left")),
+        use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, "right")),
         use_channel_mix=((8000, 2, 1, "mix"), (8000, 2, 1, "mix")),
-        use_channel_None=((8000, 2, 2, None), (8000, 2, 2, 0)),
-        no_use_channel=((8000, 2, 2), (8000, 2, 2, 0)),
+        use_channel_None=((8000, 2, 2, None), (8000, 2, 2, None)),
+        no_use_channel=((8000, 2, 2), (8000, 2, 2, None)),
     )
     def test_get_audio_parameters_short_params(self, values, expected):
         params = dict(zip(("sr", "sw", "ch", "uc"), values))
@@ -98,13 +98,13 @@
         self.assertEqual(result, expected)
 
     @genty_dataset(
-        int_1=((8000, 2, 1, 1), (8000, 2, 1, 0)),
-        int_2=((8000, 2, 1, 2), (8000, 2, 1, 1)),
-        use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, 0)),
-        use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, 1)),
+        int_1=((8000, 2, 1, 1), (8000, 2, 1, 1)),
+        int_2=((8000, 2, 1, 2), (8000, 2, 1, 2)),
+        use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, "left")),
+        use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, "right")),
         use_channel_mix=((8000, 2, 1, "mix"), (8000, 2, 1, "mix")),
-        use_channel_None=((8000, 2, 2, None), (8000, 2, 2, 0)),
-        no_use_channel=((8000, 2, 2), (8000, 2, 2, 0)),
+        use_channel_None=((8000, 2, 2, None), (8000, 2, 2, None)),
+        no_use_channel=((8000, 2, 2), (8000, 2, 2, None)),
     )
     def test_get_audio_parameters_long_params(self, values, expected):
         params = dict(
@@ -116,7 +116,7 @@
         result = _get_audio_parameters(params)
         self.assertEqual(result, expected)
 
-    @genty_dataset(simple=((8000, 2, 1, 1), (8000, 2, 1, 0)))
+    @genty_dataset(simple=((8000, 2, 1, 1), (8000, 2, 1, 1)))
     def test_get_audio_parameters_long_params_shadow_short_ones(
         self, values, expected
     ):
@@ -459,9 +459,9 @@
 
     @genty_dataset(
         dafault_first_channel=(None, 400),
-        first_channel=(0, 400),
-        second_channel=(1, 800),
-        third_channel=(2, 1600),
+        first_channel=(1, 400),
+        second_channel=(2, 800),
+        third_channel=(3, 1600),
         negative_first_channel=(-3, 400),
         negative_second_channel=(-2, 800),
         negative_third_channel=(-1, 1600),
@@ -537,9 +537,9 @@
 
     @genty_dataset(
         dafault_first_channel=(None, 400),
-        first_channel=(0, 400),
-        second_channel=(1, 800),
-        third_channel=(2, 1600),
+        first_channel=(1, 400),
+        second_channel=(2, 800),
+        third_channel=(3, 1600),
         negative_first_channel=(-3, 400),
         negative_second_channel=(-2, 800),
         negative_third_channel=(-1, 1600),
@@ -596,7 +596,7 @@
         mp3_left_channel=("mp3", 1, "left", "from_mp3"),
         mp3_right_channel=("mp3", 2, "right", "from_mp3"),
         mp3_mix_channels=("mp3", 3, "mix", "from_mp3"),
-        flac_first_channel=("flac", 2, 0, "from_file"),
+        flac_first_channel=("flac", 2, 1, "from_file"),
         flac_second_channel=("flac", 2, 1, "from_file"),
         flv_left_channel=("flv", 1, "left", "from_flv"),
         webm_right_channel=("webm", 2, "right", "from_file"),
@@ -615,9 +615,11 @@
                 "auditok.io.AudioSegment.{}".format(function)
             ) as open_func:
                 open_func.return_value = segment_mock
-                use_channel = {"left": 0, "right": 1, None: 0}.get(
+                normalized_use_channel = {"left": 1, "right": 2, None: 0}.get(
                     use_channel, use_channel
                 )
+                if isinstance(normalized_use_channel, int) and normalized_use_channel > 0:
+                     normalized_use_channel -= 1
                 _load_with_pydub(filename, audio_format, use_channel)
                 self.assertTrue(open_func.called)
                 if channels > 1:
@@ -626,7 +628,7 @@
                         segment_mock._data,
                         segment_mock.channels,
                         segment_mock.sample_width,
-                        use_channel,
+                        normalized_use_channel,
                     )
                 else:
                     self.assertFalse(ext_mock.called)
@@ -776,4 +778,4 @@
         if extra_args is not None:
             kwargs.update(extra_args)
         audio_source = get_audio_source(input, **kwargs)
-        self.assertIsInstance(audio_source, expected_type)
+        self.assertIsInstance(audio_source, expected_type)
\ No newline at end of file