Mercurial > hg > auditok
changeset 275:a1388f0d18d3
Add test for PrintWorker
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Fri, 20 Sep 2019 21:02:08 +0200 |
parents | 961f35fc09a8 |
children | f0252da17455 |
files | tests/test_workers.py |
diffstat | 1 files changed, 57 insertions(+), 18 deletions(-) [+] |
line wrap: on
line diff
--- a/tests/test_workers.py Wed Sep 18 20:21:10 2019 +0200 +++ b/tests/test_workers.py Fri Sep 20 21:02:08 2019 +0200 @@ -17,21 +17,33 @@ @genty class TestWorkers(TestCase): - def test_TokenizerWorker(self): - reader = AudioDataSource( + def setUp(self): + + self.reader = AudioDataSource( input="tests/data/test_split_10HZ_mono.raw", block_dur=0.1, sr=10, sw=2, ch=1, ) + self.expected = [ + (0.2, 1.6), + (1.7, 3.1), + (3.4, 5.4), + (5.4, 7.4), + (7.4, 7.6), + ] + + def tearDown(self): + self.reader.close() + + def test_TokenizerWorker(self): + with TemporaryDirectory() as tmpdir: file = os.path.join(tmpdir, "file.log") - observers = [PrintWorker()] logger = make_logger(file=file, name="test_TokenizerWorker") tokenizer = TokenizerWorker( - reader, - observers=observers, + self.reader, logger=logger, min_dur=0.3, max_dur=2, @@ -40,26 +52,17 @@ strict_min_dur=False, eth=50, ) - with patch("builtins.print") as patched_print: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() + tokenizer.start_all() + tokenizer.join() # Get logged text with open(file) as fp: log_lines = fp.readlines() - expected = [(0.2, 1.6), (1.7, 3.1), (3.4, 5.4), (5.4, 7.4), (7.4, 7.6)] - # Asser PrintWorker ran as expected - expected_print_calls = [ - call("{:.3f} {:.3f}".format(*exp)) for exp in expected - ] - self.assertEqual(patched_print.mock_calls, expected_print_calls) - self.assertEqual(len(tokenizer.detections), len(expected)) - log_fmt = "[DET]: Detection {} (start: {:.3f}, " log_fmt += "end: {:.3f}, duration: {:.3f})" + self.assertEqual(len(tokenizer.detections), len(self.expected)) for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, expected, log_lines), 1 + zip(tokenizer.detections, self.expected, log_lines), 1 ): start, end = exp exp_log_line = log_fmt.format(i, start, end, end - start) @@ -67,3 +70,39 @@ self.assertAlmostEqual(det.end, end) # remove timestamp part and strip new line self.assertEqual(log_line[28:].strip(), exp_log_line) + + def test_PrintWorker(self): + observers = [ + PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") + ] + tokenizer = TokenizerWorker( + self.reader, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + + # Asser PrintWorker ran as expected + expected_print_calls = [ + call( + "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( + i, *exp, exp[1] - exp[0] + ) + ) + for i, exp in enumerate(self.expected, 1) + ] + self.assertEqual(patched_print.mock_calls, expected_print_calls) + self.assertEqual(len(tokenizer.detections), len(self.expected)) + + for det, exp in zip(tokenizer.detections, self.expected): + start, end = exp + self.assertAlmostEqual(det.start, start) + self.assertAlmostEqual(det.end, end)