changeset 274:961f35fc09a8

Add test_workers.py - Add test for TokenizerWorker
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 18 Sep 2019 20:21:10 +0200
parents fe8d97a97ea5
children a1388f0d18d3
files tests/test_workers.py
diffstat 1 files changed, 69 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test_workers.py	Wed Sep 18 20:21:10 2019 +0200
@@ -0,0 +1,69 @@
+import os
+from unittest import TestCase
+from unittest.mock import patch, call
+from tempfile import TemporaryDirectory
+from genty import genty, genty_dataset
+from auditok import AudioDataSource
+from auditok.cmdline_util import make_logger
+from auditok.workers import (
+    TokenizerWorker,
+    StreamSaverWorker,
+    RegionSaverWorker,
+    PlayerWorker,
+    CommandLineWorker,
+    PrintWorker,
+)
+
+
+@genty
+class TestWorkers(TestCase):
+    def test_TokenizerWorker(self):
+        reader = AudioDataSource(
+            input="tests/data/test_split_10HZ_mono.raw",
+            block_dur=0.1,
+            sr=10,
+            sw=2,
+            ch=1,
+        )
+        with TemporaryDirectory() as tmpdir:
+            file = os.path.join(tmpdir, "file.log")
+            observers = [PrintWorker()]
+            logger = make_logger(file=file, name="test_TokenizerWorker")
+            tokenizer = TokenizerWorker(
+                reader,
+                observers=observers,
+                logger=logger,
+                min_dur=0.3,
+                max_dur=2,
+                max_silence=0.2,
+                drop_trailing_silence=False,
+                strict_min_dur=False,
+                eth=50,
+            )
+            with patch("builtins.print") as patched_print:
+                tokenizer.start_all()
+                tokenizer.join()
+                tokenizer._observers[0].join()
+            # Get logged text
+            with open(file) as fp:
+                log_lines = fp.readlines()
+
+        expected = [(0.2, 1.6), (1.7, 3.1), (3.4, 5.4), (5.4, 7.4), (7.4, 7.6)]
+        # Asser PrintWorker ran as expected
+        expected_print_calls = [
+            call("{:.3f} {:.3f}".format(*exp)) for exp in expected
+        ]
+        self.assertEqual(patched_print.mock_calls, expected_print_calls)
+        self.assertEqual(len(tokenizer.detections), len(expected))
+
+        log_fmt = "[DET]: Detection {} (start: {:.3f}, "
+        log_fmt += "end: {:.3f}, duration: {:.3f})"
+        for i, (det, exp, log_line) in enumerate(
+            zip(tokenizer.detections, expected, log_lines), 1
+        ):
+            start, end = exp
+            exp_log_line = log_fmt.format(i, start, end, end - start)
+            self.assertAlmostEqual(det.start, start)
+            self.assertAlmostEqual(det.end, end)
+            # remove timestamp part and strip new line
+            self.assertEqual(log_line[28:].strip(), exp_log_line)