annotate tests/test_workers.py @ 282:d40571459b37

Add test for PalyerWorker
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 30 Sep 2019 21:33:16 +0100
parents 87bb649f5d3a
children d13ce50446b7
rev   line source
amine@274 1 import os
amine@274 2 from unittest import TestCase
amine@282 3 from unittest.mock import patch, call, Mock
amine@274 4 from tempfile import TemporaryDirectory
amine@274 5 from genty import genty, genty_dataset
amine@274 6 from auditok import AudioDataSource
amine@274 7 from auditok.cmdline_util import make_logger
amine@274 8 from auditok.workers import (
amine@274 9 TokenizerWorker,
amine@274 10 StreamSaverWorker,
amine@274 11 RegionSaverWorker,
amine@274 12 PlayerWorker,
amine@274 13 CommandLineWorker,
amine@274 14 PrintWorker,
amine@274 15 )
amine@274 16
amine@274 17
amine@274 18 @genty
amine@274 19 class TestWorkers(TestCase):
amine@275 20 def setUp(self):
amine@275 21
amine@275 22 self.reader = AudioDataSource(
amine@274 23 input="tests/data/test_split_10HZ_mono.raw",
amine@274 24 block_dur=0.1,
amine@274 25 sr=10,
amine@274 26 sw=2,
amine@274 27 ch=1,
amine@274 28 )
amine@275 29 self.expected = [
amine@275 30 (0.2, 1.6),
amine@275 31 (1.7, 3.1),
amine@275 32 (3.4, 5.4),
amine@275 33 (5.4, 7.4),
amine@275 34 (7.4, 7.6),
amine@275 35 ]
amine@275 36
amine@275 37 def tearDown(self):
amine@275 38 self.reader.close()
amine@275 39
amine@275 40 def test_TokenizerWorker(self):
amine@274 41 with TemporaryDirectory() as tmpdir:
amine@274 42 file = os.path.join(tmpdir, "file.log")
amine@274 43 logger = make_logger(file=file, name="test_TokenizerWorker")
amine@274 44 tokenizer = TokenizerWorker(
amine@275 45 self.reader,
amine@274 46 logger=logger,
amine@274 47 min_dur=0.3,
amine@274 48 max_dur=2,
amine@274 49 max_silence=0.2,
amine@274 50 drop_trailing_silence=False,
amine@274 51 strict_min_dur=False,
amine@274 52 eth=50,
amine@274 53 )
amine@275 54 tokenizer.start_all()
amine@275 55 tokenizer.join()
amine@274 56 # Get logged text
amine@274 57 with open(file) as fp:
amine@274 58 log_lines = fp.readlines()
amine@274 59
amine@274 60 log_fmt = "[DET]: Detection {} (start: {:.3f}, "
amine@274 61 log_fmt += "end: {:.3f}, duration: {:.3f})"
amine@275 62 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@274 63 for i, (det, exp, log_line) in enumerate(
amine@275 64 zip(tokenizer.detections, self.expected, log_lines), 1
amine@274 65 ):
amine@274 66 start, end = exp
amine@274 67 exp_log_line = log_fmt.format(i, start, end, end - start)
amine@274 68 self.assertAlmostEqual(det.start, start)
amine@274 69 self.assertAlmostEqual(det.end, end)
amine@274 70 # remove timestamp part and strip new line
amine@274 71 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@275 72
amine@282 73 def test_PlayerWorker(self):
amine@282 74 with TemporaryDirectory() as tmpdir:
amine@282 75 file = os.path.join(tmpdir, "file.log")
amine@282 76 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@282 77 player_mock = Mock()
amine@282 78 observers = [PlayerWorker(player_mock, logger=logger)]
amine@282 79 tokenizer = TokenizerWorker(
amine@282 80 self.reader,
amine@282 81 logger=logger,
amine@282 82 observers=observers,
amine@282 83 min_dur=0.3,
amine@282 84 max_dur=2,
amine@282 85 max_silence=0.2,
amine@282 86 drop_trailing_silence=False,
amine@282 87 strict_min_dur=False,
amine@282 88 eth=50,
amine@282 89 )
amine@282 90 tokenizer.start_all()
amine@282 91 tokenizer.join()
amine@282 92 tokenizer._observers[0].join()
amine@282 93 # Get logged text
amine@282 94 with open(file) as fp:
amine@282 95 log_lines = [
amine@282 96 line
amine@282 97 for line in fp.readlines()
amine@282 98 if line.startswith("[PLAY]")
amine@282 99 ]
amine@282 100 self.assertTrue(player_mock.play.called)
amine@282 101
amine@282 102 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@282 103 log_fmt = "[PLAY]: Detection {id} played"
amine@282 104 for i, (det, exp, log_line) in enumerate(
amine@282 105 zip(tokenizer.detections, self.expected, log_lines), 1
amine@282 106 ):
amine@282 107 start, end = exp
amine@282 108 exp_log_line = log_fmt.format(id=i)
amine@282 109 self.assertAlmostEqual(det.start, start)
amine@282 110 self.assertAlmostEqual(det.end, end)
amine@282 111 # Remove timestamp part and strip new line
amine@282 112 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@282 113
amine@277 114 def test_RegionSaverWorker(self):
amine@277 115 filename_format = (
amine@277 116 "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
amine@277 117 )
amine@277 118 with TemporaryDirectory() as tmpdir:
amine@277 119 file = os.path.join(tmpdir, "file.log")
amine@277 120 logger = make_logger(file=file, name="test_RegionSaverWorker")
amine@277 121 observers = [RegionSaverWorker(filename_format, logger=logger)]
amine@277 122 tokenizer = TokenizerWorker(
amine@277 123 self.reader,
amine@277 124 logger=logger,
amine@277 125 observers=observers,
amine@277 126 min_dur=0.3,
amine@277 127 max_dur=2,
amine@277 128 max_silence=0.2,
amine@277 129 drop_trailing_silence=False,
amine@277 130 strict_min_dur=False,
amine@277 131 eth=50,
amine@277 132 )
amine@277 133 with patch("auditok.core.AudioRegion.save") as patched_save:
amine@277 134 tokenizer.start_all()
amine@277 135 tokenizer.join()
amine@277 136 tokenizer._observers[0].join()
amine@277 137 # Get logged text
amine@277 138 with open(file) as fp:
amine@277 139 log_lines = [
amine@277 140 line
amine@277 141 for line in fp.readlines()
amine@277 142 if line.startswith("[SAVE]")
amine@277 143 ]
amine@277 144
amine@282 145 # Assert RegionSaverWorker ran as expected
amine@277 146 expected_save_calls = [
amine@277 147 call(
amine@277 148 filename_format.format(
amine@277 149 id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
amine@277 150 ),
amine@277 151 None,
amine@277 152 )
amine@277 153 for i, exp in enumerate(self.expected, 1)
amine@277 154 ]
amine@277 155
amine@282 156 # Get calls to 'AudioRegion.save'
amine@277 157 mock_calls = [
amine@277 158 c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
amine@277 159 ]
amine@277 160 self.assertEqual(mock_calls, expected_save_calls)
amine@277 161 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@277 162
amine@279 163 log_fmt = "[SAVE]: Detection {id} saved as '{filename}'"
amine@277 164 for i, (det, exp, log_line) in enumerate(
amine@277 165 zip(tokenizer.detections, self.expected, log_lines), 1
amine@277 166 ):
amine@277 167 start, end = exp
amine@277 168 expected_filename = filename_format.format(
amine@277 169 id=i, start=start, end=end, duration=end - start
amine@277 170 )
amine@277 171 exp_log_line = log_fmt.format(i, expected_filename)
amine@277 172 self.assertAlmostEqual(det.start, start)
amine@277 173 self.assertAlmostEqual(det.end, end)
amine@282 174 # Remove timestamp part and strip new line
amine@277 175 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@277 176
amine@279 177 def test_CommandLineWorker(self):
amine@279 178 command_format = "do nothing with"
amine@279 179 with TemporaryDirectory() as tmpdir:
amine@279 180 file = os.path.join(tmpdir, "file.log")
amine@279 181 logger = make_logger(file=file, name="test_CommandLineWorker")
amine@279 182 observers = [CommandLineWorker(command_format, logger=logger)]
amine@279 183 tokenizer = TokenizerWorker(
amine@279 184 self.reader,
amine@279 185 logger=logger,
amine@279 186 observers=observers,
amine@279 187 min_dur=0.3,
amine@279 188 max_dur=2,
amine@279 189 max_silence=0.2,
amine@279 190 drop_trailing_silence=False,
amine@279 191 strict_min_dur=False,
amine@279 192 eth=50,
amine@279 193 )
amine@279 194 with patch("auditok.workers.os.system") as patched_os_system:
amine@279 195 tokenizer.start_all()
amine@279 196 tokenizer.join()
amine@279 197 tokenizer._observers[0].join()
amine@279 198 # Get logged text
amine@279 199 with open(file) as fp:
amine@279 200 log_lines = [
amine@279 201 line
amine@279 202 for line in fp.readlines()
amine@279 203 if line.startswith("[COMMAND]")
amine@279 204 ]
amine@279 205
amine@282 206 # Assert CommandLineWorker ran as expected
amine@279 207 expected_save_calls = [call(command_format) for _ in self.expected]
amine@279 208 self.assertEqual(patched_os_system.mock_calls, expected_save_calls)
amine@279 209 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@279 210 log_fmt = "[COMMAND]: Detection {id} command '{command}'"
amine@279 211 for i, (det, exp, log_line) in enumerate(
amine@279 212 zip(tokenizer.detections, self.expected, log_lines), 1
amine@279 213 ):
amine@279 214 start, end = exp
amine@279 215 exp_log_line = log_fmt.format(i, command_format)
amine@279 216 self.assertAlmostEqual(det.start, start)
amine@279 217 self.assertAlmostEqual(det.end, end)
amine@282 218 # Remove timestamp part and strip new line
amine@279 219 self.assertEqual(log_line[28:].strip(), exp_log_line)
amine@279 220
amine@275 221 def test_PrintWorker(self):
amine@275 222 observers = [
amine@275 223 PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")
amine@275 224 ]
amine@275 225 tokenizer = TokenizerWorker(
amine@275 226 self.reader,
amine@275 227 observers=observers,
amine@275 228 min_dur=0.3,
amine@275 229 max_dur=2,
amine@275 230 max_silence=0.2,
amine@275 231 drop_trailing_silence=False,
amine@275 232 strict_min_dur=False,
amine@275 233 eth=50,
amine@275 234 )
amine@275 235 with patch("builtins.print") as patched_print:
amine@275 236 tokenizer.start_all()
amine@275 237 tokenizer.join()
amine@275 238 tokenizer._observers[0].join()
amine@275 239
amine@282 240 # Assert PrintWorker ran as expected
amine@275 241 expected_print_calls = [
amine@275 242 call(
amine@275 243 "[{}] {:.3f} {:.3f}, dur: {:.3f}".format(
amine@275 244 i, *exp, exp[1] - exp[0]
amine@275 245 )
amine@275 246 )
amine@275 247 for i, exp in enumerate(self.expected, 1)
amine@275 248 ]
amine@275 249 self.assertEqual(patched_print.mock_calls, expected_print_calls)
amine@275 250 self.assertEqual(len(tokenizer.detections), len(self.expected))
amine@275 251 for det, exp in zip(tokenizer.detections, self.expected):
amine@275 252 start, end = exp
amine@275 253 self.assertAlmostEqual(det.start, start)
amine@275 254 self.assertAlmostEqual(det.end, end)