changeset 119:c768fa017e21

Add tests for _extract_selected_channel
author Amine Sehili <amine.sehili@gmail.com>
date Sat, 02 Feb 2019 11:53:59 +0100
parents 1af0c6050073
children 9b117eb6ecfd
files auditok/io.py tests/test_io.py
diffstat 2 files changed, 37 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/io.py	Fri Feb 01 20:55:34 2019 +0100
+++ b/auditok/io.py	Sat Feb 02 11:53:59 2019 +0100
@@ -194,6 +194,8 @@
             use_channel, channels, "s" if channels > 1 else ""
         )
         raise AudioParameterError(err_message)
+    elif use_channel < 0:
+        use_channel += channels
     fmt = DATA_FORMAT[sample_width]
     buffer = array(fmt, data)
     return _array_to_bytes(buffer[use_channel::channels])
--- a/tests/test_io.py	Fri Feb 01 20:55:34 2019 +0100
+++ b/tests/test_io.py	Sat Feb 02 11:53:59 2019 +0100
@@ -12,6 +12,7 @@
     check_audio_data,
     _array_to_bytes,
     _mix_audio_channels,
+    _extract_selected_channel,
     _save_raw,
     _save_wave,
 )
@@ -124,6 +125,40 @@
         self.assertEqual(mixed, expected)
 
     @genty_dataset(
+        mono_1byte=([400], 1, 0),
+        stereo_1byte_2st_channel=([400, 600], 1, 1),
+        mono_2byte=([400], 2, 0),
+        stereo_2byte_1st_channel=([400, 600], 2, 0),
+        stereo_2byte_2nd_channel=([400, 600], 2, 1),
+        three_channel_2byte_last_negative_idx=([400, 600, 1150], 2, -1),
+        three_channel_2byte_2nd_negative_idx=([400, 600, 1150], 2, -2),
+        three_channel_2byte_1st_negative_idx=([400, 600, 1150], 2, -3),
+        three_channel_4byte_1st=([400, 600, 1150], 4, 0),
+        three_channel_4byte_last_negative_idx=([400, 600, 1150], 4, -1),
+    )
+    def test_extract_selected_channel(
+        self, frequencies, sample_width, use_channel
+    ):
+
+        mono_channels = [
+            _generate_pure_tone(
+                freq,
+                duration_sec=0.1,
+                sampling_rate=16000,
+                sample_width=sample_width,
+            )
+            for freq in frequencies
+        ]
+        channels = len(frequencies)
+        fmt = DATA_FORMAT[sample_width]
+        expected = _array_to_bytes(mono_channels[use_channel])
+        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
+        selected_channel = _extract_selected_channel(
+            data, channels, sample_width, use_channel
+        )
+        self.assertEqual(selected_channel, expected)
+
+    @genty_dataset(
         mono=("mono_400Hz.raw", (400,)),
         three_channel=("3channel_400-800-1600Hz.raw", (400, 800, 1600)),
     )