changeset 122:3728814e1233

Add tests for from_file with channel selection
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 03 Feb 2019 14:17:14 +0100
parents dcf8a245ba1b
children 34435df8cf02
files auditok/io.py tests/test_io.py
diffstat 2 files changed, 37 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Sat Feb 02 14:22:02 2019 +0100
+++ b/auditok/io.py	Sun Feb 03 14:17:14 2019 +0100
@@ -796,9 +796,7 @@
     )
 
 
-def from_file(
-    filename, audio_format=None, use_channel=None, large_file=False, **kwargs
-):
+def from_file(filename, audio_format=None, large_file=False, **kwargs):
     """
     Read audio data from `filename` and return an `AudioSource` object.
     if `audio_format` is None, the appropriate :class:`AudioSource` class is
@@ -823,10 +821,6 @@
         path to input audio or video file.
     `audio_format`: str
         audio format used to save data  (e.g. raw, webm, wav, ogg)
-    `use_channel`: int
-        audio channel to extract from input file if file is not mono audio.
-        This must be an integer >= 0 and < channels, or one of the special
-        values `left` and `right` (treated as 0 and 1 respectively).
     `large_file`: bool
         If True, audio won't fully be loaded to memory but only when a window
         is read disk.
@@ -842,6 +836,10 @@
         sample width (i.e. number of bytes used to represent one audio sample)
     `channels`: int
         number of channels of audio data
+    `use_channel`: int, str
+        audio channel to extract from input file if file is not mono audio.
+        This must be an integer >= 0 and < channels, or one of the special
+        values `left` and `right` (treated as 0 and 1 respectively).
 
     :Returns:
 
@@ -860,7 +858,7 @@
             filename, srate, swidth, channels, use_channel, large_file
         )
 
-    use_channel = _normalize_use_channel(kwargs.get("use_channel", None))
+    use_channel = _normalize_use_channel(kwargs.get("use_channel"))
     if audio_format in ["wav", "wave"]:
         return _load_wave(filename, large_file, use_channel)
     if large_file:
--- a/tests/test_io.py	Sat Feb 02 14:22:02 2019 +0100
+++ b/tests/test_io.py	Sun Feb 03 14:17:14 2019 +0100
@@ -216,6 +216,37 @@
                 from_file("audio", "mp3")
 
     @genty_dataset(
+        raw_first_channel=("raw", 0, 400),
+        raw_second_channel=("raw", 1, 800),
+        raw_third_channel=("raw", 2, 1600),
+        raw_left_channel=("raw", "left", 400),
+        raw_right_channel=("raw", "right", 800),
+        wav_first_channel=("wav", 0, 400),
+        wav_second_channel=("wav", 1, 800),
+        wav_third_channel=("wav", 2, 1600),
+        wav_left_channel=("wav", "left", 400),
+        wav_right_channel=("wav", "right", 800),
+    )
+    def test_from_file_multichannel_audio(
+        self, audio_format, use_channel, frequency
+    ):
+        expected = PURE_TONE_DICT[frequency]
+        filename = "tests/data/test_16KHZ_3channel_400-800-1600Hz.{}".format(
+            audio_format
+        )
+        sample_width = 2
+        audio_source = from_file(
+            filename,
+            sampling_rate=16000,
+            sample_width=sample_width,
+            channels=3,
+            use_channel=use_channel,
+        )
+        fmt = DATA_FORMAT[sample_width]
+        data = array(fmt, audio_source._buffer)
+        self.assertEqual(data, expected)
+
+    @genty_dataset(
         mono=("mono_400Hz.wav", (400,)),
         three_channel=("3channel_400-800-1600Hz.wav", (400, 800, 1600)),
     )