amine@274
|
1 import os
|
amine@337
|
2 import unittest
|
amine@274
|
3 from unittest import TestCase
|
amine@282
|
4 from unittest.mock import patch, call, Mock
|
amine@274
|
5 from tempfile import TemporaryDirectory
|
amine@274
|
6 from genty import genty, genty_dataset
|
amine@287
|
7 from auditok import AudioRegion, AudioDataSource
|
amine@292
|
8 from auditok.exceptions import AudioEncodingWarning
|
amine@274
|
9 from auditok.cmdline_util import make_logger
|
amine@274
|
10 from auditok.workers import (
|
amine@274
|
11 TokenizerWorker,
|
amine@274
|
12 StreamSaverWorker,
|
amine@274
|
13 RegionSaverWorker,
|
amine@274
|
14 PlayerWorker,
|
amine@274
|
15 CommandLineWorker,
|
amine@274
|
16 PrintWorker,
|
amine@274
|
17 )
|
amine@274
|
18
|
amine@274
|
19
|
amine@274
|
20 @genty
|
amine@274
|
21 class TestWorkers(TestCase):
|
amine@275
|
22 def setUp(self):
|
amine@275
|
23
|
amine@275
|
24 self.reader = AudioDataSource(
|
amine@274
|
25 input="tests/data/test_split_10HZ_mono.raw",
|
amine@274
|
26 block_dur=0.1,
|
amine@274
|
27 sr=10,
|
amine@274
|
28 sw=2,
|
amine@274
|
29 ch=1,
|
amine@274
|
30 )
|
amine@275
|
31 self.expected = [
|
amine@275
|
32 (0.2, 1.6),
|
amine@275
|
33 (1.7, 3.1),
|
amine@275
|
34 (3.4, 5.4),
|
amine@275
|
35 (5.4, 7.4),
|
amine@275
|
36 (7.4, 7.6),
|
amine@275
|
37 ]
|
amine@275
|
38
|
amine@275
|
39 def tearDown(self):
|
amine@275
|
40 self.reader.close()
|
amine@275
|
41
|
amine@275
|
42 def test_TokenizerWorker(self):
|
amine@274
|
43 with TemporaryDirectory() as tmpdir:
|
amine@274
|
44 file = os.path.join(tmpdir, "file.log")
|
amine@274
|
45 logger = make_logger(file=file, name="test_TokenizerWorker")
|
amine@274
|
46 tokenizer = TokenizerWorker(
|
amine@275
|
47 self.reader,
|
amine@274
|
48 logger=logger,
|
amine@274
|
49 min_dur=0.3,
|
amine@274
|
50 max_dur=2,
|
amine@274
|
51 max_silence=0.2,
|
amine@274
|
52 drop_trailing_silence=False,
|
amine@274
|
53 strict_min_dur=False,
|
amine@274
|
54 eth=50,
|
amine@274
|
55 )
|
amine@275
|
56 tokenizer.start_all()
|
amine@275
|
57 tokenizer.join()
|
amine@274
|
58 # Get logged text
|
amine@274
|
59 with open(file) as fp:
|
amine@274
|
60 log_lines = fp.readlines()
|
amine@274
|
61
|
amine@274
|
62 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
|
amine@274
|
63 log_fmt += "end: {:.3f}, duration: {:.3f})"
|
amine@275
|
64 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@274
|
65 for i, (det, exp, log_line) in enumerate(
|
amine@275
|
66 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@274
|
67 ):
|
amine@274
|
68 start, end = exp
|
amine@274
|
69 exp_log_line = log_fmt.format(i, start, end, end - start)
|
amine@274
|
70 self.assertAlmostEqual(det.start, start)
|
amine@274
|
71 self.assertAlmostEqual(det.end, end)
|
amine@274
|
72 # remove timestamp part and strip new line
|
amine@274
|
73 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@275
|
74
|
amine@282
|
75 def test_PlayerWorker(self):
|
amine@282
|
76 with TemporaryDirectory() as tmpdir:
|
amine@282
|
77 file = os.path.join(tmpdir, "file.log")
|
amine@282
|
78 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@282
|
79 player_mock = Mock()
|
amine@282
|
80 observers = [PlayerWorker(player_mock, logger=logger)]
|
amine@282
|
81 tokenizer = TokenizerWorker(
|
amine@282
|
82 self.reader,
|
amine@282
|
83 logger=logger,
|
amine@282
|
84 observers=observers,
|
amine@282
|
85 min_dur=0.3,
|
amine@282
|
86 max_dur=2,
|
amine@282
|
87 max_silence=0.2,
|
amine@282
|
88 drop_trailing_silence=False,
|
amine@282
|
89 strict_min_dur=False,
|
amine@282
|
90 eth=50,
|
amine@282
|
91 )
|
amine@282
|
92 tokenizer.start_all()
|
amine@282
|
93 tokenizer.join()
|
amine@282
|
94 tokenizer._observers[0].join()
|
amine@282
|
95 # Get logged text
|
amine@282
|
96 with open(file) as fp:
|
amine@282
|
97 log_lines = [
|
amine@282
|
98 line
|
amine@282
|
99 for line in fp.readlines()
|
amine@282
|
100 if line.startswith("[PLAY]")
|
amine@282
|
101 ]
|
amine@282
|
102 self.assertTrue(player_mock.play.called)
|
amine@282
|
103
|
amine@282
|
104 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@282
|
105 log_fmt = "[PLAY]: Detection {id} played"
|
amine@282
|
106 for i, (det, exp, log_line) in enumerate(
|
amine@282
|
107 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@282
|
108 ):
|
amine@282
|
109 start, end = exp
|
amine@282
|
110 exp_log_line = log_fmt.format(id=i)
|
amine@282
|
111 self.assertAlmostEqual(det.start, start)
|
amine@282
|
112 self.assertAlmostEqual(det.end, end)
|
amine@282
|
113 # Remove timestamp part and strip new line
|
amine@282
|
114 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@282
|
115
|
amine@277
|
116 def test_RegionSaverWorker(self):
|
amine@277
|
117 filename_format = (
|
amine@277
|
118 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
|
amine@277
|
119 )
|
amine@277
|
120 with TemporaryDirectory() as tmpdir:
|
amine@277
|
121 file = os.path.join(tmpdir, "file.log")
|
amine@277
|
122 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@277
|
123 observers = [RegionSaverWorker(filename_format, logger=logger)]
|
amine@277
|
124 tokenizer = TokenizerWorker(
|
amine@277
|
125 self.reader,
|
amine@277
|
126 logger=logger,
|
amine@277
|
127 observers=observers,
|
amine@277
|
128 min_dur=0.3,
|
amine@277
|
129 max_dur=2,
|
amine@277
|
130 max_silence=0.2,
|
amine@277
|
131 drop_trailing_silence=False,
|
amine@277
|
132 strict_min_dur=False,
|
amine@277
|
133 eth=50,
|
amine@277
|
134 )
|
amine@277
|
135 with patch("auditok.core.AudioRegion.save") as patched_save:
|
amine@277
|
136 tokenizer.start_all()
|
amine@277
|
137 tokenizer.join()
|
amine@277
|
138 tokenizer._observers[0].join()
|
amine@277
|
139 # Get logged text
|
amine@277
|
140 with open(file) as fp:
|
amine@277
|
141 log_lines = [
|
amine@277
|
142 line
|
amine@277
|
143 for line in fp.readlines()
|
amine@277
|
144 if line.startswith("[SAVE]")
|
amine@277
|
145 ]
|
amine@277
|
146
|
amine@282
|
147 # Assert RegionSaverWorker ran as expected
|
amine@277
|
148 expected_save_calls = [
|
amine@277
|
149 call(
|
amine@277
|
150 filename_format.format(
|
amine@277
|
151 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
|
amine@277
|
152 ),
|
amine@277
|
153 None,
|
amine@277
|
154 )
|
amine@277
|
155 for i, exp in enumerate(self.expected, 1)
|
amine@277
|
156 ]
|
amine@277
|
157
|
amine@282
|
158 # Get calls to 'AudioRegion.save'
|
amine@277
|
159 mock_calls = [
|
amine@277
|
160 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
|
amine@277
|
161 ]
|
amine@277
|
162 self.assertEqual(mock_calls, expected_save_calls)
|
amine@277
|
163 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@277
|
164
|
amine@279
|
165 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
|
amine@277
|
166 for i, (det, exp, log_line) in enumerate(
|
amine@277
|
167 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@277
|
168 ):
|
amine@277
|
169 start, end = exp
|
amine@277
|
170 expected_filename = filename_format.format(
|
amine@277
|
171 id=i, start=start, end=end, duration=end - start
|
amine@277
|
172 )
|
amine@277
|
173 exp_log_line = log_fmt.format(i, expected_filename)
|
amine@277
|
174 self.assertAlmostEqual(det.start, start)
|
amine@277
|
175 self.assertAlmostEqual(det.end, end)
|
amine@282
|
176 # Remove timestamp part and strip new line
|
amine@277
|
177 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@277
|
178
|
amine@279
|
179 def test_CommandLineWorker(self):
|
amine@279
|
180 command_format = "do nothing with"
|
amine@279
|
181 with TemporaryDirectory() as tmpdir:
|
amine@279
|
182 file = os.path.join(tmpdir, "file.log")
|
amine@279
|
183 logger = make_logger(file=file, name="test_CommandLineWorker")
|
amine@279
|
184 observers = [CommandLineWorker(command_format, logger=logger)]
|
amine@279
|
185 tokenizer = TokenizerWorker(
|
amine@279
|
186 self.reader,
|
amine@279
|
187 logger=logger,
|
amine@279
|
188 observers=observers,
|
amine@279
|
189 min_dur=0.3,
|
amine@279
|
190 max_dur=2,
|
amine@279
|
191 max_silence=0.2,
|
amine@279
|
192 drop_trailing_silence=False,
|
amine@279
|
193 strict_min_dur=False,
|
amine@279
|
194 eth=50,
|
amine@279
|
195 )
|
amine@279
|
196 with patch("auditok.workers.os.system") as patched_os_system:
|
amine@279
|
197 tokenizer.start_all()
|
amine@279
|
198 tokenizer.join()
|
amine@279
|
199 tokenizer._observers[0].join()
|
amine@279
|
200 # Get logged text
|
amine@279
|
201 with open(file) as fp:
|
amine@279
|
202 log_lines = [
|
amine@279
|
203 line
|
amine@279
|
204 for line in fp.readlines()
|
amine@279
|
205 if line.startswith("[COMMAND]")
|
amine@279
|
206 ]
|
amine@279
|
207
|
amine@282
|
208 # Assert CommandLineWorker ran as expected
|
amine@279
|
209 expected_save_calls = [call(command_format) for _ in self.expected]
|
amine@279
|
210 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
|
amine@279
|
211 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@279
|
212 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
|
amine@279
|
213 for i, (det, exp, log_line) in enumerate(
|
amine@279
|
214 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@279
|
215 ):
|
amine@279
|
216 start, end = exp
|
amine@279
|
217 exp_log_line = log_fmt.format(i, command_format)
|
amine@279
|
218 self.assertAlmostEqual(det.start, start)
|
amine@279
|
219 self.assertAlmostEqual(det.end, end)
|
amine@282
|
220 # Remove timestamp part and strip new line
|
amine@279
|
221 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@279
|
222
|
amine@275
|
223 def test_PrintWorker(self):
|
amine@275
|
224 observers = [
|
amine@275
|
225 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
|
amine@275
|
226 ]
|
amine@275
|
227 tokenizer = TokenizerWorker(
|
amine@275
|
228 self.reader,
|
amine@275
|
229 observers=observers,
|
amine@275
|
230 min_dur=0.3,
|
amine@275
|
231 max_dur=2,
|
amine@275
|
232 max_silence=0.2,
|
amine@275
|
233 drop_trailing_silence=False,
|
amine@275
|
234 strict_min_dur=False,
|
amine@275
|
235 eth=50,
|
amine@275
|
236 )
|
amine@275
|
237 with patch("builtins.print") as patched_print:
|
amine@275
|
238 tokenizer.start_all()
|
amine@275
|
239 tokenizer.join()
|
amine@275
|
240 tokenizer._observers[0].join()
|
amine@275
|
241
|
amine@282
|
242 # Assert PrintWorker ran as expected
|
amine@275
|
243 expected_print_calls = [
|
amine@275
|
244 call(
|
amine@275
|
245 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
|
amine@345
|
246 i, exp[0], exp[1], exp[1] - exp[0]
|
amine@275
|
247 )
|
amine@275
|
248 )
|
amine@275
|
249 for i, exp in enumerate(self.expected, 1)
|
amine@275
|
250 ]
|
amine@275
|
251 self.assertEqual(patched_print.mock_calls, expected_print_calls)
|
amine@275
|
252 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@275
|
253 for det, exp in zip(tokenizer.detections, self.expected):
|
amine@275
|
254 start, end = exp
|
amine@275
|
255 self.assertAlmostEqual(det.start, start)
|
amine@275
|
256 self.assertAlmostEqual(det.end, end)
|
amine@287
|
257
|
amine@287
|
258 def test_StreamSaverWorker_wav(self):
|
amine@287
|
259 with TemporaryDirectory() as tmpdir:
|
amine@287
|
260 expected_filename = os.path.join(tmpdir, "output.wav")
|
amine@287
|
261 saver = StreamSaverWorker(self.reader, expected_filename)
|
amine@287
|
262 saver.start()
|
amine@287
|
263
|
amine@287
|
264 tokenizer = TokenizerWorker(saver)
|
amine@287
|
265 tokenizer.start_all()
|
amine@287
|
266 tokenizer.join()
|
amine@287
|
267 saver.join()
|
amine@287
|
268
|
amine@287
|
269 output_filename = saver.save_stream()
|
amine@287
|
270 region = AudioRegion.load(
|
amine@287
|
271 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
272 )
|
amine@287
|
273
|
amine@287
|
274 expected_region = AudioRegion.load(output_filename)
|
amine@287
|
275 self.assertEqual(output_filename, expected_filename)
|
amine@287
|
276 self.assertEqual(region, expected_region)
|
amine@287
|
277 self.assertEqual(saver.data, bytes(expected_region))
|
amine@287
|
278
|
amine@287
|
279 def test_StreamSaverWorker_raw(self):
|
amine@287
|
280 with TemporaryDirectory() as tmpdir:
|
amine@287
|
281 expected_filename = os.path.join(tmpdir, "output")
|
amine@287
|
282 saver = StreamSaverWorker(
|
amine@287
|
283 self.reader, expected_filename, export_format="raw"
|
amine@287
|
284 )
|
amine@287
|
285 saver.start()
|
amine@287
|
286 tokenizer = TokenizerWorker(saver)
|
amine@287
|
287 tokenizer.start_all()
|
amine@287
|
288 tokenizer.join()
|
amine@287
|
289 saver.join()
|
amine@287
|
290 output_filename = saver.save_stream()
|
amine@287
|
291 region = AudioRegion.load(
|
amine@287
|
292 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
293 )
|
amine@287
|
294 expected_region = AudioRegion.load(
|
amine@287
|
295 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
|
amine@287
|
296 )
|
amine@287
|
297 self.assertEqual(output_filename, expected_filename)
|
amine@287
|
298 self.assertEqual(region, expected_region)
|
amine@287
|
299 self.assertEqual(saver.data, bytes(expected_region))
|
amine@287
|
300
|
amine@287
|
301 def test_StreamSaverWorker_encode_audio(self):
|
amine@287
|
302 with TemporaryDirectory() as tmpdir:
|
amine@287
|
303 with patch("auditok.workers._run_subprocess") as patch_rsp:
|
amine@287
|
304 patch_rsp.return_value = (1, None, None)
|
amine@287
|
305 expected_filename = os.path.join(tmpdir, "output.ogg")
|
amine@287
|
306 tmp_expected_filename = expected_filename + ".wav"
|
amine@287
|
307 saver = StreamSaverWorker(self.reader, expected_filename)
|
amine@287
|
308 saver.start()
|
amine@287
|
309 tokenizer = TokenizerWorker(saver)
|
amine@287
|
310 tokenizer.start_all()
|
amine@287
|
311 tokenizer.join()
|
amine@287
|
312 saver.join()
|
amine@292
|
313 with self.assertRaises(AudioEncodingWarning) as rt_warn:
|
amine@287
|
314 saver.save_stream()
|
amine@287
|
315 warn_msg = "Couldn't save audio data in the desired format "
|
amine@287
|
316 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
|
amine@287
|
317 warn_msg += "is installed or this format is not recognized.\n"
|
amine@287
|
318 warn_msg += "Audio file was saved as '{}'"
|
amine@287
|
319 self.assertEqual(
|
amine@287
|
320 warn_msg.format(tmp_expected_filename), str(rt_warn.exception)
|
amine@287
|
321 )
|
amine@287
|
322 ffmpef_avconv = [
|
amine@287
|
323 "-y",
|
amine@287
|
324 "-f",
|
amine@287
|
325 "wav",
|
amine@287
|
326 "-i",
|
amine@287
|
327 tmp_expected_filename,
|
amine@287
|
328 "-f",
|
amine@287
|
329 "ogg",
|
amine@287
|
330 expected_filename,
|
amine@287
|
331 ]
|
amine@287
|
332 expected_calls = [
|
amine@287
|
333 call(["ffmpeg"] + ffmpef_avconv),
|
amine@287
|
334 call(["avconv"] + ffmpef_avconv),
|
amine@287
|
335 call(
|
amine@287
|
336 [
|
amine@287
|
337 "sox",
|
amine@287
|
338 "-t",
|
amine@287
|
339 "wav",
|
amine@287
|
340 tmp_expected_filename,
|
amine@287
|
341 expected_filename,
|
amine@287
|
342 ]
|
amine@287
|
343 ),
|
amine@287
|
344 ]
|
amine@287
|
345 self.assertEqual(patch_rsp.mock_calls, expected_calls)
|
amine@287
|
346 region = AudioRegion.load(
|
amine@287
|
347 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
|
amine@287
|
348 )
|
amine@287
|
349 self.assertTrue(saver._exported)
|
amine@287
|
350 self.assertEqual(saver.data, bytes(region))
|
amine@337
|
351
|
amine@337
|
352
|
amine@337
|
353 if __name__ == "__main__":
|
amine@337
|
354 unittest.main()
|