amine@274
|
1 import os
|
amine@274
|
2 from unittest import TestCase
|
amine@274
|
3 from unittest.mock import patch, call
|
amine@274
|
4 from tempfile import TemporaryDirectory
|
amine@274
|
5 from genty import genty, genty_dataset
|
amine@274
|
6 from auditok import AudioDataSource
|
amine@274
|
7 from auditok.cmdline_util import make_logger
|
amine@274
|
8 from auditok.workers import (
|
amine@274
|
9 TokenizerWorker,
|
amine@274
|
10 StreamSaverWorker,
|
amine@274
|
11 RegionSaverWorker,
|
amine@274
|
12 PlayerWorker,
|
amine@274
|
13 CommandLineWorker,
|
amine@274
|
14 PrintWorker,
|
amine@274
|
15 )
|
amine@274
|
16
|
amine@274
|
17
|
amine@274
|
18 @genty
|
amine@274
|
19 class TestWorkers(TestCase):
|
amine@275
|
20 def setUp(self):
|
amine@275
|
21
|
amine@275
|
22 self.reader = AudioDataSource(
|
amine@274
|
23 input="tests/data/test_split_10HZ_mono.raw",
|
amine@274
|
24 block_dur=0.1,
|
amine@274
|
25 sr=10,
|
amine@274
|
26 sw=2,
|
amine@274
|
27 ch=1,
|
amine@274
|
28 )
|
amine@275
|
29 self.expected = [
|
amine@275
|
30 (0.2, 1.6),
|
amine@275
|
31 (1.7, 3.1),
|
amine@275
|
32 (3.4, 5.4),
|
amine@275
|
33 (5.4, 7.4),
|
amine@275
|
34 (7.4, 7.6),
|
amine@275
|
35 ]
|
amine@275
|
36
|
amine@275
|
37 def tearDown(self):
|
amine@275
|
38 self.reader.close()
|
amine@275
|
39
|
amine@275
|
40 def test_TokenizerWorker(self):
|
amine@275
|
41
|
amine@274
|
42 with TemporaryDirectory() as tmpdir:
|
amine@274
|
43 file = os.path.join(tmpdir, "file.log")
|
amine@274
|
44 logger = make_logger(file=file, name="test_TokenizerWorker")
|
amine@274
|
45 tokenizer = TokenizerWorker(
|
amine@275
|
46 self.reader,
|
amine@274
|
47 logger=logger,
|
amine@274
|
48 min_dur=0.3,
|
amine@274
|
49 max_dur=2,
|
amine@274
|
50 max_silence=0.2,
|
amine@274
|
51 drop_trailing_silence=False,
|
amine@274
|
52 strict_min_dur=False,
|
amine@274
|
53 eth=50,
|
amine@274
|
54 )
|
amine@275
|
55 tokenizer.start_all()
|
amine@275
|
56 tokenizer.join()
|
amine@274
|
57 # Get logged text
|
amine@274
|
58 with open(file) as fp:
|
amine@274
|
59 log_lines = fp.readlines()
|
amine@274
|
60
|
amine@274
|
61 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
|
amine@274
|
62 log_fmt += "end: {:.3f}, duration: {:.3f})"
|
amine@275
|
63 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@274
|
64 for i, (det, exp, log_line) in enumerate(
|
amine@275
|
65 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@274
|
66 ):
|
amine@274
|
67 start, end = exp
|
amine@274
|
68 exp_log_line = log_fmt.format(i, start, end, end - start)
|
amine@274
|
69 self.assertAlmostEqual(det.start, start)
|
amine@274
|
70 self.assertAlmostEqual(det.end, end)
|
amine@274
|
71 # remove timestamp part and strip new line
|
amine@274
|
72 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@275
|
73
|
amine@277
|
74 def test_RegionSaverWorker(self):
|
amine@277
|
75 filename_format = (
|
amine@277
|
76 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
|
amine@277
|
77 )
|
amine@277
|
78 with TemporaryDirectory() as tmpdir:
|
amine@277
|
79 file = os.path.join(tmpdir, "file.log")
|
amine@277
|
80 logger = make_logger(file=file, name="test_RegionSaverWorker")
|
amine@277
|
81 observers = [RegionSaverWorker(filename_format, logger=logger)]
|
amine@277
|
82 tokenizer = TokenizerWorker(
|
amine@277
|
83 self.reader,
|
amine@277
|
84 logger=logger,
|
amine@277
|
85 observers=observers,
|
amine@277
|
86 min_dur=0.3,
|
amine@277
|
87 max_dur=2,
|
amine@277
|
88 max_silence=0.2,
|
amine@277
|
89 drop_trailing_silence=False,
|
amine@277
|
90 strict_min_dur=False,
|
amine@277
|
91 eth=50,
|
amine@277
|
92 )
|
amine@277
|
93 with patch("auditok.core.AudioRegion.save") as patched_save:
|
amine@277
|
94 tokenizer.start_all()
|
amine@277
|
95 tokenizer.join()
|
amine@277
|
96 tokenizer._observers[0].join()
|
amine@277
|
97 # Get logged text
|
amine@277
|
98 with open(file) as fp:
|
amine@277
|
99 log_lines = [
|
amine@277
|
100 line
|
amine@277
|
101 for line in fp.readlines()
|
amine@277
|
102 if line.startswith("[SAVE]")
|
amine@277
|
103 ]
|
amine@277
|
104
|
amine@277
|
105 # Asser PrintWorker ran as expected
|
amine@277
|
106 expected_save_calls = [
|
amine@277
|
107 call(
|
amine@277
|
108 filename_format.format(
|
amine@277
|
109 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
|
amine@277
|
110 ),
|
amine@277
|
111 None,
|
amine@277
|
112 )
|
amine@277
|
113 for i, exp in enumerate(self.expected, 1)
|
amine@277
|
114 ]
|
amine@277
|
115
|
amine@277
|
116 # get calls to 'AudioRegion.save'
|
amine@277
|
117 mock_calls = [
|
amine@277
|
118 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
|
amine@277
|
119 ]
|
amine@277
|
120 self.assertEqual(mock_calls, expected_save_calls)
|
amine@277
|
121 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@277
|
122
|
amine@279
|
123 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
|
amine@277
|
124 for i, (det, exp, log_line) in enumerate(
|
amine@277
|
125 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@277
|
126 ):
|
amine@277
|
127 start, end = exp
|
amine@277
|
128 expected_filename = filename_format.format(
|
amine@277
|
129 id=i, start=start, end=end, duration=end - start
|
amine@277
|
130 )
|
amine@277
|
131 exp_log_line = log_fmt.format(i, expected_filename)
|
amine@277
|
132 self.assertAlmostEqual(det.start, start)
|
amine@277
|
133 self.assertAlmostEqual(det.end, end)
|
amine@277
|
134 # remove timestamp part and strip new line
|
amine@277
|
135 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@277
|
136
|
amine@279
|
137 def test_CommandLineWorker(self):
|
amine@279
|
138 command_format = "do nothing with"
|
amine@279
|
139 with TemporaryDirectory() as tmpdir:
|
amine@279
|
140 file = os.path.join(tmpdir, "file.log")
|
amine@279
|
141 logger = make_logger(file=file, name="test_CommandLineWorker")
|
amine@279
|
142 observers = [CommandLineWorker(command_format, logger=logger)]
|
amine@279
|
143 tokenizer = TokenizerWorker(
|
amine@279
|
144 self.reader,
|
amine@279
|
145 logger=logger,
|
amine@279
|
146 observers=observers,
|
amine@279
|
147 min_dur=0.3,
|
amine@279
|
148 max_dur=2,
|
amine@279
|
149 max_silence=0.2,
|
amine@279
|
150 drop_trailing_silence=False,
|
amine@279
|
151 strict_min_dur=False,
|
amine@279
|
152 eth=50,
|
amine@279
|
153 )
|
amine@279
|
154 with patch("auditok.workers.os.system") as patched_os_system:
|
amine@279
|
155 tokenizer.start_all()
|
amine@279
|
156 tokenizer.join()
|
amine@279
|
157 tokenizer._observers[0].join()
|
amine@279
|
158 # Get logged text
|
amine@279
|
159 with open(file) as fp:
|
amine@279
|
160 log_lines = [
|
amine@279
|
161 line
|
amine@279
|
162 for line in fp.readlines()
|
amine@279
|
163 if line.startswith("[COMMAND]")
|
amine@279
|
164 ]
|
amine@279
|
165
|
amine@279
|
166 # Asser PrintWorker ran as expected
|
amine@279
|
167 expected_save_calls = [call(command_format) for _ in self.expected]
|
amine@279
|
168 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
|
amine@279
|
169 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@279
|
170
|
amine@279
|
171 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
|
amine@279
|
172 for i, (det, exp, log_line) in enumerate(
|
amine@279
|
173 zip(tokenizer.detections, self.expected, log_lines), 1
|
amine@279
|
174 ):
|
amine@279
|
175 start, end = exp
|
amine@279
|
176 exp_log_line = log_fmt.format(i, command_format)
|
amine@279
|
177 self.assertAlmostEqual(det.start, start)
|
amine@279
|
178 self.assertAlmostEqual(det.end, end)
|
amine@279
|
179 # remove timestamp part and strip new line
|
amine@279
|
180 self.assertEqual(log_line[28:].strip(), exp_log_line)
|
amine@279
|
181
|
amine@275
|
182 def test_PrintWorker(self):
|
amine@275
|
183 observers = [
|
amine@275
|
184 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
|
amine@275
|
185 ]
|
amine@275
|
186 tokenizer = TokenizerWorker(
|
amine@275
|
187 self.reader,
|
amine@275
|
188 observers=observers,
|
amine@275
|
189 min_dur=0.3,
|
amine@275
|
190 max_dur=2,
|
amine@275
|
191 max_silence=0.2,
|
amine@275
|
192 drop_trailing_silence=False,
|
amine@275
|
193 strict_min_dur=False,
|
amine@275
|
194 eth=50,
|
amine@275
|
195 )
|
amine@275
|
196 with patch("builtins.print") as patched_print:
|
amine@275
|
197 tokenizer.start_all()
|
amine@275
|
198 tokenizer.join()
|
amine@275
|
199 tokenizer._observers[0].join()
|
amine@275
|
200
|
amine@275
|
201 # Asser PrintWorker ran as expected
|
amine@275
|
202 expected_print_calls = [
|
amine@275
|
203 call(
|
amine@275
|
204 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
|
amine@275
|
205 i, *exp, exp[1] - exp[0]
|
amine@275
|
206 )
|
amine@275
|
207 )
|
amine@275
|
208 for i, exp in enumerate(self.expected, 1)
|
amine@275
|
209 ]
|
amine@275
|
210 self.assertEqual(patched_print.mock_calls, expected_print_calls)
|
amine@275
|
211 self.assertEqual(len(tokenizer.detections), len(self.expected))
|
amine@275
|
212
|
amine@275
|
213 for det, exp in zip(tokenizer.detections, self.expected):
|
amine@275
|
214 start, end = exp
|
amine@275
|
215 self.assertAlmostEqual(det.start, start)
|
amine@275
|
216 self.assertAlmostEqual(det.end, end)
|