amine@274
|
1 import os
|
amine@274
|
2 from unittest import TestCase
|
amine@282
|
3 from unittest.mock import patch, call, Mock
|
amine@274
|
4 from tempfile import TemporaryDirectory
|
amine@274
|
5 from genty import genty, genty_dataset
|
amine@287
|
6 from auditok import AudioRegion, AudioDataSource
|
amine@292
|
7 from auditok.exceptions import AudioEncodingWarning
|
amine@274
|
8 from auditok.cmdline_util import make_logger
|
amine@274
|
9 from auditok.workers import (
|
amine@274
|
10 TokenizerWorker,
|
amine@274
|
11 StreamSaverWorker,
|
amine@274
|
12 RegionSaverWorker,
|
amine@274
|
13 PlayerWorker,
|
amine@274
|
14 CommandLineWorker,
|
amine@274
|
15 PrintWorker,
|
amine@274
|
16 )
|
amine@274
|
17
|
amine@274
|
18
|
amine@274
|
19 @genty
|
amine@274
|
20 class TestWorkers(TestCase):
|
amine@275
|
21 def setUp(self):
|
amine@275
|
22
|
amine@275
|
23 self.reader = AudioDataSource(
|
amine@274
|
24 input="tests/data/test_split_10HZ_mono.raw",
|
amine@274
|
25 block_dur=0.1,
|
amine@274
|
26 sr=10,
|
amine@274
|
27 sw=2,
|
amine@274
|
28 ch=1,
|
amine@274
|
29 )
|
amine@275
|
30 self.expected = [
|
amine@275
|
31 (0.2, 1.6),
|
amine@275
|
32 (1.7, 3.1),
|
amine@275
|
33 (3.4, 5.4),
|
amine@275
|
34 (5.4, 7.4),
|
amine@275
|
35 (7.4, 7.6),
|
amine@275
|
36 ]
|
amine@275
|
37
|
amine@275
|
38 def tearDown(self):
|
amine@275
|
39 self.reader.close()
|
amine@275
|
40
|
amine@275
|
41 def test_TokenizerWorker(self):
|
amine@274
|
42 with TemporaryDirectory() as tmpdir:
|
amine@274
|
43 file = os.path.join(tmpdir, "file.log")
|
amine@274
|
44 logger = make_logger(file=file, name="test_TokenizerWorker")
|
amine@274
|
45 tokenizer = TokenizerWorker(
|
amine@275
|
46 self.reader,
|
amine@274
|
47 logger=logger,
|
amine@274
|
48 min_dur=0.3,
|
amine@274
|
49 max_dur=2,
|
amine@274
|
50 max_silence=0.2,
|
amine@274
|
51 drop_trailing_silence=False,
|
amine@274
|
52 strict_min_dur=False,
|
amine@274
|
53 eth=50,
|
amine@274
|
54 )
|
amine@275
|
55 tokenizer.start_all()
|
amine@275
|
56 tokenizer.join()
|
amine@274
|
57 # Get logged text
|
amine@274
|
58 with open(file) as fp:
|
amine@274
|
59 log_lines = fp.readlines()
|
amine@274
|
60
|
amine@274
|
61 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
|
amine@274
|
62 log_fmt += "end: {:.3f}, duration: {:.3f})"
|
amine@275
|
63 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@274
|
64 for i, (det, exp, log_line) in enumerate(
|
amine@275
|
65 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@274
|
66 ):
|
amine@274
|
67 start, end = exp
|
amine@274
|
68 exp_log_line = log_fmt.format(i, start, end, end - start)
|
amine@274
|
69 self.assertAlmostEqual(det.start, start)
|
amine@274
|
70 self.assertAlmostEqual(det.end, end)
|
amine@274
|
71 # remove timestamp part and strip new line
|
amine@274
|
72 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@275
|
73
|
amine@282
|
74 def test_PlayerWorker(self):
|
amine@282
|
75 with TemporaryDirectory() as tmpdir:
|
amine@282
|
76 file = os.path.join(tmpdir, "file.log")
|
amine@282
|
77 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@282
|
78 player_mock = Mock()
|
amine@282
|
79 observers = [PlayerWorker(player_mock, logger=logger)]
|
amine@282
|
80 tokenizer = TokenizerWorker(
|
amine@282
|
81 self.reader,
|
amine@282
|
82 logger=logger,
|
amine@282
|
83 observers=observers,
|
amine@282
|
84 min_dur=0.3,
|
amine@282
|
85 max_dur=2,
|
amine@282
|
86 max_silence=0.2,
|
amine@282
|
87 drop_trailing_silence=False,
|
amine@282
|
88 strict_min_dur=False,
|
amine@282
|
89 eth=50,
|
amine@282
|
90 )
|
amine@282
|
91 tokenizer.start_all()
|
amine@282
|
92 tokenizer.join()
|
amine@282
|
93 tokenizer._observers[0].join()
|
amine@282
|
94 # Get logged text
|
amine@282
|
95 with open(file) as fp:
|
amine@282
|
96 log_lines = [
|
amine@282
|
97 line
|
amine@282
|
98 for line in fp.readlines()
|
amine@282
|
99 if line.startswith("[PLAY]")
|
amine@282
|
100 ]
|
amine@282
|
101 self.assertTrue(player_mock.play.called)
|
amine@282
|
102
|
amine@282
|
103 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@282
|
104 log_fmt = "[PLAY]: Detection {id} played"
|
amine@282
|
105 for i, (det, exp, log_line) in enumerate(
|
amine@282
|
106 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@282
|
107 ):
|
amine@282
|
108 start, end = exp
|
amine@282
|
109 exp_log_line = log_fmt.format(id=i)
|
amine@282
|
110 self.assertAlmostEqual(det.start, start)
|
amine@282
|
111 self.assertAlmostEqual(det.end, end)
|
amine@282
|
112 # Remove timestamp part and strip new line
|
amine@282
|
113 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@282
|
114
|
amine@277
|
115 def test_RegionSaverWorker(self):
|
amine@277
|
116 filename_format = (
|
amine@277
|
117 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
|
amine@277
|
118 )
|
amine@277
|
119 with TemporaryDirectory() as tmpdir:
|
amine@277
|
120 file = os.path.join(tmpdir, "file.log")
|
amine@277
|
121 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@277
|
122 observers = [RegionSaverWorker(filename_format, logger=logger)]
|
amine@277
|
123 tokenizer = TokenizerWorker(
|
amine@277
|
124 self.reader,
|
amine@277
|
125 logger=logger,
|
amine@277
|
126 observers=observers,
|
amine@277
|
127 min_dur=0.3,
|
amine@277
|
128 max_dur=2,
|
amine@277
|
129 max_silence=0.2,
|
amine@277
|
130 drop_trailing_silence=False,
|
amine@277
|
131 strict_min_dur=False,
|
amine@277
|
132 eth=50,
|
amine@277
|
133 )
|
amine@277
|
134 with patch("auditok.core.AudioRegion.save") as patched_save:
|
amine@277
|
135 tokenizer.start_all()
|
amine@277
|
136 tokenizer.join()
|
amine@277
|
137 tokenizer._observers[0].join()
|
amine@277
|
138 # Get logged text
|
amine@277
|
139 with open(file) as fp:
|
amine@277
|
140 log_lines = [
|
amine@277
|
141 line
|
amine@277
|
142 for line in fp.readlines()
|
amine@277
|
143 if line.startswith("[SAVE]")
|
amine@277
|
144 ]
|
amine@277
|
145
|
amine@282
|
146 # Assert RegionSaverWorker ran as expected
|
amine@277
|
147 expected_save_calls = [
|
amine@277
|
148 call(
|
amine@277
|
149 filename_format.format(
|
amine@277
|
150 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
|
amine@277
|
151 ),
|
amine@277
|
152 None,
|
amine@277
|
153 )
|
amine@277
|
154 for i, exp in enumerate(self.expected, 1)
|
amine@277
|
155 ]
|
amine@277
|
156
|
amine@282
|
157 # Get calls to 'AudioRegion.save'
|
amine@277
|
158 mock_calls = [
|
amine@277
|
159 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
|
amine@277
|
160 ]
|
amine@277
|
161 self.assertEqual(mock_calls, expected_save_calls)
|
amine@277
|
162 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@277
|
163
|
amine@279
|
164 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
|
amine@277
|
165 for i, (det, exp, log_line) in enumerate(
|
amine@277
|
166 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@277
|
167 ):
|
amine@277
|
168 start, end = exp
|
amine@277
|
169 expected_filename = filename_format.format(
|
amine@277
|
170 id=i, start=start, end=end, duration=end - start
|
amine@277
|
171 )
|
amine@277
|
172 exp_log_line = log_fmt.format(i, expected_filename)
|
amine@277
|
173 self.assertAlmostEqual(det.start, start)
|
amine@277
|
174 self.assertAlmostEqual(det.end, end)
|
amine@282
|
175 # Remove timestamp part and strip new line
|
amine@277
|
176 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@277
|
177
|
amine@279
|
178 def test_CommandLineWorker(self):
|
amine@279
|
179 command_format = "do nothing with"
|
amine@279
|
180 with TemporaryDirectory() as tmpdir:
|
amine@279
|
181 file = os.path.join(tmpdir, "file.log")
|
amine@279
|
182 logger = make_logger(file=file, name="test_CommandLineWorker")
|
amine@279
|
183 observers = [CommandLineWorker(command_format, logger=logger)]
|
amine@279
|
184 tokenizer = TokenizerWorker(
|
amine@279
|
185 self.reader,
|
amine@279
|
186 logger=logger,
|
amine@279
|
187 observers=observers,
|
amine@279
|
188 min_dur=0.3,
|
amine@279
|
189 max_dur=2,
|
amine@279
|
190 max_silence=0.2,
|
amine@279
|
191 drop_trailing_silence=False,
|
amine@279
|
192 strict_min_dur=False,
|
amine@279
|
193 eth=50,
|
amine@279
|
194 )
|
amine@279
|
195 with patch("auditok.workers.os.system") as patched_os_system:
|
amine@279
|
196 tokenizer.start_all()
|
amine@279
|
197 tokenizer.join()
|
amine@279
|
198 tokenizer._observers[0].join()
|
amine@279
|
199 # Get logged text
|
amine@279
|
200 with open(file) as fp:
|
amine@279
|
201 log_lines = [
|
amine@279
|
202 line
|
amine@279
|
203 for line in fp.readlines()
|
amine@279
|
204 if line.startswith("[COMMAND]")
|
amine@279
|
205 ]
|
amine@279
|
206
|
amine@282
|
207 # Assert CommandLineWorker ran as expected
|
amine@279
|
208 expected_save_calls = [call(command_format) for _ in self.expected]
|
amine@279
|
209 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
|
amine@279
|
210 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@279
|
211 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
|
amine@279
|
212 for i, (det, exp, log_line) in enumerate(
|
amine@279
|
213 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@279
|
214 ):
|
amine@279
|
215 start, end = exp
|
amine@279
|
216 exp_log_line = log_fmt.format(i, command_format)
|
amine@279
|
217 self.assertAlmostEqual(det.start, start)
|
amine@279
|
218 self.assertAlmostEqual(det.end, end)
|
amine@282
|
219 # Remove timestamp part and strip new line
|
amine@279
|
220 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@279
|
221
|
amine@275
|
222 def test_PrintWorker(self):
|
amine@275
|
223 observers = [
|
amine@275
|
224 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
|
amine@275
|
225 ]
|
amine@275
|
226 tokenizer = TokenizerWorker(
|
amine@275
|
227 self.reader,
|
amine@275
|
228 observers=observers,
|
amine@275
|
229 min_dur=0.3,
|
amine@275
|
230 max_dur=2,
|
amine@275
|
231 max_silence=0.2,
|
amine@275
|
232 drop_trailing_silence=False,
|
amine@275
|
233 strict_min_dur=False,
|
amine@275
|
234 eth=50,
|
amine@275
|
235 )
|
amine@275
|
236 with patch("builtins.print") as patched_print:
|
amine@275
|
237 tokenizer.start_all()
|
amine@275
|
238 tokenizer.join()
|
amine@275
|
239 tokenizer._observers[0].join()
|
amine@275
|
240
|
amine@282
|
241 # Assert PrintWorker ran as expected
|
amine@275
|
242 expected_print_calls = [
|
amine@275
|
243 call(
|
amine@275
|
244 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
|
amine@275
|
245 i, *exp, exp[1] - exp[0]
|
amine@275
|
246 )
|
amine@275
|
247 )
|
amine@275
|
248 for i, exp in enumerate(self.expected, 1)
|
amine@275
|
249 ]
|
amine@275
|
250 self.assertEqual(patched_print.mock_calls, expected_print_calls)
|
amine@275
|
251 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@275
|
252 for det, exp in zip(tokenizer.detections, self.expected):
|
amine@275
|
253 start, end = exp
|
amine@275
|
254 self.assertAlmostEqual(det.start, start)
|
amine@275
|
255 self.assertAlmostEqual(det.end, end)
|
amine@287
|
256
|
amine@287
|
257 def test_StreamSaverWorker_wav(self):
|
amine@287
|
258 with TemporaryDirectory() as tmpdir:
|
amine@287
|
259 expected_filename = os.path.join(tmpdir, "output.wav")
|
amine@287
|
260 saver = StreamSaverWorker(self.reader, expected_filename)
|
amine@287
|
261 saver.start()
|
amine@287
|
262
|
amine@287
|
263 tokenizer = TokenizerWorker(saver)
|
amine@287
|
264 tokenizer.start_all()
|
amine@287
|
265 tokenizer.join()
|
amine@287
|
266 saver.join()
|
amine@287
|
267
|
amine@287
|
268 output_filename = saver.save_stream()
|
amine@287
|
269 region = AudioRegion.load(
|
amine@287
|
270 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
271 )
|
amine@287
|
272
|
amine@287
|
273 expected_region = AudioRegion.load(output_filename)
|
amine@287
|
274 self.assertEqual(output_filename, expected_filename)
|
amine@287
|
275 self.assertEqual(region, expected_region)
|
amine@287
|
276 self.assertEqual(saver.data, bytes(expected_region))
|
amine@287
|
277
|
amine@287
|
278 def test_StreamSaverWorker_raw(self):
|
amine@287
|
279 with TemporaryDirectory() as tmpdir:
|
amine@287
|
280 expected_filename = os.path.join(tmpdir, "output")
|
amine@287
|
281 saver = StreamSaverWorker(
|
amine@287
|
282 self.reader, expected_filename, export_format="raw"
|
amine@287
|
283 )
|
amine@287
|
284 saver.start()
|
amine@287
|
285 tokenizer = TokenizerWorker(saver)
|
amine@287
|
286 tokenizer.start_all()
|
amine@287
|
287 tokenizer.join()
|
amine@287
|
288 saver.join()
|
amine@287
|
289 output_filename = saver.save_stream()
|
amine@287
|
290 region = AudioRegion.load(
|
amine@287
|
291 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
292 )
|
amine@287
|
293 expected_region = AudioRegion.load(
|
amine@287
|
294 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
|
amine@287
|
295 )
|
amine@287
|
296 self.assertEqual(output_filename, expected_filename)
|
amine@287
|
297 self.assertEqual(region, expected_region)
|
amine@287
|
298 self.assertEqual(saver.data, bytes(expected_region))
|
amine@287
|
299
|
amine@287
|
300 def test_StreamSaverWorker_encode_audio(self):
|
amine@287
|
301 with TemporaryDirectory() as tmpdir:
|
amine@287
|
302 with patch("auditok.workers._run_subprocess") as patch_rsp:
|
amine@287
|
303 patch_rsp.return_value = (1, None, None)
|
amine@287
|
304 expected_filename = os.path.join(tmpdir, "output.ogg")
|
amine@287
|
305 tmp_expected_filename = expected_filename + ".wav"
|
amine@287
|
306 saver = StreamSaverWorker(self.reader, expected_filename)
|
amine@287
|
307 saver.start()
|
amine@287
|
308 tokenizer = TokenizerWorker(saver)
|
amine@287
|
309 tokenizer.start_all()
|
amine@287
|
310 tokenizer.join()
|
amine@287
|
311 saver.join()
|
amine@292
|
312 with self.assertRaises(AudioEncodingWarning) as rt_warn:
|
amine@287
|
313 saver.save_stream()
|
amine@287
|
314 warn_msg = "Couldn't save audio data in the desired format "
|
amine@287
|
315 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
|
amine@287
|
316 warn_msg += "is installed or this format is not recognized.\n"
|
amine@287
|
317 warn_msg += "Audio file was saved as '{}'"
|
amine@287
|
318 self.assertEqual(
|
amine@287
|
319 warn_msg.format(tmp_expected_filename), str(rt_warn.exception)
|
amine@287
|
320 )
|
amine@287
|
321 ffmpef_avconv = [
|
amine@287
|
322 "-y",
|
amine@287
|
323 "-f",
|
amine@287
|
324 "wav",
|
amine@287
|
325 "-i",
|
amine@287
|
326 tmp_expected_filename,
|
amine@287
|
327 "-f",
|
amine@287
|
328 "ogg",
|
amine@287
|
329 expected_filename,
|
amine@287
|
330 ]
|
amine@287
|
331 expected_calls = [
|
amine@287
|
332 call(["ffmpeg"] + ffmpef_avconv),
|
amine@287
|
333 call(["avconv"] + ffmpef_avconv),
|
amine@287
|
334 call(
|
amine@287
|
335 [
|
amine@287
|
336 "sox",
|
amine@287
|
337 "-t",
|
amine@287
|
338 "wav",
|
amine@287
|
339 tmp_expected_filename,
|
amine@287
|
340 expected_filename,
|
amine@287
|
341 ]
|
amine@287
|
342 ),
|
amine@287
|
343 ]
|
amine@287
|
344 self.assertEqual(patch_rsp.mock_calls, expected_calls)
|
amine@287
|
345 region = AudioRegion.load(
|
amine@287
|
346 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
347 )
|
amine@287
|
348 self.assertTrue(saver._exported)
|
amine@287
|
349 self.assertEqual(saver.data, bytes(region))
|