annotate tests/test_workers.py @ 287:d13ce50446b7

Add tests for StreamSaverWorker
author Amine Sehili <amine.sehili@gmail.com>
date Fri, 04 Oct 2019 21:48:01 +0100
parents d40571459b37
children 9907db0843cb
rev   line source
amine@274 1 import os
amine@274 2 from unittest import TestCase
amine@282 3 from unittest.mock import patch, call, Mock
amine@274 4 from tempfile import TemporaryDirectory
amine@274 5 from genty import genty, genty_dataset
amine@287 6 from auditok import AudioRegion, AudioDataSource
amine@274 7 from auditok.cmdline_util import make_logger
amine@274 8 from auditok.workers import (
amine@274 9 TokenizerWorker,
amine@274 10 StreamSaverWorker,
amine@274 11 RegionSaverWorker,
amine@274 12 PlayerWorker,
amine@274 13 CommandLineWorker,
amine@274 14 PrintWorker,
amine@274 15 )
amine@274 16
amine@274 17
amine@274 18 @genty
amine@274 19 class TestWorkers(TestCase):
amine@275 20 def setUp(self):
amine@275 21
amine@275 22 self.reader = AudioDataSource(
amine@274 23 input="tests/data/test_split_10HZ_mono.raw",
amine@274 24 block_dur=0.1,
amine@274 25 sr=10,
amine@274 26 sw=2,
amine@274 27 ch=1,
amine@274 28 )
amine@275 29 self.expected = [
amine@275 30 (0.2, 1.6),
amine@275 31 (1.7, 3.1),
amine@275 32 (3.4, 5.4),
amine@275 33 (5.4, 7.4),
amine@275 34 (7.4, 7.6),
amine@275 35 ]
amine@275 36
amine@275 37 def tearDown(self):
amine@275 38 self.reader.close()
amine@275 39
amine@275 40 def test_TokenizerWorker(self):
amine@274 41 with TemporaryDirectory() as tmpdir:
amine@274 42 file = os.path.join(tmpdir, "file.log")
amine@274 43 logger = make_logger(file=file, name="test_TokenizerWorker")
amine@274 44 tokenizer = TokenizerWorker(
amine@275 45 self.reader,
amine@274 46 logger=logger,
amine@274 47 min_dur=0.3,
amine@274 48 max_dur=2,
amine@274 49 max_silence=0.2,
amine@274 50 drop_trailing_silence=False,
amine@274 51 strict_min_dur=False,
amine@274 52 eth=50,
amine@274 53 )
amine@275 54 tokenizer.start_all()
amine@275 55 tokenizer.join()
amine@274 56 # Get logged text
amine@274 57 with open(file) as fp:
amine@274 58 log_lines = fp.readlines()
amine@274 59
amine@274 60 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
amine@274 61 log_fmt += "end: {:.3f}, duration: {:.3f})"
amine@275 62 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@274 63 for i, (det, exp, log_line) in enumerate(
amine@275 64 zip(tokenizer.detections, self.expected, log_lines), 1
amine@274 65 ):
amine@274 66 start, end = exp
amine@274 67 exp_log_line = log_fmt.format(i, start, end, end - start)
amine@274 68 self.assertAlmostEqual(det.start, start)
amine@274 69 self.assertAlmostEqual(det.end, end)
amine@274 70 # remove timestamp part and strip new line
amine@274 71 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@275 72
amine@282 73 def test_PlayerWorker(self):
amine@282 74 with TemporaryDirectory() as tmpdir:
amine@282 75 file = os.path.join(tmpdir, "file.log")
amine@282 76 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@282 77 player_mock = Mock()
amine@282 78 observers = [PlayerWorker(player_mock, logger=logger)]
amine@282 79 tokenizer = TokenizerWorker(
amine@282 80 self.reader,
amine@282 81 logger=logger,
amine@282 82 observers=observers,
amine@282 83 min_dur=0.3,
amine@282 84 max_dur=2,
amine@282 85 max_silence=0.2,
amine@282 86 drop_trailing_silence=False,
amine@282 87 strict_min_dur=False,
amine@282 88 eth=50,
amine@282 89 )
amine@282 90 tokenizer.start_all()
amine@282 91 tokenizer.join()
amine@282 92 tokenizer._observers[0].join()
amine@282 93 # Get logged text
amine@282 94 with open(file) as fp:
amine@282 95 log_lines = [
amine@282 96 line
amine@282 97 for line in fp.readlines()
amine@282 98 if line.startswith("[PLAY]")
amine@282 99 ]
amine@282 100 self.assertTrue(player_mock.play.called)
amine@282 101
amine@282 102 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@282 103 log_fmt = "[PLAY]: Detection {id} played"
amine@282 104 for i, (det, exp, log_line) in enumerate(
amine@282 105 zip(tokenizer.detections, self.expected, log_lines), 1
amine@282 106 ):
amine@282 107 start, end = exp
amine@282 108 exp_log_line = log_fmt.format(id=i)
amine@282 109 self.assertAlmostEqual(det.start, start)
amine@282 110 self.assertAlmostEqual(det.end, end)
amine@282 111 # Remove timestamp part and strip new line
amine@282 112 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@282 113
amine@277 114 def test_RegionSaverWorker(self):
amine@277 115 filename_format = (
amine@277 116 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
amine@277 117 )
amine@277 118 with TemporaryDirectory() as tmpdir:
amine@277 119 file = os.path.join(tmpdir, "file.log")
amine@277 120 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@277 121 observers = [RegionSaverWorker(filename_format, logger=logger)]
amine@277 122 tokenizer = TokenizerWorker(
amine@277 123 self.reader,
amine@277 124 logger=logger,
amine@277 125 observers=observers,
amine@277 126 min_dur=0.3,
amine@277 127 max_dur=2,
amine@277 128 max_silence=0.2,
amine@277 129 drop_trailing_silence=False,
amine@277 130 strict_min_dur=False,
amine@277 131 eth=50,
amine@277 132 )
amine@277 133 with patch("auditok.core.AudioRegion.save") as patched_save:
amine@277 134 tokenizer.start_all()
amine@277 135 tokenizer.join()
amine@277 136 tokenizer._observers[0].join()
amine@277 137 # Get logged text
amine@277 138 with open(file) as fp:
amine@277 139 log_lines = [
amine@277 140 line
amine@277 141 for line in fp.readlines()
amine@277 142 if line.startswith("[SAVE]")
amine@277 143 ]
amine@277 144
amine@282 145 # Assert RegionSaverWorker ran as expected
amine@277 146 expected_save_calls = [
amine@277 147 call(
amine@277 148 filename_format.format(
amine@277 149 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
amine@277 150 ),
amine@277 151 None,
amine@277 152 )
amine@277 153 for i, exp in enumerate(self.expected, 1)
amine@277 154 ]
amine@277 155
amine@282 156 # Get calls to 'AudioRegion.save'
amine@277 157 mock_calls = [
amine@277 158 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
amine@277 159 ]
amine@277 160 self.assertEqual(mock_calls, expected_save_calls)
amine@277 161 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@277 162
amine@279 163 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
amine@277 164 for i, (det, exp, log_line) in enumerate(
amine@277 165 zip(tokenizer.detections, self.expected, log_lines), 1
amine@277 166 ):
amine@277 167 start, end = exp
amine@277 168 expected_filename = filename_format.format(
amine@277 169 id=i, start=start, end=end, duration=end - start
amine@277 170 )
amine@277 171 exp_log_line = log_fmt.format(i, expected_filename)
amine@277 172 self.assertAlmostEqual(det.start, start)
amine@277 173 self.assertAlmostEqual(det.end, end)
amine@282 174 # Remove timestamp part and strip new line
amine@277 175 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@277 176
amine@279 177 def test_CommandLineWorker(self):
amine@279 178 command_format = "do nothing with"
amine@279 179 with TemporaryDirectory() as tmpdir:
amine@279 180 file = os.path.join(tmpdir, "file.log")
amine@279 181 logger = make_logger(file=file, name="test_CommandLineWorker")
amine@279 182 observers = [CommandLineWorker(command_format, logger=logger)]
amine@279 183 tokenizer = TokenizerWorker(
amine@279 184 self.reader,
amine@279 185 logger=logger,
amine@279 186 observers=observers,
amine@279 187 min_dur=0.3,
amine@279 188 max_dur=2,
amine@279 189 max_silence=0.2,
amine@279 190 drop_trailing_silence=False,
amine@279 191 strict_min_dur=False,
amine@279 192 eth=50,
amine@279 193 )
amine@279 194 with patch("auditok.workers.os.system") as patched_os_system:
amine@279 195 tokenizer.start_all()
amine@279 196 tokenizer.join()
amine@279 197 tokenizer._observers[0].join()
amine@279 198 # Get logged text
amine@279 199 with open(file) as fp:
amine@279 200 log_lines = [
amine@279 201 line
amine@279 202 for line in fp.readlines()
amine@279 203 if line.startswith("[COMMAND]")
amine@279 204 ]
amine@279 205
amine@282 206 # Assert CommandLineWorker ran as expected
amine@279 207 expected_save_calls = [call(command_format) for _ in self.expected]
amine@279 208 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
amine@279 209 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@279 210 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
amine@279 211 for i, (det, exp, log_line) in enumerate(
amine@279 212 zip(tokenizer.detections, self.expected, log_lines), 1
amine@279 213 ):
amine@279 214 start, end = exp
amine@279 215 exp_log_line = log_fmt.format(i, command_format)
amine@279 216 self.assertAlmostEqual(det.start, start)
amine@279 217 self.assertAlmostEqual(det.end, end)
amine@282 218 # Remove timestamp part and strip new line
amine@279 219 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@279 220
amine@275 221 def test_PrintWorker(self):
amine@275 222 observers = [
amine@275 223 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
amine@275 224 ]
amine@275 225 tokenizer = TokenizerWorker(
amine@275 226 self.reader,
amine@275 227 observers=observers,
amine@275 228 min_dur=0.3,
amine@275 229 max_dur=2,
amine@275 230 max_silence=0.2,
amine@275 231 drop_trailing_silence=False,
amine@275 232 strict_min_dur=False,
amine@275 233 eth=50,
amine@275 234 )
amine@275 235 with patch("builtins.print") as patched_print:
amine@275 236 tokenizer.start_all()
amine@275 237 tokenizer.join()
amine@275 238 tokenizer._observers[0].join()
amine@275 239
amine@282 240 # Assert PrintWorker ran as expected
amine@275 241 expected_print_calls = [
amine@275 242 call(
amine@275 243 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
amine@275 244 i, *exp, exp[1] - exp[0]
amine@275 245 )
amine@275 246 )
amine@275 247 for i, exp in enumerate(self.expected, 1)
amine@275 248 ]
amine@275 249 self.assertEqual(patched_print.mock_calls, expected_print_calls)
amine@275 250 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@275 251 for det, exp in zip(tokenizer.detections, self.expected):
amine@275 252 start, end = exp
amine@275 253 self.assertAlmostEqual(det.start, start)
amine@275 254 self.assertAlmostEqual(det.end, end)
amine@287 255
amine@287 256 def test_StreamSaverWorker_wav(self):
amine@287 257 with TemporaryDirectory() as tmpdir:
amine@287 258 expected_filename = os.path.join(tmpdir, "output.wav")
amine@287 259 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 260 saver.start()
amine@287 261
amine@287 262 tokenizer = TokenizerWorker(saver)
amine@287 263 tokenizer.start_all()
amine@287 264 tokenizer.join()
amine@287 265 saver.join()
amine@287 266
amine@287 267 output_filename = saver.save_stream()
amine@287 268 region = AudioRegion.load(
amine@287 269 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 270 )
amine@287 271
amine@287 272 expected_region = AudioRegion.load(output_filename)
amine@287 273 self.assertEqual(output_filename, expected_filename)
amine@287 274 self.assertEqual(region, expected_region)
amine@287 275 self.assertEqual(saver.data, bytes(expected_region))
amine@287 276
amine@287 277 def test_StreamSaverWorker_raw(self):
amine@287 278 with TemporaryDirectory() as tmpdir:
amine@287 279 expected_filename = os.path.join(tmpdir, "output")
amine@287 280 saver = StreamSaverWorker(
amine@287 281 self.reader, expected_filename, export_format="raw"
amine@287 282 )
amine@287 283 saver.start()
amine@287 284 tokenizer = TokenizerWorker(saver)
amine@287 285 tokenizer.start_all()
amine@287 286 tokenizer.join()
amine@287 287 saver.join()
amine@287 288 output_filename = saver.save_stream()
amine@287 289 region = AudioRegion.load(
amine@287 290 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 291 )
amine@287 292 expected_region = AudioRegion.load(
amine@287 293 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
amine@287 294 )
amine@287 295 self.assertEqual(output_filename, expected_filename)
amine@287 296 self.assertEqual(region, expected_region)
amine@287 297 self.assertEqual(saver.data, bytes(expected_region))
amine@287 298
amine@287 299 def test_StreamSaverWorker_encode_audio(self):
amine@287 300 with TemporaryDirectory() as tmpdir:
amine@287 301 with patch("auditok.workers._run_subprocess") as patch_rsp:
amine@287 302 patch_rsp.return_value = (1, None, None)
amine@287 303 expected_filename = os.path.join(tmpdir, "output.ogg")
amine@287 304 tmp_expected_filename = expected_filename + ".wav"
amine@287 305 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 306 saver.start()
amine@287 307 tokenizer = TokenizerWorker(saver)
amine@287 308 tokenizer.start_all()
amine@287 309 tokenizer.join()
amine@287 310 saver.join()
amine@287 311 with self.assertRaises(RuntimeWarning) as rt_warn:
amine@287 312 saver.save_stream()
amine@287 313 warn_msg = "Couldn't save audio data in the desired format "
amine@287 314 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
amine@287 315 warn_msg += "is installed or this format is not recognized.\n"
amine@287 316 warn_msg += "Audio file was saved as '{}'"
amine@287 317 self.assertEqual(
amine@287 318 warn_msg.format(tmp_expected_filename), str(rt_warn.exception)
amine@287 319 )
amine@287 320 ffmpef_avconv = [
amine@287 321 "-y",
amine@287 322 "-f",
amine@287 323 "wav",
amine@287 324 "-i",
amine@287 325 tmp_expected_filename,
amine@287 326 "-f",
amine@287 327 "ogg",
amine@287 328 expected_filename,
amine@287 329 ]
amine@287 330 expected_calls = [
amine@287 331 call(["ffmpeg"] + ffmpef_avconv),
amine@287 332 call(["avconv"] + ffmpef_avconv),
amine@287 333 call(
amine@287 334 [
amine@287 335 "sox",
amine@287 336 "-t",
amine@287 337 "wav",
amine@287 338 tmp_expected_filename,
amine@287 339 expected_filename,
amine@287 340 ]
amine@287 341 ),
amine@287 342 ]
amine@287 343 self.assertEqual(patch_rsp.mock_calls, expected_calls)
amine@287 344 region = AudioRegion.load(
amine@287 345 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 346 )
amine@287 347 self.assertTrue(saver._exported)
amine@287 348 self.assertEqual(saver.data, bytes(region))