# HG changeset patch # User Amine Sehili # Date 1569006128 -7200 # Node ID a1388f0d18d3a29de00e171d0fa29729c42f0963 # Parent 961f35fc09a8f90e626ec9dcad69b7ca09998895 Add test for PrintWorker diff -r 961f35fc09a8 -r a1388f0d18d3 tests/test_workers.py --- a/tests/test_workers.py Wed Sep 18 20:21:10 2019 +0200 +++ b/tests/test_workers.py Fri Sep 20 21:02:08 2019 +0200 @@ -17,21 +17,33 @@ @genty class TestWorkers(TestCase): - def test_TokenizerWorker(self): - reader = AudioDataSource( + def setUp(self): + + self.reader = AudioDataSource( input="tests/data/test_split_10HZ_mono.raw", block_dur=0.1, sr=10, sw=2, ch=1, ) + self.expected = [ + (0.2, 1.6), + (1.7, 3.1), + (3.4, 5.4), + (5.4, 7.4), + (7.4, 7.6), + ] + + def tearDown(self): + self.reader.close() + + def test_TokenizerWorker(self): + with TemporaryDirectory() as tmpdir: file = os.path.join(tmpdir, "file.log") - observers = [PrintWorker()] logger = make_logger(file=file, name="test_TokenizerWorker") tokenizer = TokenizerWorker( - reader, - observers=observers, + self.reader, logger=logger, min_dur=0.3, max_dur=2, @@ -40,26 +52,17 @@ strict_min_dur=False, eth=50, ) - with patch("builtins.print") as patched_print: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() + tokenizer.start_all() + tokenizer.join() # Get logged text with open(file) as fp: log_lines = fp.readlines() - expected = [(0.2, 1.6), (1.7, 3.1), (3.4, 5.4), (5.4, 7.4), (7.4, 7.6)] - # Asser PrintWorker ran as expected - expected_print_calls = [ - call("{:.3f} {:.3f}".format(*exp)) for exp in expected - ] - self.assertEqual(patched_print.mock_calls, expected_print_calls) - self.assertEqual(len(tokenizer.detections), len(expected)) - log_fmt = "[DET]: Detection {} (start: {:.3f}, " log_fmt += "end: {:.3f}, duration: {:.3f})" + self.assertEqual(len(tokenizer.detections), len(self.expected)) for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, expected, log_lines), 1 + zip(tokenizer.detections, self.expected, log_lines), 1 ): start, end = exp exp_log_line = log_fmt.format(i, start, end, end - start) @@ -67,3 +70,39 @@ self.assertAlmostEqual(det.end, end) # remove timestamp part and strip new line self.assertEqual(log_line[28:].strip(), exp_log_line) + + def test_PrintWorker(self): + observers = [ + PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") + ] + tokenizer = TokenizerWorker( + self.reader, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + + # Asser PrintWorker ran as expected + expected_print_calls = [ + call( + "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( + i, *exp, exp[1] - exp[0] + ) + ) + for i, exp in enumerate(self.expected, 1) + ] + self.assertEqual(patched_print.mock_calls, expected_print_calls) + self.assertEqual(len(tokenizer.detections), len(self.expected)) + + for det, exp in zip(tokenizer.detections, self.expected): + start, end = exp + self.assertAlmostEqual(det.start, start) + self.assertAlmostEqual(det.end, end)