changeset 277:0a5374fcb700

Add test for RegionSaverWorker
author Amine Sehili <amine.sehili@gmail.com>
date Sat, 21 Sep 2019 12:13:01 +0100
parents f0252da17455
children b9d52cf32c89
files tests/test_workers.py
diffstat 1 files changed, 63 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/tests/test_workers.py	Sat Sep 21 11:39:49 2019 +0100
+++ b/tests/test_workers.py	Sat Sep 21 12:13:01 2019 +0100
@@ -71,6 +71,69 @@
             # remove timestamp part and strip new line
             self.assertEqual(log_line[28:].strip(), exp_log_line)
 
+    def test_RegionSaverWorker(self):
+        filename_format = (
+            "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav"
+        )
+        with TemporaryDirectory() as tmpdir:
+            file = os.path.join(tmpdir, "file.log")
+            logger = make_logger(file=file, name="test_RegionSaverWorker")
+            observers = [RegionSaverWorker(filename_format, logger=logger)]
+            tokenizer = TokenizerWorker(
+                self.reader,
+                logger=logger,
+                observers=observers,
+                min_dur=0.3,
+                max_dur=2,
+                max_silence=0.2,
+                drop_trailing_silence=False,
+                strict_min_dur=False,
+                eth=50,
+            )
+            with patch("auditok.core.AudioRegion.save") as patched_save:
+                tokenizer.start_all()
+                tokenizer.join()
+                tokenizer._observers[0].join()
+            # Get logged text
+            with open(file) as fp:
+                log_lines = [
+                    line
+                    for line in fp.readlines()
+                    if line.startswith("[SAVE]")
+                ]
+
+        # Asser PrintWorker ran as expected
+        expected_save_calls = [
+            call(
+                filename_format.format(
+                    id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0]
+                ),
+                None,
+            )
+            for i, exp in enumerate(self.expected, 1)
+        ]
+
+        # get calls to 'AudioRegion.save'
+        mock_calls = [
+            c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0
+        ]
+        self.assertEqual(mock_calls, expected_save_calls)
+        self.assertEqual(len(tokenizer.detections), len(self.expected))
+
+        log_fmt = '[SAVE]: Detection {id} saved as "{filename}"'
+        for i, (det, exp, log_line) in enumerate(
+            zip(tokenizer.detections, self.expected, log_lines), 1
+        ):
+            start, end = exp
+            expected_filename = filename_format.format(
+                id=i, start=start, end=end, duration=end - start
+            )
+            exp_log_line = log_fmt.format(i, expected_filename)
+            self.assertAlmostEqual(det.start, start)
+            self.assertAlmostEqual(det.end, end)
+            # remove timestamp part and strip new line
+            self.assertEqual(log_line[28:].strip(), exp_log_line)
+
     def test_PrintWorker(self):
         observers = [
             PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}")