Mercurial > hg > auditok
comparison tests/test_signal.py @ 400:323d59b404a2
Use pytest instead of genty
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sat, 25 May 2024 21:54:13 +0200 |
parents | d653e3f58f3c |
children | 996948ada980 |
comparison
equal
deleted
inserted
replaced
399:08f893725d23 | 400:323d59b404a2 |
---|---|
1 import unittest | 1 import pytest |
2 from unittest import TestCase | |
3 from array import array as array_ | 2 from array import array as array_ |
4 from genty import genty, genty_dataset | |
5 import numpy as np | 3 import numpy as np |
6 from auditok import signal as signal_ | 4 from auditok import signal as signal_ |
7 from auditok import signal_numpy | 5 from auditok import signal_numpy |
8 | 6 |
9 | 7 |
10 @genty | 8 @pytest.fixture |
11 class TestSignal(TestCase): | 9 def setup_data(): |
12 def setUp(self): | 10 return b"012345679ABC" |
13 self.data = b"012345679ABC" | 11 |
14 self.numpy_fmt = {"b": np.int8, "h": np.int16, "i": np.int32} | 12 |
15 | 13 @pytest.fixture |
16 @genty_dataset( | 14 def numpy_fmt(): |
17 int8_mono=(1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]]), | 15 return {"b": np.int8, "h": np.int16, "i": np.int32} |
18 int16_mono=(2, [[12592, 13106, 13620, 14134, 16697, 17218]]), | 16 |
19 int32_mono=(4, [[858927408, 926299444, 1128415545]]), | 17 |
20 int8_stereo=(1, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]]), | 18 @pytest.mark.parametrize( |
21 int16_stereo=(2, [[12592, 13620, 16697], [13106, 14134, 17218]]), | 19 "sample_width, expected", |
22 int32_3channel=(4, [[858927408], [926299444], [1128415545]]), | 20 [ |
23 ) | 21 (1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]]), # int8_mono |
24 def test_to_array(self, sample_width, expected): | 22 (2, [[12592, 13106, 13620, 14134, 16697, 17218]]), # int16_mono |
25 channels = len(expected) | 23 (4, [[858927408, 926299444, 1128415545]]), # int32_mono |
26 expected = [ | 24 ( |
27 array_(signal_.FORMAT[sample_width], xi) for xi in expected | 25 1, |
28 ] | 26 [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], |
29 result = signal_.to_array(self.data, sample_width, channels) | 27 ), # int8_stereo |
30 result_numpy = signal_numpy.to_array(self.data, sample_width, channels) | 28 (2, [[12592, 13620, 16697], [13106, 14134, 17218]]), # int16_stereo |
31 self.assertEqual(result, expected) | 29 (4, [[858927408], [926299444], [1128415545]]), # int32_3channel |
32 self.assertTrue((result_numpy == np.asarray(expected)).all()) | 30 ], |
33 self.assertEqual(result_numpy.dtype, np.float64) | 31 ids=[ |
34 | 32 "int8_mono", |
35 @genty_dataset( | 33 "int16_mono", |
36 int8_1channel_select_0=( | 34 "int32_mono", |
35 "int8_stereo", | |
36 "int16_stereo", | |
37 "int32_3channel", | |
38 ], | |
39 ) | |
40 def test_to_array(setup_data, sample_width, expected): | |
41 data = setup_data | |
42 channels = len(expected) | |
43 expected = [array_(signal_.FORMAT[sample_width], xi) for xi in expected] | |
44 result = signal_.to_array(data, sample_width, channels) | |
45 result_numpy = signal_numpy.to_array(data, sample_width, channels) | |
46 assert result == expected | |
47 assert (result_numpy == np.asarray(expected)).all() | |
48 assert result_numpy.dtype == np.float64 | |
49 | |
50 | |
51 @pytest.mark.parametrize( | |
52 "fmt, channels, selected, expected", | |
53 [ | |
54 ( | |
37 "b", | 55 "b", |
38 1, | 56 1, |
39 0, | 57 0, |
40 [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], | 58 [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], |
41 ), | 59 ), # int8_1channel_select_0 |
42 int8_2channel_select_0=("b", 2, 0, [48, 50, 52, 54, 57, 66]), | 60 ("b", 2, 0, [48, 50, 52, 54, 57, 66]), # int8_2channel_select_0 |
43 int8_3channel_select_0=("b", 3, 0, [48, 51, 54, 65]), | 61 ("b", 3, 0, [48, 51, 54, 65]), # int8_3channel_select_0 |
44 int8_3channel_select_1=("b", 3, 1, [49, 52, 55, 66]), | 62 ("b", 3, 1, [49, 52, 55, 66]), # int8_3channel_select_1 |
45 int8_3channel_select_2=("b", 3, 2, [50, 53, 57, 67]), | 63 ("b", 3, 2, [50, 53, 57, 67]), # int8_3channel_select_2 |
46 int8_4channel_select_0=("b", 4, 0, [48, 52, 57]), | 64 ("b", 4, 0, [48, 52, 57]), # int8_4channel_select_0 |
47 int16_1channel_select_0=( | 65 ( |
48 "h", | 66 "h", |
49 1, | 67 1, |
50 0, | 68 0, |
51 [12592, 13106, 13620, 14134, 16697, 17218], | 69 [12592, 13106, 13620, 14134, 16697, 17218], |
52 ), | 70 ), # int16_1channel_select_0 |
53 int16_2channel_select_0=("h", 2, 0, [12592, 13620, 16697]), | 71 ("h", 2, 0, [12592, 13620, 16697]), # int16_2channel_select_0 |
54 int16_2channel_select_1=("h", 2, 1, [13106, 14134, 17218]), | 72 ("h", 2, 1, [13106, 14134, 17218]), # int16_2channel_select_1 |
55 int16_3channel_select_0=("h", 3, 0, [12592, 14134]), | 73 ("h", 3, 0, [12592, 14134]), # int16_3channel_select_0 |
56 int16_3channel_select_1=("h", 3, 1, [13106, 16697]), | 74 ("h", 3, 1, [13106, 16697]), # int16_3channel_select_1 |
57 int16_3channel_select_2=("h", 3, 2, [13620, 17218]), | 75 ("h", 3, 2, [13620, 17218]), # int16_3channel_select_2 |
58 int32_1channel_select_0=( | 76 ( |
59 "i", | 77 "i", |
60 1, | 78 1, |
61 0, | 79 0, |
62 [858927408, 926299444, 1128415545], | 80 [858927408, 926299444, 1128415545], |
63 ), | 81 ), # int32_1channel_select_0 |
64 int32_3channel_select_0=("i", 3, 0, [858927408]), | 82 ("i", 3, 0, [858927408]), # int32_3channel_select_0 |
65 int32_3channel_select_1=("i", 3, 1, [926299444]), | 83 ("i", 3, 1, [926299444]), # int32_3channel_select_1 |
66 int32_3channel_select_2=("i", 3, 2, [1128415545]), | 84 ("i", 3, 2, [1128415545]), # int32_3channel_select_2 |
67 ) | 85 ], |
68 def test_extract_single_channel(self, fmt, channels, selected, expected): | 86 ids=[ |
69 result = signal_.extract_single_channel( | 87 "int8_1channel_select_0", |
70 self.data, fmt, channels, selected | 88 "int8_2channel_select_0", |
71 ) | 89 "int8_3channel_select_0", |
72 expected = array_(fmt, expected) | 90 "int8_3channel_select_1", |
73 expected_numpy_fmt = self.numpy_fmt[fmt] | 91 "int8_3channel_select_2", |
74 self.assertEqual(result, expected) | 92 "int8_4channel_select_0", |
75 result_numpy = signal_numpy.extract_single_channel( | 93 "int16_1channel_select_0", |
76 self.data, self.numpy_fmt[fmt], channels, selected | 94 "int16_2channel_select_0", |
77 ) | 95 "int16_2channel_select_1", |
78 self.assertTrue(all(result_numpy == expected)) | 96 "int16_3channel_select_0", |
79 self.assertEqual(result_numpy.dtype, expected_numpy_fmt) | 97 "int16_3channel_select_1", |
80 | 98 "int16_3channel_select_2", |
81 @genty_dataset( | 99 "int32_1channel_select_0", |
82 int8_2channel=("b", 2, [48, 50, 52, 54, 61, 66]), | 100 "int32_3channel_select_0", |
83 int8_4channel=("b", 4, [50, 54, 64]), | 101 "int32_3channel_select_1", |
84 int16_1channel=("h", 1, [12592, 13106, 13620, 14134, 16697, 17218]), | 102 "int32_3channel_select_2", |
85 int16_2channel=("h", 2, [12849, 13877, 16958]), | 103 ], |
86 int32_3channel=("i", 3, [971214132]), | 104 ) |
87 ) | 105 def test_extract_single_channel( |
88 def test_compute_average_channel(self, fmt, channels, expected): | 106 setup_data, numpy_fmt, fmt, channels, selected, expected |
89 result = signal_.compute_average_channel(self.data, fmt, channels) | 107 ): |
90 expected = array_(fmt, expected) | 108 data = setup_data |
91 expected_numpy_fmt = self.numpy_fmt[fmt] | 109 result = signal_.extract_single_channel(data, fmt, channels, selected) |
92 self.assertEqual(result, expected) | 110 expected = array_(fmt, expected) |
93 result_numpy = signal_numpy.compute_average_channel( | 111 expected_numpy_fmt = numpy_fmt[fmt] |
94 self.data, self.numpy_fmt[fmt], channels | 112 assert result == expected |
95 ) | 113 result_numpy = signal_numpy.extract_single_channel( |
96 self.assertTrue(all(result_numpy == expected)) | 114 data, numpy_fmt[fmt], channels, selected |
97 self.assertEqual(result_numpy.dtype, expected_numpy_fmt) | 115 ) |
98 | 116 assert all(result_numpy == expected) |
99 @genty_dataset( | 117 assert result_numpy.dtype == expected_numpy_fmt |
100 int8_2channel=(1, [48, 50, 52, 54, 61, 66]), | 118 |
101 int16_2channel=(2, [12849, 13877, 16957]), | 119 |
102 ) | 120 @pytest.mark.parametrize( |
103 def test_compute_average_channel_stereo(self, sample_width, expected): | 121 "fmt, channels, expected", |
104 result = signal_.compute_average_channel_stereo( | 122 [ |
105 self.data, sample_width | 123 ("b", 2, [48, 50, 52, 54, 61, 66]), # int8_2channel |
106 ) | 124 ("b", 4, [50, 54, 64]), # int8_4channel |
107 fmt = signal_.FORMAT[sample_width] | 125 ("h", 1, [12592, 13106, 13620, 14134, 16697, 17218]), # int16_1channel |
108 expected = array_(fmt, expected) | 126 ("h", 2, [12849, 13877, 16958]), # int16_2channel |
109 self.assertEqual(result, expected) | 127 ("i", 3, [971214132]), # int32_3channel |
110 | 128 ], |
111 @genty_dataset( | 129 ids=[ |
112 int8_1channel=( | 130 "int8_2channel", |
131 "int8_4channel", | |
132 "int16_1channel", | |
133 "int16_2channel", | |
134 "int32_3channel", | |
135 ], | |
136 ) | |
137 def test_compute_average_channel( | |
138 setup_data, numpy_fmt, fmt, channels, expected | |
139 ): | |
140 data = setup_data | |
141 result = signal_.compute_average_channel(data, fmt, channels) | |
142 expected = array_(fmt, expected) | |
143 expected_numpy_fmt = numpy_fmt[fmt] | |
144 assert result == expected | |
145 result_numpy = signal_numpy.compute_average_channel( | |
146 data, numpy_fmt[fmt], channels | |
147 ) | |
148 assert all(result_numpy == expected) | |
149 assert result_numpy.dtype == expected_numpy_fmt | |
150 | |
151 | |
152 @pytest.mark.parametrize( | |
153 "sample_width, expected", | |
154 [ | |
155 (1, [48, 50, 52, 54, 61, 66]), # int8_2channel | |
156 (2, [12849, 13877, 16957]), # int16_2channel | |
157 ], | |
158 ids=["int8_2channel", "int16_2channel"], | |
159 ) | |
160 def test_compute_average_channel_stereo(setup_data, sample_width, expected): | |
161 data = setup_data | |
162 result = signal_.compute_average_channel_stereo(data, sample_width) | |
163 fmt = signal_.FORMAT[sample_width] | |
164 expected = array_(fmt, expected) | |
165 assert result == expected | |
166 | |
167 | |
168 @pytest.mark.parametrize( | |
169 "fmt, channels, expected", | |
170 [ | |
171 ( | |
113 "b", | 172 "b", |
114 1, | 173 1, |
115 [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], | 174 [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], |
116 ), | 175 ), # int8_1channel |
117 int8_2channel=( | 176 ( |
118 "b", | 177 "b", |
119 2, | 178 2, |
120 [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], | 179 [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], |
121 ), | 180 ), # int8_2channel |
122 int8_4channel=( | 181 ( |
123 "b", | 182 "b", |
124 4, | 183 4, |
125 [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], | 184 [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], |
126 ), | 185 ), # int8_4channel |
127 int16_2channel=( | 186 ( |
128 "h", | 187 "h", |
129 2, | 188 2, |
130 [[12592, 13620, 16697], [13106, 14134, 17218]], | 189 [[12592, 13620, 16697], [13106, 14134, 17218]], |
131 ), | 190 ), # int16_2channel |
132 int32_3channel=("i", 3, [[858927408], [926299444], [1128415545]]), | 191 ("i", 3, [[858927408], [926299444], [1128415545]]), # int32_3channel |
133 ) | 192 ], |
134 def test_separate_channels(self, fmt, channels, expected): | 193 ids=[ |
135 result = signal_.separate_channels(self.data, fmt, channels) | 194 "int8_1channel", |
136 expected = [array_(fmt, exp) for exp in expected] | 195 "int8_2channel", |
137 expected_numpy_fmt = self.numpy_fmt[fmt] | 196 "int8_4channel", |
138 self.assertEqual(result, expected) | 197 "int16_2channel", |
139 | 198 "int32_3channel", |
140 result_numpy = signal_numpy.separate_channels( | 199 ], |
141 self.data, self.numpy_fmt[fmt], channels | 200 ) |
142 ) | 201 def test_separate_channels(setup_data, numpy_fmt, fmt, channels, expected): |
143 self.assertTrue((result_numpy == expected).all()) | 202 data = setup_data |
144 self.assertEqual(result_numpy.dtype, expected_numpy_fmt) | 203 result = signal_.separate_channels(data, fmt, channels) |
145 | 204 expected = [array_(fmt, exp) for exp in expected] |
146 @genty_dataset( | 205 expected_numpy_fmt = numpy_fmt[fmt] |
147 simple=([300, 320, 400, 600], 2, 52.50624901923348), | 206 assert result == expected |
148 zero=([0], 2, -200), | 207 result_numpy = signal_numpy.separate_channels( |
149 zeros=([0, 0, 0], 2, -200), | 208 data, numpy_fmt[fmt], channels |
150 ) | 209 ) |
151 def test_calculate_energy_single_channel(self, x, sample_width, expected): | 210 assert (result_numpy == expected).all() |
152 x = array_(signal_.FORMAT[sample_width], x) | 211 assert result_numpy.dtype == expected_numpy_fmt |
153 energy = signal_.calculate_energy_single_channel(x, sample_width) | 212 |
154 self.assertEqual(energy, expected) | 213 |
155 energy = signal_numpy.calculate_energy_single_channel(x, sample_width) | 214 @pytest.mark.parametrize( |
156 self.assertEqual(energy, expected) | 215 "x, sample_width, expected", |
157 | 216 [ |
158 @genty_dataset( | 217 ([300, 320, 400, 600], 2, 52.50624901923348), # simple |
159 min_=( | 218 ([0], 2, -200), # zero |
219 ([0, 0, 0], 2, -200), # zeros | |
220 ], | |
221 ids=["simple", "zero", "zeros"], | |
222 ) | |
223 def test_calculate_energy_single_channel(x, sample_width, expected): | |
224 x = array_(signal_.FORMAT[sample_width], x) | |
225 energy = signal_.calculate_energy_single_channel(x, sample_width) | |
226 assert energy == expected | |
227 energy = signal_numpy.calculate_energy_single_channel(x, sample_width) | |
228 assert energy == expected | |
229 | |
230 | |
231 @pytest.mark.parametrize( | |
232 "x, sample_width, aggregation_fn, expected", | |
233 [ | |
234 ( | |
160 [[300, 320, 400, 600], [150, 160, 200, 300]], | 235 [[300, 320, 400, 600], [150, 160, 200, 300]], |
161 2, | 236 2, |
162 min, | 237 min, |
163 46.485649105953854, | 238 46.485649105953854, |
164 ), | 239 ), # min_ |
165 max_=( | 240 ( |
166 [[300, 320, 400, 600], [150, 160, 200, 300]], | 241 [[300, 320, 400, 600], [150, 160, 200, 300]], |
167 2, | 242 2, |
168 max, | 243 max, |
169 52.50624901923348, | 244 52.50624901923348, |
170 ), | 245 ), # max_ |
171 ) | 246 ], |
172 def test_calculate_energy_multichannel( | 247 ids=["min_", "max_"], |
173 self, x, sample_width, aggregation_fn, expected | 248 ) |
174 ): | 249 def test_calculate_energy_multichannel( |
175 x = [array_(signal_.FORMAT[sample_width], xi) for xi in x] | 250 x, sample_width, aggregation_fn, expected |
176 energy = signal_.calculate_energy_multichannel( | 251 ): |
177 x, sample_width, aggregation_fn | 252 x = [array_(signal_.FORMAT[sample_width], xi) for xi in x] |
178 ) | 253 energy = signal_.calculate_energy_multichannel( |
179 self.assertEqual(energy, expected) | 254 x, sample_width, aggregation_fn |
180 | 255 ) |
181 energy = signal_numpy.calculate_energy_multichannel( | 256 assert energy == expected |
182 x, sample_width, aggregation_fn | 257 energy = signal_numpy.calculate_energy_multichannel( |
183 ) | 258 x, sample_width, aggregation_fn |
184 self.assertEqual(energy, expected) | 259 ) |
185 | 260 assert energy == expected |
186 | |
187 if __name__ == "__main__": | |
188 unittest.main() |