Mercurial > hg > auditok
changeset 277:0a5374fcb700
Add test for RegionSaverWorker
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Sat, 21 Sep 2019 12:13:01 +0100 |
parents | f0252da17455 |
children | b9d52cf32c89 |
files | tests/test_workers.py |
diffstat | 1 files changed, 63 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/tests/test_workers.py Sat Sep 21 11:39:49 2019 +0100 +++ b/tests/test_workers.py Sat Sep 21 12:13:01 2019 +0100 @@ -71,6 +71,69 @@ # remove timestamp part and strip new line self.assertEqual(log_line[28:].strip(), exp_log_line) + def test_RegionSaverWorker(self): + filename_format = ( + "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" + ) + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_RegionSaverWorker") + observers = [RegionSaverWorker(filename_format, logger=logger)] + tokenizer = TokenizerWorker( + self.reader, + logger=logger, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("auditok.core.AudioRegion.save") as patched_save: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + # Get logged text + with open(file) as fp: + log_lines = [ + line + for line in fp.readlines() + if line.startswith("[SAVE]") + ] + + # Asser PrintWorker ran as expected + expected_save_calls = [ + call( + filename_format.format( + id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] + ), + None, + ) + for i, exp in enumerate(self.expected, 1) + ] + + # get calls to 'AudioRegion.save' + mock_calls = [ + c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 + ] + self.assertEqual(mock_calls, expected_save_calls) + self.assertEqual(len(tokenizer.detections), len(self.expected)) + + log_fmt = '[SAVE]: Detection {id} saved as "{filename}"' + for i, (det, exp, log_line) in enumerate( + zip(tokenizer.detections, self.expected, log_lines), 1 + ): + start, end = exp + expected_filename = filename_format.format( + id=i, start=start, end=end, duration=end - start + ) + exp_log_line = log_fmt.format(i, expected_filename) + self.assertAlmostEqual(det.start, start) + self.assertAlmostEqual(det.end, end) + # remove timestamp part and strip new line + self.assertEqual(log_line[28:].strip(), exp_log_line) + def test_PrintWorker(self): observers = [ PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")