Mercurial > hg > auditok
view tests/test_workers.py @ 275:a1388f0d18d3
Add test for PrintWorker
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Fri, 20 Sep 2019 21:02:08 +0200 |
parents | 961f35fc09a8 |
children | 0a5374fcb700 |
line wrap: on
line source
import os from unittest import TestCase from unittest.mock import patch, call from tempfile import TemporaryDirectory from genty import genty, genty_dataset from auditok import AudioDataSource from auditok.cmdline_util import make_logger from auditok.workers import ( TokenizerWorker, StreamSaverWorker, RegionSaverWorker, PlayerWorker, CommandLineWorker, PrintWorker, ) @genty class TestWorkers(TestCase): def setUp(self): self.reader = AudioDataSource( input="tests/data/test_split_10HZ_mono.raw", block_dur=0.1, sr=10, sw=2, ch=1, ) self.expected = [ (0.2, 1.6), (1.7, 3.1), (3.4, 5.4), (5.4, 7.4), (7.4, 7.6), ] def tearDown(self): self.reader.close() def test_TokenizerWorker(self): with TemporaryDirectory() as tmpdir: file = os.path.join(tmpdir, "file.log") logger = make_logger(file=file, name="test_TokenizerWorker") tokenizer = TokenizerWorker( self.reader, logger=logger, min_dur=0.3, max_dur=2, max_silence=0.2, drop_trailing_silence=False, strict_min_dur=False, eth=50, ) tokenizer.start_all() tokenizer.join() # Get logged text with open(file) as fp: log_lines = fp.readlines() log_fmt = "[DET]: Detection {} (start: {:.3f}, " log_fmt += "end: {:.3f}, duration: {:.3f})" self.assertEqual(len(tokenizer.detections), len(self.expected)) for i, (det, exp, log_line) in enumerate( zip(tokenizer.detections, self.expected, log_lines), 1 ): start, end = exp exp_log_line = log_fmt.format(i, start, end, end - start) self.assertAlmostEqual(det.start, start) self.assertAlmostEqual(det.end, end) # remove timestamp part and strip new line self.assertEqual(log_line[28:].strip(), exp_log_line) def test_PrintWorker(self): observers = [ PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") ] tokenizer = TokenizerWorker( self.reader, observers=observers, min_dur=0.3, max_dur=2, max_silence=0.2, drop_trailing_silence=False, strict_min_dur=False, eth=50, ) with patch("builtins.print") as patched_print: tokenizer.start_all() tokenizer.join() tokenizer._observers[0].join() # Asser PrintWorker ran as expected expected_print_calls = [ call( "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( i, *exp, exp[1] - exp[0] ) ) for i, exp in enumerate(self.expected, 1) ] self.assertEqual(patched_print.mock_calls, expected_print_calls) self.assertEqual(len(tokenizer.detections), len(self.expected)) for det, exp in zip(tokenizer.detections, self.expected): start, end = exp self.assertAlmostEqual(det.start, start) self.assertAlmostEqual(det.end, end)