# HG changeset patch # User Amine Sehili # Date 1570478303 -3600 # Node ID 5af0974b344619a35e8d3f63fcea17522c0b981e # Parent 49082909193c93f848af8e52f814b250c4deb30c Add Recorder class as an alias of AudioReader with record=True diff -r 49082909193c -r 5af0974b3446 auditok/util.py --- a/auditok/util.py Mon Oct 07 20:44:57 2019 +0100 +++ b/auditok/util.py Mon Oct 07 20:58:23 2019 +0100 @@ -51,6 +51,7 @@ "ADSFactory", "AudioDataSource", "AudioReader", + "Recorder", "AudioEnergyValidator", ] @@ -72,10 +73,7 @@ err_msg += ", given: {}" raise ValueError(err_msg.format(selected)) return partial( - signal.extract_single_channel, - fmt=fmt, - channels=channels, - selected=selected, + signal.extract_single_channel, fmt=fmt, channels=channels, selected=selected ) if selected in ("mix", "avg", "average"): @@ -118,12 +116,8 @@ class AudioEnergyValidator(DataValidator): - def __init__( - self, energy_threshold, sample_width, channels, use_channel=None - ): - self._selector = make_channel_selector( - sample_width, channels, use_channel - ) + def __init__(self, energy_threshold, sample_width, channels, use_channel=None): + self._selector = make_channel_selector(sample_width, channels, use_channel) if channels == 1 or use_channel is not None: self._energy_fn = signal.calculate_energy_single_channel else: @@ -305,13 +299,9 @@ kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None) kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None) kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None) - kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop( - "asrc", None - ) + kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop("asrc", None) kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None) - kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop( - "db", None - ) + kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None) record = kwargs.pop("record", False) if not record: @@ -328,19 +318,17 @@ ) or kwargs.pop("fpb", None) if "sampling_rate" in kwargs or "sr" in kwargs: - kwargs["sampling_rate"] = kwargs.pop( - "sampling_rate", None - ) or kwargs.pop("sr", None) + kwargs["sampling_rate"] = kwargs.pop("sampling_rate", None) or kwargs.pop( + "sr", None + ) if "sample_width" in kwargs or "sw" in kwargs: - kwargs["sample_width"] = kwargs.pop( - "sample_width", None - ) or kwargs.pop("sw", None) + kwargs["sample_width"] = kwargs.pop("sample_width", None) or kwargs.pop( + "sw", None + ) if "channels" in kwargs or "ch" in kwargs: - kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop( - "ch", None - ) + kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop("ch", None) @staticmethod def ads(**kwargs): @@ -748,9 +736,7 @@ super(_FixedSizeAudioReader, self).__init__(audio_source) if block_dur <= 0: - raise ValueError( - "block_dur must be > 0, given: {}".format(block_dur) - ) + raise ValueError("block_dur must be > 0, given: {}".format(block_dur)) self._block_size = int(block_dur * self.sr) if self._block_size == 0: @@ -799,9 +785,7 @@ if block is None: yield None - _hop_size_bytes = ( - self._hop_size * self._audio_source.sw * self._audio_source.ch - ) + _hop_size_bytes = self._hop_size * self._audio_source.sw * self._audio_source.ch cache = block[_hop_size_bytes:] yield block @@ -846,13 +830,7 @@ """ def __init__( - self, - input, - block_dur=0.01, - hop_dur=None, - record=False, - max_read=None, - **kwargs + self, input, block_dur=0.01, hop_dur=None, record=False, max_read=None, **kwargs ): if not isinstance(input, AudioSource): input = get_audio_source(input, **kwargs) @@ -877,10 +855,11 @@ if self.max_read is not None: max_read = "{:.3f}".format(self.max_read) return ( - "AudioReader(source, block_dur={block_dur}, " + "{cls}(block_dur={block_dur}, " "hop_dur={hop_dur}, record={rewindable}, " "max_read={max_read})" ).format( + cls=self.__class__.__name__, block_dur=block_dur, hop_dur=hop_dur, rewindable=self._record, @@ -919,16 +898,25 @@ def __getattr__(self, name): if name in ("data", "rewind") and not self.rewindable: - raise AttributeError( - "'AudioReader' has no attribute '{}'".format(name) - ) + raise AttributeError("'AudioReader' has no attribute '{}'".format(name)) try: return getattr(self._audio_source, name) except AttributeError: - raise AttributeError( - "'AudioReader' has no attribute '{}'".format(name) - ) + raise AttributeError("'AudioReader' has no attribute '{}'".format(name)) + # Keep AudioDataSource for compatibility # Remove in a future version when ADSFactory is dropped -AudioDataSource = AudioReader \ No newline at end of file +AudioDataSource = AudioReader + + +class Recorder(AudioReader): + def __init__(self, input, block_dur=0.01, hop_dur=None, max_read=None, **kwargs): + super().__init__( + input, + block_dur=block_dur, + hop_dur=hop_dur, + record=True, + max_read=max_read, + **kwargs + ) diff -r 49082909193c -r 5af0974b3446 tests/test_AudioDataSource.py --- a/tests/test_AudioDataSource.py Mon Oct 07 20:44:57 2019 +0100 +++ b/tests/test_AudioDataSource.py Mon Oct 07 20:58:23 2019 +0100 @@ -13,6 +13,8 @@ dataset, ADSFactory, AudioDataSource, + AudioReader, + Recorder, BufferAudioSource, WaveAudioSource, DuplicateArgument, @@ -1034,7 +1036,7 @@ with open(input_raw, "rb") as fp: expected = fp.read(size) - reader = AudioDataSource(input_wav, block_dur=0.1, max_read=max_read) + reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) reader.open() data = _read_all_data(reader) reader.close() @@ -1047,7 +1049,7 @@ with open(input_raw, "rb") as fp: expected = fp.read() - reader = AudioDataSource(input_wav, block_dur=0.1, record=True) + reader = AudioReader(input_wav, block_dur=0.1, record=True) reader.open() data = _read_all_data(reader) self.assertEqual(data, expected) @@ -1060,6 +1062,26 @@ self.assertEqual(data, reader.data) reader.close() + @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) + def test_Recorder_alias(self, file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = Recorder(input_wav, block_dur=0.1) + reader.open() + data = _read_all_data(reader) + self.assertEqual(data, expected) + + # rewind many times + for _ in range(3): + reader.rewind() + data = _read_all_data(reader) + self.assertEqual(data, expected) + self.assertEqual(data, reader.data) + reader.close() + if __name__ == "__main__": unittest.main()