comparison tests/test_core.py @ 426:c5b4178aa80f

Refactor tests
author Amine Sehili <amine.sehili@gmail.com>
date Tue, 29 Oct 2024 18:50:28 +0100
parents 14efef6f4bae
children
comparison
equal deleted inserted replaced
425:1b78211b7e07 426:c5b4178aa80f
1 import math 1 import math
2 import os 2 import os
3 from pathlib import Path 3 from pathlib import Path
4 from random import random 4 from random import random
5 from tempfile import TemporaryDirectory 5 from tempfile import TemporaryDirectory
6 from unittest.mock import Mock, patch 6 from unittest import mock
7 from unittest.mock import patch
7 8
8 import numpy as np 9 import numpy as np
9 import pytest 10 import pytest
10 11
11 from auditok import ( 12 from auditok import (
23 _read_offline, 24 _read_offline,
24 ) 25 )
25 from auditok.io import get_audio_source 26 from auditok.io import get_audio_source
26 from auditok.signal import to_array 27 from auditok.signal import to_array
27 from auditok.util import AudioReader 28 from auditok.util import AudioReader
29
30 mock._magics.add("__round__")
28 31
29 32
30 def _make_random_length_regions( 33 def _make_random_length_regions(
31 byte_seq, sampling_rate, sample_width, channels 34 byte_seq, sampling_rate, sample_width, channels
32 ): 35 ):
407 err_msg = "Wrong number of regions after AudioRegion.split, expected: " 410 err_msg = "Wrong number of regions after AudioRegion.split, expected: "
408 err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) 411 err_msg += "{}, found: {}".format(len(expected), len(regions_ar))
409 assert len(regions_ar) == len(expected), err_msg 412 assert len(regions_ar) == len(expected), err_msg
410 413
411 sample_width = 2 414 sample_width = 2
412 for reg, reg_ar, exp in zip(regions, regions_ar, expected, strict=True): 415 for reg, reg_ar, exp in zip(regions, regions_ar, expected):
413 onset, offset = exp 416 onset, offset = exp
414 exp_data = data[onset * sample_width : offset * sample_width] 417 exp_data = data[onset * sample_width : offset * sample_width]
415 assert bytes(reg) == exp_data 418 assert bytes(reg) == exp_data
416 assert reg == reg_ar 419 assert reg == reg_ar
417 420
536 err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) 539 err_msg += "{}, found: {}".format(len(expected), len(regions_ar))
537 assert len(regions_ar) == len(expected), err_msg 540 assert len(regions_ar) == len(expected), err_msg
538 541
539 sample_width = 2 542 sample_width = 2
540 sample_size_bytes = sample_width * channels 543 sample_size_bytes = sample_width * channels
541 for reg, reg_ar, exp in zip(regions, regions_ar, expected, strict=True): 544 for reg, reg_ar, exp in zip(
545 regions,
546 regions_ar,
547 expected,
548 ):
542 onset, offset = exp 549 onset, offset = exp
543 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] 550 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
544 assert len(bytes(reg)) == len(exp_data) 551 assert len(bytes(reg)) == len(exp_data)
545 assert reg == reg_ar 552 assert reg == reg_ar
546 553
957 err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) 964 err_msg += "{}, found: {}".format(len(expected), len(regions_ar))
958 assert len(regions_ar) == len(expected), err_msg 965 assert len(regions_ar) == len(expected), err_msg
959 966
960 sample_width = 2 967 sample_width = 2
961 sample_size_bytes = sample_width * channels 968 sample_size_bytes = sample_width * channels
962 for reg, reg_ar, exp in zip(regions, regions_ar, expected, strict=True): 969 for reg, reg_ar, exp in zip(
970 regions,
971 regions_ar,
972 expected,
973 ):
963 onset, offset = exp 974 onset, offset = exp
964 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] 975 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
965 assert bytes(reg) == exp_data 976 assert bytes(reg) == exp_data
966 assert reg == reg_ar 977 assert reg == reg_ar
967 978
1005 err_msg = "Wrong number of regions after AudioRegion.split, expected: " 1016 err_msg = "Wrong number of regions after AudioRegion.split, expected: "
1006 err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) 1017 err_msg += "{}, found: {}".format(len(expected), len(regions_ar))
1007 assert len(regions_ar) == len(expected), err_msg 1018 assert len(regions_ar) == len(expected), err_msg
1008 1019
1009 sample_size_bytes = 2 1020 sample_size_bytes = 2
1010 for reg, reg_ar, exp in zip(regions, regions_ar, expected, strict=True): 1021 for reg, reg_ar, exp in zip(
1022 regions,
1023 regions_ar,
1024 expected,
1025 ):
1011 onset, offset = exp 1026 onset, offset = exp
1012 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] 1027 exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes]
1013 assert bytes(reg) == exp_data 1028 assert bytes(reg) == exp_data
1014 assert reg == reg_ar 1029 assert reg == reg_ar
1015 1030
1093 expected = [(2, 32), (34, 76)] 1108 expected = [(2, 32), (34, 76)]
1094 sample_width = 2 1109 sample_width = 2
1095 err_msg = "Wrong number of regions after split, expected: " 1110 err_msg = "Wrong number of regions after split, expected: "
1096 err_msg += "{}, found: {}".format(expected, regions) 1111 err_msg += "{}, found: {}".format(expected, regions)
1097 assert len(regions) == len(expected), err_msg 1112 assert len(regions) == len(expected), err_msg
1098 for reg, exp in zip(regions, expected, strict=True): 1113 for reg, exp in zip(
1114 regions,
1115 expected,
1116 ):
1099 onset, offset = exp 1117 onset, offset = exp
1100 exp_data = data[onset * sample_width * 2 : offset * sample_width * 2] 1118 exp_data = data[onset * sample_width * 2 : offset * sample_width * 2]
1101 assert bytes(reg) == exp_data 1119 assert bytes(reg) == exp_data
1102 1120
1103 1121
2155 def test_truediv(data): 2173 def test_truediv(data):
2156 2174
2157 region = AudioRegion(b"".join(data), 80, 1, 1) 2175 region = AudioRegion(b"".join(data), 80, 1, 1)
2158 2176
2159 sub_regions = region / len(data) 2177 sub_regions = region / len(data)
2160 for data_i, region in zip(data, sub_regions, strict=True): 2178 for data_i, region in zip(
2179 data,
2180 sub_regions,
2181 ):
2161 assert len(data_i) == len(bytes(region)) 2182 assert len(data_i) == len(bytes(region))
2162 2183
2163 2184
2164 @pytest.mark.parametrize( 2185 @pytest.mark.parametrize(
2165 "data, sample_width, channels, expected", 2186 "data, sample_width, channels, expected",