changeset 208:29472a5a798a

Harmonize calls to _extract_selected_channel
author Amine Sehili <amine.sehili@gmail.com>
date Fri, 14 Jun 2019 22:40:04 +0100
parents 12e6837c5961
children 9047740c5092
files auditok/io.py tests/test_AudioSource.py tests/test_io.py
diffstat 3 files changed, 46 insertions(+), 61 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Wed Jun 12 20:27:32 2019 +0100
+++ b/auditok/io.py	Fri Jun 14 22:40:04 2019 +0100
@@ -109,8 +109,10 @@
     """
     if use_channel is None:
         return 0
-    if use_channel == "mix" or isinstance(use_channel, int):
-        return use_channel
+    if use_channel == "mix":
+        return "mix"
+    if isinstance(use_channel, int):
+        return use_channel - 1
     try:
         return ["left", "right"].index(use_channel)
     except ValueError:
@@ -167,7 +169,7 @@
             )
         parameters.append(param)
     sampling_rate, sample_width, channels = parameters
-    use_channel = param_dict.get("use_channel", param_dict.get("uc", 0))
+    use_channel = param_dict.get("use_channel", param_dict.get("uc", 1))
     use_channel = _normalize_use_channel(use_channel)
     return sampling_rate, sample_width, channels, use_channel
 
@@ -835,24 +837,16 @@
         "-", raw data will be read from stdin. If None, read audio data from
         microphone using PyAudio.
     """
-
-    sampling_rate = kwargs.get(
-        "sampling_rate", kwargs.get("sr", DEFAULT_SAMPLING_RATE)
-    )
-    sample_width = kwargs.get(
-        "sample_rate", kwargs.get("sw", DEFAULT_SAMPLE_WIDTH)
-    )
-    channels = kwargs.get("channels", kwargs.get("ch", DEFAULT_NB_CHANNELS))
-    use_channel = kwargs.get(
-        "use_channel", kwargs.get("uc", DEFAULT_USE_CHANNEL)
-    )
+    sampling_rate, sample_width, channels, use_channel = _get_audio_parameters(kwargs)
     if input == "-":
         return StdinAudioSource(
             sampling_rate, sample_width, channels, use_channel
         )
 
     if isinstance(input, bytes):
-        return BufferAudioSource(input, sampling_rate, sample_width, channels)
+        use_channel = _normalize_use_channel(use_channel)
+        data = _extract_selected_channel(input, channels, sample_width, use_channel)
+        return BufferAudioSource(data, sampling_rate, sample_width, channels)
 
     # read data from a file
     if input is not None:
@@ -922,7 +916,6 @@
         with open(file, "rb") as fp:
             data = fp.read()
         if channels != 1:
-            # TODO check if striding with mmap doesn't load all data to memory
             data = _extract_selected_channel(
                 data, channels, sample_width, use_channel
             )
@@ -1158,4 +1151,4 @@
         )
     else:
         err_message = "cannot write file format {} (file name: {})"
-        raise AudioIOError(err_message.format(audio_format, file))
+        raise AudioIOError(err_message.format(audio_format, file))
\ No newline at end of file
--- a/tests/test_AudioSource.py	Wed Jun 12 20:27:32 2019 +0100
+++ b/tests/test_AudioSource.py	Fri Jun 14 22:40:04 2019 +0100
@@ -35,7 +35,9 @@
         mono_mix=("mono_400Hz", 1, "mix", 400),
         mono_channel_selection=("mono_400Hz", 1, 2, 400),
         multichannel_default=("3channel_400-800-1600Hz", 3, None, 400),
-        multichannel_channel_selection=("3channel_400-800-1600Hz", 3, 1, 800),
+        multichannel_channel_select_1st=("3channel_400-800-1600Hz", 3, 1, 400),
+        multichannel_channel_select_2nd=("3channel_400-800-1600Hz", 3, 2, 800),
+        multichannel_channel_select_3rd=("3channel_400-800-1600Hz", 3, 3, 1600),
     )
     def test_RawAudioSource(
         self, file_suffix, channels, use_channel, frequency
@@ -68,7 +70,9 @@
         mono_mix=("mono_400Hz", 1, "mix", 400),
         mono_channel_selection=("mono_400Hz", 1, 2, 400),
         multichannel_default=("3channel_400-800-1600Hz", 3, None, 400),
-        multichannel_channel_selection=("3channel_400-800-1600Hz", 3, 1, 800),
+        multichannel_channel_select_1st=("3channel_400-800-1600Hz", 3, 1, 400),
+        multichannel_channel_select_2nd=("3channel_400-800-1600Hz", 3, 2, 800),
+        multichannel_channel_select_3rd=("3channel_400-800-1600Hz", 3, 3, 1600),
     )
     def test_WaveAudioSource(
         self, file_suffix, channels, use_channel, frequency
@@ -1072,4 +1076,4 @@
 
 
 if __name__ == "__main__":
-    unittest.main()
+    unittest.main()
\ No newline at end of file
--- a/tests/test_io.py	Wed Jun 12 20:27:32 2019 +0100
+++ b/tests/test_io.py	Fri Jun 14 22:40:04 2019 +0100
@@ -74,8 +74,7 @@
 
     @genty_dataset(
         none=(None, 0),
-        positive_int=(1, 1),
-        negative_int=(-1, -1),
+        positive_int=(1, 0),
         left=("left", 0),
         right=("right", 1),
         mix=("mix", "mix"),
@@ -85,7 +84,8 @@
         self.assertEqual(result, expected)
 
     @genty_dataset(
-        simple=((8000, 2, 1, 0), (8000, 2, 1, 0)),
+        int_1=((8000, 2, 1, 1), (8000, 2, 1, 0)),
+        int_2=((8000, 2, 1, 2), (8000, 2, 1, 1)),
         use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, 0)),
         use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, 1)),
         use_channel_mix=((8000, 2, 1, "mix"), (8000, 2, 1, "mix")),
@@ -93,12 +93,13 @@
         no_use_channel=((8000, 2, 2), (8000, 2, 2, 0)),
     )
     def test_get_audio_parameters_short_params(self, values, expected):
-        params = {k: v for k, v in zip(("sr", "sw", "ch", "uc"), values)}
+        params = dict(zip(("sr", "sw", "ch", "uc"), values))
         result = _get_audio_parameters(params)
         self.assertEqual(result, expected)
 
     @genty_dataset(
-        simple=((8000, 2, 1, 0), (8000, 2, 1, 0)),
+        int_1=((8000, 2, 1, 1), (8000, 2, 1, 0)),
+        int_2=((8000, 2, 1, 2), (8000, 2, 1, 1)),
         use_channel_left=((8000, 2, 1, "left"), (8000, 2, 1, 0)),
         use_channel_right=((8000, 2, 1, "right"), (8000, 2, 1, 1)),
         use_channel_mix=((8000, 2, 1, "mix"), (8000, 2, 1, "mix")),
@@ -106,29 +107,16 @@
         no_use_channel=((8000, 2, 2), (8000, 2, 2, 0)),
     )
     def test_get_audio_parameters_long_params(self, values, expected):
-        params = {
-            k: v
-            for k, v in zip(
-                ("sampling_rate", "sample_width", "channels", "use_channel"),
-                values,
-            )
-        }
+        params = dict(zip(("sampling_rate", "sample_width", "channels", "use_channel"), values))
         result = _get_audio_parameters(params)
         self.assertEqual(result, expected)
 
-    @genty_dataset(simple=((8000, 2, 1, 0), (8000, 2, 1, 0)))
-    def test_get_audio_parameters_short_and_long_params(
+    @genty_dataset(simple=((8000, 2, 1, 1), (8000, 2, 1, 0)))
+    def test_get_audio_parameters_long_params_shadow_short_ones(
         self, values, expected
     ):
-        params = {
-            k: v
-            for k, v in zip(
-                ("sampling_rate", "sample_width", "channels", "use_channel"),
-                values,
-            )
-        }
-
-        params.update({k: v for k, v in zip(("sr", "sw", "ch", "uc"), "xxxx")})
+        params = dict(zip(("sampling_rate", "sample_width", "channels", "use_channel"), values))
+        params.update(dict(zip(("sr", "sw", "ch", "uc"), "xxxx")))
         result = _get_audio_parameters(params)
         self.assertEqual(result, expected)
 
@@ -141,13 +129,10 @@
         negative_channels=((8000, 2, -1, 0),),
     )
     def test_get_audio_parameters_invalid(self, values):
-        params = {
-            k: v
-            for k, v in zip(
-                ("sampling_rate", "sample_width", "channels", "use_channel"),
-                values,
-            )
-        }
+        # TODO 0 or negative use_channel must raise AudioParameterError
+        # change implementation, don't accept negative uc
+        # hifglight everywhere in doc that uc must be positive
+        params = dict(zip(("sampling_rate", "sample_width", "channels", "use_channel"), values))
         with self.assertRaises(AudioParameterError):
             _get_audio_parameters(params)
 
@@ -311,14 +296,14 @@
                 from_file("audio", "mp3")
 
     @genty_dataset(
-        raw_first_channel=("raw", 0, 400),
-        raw_second_channel=("raw", 1, 800),
-        raw_third_channel=("raw", 2, 1600),
+        raw_first_channel=("raw", 1, 400),
+        raw_second_channel=("raw", 2, 800),
+        raw_third_channel=("raw", 3, 1600),
         raw_left_channel=("raw", "left", 400),
         raw_right_channel=("raw", "right", 800),
-        wav_first_channel=("wav", 0, 400),
-        wav_second_channel=("wav", 1, 800),
-        wav_third_channel=("wav", 2, 1600),
+        wav_first_channel=("wav", 1, 400),
+        wav_second_channel=("wav", 2, 800),
+        wav_third_channel=("wav", 3, 1600),
         wav_left_channel=("wav", "left", 400),
         wav_right_channel=("wav", "right", 800),
     )
@@ -378,13 +363,13 @@
     @patch("auditok.io._WITH_PYDUB", True)
     @patch("auditok.io.BufferAudioSource")
     @genty_dataset(
-        ogg_first_channel=("ogg", 0, "from_ogg"),
-        ogg_second_channel=("ogg", 1, "from_ogg"),
+        ogg_first_channel=("ogg", 1, "from_ogg"),
+        ogg_second_channel=("ogg", 2, "from_ogg"),
         ogg_mix=("ogg", "mix", "from_ogg"),
         ogg_default=("ogg", None, "from_ogg"),
         mp3_left_channel=("mp3", "left", "from_mp3"),
         mp3_right_channel=("mp3", "right", "from_mp3"),
-        flac_first_channel=("flac", 0, "from_file"),
+        flac_first_channel=("flac", 1, "from_file"),
         flac_second_channel=("flac", 1, "from_file"),
         flv_left_channel=("flv", "left", "from_flv"),
         webm_right_channel=("webm", "right", "from_file"),
@@ -406,9 +391,12 @@
                 self.assertTrue(open_func.called)
                 self.assertTrue(ext_mock.called)
 
-                use_channel = {"left": 0, "right": 1, None: 0}.get(
+                use_channel = {"left": 1, "right": 2, None: 1}.get(
                     use_channel, use_channel
                 )
+                if isinstance(use_channel, int):
+                    # _extract_selected_channel will be called with a channel starting from 0
+                    use_channel -= 1
                 ext_mock.assert_called_with(
                     segment_mock._data,
                     segment_mock.channels,
@@ -773,4 +761,4 @@
         if extra_args is not None:
             kwargs.update(extra_args)
         audio_source = get_audio_source(input, **kwargs)
-        self.assertIsInstance(audio_source, expected_type)
+        self.assertIsInstance(audio_source, expected_type)
\ No newline at end of file