Mercurial > hg > auditok
view tests/test_io.py @ 400:323d59b404a2
Use pytest instead of genty
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sat, 25 May 2024 21:54:13 +0200 |
parents | 9f17aa9a4018 |
children | 996948ada980 |
line wrap: on
line source
import os import sys import math from array import array from tempfile import NamedTemporaryFile, TemporaryDirectory import filecmp import pytest from unittest.mock import patch, Mock from test_util import _sample_generator, _generate_pure_tone, PURE_TONE_DICT from auditok.signal import FORMAT from auditok.io import ( AudioIOError, AudioParameterError, BufferAudioSource, RawAudioSource, WaveAudioSource, StdinAudioSource, check_audio_data, _guess_audio_format, _get_audio_parameters, _load_raw, _load_wave, _load_with_pydub, get_audio_source, from_file, _save_raw, _save_wave, _save_with_pydub, to_file, ) AUDIO_PARAMS_SHORT = {"sr": 16000, "sw": 2, "ch": 1} @pytest.mark.parametrize( "data, sample_width, channels, valid", [ (b"\0" * 113, 1, 1, True), # valid_mono (b"\0" * 160, 1, 2, True), # valid_stereo (b"\0" * 113, 2, 1, False), # invalid_mono_sw_2 (b"\0" * 113, 1, 2, False), # invalid_stereo_sw_1 (b"\0" * 158, 2, 2, False), # invalid_stereo_sw_2 ], ids=[ "valid_mono", "valid_stereo", "invalid_mono_sw_2", "invalid_stereo_sw_1", "invalid_stereo_sw_2", ], ) def test_check_audio_data(data, sample_width, channels, valid): if not valid: with pytest.raises(AudioParameterError): check_audio_data(data, sample_width, channels) else: assert check_audio_data(data, sample_width, channels) is None @pytest.mark.parametrize( "fmt, filename, expected", [ ("wav", "filename.wav", "wav"), # extention_and_format_same ("wav", "filename.mp3", "wav"), # extention_and_format_different (None, "filename.wav", "wav"), # extention_no_format ("wav", "filename", "wav"), # format_no_extension (None, "filename", None), # no_format_no_extension ("wave", "filename", "wav"), # wave_as_wav (None, "filename.wave", "wav"), # wave_as_wav_extension ], ids=[ "extention_and_format_same", "extention_and_format_different", "extention_no_format", "format_no_extension", "no_format_no_extension", "wave_as_wav", "wave_as_wav_extension", ], ) def test_guess_audio_format(fmt, filename, expected): result = _guess_audio_format(fmt, filename) assert result == expected def test_get_audio_parameters_short_params(): expected = (8000, 2, 1) params = dict(zip(("sr", "sw", "ch"), expected)) result = _get_audio_parameters(params) assert result == expected def test_get_audio_parameters_long_params(): expected = (8000, 2, 1) params = dict( zip( ("sampling_rate", "sample_width", "channels", "use_channel"), expected, ) ) result = _get_audio_parameters(params) assert result == expected def test_get_audio_parameters_long_params_shadow_short_ones(): expected = (8000, 2, 1) params = dict(zip(("sampling_rate", "sample_width", "channels"), expected)) params.update(dict(zip(("sr", "sw", "ch"), "xxx"))) result = _get_audio_parameters(params) assert result == expected @pytest.mark.parametrize( "values", [ ("x", 2, 1), # str_sampling_rate (-8000, 2, 1), # negative_sampling_rate (8000, "x", 1), # str_sample_width (8000, -2, 1), # negative_sample_width (8000, 2, "x"), # str_channels (8000, 2, -1), # negative_channels ], ids=[ "str_sampling_rate", "negative_sampling_rate", "str_sample_width", "negative_sample_width", "str_channels", "negative_channels", ], ) def test_get_audio_parameters_invalid(values): params = dict(zip(("sampling_rate", "sample_width", "channels"), values)) with pytest.raises(AudioParameterError): _get_audio_parameters(params) @pytest.mark.parametrize( "filename, audio_format, funtion_name, kwargs", [ ( "audio", "raw", "_load_raw", AUDIO_PARAMS_SHORT, ), # raw_with_audio_format ( "audio.raw", None, "_load_raw", AUDIO_PARAMS_SHORT, ), # raw_with_extension ("audio", "wave", "_load_wave", None), # wave_with_audio_format ("audio", "wave", "_load_wave", None), # wav_with_audio_format ("audio.wav", None, "_load_wave", None), # wav_with_extension ( "audio.dat", "wav", "_load_wave", None, ), # format_and_extension_both_given_a ( "audio.raw", "wave", "_load_wave", None, ), # format_and_extension_both_given_b ("audio", None, "_load_with_pydub", None), # no_format_nor_extension ("audio.ogg", None, "_load_with_pydub", None), # other_formats_ogg ("audio", "webm", "_load_with_pydub", None), # other_formats_webm ], ids=[ "raw_with_audio_format", "raw_with_extension", "wave_with_audio_format", "wav_with_audio_format", "wav_with_extension", "format_and_extension_both_given_a", "format_and_extension_both_given_b", "no_format_nor_extension", "other_formats_ogg", "other_formats_webm", ], ) def test_from_file(filename, audio_format, funtion_name, kwargs): funtion_name = "auditok.io." + funtion_name if kwargs is None: kwargs = {} with patch(funtion_name) as patch_function: from_file(filename, audio_format, **kwargs) assert patch_function.called def test_from_file_large_file_raw(): filename = "tests/data/test_16KHZ_mono_400Hz.raw" audio_source = from_file( filename, large_file=True, sampling_rate=16000, sample_width=2, channels=1, ) assert isinstance(audio_source, RawAudioSource) def test_from_file_large_file_wave(): filename = "tests/data/test_16KHZ_mono_400Hz.wav" audio_source = from_file(filename, large_file=True) assert isinstance(audio_source, WaveAudioSource) def test_from_file_large_file_compressed(): filename = "tests/data/test_16KHZ_mono_400Hz.ogg" with pytest.raises(AudioIOError): from_file(filename, large_file=True) @pytest.mark.parametrize( "missing_param", [ "sr", # missing_sampling_rate "sw", # missing_sample_width "ch", # missing_channels ], ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], ) def test_from_file_missing_audio_param(missing_param): with pytest.raises(AudioParameterError): params = AUDIO_PARAMS_SHORT.copy() del params[missing_param] from_file("audio", audio_format="raw", **params) def test_from_file_no_pydub(): with patch("auditok.io._WITH_PYDUB", False): with pytest.raises(AudioIOError): from_file("audio", "mp3") @pytest.mark.parametrize( "audio_format, function", [ ("ogg", "from_ogg"), # ogg_first_channel ("ogg", "from_ogg"), # ogg_second_channel ("ogg", "from_ogg"), # ogg_mix ("ogg", "from_ogg"), # ogg_default ("mp3", "from_mp3"), # mp3_left_channel ("mp3", "from_mp3"), # mp3_right_channel ("flac", "from_file"), # flac_first_channel ("flac", "from_file"), # flac_second_channel ("flv", "from_flv"), # flv_left_channel ("webm", "from_file"), # webm_right_channel ], ids=[ "ogg_first_channel", "ogg_second_channel", "ogg_mix", "ogg_default", "mp3_left_channel", "mp3_right_channel", "flac_first_channel", "flac_second_channel", "flv_left_channel", "webm_right_channel", ], ) @patch("auditok.io._WITH_PYDUB", True) @patch("auditok.io.BufferAudioSource") def test_from_file_multichannel_audio_compressed( mock_buffer_audio_source, audio_format, function ): filename = "audio.{}".format(audio_format) segment_mock = Mock() segment_mock.sample_width = 2 segment_mock.channels = 2 segment_mock._data = b"abcd" with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: open_func.return_value = segment_mock from_file(filename) assert open_func.called @pytest.mark.parametrize( "file_id, frequencies, large_file", [ ("mono_400", (400,), False), # mono ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel ("mono_400", (400,), True), # mono_large_file ( "3channel_400-800-1600", (400, 800, 1600), True, ), # three_channel_large_file ], ids=[ "mono", "three_channel", "mono_large_file", "three_channel_large_file", ], ) def test_load_raw(file_id, frequencies, large_file): filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) audio_source = _load_raw( filename, 16000, 2, len(frequencies), large_file=large_file ) audio_source.open() data = audio_source.read(-1) audio_source.close() expected_class = RawAudioSource if large_file else BufferAudioSource assert isinstance(audio_source, expected_class) assert audio_source.sampling_rate == 16000 assert audio_source.sample_width == 2 assert audio_source.channels == len(frequencies) mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] fmt = FORMAT[audio_source.sample_width] expected = array(fmt, _sample_generator(*mono_channels)).tobytes() assert data == expected @pytest.mark.parametrize( "missing_param", [ "sr", # missing_sampling_rate "sw", # missing_sample_width "ch", # missing_channels ], ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], ) def test_load_raw_missing_audio_param(missing_param): with pytest.raises(AudioParameterError): params = AUDIO_PARAMS_SHORT.copy() del params[missing_param] srate, swidth, channels, _ = _get_audio_parameters(params) _load_raw("audio", srate, swidth, channels) @pytest.mark.parametrize( "file_id, frequencies, large_file", [ ("mono_400", (400,), False), # mono ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel ("mono_400", (400,), True), # mono_large_file ( "3channel_400-800-1600", (400, 800, 1600), True, ), # three_channel_large_file ], ids=[ "mono", "three_channel", "mono_large_file", "three_channel_large_file", ], ) def test_load_wave(file_id, frequencies, large_file): filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) audio_source = _load_wave(filename, large_file=large_file) audio_source.open() data = audio_source.read(-1) audio_source.close() expected_class = WaveAudioSource if large_file else BufferAudioSource assert isinstance(audio_source, expected_class) assert audio_source.sampling_rate == 16000 assert audio_source.sample_width == 2 assert audio_source.channels == len(frequencies) mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] fmt = FORMAT[audio_source.sample_width] expected = array(fmt, _sample_generator(*mono_channels)).tobytes() assert data == expected @pytest.mark.parametrize( "audio_format, channels, function", [ ("ogg", 2, "from_ogg"), # ogg_default_first_channel ("ogg", 1, "from_ogg"), # ogg_first_channel ("ogg", 2, "from_ogg"), # ogg_second_channel ("ogg", 3, "from_ogg"), # ogg_mix_channels ("mp3", 1, "from_mp3"), # mp3_left_channel ("mp3", 2, "from_mp3"), # mp3_right_channel ("mp3", 3, "from_mp3"), # mp3_mix_channels ("flac", 2, "from_file"), # flac_first_channel ("flac", 2, "from_file"), # flac_second_channel ("flv", 1, "from_flv"), # flv_left_channel ("webm", 2, "from_file"), # webm_right_channel ("webm", 4, "from_file"), # webm_mix_channels ], ids=[ "ogg_default_first_channel", "ogg_first_channel", "ogg_second_channel", "ogg_mix_channels", "mp3_left_channel", "mp3_right_channel", "mp3_mix_channels", "flac_first_channel", "flac_second_channel", "flv_left_channel", "webm_right_channel", "webm_mix_channels", ], ) @patch("auditok.io._WITH_PYDUB", True) @patch("auditok.io.BufferAudioSource") def test_load_with_pydub( mock_buffer_audio_source, audio_format, channels, function ): filename = "audio.{}".format(audio_format) segment_mock = Mock() segment_mock.sample_width = 2 segment_mock.channels = channels segment_mock._data = b"abcdefgh" with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: open_func.return_value = segment_mock _load_with_pydub(filename, audio_format) assert open_func.called @pytest.mark.parametrize( "filename, frequencies", [ ("mono_400Hz.raw", (400,)), # mono ("3channel_400-800-1600Hz.raw", (400, 800, 1600)), # three_channel ], ids=["mono", "three_channel"], ) def test_save_raw(filename, frequencies): filename = "tests/data/test_16KHZ_{}".format(filename) sample_width = 2 fmt = FORMAT[sample_width] mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] data = array(fmt, _sample_generator(*mono_channels)).tobytes() tmpfile = NamedTemporaryFile() _save_raw(data, tmpfile.name) assert filecmp.cmp(tmpfile.name, filename, shallow=False) @pytest.mark.parametrize( "filename, frequencies", [ ("mono_400Hz.wav", (400,)), # mono ("3channel_400-800-1600Hz.wav", (400, 800, 1600)), # three_channel ], ids=["mono", "three_channel"], ) def test_save_wave(filename, frequencies): filename = "tests/data/test_16KHZ_{}".format(filename) sampling_rate = 16000 sample_width = 2 channels = len(frequencies) fmt = FORMAT[sample_width] mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] data = array(fmt, _sample_generator(*mono_channels)).tobytes() tmpfile = NamedTemporaryFile() _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels) assert filecmp.cmp(tmpfile.name, filename, shallow=False) @pytest.mark.parametrize( "missing_param", [ "sr", # missing_sampling_rate "sw", # missing_sample_width "ch", # missing_channels ], ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], ) def test_save_wave_missing_audio_param(missing_param): with pytest.raises(AudioParameterError): params = AUDIO_PARAMS_SHORT.copy() del params[missing_param] srate, swidth, channels, _ = _get_audio_parameters(params) _save_wave(b"\0\0", "audio", srate, swidth, channels) def test_save_with_pydub(): with patch("auditok.io.AudioSegment.export") as export: tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, "audio.ogg") _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1) assert export.called tmpdir.cleanup() @pytest.mark.parametrize( "filename, audio_format", [ ("audio", "raw"), # raw_with_audio_format ("audio.raw", None), # raw_with_extension ("audio.mp3", "raw"), # raw_with_audio_format_and_extension ("audio", None), # raw_no_audio_format_nor_extension ], ids=[ "raw_with_audio_format", "raw_with_extension", "raw_with_audio_format_and_extension", "raw_no_audio_format_nor_extension", ], ) def test_to_file_raw(filename, audio_format): exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw" tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, filename) data = PURE_TONE_DICT[400].tobytes() to_file(data, filename, audio_format=audio_format) assert filecmp.cmp(filename, exp_filename, shallow=False) tmpdir.cleanup() @pytest.mark.parametrize( "filename, audio_format", [ ("audio", "wav"), # wav_with_audio_format ("audio.wav", None), # wav_with_extension ("audio.mp3", "wav"), # wav_with_audio_format_and_extension ("audio", "wave"), # wave_with_audio_format ("audio.wave", None), # wave_with_extension ("audio.mp3", "wave"), # wave_with_audio_format_and_extension ], ids=[ "wav_with_audio_format", "wav_with_extension", "wav_with_audio_format_and_extension", "wave_with_audio_format", "wave_with_extension", "wave_with_audio_format_and_extension", ], ) def test_to_file_wave(filename, audio_format): exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav" tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, filename) data = PURE_TONE_DICT[400].tobytes() to_file( data, filename, audio_format=audio_format, sampling_rate=16000, sample_width=2, channels=1, ) assert filecmp.cmp(filename, exp_filename, shallow=False) tmpdir.cleanup() @pytest.mark.parametrize( "missing_param", [ "sr", # missing_sampling_rate "sw", # missing_sample_width "ch", # missing_channels ], ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], ) def test_to_file_missing_audio_param(missing_param): params = AUDIO_PARAMS_SHORT.copy() del params[missing_param] with pytest.raises(AudioParameterError): to_file(b"\0\0", "audio", audio_format="wav", **params) with pytest.raises(AudioParameterError): to_file(b"\0\0", "audio", audio_format="mp3", **params) def test_to_file_no_pydub(): with patch("auditok.io._WITH_PYDUB", False): with pytest.raises(AudioIOError): to_file("audio", b"", "mp3") @pytest.mark.parametrize( "filename, audio_format", [ ("audio.ogg", None), # ogg_with_extension ("audio", "ogg"), # ogg_with_audio_format ("audio.wav", "ogg"), # ogg_format_with_wrong_extension ], ids=[ "ogg_with_extension", "ogg_with_audio_format", "ogg_format_with_wrong_extension", ], ) @patch("auditok.io._WITH_PYDUB", True) def test_to_file_compressed(filename, audio_format): with patch("auditok.io.AudioSegment.export") as export: tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, filename) to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT) assert export.called tmpdir.cleanup() @pytest.mark.parametrize( "input, expected_type, extra_args", [ ( "tests/data/test_16KHZ_mono_400Hz.wav", BufferAudioSource, None, ), # string_wave ( "tests/data/test_16KHZ_mono_400Hz.wav", WaveAudioSource, {"large_file": True}, ), # string_wave_large_file ("-", StdinAudioSource, None), # stdin ( "tests/data/test_16KHZ_mono_400Hz.raw", BufferAudioSource, None, ), # string_raw ( "tests/data/test_16KHZ_mono_400Hz.raw", RawAudioSource, {"large_file": True}, ), # string_raw_large_file (b"0" * 8000, BufferAudioSource, None), # bytes_ ], ids=[ "string_wave", "string_wave_large_file", "stdin", "string_raw", "string_raw_large_file", "bytes_", ], ) def test_get_audio_source(input, expected_type, extra_args): kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} if extra_args is not None: kwargs.update(extra_args) audio_source = get_audio_source(input, **kwargs) assert isinstance(audio_source, expected_type)