amine@274: import os amine@274: from unittest import TestCase amine@282: from unittest.mock import patch, call, Mock amine@274: from tempfile import TemporaryDirectory amine@274: from genty import genty, genty_dataset amine@274: from auditok import AudioDataSource amine@274: from auditok.cmdline_util import make_logger amine@274: from auditok.workers import ( amine@274: TokenizerWorker, amine@274: StreamSaverWorker, amine@274: RegionSaverWorker, amine@274: PlayerWorker, amine@274: CommandLineWorker, amine@274: PrintWorker, amine@274: ) amine@274: amine@274: amine@274: @genty amine@274: class TestWorkers(TestCase): amine@275: def setUp(self): amine@275: amine@275: self.reader = AudioDataSource( amine@274: input="tests/data/test_split_10HZ_mono.raw", amine@274: block_dur=0.1, amine@274: sr=10, amine@274: sw=2, amine@274: ch=1, amine@274: ) amine@275: self.expected = [ amine@275: (0.2, 1.6), amine@275: (1.7, 3.1), amine@275: (3.4, 5.4), amine@275: (5.4, 7.4), amine@275: (7.4, 7.6), amine@275: ] amine@275: amine@275: def tearDown(self): amine@275: self.reader.close() amine@275: amine@275: def test_TokenizerWorker(self): amine@274: with TemporaryDirectory() as tmpdir: amine@274: file = os.path.join(tmpdir, "file.log") amine@274: logger = make_logger(file=file, name="test_TokenizerWorker") amine@274: tokenizer = TokenizerWorker( amine@275: self.reader, amine@274: logger=logger, amine@274: min_dur=0.3, amine@274: max_dur=2, amine@274: max_silence=0.2, amine@274: drop_trailing_silence=False, amine@274: strict_min_dur=False, amine@274: eth=50, amine@274: ) amine@275: tokenizer.start_all() amine@275: tokenizer.join() amine@274: # Get logged text amine@274: with open(file) as fp: amine@274: log_lines = fp.readlines() amine@274: amine@274: log_fmt = "[DET]: Detection {} (start: {:.3f}, " amine@274: log_fmt += "end: {:.3f}, duration: {:.3f})" amine@275: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@274: for i, (det, exp, log_line) in enumerate( amine@275: zip(tokenizer.detections, self.expected, log_lines), 1 amine@274: ): amine@274: start, end = exp amine@274: exp_log_line = log_fmt.format(i, start, end, end - start) amine@274: self.assertAlmostEqual(det.start, start) amine@274: self.assertAlmostEqual(det.end, end) amine@274: # remove timestamp part and strip new line amine@274: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@275: amine@282: def test_PlayerWorker(self): amine@282: with TemporaryDirectory() as tmpdir: amine@282: file = os.path.join(tmpdir, "file.log") amine@282: logger = make_logger(file=file, name="test_RegionSaverWorker") amine@282: player_mock = Mock() amine@282: observers = [PlayerWorker(player_mock, logger=logger)] amine@282: tokenizer = TokenizerWorker( amine@282: self.reader, amine@282: logger=logger, amine@282: observers=observers, amine@282: min_dur=0.3, amine@282: max_dur=2, amine@282: max_silence=0.2, amine@282: drop_trailing_silence=False, amine@282: strict_min_dur=False, amine@282: eth=50, amine@282: ) amine@282: tokenizer.start_all() amine@282: tokenizer.join() amine@282: tokenizer._observers[0].join() amine@282: # Get logged text amine@282: with open(file) as fp: amine@282: log_lines = [ amine@282: line amine@282: for line in fp.readlines() amine@282: if line.startswith("[PLAY]") amine@282: ] amine@282: self.assertTrue(player_mock.play.called) amine@282: amine@282: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@282: log_fmt = "[PLAY]: Detection {id} played" amine@282: for i, (det, exp, log_line) in enumerate( amine@282: zip(tokenizer.detections, self.expected, log_lines), 1 amine@282: ): amine@282: start, end = exp amine@282: exp_log_line = log_fmt.format(id=i) amine@282: self.assertAlmostEqual(det.start, start) amine@282: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@282: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@282: amine@277: def test_RegionSaverWorker(self): amine@277: filename_format = ( amine@277: "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" amine@277: ) amine@277: with TemporaryDirectory() as tmpdir: amine@277: file = os.path.join(tmpdir, "file.log") amine@277: logger = make_logger(file=file, name="test_RegionSaverWorker") amine@277: observers = [RegionSaverWorker(filename_format, logger=logger)] amine@277: tokenizer = TokenizerWorker( amine@277: self.reader, amine@277: logger=logger, amine@277: observers=observers, amine@277: min_dur=0.3, amine@277: max_dur=2, amine@277: max_silence=0.2, amine@277: drop_trailing_silence=False, amine@277: strict_min_dur=False, amine@277: eth=50, amine@277: ) amine@277: with patch("auditok.core.AudioRegion.save") as patched_save: amine@277: tokenizer.start_all() amine@277: tokenizer.join() amine@277: tokenizer._observers[0].join() amine@277: # Get logged text amine@277: with open(file) as fp: amine@277: log_lines = [ amine@277: line amine@277: for line in fp.readlines() amine@277: if line.startswith("[SAVE]") amine@277: ] amine@277: amine@282: # Assert RegionSaverWorker ran as expected amine@277: expected_save_calls = [ amine@277: call( amine@277: filename_format.format( amine@277: id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] amine@277: ), amine@277: None, amine@277: ) amine@277: for i, exp in enumerate(self.expected, 1) amine@277: ] amine@277: amine@282: # Get calls to 'AudioRegion.save' amine@277: mock_calls = [ amine@277: c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 amine@277: ] amine@277: self.assertEqual(mock_calls, expected_save_calls) amine@277: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@277: amine@279: log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" amine@277: for i, (det, exp, log_line) in enumerate( amine@277: zip(tokenizer.detections, self.expected, log_lines), 1 amine@277: ): amine@277: start, end = exp amine@277: expected_filename = filename_format.format( amine@277: id=i, start=start, end=end, duration=end - start amine@277: ) amine@277: exp_log_line = log_fmt.format(i, expected_filename) amine@277: self.assertAlmostEqual(det.start, start) amine@277: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@277: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@277: amine@279: def test_CommandLineWorker(self): amine@279: command_format = "do nothing with" amine@279: with TemporaryDirectory() as tmpdir: amine@279: file = os.path.join(tmpdir, "file.log") amine@279: logger = make_logger(file=file, name="test_CommandLineWorker") amine@279: observers = [CommandLineWorker(command_format, logger=logger)] amine@279: tokenizer = TokenizerWorker( amine@279: self.reader, amine@279: logger=logger, amine@279: observers=observers, amine@279: min_dur=0.3, amine@279: max_dur=2, amine@279: max_silence=0.2, amine@279: drop_trailing_silence=False, amine@279: strict_min_dur=False, amine@279: eth=50, amine@279: ) amine@279: with patch("auditok.workers.os.system") as patched_os_system: amine@279: tokenizer.start_all() amine@279: tokenizer.join() amine@279: tokenizer._observers[0].join() amine@279: # Get logged text amine@279: with open(file) as fp: amine@279: log_lines = [ amine@279: line amine@279: for line in fp.readlines() amine@279: if line.startswith("[COMMAND]") amine@279: ] amine@279: amine@282: # Assert CommandLineWorker ran as expected amine@279: expected_save_calls = [call(command_format) for _ in self.expected] amine@279: self.assertEqual(patched_os_system.mock_calls, expected_save_calls) amine@279: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@279: log_fmt = "[COMMAND]: Detection {id} command '{command}'" amine@279: for i, (det, exp, log_line) in enumerate( amine@279: zip(tokenizer.detections, self.expected, log_lines), 1 amine@279: ): amine@279: start, end = exp amine@279: exp_log_line = log_fmt.format(i, command_format) amine@279: self.assertAlmostEqual(det.start, start) amine@279: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@279: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@279: amine@275: def test_PrintWorker(self): amine@275: observers = [ amine@275: PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") amine@275: ] amine@275: tokenizer = TokenizerWorker( amine@275: self.reader, amine@275: observers=observers, amine@275: min_dur=0.3, amine@275: max_dur=2, amine@275: max_silence=0.2, amine@275: drop_trailing_silence=False, amine@275: strict_min_dur=False, amine@275: eth=50, amine@275: ) amine@275: with patch("builtins.print") as patched_print: amine@275: tokenizer.start_all() amine@275: tokenizer.join() amine@275: tokenizer._observers[0].join() amine@275: amine@282: # Assert PrintWorker ran as expected amine@275: expected_print_calls = [ amine@275: call( amine@275: "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( amine@275: i, *exp, exp[1] - exp[0] amine@275: ) amine@275: ) amine@275: for i, exp in enumerate(self.expected, 1) amine@275: ] amine@275: self.assertEqual(patched_print.mock_calls, expected_print_calls) amine@275: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@275: for det, exp in zip(tokenizer.detections, self.expected): amine@275: start, end = exp amine@275: self.assertAlmostEqual(det.start, start) amine@275: self.assertAlmostEqual(det.end, end)