amine@274: import os amine@274: from unittest import TestCase amine@282: from unittest.mock import patch, call, Mock amine@274: from tempfile import TemporaryDirectory amine@274: from genty import genty, genty_dataset amine@287: from auditok import AudioRegion, AudioDataSource amine@292: from auditok.exceptions import AudioEncodingWarning amine@274: from auditok.cmdline_util import make_logger amine@274: from auditok.workers import ( amine@274: TokenizerWorker, amine@274: StreamSaverWorker, amine@274: RegionSaverWorker, amine@274: PlayerWorker, amine@274: CommandLineWorker, amine@274: PrintWorker, amine@274: ) amine@274: amine@274: amine@274: @genty amine@274: class TestWorkers(TestCase): amine@275: def setUp(self): amine@275: amine@275: self.reader = AudioDataSource( amine@274: input="tests/data/test_split_10HZ_mono.raw", amine@274: block_dur=0.1, amine@274: sr=10, amine@274: sw=2, amine@274: ch=1, amine@274: ) amine@275: self.expected = [ amine@275: (0.2, 1.6), amine@275: (1.7, 3.1), amine@275: (3.4, 5.4), amine@275: (5.4, 7.4), amine@275: (7.4, 7.6), amine@275: ] amine@275: amine@275: def tearDown(self): amine@275: self.reader.close() amine@275: amine@275: def test_TokenizerWorker(self): amine@274: with TemporaryDirectory() as tmpdir: amine@274: file = os.path.join(tmpdir, "file.log") amine@274: logger = make_logger(file=file, name="test_TokenizerWorker") amine@274: tokenizer = TokenizerWorker( amine@275: self.reader, amine@274: logger=logger, amine@274: min_dur=0.3, amine@274: max_dur=2, amine@274: max_silence=0.2, amine@274: drop_trailing_silence=False, amine@274: strict_min_dur=False, amine@274: eth=50, amine@274: ) amine@275: tokenizer.start_all() amine@275: tokenizer.join() amine@274: # Get logged text amine@274: with open(file) as fp: amine@274: log_lines = fp.readlines() amine@274: amine@274: log_fmt = "[DET]: Detection {} (start: {:.3f}, " amine@274: log_fmt += "end: {:.3f}, duration: {:.3f})" amine@275: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@274: for i, (det, exp, log_line) in enumerate( amine@275: zip(tokenizer.detections, self.expected, log_lines), 1 amine@274: ): amine@274: start, end = exp amine@274: exp_log_line = log_fmt.format(i, start, end, end - start) amine@274: self.assertAlmostEqual(det.start, start) amine@274: self.assertAlmostEqual(det.end, end) amine@274: # remove timestamp part and strip new line amine@274: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@275: amine@282: def test_PlayerWorker(self): amine@282: with TemporaryDirectory() as tmpdir: amine@282: file = os.path.join(tmpdir, "file.log") amine@282: logger = make_logger(file=file, name="test_RegionSaverWorker") amine@282: player_mock = Mock() amine@282: observers = [PlayerWorker(player_mock, logger=logger)] amine@282: tokenizer = TokenizerWorker( amine@282: self.reader, amine@282: logger=logger, amine@282: observers=observers, amine@282: min_dur=0.3, amine@282: max_dur=2, amine@282: max_silence=0.2, amine@282: drop_trailing_silence=False, amine@282: strict_min_dur=False, amine@282: eth=50, amine@282: ) amine@282: tokenizer.start_all() amine@282: tokenizer.join() amine@282: tokenizer._observers[0].join() amine@282: # Get logged text amine@282: with open(file) as fp: amine@282: log_lines = [ amine@282: line amine@282: for line in fp.readlines() amine@282: if line.startswith("[PLAY]") amine@282: ] amine@282: self.assertTrue(player_mock.play.called) amine@282: amine@282: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@282: log_fmt = "[PLAY]: Detection {id} played" amine@282: for i, (det, exp, log_line) in enumerate( amine@282: zip(tokenizer.detections, self.expected, log_lines), 1 amine@282: ): amine@282: start, end = exp amine@282: exp_log_line = log_fmt.format(id=i) amine@282: self.assertAlmostEqual(det.start, start) amine@282: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@282: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@282: amine@277: def test_RegionSaverWorker(self): amine@277: filename_format = ( amine@277: "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" amine@277: ) amine@277: with TemporaryDirectory() as tmpdir: amine@277: file = os.path.join(tmpdir, "file.log") amine@277: logger = make_logger(file=file, name="test_RegionSaverWorker") amine@277: observers = [RegionSaverWorker(filename_format, logger=logger)] amine@277: tokenizer = TokenizerWorker( amine@277: self.reader, amine@277: logger=logger, amine@277: observers=observers, amine@277: min_dur=0.3, amine@277: max_dur=2, amine@277: max_silence=0.2, amine@277: drop_trailing_silence=False, amine@277: strict_min_dur=False, amine@277: eth=50, amine@277: ) amine@277: with patch("auditok.core.AudioRegion.save") as patched_save: amine@277: tokenizer.start_all() amine@277: tokenizer.join() amine@277: tokenizer._observers[0].join() amine@277: # Get logged text amine@277: with open(file) as fp: amine@277: log_lines = [ amine@277: line amine@277: for line in fp.readlines() amine@277: if line.startswith("[SAVE]") amine@277: ] amine@277: amine@282: # Assert RegionSaverWorker ran as expected amine@277: expected_save_calls = [ amine@277: call( amine@277: filename_format.format( amine@277: id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] amine@277: ), amine@277: None, amine@277: ) amine@277: for i, exp in enumerate(self.expected, 1) amine@277: ] amine@277: amine@282: # Get calls to 'AudioRegion.save' amine@277: mock_calls = [ amine@277: c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 amine@277: ] amine@277: self.assertEqual(mock_calls, expected_save_calls) amine@277: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@277: amine@279: log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" amine@277: for i, (det, exp, log_line) in enumerate( amine@277: zip(tokenizer.detections, self.expected, log_lines), 1 amine@277: ): amine@277: start, end = exp amine@277: expected_filename = filename_format.format( amine@277: id=i, start=start, end=end, duration=end - start amine@277: ) amine@277: exp_log_line = log_fmt.format(i, expected_filename) amine@277: self.assertAlmostEqual(det.start, start) amine@277: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@277: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@277: amine@279: def test_CommandLineWorker(self): amine@279: command_format = "do nothing with" amine@279: with TemporaryDirectory() as tmpdir: amine@279: file = os.path.join(tmpdir, "file.log") amine@279: logger = make_logger(file=file, name="test_CommandLineWorker") amine@279: observers = [CommandLineWorker(command_format, logger=logger)] amine@279: tokenizer = TokenizerWorker( amine@279: self.reader, amine@279: logger=logger, amine@279: observers=observers, amine@279: min_dur=0.3, amine@279: max_dur=2, amine@279: max_silence=0.2, amine@279: drop_trailing_silence=False, amine@279: strict_min_dur=False, amine@279: eth=50, amine@279: ) amine@279: with patch("auditok.workers.os.system") as patched_os_system: amine@279: tokenizer.start_all() amine@279: tokenizer.join() amine@279: tokenizer._observers[0].join() amine@279: # Get logged text amine@279: with open(file) as fp: amine@279: log_lines = [ amine@279: line amine@279: for line in fp.readlines() amine@279: if line.startswith("[COMMAND]") amine@279: ] amine@279: amine@282: # Assert CommandLineWorker ran as expected amine@279: expected_save_calls = [call(command_format) for _ in self.expected] amine@279: self.assertEqual(patched_os_system.mock_calls, expected_save_calls) amine@279: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@279: log_fmt = "[COMMAND]: Detection {id} command '{command}'" amine@279: for i, (det, exp, log_line) in enumerate( amine@279: zip(tokenizer.detections, self.expected, log_lines), 1 amine@279: ): amine@279: start, end = exp amine@279: exp_log_line = log_fmt.format(i, command_format) amine@279: self.assertAlmostEqual(det.start, start) amine@279: self.assertAlmostEqual(det.end, end) amine@282: # Remove timestamp part and strip new line amine@279: self.assertEqual(log_line[28:].strip(), exp_log_line) amine@279: amine@275: def test_PrintWorker(self): amine@275: observers = [ amine@275: PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") amine@275: ] amine@275: tokenizer = TokenizerWorker( amine@275: self.reader, amine@275: observers=observers, amine@275: min_dur=0.3, amine@275: max_dur=2, amine@275: max_silence=0.2, amine@275: drop_trailing_silence=False, amine@275: strict_min_dur=False, amine@275: eth=50, amine@275: ) amine@275: with patch("builtins.print") as patched_print: amine@275: tokenizer.start_all() amine@275: tokenizer.join() amine@275: tokenizer._observers[0].join() amine@275: amine@282: # Assert PrintWorker ran as expected amine@275: expected_print_calls = [ amine@275: call( amine@275: "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( amine@275: i, *exp, exp[1] - exp[0] amine@275: ) amine@275: ) amine@275: for i, exp in enumerate(self.expected, 1) amine@275: ] amine@275: self.assertEqual(patched_print.mock_calls, expected_print_calls) amine@275: self.assertEqual(len(tokenizer.detections), len(self.expected)) amine@275: for det, exp in zip(tokenizer.detections, self.expected): amine@275: start, end = exp amine@275: self.assertAlmostEqual(det.start, start) amine@275: self.assertAlmostEqual(det.end, end) amine@287: amine@287: def test_StreamSaverWorker_wav(self): amine@287: with TemporaryDirectory() as tmpdir: amine@287: expected_filename = os.path.join(tmpdir, "output.wav") amine@287: saver = StreamSaverWorker(self.reader, expected_filename) amine@287: saver.start() amine@287: amine@287: tokenizer = TokenizerWorker(saver) amine@287: tokenizer.start_all() amine@287: tokenizer.join() amine@287: saver.join() amine@287: amine@287: output_filename = saver.save_stream() amine@287: region = AudioRegion.load( amine@287: "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 amine@287: ) amine@287: amine@287: expected_region = AudioRegion.load(output_filename) amine@287: self.assertEqual(output_filename, expected_filename) amine@287: self.assertEqual(region, expected_region) amine@287: self.assertEqual(saver.data, bytes(expected_region)) amine@287: amine@287: def test_StreamSaverWorker_raw(self): amine@287: with TemporaryDirectory() as tmpdir: amine@287: expected_filename = os.path.join(tmpdir, "output") amine@287: saver = StreamSaverWorker( amine@287: self.reader, expected_filename, export_format="raw" amine@287: ) amine@287: saver.start() amine@287: tokenizer = TokenizerWorker(saver) amine@287: tokenizer.start_all() amine@287: tokenizer.join() amine@287: saver.join() amine@287: output_filename = saver.save_stream() amine@287: region = AudioRegion.load( amine@287: "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 amine@287: ) amine@287: expected_region = AudioRegion.load( amine@287: output_filename, sr=10, sw=2, ch=1, audio_format="raw" amine@287: ) amine@287: self.assertEqual(output_filename, expected_filename) amine@287: self.assertEqual(region, expected_region) amine@287: self.assertEqual(saver.data, bytes(expected_region)) amine@287: amine@287: def test_StreamSaverWorker_encode_audio(self): amine@287: with TemporaryDirectory() as tmpdir: amine@287: with patch("auditok.workers._run_subprocess") as patch_rsp: amine@287: patch_rsp.return_value = (1, None, None) amine@287: expected_filename = os.path.join(tmpdir, "output.ogg") amine@287: tmp_expected_filename = expected_filename + ".wav" amine@287: saver = StreamSaverWorker(self.reader, expected_filename) amine@287: saver.start() amine@287: tokenizer = TokenizerWorker(saver) amine@287: tokenizer.start_all() amine@287: tokenizer.join() amine@287: saver.join() amine@292: with self.assertRaises(AudioEncodingWarning) as rt_warn: amine@287: saver.save_stream() amine@287: warn_msg = "Couldn't save audio data in the desired format " amine@287: warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' " amine@287: warn_msg += "is installed or this format is not recognized.\n" amine@287: warn_msg += "Audio file was saved as '{}'" amine@287: self.assertEqual( amine@287: warn_msg.format(tmp_expected_filename), str(rt_warn.exception) amine@287: ) amine@287: ffmpef_avconv = [ amine@287: "-y", amine@287: "-f", amine@287: "wav", amine@287: "-i", amine@287: tmp_expected_filename, amine@287: "-f", amine@287: "ogg", amine@287: expected_filename, amine@287: ] amine@287: expected_calls = [ amine@287: call(["ffmpeg"] + ffmpef_avconv), amine@287: call(["avconv"] + ffmpef_avconv), amine@287: call( amine@287: [ amine@287: "sox", amine@287: "-t", amine@287: "wav", amine@287: tmp_expected_filename, amine@287: expected_filename, amine@287: ] amine@287: ), amine@287: ] amine@287: self.assertEqual(patch_rsp.mock_calls, expected_calls) amine@287: region = AudioRegion.load( amine@287: "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 amine@287: ) amine@287: self.assertTrue(saver._exported) amine@287: self.assertEqual(saver.data, bytes(region))