Mercurial > hg > auditok
changeset 241:79b668c48fce
Make sure split keeps original number of channels
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Fri, 26 Jul 2019 20:46:53 +0100 |
parents | 173ffca58d23 |
children | 90445f084929 |
files | auditok/core.py auditok/util.py tests/test_core.py |
diffstat | 3 files changed, 185 insertions(+), 191 deletions(-) [+] |
line wrap: on
line diff
--- a/auditok/core.py Thu Jul 25 20:50:52 2019 +0100 +++ b/auditok/core.py Fri Jul 26 20:46:53 2019 +0100 @@ -74,12 +74,10 @@ nuumber of channels of audio data. Only needed for raw audio files. use_channel, uc: int, str which channel to use if input has multichannel audio data. Can be an - int (0 being the first channel), or one of the following special str - values: - - 'left': first channel (equivalent to 0) - - 'right': second channel (equivalent to 1) - - 'mix': compute average channel - Default: 0, use the first channel. + int (0 being the first channel), or one of the following values: + - None, "any": a valid frame from one any given channel makes + parallel frames from all other channels automatically valid. + - 'mix': compute average channel (i.e. mix down all channels) max_read, mr: float maximum data to read in seconds. Default: `None`, read until there is no more data to read. @@ -132,8 +130,10 @@ energy_threshold = kwargs.get( "energy_threshold", kwargs.get("eth", DEFAULT_ENERGY_THRESHOLD) ) - validator = AudioEnergyValidator(source.sw, energy_threshold) - + use_channel = kwargs.get("use_channel", kwargs.get("uc")) + validator = AudioEnergyValidator( + energy_threshold, source.sw, source.ch, use_channel=use_channel + ) mode = ( StreamTokenizer.DROP_TRAILING_SILENCE if drop_trailing_silence else 0 )
--- a/auditok/util.py Thu Jul 25 20:50:52 2019 +0100 +++ b/auditok/util.py Fri Jul 26 20:46:53 2019 +0100 @@ -14,9 +14,9 @@ ADSFactory.RecorderADS DataValidator AudioEnergyValidator - """ from __future__ import division +import sys from abc import ABCMeta, abstractmethod import math from array import array @@ -29,14 +29,16 @@ get_audio_source, ) from .exceptions import DuplicateArgument, TooSamllBlockDuration -import sys try: import numpy + np = numpy _WITH_NUMPY = True + _FORMAT = {1: np.int8, 2: np.int16, 4: np.int32} except ImportError as e: _WITH_NUMPY = False + _FORMAT = {1: "b", 2: "h", 4: "i"} try: from builtins import str @@ -56,6 +58,119 @@ ] +def make_channel_selector(sample_width, channels, selected=None): + fmt = _FORMAT.get(sample_width) + if fmt is None: + err_msg = "'sample_width' must be 1, 2 or 4, given: {}" + raise ValueError(err_msg.format(sample_width)) + + if channels == 1: + if _WITH_NUMPY: + + def _as_array(data): + return np.frombuffer(data, dtype=fmt).astype(np.float64) + + else: + + def _as_array(data): + return array(fmt, data) + + return _as_array + + if isinstance(selected, int): + if selected < 0: + selected += channels + if selected < 0 or selected >= channels: + err_msg = "Selected channel must be >= -channels and < 'channels'" + err_msg += ", given: {}" + raise ValueError(err_msg.format(selected)) + if _WITH_NUMPY: + + def _extract_single_channel(data): + samples = np.frombuffer(data, dtype=fmt) + return samples[selected::channels].astype(np.float64) + + else: + + def _extract_single_channel(data): + samples = array(fmt, data) + return samples[selected::channels] + + return _extract_single_channel + + if selected in ("mix", "avg", "average"): + if _WITH_NUMPY: + + def _average_channels(data): + array = np.frombuffer(data, dtype=fmt).astype(np.float64) + return array.reshape(-1, channels).mean(axis=1) + + else: + + def _average_channels(data): + all_channels = array(fmt, data) + mono_channels = [ + array(fmt, all_channels[ch::channels]) + for ch in range(channels) + ] + avg_arr = array( + fmt, + ( + sum(samples) // channels + for samples in zip(*mono_channels) + ), + ) + return avg_arr + + return _average_channels + + if selected is None: + if _WITH_NUMPY: + + def _split_channels(data): + array = np.frombuffer(data, dtype=fmt).astype(np.float64) + return array.reshape(-1, channels).T + + else: + + def _split_channels(data): + all_channels = array(fmt, data) + mono_channels = [ + array(fmt, all_channels[ch::channels]) + for ch in range(channels) + ] + return mono_channels + + return _split_channels + + +if _WITH_NUMPY: + + def _calculate_energy_single_channel(x): + return 10 * np.log10(np.dot(x, x).clip(min=1e-20) / x.size) + + +else: + + def _calculate_energy_single_channel(x): + energy = max(sum(i ** 2 for i in x) / len(x), 1e-20) + return 10 * math.log10(energy) + + +if _WITH_NUMPY: + + def _calculate_energy_multichannel(x, aggregation_fn=np.max): + energy = 10 * np.log10((x * x).mean(axis=1).clip(min=1e-20)) + return aggregation_fn(energy) + + +else: + + def _calculate_energy_multichannel(x, aggregation_fn=max): + energies = (_calculate_energy_single_channel(xi) for xi in x) + return aggregation_fn(energies) + + class DataSource: """ Base class for objects passed to :func:`auditok.core.StreamTokenizer.tokenize`. @@ -88,6 +203,23 @@ """ +class AudioEnergyValidator(DataValidator): + def __init__( + self, energy_threshold, sample_width, channels, use_channel=None + ): + self._selector = make_channel_selector( + sample_width, channels, use_channel + ) + if channels == 1 or use_channel is not None: + self._energy_fn = _calculate_energy_single_channel + else: + self._energy_fn = _calculate_energy_multichannel + self._energy_threshold = energy_threshold + + def is_valid(self, data): + return self._energy_fn(self._selector(data)) > self._energy_threshold + + class StringDataSource(DataSource): """ A class that represent a :class:`DataSource` as a string buffer. @@ -881,107 +1013,3 @@ raise AttributeError( "'AudioDataSource' has no attribute '{}'".format(name) ) - - -class AudioEnergyValidator(DataValidator): - """ - The most basic auditok audio frame validator. - This validator computes the log energy of an input audio frame - and return True if the result is >= a given threshold, False - otherwise. - - :Parameters: - - `sample_width` : *(int)* - Number of bytes of one audio sample. This is used to convert data from `basestring` or `Bytes` to - an array of floats. - - `energy_threshold` : *(float)* - A threshold used to check whether an input data buffer is valid. - """ - - if _WITH_NUMPY: - _formats = {1: numpy.int8, 2: numpy.int16, 4: numpy.int32} - - @staticmethod - def _convert(signal, sample_width): - return numpy.array( - numpy.frombuffer( - signal, dtype=AudioEnergyValidator._formats[sample_width] - ), - dtype=numpy.float64, - ) - - @staticmethod - def _signal_energy(signal): - return float(numpy.dot(signal, signal)) / len(signal) - - @staticmethod - def _signal_log_energy(signal): - energy = AudioEnergyValidator._signal_energy(signal) - if energy <= 0: - return -200 - return 10.0 * numpy.log10(energy) - - else: - _formats = {1: "b", 2: "h", 4: "i"} - - @staticmethod - def _convert(signal, sample_width): - return array( - "d", array(AudioEnergyValidator._formats[sample_width], signal) - ) - - @staticmethod - def _signal_energy(signal): - energy = 0.0 - for a in signal: - energy += a * a - return energy / len(signal) - - @staticmethod - def _signal_log_energy(signal): - energy = AudioEnergyValidator._signal_energy(signal) - if energy <= 0: - return -200 - return 10.0 * math.log10(energy) - - def __init__(self, sample_width, energy_threshold=45): - self.sample_width = sample_width - self._energy_threshold = energy_threshold - - def is_valid(self, data): - """ - Check if data is valid. Audio data will be converted into an array (of - signed values) of which the log energy is computed. Log energy is computed - as follows: - - .. code:: python - - arr = AudioEnergyValidator._convert(signal, sample_width) - energy = float(numpy.dot(arr, arr)) / len(arr) - log_energy = 10. * numpy.log10(energy) - - - :Parameters: - - `data` : either a *string* or a *Bytes* buffer - `data` is converted into a numerical array using the `sample_width` - given in the constructor. - - :Returns: - - True if `log_energy` >= `energy_threshold`, False otherwise. - """ - - signal = AudioEnergyValidator._convert(data, self.sample_width) - return ( - AudioEnergyValidator._signal_log_energy(signal) - >= self._energy_threshold - ) - - def get_energy_threshold(self): - return self._energy_threshold - - def set_energy_threshold(self, threshold): - self._energy_threshold = threshold
--- a/tests/test_core.py Thu Jul 25 20:50:52 2019 +0100 +++ b/tests/test_core.py Fri Jul 26 20:46:53 2019 +0100 @@ -182,25 +182,15 @@ self.assertEqual(bytes(reg), exp_data) @genty_dataset( - stereo_all_default=(2, {}, [(2, 16), (17, 31), (34, 76)]), + stereo_all_default=(2, {}, [(2, 32), (34, 76)]), mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), mono_max_read_short_name=(1, {"mr": 5}, [(2, 16), (17, 31), (34, 50)]), mono_use_channel_1=( 1, - {"eth": 50, "use_channel": 1}, + {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)], ), mono_uc_1=(1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]), - mono_use_channel_left=( - 1, - {"eth": 50, "use_channel": "left"}, - [(2, 16), (17, 31), (34, 76)], - ), - mono_uc_left=( - 1, - {"eth": 50, "uc": "left"}, - [(2, 16), (17, 31), (34, 76)], - ), mono_use_channel_None=( 1, {"eth": 50, "use_channel": None}, @@ -208,30 +198,20 @@ ), stereo_use_channel_1=( 2, - {"eth": 50, "use_channel": 1}, - [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_left=( - 2, - {"eth": 50, "use_channel": "left"}, + {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)], ), stereo_use_channel_no_use_channel_given=( 2, {"eth": 50}, - [(2, 16), (17, 31), (34, 76)], + [(2, 32), (34, 76)], ), stereo_use_channel_minus_2=( 2, {"eth": 50, "use_channel": -2}, [(2, 16), (17, 31), (34, 76)], ), - stereo_uc_2=(2, {"eth": 50, "uc": 2}, [(10, 32), (36, 76)]), - stereo_use_channel_right=( - 2, - {"eth": 50, "use_channel": "right"}, - [(10, 32), (36, 76)], - ), + stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), mono_uc_mix=( 1, @@ -272,23 +252,16 @@ ) regions = list(regions) sample_width = 2 - import numpy as np - - use_channel = kwargs.get("use_channel", kwargs.get("uc")) - # extrat channel of interest - if channels != 1: - use_channel = kwargs.get("use_channel", kwargs.get("uc")) - use_channel = _normalize_use_channel(use_channel) - data = _extract_selected_channel( - data, channels, sample_width, use_channel=use_channel - ) err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(expected, regions) + err_msg += "{}, found: {}".format(len(expected), len(regions)) self.assertEqual(len(regions), len(expected), err_msg) + sample_size_bytes = sample_width * channels for reg, exp in zip(regions, expected): onset, offset = exp - exp_data = data[onset * sample_width : offset * sample_width] - self.assertEqual(bytes(reg), exp_data) + exp_data = data[ + onset * sample_size_bytes : offset * sample_size_bytes + ] + self.assertEqual(len(bytes(reg)), len(exp_data)) @genty_dataset( mono_aw_0_2_max_silence_0_2=( @@ -296,7 +269,7 @@ 5, 0.2, 1, - {"uc": 1, "aw": 0.2}, + {"aw": 0.2}, [(2, 30), (34, 76)], ), mono_aw_0_2_max_silence_0_3=( @@ -304,7 +277,7 @@ 5, 0.3, 1, - {"uc": 1, "aw": 0.2}, + {"aw": 0.2}, [(2, 30), (34, 76)], ), mono_aw_0_2_max_silence_0_4=( @@ -312,7 +285,7 @@ 5, 0.4, 1, - {"uc": 1, "aw": 0.2}, + {"aw": 0.2}, [(2, 32), (34, 76)], ), mono_aw_0_2_max_silence_0=( @@ -320,23 +293,16 @@ 5, 0, 1, - {"uc": 1, "aw": 0.2}, + {"aw": 0.2}, [(2, 14), (16, 24), (26, 28), (34, 76)], ), - mono_aw_0_2=( - 0.2, - 5, - 0.2, - 1, - {"uc": 1, "aw": 0.2}, - [(2, 30), (34, 76)], - ), + mono_aw_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), mono_aw_0_3_max_silence_0=( 0.3, 5, 0, 1, - {"uc": 1, "aw": 0.3}, + {"aw": 0.3}, [(3, 12), (15, 24), (36, 76)], ), mono_aw_0_3_max_silence_0_3=( @@ -344,7 +310,7 @@ 5, 0.3, 1, - {"uc": 1, "aw": 0.3}, + {"aw": 0.3}, [(3, 27), (36, 76)], ), mono_aw_0_3_max_silence_0_5=( @@ -352,7 +318,7 @@ 5, 0.5, 1, - {"uc": 1, "aw": 0.3}, + {"aw": 0.3}, [(3, 27), (36, 76)], ), mono_aw_0_3_max_silence_0_6=( @@ -360,7 +326,7 @@ 5, 0.6, 1, - {"uc": 1, "aw": 0.3}, + {"aw": 0.3}, [(3, 30), (36, 76)], ), mono_aw_0_4_max_silence_0=( @@ -368,7 +334,7 @@ 5, 0, 1, - {"uc": 1, "aw": 0.4}, + {"aw": 0.4}, [(4, 12), (16, 24), (36, 76)], ), mono_aw_0_4_max_silence_0_3=( @@ -376,7 +342,7 @@ 5, 0.3, 1, - {"uc": 1, "aw": 0.4}, + {"aw": 0.4}, [(4, 12), (16, 24), (36, 76)], ), mono_aw_0_4_max_silence_0_4=( @@ -384,23 +350,23 @@ 5, 0.4, 1, - {"uc": 1, "aw": 0.4}, + {"aw": 0.4}, [(4, 28), (36, 76)], ), + stereo_uc_0_analysis_window_0_2=( + 0.2, + 5, + 0.2, + 2, + {"uc": 0, "analysis_window": 0.2}, + [(2, 30), (34, 76)], + ), stereo_uc_1_analysis_window_0_2=( 0.2, 5, 0.2, 2, {"uc": 1, "analysis_window": 0.2}, - [(2, 30), (34, 76)], - ), - stereo_uc_2_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": 2, "analysis_window": 0.2}, [(10, 32), (36, 76)], ), stereo_uc_mix_aw_0_1_max_silence_0=( @@ -597,20 +563,18 @@ sample_width = 2 import numpy as np - use_channel = kwargs.get("use_channel", kwargs.get("uc")) - # extrat channel of interest - if channels != 1: - use_channel = kwargs.get("use_channel", kwargs.get("uc")) - use_channel = _normalize_use_channel(use_channel) - data = _extract_selected_channel( - data, channels, sample_width, use_channel=use_channel - ) err_msg = "Wrong number of regions after split, expected: " err_msg += "{}, found: {}".format(expected, regions) self.assertEqual(len(regions), len(expected), err_msg) for reg, exp in zip(regions, expected): onset, offset = exp - exp_data = data[onset * sample_width : offset * sample_width] + exp_data = data[ + onset + * sample_width + * channels : offset + * sample_width + * channels + ] self.assertEqual(bytes(reg), exp_data) @genty_dataset( @@ -663,7 +627,7 @@ ) def test_split_input_type(self, input, kwargs): - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp: data = fp.read() regions = split( @@ -677,14 +641,16 @@ **kwargs ) regions = list(regions) - expected = [(2, 16), (17, 31), (34, 76)] + expected = [(2, 32), (34, 76)] sample_width = 2 err_msg = "Wrong number of regions after split, expected: " err_msg += "{}, found: {}".format(expected, regions) self.assertEqual(len(regions), len(expected), err_msg) for reg, exp in zip(regions, expected): onset, offset = exp - exp_data = data[onset * sample_width : offset * sample_width] + exp_data = data[ + onset * sample_width * 2 : offset * sample_width * 2 + ] self.assertEqual(bytes(reg), exp_data) @genty_dataset(