# HG changeset patch # User Amine Sehili # Date 1570893797 -3600 # Node ID bc963239143fa7cb4c7f7d49ce59dfc567f3531d # Parent 334c8760e80fa4d9bf7b30d5f0e0ef379b6ba554 Refactor AudioRegion.load and add tests diff -r 334c8760e80f -r bc963239143f auditok/core.py --- a/auditok/core.py Fri Oct 11 20:58:49 2019 +0100 +++ b/auditok/core.py Sat Oct 12 16:23:17 2019 +0100 @@ -273,6 +273,49 @@ return AudioRegion(data, sampling_rate, sample_width, channels, meta) +def _read_chunks_online(max_read, **kwargs): + reader = AudioReader(None, block_dur=0.5, max_read=max_read, **kwargs) + reader.open() + data = [] + try: + while True: + frame = reader.read() + if frame is None: + break + data.append(frame) + except KeyboardInterrupt: + # Stop data acquisition from microphone when pressing + # Ctrl+C on a [i]python session or a notebook + pass + reader.close() + return ( + b"".join(data), + reader.sampling_rate, + reader.sample_width, + reader.channels, + ) + + +def _read_offline(input, skip=0, max_read=None, **kwargs): + audio_source = get_audio_source(input, **kwargs) + audio_source.open() + if skip is not None and skip > 0: + skip_samples = round(skip * audio_source.sampling_rate) + audio_source.read(skip_samples) + if max_read is not None: + if max_read < 0: + max_read = None + else: + max_read = round(max_read * audio_source.sampling_rate) + data = audio_source.read(max_read) + return ( + data, + audio_source.sampling_rate, + audio_source.sample_width, + audio_source.channels, + ) + + def _check_convert_index(index, types, err_msg): if not isinstance(index, slice) or index.step is not None: raise TypeError(err_msg) @@ -394,27 +437,24 @@ @classmethod def load(cls, input, skip=0, max_read=None, **kwargs): - if input is None and max_read is None: - raise ValueError( - "'max_read' should not be None when reading from microphone" + if input is None: + if max_read is None or max_read < 0: + raise ValueError( + "'max_read' should not be None when reading from microphone" + ) + if skip > 0: + raise ValueError( + "'skip' should be 0 when reading from microphone" + ) + data, sampling_rate, sample_width, channels = _read_chunks_online( + max_read, **kwargs ) - audio_source = get_audio_source(input, **kwargs) - audio_source.open() - if skip is not None and skip > 0: - skip_samples = int(skip * audio_source.sampling_rate) - audio_source.read(skip_samples) - if max_read is None or max_read < 0: - max_read_samples = None else: - max_read_samples = round(max_read * audio_source.sampling_rate) - data = audio_source.read(max_read_samples) - audio_source.close() - return cls( - data, - audio_source.sampling_rate, - audio_source.sample_width, - audio_source.channels, - ) + data, sampling_rate, sample_width, channels = _read_offline( + input, skip=skip, max_read=max_read, **kwargs + ) + + return cls(data, sampling_rate, sample_width, channels) @property def sec(self): diff -r 334c8760e80f -r bc963239143f tests/test_core.py --- a/tests/test_core.py Fri Oct 11 20:58:49 2019 +0100 +++ b/tests/test_core.py Sat Oct 12 16:23:17 2019 +0100 @@ -7,7 +7,7 @@ from mock import patch from genty import genty, genty_dataset from auditok import split, AudioRegion, AudioParameterError -from auditok.core import _duration_to_nb_windows +from auditok.core import _duration_to_nb_windows, _read_chunks_online from auditok.util import AudioDataSource from auditok.io import ( _normalize_use_channel, @@ -1079,14 +1079,71 @@ str(audio_param_err.exception), ) - def test_load_exception(self): + @genty_dataset( + no_skip_read_all=(0, -1), + no_skip_read_all_stereo=(0, -1, 2), + skip_2_read_all=(2, -1), + skip_2_read_all_None=(2, None), + skip_2_read_3=(2, 3), + skip_2_read_3_5_stereo=(2, 3.5, 2), + skip_2_4_read_3_5_stereo=(2.4, 3.5, 2), + ) + def test_load(self, skip, max_read, channels=1): + sampling_rate = 10 + sample_width = 2 + filename = "tests/data/test_split_10HZ_{}.raw" + filename = filename.format("mono" if channels == 1 else "stereo") + region = AudioRegion.load( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, + ) + with open(filename, "rb") as fp: + fp.read(round(skip * sampling_rate * sample_width * channels)) + if max_read is None or max_read < 0: + to_read = -1 + else: + to_read = round( + max_read * sampling_rate * sample_width * channels + ) + expected = fp.read(to_read) + self.assertEqual(bytes(region), expected) + + def test_load_from_microphone(self): + with patch("auditok.io.PyAudioSource") as patch_pyaudio_source: + with patch("auditok.core.AudioReader.read") as patch_reader: + patch_reader.return_value = None + with patch( + "auditok.core.AudioRegion.__init__" + ) as patch_AudioRegion: + patch_AudioRegion.return_value = None + AudioRegion.load( + None, skip=0, max_read=5, sr=16000, sw=2, ch=1 + ) + self.assertTrue(patch_pyaudio_source.called) + self.assertTrue(patch_reader.called) + self.assertTrue(patch_AudioRegion.called) + + @genty_dataset(none=(None,), negative=(-1,)) + def test_load_from_microphone_without_max_read_exception(self, max_read): with self.assertRaises(ValueError) as val_err: - AudioRegion.load(None, sr=16000, sw=2, ch=1) + AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1) self.assertEqual( "'max_read' should not be None when reading from microphone", str(val_err.exception), ) + def test_load_from_microphone_with_nonzero_skip_exception(self): + with self.assertRaises(ValueError) as val_err: + AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1) + self.assertEqual( + "'skip' should be 0 when reading from microphone", + str(val_err.exception), + ) + @genty_dataset( simple=("output.wav", 1.230, "output.wav"), start=("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"),