view tests/test_io.py @ 289:69ca1c64a9b0

Return 'wav' when 'wave' is guessed in _guess_audio_format
author Amine Sehili <amine.sehili@gmail.com>
date Sat, 05 Oct 2019 14:26:03 +0200
parents 173ffca58d23
children 10b725735637
line wrap: on
line source
import os
import sys
import math
from array import array
from tempfile import NamedTemporaryFile, TemporaryDirectory
import filecmp
from unittest import TestCase
from genty import genty, genty_dataset
from test_util import _sample_generator, _generate_pure_tone, PURE_TONE_DICT
from auditok.io import (
    DATA_FORMAT,
    AudioIOError,
    AudioParameterError,
    BufferAudioSource,
    RawAudioSource,
    WaveAudioSource,
    StdinAudioSource,
    check_audio_data,
    _guess_audio_format,
    _normalize_use_channel,
    _get_audio_parameters,
    _array_to_bytes,
    _mix_audio_channels,
    _extract_selected_channel,
    _load_raw,
    _load_wave,
    _load_with_pydub,
    get_audio_source,
    from_file,
    _save_raw,
    _save_wave,
    _save_with_pydub,
    to_file,
)


if sys.version_info >= (3, 0):
    PYTHON_3 = True
    from unittest.mock import patch, Mock
else:
    PYTHON_3 = False
    from mock import patch, Mock

AUDIO_PARAMS_SHORT = {"sr": 16000, "sw": 2, "ch": 1}


@genty
class TestIO(TestCase):
    @genty_dataset(
        valid_mono=(b"\0" * 113, 1, 1),
        valid_stereo=(b"\0" * 160, 1, 2),
        invalid_mono_sw_2=(b"\0" * 113, 2, 1, False),
        invalid_stereo_sw_1=(b"\0" * 113, 1, 2, False),
        invalid_stereo_sw_2=(b"\0" * 158, 2, 2, False),
    )
    def test_check_audio_data(self, data, sample_width, channels, valid=True):

        if not valid:
            with self.assertRaises(AudioParameterError):
                check_audio_data(data, sample_width, channels)
        else:
            self.assertIsNone(check_audio_data(data, sample_width, channels))

    @genty_dataset(
        extention_and_format_same=("wav", "filename.wav", "wav"),
        extention_and_format_different=("wav", "filename.mp3", "wav"),
        extention_no_format=(None, "filename.wav", "wav"),
        format_no_extension=("wav", "filename", "wav"),
        no_format_no_extension=(None, "filename", None),
        wave_as_wav=("wave", "filename", "wav"),
        wave_as_wav_extension=(None, "filename.wave", "wav"),
    )
    def test_guess_audio_format(self, fmt, filename, expected):
        result = _guess_audio_format(fmt, filename)
        self.assertEqual(result, expected)

    @genty_dataset(
        none=(None, 0),
        positive_int=(1, 0),
        left=("left", 0),
        right=("right", 1),
        mix=("mix", "mix"),
    )
    def test_normalize_use_channel(self, use_channel, expected):
        result = _normalize_use_channel(use_channel)
        self.assertEqual(result, expected)

    def test_get_audio_parameters_short_params(self):
        expected = (8000, 2, 1)
        params = dict(zip(("sr", "sw", "ch"), expected))
        result = _get_audio_parameters(params)
        self.assertEqual(result, expected)

    def test_get_audio_parameters_long_params(self):
        expected = (8000, 2, 1)
        params = dict(
            zip(
                ("sampling_rate", "sample_width", "channels", "use_channel"),
                expected,
            )
        )
        result = _get_audio_parameters(params)
        self.assertEqual(result, expected)

    def test_get_audio_parameters_long_params_shadow_short_ones(self):
        expected = (8000, 2, 1)
        params = dict(
            zip(("sampling_rate", "sample_width", "channels"), expected)
        )
        params.update(dict(zip(("sr", "sw", "ch"), "xxx")))
        result = _get_audio_parameters(params)
        self.assertEqual(result, expected)

    @genty_dataset(
        str_sampling_rate=(("x", 2, 1),),
        negative_sampling_rate=((-8000, 2, 1),),
        str_sample_width=((8000, "x", 1),),
        negative_sample_width=((8000, -2, 1),),
        str_channels=((8000, 2, "x"),),
        negative_channels=((8000, 2, -1),),
    )
    def test_get_audio_parameters_invalid(self, values):
        params = dict(
            zip(
                ("sampling_rate", "sample_width", "channels"),
                values,
            )
        )
        with self.assertRaises(AudioParameterError):
            _get_audio_parameters(params)

    @genty_dataset(
        mono_1byte=([400], 1),
        stereo_1byte=([400, 600], 1),
        three_channel_1byte=([400, 600, 2400], 1),
        mono_2byte=([400], 2),
        stereo_2byte=([400, 600], 2),
        three_channel_2byte=([400, 600, 1150], 2),
        mono_4byte=([400], 4),
        stereo_4byte=([400, 600], 4),
        four_channel_2byte=([400, 600, 1150, 7220], 4),
    )
    def test_mix_audio_channels(self, frequencies, sample_width):
        sampling_rate = 16000
        sample_width = 2
        channels = len(frequencies)
        mono_channels = [
            _generate_pure_tone(
                freq,
                duration_sec=0.1,
                sampling_rate=sampling_rate,
                sample_width=sample_width,
            )
            for freq in frequencies
        ]
        fmt = DATA_FORMAT[sample_width]
        expected = _array_to_bytes(
            array(
                fmt,
                (sum(samples) // channels for samples in zip(*mono_channels)),
            )
        )
        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        mixed = _mix_audio_channels(data, channels, sample_width)
        self.assertEqual(mixed, expected)

    @genty_dataset(
        mono_1byte=([400], 1, 0),
        stereo_1byte_2st_channel=([400, 600], 1, 1),
        mono_2byte=([400], 2, 0),
        stereo_2byte_1st_channel=([400, 600], 2, 0),
        stereo_2byte_2nd_channel=([400, 600], 2, 1),
        three_channel_2byte_last_negative_idx=([400, 600, 1150], 2, -1),
        three_channel_2byte_2nd_negative_idx=([400, 600, 1150], 2, -2),
        three_channel_2byte_1st_negative_idx=([400, 600, 1150], 2, -3),
        three_channel_4byte_1st=([400, 600, 1150], 4, 0),
        three_channel_4byte_last_negative_idx=([400, 600, 1150], 4, -1),
    )
    def test_extract_selected_channel(
        self, frequencies, sample_width, use_channel
    ):

        mono_channels = [
            _generate_pure_tone(
                freq,
                duration_sec=0.1,
                sampling_rate=16000,
                sample_width=sample_width,
            )
            for freq in frequencies
        ]
        channels = len(frequencies)
        fmt = DATA_FORMAT[sample_width]
        expected = _array_to_bytes(mono_channels[use_channel])
        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        selected_channel = _extract_selected_channel(
            data, channels, sample_width, use_channel
        )
        self.assertEqual(selected_channel, expected)

    @genty_dataset(mono=([400],), three_channel=([600, 1150, 2400],))
    def test_extract_selected_channel_mix(self, frequencies):

        mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
        channels = len(frequencies)
        fmt = DATA_FORMAT[2]
        expected = _array_to_bytes(
            array(
                fmt,
                (sum(samples) // channels for samples in zip(*mono_channels)),
            )
        )
        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        selected_channel = _extract_selected_channel(data, channels, 2, "mix")
        self.assertEqual(selected_channel, expected)

    @genty_dataset(positive=(2,), negative=(-3,))
    def test_extract_selected_channel_invalid_use_channel(self, use_channel):
        with self.assertRaises(AudioParameterError):
            _extract_selected_channel(b"\0\0", 2, 2, use_channel)

    @genty_dataset(
        raw_with_audio_format=(
            "audio",
            "raw",
            "_load_raw",
            AUDIO_PARAMS_SHORT,
        ),
        raw_with_extension=(
            "audio.raw",
            None,
            "_load_raw",
            AUDIO_PARAMS_SHORT,
        ),
        wave_with_audio_format=("audio", "wave", "_load_wave"),
        wav_with_audio_format=("audio", "wave", "_load_wave"),
        wav_with_extension=("audio.wav", None, "_load_wave"),
        format_and_extension_both_given=("audio.dat", "wav", "_load_wave"),
        format_and_extension_both_given_b=("audio.raw", "wave", "_load_wave"),
        no_format_nor_extension=("audio", None, "_load_with_pydub"),
        other_formats_ogg=("audio.ogg", None, "_load_with_pydub"),
        other_formats_webm=("audio", "webm", "_load_with_pydub"),
    )
    def test_from_file(
        self, filename, audio_format, funtion_name, kwargs=None
    ):
        funtion_name = "auditok.io." + funtion_name
        if kwargs is None:
            kwargs = {}
        with patch(funtion_name) as patch_function:
            from_file(filename, audio_format, **kwargs)
        self.assertTrue(patch_function.called)

    def test_from_file_large_file_raw(self,):
        filename = "tests/data/test_16KHZ_mono_400Hz.raw"
        audio_source = from_file(
            filename,
            large_file=True,
            sampling_rate=16000,
            sample_width=2,
            channels=1,
        )
        self.assertIsInstance(audio_source, RawAudioSource)

    def test_from_file_large_file_wave(self,):
        filename = "tests/data/test_16KHZ_mono_400Hz.wav"
        audio_source = from_file(filename, large_file=True)
        self.assertIsInstance(audio_source, WaveAudioSource)

    def test_from_file_large_file_compressed(self,):
        filename = "tests/data/test_16KHZ_mono_400Hz.ogg"
        with self.assertRaises(AudioIOError):
            from_file(filename, large_file=True)

    @genty_dataset(
        missing_sampling_rate=("sr",),
        missing_sample_width=("sw",),
        missing_channels=("ch",),
    )
    def test_from_file_missing_audio_param(self, missing_param):
        with self.assertRaises(AudioParameterError):
            params = AUDIO_PARAMS_SHORT.copy()
            del params[missing_param]
            from_file("audio", audio_format="raw", **params)

    def test_from_file_no_pydub(self):
        with patch("auditok.io._WITH_PYDUB", False):
            with self.assertRaises(AudioIOError):
                from_file("audio", "mp3")


    @patch("auditok.io._WITH_PYDUB", True)
    @patch("auditok.io.BufferAudioSource")
    @genty_dataset(
        ogg_first_channel=("ogg", "from_ogg"),
        ogg_second_channel=("ogg", "from_ogg"),
        ogg_mix=("ogg", "from_ogg"),
        ogg_default=("ogg", "from_ogg"),
        mp3_left_channel=("mp3", "from_mp3"),
        mp3_right_channel=("mp3", "from_mp3"),
        flac_first_channel=("flac", "from_file"),
        flac_second_channel=("flac", "from_file"),
        flv_left_channel=("flv", "from_flv"),
        webm_right_channel=("webm", "from_file"),
    )
    def test_from_file_multichannel_audio_compressed(
        self, audio_format, function, *mocks
    ):
        filename = "audio.{}".format(audio_format)
        segment_mock = Mock()
        segment_mock.sample_width = 2
        segment_mock.channels = 2
        segment_mock._data = b"abcd"
        with patch(
            "auditok.io.AudioSegment.{}".format(function)
        ) as open_func:
            open_func.return_value = segment_mock
            from_file(filename)
            self.assertTrue(open_func.called)


    @genty_dataset(
        mono=("mono_400", (400,)),
        three_channel=("3channel_400-800-1600", (400, 800, 1600)),

        mono_large_file=("mono_400", (400,), True),
        three_channel_large_file=("3channel_400-800-1600", (400, 800, 1600), True),
    )
    def test_load_raw(self, file_id, frequencies, large_file=False):
        filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id)
        audio_source = _load_raw(filename, 16000, 2, len(frequencies), large_file=large_file)
        audio_source.open()
        data = audio_source.read(-1)
        audio_source.close()
        expected_class = RawAudioSource if large_file else BufferAudioSource
        self.assertIsInstance(audio_source, expected_class)
        self.assertEqual(audio_source.sampling_rate, 16000)
        self.assertEqual(audio_source.sample_width, 2)
        self.assertEqual(audio_source.channels, len(frequencies))
        mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
        fmt = DATA_FORMAT[audio_source.sample_width]
        expected =_array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        self.assertEqual(data, expected)

    @genty_dataset(
        missing_sampling_rate=("sr",),
        missing_sample_width=("sw",),
        missing_channels=("ch",),
    )
    def test_load_raw_missing_audio_param(self, missing_param):
        with self.assertRaises(AudioParameterError):
            params = AUDIO_PARAMS_SHORT.copy()
            del params[missing_param]
            srate, swidth, channels, _ = _get_audio_parameters(params)
            _load_raw("audio", srate, swidth, channels)

    @genty_dataset(
        mono=("mono_400", (400,)),
        three_channel=("3channel_400-800-1600", (400, 800, 1600)),

        mono_large_file=("mono_400", (400,), True),
        three_channel_large_file=("3channel_400-800-1600", (400, 800, 1600), True),
    )
    def test_load_wave(self, file_id, frequencies, large_file=False):
        filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id)
        audio_source = _load_wave(filename, large_file=large_file)
        audio_source.open()
        data = audio_source.read(-1)
        audio_source.close()
        expected_class = WaveAudioSource if large_file else BufferAudioSource
        self.assertIsInstance(audio_source, expected_class)
        self.assertEqual(audio_source.sampling_rate, 16000)
        self.assertEqual(audio_source.sample_width, 2)
        self.assertEqual(audio_source.channels, len(frequencies))
        mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
        fmt = DATA_FORMAT[audio_source.sample_width]
        expected =_array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        self.assertEqual(data, expected)


    @patch("auditok.io._WITH_PYDUB", True)
    @patch("auditok.io.BufferAudioSource")
    @genty_dataset(
        ogg_default_first_channel=("ogg", 2, "from_ogg"),
        ogg_first_channel=("ogg", 1, "from_ogg"),
        ogg_second_channel=("ogg", 2, "from_ogg"),
        ogg_mix_channels=("ogg", 3, "from_ogg"),
        mp3_left_channel=("mp3", 1, "from_mp3"),
        mp3_right_channel=("mp3", 2, "from_mp3"),
        mp3_mix_channels=("mp3", 3, "from_mp3"),
        flac_first_channel=("flac", 2, "from_file"),
        flac_second_channel=("flac", 2, "from_file"),
        flv_left_channel=("flv", 1, "from_flv"),
        webm_right_channel=("webm", 2, "from_file"),
        webm_mix_channels=("webm", 4, "from_file"),
    )
    def test_load_with_pydub(
        self, audio_format, channels, function, *mocks
    ):
        filename = "audio.{}".format(audio_format)
        segment_mock = Mock()
        segment_mock.sample_width = 2
        segment_mock.channels = channels
        segment_mock._data = b"abcdefgh"
        with patch(
            "auditok.io.AudioSegment.{}".format(function)
        ) as open_func:
            open_func.return_value = segment_mock
            _load_with_pydub(filename, audio_format)
            self.assertTrue(open_func.called)


    @genty_dataset(
        mono=("mono_400Hz.raw", (400,)),
        three_channel=("3channel_400-800-1600Hz.raw", (400, 800, 1600)),
    )
    def test_save_raw(self, filename, frequencies):
        filename = "tests/data/test_16KHZ_{}".format(filename)
        sample_width = 2
        fmt = DATA_FORMAT[sample_width]
        mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        tmpfile = NamedTemporaryFile()
        _save_raw(data, tmpfile.name)
        self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False))

    @genty_dataset(
        mono=("mono_400Hz.wav", (400,)),
        three_channel=("3channel_400-800-1600Hz.wav", (400, 800, 1600)),
    )
    def test_save_wave(self, filename, frequencies):
        filename = "tests/data/test_16KHZ_{}".format(filename)
        sampling_rate = 16000
        sample_width = 2
        channels = len(frequencies)
        fmt = DATA_FORMAT[sample_width]
        mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies]
        data = _array_to_bytes(array(fmt, _sample_generator(*mono_channels)))
        tmpfile = NamedTemporaryFile()
        _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels)
        self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False))

    @genty_dataset(
        missing_sampling_rate=("sr",),
        missing_sample_width=("sw",),
        missing_channels=("ch",),
    )
    def test_save_wave_missing_audio_param(self, missing_param):
        with self.assertRaises(AudioParameterError):
            params = AUDIO_PARAMS_SHORT.copy()
            del params[missing_param]
            srate, swidth, channels, _ = _get_audio_parameters(params)
            _save_wave(b"\0\0", "audio", srate, swidth, channels)

    def test_save_with_pydub(self):
        with patch("auditok.io.AudioSegment.export") as export:
            tmpdir = TemporaryDirectory()
            filename = os.path.join(tmpdir.name, "audio.ogg")
            _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1)
            self.assertTrue(export.called)
            tmpdir.cleanup()

    @genty_dataset(
        raw_with_audio_format=("audio", "raw"),
        raw_with_extension=("audio.raw", None),
        raw_with_audio_format_and_extension=("audio.mp3", "raw"),
        raw_no_audio_format_nor_extension=("audio", None),
    )
    def test_to_file_raw(self, filename, audio_format):
        exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw"
        tmpdir = TemporaryDirectory()
        filename = os.path.join(tmpdir.name, filename)
        data = _array_to_bytes(PURE_TONE_DICT[400])
        to_file(data, filename, audio_format=audio_format)
        self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False))
        tmpdir.cleanup()

    @genty_dataset(
        wav_with_audio_format=("audio", "wav"),
        wav_with_extension=("audio.wav", None),
        wav_with_audio_format_and_extension=("audio.mp3", "wav"),
        wave_with_audio_format=("audio", "wave"),
        wave_with_extension=("audio.wave", None),
        wave_with_audio_format_and_extension=("audio.mp3", "wave"),
    )
    def test_to_file_wave(self, filename, audio_format):
        exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav"
        tmpdir = TemporaryDirectory()
        filename = os.path.join(tmpdir.name, filename)
        data = _array_to_bytes(PURE_TONE_DICT[400])
        to_file(
            data,
            filename,
            audio_format=audio_format,
            sampling_rate=16000,
            sample_width=2,
            channels=1,
        )
        self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False))
        tmpdir.cleanup()

    @genty_dataset(
        missing_sampling_rate=("sr",),
        missing_sample_width=("sw",),
        missing_channels=("ch",),
    )
    def test_to_file_missing_audio_param(self, missing_param):
        params = AUDIO_PARAMS_SHORT.copy()
        del params[missing_param]
        with self.assertRaises(AudioParameterError):
            to_file(b"\0\0", "audio", audio_format="wav", **params)
        with self.assertRaises(AudioParameterError):
            to_file(b"\0\0", "audio", audio_format="mp3", **params)

    def test_to_file_no_pydub(self):
        with patch("auditok.io._WITH_PYDUB", False):
            with self.assertRaises(AudioIOError):
                to_file("audio", b"", "mp3")

    @patch("auditok.io._WITH_PYDUB", True)
    @genty_dataset(
        ogg_with_extension=("audio.ogg", None),
        ogg_with_audio_format=("audio", "ogg"),
        ogg_format_with_wrong_extension=("audio.wav", "ogg"),
    )
    def test_to_file_compressed(self, filename, audio_format, *mocks):
        with patch("auditok.io.AudioSegment.export") as export:
            tmpdir = TemporaryDirectory()
            filename = os.path.join(tmpdir.name, filename)
            to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT)
            self.assertTrue(export.called)
            tmpdir.cleanup()

    @genty_dataset(
        string_wave=(
            "tests/data/test_16KHZ_mono_400Hz.wav",
            BufferAudioSource,
        ),
        string_wave_large_file=(
            "tests/data/test_16KHZ_mono_400Hz.wav",
            WaveAudioSource,
            {"large_file": True},
        ),
        stdin=("-", StdinAudioSource),
        string_raw=("tests/data/test_16KHZ_mono_400Hz.raw", BufferAudioSource),
        string_raw_large_file=(
            "tests/data/test_16KHZ_mono_400Hz.raw",
            RawAudioSource,
            {"large_file": True},
        ),
        bytes_=(b"0" * 8000, BufferAudioSource),
    )
    def test_get_audio_source(self, input, expected_type, extra_args=None):
        kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1}
        if extra_args is not None:
            kwargs.update(extra_args)
        audio_source = get_audio_source(input, **kwargs)
        self.assertIsInstance(audio_source, expected_type)