annotate tests/test_workers.py @ 316:b6c5125be036

Fix bugs in AudioEnergyValidator and signal_numpy and add tests
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 17 Oct 2019 21:21:29 +0100
parents 9907db0843cb
children 9f17aa9a4018
rev   line source
amine@274 1 import os
amine@274 2 from unittest import TestCase
amine@282 3 from unittest.mock import patch, call, Mock
amine@274 4 from tempfile import TemporaryDirectory
amine@274 5 from genty import genty, genty_dataset
amine@287 6 from auditok import AudioRegion, AudioDataSource
amine@292 7 from auditok.exceptions import AudioEncodingWarning
amine@274 8 from auditok.cmdline_util import make_logger
amine@274 9 from auditok.workers import (
amine@274 10 TokenizerWorker,
amine@274 11 StreamSaverWorker,
amine@274 12 RegionSaverWorker,
amine@274 13 PlayerWorker,
amine@274 14 CommandLineWorker,
amine@274 15 PrintWorker,
amine@274 16 )
amine@274 17
amine@274 18
amine@274 19 @genty
amine@274 20 class TestWorkers(TestCase):
amine@275 21 def setUp(self):
amine@275 22
amine@275 23 self.reader = AudioDataSource(
amine@274 24 input="tests/data/test_split_10HZ_mono.raw",
amine@274 25 block_dur=0.1,
amine@274 26 sr=10,
amine@274 27 sw=2,
amine@274 28 ch=1,
amine@274 29 )
amine@275 30 self.expected = [
amine@275 31 (0.2, 1.6),
amine@275 32 (1.7, 3.1),
amine@275 33 (3.4, 5.4),
amine@275 34 (5.4, 7.4),
amine@275 35 (7.4, 7.6),
amine@275 36 ]
amine@275 37
amine@275 38 def tearDown(self):
amine@275 39 self.reader.close()
amine@275 40
amine@275 41 def test_TokenizerWorker(self):
amine@274 42 with TemporaryDirectory() as tmpdir:
amine@274 43 file = os.path.join(tmpdir, "file.log")
amine@274 44 logger = make_logger(file=file, name="test_TokenizerWorker")
amine@274 45 tokenizer = TokenizerWorker(
amine@275 46 self.reader,
amine@274 47 logger=logger,
amine@274 48 min_dur=0.3,
amine@274 49 max_dur=2,
amine@274 50 max_silence=0.2,
amine@274 51 drop_trailing_silence=False,
amine@274 52 strict_min_dur=False,
amine@274 53 eth=50,
amine@274 54 )
amine@275 55 tokenizer.start_all()
amine@275 56 tokenizer.join()
amine@274 57 # Get logged text
amine@274 58 with open(file) as fp:
amine@274 59 log_lines = fp.readlines()
amine@274 60
amine@274 61 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
amine@274 62 log_fmt += "end: {:.3f}, duration: {:.3f})"
amine@275 63 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@274 64 for i, (det, exp, log_line) in enumerate(
amine@275 65 zip(tokenizer.detections, self.expected, log_lines), 1
amine@274 66 ):
amine@274 67 start, end = exp
amine@274 68 exp_log_line = log_fmt.format(i, start, end, end - start)
amine@274 69 self.assertAlmostEqual(det.start, start)
amine@274 70 self.assertAlmostEqual(det.end, end)
amine@274 71 # remove timestamp part and strip new line
amine@274 72 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@275 73
amine@282 74 def test_PlayerWorker(self):
amine@282 75 with TemporaryDirectory() as tmpdir:
amine@282 76 file = os.path.join(tmpdir, "file.log")
amine@282 77 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@282 78 player_mock = Mock()
amine@282 79 observers = [PlayerWorker(player_mock, logger=logger)]
amine@282 80 tokenizer = TokenizerWorker(
amine@282 81 self.reader,
amine@282 82 logger=logger,
amine@282 83 observers=observers,
amine@282 84 min_dur=0.3,
amine@282 85 max_dur=2,
amine@282 86 max_silence=0.2,
amine@282 87 drop_trailing_silence=False,
amine@282 88 strict_min_dur=False,
amine@282 89 eth=50,
amine@282 90 )
amine@282 91 tokenizer.start_all()
amine@282 92 tokenizer.join()
amine@282 93 tokenizer._observers[0].join()
amine@282 94 # Get logged text
amine@282 95 with open(file) as fp:
amine@282 96 log_lines = [
amine@282 97 line
amine@282 98 for line in fp.readlines()
amine@282 99 if line.startswith("[PLAY]")
amine@282 100 ]
amine@282 101 self.assertTrue(player_mock.play.called)
amine@282 102
amine@282 103 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@282 104 log_fmt = "[PLAY]: Detection {id} played"
amine@282 105 for i, (det, exp, log_line) in enumerate(
amine@282 106 zip(tokenizer.detections, self.expected, log_lines), 1
amine@282 107 ):
amine@282 108 start, end = exp
amine@282 109 exp_log_line = log_fmt.format(id=i)
amine@282 110 self.assertAlmostEqual(det.start, start)
amine@282 111 self.assertAlmostEqual(det.end, end)
amine@282 112 # Remove timestamp part and strip new line
amine@282 113 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@282 114
amine@277 115 def test_RegionSaverWorker(self):
amine@277 116 filename_format = (
amine@277 117 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
amine@277 118 )
amine@277 119 with TemporaryDirectory() as tmpdir:
amine@277 120 file = os.path.join(tmpdir, "file.log")
amine@277 121 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@277 122 observers = [RegionSaverWorker(filename_format, logger=logger)]
amine@277 123 tokenizer = TokenizerWorker(
amine@277 124 self.reader,
amine@277 125 logger=logger,
amine@277 126 observers=observers,
amine@277 127 min_dur=0.3,
amine@277 128 max_dur=2,
amine@277 129 max_silence=0.2,
amine@277 130 drop_trailing_silence=False,
amine@277 131 strict_min_dur=False,
amine@277 132 eth=50,
amine@277 133 )
amine@277 134 with patch("auditok.core.AudioRegion.save") as patched_save:
amine@277 135 tokenizer.start_all()
amine@277 136 tokenizer.join()
amine@277 137 tokenizer._observers[0].join()
amine@277 138 # Get logged text
amine@277 139 with open(file) as fp:
amine@277 140 log_lines = [
amine@277 141 line
amine@277 142 for line in fp.readlines()
amine@277 143 if line.startswith("[SAVE]")
amine@277 144 ]
amine@277 145
amine@282 146 # Assert RegionSaverWorker ran as expected
amine@277 147 expected_save_calls = [
amine@277 148 call(
amine@277 149 filename_format.format(
amine@277 150 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
amine@277 151 ),
amine@277 152 None,
amine@277 153 )
amine@277 154 for i, exp in enumerate(self.expected, 1)
amine@277 155 ]
amine@277 156
amine@282 157 # Get calls to 'AudioRegion.save'
amine@277 158 mock_calls = [
amine@277 159 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
amine@277 160 ]
amine@277 161 self.assertEqual(mock_calls, expected_save_calls)
amine@277 162 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@277 163
amine@279 164 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
amine@277 165 for i, (det, exp, log_line) in enumerate(
amine@277 166 zip(tokenizer.detections, self.expected, log_lines), 1
amine@277 167 ):
amine@277 168 start, end = exp
amine@277 169 expected_filename = filename_format.format(
amine@277 170 id=i, start=start, end=end, duration=end - start
amine@277 171 )
amine@277 172 exp_log_line = log_fmt.format(i, expected_filename)
amine@277 173 self.assertAlmostEqual(det.start, start)
amine@277 174 self.assertAlmostEqual(det.end, end)
amine@282 175 # Remove timestamp part and strip new line
amine@277 176 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@277 177
amine@279 178 def test_CommandLineWorker(self):
amine@279 179 command_format = "do nothing with"
amine@279 180 with TemporaryDirectory() as tmpdir:
amine@279 181 file = os.path.join(tmpdir, "file.log")
amine@279 182 logger = make_logger(file=file, name="test_CommandLineWorker")
amine@279 183 observers = [CommandLineWorker(command_format, logger=logger)]
amine@279 184 tokenizer = TokenizerWorker(
amine@279 185 self.reader,
amine@279 186 logger=logger,
amine@279 187 observers=observers,
amine@279 188 min_dur=0.3,
amine@279 189 max_dur=2,
amine@279 190 max_silence=0.2,
amine@279 191 drop_trailing_silence=False,
amine@279 192 strict_min_dur=False,
amine@279 193 eth=50,
amine@279 194 )
amine@279 195 with patch("auditok.workers.os.system") as patched_os_system:
amine@279 196 tokenizer.start_all()
amine@279 197 tokenizer.join()
amine@279 198 tokenizer._observers[0].join()
amine@279 199 # Get logged text
amine@279 200 with open(file) as fp:
amine@279 201 log_lines = [
amine@279 202 line
amine@279 203 for line in fp.readlines()
amine@279 204 if line.startswith("[COMMAND]")
amine@279 205 ]
amine@279 206
amine@282 207 # Assert CommandLineWorker ran as expected
amine@279 208 expected_save_calls = [call(command_format) for _ in self.expected]
amine@279 209 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
amine@279 210 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@279 211 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
amine@279 212 for i, (det, exp, log_line) in enumerate(
amine@279 213 zip(tokenizer.detections, self.expected, log_lines), 1
amine@279 214 ):
amine@279 215 start, end = exp
amine@279 216 exp_log_line = log_fmt.format(i, command_format)
amine@279 217 self.assertAlmostEqual(det.start, start)
amine@279 218 self.assertAlmostEqual(det.end, end)
amine@282 219 # Remove timestamp part and strip new line
amine@279 220 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@279 221
amine@275 222 def test_PrintWorker(self):
amine@275 223 observers = [
amine@275 224 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
amine@275 225 ]
amine@275 226 tokenizer = TokenizerWorker(
amine@275 227 self.reader,
amine@275 228 observers=observers,
amine@275 229 min_dur=0.3,
amine@275 230 max_dur=2,
amine@275 231 max_silence=0.2,
amine@275 232 drop_trailing_silence=False,
amine@275 233 strict_min_dur=False,
amine@275 234 eth=50,
amine@275 235 )
amine@275 236 with patch("builtins.print") as patched_print:
amine@275 237 tokenizer.start_all()
amine@275 238 tokenizer.join()
amine@275 239 tokenizer._observers[0].join()
amine@275 240
amine@282 241 # Assert PrintWorker ran as expected
amine@275 242 expected_print_calls = [
amine@275 243 call(
amine@275 244 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
amine@275 245 i, *exp, exp[1] - exp[0]
amine@275 246 )
amine@275 247 )
amine@275 248 for i, exp in enumerate(self.expected, 1)
amine@275 249 ]
amine@275 250 self.assertEqual(patched_print.mock_calls, expected_print_calls)
amine@275 251 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@275 252 for det, exp in zip(tokenizer.detections, self.expected):
amine@275 253 start, end = exp
amine@275 254 self.assertAlmostEqual(det.start, start)
amine@275 255 self.assertAlmostEqual(det.end, end)
amine@287 256
amine@287 257 def test_StreamSaverWorker_wav(self):
amine@287 258 with TemporaryDirectory() as tmpdir:
amine@287 259 expected_filename = os.path.join(tmpdir, "output.wav")
amine@287 260 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 261 saver.start()
amine@287 262
amine@287 263 tokenizer = TokenizerWorker(saver)
amine@287 264 tokenizer.start_all()
amine@287 265 tokenizer.join()
amine@287 266 saver.join()
amine@287 267
amine@287 268 output_filename = saver.save_stream()
amine@287 269 region = AudioRegion.load(
amine@287 270 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 271 )
amine@287 272
amine@287 273 expected_region = AudioRegion.load(output_filename)
amine@287 274 self.assertEqual(output_filename, expected_filename)
amine@287 275 self.assertEqual(region, expected_region)
amine@287 276 self.assertEqual(saver.data, bytes(expected_region))
amine@287 277
amine@287 278 def test_StreamSaverWorker_raw(self):
amine@287 279 with TemporaryDirectory() as tmpdir:
amine@287 280 expected_filename = os.path.join(tmpdir, "output")
amine@287 281 saver = StreamSaverWorker(
amine@287 282 self.reader, expected_filename, export_format="raw"
amine@287 283 )
amine@287 284 saver.start()
amine@287 285 tokenizer = TokenizerWorker(saver)
amine@287 286 tokenizer.start_all()
amine@287 287 tokenizer.join()
amine@287 288 saver.join()
amine@287 289 output_filename = saver.save_stream()
amine@287 290 region = AudioRegion.load(
amine@287 291 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 292 )
amine@287 293 expected_region = AudioRegion.load(
amine@287 294 output_filename, sr=10, sw=2, ch=1, audio_format="raw"
amine@287 295 )
amine@287 296 self.assertEqual(output_filename, expected_filename)
amine@287 297 self.assertEqual(region, expected_region)
amine@287 298 self.assertEqual(saver.data, bytes(expected_region))
amine@287 299
amine@287 300 def test_StreamSaverWorker_encode_audio(self):
amine@287 301 with TemporaryDirectory() as tmpdir:
amine@287 302 with patch("auditok.workers._run_subprocess") as patch_rsp:
amine@287 303 patch_rsp.return_value = (1, None, None)
amine@287 304 expected_filename = os.path.join(tmpdir, "output.ogg")
amine@287 305 tmp_expected_filename = expected_filename + ".wav"
amine@287 306 saver = StreamSaverWorker(self.reader, expected_filename)
amine@287 307 saver.start()
amine@287 308 tokenizer = TokenizerWorker(saver)
amine@287 309 tokenizer.start_all()
amine@287 310 tokenizer.join()
amine@287 311 saver.join()
amine@292 312 with self.assertRaises(AudioEncodingWarning) as rt_warn:
amine@287 313 saver.save_stream()
amine@287 314 warn_msg = "Couldn't save audio data in the desired format "
amine@287 315 warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' "
amine@287 316 warn_msg += "is installed or this format is not recognized.\n"
amine@287 317 warn_msg += "Audio file was saved as '{}'"
amine@287 318 self.assertEqual(
amine@287 319 warn_msg.format(tmp_expected_filename), str(rt_warn.exception)
amine@287 320 )
amine@287 321 ffmpef_avconv = [
amine@287 322 "-y",
amine@287 323 "-f",
amine@287 324 "wav",
amine@287 325 "-i",
amine@287 326 tmp_expected_filename,
amine@287 327 "-f",
amine@287 328 "ogg",
amine@287 329 expected_filename,
amine@287 330 ]
amine@287 331 expected_calls = [
amine@287 332 call(["ffmpeg"] + ffmpef_avconv),
amine@287 333 call(["avconv"] + ffmpef_avconv),
amine@287 334 call(
amine@287 335 [
amine@287 336 "sox",
amine@287 337 "-t",
amine@287 338 "wav",
amine@287 339 tmp_expected_filename,
amine@287 340 expected_filename,
amine@287 341 ]
amine@287 342 ),
amine@287 343 ]
amine@287 344 self.assertEqual(patch_rsp.mock_calls, expected_calls)
amine@287 345 region = AudioRegion.load(
amine@287 346 "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1
amine@287 347 )
amine@287 348 self.assertTrue(saver._exported)
amine@287 349 self.assertEqual(saver.data, bytes(region))