annotate tests/test_workers.py @ 338:f424ac9193b7

Make sure all modules define __all__
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 27 Oct 2019 15:23:00 +0100
parents 9f17aa9a4018
children 5732edbfae30
rev   line source
amine@274 1 import os
amine@337 2 import unittest
amine@274 3 from unittest import TestCase
amine@282 4 from unittest.mock import patch, call, Mock
amine@274 5 from tempfile import TemporaryDirectory
amine@274 6 from genty import genty, genty_dataset
amine@287 7 from auditok import AudioRegion, AudioDataSource
amine@292 8 from auditok.exceptions import AudioEncodingWarning
amine@274 9 from auditok.cmdline_util import make_logger
amine@274 10 from auditok.workers import (
amine@274 11 TokenizerWorker,
amine@274 12 StreamSaverWorker,
amine@274 13 RegionSaverWorker,
amine@274 14 PlayerWorker,
amine@274 15 CommandLineWorker,
amine@274 16 PrintWorker,
amine@274 17 )
amine@274 18
amine@274 19
amine@274 20 @genty
amine@274 21 class TestWorkers(TestCase):
amine@275 22 def setUp(self):
amine@275 23
amine@275 24 self.reader = AudioDataSource(
amine@274 25 input="tests/data/test_split_10HZ_mono.raw",
amine@274 26 block_dur=0.1,
amine@274 27 sr=10,
amine@274 28 sw=2,
amine@274 29 ch=1,
amine@274 30 )
amine@275 31 self.expected = [
amine@275 32 (0.2, 1.6),
amine@275 33 (1.7, 3.1),
amine@275 34 (3.4, 5.4),
amine@275 35 (5.4, 7.4),
amine@275 36 (7.4, 7.6),
amine@275 37 ]
amine@275 38
amine@275 39 def tearDown(self):
amine@275 40 self.reader.close()
amine@275 41
amine@275 42 def test_TokenizerWorker(self):
amine@274 43 with TemporaryDirectory() as tmpdir:
amine@274 44 file = os.path.join(tmpdir, "file.log")
amine@274 45 logger = make_logger(file=file, name="test_TokenizerWorker")
amine@274 46 tokenizer = TokenizerWorker(
amine@275 47 self.reader,
amine@274 48 logger=logger,
amine@274 49 min_dur=0.3,
amine@274 50 max_dur=2,
amine@274 51 max_silence=0.2,
amine@274 52 drop_trailing_silence=False,
amine@274 53 strict_min_dur=False,
amine@274 54 eth=50,
amine@274 55 )
amine@275 56 tokenizer.start_all()
amine@275 57 tokenizer.join()
amine@274 58 # Get logged text
amine@274 59 with open(file) as fp:
amine@274 60 log_lines = fp.readlines()
amine@274 61
amine@274 62 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
amine@274 63 log_fmt += "end: {:.3f}, duration: {:.3f})"
amine@275 64 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@274 65 for i, (det, exp, log_line) in enumerate(
amine@275 66 zip(tokenizer.detections, self.expected, log_lines), 1
amine@274 67 ):
amine@274 68 start, end = exp
amine@274 69 exp_log_line = log_fmt.format(i, start, end, end - start)
amine@274 70 self.assertAlmostEqual(det.start, start)
amine@274 71 self.assertAlmostEqual(det.end, end)
amine@274 72 # remove timestamp part and strip new line
amine@274 73 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@275 74
amine@282 75 def test_PlayerWorker(self):
amine@282 76 with TemporaryDirectory() as tmpdir:
amine@282 77 file = os.path.join(tmpdir, "file.log")
amine@282 78 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@282 79 player_mock = Mock()
amine@282 80 observers = [PlayerWorker(player_mock, logger=logger)]
amine@282 81 tokenizer = TokenizerWorker(
amine@282 82 self.reader,
amine@282 83 logger=logger,
amine@282 84 observers=observers,
amine@282 85 min_dur=0.3,
amine@282 86 max_dur=2,
amine@282 87 max_silence=0.2,
amine@282 88 drop_trailing_silence=False,
amine@282 89 strict_min_dur=False,
amine@282 90 eth=50,
amine@282 91 )
amine@282 92 tokenizer.start_all()
amine@282 93 tokenizer.join()
amine@282 94 tokenizer._observers[0].join()
amine@282 95 # Get logged text
amine@282 96 with open(file) as fp:
amine@282 97 log_lines = [
amine@282 98 line
amine@282 99 for line in fp.readlines()
amine@282 100 if line.startswith("[PLAY]")
amine@282 101 ]
amine@282 102 self.assertTrue(player_mock.play.called)
amine@282 103
amine@282 104 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@282 105 log_fmt = "[PLAY]: Detection {id} played"
amine@282 106 for i, (det, exp, log_line) in enumerate(
amine@282 107 zip(tokenizer.detections, self.expected, log_lines), 1
amine@282 108 ):
amine@282 109 start, end = exp
amine@282 110 exp_log_line = log_fmt.format(id=i)
amine@282 111 self.assertAlmostEqual(det.start, start)
amine@282 112 self.assertAlmostEqual(det.end, end)
amine@282 113 # Remove timestamp part and strip new line
amine@282 114 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@282 115
amine@277 116 def test_RegionSaverWorker(self):
amine@277 117 filename_format = (
amine@277 118 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
amine@277 119 )
amine@277 120 with TemporaryDirectory() as tmpdir:
amine@277 121 file = os.path.join(tmpdir, "file.log")
amine@277 122 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@277 123 observers = [RegionSaverWorker(filename_format, logger=logger)]
amine@277 124 tokenizer = TokenizerWorker(
amine@277 125 self.reader,
amine@277 126 logger=logger,
amine@277 127 observers=observers,
amine@277 128 min_dur=0.3,
amine@277 129 max_dur=2,
amine@277 130 max_silence=0.2,
amine@277 131 drop_trailing_silence=False,
amine@277 132 strict_min_dur=False,
amine@277 133 eth=50,
amine@277 134 )
amine@277 135 with patch("auditok.core.AudioRegion.save") as patched_save:
amine@277 136 tokenizer.start_all()
amine@277 137 tokenizer.join()
amine@277 138 tokenizer._observers[0].join()
amine@277 139 # Get logged text
amine@277 140 with open(file) as fp:
amine@277 141 log_lines = [
amine@277 142 line
amine@277 143 for line in fp.readlines()
amine@277 144 if line.startswith("[SAVE]")
amine@277 145 ]
amine@277 146
amine@282 147 # Assert RegionSaverWorker ran as expected
amine@277 148 expected_save_calls = [
amine@277 149 call(
amine@277 150 filename_format.format(
amine@277 151 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
amine@277 152 ),
amine@277 153 None,
amine@277 154 )
amine@277 155 for i, exp in enumerate(self.expected, 1)
amine@277 156 ]
amine@277 157
amine@282 158 # Get calls to 'AudioRegion.save'
amine@277 159 mock_calls = [
amine@277 160 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
amine@277 161 ]
amine@277 162 self.assertEqual(mock_calls, expected_save_calls)
amine@277 163 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@277 164
amine@279 165 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
amine@277 166 for i, (det, exp, log_line) in enumerate(
amine@277 167 zip(tokenizer.detections, self.expected, log_lines), 1
amine@277 168 ):
amine@277 169 start, end = exp
amine@277 170 expected_filename = filename_format.format(
amine@277 171 id=i, start=start, end=end, duration=end - start
amine@277 172 )
amine@277 173 exp_log_line = log_fmt.format(i, expected_filename)
amine@277 174 self.assertAlmostEqual(det.start, start)
amine@277 175 self.assertAlmostEqual(det.end, end)
amine@282 176 # Remove timestamp part and strip new line
amine@277 177 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@277 178
amine@279 179 def test_CommandLineWorker(self):
amine@279 180 command_format = "do nothing with"
amine@279 181 with TemporaryDirectory() as tmpdir:
amine@279 182 file = os.path.join(tmpdir, "file.log")
amine@279 183 logger = make_logger(file=file, name="test_CommandLineWorker")
amine@279 184 observers = [CommandLineWorker(command_format, logger=logger)]
amine@279 185 tokenizer = TokenizerWorker(
amine@279 186 self.reader,
amine@279 187 logger=logger,
amine@279 188 observers=observers,
amine@279 189 min_dur=0.3,
amine@279 190 max_dur=2,
amine@279 191 max_silence=0.2,
amine@279 192 drop_trailing_silence=False,
amine@279 193 strict_min_dur=False,
amine@279 194 eth=50,
amine@279 195 )
amine@279 196 with patch("auditok.workers.os.system") as patched_os_system:
amine@279 197 tokenizer.start_all()
amine@279 198 tokenizer.join()
amine@279 199 tokenizer._observers[0].join()
amine@279 200 # Get logged text
amine@279 201 with open(file) as fp:
amine@279 202 log_lines = [
amine@279 203 line
amine@279 204 for line in fp.readlines()
amine@279 205 if line.startswith("[COMMAND]")
amine@279 206 ]
amine@279 207
amine@282 208 # Assert CommandLineWorker ran as expected
amine@279 209 expected_save_calls = [call(command_format) for _ in self.expected]
amine@279 210 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
amine@279 211 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@279 212 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
amine@279 213 for i, (det, exp, log_line) in enumerate(
amine@279 214 zip(tokenizer.detections, self.expected, log_lines), 1
amine@279 215 ):
amine@279 216 start, end = exp
amine@279 217 exp_log_line = log_fmt.format(i, command_format)
amine@279 218 self.assertAlmostEqual(det.start, start)
amine@279 219 self.assertAlmostEqual(det.end, end)
amine@282 220 # Remove timestamp part and strip new line
amine@279 221 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@279 222
amine@275 223 def test_PrintWorker(self):
amine@275 224 observers = [
amine@275 225 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
amine@275 226 ]
amine@275 227 tokenizer = TokenizerWorker(
amine@275 228 self.reader,
amine@275 229 observers=observers,
amine@275 230 min_dur=0.3,
amine@275 231 max_dur=2,
amine@275 232 max_silence=0.2,
amine@275 233 drop_trailing_silence=False,
amine@275 234 strict_min_dur=False,
amine@275 235 eth=50,
amine@275 236 )
amine@275 237 with patch("builtins.print") as patched_print:
amine@275 238 tokenizer.start_all()
amine@275 239 tokenizer.join()
amine@275 240 tokenizer._observers[0].join()
amine@275 241
amine@282 242 # Assert PrintWorker ran as expected
amine@275 243 expected_print_calls = [
amine@275 244 call(
amine@275 245 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
amine@275 246 i, *exp, exp[1] - exp[0]
amine@275 247 )
amine@275 248 )
amine@275 249 for i, exp in enumerate(self.expected, 1)
amine@275 250 ]
amine@275 251 self.assertEqual(patched_print.mock_calls, expected_print_calls)
amine@275 252 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@275 253 for det, exp in zip(tokenizer.detections, self.expected):
amine@275 254 start, end = exp
amine@275 255 self.assertAlmostEqual(det.start, start)
amine@275 256 self.assertAlmostEqual(det.end, end)
amine@287 257
amine@287 258 def test_StreamSaverWorker_wav(self):
amine@287 259 with TemporaryDirectory() as tmpdir:
amine@287 260 expected_filename = os.path.join(tmpdir, "output.wav")
amine@287 261 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 262 saver.start()
amine@287 263
amine@287 264 tokenizer = TokenizerWorker(saver)
amine@287 265 tokenizer.start_all()
amine@287 266 tokenizer.join()
amine@287 267 saver.join()
amine@287 268
amine@287 269 output_filename = saver.save_stream()
amine@287 270 region = AudioRegion.load(
amine@287 271 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 272 )
amine@287 273
amine@287 274 expected_region = AudioRegion.load(output_filename)
amine@287 275 self.assertEqual(output_filename, expected_filename)
amine@287 276 self.assertEqual(region, expected_region)
amine@287 277 self.assertEqual(saver.data, bytes(expected_region))
amine@287 278
amine@287 279 def test_StreamSaverWorker_raw(self):
amine@287 280 with TemporaryDirectory() as tmpdir:
amine@287 281 expected_filename = os.path.join(tmpdir, "output")
amine@287 282 saver = StreamSaverWorker(
amine@287 283 self.reader, expected_filename, export_format="raw"
amine@287 284 )
amine@287 285 saver.start()
amine@287 286 tokenizer = TokenizerWorker(saver)
amine@287 287 tokenizer.start_all()
amine@287 288 tokenizer.join()
amine@287 289 saver.join()
amine@287 290 output_filename = saver.save_stream()
amine@287 291 region = AudioRegion.load(
amine@287 292 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 293 )
amine@287 294 expected_region = AudioRegion.load(
amine@287 295 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
amine@287 296 )
amine@287 297 self.assertEqual(output_filename, expected_filename)
amine@287 298 self.assertEqual(region, expected_region)
amine@287 299 self.assertEqual(saver.data, bytes(expected_region))
amine@287 300
amine@287 301 def test_StreamSaverWorker_encode_audio(self):
amine@287 302 with TemporaryDirectory() as tmpdir:
amine@287 303 with patch("auditok.workers._run_subprocess") as patch_rsp:
amine@287 304 patch_rsp.return_value = (1, None, None)
amine@287 305 expected_filename = os.path.join(tmpdir, "output.ogg")
amine@287 306 tmp_expected_filename = expected_filename + ".wav"
amine@287 307 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 308 saver.start()
amine@287 309 tokenizer = TokenizerWorker(saver)
amine@287 310 tokenizer.start_all()
amine@287 311 tokenizer.join()
amine@287 312 saver.join()
amine@292 313 with self.assertRaises(AudioEncodingWarning) as rt_warn:
amine@287 314 saver.save_stream()
amine@287 315 warn_msg = "Couldn't save audio data in the desired format "
amine@287 316 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
amine@287 317 warn_msg += "is installed or this format is not recognized.\n"
amine@287 318 warn_msg += "Audio file was saved as '{}'"
amine@287 319 self.assertEqual(
amine@287 320 warn_msg.format(tmp_expected_filename), str(rt_warn.exception)
amine@287 321 )
amine@287 322 ffmpef_avconv = [
amine@287 323 "-y",
amine@287 324 "-f",
amine@287 325 "wav",
amine@287 326 "-i",
amine@287 327 tmp_expected_filename,
amine@287 328 "-f",
amine@287 329 "ogg",
amine@287 330 expected_filename,
amine@287 331 ]
amine@287 332 expected_calls = [
amine@287 333 call(["ffmpeg"] + ffmpef_avconv),
amine@287 334 call(["avconv"] + ffmpef_avconv),
amine@287 335 call(
amine@287 336 [
amine@287 337 "sox",
amine@287 338 "-t",
amine@287 339 "wav",
amine@287 340 tmp_expected_filename,
amine@287 341 expected_filename,
amine@287 342 ]
amine@287 343 ),
amine@287 344 ]
amine@287 345 self.assertEqual(patch_rsp.mock_calls, expected_calls)
amine@287 346 region = AudioRegion.load(
amine@287 347 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 348 )
amine@287 349 self.assertTrue(saver._exported)
amine@287 350 self.assertEqual(saver.data, bytes(region))
amine@337 351
amine@337 352
amine@337 353 if __name__ == "__main__":
amine@337 354 unittest.main()