amine@274
|
1 import os
|
amine@282
|
2 from unittest.mock import patch, call, Mock
|
amine@274
|
3 from tempfile import TemporaryDirectory
|
amine@400
|
4 import pytest
|
amine@287
|
5 from auditok import AudioRegion, AudioDataSource
|
amine@292
|
6 from auditok.exceptions import AudioEncodingWarning
|
amine@274
|
7 from auditok.cmdline_util import make_logger
|
amine@274
|
8 from auditok.workers import (
|
amine@274
|
9 TokenizerWorker,
|
amine@274
|
10 StreamSaverWorker,
|
amine@274
|
11 RegionSaverWorker,
|
amine@274
|
12 PlayerWorker,
|
amine@274
|
13 CommandLineWorker,
|
amine@274
|
14 PrintWorker,
|
amine@274
|
15 )
|
amine@274
|
16
|
amine@274
|
17
|
amine@400
|
18 @pytest.fixture
|
amine@400
|
19 def audio_data_source():
|
amine@400
|
20 reader = AudioDataSource(
|
amine@400
|
21 input="tests/data/test_split_10HZ_mono.raw",
|
amine@400
|
22 block_dur=0.1,
|
amine@400
|
23 sr=10,
|
amine@400
|
24 sw=2,
|
amine@400
|
25 ch=1,
|
amine@400
|
26 )
|
amine@400
|
27 yield reader
|
amine@400
|
28 reader.close()
|
amine@275
|
29
|
amine@400
|
30
|
amine@400
|
31 @pytest.fixture
|
amine@400
|
32 def expected_detections():
|
amine@400
|
33 return [
|
amine@400
|
34 (0.2, 1.6),
|
amine@400
|
35 (1.7, 3.1),
|
amine@400
|
36 (3.4, 5.4),
|
amine@400
|
37 (5.4, 7.4),
|
amine@400
|
38 (7.4, 7.6),
|
amine@400
|
39 ]
|
amine@400
|
40
|
amine@400
|
41
|
amine@400
|
42 def test_TokenizerWorker(audio_data_source, expected_detections):
|
amine@400
|
43 with TemporaryDirectory() as tmpdir:
|
amine@400
|
44 file = os.path.join(tmpdir, "file.log")
|
amine@400
|
45 logger = make_logger(file=file, name="test_TokenizerWorker")
|
amine@400
|
46 tokenizer = TokenizerWorker(
|
amine@400
|
47 audio_data_source,
|
amine@400
|
48 logger=logger,
|
amine@400
|
49 min_dur=0.3,
|
amine@400
|
50 max_dur=2,
|
amine@400
|
51 max_silence=0.2,
|
amine@400
|
52 drop_trailing_silence=False,
|
amine@400
|
53 strict_min_dur=False,
|
amine@400
|
54 eth=50,
|
amine@274
|
55 )
|
amine@400
|
56 tokenizer.start_all()
|
amine@400
|
57 tokenizer.join()
|
amine@400
|
58 with open(file) as fp:
|
amine@400
|
59 log_lines = fp.readlines()
|
amine@275
|
60
|
amine@400
|
61 log_fmt = (
|
amine@400
|
62 "[DET]: Detection {} (start: {:.3f}, end: {:.3f}, duration: {:.3f})"
|
amine@400
|
63 )
|
amine@400
|
64 assert len(tokenizer.detections) == len(expected_detections)
|
amine@400
|
65 for i, (det, exp, log_line) in enumerate(
|
amine@400
|
66 zip(tokenizer.detections, expected_detections, log_lines), 1
|
amine@400
|
67 ):
|
amine@400
|
68 start, end = exp
|
amine@400
|
69 exp_log_line = log_fmt.format(i, start, end, end - start)
|
amine@400
|
70 assert pytest.approx(det.start) == start
|
amine@400
|
71 assert pytest.approx(det.end) == end
|
amine@400
|
72 assert log_line[28:].strip() == exp_log_line
|
amine@275
|
73
|
amine@274
|
74
|
amine@400
|
75 def test_PlayerWorker(audio_data_source, expected_detections):
|
amine@400
|
76 with TemporaryDirectory() as tmpdir:
|
amine@400
|
77 file = os.path.join(tmpdir, "file.log")
|
amine@400
|
78 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@400
|
79 player_mock = Mock()
|
amine@400
|
80 observers = [PlayerWorker(player_mock, logger=logger)]
|
amine@275
|
81 tokenizer = TokenizerWorker(
|
amine@400
|
82 audio_data_source,
|
amine@400
|
83 logger=logger,
|
amine@275
|
84 observers=observers,
|
amine@275
|
85 min_dur=0.3,
|
amine@275
|
86 max_dur=2,
|
amine@275
|
87 max_silence=0.2,
|
amine@275
|
88 drop_trailing_silence=False,
|
amine@275
|
89 strict_min_dur=False,
|
amine@275
|
90 eth=50,
|
amine@275
|
91 )
|
amine@400
|
92 tokenizer.start_all()
|
amine@400
|
93 tokenizer.join()
|
amine@400
|
94 tokenizer._observers[0].join()
|
amine@400
|
95 with open(file) as fp:
|
amine@400
|
96 log_lines = [
|
amine@400
|
97 line for line in fp.readlines() if line.startswith("[PLAY]")
|
amine@400
|
98 ]
|
amine@400
|
99
|
amine@400
|
100 assert player_mock.play.called
|
amine@400
|
101 assert len(tokenizer.detections) == len(expected_detections)
|
amine@400
|
102 log_fmt = "[PLAY]: Detection {id} played"
|
amine@400
|
103 for i, (det, exp, log_line) in enumerate(
|
amine@400
|
104 zip(tokenizer.detections, expected_detections, log_lines), 1
|
amine@400
|
105 ):
|
amine@400
|
106 start, end = exp
|
amine@400
|
107 exp_log_line = log_fmt.format(id=i)
|
amine@400
|
108 assert pytest.approx(det.start) == start
|
amine@400
|
109 assert pytest.approx(det.end) == end
|
amine@400
|
110 assert log_line[28:].strip() == exp_log_line
|
amine@400
|
111
|
amine@400
|
112
|
amine@400
|
113 def test_RegionSaverWorker(audio_data_source, expected_detections):
|
amine@400
|
114 filename_format = "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
|
amine@400
|
115 with TemporaryDirectory() as tmpdir:
|
amine@400
|
116 file = os.path.join(tmpdir, "file.log")
|
amine@400
|
117 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@400
|
118 observers = [RegionSaverWorker(filename_format, logger=logger)]
|
amine@400
|
119 tokenizer = TokenizerWorker(
|
amine@400
|
120 audio_data_source,
|
amine@400
|
121 logger=logger,
|
amine@400
|
122 observers=observers,
|
amine@400
|
123 min_dur=0.3,
|
amine@400
|
124 max_dur=2,
|
amine@400
|
125 max_silence=0.2,
|
amine@400
|
126 drop_trailing_silence=False,
|
amine@400
|
127 strict_min_dur=False,
|
amine@400
|
128 eth=50,
|
amine@400
|
129 )
|
amine@400
|
130 with patch("auditok.core.AudioRegion.save") as patched_save:
|
amine@275
|
131 tokenizer.start_all()
|
amine@275
|
132 tokenizer.join()
|
amine@275
|
133 tokenizer._observers[0].join()
|
amine@400
|
134 with open(file) as fp:
|
amine@400
|
135 log_lines = [
|
amine@400
|
136 line for line in fp.readlines() if line.startswith("[SAVE]")
|
amine@400
|
137 ]
|
amine@275
|
138
|
amine@400
|
139 expected_save_calls = [
|
amine@400
|
140 call(
|
amine@400
|
141 filename_format.format(
|
amine@400
|
142 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
|
amine@400
|
143 ),
|
amine@400
|
144 None,
|
amine@400
|
145 )
|
amine@400
|
146 for i, exp in enumerate(expected_detections, 1)
|
amine@400
|
147 ]
|
amine@287
|
148
|
amine@400
|
149 mock_calls = [
|
amine@400
|
150 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
|
amine@400
|
151 ]
|
amine@400
|
152 assert mock_calls == expected_save_calls
|
amine@400
|
153 assert len(tokenizer.detections) == len(expected_detections)
|
amine@287
|
154
|
amine@400
|
155 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
|
amine@400
|
156 for i, (det, exp, log_line) in enumerate(
|
amine@400
|
157 zip(tokenizer.detections, expected_detections, log_lines), 1
|
amine@400
|
158 ):
|
amine@400
|
159 start, end = exp
|
amine@400
|
160 expected_filename = filename_format.format(
|
amine@400
|
161 id=i, start=start, end=end, duration=end - start
|
amine@400
|
162 )
|
amine@400
|
163 exp_log_line = log_fmt.format(id=i, filename=expected_filename)
|
amine@400
|
164 assert pytest.approx(det.start) == start
|
amine@400
|
165 assert pytest.approx(det.end) == end
|
amine@400
|
166 assert log_line[28:].strip() == exp_log_line
|
amine@400
|
167
|
amine@400
|
168
|
amine@400
|
169 def test_CommandLineWorker(audio_data_source, expected_detections):
|
amine@400
|
170 command_format = "do nothing with"
|
amine@400
|
171 with TemporaryDirectory() as tmpdir:
|
amine@400
|
172 file = os.path.join(tmpdir, "file.log")
|
amine@400
|
173 logger = make_logger(file=file, name="test_CommandLineWorker")
|
amine@400
|
174 observers = [CommandLineWorker(command_format, logger=logger)]
|
amine@400
|
175 tokenizer = TokenizerWorker(
|
amine@400
|
176 audio_data_source,
|
amine@400
|
177 logger=logger,
|
amine@400
|
178 observers=observers,
|
amine@400
|
179 min_dur=0.3,
|
amine@400
|
180 max_dur=2,
|
amine@400
|
181 max_silence=0.2,
|
amine@400
|
182 drop_trailing_silence=False,
|
amine@400
|
183 strict_min_dur=False,
|
amine@400
|
184 eth=50,
|
amine@400
|
185 )
|
amine@400
|
186 with patch("auditok.workers.os.system") as patched_os_system:
|
amine@287
|
187 tokenizer.start_all()
|
amine@287
|
188 tokenizer.join()
|
amine@400
|
189 tokenizer._observers[0].join()
|
amine@400
|
190 with open(file) as fp:
|
amine@400
|
191 log_lines = [
|
amine@400
|
192 line for line in fp.readlines() if line.startswith("[COMMAND]")
|
amine@400
|
193 ]
|
amine@287
|
194
|
amine@400
|
195 expected_save_calls = [call(command_format) for _ in expected_detections]
|
amine@400
|
196 assert patched_os_system.mock_calls == expected_save_calls
|
amine@400
|
197 assert len(tokenizer.detections) == len(expected_detections)
|
amine@400
|
198 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
|
amine@400
|
199 for i, (det, exp, log_line) in enumerate(
|
amine@400
|
200 zip(tokenizer.detections, expected_detections, log_lines), 1
|
amine@400
|
201 ):
|
amine@400
|
202 start, end = exp
|
amine@400
|
203 exp_log_line = log_fmt.format(id=i, command=command_format)
|
amine@400
|
204 assert pytest.approx(det.start) == start
|
amine@400
|
205 assert pytest.approx(det.end) == end
|
amine@400
|
206 assert log_line[28:].strip() == exp_log_line
|
amine@400
|
207
|
amine@400
|
208
|
amine@400
|
209 def test_PrintWorker(audio_data_source, expected_detections):
|
amine@400
|
210 observers = [
|
amine@400
|
211 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
|
amine@400
|
212 ]
|
amine@400
|
213 tokenizer = TokenizerWorker(
|
amine@400
|
214 audio_data_source,
|
amine@400
|
215 observers=observers,
|
amine@400
|
216 min_dur=0.3,
|
amine@400
|
217 max_dur=2,
|
amine@400
|
218 max_silence=0.2,
|
amine@400
|
219 drop_trailing_silence=False,
|
amine@400
|
220 strict_min_dur=False,
|
amine@400
|
221 eth=50,
|
amine@400
|
222 )
|
amine@400
|
223 with patch("builtins.print") as patched_print:
|
amine@400
|
224 tokenizer.start_all()
|
amine@400
|
225 tokenizer.join()
|
amine@400
|
226 tokenizer._observers[0].join()
|
amine@400
|
227
|
amine@400
|
228 expected_print_calls = [
|
amine@400
|
229 call(
|
amine@400
|
230 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
|
amine@400
|
231 i, exp[0], exp[1], exp[1] - exp[0]
|
amine@287
|
232 )
|
amine@400
|
233 )
|
amine@400
|
234 for i, exp in enumerate(expected_detections, 1)
|
amine@400
|
235 ]
|
amine@400
|
236 assert patched_print.mock_calls == expected_print_calls
|
amine@400
|
237 assert len(tokenizer.detections) == len(expected_detections)
|
amine@400
|
238 for det, exp in zip(tokenizer.detections, expected_detections):
|
amine@400
|
239 start, end = exp
|
amine@400
|
240 assert pytest.approx(det.start) == start
|
amine@400
|
241 assert pytest.approx(det.end) == end
|
amine@287
|
242
|
amine@287
|
243
|
amine@400
|
244 def test_StreamSaverWorker_wav(audio_data_source):
|
amine@400
|
245 with TemporaryDirectory() as tmpdir:
|
amine@400
|
246 expected_filename = os.path.join(tmpdir, "output.wav")
|
amine@400
|
247 saver = StreamSaverWorker(audio_data_source, expected_filename)
|
amine@400
|
248 saver.start()
|
amine@400
|
249
|
amine@400
|
250 tokenizer = TokenizerWorker(saver)
|
amine@400
|
251 tokenizer.start_all()
|
amine@400
|
252 tokenizer.join()
|
amine@400
|
253 saver.join()
|
amine@400
|
254
|
amine@400
|
255 output_filename = saver.save_stream()
|
amine@400
|
256 region = AudioRegion.load(
|
amine@400
|
257 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@400
|
258 )
|
amine@400
|
259
|
amine@400
|
260 expected_region = AudioRegion.load(output_filename)
|
amine@400
|
261 assert output_filename == expected_filename
|
amine@400
|
262 assert region == expected_region
|
amine@400
|
263 assert saver.data == bytes(expected_region)
|
amine@400
|
264
|
amine@400
|
265
|
amine@400
|
266 def test_StreamSaverWorker_raw(audio_data_source):
|
amine@400
|
267 with TemporaryDirectory() as tmpdir:
|
amine@400
|
268 expected_filename = os.path.join(tmpdir, "output")
|
amine@400
|
269 saver = StreamSaverWorker(
|
amine@400
|
270 audio_data_source, expected_filename, export_format="raw"
|
amine@400
|
271 )
|
amine@400
|
272 saver.start()
|
amine@400
|
273 tokenizer = TokenizerWorker(saver)
|
amine@400
|
274 tokenizer.start_all()
|
amine@400
|
275 tokenizer.join()
|
amine@400
|
276 saver.join()
|
amine@400
|
277 output_filename = saver.save_stream()
|
amine@400
|
278 region = AudioRegion.load(
|
amine@400
|
279 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@400
|
280 )
|
amine@400
|
281 expected_region = AudioRegion.load(
|
amine@400
|
282 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
|
amine@400
|
283 )
|
amine@400
|
284 assert output_filename == expected_filename
|
amine@400
|
285 assert region == expected_region
|
amine@400
|
286 assert saver.data == bytes(expected_region)
|
amine@400
|
287
|
amine@400
|
288
|
amine@400
|
289 def test_StreamSaverWorker_encode_audio(audio_data_source):
|
amine@400
|
290 with TemporaryDirectory() as tmpdir:
|
amine@400
|
291 with patch("auditok.workers._run_subprocess") as patch_rsp:
|
amine@400
|
292 patch_rsp.return_value = (1, None, None)
|
amine@400
|
293 expected_filename = os.path.join(tmpdir, "output.ogg")
|
amine@400
|
294 tmp_expected_filename = expected_filename + ".wav"
|
amine@400
|
295 saver = StreamSaverWorker(audio_data_source, expected_filename)
|
amine@287
|
296 saver.start()
|
amine@287
|
297 tokenizer = TokenizerWorker(saver)
|
amine@287
|
298 tokenizer.start_all()
|
amine@287
|
299 tokenizer.join()
|
amine@287
|
300 saver.join()
|
amine@400
|
301 with pytest.raises(AudioEncodingWarning) as rt_warn:
|
amine@400
|
302 saver.save_stream()
|
amine@400
|
303 warn_msg = "Couldn't save audio data in the desired format "
|
amine@400
|
304 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
|
amine@400
|
305 warn_msg += "is installed or this format is not recognized.\n"
|
amine@400
|
306 warn_msg += "Audio file was saved as '{}'"
|
amine@400
|
307 assert warn_msg.format(tmp_expected_filename) == str(rt_warn.value)
|
amine@400
|
308 ffmpef_avconv = [
|
amine@400
|
309 "-y",
|
amine@400
|
310 "-f",
|
amine@400
|
311 "wav",
|
amine@400
|
312 "-i",
|
amine@400
|
313 tmp_expected_filename,
|
amine@400
|
314 "-f",
|
amine@400
|
315 "ogg",
|
amine@400
|
316 expected_filename,
|
amine@400
|
317 ]
|
amine@400
|
318 expected_calls = [
|
amine@400
|
319 call(["ffmpeg"] + ffmpef_avconv),
|
amine@400
|
320 call(["avconv"] + ffmpef_avconv),
|
amine@400
|
321 call(
|
amine@400
|
322 [
|
amine@400
|
323 "sox",
|
amine@400
|
324 "-t",
|
amine@400
|
325 "wav",
|
amine@400
|
326 tmp_expected_filename,
|
amine@400
|
327 expected_filename,
|
amine@400
|
328 ]
|
amine@400
|
329 ),
|
amine@400
|
330 ]
|
amine@400
|
331 assert patch_rsp.mock_calls == expected_calls
|
amine@400
|
332 region = AudioRegion.load(
|
amine@400
|
333 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@400
|
334 )
|
amine@400
|
335 assert saver._exported
|
amine@400
|
336 assert saver.data == bytes(region)
|