Mercurial > hg > auditok
changeset 434:89cc78530ea0
Merge branch 'master' of https://github.com/amsehili/auditok
author | www-data <www-data@c4dm-xenserv-virt2.eecs.qmul.ac.uk> |
---|---|
date | Wed, 30 Oct 2024 17:17:59 +0000 |
parents | c801276ddf11 (current diff) 0f8f60771784 (diff) |
children | 87048a881402 |
files | .travis.yml README.rst auditok/signal_numpy.py tests/images/py34_py35/plot_mono_region.png tests/images/py34_py35/plot_stereo_region.png tests/images/py34_py35/split_and_plot_mono_region.png tests/images/py34_py35/split_and_plot_uc_0_stereo_region.png tests/images/py34_py35/split_and_plot_uc_1_stereo_region.png tests/images/py34_py35/split_and_plot_uc_any_stereo_region.png tests/images/py34_py35/split_and_plot_uc_mix_stereo_region.png |
diffstat | 48 files changed, 7523 insertions(+), 7198 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/.github/workflows/ci.yml Wed Oct 30 17:17:59 2024 +0000 @@ -0,0 +1,39 @@ +name: CI + +on: + push: + branches: [master, dev] + pull_request: + branches: [master] + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up PYTHONPATH + run: echo "PYTHONPATH=$PYTHONPATH:${{ github.workspace }}" >> $GITHUB_ENV + + - name: Install dependencies + run: | + sudo apt-get update --fix-missing + pip install numpy pytest pydub matplotlib + + - name: Install specific package for Python 3.13 only + if: matrix.python-version == '3.13' + run: pip install audioop-lts + + - name: Run tests + run: pytest -s -p no:warnings "tests"
--- a/.gitignore Thu Mar 30 10:17:57 2023 +0100 +++ b/.gitignore Wed Oct 30 17:17:59 2024 +0000 @@ -2,11 +2,181 @@ *pyc auditok/__pycache__ .*.swp -tags -build -dist -MANIFEST.in *~ .pydevproject .project TODO + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +tags/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +MANIFEST.in + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy__pycache__ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py__pycache__ + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +pdm.lock + +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# macOS specific files +.DS_Store +.AppleDouble +.LSOverride +Icon? +__MACOSX + +# Thumbnails +*.sublime-workspace +*.sublime-project + +# macOS Finder directory metadata +*.DS_Store +*.AppleDouble +*.LSOverride +*.Icon* +*.__MACOSX
--- a/.pre-commit-config.yaml Thu Mar 30 10:17:57 2023 +0100 +++ b/.pre-commit-config.yaml Wed Oct 30 17:17:59 2024 +0000 @@ -1,12 +1,23 @@ repos: +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ['--line-length=80'] - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 24.4.2 hooks: - id: black - language_version: python3.7 +- repo: https://github.com/PyCQA/flake8 + rev: 7.1.0 + hooks: + - id: flake8 + additional_dependencies: + - flake8-bugbear + - flake8-comprehensions + - flake8-simplify - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.4.0 + rev: v4.6.0 hooks: - - id: flake8 - - id: end-of-file-fixer - - id: trailing-whitespace + - id: end-of-file-fixer + - id: trailing-whitespace
--- a/.travis.yml Thu Mar 30 10:17:57 2023 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,17 +0,0 @@ -before_install: - - sudo apt-get update --fix-missing -install: - - pip install numpy - - pip install genty - - pip install pydub - - pip install "matplotlib<=3.2.1" -language: python -python: - - "3.4" - - "3.5" - - "3.6" - - "3.7" - - "3.8" - - "3.9" -script: - - python -m unittest discover tests/
--- a/README.rst Thu Mar 30 10:17:57 2023 +0100 +++ b/README.rst Wed Oct 30 17:17:59 2024 +0000 @@ -1,51 +1,37 @@ .. image:: doc/figures/auditok-logo.png :align: center - :alt: Build status -.. image:: https://travis-ci.org/amsehili/auditok.svg?branch=master - :target: https://travis-ci.org/amsehili/auditok +.. image:: https://github.com/amsehili/auditok/actions/workflows/ci.yml/badge.svg + :target: https://github.com/amsehili/auditok/actions/workflows/ci.yml/ + :alt: Build Status .. image:: https://readthedocs.org/projects/auditok/badge/?version=latest :target: http://auditok.readthedocs.org/en/latest/?badge=latest - :alt: Documentation status + :alt: Documentation Status -``auditok`` is an **Audio Activity Detection** tool that can process online data -(read from an audio device or from standard input) as well as audio files. -It can be used as a command-line program or by calling its API. +``auditok`` is an **Audio Activity Detection** tool that processes online data +(from an audio device or standard input) and audio files. It can be used via the command line or through its API. -The latest version of the documentation can be found on -`readthedocs. <https://auditok.readthedocs.io/en/latest/>`_ - +Full documentation is available on `Read the Docs <https://auditok.readthedocs.io/en/latest/>`_. Installation ------------ -A basic version of ``auditok`` will run with standard Python (>=3.4). However, -without installing additional dependencies, ``auditok`` can only deal with audio -files in *wav* or *raw* formats. if you want more features, the following -packages are needed: +``auditok`` requires Python 3.7 or higher. -- `pydub <https://github.com/jiaaro/pydub>`_ : read audio files in popular audio formats (ogg, mp3, etc.) or extract audio from a video file. -- `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_ : read audio data from the microphone and play audio back. -- `tqdm <https://github.com/tqdm/tqdm>`_ : show progress bar while playing audio clips. -- `matplotlib <https://matplotlib.org/stable/index.html>`_ : plot audio signal and detections. -- `numpy <https://numpy.org/>`_ : required by matplotlib. Also used for some math operations instead of standard python if available. - -Install the latest stable version with pip: - +To install the latest stable version, use pip: .. code:: bash sudo pip install auditok - -Install the latest development version from github: +To install the latest development version from GitHub: .. code:: bash pip install git+https://github.com/amsehili/auditok -or +Alternatively, clone the repository and install it manually: .. code:: bash @@ -57,79 +43,112 @@ cd auditok python setup.py install - Basic example ------------- +Here's a simple example of using ``auditok`` to detect audio events: + .. code:: python import auditok - # split returns a generator of AudioRegion objects - audio_regions = auditok.split( + # `split` returns a generator of AudioRegion objects + audio_events = auditok.split( "audio.wav", - min_dur=0.2, # minimum duration of a valid audio event in seconds - max_dur=4, # maximum duration of an event - max_silence=0.3, # maximum duration of tolerated continuous silence within an event - energy_threshold=55 # threshold of detection + min_dur=0.2, # Minimum duration of a valid audio event in seconds + max_dur=4, # Maximum duration of an event + max_silence=0.3, # Maximum tolerated silence duration within an event + energy_threshold=55 # Detection threshold ) - for i, r in enumerate(audio_regions): + for i, r in enumerate(audio_events): + # AudioRegions returned by `split` have defined 'start' and 'end' attributes + print(f"Event {i}: {r.start:.3f}s -- {r.end:.3f}") - # Regions returned by `split` have 'start' and 'end' metadata fields - print("Region {i}: {r.meta.start:.3f}s -- {r.meta.end:.3f}s".format(i=i, r=r)) + # Play the audio event + r.play(progress_bar=True) - # play detection - # r.play(progress_bar=True) + # Save the event with start and end times in the filename + filename = r.save("event_{start:.3f}-{end:.3f}.wav") + print(f"Event saved as: {filename}") - # region's metadata can also be used with the `save` method - # (no need to explicitly specify region's object and `format` arguments) - filename = r.save("region_{meta.start:.3f}-{meta.end:.3f}.wav") - print("region saved as: {}".format(filename)) - -output example: +Example output: .. code:: bash - Region 0: 0.700s -- 1.400s - region saved as: region_0.700-1.400.wav - Region 1: 3.800s -- 4.500s - region saved as: region_3.800-4.500.wav - Region 2: 8.750s -- 9.950s - region saved as: region_8.750-9.950.wav - Region 3: 11.700s -- 12.400s - region saved as: region_11.700-12.400.wav - Region 4: 15.050s -- 15.850s - region saved as: region_15.050-15.850.wav - + Event 0: 0.700s -- 1.400s + Event saved as: event_0.700-1.400.wav + Event 1: 3.800s -- 4.500s + Event saved as: event_3.800-4.500.wav + Event 2: 8.750s -- 9.950s + Event saved as: event_8.750-9.950.wav + Event 3: 11.700s -- 12.400s + Event saved as: event_11.700-12.400.wav + Event 4: 15.050s -- 15.850s + Event saved as: event_15.050-15.850.wav Split and plot -------------- -Visualize audio signal and detections: +Visualize the audio signal with detected events: .. code:: python import auditok - region = auditok.load("audio.wav") # returns an AudioRegion object - regions = region.split_and_plot(...) # or just region.splitp() + region = auditok.load("audio.wav") # Returns an AudioRegion object + regions = region.split_and_plot(...) # Or simply use `region.splitp()` -output figure: +Example output: .. image:: doc/figures/example_1.png +Split an audio stream and re-join (glue) audio events with silence +------------------------------------------------------------------ + +The following code detects audio events within an audio stream, then insert +1 second of silence between them to create an audio with pauses: + +.. code:: python + + # Create a 1-second silent audio region + # Audio parameters must match the original stream + from auditok import split, make_silence + silence = make_silence(duration=1, + sampling_rate=16000, + sample_width=2, + channels=1) + events = split("audio.wav") + audio_with_pauses = silence.join(events) + +Alternatively, use ``split_and_join_with_silence``: + +.. code:: python + + from auditok import split_and_join_with_silence + audio_with_pauses = split_and_join_with_silence(silence_duration=1, input="audio.wav") + +Export an ``AudioRegion`` as a ``numpy`` array +---------------------------------------------- + +.. code:: python + + from auditok import load, AudioRegion + audio = load("audio.wav") # or use `AudioRegion.load("audio.wav")` + x = audio.numpy() + assert x.shape[0] == audio.channels + assert x.shape[1] == len(audio) + Limitations ----------- -Currently, the core detection algorithm is based on the energy of audio signal. -While this is fast and works very well for audio streams with low background -noise (e.g., podcasts with few people talking, language lessons, audio recorded -in a rather quiet environment, etc.) the performance can drop as the level of -noise increases. Furthermore, the algorithm makes no distinction between speech -and other kinds of sounds, so you shouldn't use it for Voice Activity Detection -if your audio data also contain non-speech events. +The detection algorithm is based on audio signal energy. While it performs well +in low-noise environments (e.g., podcasts, language lessons, or quiet recordings), +performance may drop in noisy settings. Additionally, the algorithm does not +distinguish between speech and other sounds, so it is not suitable for Voice +Activity Detection in multi-sound environments. License ------- + MIT.
--- a/auditok/__init__.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/__init__.py Wed Oct 30 17:17:59 2024 +0000 @@ -2,7 +2,7 @@ :author: Amine SEHILI <amine.sehili@gmail.com> -2015-2021 +2015-2024 :License: @@ -10,8 +10,8 @@ """ from .core import * +from .exceptions import * from .io import * from .util import * -from .exceptions import * -__version__ = "0.2.0" +__version__ = "0.3.0"
--- a/auditok/cmdline.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/cmdline.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,36 +1,32 @@ #!/usr/bin/env python # encoding: utf-8 -""" -`auditok` -- An Audio Activity Detection tool +"""`auditok` -- An Audio Activity Detection Tool -`auditok` is a program that can be used for Audio/Acoustic -activity detection. It can read audio data from audio files as well -as from the microphone or standard input. +`auditok` is a program designed for audio or acoustic activity detection. +It supports reading audio data from various sources, including audio files, +microphones, and standard input. @author: Mohamed El Amine SEHILI -@copyright: 2015-2021 Mohamed El Amine SEHILI +@copyright: 2015-2024 Mohamed El Amine SEHILI @license: MIT @contact: amine.sehili@gmail.com -@deffield updated: 01 Mar 2021 +@deffield updated: 30 Oct 2024 """ +import os import sys -import os +import threading +import time from argparse import ArgumentParser -import time -import threading -from auditok import __version__, AudioRegion -from .util import AudioDataSource -from .exceptions import EndOfProcessing, AudioEncodingWarning -from .io import player_for -from .cmdline_util import make_logger, make_kwargs, initialize_workers -from . import workers +from auditok import AudioRegion, __version__ +from .cmdline_util import initialize_workers, make_kwargs, make_logger +from .exceptions import ArgumentError, EndOfProcessing __all__ = [] __date__ = "2015-11-23" -__updated__ = "2021-03-01" +__updated__ = "2024-10-30" def main(argv=None): @@ -129,6 +125,17 @@ metavar="STRING", ) group.add_argument( + "-j", + "--join-detections", + dest="join_detections", + type=float, + default=None, + help="Join (i.e., glue) detected audio events with a silence of " + "this duration. Should be used jointly with the --save-stream / " + "-O option.", + metavar="FLOAT", + ) + group.add_argument( "-T", "--output-format", dest="output_format", @@ -378,13 +385,16 @@ ) args = parser.parse_args(argv) + try: + kwargs = make_kwargs(args) + except ArgumentError as exc: + print(exc, file=sys.stderr) + return 1 + logger = make_logger(args.debug, args.debug_file) - kwargs = make_kwargs(args) - reader, observers = initialize_workers( - logger=logger, **kwargs.io, **kwargs.miscellaneous - ) - tokenizer_worker = workers.TokenizerWorker( - reader, observers, logger=logger, **kwargs.split + + stream_saver, tokenizer_worker = initialize_workers( + logger=logger, **kwargs.split, **kwargs.io, **kwargs.miscellaneous ) tokenizer_worker.start_all() @@ -397,16 +407,18 @@ if tokenizer_worker is not None: tokenizer_worker.stop_all() - if isinstance(reader, workers.StreamSaverWorker): - reader.join() + if stream_saver is not None: + stream_saver.join() try: - reader.save_stream() - except AudioEncodingWarning as ae_warn: - print(str(ae_warn), file=sys.stderr) + stream_saver.export_audio() + except Exception as aee: + print(aee, file=sys.stderr) if args.plot or args.save_image is not None: from .plotting import plot + reader = tokenizer_worker.reader + reader.rewind() record = AudioRegion( reader.data, reader.sr, reader.sw, reader.ch
--- a/auditok/cmdline_util.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/cmdline_util.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,9 +1,11 @@ +import logging import sys -import logging from collections import namedtuple + from . import workers -from .util import AudioDataSource +from .exceptions import ArgumentError from .io import player_for +from .util import AudioReader _AUDITOK_LOGGER = "AUDITOK_LOGGER" KeywordArguments = namedtuple( @@ -21,6 +23,12 @@ except (ValueError, TypeError): use_channel = args_ns.use_channel + if args_ns.join_detections is not None and args_ns.save_stream is None: + raise ArgumentError( + "using --join-detections/-j requires --save-stream/-O " + "to be specified." + ) + io_kwargs = { "input": args_ns.input, "audio_format": args_ns.input_format, @@ -32,6 +40,7 @@ "use_channel": use_channel, "save_stream": args_ns.save_stream, "save_detections_as": args_ns.save_detections_as, + "join_detections": args_ns.join_detections, "export_format": args_ns.output_format, "large_file": args_ns.large_file, "frames_per_buffer": args_ns.frame_per_buffer, @@ -81,14 +90,30 @@ def initialize_workers(logger=None, **kwargs): observers = [] - reader = AudioDataSource(source=kwargs["input"], **kwargs) + reader = AudioReader(source=kwargs["input"], **kwargs) if kwargs["save_stream"] is not None: - reader = workers.StreamSaverWorker( - reader, - filename=kwargs["save_stream"], - export_format=kwargs["export_format"], - ) - reader.start() + + if kwargs["join_detections"] is not None: + stream_saver = workers.AudioEventsJoinerWorker( + silence_duration=kwargs["join_detections"], + filename=kwargs["save_stream"], + export_format=kwargs["export_format"], + sampling_rate=reader.sampling_rate, + sample_width=reader.sample_width, + channels=reader.channels, + ) + observers.append(stream_saver) + + else: + reader = workers.StreamSaverWorker( + reader, + filename=kwargs["save_stream"], + export_format=kwargs["export_format"], + ) + stream_saver = reader + stream_saver.start() + else: + stream_saver = None if kwargs["save_detections_as"] is not None: worker = workers.RegionSaverWorker( @@ -123,4 +148,8 @@ ) observers.append(worker) - return reader, observers + tokenizer_worker = workers.TokenizerWorker( + reader, observers, logger=logger, **kwargs + ) + + return stream_saver, tokenizer_worker
--- a/auditok/core.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/core.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,24 +1,41 @@ """ +Module for main data structures and tokenization algorithms. + .. autosummary:: :toctree: generated/ load split + make_silence + split_and_join_with_silence AudioRegion StreamTokenizer """ + +import math import os -import math -from .util import AudioReader, DataValidator, AudioEnergyValidator -from .io import check_audio_data, to_file, player_for, get_audio_source -from .exceptions import TooSamllBlockDuration +import warnings +from dataclasses import dataclass, field +from pathlib import Path + +from .exceptions import AudioParameterError, TooSmallBlockDuration +from .io import check_audio_data, get_audio_source, player_for, to_file +from .plotting import plot +from .util import AudioEnergyValidator, AudioReader, DataValidator try: from . import signal_numpy as signal except ImportError: from . import signal -__all__ = ["load", "split", "AudioRegion", "StreamTokenizer"] +__all__ = [ + "load", + "split", + "make_silence", + "split_and_join_with_silence", + "AudioRegion", + "StreamTokenizer", +] DEFAULT_ANALYSIS_WINDOW = 0.05 @@ -27,59 +44,65 @@ def load(input, skip=0, max_read=None, **kwargs): - """Load audio data from a source and return it as an :class:`AudioRegion`. + """ + Load audio data from a specified source and return it as an + :class:`AudioRegion`. Parameters ---------- - input : None, str, bytes, AudioSource - source to read audio data from. If `str`, it should be a path to a - valid audio file. If `bytes`, it is used as raw audio data. If it is - "-", raw data will be read from stdin. If None, read audio data from - the microphone using PyAudio. If of type `bytes` or is a path to a - raw audio file then `sampling_rate`, `sample_width` and `channels` - parameters (or their alias) are required. If it's an - :class:`AudioSource` object it's used directly to read data. + input : None, str, Path, bytes, AudioSource + The source from which to read audio data. If a `str` or `Path`, it + should specify the path to a valid audio file. If `bytes`, it is + treated as raw audio data. If set to "-", raw data will be read from + standard input (stdin). If `None`, audio data is read from the + microphone using PyAudio. For `bytes` data or a raw audio file path, + `sampling_rate`, `sample_width`, and `channels` parameters (or their + aliases) must be specified. If an :class:`AudioSource` object is + provided, it is used directly to read data. skip : float, default: 0 - amount, in seconds, of audio data to skip from source. If read from - a microphone, `skip` must be 0, otherwise a `ValueError` is raised. + Duration in seconds of audio data to skip from the beginning of the + source. When reading from a microphone, `skip` must be 0; otherwise, + a `ValueError` is raised. max_read : float, default: None - amount, in seconds, of audio data to read from source. If read from - microphone, `max_read` should not be None, otherwise a `ValueError` is - raised. + Duration in seconds of audio data to read from the source. When reading + from the microphone, `max_read` must not be `None`; otherwise, a + `ValueError` is raised. audio_format, fmt : str - type of audio data (e.g., wav, ogg, flac, raw, etc.). This will only - be used if `input` is a string path to an audio file. If not given, - audio type will be guessed from file name extension or from file + Format of the audio data (e.g., wav, ogg, flac, raw, etc.). This is + only used if `input` is a string path to an audio file. If not + provided, the audio format is inferred from the file's extension or header. sampling_rate, sr : int - sampling rate of audio data. Required if `input` is a raw audio file, - a `bytes` object or None (i.e., read from microphone). + Sampling rate of the audio data. Required if `input` is a raw audio + file, a `bytes` object, or `None` (i.e., when reading from the + microphone). sample_width, sw : int - number of bytes used to encode one audio sample, typically 1, 2 or 4. - Required for raw data, see `sampling_rate`. + Number of bytes used to encode a single audio sample, typically 1, 2, + or 4. Required for raw audio data; see `sampling_rate`. channels, ch : int - number of channels of audio data. Required for raw data, see - `sampling_rate`. + Number of channels in the audio data. Required for raw audio data; + see `sampling_rate`. large_file : bool, default: False - If True, AND if `input` is a path to a *wav* of a *raw* audio file - (and **only** these two formats) then audio file is not fully loaded to - memory in order to create the region (but the portion of data needed to - create the region is of course loaded to memory). Set to True if - `max_read` is significantly smaller then the size of a large audio file - that shouldn't be entirely loaded to memory. + If `True`, and `input` is a path to a *wav* or *raw* audio file, the + file is not fully loaded into memory to create the region (only the + necessary portion of data is loaded). This should be set to `True` + when `max_read` is much smaller than the total size of a large audio + file, to avoid loading the entire file into memory. Returns ------- - region: AudioRegion + region : AudioRegion Raises ------ ValueError - raised if `input` is None (i.e., read data from microphone) and `skip` - != 0 or `input` is None `max_read` is None (meaning that when reading - from the microphone, no data should be skipped, and maximum amount of - data to read should be explicitly provided). + Raised if `input` is `None` (i.e., reading from the microphone) and + `skip` is not 0, or if `max_read` is `None` when `input` is `None`. + This ensures that when reading from the microphone, no data is + skipped, and the maximum amount of data to read is explicitly + specified. """ + return AudioRegion.load(input, skip, max_read, **kwargs) @@ -90,118 +113,114 @@ max_silence=0.3, drop_trailing_silence=False, strict_min_dur=False, - **kwargs + **kwargs, ): """ - Split audio data and return a generator of AudioRegions + Split audio data and return a generator of :class:`AudioRegion`s. Parameters ---------- - input : str, bytes, AudioSource, AudioReader, AudioRegion or None - input audio data. If str, it should be a path to an existing audio file. - "-" is interpreted as standard input. If bytes, input is considered as - raw audio data. If None, read audio from microphone. - Every object that is not an `AudioReader` will be transformed into an - `AudioReader` before processing. If it is an `str` that refers to a raw - audio file, `bytes` or None, audio parameters should be provided using - kwargs (i.e., `samplig_rate`, `sample_width` and `channels` or their - alias). - If `input` is str then audio format will be guessed from file extension. - `audio_format` (alias `fmt`) kwarg can also be given to specify audio - format explicitly. If none of these options is available, rely on - backend (currently only pydub is supported) to load data. - min_dur : float, default: 0.2 - minimun duration in seconds of a detected audio event. By using large - values for `min_dur`, very short audio events (e.g., very short 1-word - utterances like 'yes' or 'no') can be mis detected. Using very short - values might result in a high number of short, unuseful audio events. - max_dur : float, default: 5 - maximum duration in seconds of a detected audio event. If an audio event - lasts more than `max_dur` it will be truncated. If the continuation of a - truncated audio event is shorter than `min_dur` then this continuation - is accepted as a valid audio event if `strict_min_dur` is False. - Otherwise it is rejected. - max_silence : float, default: 0.3 - maximum duration of continuous silence within an audio event. There - might be many silent gaps of this duration within one audio event. If - the continuous silence happens at the end of the event than it's kept as - part of the event if `drop_trailing_silence` is False (default). - drop_trailing_silence : bool, default: False - Whether to remove trailing silence from detected events. To avoid abrupt - cuts in speech, trailing silence should be kept, therefore this - parameter should be False. - strict_min_dur : bool, default: False - strict minimum duration. Do not accept an audio event if it is shorter - than `min_dur` even if it is contiguous to the latest valid event. This - happens if the the latest detected event had reached `max_dur`. + input : str, Path, bytes, AudioSource, AudioReader, AudioRegion, or None + Audio data input. If `str` or `Path`, it should be the path to an audio + file. Use "-" to indicate standard input. If bytes, the input is treated + as raw audio data. If None, audio is read from the microphone. + + Any input not of type `AudioReader` is converted into an `AudioReader` + before processing. If `input` is raw audio data (str, bytes, or None), + specify audio parameters using kwargs (e.g., `sampling_rate`, + `sample_width`, `channels`). + + For string inputs, audio format is inferred from the file extension, or + specify explicitly via `audio_format` or `fmt`. Otherwise, the backend + (currently only `pydub`) handles loading data. + + min_dur : float, default=0.2 + Minimum duration in seconds of a detected audio event. Higher values + can exclude very short utterances (e.g., single words like "yes" or + "no"). Lower values may increase the number of short audio events. + + max_dur : float, default=5 + Maximum duration in seconds for an audio event. Events longer than this + are truncated. If the remainder of a truncated event is shorter than + `min_dur`, it is included as a valid event if `strict_min_dur` is False; + otherwise, it is rejected. + + max_silence : float, default=0.3 + Maximum duration of continuous silence allowed within an audio event. + Multiple silent gaps of this duration may appear in a single event. + Trailing silence at the end of an event is kept if + `drop_trailing_silence` is False. + + drop_trailing_silence : bool, default=False + Whether to remove trailing silence from detected events. To prevent + abrupt speech cuts, it is recommended to keep trailing silence, so + default is False. + + strict_min_dur : bool, default=False + Whether to strictly enforce `min_dur` for all events, rejecting any + event shorter than `min_dur`, even if contiguous with a valid event. Other Parameters ---------------- - analysis_window, aw : float, default: 0.05 (50 ms) - duration of analysis window in seconds. A value between 0.01 (10 ms) and - 0.1 (100 ms) should be good for most use-cases. + analysis_window, aw : float, default=0.05 (50 ms) + Duration of analysis window in seconds. Values between 0.01 and 0.1 are + generally effective. + audio_format, fmt : str - type of audio data (e.g., wav, ogg, flac, raw, etc.). This will only be - used if `input` is a string path to an audio file. If not given, audio - type will be guessed from file name extension or from file header. + Type of audio data (e.g., wav, ogg, flac, raw). Used if `input` is a + file path. If not specified, audio format is inferred from the file + extension or header. + sampling_rate, sr : int - sampling rate of audio data. Required if `input` is a raw audio file, is - a bytes object or None (i.e., read from microphone). + Sampling rate of audio data, required if `input` is raw data (bytes or + None). + sample_width, sw : int - number of bytes used to encode one audio sample, typically 1, 2 or 4. - Required for raw data, see `sampling_rate`. + Number of bytes per audio sample (typically 1, 2, or 4). Required for + raw audio; see `sampling_rate`. + channels, ch : int - number of channels of audio data. Required for raw data, see - `sampling_rate`. + Number of audio channels. Required for raw data; see `sampling_rate`. + use_channel, uc : {None, "mix"} or int - which channel to use for split if `input` has multiple audio channels. - Regardless of which channel is used for splitting, returned audio events - contain data from *all* channels, just as `input`. - The following values are accepted: + Channel selection for splitting if `input` has multiple channels. All + channels are retained in detected events. Options: - - None (alias "any"): accept audio activity from any channel, even if - other channels are silent. This is the default behavior. + - None or "any" (default): accept activity from any channel. + - "mix" or "average": mix all channels into a single averaged channel. + - int (0 <= value < channels): use the specified channel ID for splitting. - - "mix" ("avg" or "average"): mix down all channels (i.e. compute - average channel) and split the resulting channel. + large_file : bool, default=False + If True and `input` is a path to a wav or raw file, audio is processed + lazily. Otherwise, the entire file is loaded before splitting. Set to + True if file size exceeds available memory. - - int (0 <=, > `channels`): use one channel, specified by integer id, - for split. + max_read, mr : float, default=None + Maximum data read from source in seconds. Default is to read to end. - large_file : bool, default: False - If True, AND if `input` is a path to a *wav* of a *raw* audio file - (and only these two formats) then audio data is lazily loaded to memory - (i.e., one analysis window a time). Otherwise the whole file is loaded - to memory before split. Set to True if the size of the file is larger - than available memory. - max_read, mr : float, default: None, read until end of stream - maximum data to read from source in seconds. - validator, val : callable, DataValidator - custom data validator. If `None` (default), an `AudioEnergyValidor` is - used with the given energy threshold. Can be a callable or an instance - of `DataValidator` that implements `is_valid`. In either case, it'll be - called with with a window of audio data as the first parameter. - energy_threshold, eth : float, default: 50 - energy threshold for audio activity detection. Audio regions that have - enough windows of with a signal energy equal to or above this threshold - are considered valid audio events. Here we are referring to this amount - as the energy of the signal but to be more accurate, it is the log - energy of computed as: `20 * log10(sqrt(dot(x, x) / len(x)))` (see - :class:`AudioEnergyValidator` and - :func:`calculate_energy_single_channel`). If `validator` is given, this - argument is ignored. + validator, val : callable or DataValidator, default=None + Custom validator for audio data. If None, uses `AudioEnergyValidator` + with the given `energy_threshold`. Should be callable or an instance of + `DataValidator` implementing `is_valid`. + + energy_threshold, eth : float, default=50 + Energy threshold for audio activity detection. Audio regions with + sufficient signal energy above this threshold are considered valid. + Calculated as the log energy: `20 * log10(sqrt(dot(x, x) / len(x)))`. + Ignored if `validator` is specified. Yields ------ AudioRegion - a generator of detected :class:`AudioRegion` s. + Generator yielding detected :class:`AudioRegion` instances. """ + if min_dur <= 0: - raise ValueError("'min_dur' ({}) must be > 0".format(min_dur)) + raise ValueError(f"'min_dur' ({min_dur}) must be > 0") if max_dur <= 0: - raise ValueError("'max_dur' ({}) must be > 0".format(max_dur)) + raise ValueError(f"'max_dur' ({max_dur}) must be > 0") if max_silence < 0: - raise ValueError("'max_silence' ({}) must be >= 0".format(max_silence)) + raise ValueError(f"'max_silence' ({max_silence}) must be >= 0") if isinstance(input, AudioReader): source = input @@ -212,7 +231,7 @@ ) if analysis_window <= 0: raise ValueError( - "'analysis_window' ({}) must be > 0".format(analysis_window) + f"'analysis_window' ({analysis_window}) must be > 0" ) params = kwargs.copy() @@ -225,11 +244,12 @@ input = bytes(input) try: source = AudioReader(input, block_dur=analysis_window, **params) - except TooSamllBlockDuration as exc: - err_msg = "Too small 'analysis_windows' ({0}) for sampling rate " - err_msg += "({1}). Analysis windows should at least be 1/{1} to " - err_msg += "cover one single data sample" - raise ValueError(err_msg.format(exc.block_dur, exc.sampling_rate)) + except TooSmallBlockDuration as exc: + err_msg = f"Too small 'analysis_window' ({exc.block_dur}) for " + err_msg += f"sampling rate ({exc.sampling_rate}). Analysis window " + err_msg += f"should at least be 1/{exc.sampling_rate} to cover " + err_msg += "one data sample" + raise ValueError(err_msg) from exc validator = kwargs.get("validator", kwargs.get("val")) if validator is None: @@ -301,37 +321,92 @@ return region_gen +def make_silence(duration, sampling_rate=16000, sample_width=2, channels=1): + """ + Generate a silence of specified duration. + + Parameters + ---------- + duration : float + Duration of silence in seconds. + sampling_rate : int, optional + Sampling rate of the audio data, default is 16000. + sample_width : int, optional + Number of bytes per audio sample, default is 2. + channels : int, optional + Number of audio channels, default is 1. + + Returns + ------- + AudioRegion + A "silent" AudioRegion of the specified duration. + """ + size = round(duration * sampling_rate) * sample_width * channels + data = b"\0" * size + region = AudioRegion(data, sampling_rate, sample_width, channels) + return region + + +def split_and_join_with_silence(input, silence_duration, **kwargs): + """ + Split input audio and join (glue) the resulting regions with a specified + silence duration between them. This can be used to adjust the length of + silence between audio events, either shortening or lengthening pauses. + + Parameters + ---------- + silence_duration : float + Duration of silence in seconds between audio events. + + Returns + ------- + AudioRegion or None + An :meth:`AudioRegion` with the specified between-events silence + duration. Returns None if no audio events are detected in the input + data. + """ + regions = list(split(input, **kwargs)) + if regions: + first = regions[0] + # create a silence with the same parameters as input audio + silence = make_silence(silence_duration, first.sr, first.sw, first.ch) + return silence.join(regions) + return None + + def _duration_to_nb_windows( duration, analysis_window, round_fn=round, epsilon=0 ): """ - Converts a given duration into a positive integer of analysis windows. - if `duration / analysis_window` is not an integer, the result will be - rounded to the closest bigger integer. If `duration == 0`, returns `0`. - If `duration < analysis_window`, returns 1. - `duration` and `analysis_window` can be in seconds or milliseconds but - must be in the same unit. + Helper function to convert a given duration into a positive integer + of analysis windows. If `duration / analysis_window` is not an integer, + the result will be rounded up to the nearest integer. If `duration == 0`, + returns 0. If `duration < analysis_window`, returns 1. + + Both `duration` and `analysis_window` should be in the same units, + either seconds or milliseconds. Parameters ---------- duration : float - a given duration in seconds or ms. - analysis_window: float - size of analysis window, in the same unit as `duration`. - round_fn : callable - function called to round the result. Default: `round`. - epsilon : float - small value to add to the division result before rounding. - E.g., `0.3 / 0.1 = 2.9999999999999996`, when called with - `round_fn=math.floor` returns `2` instead of `3`. Adding a small value - to `0.3 / 0.1` avoids this error. + The given duration in seconds or milliseconds. + analysis_window : float + The size of each analysis window, in the same units as `duration`. + round_fn : callable, optional + A function for rounding the result, default is `round`. + epsilon : float, optional + A small value added before rounding to address floating-point + precision issues, ensuring accurate rounding for cases like + `0.3 / 0.1`, where `round_fn=math.floor` would otherwise yield + an incorrect result. Returns ------- nb_windows : int - minimum number of `analysis_window`'s to cover `durartion`. That means - that `analysis_window * nb_windows >= duration`. + The minimum number of `analysis_window` units needed to cover + `duration`, ensuring `analysis_window * nb_windows >= duration`. """ + if duration < 0 or analysis_window <= 0: err_msg = "'duration' ({}) must be >= 0 and 'analysis_window' ({}) > 0" raise ValueError(err_msg.format(duration, analysis_window)) @@ -349,54 +424,49 @@ channels, ): """ - Helper function to create an `AudioRegion` from parameters returned by - tokenization object. It takes care of setting up region `start` and `end` - in metadata. + Helper function to create an :class:`AudioRegion` from parameters provided + by a tokenization object. This function handles setting the `start` and `end` + metadata for the region. Parameters ---------- - frame_duration: float - duration of analysis window in seconds + frame_duration : float + Duration of each analysis window in seconds. start_frame : int - index of the fisrt analysis window - samling_rate : int - sampling rate of audio data + Index of the first analysis window. + sampling_rate : int + Sampling rate of the audio data. sample_width : int - number of bytes of one audio sample + Number of bytes per audio sample. channels : int - number of channels of audio data + Number of audio channels. Returns ------- audio_region : AudioRegion - AudioRegion whose start time is calculeted as: - `1000 * start_frame * frame_duration` + An AudioRegion object with `start` time calculated as: + `1000 * start_frame * frame_duration`. """ start = start_frame * frame_duration data = b"".join(data_frames) - duration = len(data) / (sampling_rate * sample_width * channels) - meta = {"start": start, "end": start + duration} - return AudioRegion(data, sampling_rate, sample_width, channels, meta) + return AudioRegion(data, sampling_rate, sample_width, channels, start) def _read_chunks_online(max_read, **kwargs): """ Helper function to read audio data from an online blocking source - (i.e., microphone). Used to build an `AudioRegion` and can intercept - KeyboardInterrupt so that reading stops as soon as this exception is - raised. Makes building `AudioRegion`s on [i]python sessions and jupyter - notebooks more user friendly. + (e.g., a microphone). This function builds an `AudioRegion` and can + intercept `KeyboardInterrupt` to stop reading immediately when the + exception is raised, making it more user-friendly for [i]Python sessions + and Jupyter notebooks. Parameters ---------- max_read : float - maximum amount of data to read in seconds. + Maximum duration of audio data to read, in seconds. kwargs : - audio parameters (sampling_rate, sample_width and channels). - - See also - -------- - `AudioRegion.build` + Additional audio parameters (e.g., `sampling_rate`, `sample_width`, + and `channels`). """ reader = AudioReader(None, block_dur=0.5, max_read=max_read, **kwargs) reader.open() @@ -409,7 +479,7 @@ data.append(frame) except KeyboardInterrupt: # Stop data acquisition from microphone when pressing - # Ctrl+C on a [i]python session or a notebook + # Ctrl+C in an [i]python session or a notebook pass reader.close() return ( @@ -422,27 +492,25 @@ def _read_offline(input, skip=0, max_read=None, **kwargs): """ - Helper function to read audio data from an offline (i.e., file). Used to - build `AudioRegion`s. + Helper function to read audio data from an offline source (e.g., file). + This function is used to build :class:`AudioRegion` objects. Parameters ---------- - input : str, bytes - path to audio file (if str), or a bytes object representing raw audio - data. - skip : float, default 0 - amount of data to skip from the begining of audio source. - max_read : float, default: None - maximum amount of audio data to read. Default: None, means read until - end of stream. + input : str or bytes + Path to an audio file (if str) or a bytes object representing raw + audio data. + skip : float, optional, default=0 + Amount of data to skip from the beginning of the audio source, in + seconds. + max_read : float, optional, default=None + Maximum duration of audio data to read, in seconds. Default is None, + which reads until the end of the stream. kwargs : - audio parameters (sampling_rate, sample_width and channels). + Additional audio parameters (e.g., `sampling_rate`, `sample_width`, + and `channels`). + """ - See also - -------- - `AudioRegion.build` - - """ audio_source = get_audio_source(input, **kwargs) audio_source.open() if skip is not None and skip > 0: @@ -475,8 +543,9 @@ class _SecondsView: - """A class to create a view of `AudioRegion` that can be sliced using - indices in seconds. + """ + A class to create a view of an :class:`AudioRegion` that supports slicing + with time-based indices in seconds. """ def __init__(self, region): @@ -513,7 +582,7 @@ start_sec = start_ms / 1000 stop_sec = None if stop_ms is None else stop_ms / 1000 index = slice(start_sec, stop_sec) - return super(_MillisView, self).__getitem__(index) + return super().__getitem__(index) def __len__(self): """ @@ -530,9 +599,16 @@ class _AudioRegionMetadata(dict): - """A class to store `AudioRegion`'s metadata.""" + """A class to store :class:`AudioRegion`'s metadata.""" def __getattr__(self, name): + warnings.warn( + "`AudioRegion.meta` is deprecated and will be removed in future " + "versions. For the 'start' and 'end' fields, please use " + "`AudioRegion.start` and `AudioRegion.end`.", + DeprecationWarning, + stacklevel=2, + ) if name in self: return self[name] else: @@ -549,82 +625,81 @@ return str(self) +@dataclass(frozen=True) class AudioRegion(object): """ - AudioRegion encapsulates raw audio data and provides an interface to - perform simple operations on it. Use `AudioRegion.load` to build an - `AudioRegion` from different types of objects. + `AudioRegion` encapsulates raw audio data and provides an interface for + performing basic audio operations. Use :meth:`AudioRegion.load` or + :func:`load` to create an `AudioRegion` from various input types. Parameters ---------- data : bytes - raw audio data as a bytes object + Raw audio data as a bytes object. sampling_rate : int - sampling rate of audio data + Sampling rate of the audio data. sample_width : int - number of bytes of one audio sample + Number of bytes per audio sample. channels : int - number of channels of audio data - meta : dict, default: None - any collection of <key:value> elements used to build metadata for - this `AudioRegion`. Meta data can be accessed via `region.meta.key` - if `key` is a valid python attribute name, or via `region.meta[key]` - if not. Note that the :func:`split` function (or the - :meth:`AudioRegion.split` method) returns `AudioRegions` with a ``start`` - and a ``stop`` meta values that indicate the location in seconds of the - region in original audio data. - - See also - -------- - AudioRegion.load - + Number of audio channels. + start : float, optional, default=None + Optional start time of the region, typically provided by the `split` + function. """ - def __init__(self, data, sampling_rate, sample_width, channels, meta=None): - check_audio_data(data, sample_width, channels) - self._data = data - self._sampling_rate = sampling_rate - self._sample_width = sample_width - self._channels = channels - self._samples = None - self.splitp = self.split_and_plot + data: bytes + sampling_rate: int + sample_width: int + channels: int + start: float = field(default=None, repr=None) - if meta is not None: - self._meta = _AudioRegionMetadata(meta) + def __post_init__(self): + + check_audio_data(self.data, self.sample_width, self.channels) + object.__setattr__(self, "splitp", self.split_and_plot) + duration = len(self.data) / ( + self.sampling_rate * self.sample_width * self.channels + ) + object.__setattr__(self, "duration", duration) + + if self.start is not None: + object.__setattr__(self, "end", self.start + self.duration) + object.__setattr__( + self, + "meta", + _AudioRegionMetadata({"start": self.start, "end": self.end}), + ) else: - self._meta = None + object.__setattr__(self, "end", None) + object.__setattr__(self, "meta", None) - self._seconds_view = _SecondsView(self) - self.sec = self.seconds - self.s = self.seconds + # `seconds` and `millis` are defined below as @property with docstring + object.__setattr__(self, "_seconds_view", _SecondsView(self)) + object.__setattr__(self, "_millis_view", _MillisView(self)) - self._millis_view = _MillisView(self) - self.ms = self.millis - - @property - def meta(self): - return self._meta - - @meta.setter - def meta(self, new_meta): - """Meta data of audio region.""" - self._meta = _AudioRegionMetadata(new_meta) + object.__setattr__(self, "sec", self.seconds) + object.__setattr__(self, "s", self.seconds) + object.__setattr__(self, "ms", self.millis) @classmethod def load(cls, input, skip=0, max_read=None, **kwargs): """ - Create an `AudioRegion` by loading data from `input`. See :func:`load` - for parameters descripion. + Create an :class:`AudioRegion` by loading data from `input`. + + See :func:`load` for a full description of parameters. Returns ------- - region: AudioRegion + region : AudioRegion + An AudioRegion instance created from the specified input data. Raises ------ ValueError - raised if `input` is None and `skip` != 0 or `max_read` is None. + Raised if `input` is None and either `skip` is not 0 or `max_read` + is None. """ + if input is None: if skip > 0: raise ValueError( @@ -648,136 +723,131 @@ @property def seconds(self): """ - A view to slice audio region by seconds (using ``region.seconds[start:end]``). + A view to slice audio region by seconds using + ``region.seconds[start:end]``. """ return self._seconds_view @property def millis(self): - """A view to slice audio region by milliseconds (using ``region.millis[start:end]``).""" + """A view to slice audio region by milliseconds using + ``region.millis[start:end]``.""" return self._millis_view @property - def duration(self): - """ - Returns region duration in seconds. - """ - return len(self._data) / ( - self.sampling_rate * self.sample_width * self.channels - ) - - @property - def sampling_rate(self): - """Sampling rate of audio data.""" - return self._sampling_rate - - @property def sr(self): """Sampling rate of audio data, alias for `sampling_rate`.""" - return self._sampling_rate - - @property - def sample_width(self): - """Number of bytes per sample, one channel considered.""" - return self._sample_width + return self.sampling_rate @property def sw(self): - """Number of bytes per sample, alias for `sampling_rate`.""" - return self._sample_width - - @property - def channels(self): - """Number of channels of audio data.""" - return self._channels + """Number of bytes per sample, alias for `sample_width`.""" + return self.sample_width @property def ch(self): """Number of channels of audio data, alias for `channels`.""" - return self._channels + return self.channels def play(self, progress_bar=False, player=None, **progress_bar_kwargs): """ - Play audio region. + Play the audio region. Parameters ---------- - progress_bar : bool, default: False - whether to use a progress bar while playing audio. Default: False. - `progress_bar` requires `tqdm`, if not installed, no progress bar - will be shown. - player : AudioPalyer, default: None - audio player to use. if None (default), use `player_for()` - to get a new audio player. - progress_bar_kwargs : kwargs - keyword arguments to pass to `tqdm` progress_bar builder (e.g., - use `leave=False` to clean up the screen when play finishes). + progress_bar : bool, optional, default=False + Whether to display a progress bar during playback. Requires `tqdm`, + if not installed, no progress bar will be shown. + player : AudioPlayer, optional, default=None + Audio player to use for playback. If None (default), a new player is + obtained via `player_for()`. + progress_bar_kwargs : dict, optional + Additional keyword arguments to pass to the `tqdm` progress bar + (e.g., `leave=False` to clear the bar from the screen upon completion). """ + if player is None: player = player_for(self) - player.play( - self._data, progress_bar=progress_bar, **progress_bar_kwargs - ) + player.play(self.data, progress_bar=progress_bar, **progress_bar_kwargs) - def save(self, file, audio_format=None, exists_ok=True, **audio_parameters): + def save( + self, filename, audio_format=None, exists_ok=True, **audio_parameters + ): """ - Save audio region to file. + Save the audio region to a file. Parameters ---------- - file : str - path to output audio file. May contain `{duration}` placeholder - as well as any place holder that this region's metadata might - contain (e.g., regions returned by `split` contain metadata with - `start` and `end` attributes that can be used to build output file - name as `{meta.start}` and `{meta.end}`. See examples using - placeholders with formatting. - - audio_format : str, default: None - format used to save audio data. If None (default), format is guessed - from file name's extension. If file name has no extension, audio - data is saved as a raw (headerless) audio file. - exists_ok : bool, default: True - If True, overwrite `file` if a file with the same name exists. - If False, raise an `IOError` if `file` exists. - audio_parameters: dict - any keyword arguments to be passed to audio saving backend. + filename : str or Path + Path to the output audio file. If a string, it may include `{start}`, + `{end}`, and `{duration}` placeholders. Regions created by `split` + contain `start` and `end` attributes that can be used to format the + filename, as shown in the example. + audio_format : str, optional, default=None + Format used to save the audio data. If None (default), the format is + inferred from the file extension. If the filename has no extension, + the audio is saved as a raw (headerless) audio file. + exists_ok : bool, optional, default=True + If True, overwrite the file if it already exists. If False, raise an + `IOError` if the file exists. + audio_parameters : dict, optional + Additional keyword arguments to pass to the audio-saving backend. Returns ------- - file: str - name of output file with replaced placehoders. + file : str + The output filename with placeholders filled in. + Raises - IOError if `file` exists and `exists_ok` is False. - + ------ + IOError + If `filename` exists and `exists_ok` is False. Examples -------- - >>> region = AudioRegion(b'\\0' * 2 * 24000, + Create an AudioRegion, specifying `start`. The `end` will be computed + based on `start` and the region's duration. + + >>> region = AudioRegion(b'\0' * 2 * 24000, >>> sampling_rate=16000, >>> sample_width=2, - >>> channels=1) - >>> region.meta.start = 2.25 - >>> region.meta.end = 2.25 + region.duration - >>> region.save('audio_{meta.start}-{meta.end}.wav') - >>> audio_2.25-3.75.wav - >>> region.save('region_{meta.start:.3f}_{duration:.3f}.wav') - audio_2.250_1.500.wav + >>> channels=1, + >>> start=2.25) + >>> region + <AudioRegion(duration=1.500, sampling_rate=16000, sample_width=2, channels=1)> + + >>> assert region.end == 3.75 + >>> assert region.save('audio_{start}-{end}.wav') == "audio_2.25-3.75.wav" + >>> filename = region.save('audio_{start:.3f}-{end:.3f}_{duration:.3f}.wav') + >>> assert filename == "audio_2.250-3.750_1.500.wav" """ - if isinstance(file, str): - file = file.format(duration=self.duration, meta=self.meta) - if not exists_ok and os.path.exists(file): - raise FileExistsError("file '{file}' exists".format(file=file)) + + if isinstance(filename, Path): + if not exists_ok and filename.exists(): + raise FileExistsError( + "file '{filename}' exists".format(filename=str(filename)) + ) + if isinstance(filename, str): + filename = filename.format( + duration=self.duration, + meta=self.meta, + start=self.start, + end=self.end, + ) + if not exists_ok and os.path.exists(filename): + raise FileExistsError( + "file '{filename}' exists".format(filename=filename) + ) to_file( - self._data, - file, + self.data, + filename, audio_format, sr=self.sr, sw=self.sw, ch=self.ch, audio_parameters=audio_parameters, ) - return file + return filename def split( self, @@ -786,9 +856,10 @@ max_silence=0.3, drop_trailing_silence=False, strict_min_dur=False, - **kwargs + **kwargs, ): - """Split audio region. See :func:`auditok.split()` for a comprehensive + """ + Split audio region. See :func:`auditok.split` for a comprehensive description of split parameters. See Also :meth:`AudioRegio.split_and_plot`. """ @@ -804,7 +875,7 @@ max_silence=max_silence, drop_trailing_silence=drop_trailing_silence, strict_min_dur=strict_min_dur, - **kwargs + **kwargs, ) def plot( @@ -816,39 +887,37 @@ dpi=120, theme="auditok", ): - """Plot audio region, one sub-plot for each channel. + """ + Plot the audio region with one subplot per channel. Parameters ---------- - scale_signal : bool, default: True - if true, scale signal by subtracting its mean and dividing by its + scale_signal : bool, optional, default=True + If True, scale the signal by subtracting its mean and dividing by its standard deviation before plotting. - show : bool - whether to show plotted signal right after the call. - figsize : tuple, default: None - width and height of the figure to pass to `matplotlib`. - save_as : str, default None. - if provided, also save plot to file. - dpi : int, default: 120 - plot dpi to pass to `matplotlib`. - theme : str or dict, default: "auditok" - plot theme to use. Currently only "auditok" theme is implemented. To - provide you own them see :attr:`auditok.plotting.AUDITOK_PLOT_THEME`. + show : bool, optional, default=False + Whether to display the plot immediately after the function call. + figsize : tuple, optional, default=None + Width and height of the figure, passed to `matplotlib`. + save_as : str, optional, default=None + If specified, save the plot to the given filename. + dpi : int, optional, default=120 + Dots per inch (DPI) for the plot, passed to `matplotlib`. + theme : str or dict, optional, default="auditok" + Plot theme to use. Only the "auditok" theme is currently implemented. + To define a custom theme, refer to + :attr:`auditok.plotting.AUDITOK_PLOT_THEME`. """ - try: - from auditok.plotting import plot - plot( - self, - scale_signal=scale_signal, - show=show, - figsize=figsize, - save_as=save_as, - dpi=dpi, - theme=theme, - ) - except ImportError: - raise RuntimeWarning("Plotting requires matplotlib") + plot( + self, + scale_signal=scale_signal, + show=show, + figsize=figsize, + save_as=save_as, + dpi=dpi, + theme=theme, + ) def split_and_plot( self, @@ -863,70 +932,105 @@ save_as=None, dpi=120, theme="auditok", - **kwargs + **kwargs, ): - """Split region and plot signal and detections. Alias: :meth:`splitp`. - See :func:`auditok.split()` for a comprehensive description of split - parameters. Also see :meth:`plot` for plot parameters. """ - try: - from auditok.plotting import plot + Split the audio region, then plot the signal and detected regions. - regions = self.split( - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - drop_trailing_silence=drop_trailing_silence, - strict_min_dur=strict_min_dur, - **kwargs + Alias + ----- + :meth:`splitp` + + Refer to :func:`auditok.split()` for a detailed description of split + parameters, and to :meth:`plot` for plot-specific parameters. + """ + regions = self.split( + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + drop_trailing_silence=drop_trailing_silence, + strict_min_dur=strict_min_dur, + **kwargs, + ) + regions = list(regions) + detections = ((reg.meta.start, reg.meta.end) for reg in regions) + eth = kwargs.get( + "energy_threshold", kwargs.get("eth", DEFAULT_ENERGY_THRESHOLD) + ) + plot( + self, + scale_signal=scale_signal, + detections=detections, + energy_threshold=eth, + show=show, + figsize=figsize, + save_as=save_as, + dpi=dpi, + theme=theme, + ) + return regions + + def _check_other_parameters(self, other): + if other.sr != self.sr: + raise AudioParameterError( + "Can only concatenate AudioRegions of the same " + "sampling rate ({} != {})".format(self.sr, other.sr) ) - regions = list(regions) - detections = ((reg.meta.start, reg.meta.end) for reg in regions) - eth = kwargs.get( - "energy_threshold", kwargs.get("eth", DEFAULT_ENERGY_THRESHOLD) + if other.sw != self.sw: + raise AudioParameterError( + "Can only concatenate AudioRegions of the same " + "sample width ({} != {})".format(self.sw, other.sw) ) - plot( - self, - scale_signal=scale_signal, - detections=detections, - energy_threshold=eth, - show=show, - figsize=figsize, - save_as=save_as, - dpi=dpi, - theme=theme, + if other.ch != self.ch: + raise AudioParameterError( + "Can only concatenate AudioRegions of the same " + "number of channels ({} != {})".format(self.ch, other.ch) ) - return regions - except ImportError: - raise RuntimeWarning("Plotting requires matplotlib") - def __array__(self): - return self.samples + def _check_iter_others(self, others): + for other in others: + self._check_other_parameters(other) + yield other + + def join(self, others): + data = self.data.join( + other.data for other in self._check_iter_others(others) + ) + return AudioRegion(data, self.sr, self.sw, self.ch) @property def samples(self): - """Audio region as arrays of samples, one array per channel.""" - if self._samples is None: - self._samples = signal.to_array( - self._data, self.sample_width, self.channels - ) - return self._samples + warnings.warn( + "`AudioRegion.samples` is deprecated and will be removed in future " + "versions. Please use `AudioRegion.numpy()`.", + DeprecationWarning, + stacklevel=2, + ) + return self.numpy() + + def __array__(self): + return self.numpy() + + def numpy(self): + """Audio region a 2D numpy array of shape (n_channels, n_samples).""" + return signal.to_array(self.data, self.sample_width, self.channels) def __len__(self): """ Return region length in number of samples. """ - return len(self._data) // (self.sample_width * self.channels) + return len(self.data) // (self.sample_width * self.channels) @property def len(self): """ - Return region length in number of samples. + Return the length of the audio region in number of samples. """ + return len(self) def __bytes__(self): - return self._data + return self.data def __str__(self): return ( @@ -937,43 +1041,39 @@ ) def __repr__(self): - return str(self) + return "<{}>".format(str(self)) def __add__(self, other): """ - Concatenates this region and `other` and return a new region. - Both regions must have the same sampling rate, sample width - and number of channels. If not, raises a `ValueError`. + Concatenate this audio region with `other`, returning a new region. + + Both regions must have the same sampling rate, sample width, and number + of channels. If they differ, a `ValueError` is raised. """ + if not isinstance(other, AudioRegion): raise TypeError( "Can only concatenate AudioRegion, " 'not "{}"'.format(type(other)) ) - if other.sr != self.sr: - raise ValueError( - "Can only concatenate AudioRegions of the same " - "sampling rate ({} != {})".format(self.sr, other.sr) - ) - if other.sw != self.sw: - raise ValueError( - "Can only concatenate AudioRegions of the same " - "sample width ({} != {})".format(self.sw, other.sw) - ) - if other.ch != self.ch: - raise ValueError( - "Can only concatenate AudioRegions of the same " - "number of channels ({} != {})".format(self.ch, other.ch) - ) - data = self._data + other._data + self._check_other_parameters(other) + data = self.data + other.data return AudioRegion(data, self.sr, self.sw, self.ch) def __radd__(self, other): """ - Concatenates `other` and this region. `other` should be an - `AudioRegion` with the same audio parameters as this region - but can exceptionally be `0` to make it possible to concatenate - many regions with `sum`. + Concatenate `other` with this audio region. + + Parameters + ---------- + other : AudioRegion or int + An `AudioRegion` with the same audio parameters as this region, or + `0` to enable concatenating multiple regions using `sum`. + + Returns + ------- + AudioRegion + A new `AudioRegion` representing the concatenation result. """ if other == 0: return self @@ -983,7 +1083,7 @@ if not isinstance(n, int): err_msg = "Can't multiply AudioRegion by a non-int of type '{}'" raise TypeError(err_msg.format(type(n))) - data = self._data * n + data = self.data * n return AudioRegion(data, self.sr, self.sw, self.ch) def __rmul__(self, n): @@ -1011,7 +1111,7 @@ if not isinstance(other, AudioRegion): return False return ( - (self._data == other._data) + (self.data == other.data) and (self.sr == other.sr) and (self.sw == other.sw) and (self.ch == other.ch) @@ -1023,7 +1123,7 @@ start_sample, stop_sample = _check_convert_index(index, (int), err_msg) bytes_per_sample = self.sample_width * self.channels - len_samples = len(self._data) // bytes_per_sample + len_samples = len(self.data) // bytes_per_sample if start_sample < 0: start_sample = max(start_sample + len_samples, 0) @@ -1036,99 +1136,66 @@ else: offset = None - data = self._data[onset:offset] + data = self.data[onset:offset] return AudioRegion(data, self.sr, self.sw, self.ch) class StreamTokenizer: """ - Class for stream tokenizers. It implements a 4-state automaton scheme - to extract sub-sequences of interest on the fly. + Class for stream tokenizers, implementing a 4-state automaton scheme + to extract relevant sub-sequences from a data stream in real-time. Parameters ---------- - validator : callable, DataValidator (must implement `is_valid`) - called with each data frame read from source. Should take one positional - argument and return True or False for valid and invalid frames - respectively. + validator : callable or :class:`DataValidator` (must implement `is_valid`). + Called with each data frame read from the source. Should take a + single argument and return True or False to indicate valid and + invalid frames, respectively. min_length : int - Minimum number of frames of a valid token. This includes all - tolerated non valid frames within the token. + Minimum number of frames in a valid token, including any tolerated + non-valid frames within the token. max_length : int - Maximum number of frames of a valid token. This includes all - tolerated non valid frames within the token. + Maximum number of frames in a valid token, including all tolerated + non-valid frames within the token. max_continuous_silence : int - Maximum number of consecutive non-valid frames within a token. - Note that, within a valid token, there may be many tolerated - *silent* regions that contain each a number of non valid frames up - to `max_continuous_silence` + Maximum number of consecutive non-valid frames within a token. Each + silent region may contain up to `max_continuous_silence` frames. - init_min : int - Minimum number of consecutive valid frames that must be - **initially** gathered before any sequence of non valid frames can - be tolerated. This option is not always needed, it can be used to - drop non-valid tokens as early as possible. **Default = 0** means - that the option is by default ineffective. + init_min : int, default=0 + Minimum number of consecutive valid frames required before + tolerating any non-valid frames. Helps discard non-valid tokens + early if needed. - init_max_silence : int - Maximum number of tolerated consecutive non-valid frames if the - number already gathered valid frames has not yet reached - 'init_min'.This argument is normally used if `init_min` is used. - **Default = 0**, by default this argument is not taken into - consideration. + init_max_silence : int, default=0 + Maximum number of tolerated consecutive non-valid frames before + reaching `init_min`. Used if `init_min` is specified. mode : int - mode can be one of the following: + Defines the tokenizer behavior with the following options: - -1 `StreamTokenizer.NORMAL` : do not drop trailing silence, and - accept a token shorter than `min_length` if it is the continuation - of the latest delivered token. + - `StreamTokenizer.NORMAL` (0, default): Do not drop trailing silence + and allow tokens shorter than `min_length` if they immediately follow + a delivered token. - -2 `StreamTokenizer.STRICT_MIN_LENGTH`: if token `i` is delivered - because `max_length` is reached, and token `i+1` is immediately - adjacent to token `i` (i.e. token `i` ends at frame `k` and token - `i+1` starts at frame `k+1`) then accept token `i+1` only of it has - a size of at least `min_length`. The default behavior is to accept - token `i+1` event if it is shorter than `min_length` (provided that - the above conditions are fulfilled of course). + - `StreamTokenizer.STRICT_MIN_LENGTH` (2): If a token `i` is + delivered at `max_length`, any adjacent token `i+1` must meet + `min_length`. - -3 `StreamTokenizer.DROP_TRAILING_SILENCE`: drop all tailing - non-valid frames from a token to be delivered if and only if it - is not **truncated**. This can be a bit tricky. A token is actually - delivered if: + - `StreamTokenizer.DROP_TRAILING_SILENCE` (4): Drop all trailing + non-valid frames from a token unless the token is truncated + (e.g., at `max_length`). - - `max_continuous_silence` is reached. - - - Its length reaches `max_length`. This is referred to as a - **truncated** token. - - In the current implementation, a `StreamTokenizer`'s decision is only - based on already seen data and on incoming data. Thus, if a token is - truncated at a non-valid but tolerated frame (`max_length` is reached - but `max_continuous_silence` not yet) any tailing silence will be kept - because it can potentially be part of valid token (if `max_length` was - bigger). But if `max_continuous_silence` is reached before - `max_length`, the delivered token will not be considered as truncated - but a result of *normal* end of detection (i.e. no more valid data). - In that case the trailing silence can be removed if you use the - `StreamTokenizer.DROP_TRAILING_SILENCE` mode. - - -4 `(StreamTokenizer.STRICT_MIN_LENGTH | StreamTokenizer.DROP_TRAILING_SILENCE)`: - use both options. That means: first remove tailing silence, then - check if the token still has a length of at least `min_length`. - - - + - `StreamTokenizer.STRICT_MIN_LENGTH | StreamTokenizer.DROP_TRAILING_SILENCE`: + Apply both `STRICT_MIN_LENGTH` and `DROP_TRAILING_SILENCE`. Examples -------- - - In the following code, without `STRICT_MIN_LENGTH`, the 'BB' token is - accepted although it is shorter than `min_length` (3), because it - immediately follows the latest delivered token: + In the following, without `STRICT_MIN_LENGTH`, the 'BB' token is + accepted even though it is shorter than `min_length` (3) because it + immediately follows the last delivered token: >>> from auditok.core import StreamTokenizer >>> from auditok.util import StringDataSource, DataValidator @@ -1136,42 +1203,43 @@ >>> class UpperCaseChecker(DataValidator): >>> def is_valid(self, frame): return frame.isupper() + >>> dsource = StringDataSource("aaaAAAABBbbb") - >>> tokenizer = StreamTokenizer(validator=UpperCaseChecker(), - min_length=3, - max_length=4, - max_continuous_silence=0) + >>> tokenizer = StreamTokenizer( + >>> validator=UpperCaseChecker(), + >>> min_length=3, + >>> max_length=4, + >>> max_continuous_silence=0 + >>> ) >>> tokenizer.tokenize(dsource) [(['A', 'A', 'A', 'A'], 3, 6), (['B', 'B'], 7, 8)] + Using `STRICT_MIN_LENGTH` mode rejects the 'BB' token: - The following tokenizer will however reject the 'BB' token: - - >>> dsource = StringDataSource("aaaAAAABBbbb") - >>> tokenizer = StreamTokenizer(validator=UpperCaseChecker(), - min_length=3, max_length=4, - max_continuous_silence=0, - mode=StreamTokenizer.STRICT_MIN_LENGTH) + >>> tokenizer = StreamTokenizer( + >>> validator=UpperCaseChecker(), + >>> min_length=3, + >>> max_length=4, + >>> max_continuous_silence=0, + >>> mode=StreamTokenizer.STRICT_MIN_LENGTH + >>> ) >>> tokenizer.tokenize(dsource) [(['A', 'A', 'A', 'A'], 3, 6)] - + With `DROP_TRAILING_SILENCE`, trailing silence is removed if not truncated: >>> tokenizer = StreamTokenizer( - >>> validator=UpperCaseChecker(), - >>> min_length=3, - >>> max_length=6, - >>> max_continuous_silence=3, - >>> mode=StreamTokenizer.DROP_TRAILING_SILENCE - >>> ) + >>> validator=UpperCaseChecker(), + >>> min_length=3, + >>> max_length=6, + >>> max_continuous_silence=3, + >>> mode=StreamTokenizer.DROP_TRAILING_SILENCE + >>> ) >>> dsource = StringDataSource("aaaAAAaaaBBbbbb") >>> tokenizer.tokenize(dsource) [(['A', 'A', 'A', 'a', 'a', 'a'], 3, 8), (['B', 'B'], 9, 10)] - The first token is delivered with its tailing silence because it is - truncated while the second one has its tailing frames removed. - - Without `StreamTokenizer.DROP_TRAILING_SILENCE` the output would be: + Without `DROP_TRAILING_SILENCE`, the output includes trailing frames: .. code:: python @@ -1179,7 +1247,6 @@ (['A', 'A', 'A', 'a', 'a', 'a'], 3, 8), (['B', 'B', 'b', 'b', 'b'], 9, 13) ] - """ SILENCE = 0 @@ -1272,32 +1339,41 @@ def tokenize(self, data_source, callback=None, generator=False): """ - Read data from `data_source`, one frame a time, and process the read - frames in order to detect sequences of frames that make up valid - tokens. + Read data from `data_source` one frame at a time and process each frame + to detect sequences that form valid tokens. - :Parameters: - `data_source` : instance of the :class:`DataSource` class that - implements a `read` method. 'read' should return a slice of - signal, i.e. frame (of whatever type as long as it can be - processed by validator) and None if there is no more signal. + Parameters + ---------- + data_source : DataSource + An instance of the :class:`DataSource` class that implements a `read` + method. `read` should return a slice of the signal (a frame of any + type that can be processed by the validator) or None when there is no + more data in the source. - `callback` : an optional 3-argument function. - If a `callback` function is given, it will be called each time - a valid token is found. + callback : callable, optional + A function that takes three arguments. If provided, `callback` is + called each time a valid token is detected. + generator : bool, optional, default=False + If True, the method yields tokens as they are detected, rather than + returning a list. If False, a list of tokens is returned. - :Returns: - A list of tokens if `callback` is None. Each token is tuple with the - following elements: + Returns + ------- + list of tuples or generator + A list of tokens if `generator` is False, or a generator yielding + tokens if `generator` is True. Each token is a tuple with the + following structure: - .. code python + .. code:: python (data, start, end) - where `data` is a list of read frames, `start`: index of the first - frame in the original data and `end` : index of the last frame. + where `data` is a list of frames in the token, `start` is the index + of the first frame in the original data, and `end` is the index of + the last frame. """ + token_gen = self._iter_tokens(data_source) if callback: for token in token_gen:
--- a/auditok/dataset.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/dataset.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,5 +1,5 @@ """ -This module contains links to audio files that can be used for test purposes. +A module that contains links to audio files that can be used for test purposes. .. autosummary:: :toctree: generated/
--- a/auditok/exceptions.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/exceptions.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,14 +1,14 @@ -class DuplicateArgument(Exception): - pass +class ArgumentError(Exception): + """Raised when command line arguments have invalid values.""" -class TooSamllBlockDuration(ValueError): +class TooSmallBlockDuration(ValueError): """Raised when block_dur results in a block_size smaller than one sample.""" def __init__(self, message, block_dur, sampling_rate): self.block_dur = block_dur self.sampling_rate = sampling_rate - super(TooSamllBlockDuration, self).__init__(message) + super().__init__(message) class TimeFormatError(Exception):
--- a/auditok/io.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/io.py Wed Oct 30 17:17:59 2024 +0000 @@ -15,12 +15,12 @@ to_file player_for """ + import os import sys import wave -import warnings from abc import ABC, abstractmethod -from functools import partial + from .exceptions import AudioIOError, AudioParameterError try: @@ -71,9 +71,30 @@ ) -def _guess_audio_format(fmt, filename): +def _guess_audio_format(filename, fmt): + """Guess the audio format from a file extension or normalize a provided + format. + + This helper function attempts to determine the audio format based on the + file extension of `filename` or by normalizing the format specified by the + user in `fmt`. + + Parameters + ---------- + filename : str or Path + The audio file name, including its extension. + fmt : str + The un-normalized format provided by the user. + + Returns + ------- + str or None + The guessed audio format as a string, or None if no format could be + determined. + """ + if fmt is None: - extension = os.path.splitext(filename.lower())[1][1:] + extension = os.path.splitext(filename)[1][1:].lower() if extension: fmt = extension else: @@ -86,38 +107,47 @@ def _get_audio_parameters(param_dict): """ - Get audio parameters from a dictionary of parameters. An audio parameter can - have a long name or a short name. If the long name is present, the short - name will be ignored. If neither is present then `AudioParameterError` is - raised. + Retrieve audio parameters from a dictionary of parameters. - Expected parameters are: + Each audio parameter can have a long name or a short name. If both are + present, the long name takes precedence. If neither is found, an + `AudioParameterError` is raised. - - `sampling_rate`, `sr` : int, sampling rate. + Expected parameters: + - `sampling_rate`, `sr` : int, the sampling rate. + - `sample_width`, `sw` : int, the sample size in bytes. + - `channels`, `ch` : int, the number of audio channels. - - `sample_width`, `sw` : int, sample size in bytes. - - - `channels`, `ch` : int, number of channels. + Parameters + ---------- + param_dict : dict + A dictionary containing audio parameters, with possible keys as + defined above. Returns ------- - audio_parameters : tuple - a tuple for audio parameters as (sampling_rate, sample_width, channels). + tuple + A tuple containing audio parameters as + (sampling_rate, sample_width, channels). + + Raises + ------ + AudioParameterError + If a required parameter is missing, is not an integer, or is not a + positive value. """ - err_message = ( - "'{ln}' (or '{sn}') must be a positive integer, found: '{val}'" - ) + parameters = [] - for (long_name, short_name) in ( + for long_name, short_name in ( ("sampling_rate", "sr"), ("sample_width", "sw"), ("channels", "ch"), ): param = param_dict.get(long_name, param_dict.get(short_name)) if param is None or not isinstance(param, int) or param <= 0: - raise AudioParameterError( - err_message.format(ln=long_name, sn=short_name, val=param) - ) + err_message = f"{long_name!r} (or {short_name!r}) must be a " + err_message += f"positive integer, passed value: {param}." + raise AudioParameterError(err_message) parameters.append(param) sampling_rate, sample_width, channels = parameters return sampling_rate, sample_width, channels @@ -127,21 +157,25 @@ """ Base class for audio source objects. - Subclasses should implement methods to open/close and audio stream - and read the desired amount of audio samples. + This class provides a foundation for audio source objects. Subclasses are + expected to implement methods to open and close an audio stream, as well as + to read the desired number of audio samples. Parameters ---------- sampling_rate : int - number of samples per second of audio data. + The number of samples per second of audio data. sample_width : int - size in bytes of one audio sample. Possible values: 1, 2 or 4. + The size, in bytes, of each audio sample. Accepted values are 1, 2, or 4. channels : int - number of channels of audio data. + The number of audio channels. """ def __init__( - self, sampling_rate, sample_width, channels, + self, + sampling_rate, + sample_width, + channels, ): if sample_width not in (1, 2, 4): @@ -167,23 +201,26 @@ @abstractmethod def read(self, size): - """ - Read and return `size` audio samples at most. + """Read and return up to `size` audio samples. + + This abstract method reads audio data and returns it as a bytes object, + containing at most `size` samples. Parameters - ----------- + ---------- size : int - Number of samples to read. + The number of samples to read. Returns ------- - data : bytes - Audio data as a bytes object of length `N * sample_width * channels` - where `N` equals: + bytes + A bytes object containing the audio data, with a length of + `N * sample_width * channels`, where `N` is: - - `size` if `size` <= remaining samples - - - remaining samples if `size` > remaining samples + - `size`, if `size` is less than or equal to the number of remaining + samples + - the number of remaining samples, if `size` exceeds the remaining + samples """ @property @@ -218,14 +255,25 @@ """Number of channels in audio stream (alias for `channels`).""" return self.channels + def __str__(self): + return f"{self.__class__.__name__}(sampling_rate={self.sr}, sampling_rate={self.sw}, channels={self.ch})" # noqa: B950 + + def __repr__(self): + return ( + f"<{self.__class__.__name__}(" + f"sampling_rate={self.sampling_rate!r}, " + f"sample_width={self.sample_width!r}, " + f"channels={self.channels!r})>" + ) + class Rewindable(AudioSource): - """ - Base class for rewindable audio streams. + """Base class for rewindable audio sources. - Subclasses should implement a method to return back to the start of an the - stream (`rewind`), as well as a property getter/setter named `position` that - reads/sets stream position expressed in number of samples. + This class serves as a base for audio sources that support rewinding. + Subclasses should implement a method to return to the beginning of the + stream (`rewind`), and provide a property `position` that allows getting + and setting the current stream position, expressed in number of samples. """ @abstractmethod @@ -266,26 +314,32 @@ class BufferAudioSource(Rewindable): - """ - An `AudioSource` that encapsulates and reads data from a memory buffer. + """An `AudioSource` that reads audio data from a memory buffer. - This class implements the `Rewindable` interface. + This class implements the `Rewindable` interface, allowing audio data + stored in a buffer to be read with support for rewinding and position + control. + Parameters ---------- data : bytes - audio data - sampling_rate : int, default: 16000 - number of samples per second of audio data. - sample_width : int, default: 2 - size in bytes of one audio sample. Possible values: 1, 2 or 4. - channels : int, default: 1 - number of channels of audio data. + The audio data stored in a memory buffer. + sampling_rate : int, optional, default=16000 + The number of samples per second of audio data. + sample_width : int, optional, default=2 + The size in bytes of one audio sample. Accepted values are 1, 2, or 4. + channels : int, optional, default=1 + The number of audio channels. """ def __init__( - self, data, sampling_rate=16000, sample_width=2, channels=1, + self, + data, + sampling_rate=16000, + sample_width=2, + channels=1, ): - AudioSource.__init__(self, sampling_rate, sample_width, channels) + super().__init__(sampling_rate, sample_width, channels) check_audio_data(data, sample_width, channels) self._data = data self._sample_size_all_channels = sample_width * channels @@ -355,21 +409,23 @@ class FileAudioSource(AudioSource): - """ - Base class `AudioSource`s that read audio data from a file. + """Base class for `AudioSource`s that read audio data from a file. + + This class provides a foundation for audio sources that retrieve audio data + from file sources. Parameters ---------- - sampling_rate : int, default: 16000 - number of samples per second of audio data. - sample_width : int, default: 2 - size in bytes of one audio sample. Possible values: 1, 2 or 4. - channels : int, default: 1 - number of channels of audio data. + sampling_rate : int, optional, default=16000 + The number of samples per second of audio data. + sample_width : int, optional, default=2 + The size in bytes of one audio sample. Accepted values are 1, 2, or 4. + channels : int, optional, default=1 + The number of audio channels. """ def __init__(self, sampling_rate, sample_width, channels): - AudioSource.__init__(self, sampling_rate, sample_width, channels) + super().__init__(sampling_rate, sample_width, channels) self._audio_stream = None def __del__(self): @@ -399,33 +455,32 @@ class RawAudioSource(FileAudioSource): """ - A class for an `AudioSource` that reads data from a raw (headerless) audio - file. + An `AudioSource` class for reading data from a raw (headerless) audio file. - This class should be used for large raw audio files to avoid loading the - whole data to memory. + This class is suitable for large raw audio files, allowing for efficient + data handling without loading the entire file into memory. Parameters ---------- - filename : str - path to a raw audio file. + filename : str or Path + The path to the raw audio file. sampling_rate : int - Number of samples per second of audio data. + The number of samples per second of audio data. sample_width : int - Size in bytes of one audio sample. Possible values : 1, 2, 4. + The size in bytes of each audio sample. Accepted values are 1, 2, or 4. channels : int - Number of channels of audio data. + The number of audio channels. """ - def __init__(self, file, sampling_rate, sample_width, channels): - FileAudioSource.__init__(self, sampling_rate, sample_width, channels) - self._file = file + def __init__(self, filename, sampling_rate, sample_width, channels): + super().__init__(sampling_rate, sample_width, channels) + self._filename = filename self._audio_stream = None self._sample_size = sample_width * channels def open(self): if self._audio_stream is None: - self._audio_stream = open(self._file, "rb") + self._audio_stream = open(self._filename, "rb") def _read_from_stream(self, size): if size is None or size < 0: @@ -438,23 +493,22 @@ class WaveAudioSource(FileAudioSource): """ - A class for an `AudioSource` that reads data from a wave file. + An `AudioSource` class for reading data from a wave file. - This class should be used for large wave files to avoid loading the whole - data to memory. + This class is suitable for large wave files, allowing for efficient data + handling without loading the entire file into memory. Parameters ---------- - filename : str - path to a valid wave file. + filename : str or Path + The path to a valid wave file. """ def __init__(self, filename): - self._filename = filename + self._filename = str(filename) # wave requires an str filename self._audio_stream = None stream = wave.open(self._filename, "rb") - FileAudioSource.__init__( - self, + super().__init__( stream.getframerate(), stream.getsampwidth(), stream.getnchannels(), @@ -472,23 +526,25 @@ class PyAudioSource(AudioSource): - """ - A class for an `AudioSource` that reads data from built-in microphone using - PyAudio (https://people.csail.mit.edu/hubert/pyaudio/). + """An `AudioSource` class for reading data from a built-in microphone using + PyAudio. + + This class leverages PyAudio (https://people.csail.mit.edu/hubert/pyaudio/) + to capture audio data directly from a microphone. Parameters ---------- - sampling_rate : int, default: 16000 - number of samples per second of audio data. - sample_width : int, default: 2 - size in bytes of one audio sample. Possible values: 1, 2 or 4. - channels : int, default: 1 - number of channels of audio data. - frames_per_buffer : int, default: 1024 - PyAudio number of frames per buffer. - input_device_index: None or int, default: None - PyAudio index of audio device to read audio data from. If None default - device is used. + sampling_rate : int, optional, default=16000 + The number of samples per second of audio data. + sample_width : int, optional, default=2 + The size in bytes of each audio sample. Accepted values are 1, 2, or 4. + channels : int, optional, default=1 + The number of audio channels. + frames_per_buffer : int, optional, default=1024 + The number of frames per buffer, as specified by PyAudio. + input_device_index : int or None, optional, default=None + The PyAudio index of the audio device to read from. If None, the default + audio device is used. """ def __init__( @@ -500,7 +556,7 @@ input_device_index=None, ): - AudioSource.__init__(self, sampling_rate, sample_width, channels) + super().__init__(sampling_rate, sample_width, channels) self._chunk_size = frames_per_buffer self.input_device_index = input_device_index @@ -545,22 +601,28 @@ class StdinAudioSource(FileAudioSource): """ - A class for an `AudioSource` that reads data from standard input. + An `AudioSource` class for reading audio data from standard input. + + This class is designed to capture audio data directly from standard input, + making it suitable for streaming audio sources. Parameters ---------- - sampling_rate : int, default: 16000 - number of samples per second of audio data. - sample_width : int, default: 2 - size in bytes of one audio sample. Possible values: 1, 2 or 4. - channels : int, default: 1 - number of channels of audio data. + sampling_rate : int, optional, default=16000 + The number of samples per second of audio data. + sample_width : int, optional, default=2 + The size in bytes of each audio sample. Accepted values are 1, 2, or 4. + channels : int, optional, default=1 + The number of audio channels. """ def __init__( - self, sampling_rate=16000, sample_width=2, channels=1, + self, + sampling_rate=16000, + sample_width=2, + channels=1, ): - FileAudioSource.__init__(self, sampling_rate, sample_width, channels) + super().__init__(sampling_rate, sample_width, channels) self._is_open = False self._sample_size = sample_width * channels self._stream = sys.stdin.buffer @@ -595,22 +657,26 @@ class PyAudioPlayer: - """ - A class for audio playback using Pyaudio + """A class for audio playback using PyAudio. + + This class facilitates audio playback through the PyAudio library (https://people.csail.mit.edu/hubert/pyaudio/). Parameters ---------- - sampling_rate : int, default: 16000 - number of samples per second of audio data. - sample_width : int, default: 2 - size in bytes of one audio sample. Possible values: 1, 2 or 4. - channels : int, default: 1 - number of channels of audio data. + sampling_rate : int, optional, default=16000 + The number of samples per second of audio data. + sample_width : int, optional, default=2 + The size in bytes of each audio sample. Accepted values are 1, 2, or 4. + channels : int, optional, default=1 + The number of audio channels. """ def __init__( - self, sampling_rate=16000, sample_width=2, channels=1, + self, + sampling_rate=16000, + sample_width=2, + channels=1, ): if sample_width not in (1, 2, 4): raise ValueError("Sample width in bytes must be one of 1, 2 or 4") @@ -640,7 +706,7 @@ chunk_gen, total=nb_chunks, duration=duration, - **progress_bar_kwargs + **progress_bar_kwargs, ) if self.stream.is_stopped(): self.stream.start_stream() @@ -674,21 +740,26 @@ def player_for(source): """ - Return an `AudioPlayer` compatible with `source` (i.e., has the same - sampling rate, sample width and number of channels). + Return an `AudioPlayer` compatible with the specified `source`. + + This function creates an `AudioPlayer` instance (currently only + `PyAudioPlayer` is implemented) that matches the audio properties of the + provided `source`, ensuring compatibility in terms of sampling rate, sample + width, and number of channels. Parameters ---------- source : AudioSource - An object that has `sampling_rate`, `sample_width` and `sample_width` + An object with `sampling_rate`, `sample_width`, and `channels` attributes. Returns ------- - player : PyAudioPlayer - An audio player that has the same sampling rate, sample width - and number of channels as `source`. + PyAudioPlayer + An audio player with the same sampling rate, sample width, and number + of channels as `source`. """ + return PyAudioPlayer( source.sampling_rate, source.sample_width, source.channels ) @@ -696,30 +767,37 @@ def get_audio_source(input=None, **kwargs): """ - Create and return an AudioSource from input. + Create and return an `AudioSource` based on the specified input. + + This function generates an `AudioSource` instance from various input types, + allowing flexibility for audio data sources such as file paths, raw data, + standard input, or microphone input via PyAudio. Parameters ---------- - input : str, bytes, "-" or None (default) - source to read audio data from. If `str`, it should be a path to a valid - audio file. If `bytes`, it is used as raw audio data. If it is "-", - raw data will be read from stdin. If None, read audio data from the - microphone using PyAudio. - kwargs - audio parameters used to build the `AudioSource` object. Depending on - the nature of `input`, theses may be omitted (e.g., when `input` is an - audio file in a popular audio format such as wav, ogg, flac, etc.) or - include parameters such as `sampling_rate`, `sample_width`, `channels` - (or their respective short name versions `sr`, `sw` and `ch`) if `input` - is a path to a raw (headerless) audio file, a bytes object for raw audio - data or None (to read data from built-in microphone). See the respective - `AudioSource` classes from more information about possible parameters. + input : str, bytes, "-", or None, optional + The source to read audio data from. Possible values are: + - `str`: Path to a valid audio file. + - `bytes`: Raw audio data. + - "-": Read raw data from standard input. + - None (default): Read audio data from the microphone using PyAudio. + kwargs : dict, optional + Additional audio parameters used to construct the `AudioSource` object. + Depending on the `input` type, these may be optional (e.g., for common + audio file formats such as wav, ogg, or flac). When required, parameters + include `sampling_rate`, `sample_width`, `channels`, or their short + forms `sr`, `sw`, and `ch`. These parameters are typically needed when + `input` is a path to a raw audio file, a bytes object with raw audio + data, or None (for microphone input). See respective `AudioSource` + classes for detailed parameter requirements. Returns ------- - source : AudioSource - audio source created from input parameters + AudioSource + An audio source created based on the specified input and audio + parameters. """ + if input == "-": return StdinAudioSource(*_get_audio_parameters(kwargs)) @@ -737,35 +815,44 @@ return PyAudioSource( *_get_audio_parameters(kwargs), frames_per_buffer=frames_per_buffer, - input_device_index=input_device_index + input_device_index=input_device_index, ) -def _load_raw(file, sampling_rate, sample_width, channels, large_file=False): +def _load_raw( + filename, sampling_rate, sample_width, channels, large_file=False +): """ - Load a raw audio file with standard Python. If `large_file` is True, return - a `RawAudioSource` object that reads data lazily from disk, otherwise load - all data to memory and return a `BufferAudioSource` object. + Load a raw audio file using standard Python file handling. + + This function loads audio data from a raw file. If `large_file` is set to + True, it returns a `RawAudioSource` object that reads data lazily from disk. + Otherwise, it loads all data into memory and returns a `BufferAudioSource` + object. Parameters ---------- - file : str - path to a raw audio data file. + filename : str or Path + The path to the raw audio data file. sampling_rate : int - sampling rate of audio data. + The sampling rate of the audio data. sample_width : int - size in bytes of one audio sample. + The size, in bytes, of each audio sample. channels : int - number of channels of audio data. - large_file : bool - if True, return a `RawAudioSource` otherwise a `BufferAudioSource` - object. + The number of audio channels. + large_file : bool, optional + If True, a `RawAudioSource` is returned to allow lazy data loading from + disk. If False, returns a `BufferAudioSource` with all data loaded into + memory. Returns ------- - source : RawAudioSource or BufferAudioSource - an `AudioSource` that reads data from input file. + AudioSource + An `AudioSource` that reads data from the specified file. The source is + either a `RawAudioSource` (for lazy loading) or a `BufferAudioSource` + (for in-memory loading), depending on the value of `large_file`. """ + if None in (sampling_rate, sample_width, channels): raise AudioParameterError( "All audio parameters are required for raw audio files" @@ -773,13 +860,13 @@ if large_file: return RawAudioSource( - file, + filename, sampling_rate=sampling_rate, sample_width=sample_width, channels=channels, ) - with open(file, "rb") as fp: + with open(filename, "rb") as fp: data = fp.read() return BufferAudioSource( data, @@ -789,28 +876,35 @@ ) -def _load_wave(file, large_file=False): +def _load_wave(filename, large_file=False): """ - Load a wave audio file with standard Python. If `large_file` is True, return - a `WaveAudioSource` object that reads data lazily from disk, otherwise load - all data to memory and return a `BufferAudioSource` object. + Load a wave audio file using standard Python module `wave`. + + This function loads audio data from a wave (.wav) file. If `large_file` is + set to True, it returns a `WaveAudioSource` object that reads data lazily + from disk. Otherwise, it loads all data into memory and returns a + `BufferAudioSource` object. Parameters ---------- - file : str - path to a wav audio data file - large_file : bool - if True, return a `WaveAudioSource` otherwise a `BufferAudioSource` - object. + filename : str or Path + The path to the wave audio data file. + large_file : bool, optional + If True, a `WaveAudioSource` is returned to allow lazy data loading from + disk. If False, returns a `BufferAudioSource` with all data loaded into + memory. Returns ------- - source : WaveAudioSource or BufferAudioSource - an `AudioSource` that reads data from input file. + AudioSource + An `AudioSource` that reads data from the specified file. The source is + either a `WaveAudioSource` (for lazy loading) or a `BufferAudioSource` + (for in-memory loading), depending on the value of `large_file`. """ + if large_file: - return WaveAudioSource(file) - with wave.open(file) as fp: + return WaveAudioSource(filename) + with wave.open(str(filename)) as fp: channels = fp.getnchannels() srate = fp.getframerate() swidth = fp.getsampwidth() @@ -820,30 +914,33 @@ ) -def _load_with_pydub(file, audio_format=None): +def _load_with_pydub(filename, audio_format=None): """ - Open compressed audio or video file using pydub. If a video file - is passed, its audio track(s) are extracted and loaded. + Load audio from a compressed audio or video file using `pydub`. + + This function uses `pydub` to load compressed audio files. If a video file + is specified, the audio track(s) are extracted and loaded. Parameters ---------- - file : str - path to audio file. - audio_format : str, default: None - string, audio/video file format if known (e.g. raw, webm, wav, ogg) + filename : str or Path + The path to the audio file. + audio_format : str, optional, default=None + The audio file format, if known (e.g., raw, webm, wav, ogg). Returns ------- - source : BufferAudioSource - an `AudioSource` that reads data from input file. + BufferAudioSource + An `AudioSource` that reads data from the specified file. """ + func_dict = { "mp3": AudioSegment.from_mp3, "ogg": AudioSegment.from_ogg, "flv": AudioSegment.from_flv, } open_function = func_dict.get(audio_format, AudioSegment.from_file) - segment = open_function(file) + segment = open_function(filename) return BufferAudioSource( data=segment.raw_data, sampling_rate=segment.frame_rate, @@ -853,63 +950,61 @@ def from_file(filename, audio_format=None, large_file=False, **kwargs): - """ - Read audio data from `filename` and return an `AudioSource` object. - if `audio_format` is None, the appropriate `AudioSource` class is guessed - from file's extension. `filename` can be a compressed audio or video file. - This will require installing `pydub` (https://github.com/jiaaro/pydub). + """Read audio data from `filename` and return an `AudioSource` object. - The normal behavior is to load all audio data to memory from which a - :class:`BufferAudioSource` object is created. This should be convenient - most of the time unless audio file is very large. In that case, and - in order to load audio data in lazy manner (i.e. read data from disk each - time :func:`AudioSource.read` is called), `large_file` should be True. + If `audio_format` is None, the appropriate `AudioSource` class is inferred + from the file extension. The `filename` can refer to a compressed audio or + video file; if a video file is provided, its audio track(s) are extracted. + This functionality requires `pydub` (https://github.com/jiaaro/pydub). - Note that the current implementation supports only wave and raw formats for - lazy audio loading. + By default, all audio data is loaded into memory to create a + `BufferAudioSource` object, suitable for most cases. For very large files, + set `large_file=True` to enable lazy loading, which reads audio data from + disk each time `AudioSource.read` is called. Currently, lazy loading + supports only wave and raw formats. - If an audio format is `raw`, the following keyword arguments are required: + If `audio_format` is `raw`, the following keyword arguments are required: - - `sampling_rate`, `sr`: int, sampling rate of audio data. + - `sampling_rate`, `sr`: int, sampling rate of audio data. - `sample_width`, `sw`: int, size in bytes of one audio sample. - `channels`, `ch`: int, number of channels of audio data. - See also + See Also -------- - :func:`to_file`. + to_file : A related function for saving audio data to a file. Parameters ---------- - filename : str - path to input audio or video file. - audio_format : str - audio format used to save data (e.g. raw, webm, wav, ogg). - large_file : bool, default: False - if True, audio won't fully be loaded to memory but only when a window - is read from disk. - + filename : str or Path + The path to the input audio or video file. + audio_format : str, optional + The audio format (e.g., raw, webm, wav, ogg). + large_file : bool, optional, default=False + If True, the audio data is read lazily from disk rather than being + fully loaded into memory. Other Parameters ---------------- - sampling_rate, sr: int - sampling rate of audio data + sampling_rate, sr : int + The sampling rate of the audio data. sample_width : int - sample width (i.e. number of bytes used to represent one audio sample) + The sample width in bytes (i.e., number of bytes per audio sample). channels : int - number of channels of audio data + The number of audio channels. Returns ------- - audio_source : AudioSource - an :class:`AudioSource` object that reads data from input file. + AudioSource + An `AudioSource` object that reads data from the specified file. Raises ------ - `AudioIOError` - raised if audio data cannot be read in the given - format or if `format` is `raw` and one or more audio parameters are missing. + AudioIOError + If audio data cannot be read in the given format or if `audio_format` + is `raw` and one or more required audio parameters are missing. """ - audio_format = _guess_audio_format(audio_format, filename) + + audio_format = _guess_audio_format(filename, audio_format) if audio_format == "raw": srate, swidth, channels = _get_audio_parameters(kwargs) @@ -930,23 +1025,57 @@ def _save_raw(data, file): """ - Saves audio data as a headerless (i.e. raw) file. - See also :func:`to_file`. + Save audio data as a headerless (raw) file. + + This function writes audio data to a file in raw format, without any header + information. + + Parameters + ---------- + data : bytes + The audio data to be saved. + file : str or Path + The path to the file where audio data will be saved. + + See Also + -------- + to_file : A related function for saving audio data in various formats. """ + with open(file, "wb") as fp: fp.write(data) def _save_wave(data, file, sampling_rate, sample_width, channels): """ - Saves audio data to a wave file. - See also :func:`to_file`. + Save audio data to a wave file. + + This function writes audio data to a file in the wave format, including + header information based on the specified audio parameters. + + Parameters + ---------- + data : bytes + The audio data to be saved. + file : str or Path + The path to the file where audio data will be saved. + sampling_rate : int + The sampling rate of the audio data. + sample_width : int + The size, in bytes, of each audio sample. + channels : int + The number of audio channels. + + See Also + -------- + to_file : A related function for saving audio data in various formats. """ + if None in (sampling_rate, sample_width, channels): raise AudioParameterError( "All audio parameters are required to save wave audio files" ) - with wave.open(file, "w") as fp: + with wave.open(str(file), "w") as fp: fp.setframerate(sampling_rate) fp.setsampwidth(sample_width) fp.setnchannels(channels) @@ -957,9 +1086,31 @@ data, file, audio_format, sampling_rate, sample_width, channels ): """ - Saves audio data with pydub (https://github.com/jiaaro/pydub). - See also :func:`to_file`. + Save audio data using pydub. + + This function saves audio data to a file in various formats supported by + pydub (https://github.com/jiaaro/pydub), such as mp3, wav, ogg, etc. + + Parameters + ---------- + data : bytes + The audio data to be saved. + file : str or Path + The path to the file where audio data will be saved. + audio_format : str + The audio format to save the file in (e.g., mp3, wav, ogg). + sampling_rate : int + The sampling rate of the audio data. + sample_width : int + The size, in bytes, of each audio sample. + channels : int + The number of audio channels. + + See Also + -------- + to_file : A related function for saving audio data in various formats. """ + segment = AudioSegment( data, frame_rate=sampling_rate, @@ -970,52 +1121,52 @@ segment.export(fp, format=audio_format) -def to_file(data, file, audio_format=None, **kwargs): +def to_file(data, filename, audio_format=None, **kwargs): """ - Writes audio data to file. If `audio_format` is `None`, output - audio format will be guessed from extension. If `audio_format` - is `None` and `file` comes without an extension then audio - data will be written as a raw audio file. + Write audio data to a file. + + This function writes audio data to a file in the specified format. If + `audio_format` is None, the output format will be inferred from the file + extension. If `audio_format` is None and `filename` has no extension, + the data will be saved as a raw audio file. Parameters ---------- data : bytes-like - audio data to be written. Can be a `bytes`, `bytearray`, - `memoryview`, `array` or `numpy.ndarray` object. - file : str - path to output audio file. - audio_format : str - audio format used to save data (e.g. raw, webm, wav, ogg) - kwargs: dict - If an audio format other than `raw` is used, the following keyword - arguments are required: + The audio data to be written. Accepts `bytes`, `bytearray`, `memoryview`, + `array`, or `numpy.ndarray` objects. + filename : str or Path + The path to the output audio file. + audio_format : str, optional + The audio format to use for saving the data (e.g., raw, webm, wav, ogg). + kwargs : dict, optional + Additional parameters required for non-raw audio formats: - - `sampling_rate`, `sr`: int, sampling rate of audio data. - - `sample_width`, `sw`: int, size in bytes of one audio sample. - - `channels`, `ch`: int, number of channels of audio data. + - `sampling_rate`, `sr` : int, the sampling rate of the audio data. + - `sample_width`, `sw` : int, the size in bytes of one audio sample. + - `channels`, `ch` : int, the number of audio channels. Raises ------ - `AudioParameterError` if output format is different than raw and one or more - audio parameters are missing. `AudioIOError` if audio data cannot be written - in the desired format. + AudioParameterError + Raised if the output format is not raw and one or more required audio + parameters are missing. + AudioIOError + Raised if the audio data cannot be written in the specified format. """ - audio_format = _guess_audio_format(audio_format, file) + + audio_format = _guess_audio_format(filename, audio_format) if audio_format in (None, "raw"): - _save_raw(data, file) + _save_raw(data, filename) return - try: - sampling_rate, sample_width, channels = _get_audio_parameters(kwargs) - except AudioParameterError as exc: - err_message = "All audio parameters are required to save formats " - "other than raw. Error detail: {}".format(exc) - raise AudioParameterError(err_message) + sampling_rate, sample_width, channels = _get_audio_parameters(kwargs) if audio_format in ("wav", "wave"): - _save_wave(data, file, sampling_rate, sample_width, channels) + _save_wave(data, filename, sampling_rate, sample_width, channels) elif _WITH_PYDUB: _save_with_pydub( - data, file, audio_format, sampling_rate, sample_width, channels + data, filename, audio_format, sampling_rate, sample_width, channels ) else: - err_message = "cannot write file format {} (file name: {})" - raise AudioIOError(err_message.format(audio_format, file)) + raise AudioIOError( + f"cannot write file format {audio_format} (file name: {filename})" + )
--- a/auditok/plotting.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/plotting.py Wed Oct 30 17:17:59 2024 +0000 @@ -40,7 +40,7 @@ ls = theme.get("linestyle", theme.get("ls")) lw = theme.get("linewidth", theme.get("lw")) alpha = theme.get("alpha") - for (start, end) in detections: + for start, end in detections: subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha)
--- a/auditok/signal.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/signal.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,179 +1,119 @@ """ -Module for basic audio signal processing and array operations. +Module for main signal processing operations. .. autosummary:: :toctree: generated/ to_array - extract_single_channel - compute_average_channel - compute_average_channel_stereo - separate_channels - calculate_energy_single_channel - calculate_energy_multichannel + calculate_energy """ -from array import array as array_ -import audioop -import math -FORMAT = {1: "b", 2: "h", 4: "i"} -_EPSILON = 1e-10 +import numpy as np + +__all__ = [ + "SAMPLE_WIDTH_TO_DTYPE", + "to_array", + "calculate_energy", +] + +SAMPLE_WIDTH_TO_DTYPE = {1: np.int8, 2: np.int16, 4: np.int32} +EPSILON = 1e-10 + + +def _get_numpy_dtype(sample_width): + """ + Helper function to convert a sample width to the corresponding NumPy data + type. + + Parameters + ---------- + sample_width : int + The width of the sample in bytes. Accepted values are 1, 2, or 4. + + Returns + ------- + numpy.dtype + The corresponding NumPy data type for the specified sample width. + + Raises + ------ + ValueError + If `sample_width` is not one of the accepted values (1, 2, or 4). + """ + + dtype = SAMPLE_WIDTH_TO_DTYPE.get(sample_width) + if dtype is None: + err_msg = "'sample_width' must be 1, 2 or 4, given: {}" + raise ValueError(err_msg.format(sample_width)) + return dtype def to_array(data, sample_width, channels): - """Extract individual channels of audio data and return a list of arrays of - numeric samples. This will always return a list of `array.array` objects - (one per channel) even if audio data is mono. + """ + Convert raw audio data into a NumPy array. + + This function transforms raw audio data, specified by sample width and + number of channels, into a 2-D NumPy array of `numpy.float64` data type. + The array will be arranged by channels and samples. Parameters ---------- data : bytes - raw audio data. + The raw audio data. sample_width : int - size in bytes of one audio sample (one channel considered). + The sample width (in bytes) of each audio sample. + channels : int + The number of audio channels. Returns ------- - samples_arrays : list - list of arrays of audio samples. + numpy.ndarray + A 2-D NumPy array representing the audio data. The shape of the array + will be (number of channels, number of samples), with data type + `numpy.float64`. + + Raises + ------ + ValueError + If `sample_width` is not an accepted value for conversion by the helper + function `_get_numpy_dtype`. """ - fmt = FORMAT[sample_width] - if channels == 1: - return [array_(fmt, data)] - return separate_channels(data, fmt, channels) + dtype = _get_numpy_dtype(sample_width) + array = np.frombuffer(data, dtype=dtype).astype(np.float64) + return array.reshape(channels, -1, order="F") -def extract_single_channel(data, fmt, channels, selected): - samples = array_(fmt, data) - return samples[selected::channels] +def calculate_energy(x, agg_fn=None): + """Calculate the energy of audio data. -def compute_average_channel(data, fmt, channels): - """ - Compute and return average channel of multi-channel audio data. If the - number of channels is 2, use :func:`compute_average_channel_stereo` (much - faster). This function uses satandard `array` module to convert `bytes` data - into an array of numeric values. + The energy is calculated as: + + .. math:: + \text{energy} = 20 \log\left(\sqrt{\frac{1}{N} \sum_{i=1}^{N} a_i^2}\right) % # noqa: W605 + + where `a_i` is the i-th audio sample and `N` is the total number of samples + in `x`. Parameters ---------- - data : bytes - multi-channel audio data to mix down. - fmt : str - format (single character) to pass to `array.array` to convert `data` - into an array of samples. This should be "b" if audio data's sample width - is 1, "h" if it's 2 and "i" if it's 4. - channels : int - number of channels of audio data. + x : array + Array of audio data, which may contain multiple channels. + agg_fn : callable, optional + Aggregation function to use for multi-channel data. If None, the energy + will be computed and returned separately for each channel. Returns ------- - mono_audio : bytes - mixed down audio data. + float or numpy.ndarray + The energy of the audio signal. If `x` is multichannel and `agg_fn` is + None, this will be an array of energies, one per channel. """ - all_channels = array_(fmt, data) - mono_channels = [ - array_(fmt, all_channels[ch::channels]) for ch in range(channels) - ] - avg_arr = array_( - fmt, - (round(sum(samples) / channels) for samples in zip(*mono_channels)), - ) - return avg_arr - -def compute_average_channel_stereo(data, sample_width): - """Compute and return average channel of stereo audio data. This function - should be used when the number of channels is exactly 2 because in that - case we can use standard `audioop` module which *much* faster then calling - :func:`compute_average_channel`. - - Parameters - ---------- - data : bytes - 2-channel audio data to mix down. - sample_width : int - size in bytes of one audio sample (one channel considered). - - Returns - ------- - mono_audio : bytes - mixed down audio data. - """ - fmt = FORMAT[sample_width] - arr = array_(fmt, audioop.tomono(data, sample_width, 0.5, 0.5)) - return arr - - -def separate_channels(data, fmt, channels): - """Create a list of arrays of audio samples (`array.array` objects), one for - each channel. - - Parameters - ---------- - data : bytes - multi-channel audio data to mix down. - fmt : str - format (single character) to pass to `array.array` to convert `data` - into an array of samples. This should be "b" if audio data's sample width - is 1, "h" if it's 2 and "i" if it's 4. - channels : int - number of channels of audio data. - - Returns - ------- - channels_arr : list - list of audio channels, each as a standard `array.array`. - """ - all_channels = array_(fmt, data) - mono_channels = [ - array_(fmt, all_channels[ch::channels]) for ch in range(channels) - ] - return mono_channels - - -def calculate_energy_single_channel(data, sample_width): - """Calculate the energy of mono audio data. Energy is computed as: - - .. math:: energy = 20 \log(\sqrt({1}/{N}\sum_{i}^{N}{a_i}^2)) % # noqa: W605 - - where `a_i` is the i-th audio sample and `N` is the number of audio samples - in data. - - Parameters - ---------- - data : bytes - single-channel audio data. - sample_width : int - size in bytes of one audio sample. - - Returns - ------- - energy : float - energy of audio signal. - """ - energy_sqrt = max(audioop.rms(data, sample_width), _EPSILON) - return 20 * math.log10(energy_sqrt) - - -def calculate_energy_multichannel(x, sample_width, aggregation_fn=max): - """Calculate the energy of multi-channel audio data. Energy is calculated - channel-wise. An aggregation function is applied to the resulting energies - (default: `max`). Also see :func:`calculate_energy_single_channel`. - - Parameters - ---------- - data : bytes - single-channel audio data. - sample_width : int - size in bytes of one audio sample (one channel considered). - aggregation_fn : callable, default: max - aggregation function to apply to the resulting per-channel energies. - - Returns - ------- - energy : float - aggregated energy of multi-channel audio signal. - """ - energies = (calculate_energy_single_channel(xi, sample_width) for xi in x) - return aggregation_fn(energies) + x = np.array(x).astype(np.float64) + energy_sqrt = np.sqrt(np.mean(x**2, axis=-1)) + energy_sqrt = np.clip(energy_sqrt, a_min=EPSILON, a_max=None) + energy = 20 * np.log10(energy_sqrt) + if agg_fn is not None: + energy = agg_fn(energy) + return energy
--- a/auditok/signal_numpy.py Thu Mar 30 10:17:57 2023 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,30 +0,0 @@ -import numpy as np -from .signal import ( - compute_average_channel_stereo, - calculate_energy_single_channel, - calculate_energy_multichannel, -) - -FORMAT = {1: np.int8, 2: np.int16, 4: np.int32} - - -def to_array(data, sample_width, channels): - fmt = FORMAT[sample_width] - if channels == 1: - return np.frombuffer(data, dtype=fmt).astype(np.float64) - return separate_channels(data, fmt, channels).astype(np.float64) - - -def extract_single_channel(data, fmt, channels, selected): - samples = np.frombuffer(data, dtype=fmt) - return np.asanyarray(samples[selected::channels], order="C") - - -def compute_average_channel(data, fmt, channels): - array = np.frombuffer(data, dtype=fmt).astype(np.float64) - return array.reshape(-1, channels).mean(axis=1).round().astype(fmt) - - -def separate_channels(data, fmt, channels): - array = np.frombuffer(data, dtype=fmt) - return np.asanyarray(array.reshape(-1, channels).T, order="C")
--- a/auditok/util.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/util.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,4 +1,6 @@ """ +Module for high-level audio input-output operations. + .. autosummary:: :toctree: generated/ @@ -8,28 +10,23 @@ make_duration_formatter make_channel_selector """ + +import warnings from abc import ABC, abstractmethod -import warnings from functools import partial + +import numpy as np + +from . import signal +from .exceptions import TimeFormatError, TooSmallBlockDuration from .io import ( AudioIOError, AudioSource, - from_file, BufferAudioSource, PyAudioSource, + from_file, get_audio_source, ) -from .exceptions import ( - DuplicateArgument, - TooSamllBlockDuration, - TimeFormatError, -) - -try: - from . import signal_numpy as signal -except ImportError: - from . import signal - __all__ = [ "make_duration_formatter", @@ -37,8 +34,6 @@ "DataSource", "DataValidator", "StringDataSource", - "ADSFactory", - "AudioDataSource", "AudioReader", "Recorder", "AudioEnergyValidator", @@ -47,73 +42,77 @@ def make_duration_formatter(fmt): """ - Make and return a function used to format durations in seconds. Accepted - format directives are: + Create and return a function to format durations in seconds using a + specified format. Accepted format directives are: - - ``%S`` : absolute number of seconds with 3 decimals. This direction should - be used alone. + - ``%S`` : absolute seconds with 3 decimals; must be used alone. - ``%i`` : milliseconds - ``%s`` : seconds - ``%m`` : minutes - ``%h`` : hours - These last 4 directives should all be specified. They can be placed anywhere - in the input string. + The last four directives (%i, %s, %m, %h) should all be specified and can + be placed in any order within the input format string. Parameters ---------- fmt : str - duration format. + Format string specifying the duration format. Returns ------- formatter : callable - a function that takes a duration in seconds (float) and returns a string - that corresponds to that duration. + A function that takes a duration in seconds (float) and returns a + formatted string. Raises ------ TimeFormatError - if the format contains an unknown directive. + Raised if the format contains an unknown directive. Examples -------- - - Using ``%S``: + Using ``%S`` for total seconds with three decimal precision: .. code:: python formatter = make_duration_formatter("%S") formatter(123.589) - '123.589' + # '123.589' formatter(123) - '123.000' + # '123.000' - Using the other directives: + Using combined directives: .. code:: python formatter = make_duration_formatter("%h:%m:%s.%i") - formatter(3600+120+3.25) - '01:02:03.250' + formatter(3723.25) + # '01:02:03.250' formatter = make_duration_formatter("%h hrs, %m min, %s sec and %i ms") - formatter(3600+120+3.25) - '01 hrs, 02 min, 03 sec and 250 ms' + formatter(3723.25) + # '01 hrs, 02 min, 03 sec and 250 ms' - # omitting one of the 4 directives might result in a wrong duration - formatter = make_duration_formatter("%m min, %s sec and %i ms") - formatter(3600+120+3.25) - '02 min, 03 sec and 250 ms' + Note: + Omitting any of the four main directives (%i, %s, %m, %h) may result + in incorrect formatting: + + .. code:: python + + formatter = make_duration_formatter("%m min, %s sec and %i ms") + formatter(3723.25) + # '02 min, 03 sec and 250 ms' """ + if fmt == "%S": - def fromatter(seconds): + def formatter(seconds): return "{:.3f}".format(seconds) elif fmt == "%I": - def fromatter(seconds): + def formatter(seconds): return "{0}".format(int(seconds * 1000)) else: @@ -129,71 +128,63 @@ except ValueError: pass - def fromatter(seconds): + def formatter(seconds): millis = int(seconds * 1000) hrs, millis = divmod(millis, 3600000) mins, millis = divmod(millis, 60000) secs, millis = divmod(millis, 1000) return fmt.format(hrs=hrs, mins=mins, secs=secs, millis=millis) - return fromatter + return formatter def make_channel_selector(sample_width, channels, selected=None): - """Create and return a callable used for audio channel selection. The - returned selector can be used as `selector(audio_data)` and returns data - that contains selected channel only. + """ + Create and return a callable for selecting a specific audio channel. The + returned `selector` function can be used as `selector(audio_data)` and + returns data for the specified channel. - Importantly, if `selected` is None or equals "any", `selector(audio_data)` - will separate and return a list of available channels: - `[data_channe_1, data_channe_2, ...].` + If `selected` is None or "any", the `selector` will separate and return a + list of available channels: `[data_channel_1, data_channel_2, ...]`. - Note also that returned `selector` expects `bytes` format for input data but - does notnecessarily return a `bytes` object. In fact, in order to extract - the desired channel (or compute the average channel if `selected` = "avg"), - it first converts input data into a `array.array` (or `numpy.ndarray`) - object. After channel of interst is selected/computed, it is returned as - such, without any reconversion to `bytes`. This behavior is wanted for - efficiency purposes because returned objects can be directly used as buffers - of bytes. In any case, returned objects can be converted back to `bytes` - using `bytes(obj)`. + Note that `selector` expects input data in `bytes` format but does not + necessarily return a `bytes` object. To select or compute the desired + channel (or average channel if `selected="avg"`), it converts the input + data into an `array.array` or `numpy.ndarray`. After selection, the data + is returned as is, without reconversion to `bytes`, for efficiency. The + output can be converted back to `bytes` with `bytes(obj)` if needed. - Exception to this is the special case where `channels` = 1 in which input - data is returned without any processing. - + Special case: If `channels=1`, the input data is returned without processing. Parameters ---------- sample_width : int - number of bytes used to encode one audio sample, should be 1, 2 or 4. + Number of bytes per audio sample; should be 1, 2, or 4. channels : int - number of channels of raw audio data that the returned selector should - expect. - selected : int or str, default: None - audio channel to select and return when calling `selector(raw_data)`. It - should be an int >= `-channels` and < `channels`. If one of "mix", - "avg" or "average" is passed then `selector` will return the average - channel of audio data. If None or "any", return a list of all available - channels at each call. + Number of channels in the audio data that the selector should expect. + selected : int or str, optional + Channel to select in each call to `selector(raw_data)`. Acceptable values: + - An integer in range [-channels, channels). + - "mix", "avg", or "average" for averaging across channels. + - None or "any" to return a list of all channels. Returns ------- selector : callable - a callable that can be used as `selector(audio_data)` and returns data - that contains channel of interst. + A function that can be called as `selector(audio_data)` and returns data + for the selected channel. Raises ------ ValueError - if `sample_width` is not one of 1, 2 or 4, or if `selected` has an - unexpected value. + If `sample_width` is not one of {1, 2, 4}, or if `selected` has an + unsupported value. """ - fmt = signal.FORMAT.get(sample_width) - if fmt is None: - err_msg = "'sample_width' must be 1, 2 or 4, given: {}" - raise ValueError(err_msg.format(sample_width)) - if channels == 1: - return lambda x: x + to_array_ = partial( + signal.to_array, sample_width=sample_width, channels=channels + ) + if channels == 1 or selected in (None, "any"): + return to_array_ if isinstance(selected, int): if selected < 0: @@ -202,27 +193,10 @@ err_msg = "Selected channel must be >= -channels and < channels" err_msg += ", given: {}" raise ValueError(err_msg.format(selected)) - return partial( - signal.extract_single_channel, - fmt=fmt, - channels=channels, - selected=selected, - ) + return lambda x: to_array_(x)[selected] if selected in ("mix", "avg", "average"): - if channels == 2: - # when data is stereo, using audioop when possible is much faster - return partial( - signal.compute_average_channel_stereo, - sample_width=sample_width, - ) - - return partial( - signal.compute_average_channel, fmt=fmt, channels=channels - ) - - if selected in (None, "any"): - return partial(signal.separate_channels, fmt=fmt, channels=channels) + return lambda x: to_array_(x).mean(axis=0) raise ValueError( "Selected channel must be an integer, None (alias 'any') or 'average' " @@ -232,29 +206,50 @@ class DataSource(ABC): """ - Base class for objects passed to :func:`StreamTokenizer.tokenize`. - Subclasses should implement a :func:`DataSource.read` method. + Base class for objects used as data sources in + :func:`StreamTokenizer.tokenize`. + + Subclasses should implement a :func:`DataSource.read` method, which is + expected to return a frame (or slice) of data from the source, and None + when there is no more data to read. """ @abstractmethod def read(self): """ - Read a block (i.e., window) of data read from this source. - If no more data is available, return None. + Read a block (or window) of data from this source. + + Returns + ------- + data : object or None + A block of data from the source. If no more data is available, + should return None. """ class DataValidator(ABC): """ - Base class for a validator object used by :class:`.core.StreamTokenizer` - to check if read data is valid. - Subclasses should implement :func:`is_valid` method. + Base class for validator objects used by :class:`.core.StreamTokenizer` + to verify the validity of read data. + + Subclasses should implement the :func:`is_valid` method to define the + specific criteria for data validity. """ @abstractmethod def is_valid(self, data): """ - Check whether `data` is valid + Determine whether the provided `data` meets validity criteria. + + Parameters + ---------- + data : object + The data to be validated. + + Returns + ------- + bool + True if `data` is valid, otherwise False. """ @@ -264,77 +259,79 @@ samples (see :func:`AudioEnergyValidator.is_valid`), the energy is computed as: - .. math:: energy = 20 \log(\sqrt({1}/{N}\sum_{i}^{N}{a_i}^2)) % # noqa: W605 + .. math:: + energy = 20 \log(\sqrt({1}/{N} \sum_{i=1}^{N} {a_i}^2)) % # noqa: W605 - where `a_i` is the i-th audio sample. + where `a_i` represents the i-th audio sample. Parameters ---------- energy_threshold : float - minimum energy that audio window should have to be valid. + Minimum energy required for an audio window to be considered valid. sample_width : int - size in bytes of one audio sample. + Size in bytes of a single audio sample. channels : int - number of channels of audio data. + Number of audio channels in the data. use_channel : {None, "any", "mix", "avg", "average"} or int - channel to use for energy computation. The following values are - accepted: + Specifies the channel used for energy computation: - - None (alias "any") : compute energy for each of the channels and return - the maximum value. - - "mix" (alias "avg" or "average") : compute the average channel then - compute its energy. - - int (>= 0 , < `channels`) : compute the energy of the specified channel - and ignore the other ones. + - None or "any": Compute energy for each channel and return the maximum. + - "mix" (or "avg" / "average"): Average across all channels, then + compute energy. + - int (0 <= value < `channels`): Compute energy for the specified channel + only, ignoring others. Returns ------- energy : float - energy of the audio window. + Computed energy of the audio window, used to validate if the window + meets the `energy_threshold`. """ def __init__( self, energy_threshold, sample_width, channels, use_channel=None ): + self._energy_threshold = energy_threshold self._sample_width = sample_width self._selector = make_channel_selector( sample_width, channels, use_channel ) - if channels == 1 or use_channel not in (None, "any"): - self._energy_fn = signal.calculate_energy_single_channel - else: - self._energy_fn = signal.calculate_energy_multichannel - self._energy_threshold = energy_threshold + self._energy_agg_fn = np.max if use_channel in (None, "any") else None def is_valid(self, data): """ + Determine if the audio data meets the energy threshold. Parameters ---------- data : bytes-like - array of raw audio data + An array of raw audio data. Returns ------- bool - True if the energy of audio data is >= threshold, False otherwise. + True if the energy of the audio data is greater than or equal to + the specified threshold; otherwise, False. """ - log_energy = self._energy_fn(self._selector(data), self._sample_width) + + log_energy = signal.calculate_energy( + self._selector(data), self._energy_agg_fn + ) return log_energy >= self._energy_threshold class StringDataSource(DataSource): """ - Class that represent a :class:`DataSource` as a string buffer. - Each call to :func:`DataSource.read` returns on character and moves one - step forward. If the end of the buffer is reached, :func:`read` returns + A :class:`DataSource` implementation that reads from a string buffer. + + Each call to :mrth:`read` returns one character from the buffer and advances + by one position. When the end of the buffer is reached, :meth:`read` returns None. Parameters ---------- data : str - a string object used as data. - + The string data to be used as the source. """ def __init__(self, data): @@ -350,7 +347,7 @@ Returns ------- char : str - current character or None if end of buffer is reached. + current character or None if the end of the buffer is reached. """ if self._current >= len(self._data): @@ -374,317 +371,6 @@ self._current = 0 -class ADSFactory: - """ - .. deprecated:: 2.0.0 - `ADSFactory` will be removed in auditok 2.0.1, use instances of - :class:`AudioReader` instead. - - Factory class that makes it easy to create an - :class:`AudioDataSource` object that implements - :class:`DataSource` and can therefore be passed to - :func:`auditok.core.StreamTokenizer.tokenize`. - - Whether you read audio data from a file, the microphone or a memory buffer, - this factory instantiates and returns the right - :class:`AudioDataSource` object. - - There are many other features you want a :class:`AudioDataSource` object to - have, such as: memorize all read audio data so that you can rewind and reuse - it (especially useful when reading data from the microphone), read a fixed - amount of data (also useful when reading from the microphone), read - overlapping audio frames (often needed when dosing a spectral analysis of - data). - - :func:`ADSFactory.ads` automatically creates and return object with the - desired behavior according to the supplied keyword arguments. - """ - - @staticmethod # noqa: C901 - def _check_normalize_args(kwargs): - - for k in kwargs: - if k not in [ - "block_dur", - "hop_dur", - "block_size", - "hop_size", - "max_time", - "record", - "audio_source", - "filename", - "data_buffer", - "frames_per_buffer", - "sampling_rate", - "sample_width", - "channels", - "sr", - "sw", - "ch", - "asrc", - "fn", - "fpb", - "db", - "mt", - "rec", - "bd", - "hd", - "bs", - "hs", - ]: - raise ValueError("Invalid argument: {0}".format(k)) - - if "block_dur" in kwargs and "bd" in kwargs: - raise DuplicateArgument( - "Either 'block_dur' or 'bd' must be specified, not both" - ) - - if "hop_dur" in kwargs and "hd" in kwargs: - raise DuplicateArgument( - "Either 'hop_dur' or 'hd' must be specified, not both" - ) - - if "block_size" in kwargs and "bs" in kwargs: - raise DuplicateArgument( - "Either 'block_size' or 'bs' must be specified, not both" - ) - - if "hop_size" in kwargs and "hs" in kwargs: - raise DuplicateArgument( - "Either 'hop_size' or 'hs' must be specified, not both" - ) - - if "max_time" in kwargs and "mt" in kwargs: - raise DuplicateArgument( - "Either 'max_time' or 'mt' must be specified, not both" - ) - - if "audio_source" in kwargs and "asrc" in kwargs: - raise DuplicateArgument( - "Either 'audio_source' or 'asrc' must be specified, not both" - ) - - if "filename" in kwargs and "fn" in kwargs: - raise DuplicateArgument( - "Either 'filename' or 'fn' must be specified, not both" - ) - - if "data_buffer" in kwargs and "db" in kwargs: - raise DuplicateArgument( - "Either 'filename' or 'db' must be specified, not both" - ) - - if "frames_per_buffer" in kwargs and "fbb" in kwargs: - raise DuplicateArgument( - "Either 'frames_per_buffer' or 'fpb' must be specified, not " - "both" - ) - - if "sampling_rate" in kwargs and "sr" in kwargs: - raise DuplicateArgument( - "Either 'sampling_rate' or 'sr' must be specified, not both" - ) - - if "sample_width" in kwargs and "sw" in kwargs: - raise DuplicateArgument( - "Either 'sample_width' or 'sw' must be specified, not both" - ) - - if "channels" in kwargs and "ch" in kwargs: - raise DuplicateArgument( - "Either 'channels' or 'ch' must be specified, not both" - ) - - if "record" in kwargs and "rec" in kwargs: - raise DuplicateArgument( - "Either 'record' or 'rec' must be specified, not both" - ) - - kwargs["bd"] = kwargs.pop("block_dur", None) or kwargs.pop("bd", None) - kwargs["hd"] = kwargs.pop("hop_dur", None) or kwargs.pop("hd", None) - kwargs["bs"] = kwargs.pop("block_size", None) or kwargs.pop("bs", None) - kwargs["hs"] = kwargs.pop("hop_size", None) or kwargs.pop("hs", None) - kwargs["mt"] = kwargs.pop("max_time", None) or kwargs.pop("mt", None) - kwargs["asrc"] = kwargs.pop("audio_source", None) or kwargs.pop( - "asrc", None - ) - kwargs["fn"] = kwargs.pop("filename", None) or kwargs.pop("fn", None) - kwargs["db"] = kwargs.pop("data_buffer", None) or kwargs.pop("db", None) - - record = kwargs.pop("record", False) - if not record: - record = kwargs.pop("rec", False) - if not isinstance(record, bool): - raise TypeError("'record' must be a boolean") - - kwargs["rec"] = record - - # keep long names for arguments meant for BufferAudioSource - # and PyAudioSource - if "frames_per_buffer" in kwargs or "fpb" in kwargs: - kwargs["frames_per_buffer"] = kwargs.pop( - "frames_per_buffer", None - ) or kwargs.pop("fpb", None) - - if "sampling_rate" in kwargs or "sr" in kwargs: - kwargs["sampling_rate"] = kwargs.pop( - "sampling_rate", None - ) or kwargs.pop("sr", None) - - if "sample_width" in kwargs or "sw" in kwargs: - kwargs["sample_width"] = kwargs.pop( - "sample_width", None - ) or kwargs.pop("sw", None) - - if "channels" in kwargs or "ch" in kwargs: - kwargs["channels"] = kwargs.pop("channels", None) or kwargs.pop( - "ch", None - ) - - @staticmethod - def ads(**kwargs): - """ - Create an return an :class:`AudioDataSource`. The type and - behavior of the object is the result - of the supplied parameters. Called without any parameters, the class - will read audio data from the available built-in microphone with the - default parameters. - - Parameters - ---------- - sampling_rate, sr : int, default: 16000 - number of audio samples per second of input audio stream. - sample_width, sw : int, default: 2 - number of bytes per sample, must be one of 1, 2 or 4 - channels, ch : int, default: 1 - number of audio channels, only a value of 1 is currently accepted. - frames_per_buffer, fpb : int, default: 1024 - number of samples of PyAudio buffer. - audio_source, asrc : `AudioSource` - `AudioSource` to read data from - filename, fn : str - create an `AudioSource` object using this file - data_buffer, db : str - build an `io.BufferAudioSource` using data in `data_buffer`. - If this keyword is used, - `sampling_rate`, `sample_width` and `channels` are passed to - `io.BufferAudioSource` constructor and used instead of default - values. - max_time, mt : float - maximum time (in seconds) to read. Default behavior: read until - there is no more data - available. - record, rec : bool, default = False - save all read data in cache. Provide a navigable object which has a - `rewind` method. - block_dur, bd : float - processing block duration in seconds. This represents the quantity - of audio data to return each time the :func:`read` method is - invoked. If `block_dur` is 0.025 (i.e. 25 ms) and the sampling rate - is 8000 and the sample width is 2 bytes, :func:`read` returns a - buffer of 0.025 * 8000 * 2 = 400 bytes at most. This parameter will - be looked for (and used if available) before `block_size`. If - neither parameter is given, `block_dur` will be set to 0.01 second - (i.e. 10 ms) - hop_dur, hd : float - quantity of data to skip from current processing window. if - `hop_dur` is supplied then there will be an overlap of `block_dur` - - `hop_dur` between two adjacent blocks. This parameter will be - looked for (and used if available) before `hop_size`. - If neither parameter is given, `hop_dur` will be set to `block_dur` - which means that there will be no overlap between two consecutively - read blocks. - block_size, bs : int - number of samples to read each time the `read` method is called. - Default: a block size that represents a window of 10ms, so for a - sampling rate of 16000, the default `block_size` is 160 samples, - for a rate of 44100, `block_size` = 441 samples, etc. - hop_size, hs : int - determines the number of overlapping samples between two adjacent - read windows. For a `hop_size` of value *N*, the overlap is - `block_size` - *N*. Default : `hop_size` = `block_size`, means that - there is no overlap. - - Returns - ------- - audio_data_source : AudioDataSource - an `AudioDataSource` object build with input parameters. - """ - warnings.warn( - "'ADSFactory' is deprecated and will be removed in a future " - "release. Please use AudioReader class instead.", - DeprecationWarning, - ) - - # check and normalize keyword arguments - ADSFactory._check_normalize_args(kwargs) - - block_dur = kwargs.pop("bd") - hop_dur = kwargs.pop("hd") - block_size = kwargs.pop("bs") - hop_size = kwargs.pop("hs") - max_time = kwargs.pop("mt") - audio_source = kwargs.pop("asrc") - filename = kwargs.pop("fn") - data_buffer = kwargs.pop("db") - record = kwargs.pop("rec") - - # Case 1: an audio source is supplied - if audio_source is not None: - if (filename, data_buffer) != (None, None): - raise Warning( - "You should provide one of 'audio_source', 'filename' or \ - 'data_buffer' keyword parameters. 'audio_source' will be \ - used" - ) - - # Case 2: a file name is supplied - elif filename is not None: - if data_buffer is not None: - raise Warning( - "You should provide one of 'filename' or 'data_buffer'\ - keyword parameters. 'filename' will be used" - ) - audio_source = from_file(filename) - - # Case 3: a data_buffer is supplied - elif data_buffer is not None: - audio_source = BufferAudioSource(data=data_buffer, **kwargs) - - # Case 4: try to access native audio input - else: - audio_source = PyAudioSource(**kwargs) - - if block_dur is not None: - if block_size is not None: - raise DuplicateArgument( - "Either 'block_dur' or 'block_size' can be specified, not \ - both" - ) - elif block_size is not None: - block_dur = block_size / audio_source.sr - else: - block_dur = 0.01 # 10 ms - - # Read overlapping blocks of data - if hop_dur is not None: - if hop_size is not None: - raise DuplicateArgument( - "Either 'hop_dur' or 'hop_size' can be specified, not both" - ) - elif hop_size is not None: - hop_dur = hop_size / audio_source.sr - - ads = AudioDataSource( - audio_source, - block_dur=block_dur, - hop_dur=hop_dur, - record=record, - max_read=max_time, - ) - return ads - - class _AudioReadingProxy: def __init__(self, audio_source): @@ -726,12 +412,14 @@ class _Recorder(_AudioReadingProxy): """ - Class for `AudioReader` objects that can record all data they read. Useful - when reading data from microphone. + A class for `AudioReader` objects that records all data read from the source. + + This class is particularly useful for capturing audio data when reading from + a microphone or similar live audio sources. """ def __init__(self, audio_source): - super(_Recorder, self).__init__(audio_source) + super().__init__(audio_source) self._cache = [] self._read_block = self._read_and_cache self._read_from_cache = False @@ -743,7 +431,7 @@ @property def data(self): if self._data is None: - err_msg = "Unrewinded recorder. `rewind` should be called before " + err_msg = "Un-rewinded recorder. `rewind` should be called before " err_msg += "accessing recorded data" raise RuntimeError(err_msg) return self._data @@ -774,13 +462,14 @@ class _Limiter(_AudioReadingProxy): """ - Class for `AudioReader` objects that can read a fixed amount of data. - This can be useful when reading data from the microphone or from large - audio files. + A class for `AudioReader` objects that restricts the amount of data read. + + This class is useful for limiting data intake when reading from a microphone + or large audio files, ensuring only a specified amount of data is processed. """ def __init__(self, audio_source, max_read): - super(_Limiter, self).__init__(audio_source) + super().__init__(audio_source) self._max_read = max_read self._max_samples = round(max_read * self.sr) self._bytes_per_sample = self.sw * self.ch @@ -807,17 +496,17 @@ return block def rewind(self): - super(_Limiter, self).rewind() + super().rewind() self._read_samples = 0 class _FixedSizeAudioReader(_AudioReadingProxy): """ - Class to read fixed-size audio windows from source. + A class to read fixed-size audio windows from a source. """ def __init__(self, audio_source, block_dur): - super(_FixedSizeAudioReader, self).__init__(audio_source) + super().__init__(audio_source) if block_dur <= 0: raise ValueError( @@ -829,7 +518,7 @@ err_msg = "Too small block_dur ({0:f}) for sampling rate ({1}). " err_msg += "block_dur should cover at least one sample " err_msg += "(i.e. 1/{1})" - raise TooSamllBlockDuration( + raise TooSmallBlockDuration( err_msg.format(block_dur, self.sr), block_dur, self.sr ) @@ -850,16 +539,19 @@ class _OverlapAudioReader(_FixedSizeAudioReader): """ - Class for `AudioReader` objects that can read and return overlapping audio + A class for `AudioReader` objects that reads and returns overlapping audio windows. + + Useful for applications requiring overlapping segments, such as audio + analysis or feature extraction. """ def __init__(self, audio_source, block_dur, hop_dur): if hop_dur >= block_dur: - raise ValueError('"hop_dur" should be < "block_dur"') + raise ValueError('"hop_dur" should be <= "block_dur"') - super(_OverlapAudioReader, self).__init__(audio_source, block_dur) + super().__init__(audio_source, block_dur) self._hop_size = int(hop_dur * self.sr) self._blocks = self._iter_blocks_with_overlap() @@ -896,7 +588,7 @@ return None def rewind(self): - super(_OverlapAudioReader, self).rewind() + super().rewind() self._blocks = self._iter_blocks_with_overlap() @property @@ -913,86 +605,74 @@ class AudioReader(DataSource): """ - Class to read fixed-size chunks of audio data from a source. A source can - be a file on disk, standard input (with `input` = "-") or microphone. This - is normally used by tokenization algorithms that expect source objects with - a `read` function that returns a windows of data of the same size at each - call expect when remaining data does not make up a full window. + A class to read fixed-size chunks of audio data from a source, which can + be a file, standard input (with `input` set to "-"), or a microphone. + Typically used by tokenization algorithms that require source objects with + a `read` function to return data windows of consistent size, except for + the last window if remaining data is insufficient. - Objects of this class can be set up to return audio windows with a given - overlap and to record the whole stream for later access (useful when - reading data from the microphone). They can also have - a limit for the maximum amount of data to read. + This class supports overlapping audio windows, recording the audio stream + for later access (useful for microphone input), and limiting the maximum + amount of data read. Parameters ---------- - input : str, bytes, AudioSource, AudioReader, AudioRegion or None - input audio data. If the type of the passed argument is `str`, it should - be a path to an existing audio file. "-" is interpreted as standardinput. - If the type is `bytes`, input is considered as a buffer of raw audio - data. If None, read audio from microphone. Every object that is not an - :class:`AudioReader` will be transformed, when possible, into an - :class:`AudioSource` before processing. If it is an `str` that refers to - a raw audio file, `bytes` or None, audio parameters should be provided - using kwargs (i.e., `samplig_rate`, `sample_width` and `channels` or - their alias). - block_dur: float, default: 0.01 - length in seconds of audio windows to return at each `read` call. - hop_dur: float, default: None - length in seconds of data amount to skip from previous window. If - defined, it is used to compute the temporal overlap between previous and - current window (nameply `overlap = block_dur - hop_dur`). Default, None, - means that consecutive windows do not overlap. - record: bool, default: False - whether to record read audio data for later access. If True, audio data - can be retrieved by first calling `rewind()`, then using the `data` - property. Note that once `rewind()` is called, no new data will be read - from source (subsequent `read()` call will read data from cache) and - that there's no need to call `rewind()` again to access `data` property. - max_read: float, default: None - maximum amount of audio data to read in seconds. Default is None meaning - that data will be read until end of stream is reached or, when reading - from microphone a Ctrl-C is sent. + input : str, bytes, AudioSource, AudioReader, AudioRegion, or None + Input audio data. If a string, it should be the path to an audio file + (use "-" for standard input). If bytes, the input is treated as raw + audio data. If None, audio is read from a microphone. Any input that + is not an :class:`AudioReader` will be converted, if possible, to an + :class:`AudioSource` for processing. For raw audio (string path, bytes, + or None), specify audio parameters using kwargs (`sampling_rate`, + `sample_width`, `channels` or their aliases: `sr`, `sw`, `ch`). + block_dur : float, default=0.01 + Duration of audio data (in seconds) to return in each `read` call. + hop_dur : float, optional + Duration of data to skip (in seconds) from the previous window. If set, + it is used to calculate temporal overlap between the current and + previous window (`overlap = block_dur - hop_dur`). If None (default), + windows do not overlap. + record : bool, default=False + Whether to record audio data for later access. If True, recorded audio + can be accessed using the `data` property after calling `rewind()`. + Note: after `rewind()`, no new data is read from the source—subsequent + `read` calls use the cached data. + max_read : float, optional + Maximum duration of audio data to read (in seconds). If None (default), + data is read until the end of the stream or, for microphone input, until + a Ctrl-C interruption. - When `input` is None, of type bytes or a raw audio files some of the - follwing kwargs are mandatory. + Additional audio parameters may be required if `input` is raw audio + (None, bytes, or raw audio file): Other Parameters ---------------- audio_format, fmt : str - type of audio data (e.g., wav, ogg, flac, raw, etc.). This will only be - used if `input` is a string path to an audio file. If not given, audio - type will be guessed from file name extension or from file header. + Type of audio data (e.g., wav, ogg, flac, raw). Used if `input` is a + file path. If not provided, the format is inferred from the file + extension or header. sampling_rate, sr : int - sampling rate of audio data. Required if `input` is a raw audio file, is - a bytes object or None (i.e., read from microphone). + Sampling rate of the audio data. Required for raw audio (bytes, None, + or raw file). sample_width, sw : int - number of bytes used to encode one audio sample, typically 1, 2 or 4. - Required for raw data, see `sampling_rate`. + Number of bytes per audio sample (typically 1, 2, or 4). Required for + raw data. channels, ch : int - number of channels of audio data. Required for raw data, see - `sampling_rate`. + Number of audio channels. Required for raw data. use_channel, uc : {None, "any", "mix", "avg", "average"} or int - which channel to use for split if `input` has multiple audio channels. - Regardless of which channel is used for splitting, returned audio events - contain data from *all* the channels of `input`. The following values - are accepted: + Specifies the channel used for split if `input` has multiple channels. + All returned audio data includes data from *all* input channels. Options: - - None (alias "any"): accept audio activity from any channel, even if - other channels are silent. This is the default behavior. + - None or "any": Use any active channel, regardless of silence in others. + (Default) + - "mix" / "avg" / "average": Combine all channels by averaging. + - int: Use the specified channel ID (0 <= value < `channels`). - - "mix" (alias "avg" or "average"): mix down all channels (i.e., compute - average channel) and split the resulting channel. - - - int (>= 0 , < `channels`): use one channel, specified by its integer - id, for split. - - large_file : bool, default: False - If True, AND if `input` is a path to a *wav* of a *raw* audio file - (and only these two formats) then audio data is lazily loaded to memory - (i.e., one analysis window a time). Otherwise the whole file is loaded - to memory before split. Set to True if the size of the file is larger - than available memory. + large_file : bool, default=False + If True and `input` is a path to a *wav* or *raw* file, audio data is + loaded lazily (one analysis window at a time). Otherwise, the entire + file is loaded before processing. Use True for large files exceeding + available memory. """ def __init__( @@ -1002,7 +682,7 @@ hop_dur=None, record=False, max_read=None, - **kwargs + **kwargs, ): if not isinstance(input, AudioSource): input = get_audio_source(input, **kwargs) @@ -1012,10 +692,11 @@ if max_read is not None: input = _Limiter(input, max_read) self._max_read = max_read - if hop_dur is not None: + if hop_dur is None or hop_dur == block_dur: + input = _FixedSizeAudioReader(input, block_dur) + else: input = _OverlapAudioReader(input, block_dur, hop_dur) - else: - input = _FixedSizeAudioReader(input, block_dur) + self._audio_source = input def __repr__(self): @@ -1027,9 +708,9 @@ if self.max_read is not None: max_read = "{:.3f}".format(self.max_read) return ( - "{cls}(block_dur={block_dur}, " + "<{cls}(block_dur={block_dur}, " "hop_dur={hop_dur}, record={rewindable}, " - "max_read={max_read})" + "max_read={max_read})>" ).format( cls=self.__class__.__name__, block_dur=block_dur, @@ -1075,26 +756,23 @@ ) try: return getattr(self._audio_source, name) - except AttributeError: + except AttributeError as exc: raise AttributeError( - "'AudioReader' has no attribute '{}'".format(name) - ) - - -# Keep AudioDataSource for compatibility -# Remove in a future version when ADSFactory is removed -AudioDataSource = AudioReader + f"'AudioReader' has no attribute {name!r}" + ) from exc class Recorder(AudioReader): - """Class to read fixed-size chunks of audio data from a source and keeps - data in a cache. Using this class is equivalent to initializing - :class:`AudioReader` with `record=True`. For more information about the - other parameters see :class:`AudioReader`. + """ + A class to read fixed-size chunks of audio data from a source and store + them in a cache. This class is equivalent to initializing + :class:`AudioReader` with `record=True`. For more details on additional + parameters, refer to :class:`AudioReader`. - Once the desired amount of data is read, you can call the :func:`rewind` - method then get the recorded data via the :attr:`data` attribute. You can also - re-read cached data one window a time by calling :func:`read`. + Once the desired amount of data is read, you can call the :meth:`rewind` + method to access the recorded data via the :attr:`data` attribute. The + cached data can also be re-read in fixed-size windows by calling + :meth:`read`. """ def __init__( @@ -1106,5 +784,5 @@ hop_dur=hop_dur, record=True, max_read=max_read, - **kwargs + **kwargs, )
--- a/auditok/workers.py Thu Mar 30 10:17:57 2023 +0100 +++ b/auditok/workers.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,22 +1,18 @@ import os +import subprocess import sys +import wave +from abc import ABCMeta, abstractmethod +from collections import namedtuple +from datetime import datetime, timedelta +from queue import Empty, Queue from tempfile import NamedTemporaryFile -from abc import ABCMeta, abstractmethod from threading import Thread -from datetime import datetime, timedelta -from collections import namedtuple -import wave -import subprocess -from queue import Queue, Empty + +from .core import make_silence, split +from .exceptions import AudioEncodingError, AudioEncodingWarning from .io import _guess_audio_format -from .util import AudioDataSource, make_duration_formatter -from .core import split -from .exceptions import ( - EndOfProcessing, - AudioEncodingError, - AudioEncodingWarning, -) - +from .util import AudioReader, make_duration_formatter _STOP_PROCESSING = "STOP_PROCESSING" _Detection = namedtuple("_Detection", "id start end duration") @@ -32,9 +28,9 @@ ) as proc: stdout, stderr = proc.communicate() return proc.returncode, stdout, stderr - except Exception: + except Exception as exc: err_msg = "Couldn't export audio using command: '{}'".format(command) - raise AudioEncodingError(err_msg) + raise AudioEncodingError(err_msg) from exc class Worker(Thread, metaclass=ABCMeta): @@ -86,15 +82,16 @@ return None -class TokenizerWorker(Worker, AudioDataSource): +class TokenizerWorker(Worker, AudioReader): def __init__(self, reader, observers=None, logger=None, **kwargs): self._observers = observers if observers is not None else [] self._reader = reader - self._audio_region_gen = split(self, **kwargs) + kwargs["input"] = self + self._audio_region_gen = split(**kwargs) self._detections = [] self._log_format = "[DET]: Detection {0.id} (start: {0.start:.3f}, " self._log_format += "end: {0.end:.3f}, duration: {0.duration:.3f})" - Worker.__init__(self, timeout=0.2, logger=logger) + super().__init__(timeout=0.2, logger=logger) def _process_message(self): pass @@ -103,6 +100,10 @@ def detections(self): return self._detections + @property + def reader(self): + return self._reader + def _notify_observers(self, message): for observer in self._observers: observer.send(message) @@ -150,27 +151,41 @@ return getattr(self._reader, name) -class StreamSaverWorker(Worker): +class AudioDataSaverWorker(Worker): + def __init__( self, - audio_reader, filename, - export_format=None, - cache_size_sec=0.5, + export_format, + sampling_rate, + sample_width, + channels, timeout=0.2, ): - self._reader = audio_reader - sample_size_bytes = self._reader.sw * self._reader.ch - self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes + + super().__init__(timeout=timeout) self._output_filename = filename - self._export_format = _guess_audio_format(export_format, filename) + self._sampling_rate = sampling_rate + self._sample_width = sample_width + self._channels = channels + + self._export_format = _guess_audio_format(filename, export_format) if self._export_format is None: self._export_format = "wav" self._init_output_stream() self._exported = False - self._cache = [] - self._total_cached = 0 - Worker.__init__(self, timeout=timeout) + + @property + def sr(self): + return self._sampling_rate + + @property + def sw(self): + return self._sample_width + + @property + def ch(self): + return self._channels def _get_non_existent_filename(self): filename = self._output_filename + ".wav" @@ -186,74 +201,23 @@ else: self._tmp_output_filename = self._output_filename self._wfp = wave.open(self._tmp_output_filename, "wb") - self._wfp.setframerate(self._reader.sr) - self._wfp.setsampwidth(self._reader.sw) - self._wfp.setnchannels(self._reader.ch) - - @property - def sr(self): - return self._reader.sampling_rate - - @property - def sw(self): - return self._reader.sample_width - - @property - def ch(self): - return self._reader.channels - - def __del__(self): - self._post_process() - - if ( - (self._tmp_output_filename != self._output_filename) - and self._exported - and os.path.exists(self._tmp_output_filename) - ): - os.remove(self._tmp_output_filename) - - def _process_message(self, data): - self._cache.append(data) - self._total_cached += len(data) - if self._total_cached >= self._cache_size: - self._write_cached_data() - - def _post_process(self): - while True: - try: - data = self._inbox.get_nowait() - if data != _STOP_PROCESSING: - self._cache.append(data) - self._total_cached += len(data) - except Empty: - break - self._write_cached_data() - self._wfp.close() - - def _write_cached_data(self): - if self._cache: - data = b"".join(self._cache) - self._wfp.writeframes(data) - self._cache = [] - self._total_cached = 0 - - def open(self): - self._reader.open() - - def close(self): - self._reader.close() - self.stop() - - def rewind(self): - # ensure compatibility with AudioDataSource with record=True - pass + self._wfp.setframerate(self.sr) + self._wfp.setsampwidth(self.sw) + self._wfp.setnchannels(self.ch) @property def data(self): with wave.open(self._tmp_output_filename, "rb") as wfp: return wfp.readframes(-1) - def save_stream(self): + def export_audio(self): + try: + self._encode_export_audio() + except AudioEncodingError as ae_error: + raise AudioEncodingWarning(str(ae_error)) from ae_error + return self._output_filename + + def _encode_export_audio(self): if self._exported: return self._output_filename @@ -264,26 +228,29 @@ return self._output_filename try: self._export_with_ffmpeg_or_avconv() + except AudioEncodingError: try: self._export_with_sox() - except AudioEncodingError: + except AudioEncodingError as exc: warn_msg = "Couldn't save audio data in the desired format " - warn_msg += "'{}'. Either none of 'ffmpeg', 'avconv' or 'sox' " + warn_msg += "'{}'.\nEither none of 'ffmpeg', 'avconv' or 'sox' " warn_msg += "is installed or this format is not recognized.\n" warn_msg += "Audio file was saved as '{}'" - raise AudioEncodingWarning( + raise AudioEncodingError( warn_msg.format( self._export_format, self._tmp_output_filename ) - ) - finally: + ) from exc + else: + self._exported = True + else: self._exported = True return self._output_filename def _export_raw(self): - with open(self._output_filename, "wb") as wfp: - wfp.write(self.data) + with open(self._output_filename, "wb") as fp: + fp.write(self.data) def _export_with_ffmpeg_or_avconv(self): command = [ @@ -319,6 +286,84 @@ def close_output(self): self._wfp.close() + def __del__(self): + self._post_process() + + if ( + (self._tmp_output_filename != self._output_filename) + and self._exported + and os.path.exists(self._tmp_output_filename) + ): + os.remove(self._tmp_output_filename) + + +class StreamSaverWorker(AudioDataSaverWorker): + def __init__( + self, + audio_reader, + filename, + export_format=None, + cache_size_sec=0.5, + timeout=0.2, + ): + self._reader = audio_reader + super().__init__( + filename, + export_format, + self._reader.sr, + self._reader.sw, + self._reader.ch, + timeout=timeout, + ) + + sample_size_bytes = self._reader.sw * self._reader.ch + self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes + + self._exported = False + self._cache = [] + self._total_cached = 0 + + def _process_message(self, data): + self._cache.append(data) + self._total_cached += len(data) + if self._total_cached >= self._cache_size: + self._write_cached_data() + + def _post_process(self): + while True: + try: + data = self._inbox.get_nowait() + if data != _STOP_PROCESSING: + self._cache.append(data) + self._total_cached += len(data) + except Empty: + break + self._write_cached_data() + self._wfp.close() + + def _write_cached_data(self): + if self._cache: + data = b"".join(self._cache) + self._wfp.writeframes(data) + self._cache = [] + self._total_cached = 0 + + def open(self): + self._reader.open() + + def close(self): + self._reader.close() + self.stop() + + def rewind(self): + # ensure compatibility with AudioReader with record=True + pass + + @property + def data(self): + with wave.open(self._tmp_output_filename, "rb") as wfp: + return wfp.readframes(-1) + def read(self): data = self._reader.read() if data is not None: @@ -328,9 +373,60 @@ return data def __getattr__(self, name): - if name == "data": - return self.data - return getattr(self._reader, name) + try: + return getattr(self._reader, name) + except AttributeError: + return getattr(self, name) + + +class AudioEventsJoinerWorker(AudioDataSaverWorker): + + def __init__( + self, + silence_duration, + filename, + export_format, + sampling_rate, + sample_width, + channels, + timeout=0.2, + ): + + super().__init__( + filename, + export_format, + sampling_rate, + sample_width, + channels, + timeout, + ) + + self._silence_data = make_silence( + silence_duration, sampling_rate, sample_width, channels + ).data + self._first_event = True + + def _process_message(self, message): + _, audio_event = message + self._write_audio_event(audio_event.data) + + def _post_process(self): + while True: + try: + message = self._inbox.get_nowait() + if message != _STOP_PROCESSING: + _, audio_event = message + self._write_audio_event(audio_event.data) + except Empty: + break + self._wfp.close() + + def _write_audio_event(self, data): + if not self._first_event: + self._wfp.writeframes(self._silence_data) + else: + self._first_event = False + self._wfp.writeframes(data) class PlayerWorker(Worker): @@ -338,7 +434,7 @@ self._player = player self._progress_bar = progress_bar self._log_format = "[PLAY]: Detection {id} played" - Worker.__init__(self, timeout=timeout, logger=logger) + super().__init__(timeout=timeout, logger=logger) def _process_message(self, message): _id, audio_region = message @@ -357,13 +453,13 @@ audio_format=None, timeout=0.2, logger=None, - **audio_parameters + **audio_parameters, ): self._filename_format = filename_format self._audio_format = audio_format self._audio_parameters = audio_parameters self._debug_format = "[SAVE]: Detection {id} saved as '{filename}'" - Worker.__init__(self, timeout=timeout, logger=logger) + super().__init__(timeout=timeout, logger=logger) def _process_message(self, message): _id, audio_region = message @@ -384,7 +480,7 @@ class CommandLineWorker(Worker): def __init__(self, command, timeout=0.2, logger=None): self._command = command - Worker.__init__(self, timeout=timeout, logger=logger) + super().__init__(timeout=timeout, logger=logger) self._debug_format = "[COMMAND]: Detection {id} command: '{command}'" def _process_message(self, message): @@ -411,7 +507,7 @@ self._format_time = make_duration_formatter(time_format) self._timestamp_format = timestamp_format self.detections = [] - Worker.__init__(self, timeout=timeout) + super().__init__(timeout=timeout) def _process_message(self, message): _id, audio_region = message
--- a/doc/command_line_usage.rst Thu Mar 30 10:17:57 2023 +0100 +++ b/doc/command_line_usage.rst Wed Oct 30 17:17:59 2024 +0000 @@ -1,53 +1,56 @@ -``auditok`` can also be used from the command-line. For more information about -parameters and their description type: +Command-line guide +================== +``auditok`` can also be used from the command line. For information +about available parameters and descriptions, type: .. code:: bash auditok -h -In the following we'll a few examples that covers most use-cases. +Below, we provide several examples covering the most common use cases. -Read and split audio data online --------------------------------- +Read audio data and detect audio events online +---------------------------------------------- -To try ``auditok`` from the command line with you voice, you should either -install `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_ so that ``auditok`` -can directly read data from the microphone, or record data with an external program -(e.g., `sox`) and redirect its output to ``auditok``. +To try ``auditok`` from the command line with your own voice, you’ll need to +either install `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_ so +that ``auditok`` can read directly from the microphone, or record audio with +an external program (e.g., `sox`) and redirect its output to ``auditok``. -Read data from the microphone (`pyaudio` installed): +To read data directly from the microphone and use default parameters for audio +data and tokenization, simply type: .. code:: bash auditok -This will print the *id*, *start time* and *end time* of each detected audio -event. Note that we didn't pass any additional arguments to the previous command, -so ``auditok`` will use default values. The most important arguments are: +This will print the **id**, **start time**, and **end time** of each detected +audio event. As mentioned above, no additional arguments were passed in the +previous command, so ``auditok`` will use its default values. The most important +arguments are: -- ``-n``, ``--min-duration`` : minimum duration of a valid audio event in seconds, default: 0.2 -- ``-m``, ``--max-duration`` : maximum duration of a valid audio event in seconds, default: 5 -- ``-s``, ``--max-silence`` : maximum duration of a consecutive silence within a valid audio event in seconds, default: 0.3 -- ``-e``, ``--energy-threshold`` : energy threshold for detection, default: 50 +- ``-n``, ``--min-duration``: minimum duration of a valid audio event in seconds, default: 0.2 +- ``-m``, ``--max-duration``: maximum duration of a valid audio event in seconds, default: 5 +- ``-s``, ``--max-silence``: maximum duration of a continuous silence within a valid audio event in seconds, default: 0.3 +- ``-e``, ``--energy-threshold``: energy threshold for detection, default: 50 Read audio data with an external program ---------------------------------------- - -If you don't have `pyaudio`, you can use `sox` for data acquisition -(`sudo apt-get install sox`) and make ``auditok`` read data from standard input: +You can use an external program, such as `sox` (``sudo apt-get install sox``), +to record audio data in real-time, redirect it, and have `auditok` read the data +from standard input: .. code:: bash rec -q -t raw -r 16000 -c 1 -b 16 -e signed - | auditok - -r 16000 -w 2 -c 1 -Note that when data is read from standard input, the same audio parameters must -be used for both `sox` (or any other data generation/acquisition tool) and -``auditok``. The following table summarizes audio parameters. - +Note that when reading data from standard input, the same audio parameters must +be set for both `sox` (or any other data generation/acquisition tool) and ``auditok``. +The following table provides a summary of the audio parameters: +-----------------+------------+------------------+-----------------------+ | Audio parameter | sox option | `auditok` option | `auditok` default | @@ -61,17 +64,17 @@ | Encoding | -e | NA | always a signed int | +-----------------+------------+------------------+-----------------------+ -According to this table, the previous command can be run with the default -parameters as: +Based on the table, the previous command can be run with the default parameters as: .. code:: bash - rec -q -t raw -r 16000 -c 1 -b 16 -e signed - | auditok -i - + rec -q -t raw -r 16000 -c 1 -b 16 -e signed - | auditok - + Play back audio detections -------------------------- -Use the ``-E`` option (for echo): +Use the ``-E`` (or ``--echo``) option : .. code:: bash @@ -79,11 +82,6 @@ # or rec -q -t raw -r 16000 -c 1 -b 16 -e signed - | auditok - -E -The second command works without further argument because data is recorded with -``auditok``'s default audio parameters . If one of the parameters is not at the -default value you should specify it alongside ``-E``. - - Using ``-E`` requires `pyaudio`, if it's not installed you can use the ``-C`` (used to run an external command with detected audio event as argument): @@ -101,10 +99,10 @@ Print out detection information ------------------------------- -By default ``auditok`` prints out the **id**, the **start** and the **end** of -each detected audio event. The latter two values represent the absolute position -of the event within input stream (file or microphone) in seconds. The following -listing is an example output with the default format: +By default, ``auditok`` outputs the **id**, **start**, and **end** times for each +detected audio event. The start and end values indicate the beginning and end of +the event within the input stream (file or microphone) in seconds. Below is an +example of the output in the default format: .. code:: bash @@ -123,7 +121,7 @@ auditok audio.wav --printf "{id}: [{timestamp}] start:{start}, end:{end}, dur: {duration}" -the output would be something like: +the output will look like: .. code:: bash @@ -145,7 +143,7 @@ --------------- You can save audio events to disk as they're detected using ``-o`` or -``--save-detections-as``. To get a uniq file name for each event, you can use +``--save-detections-as``. To create a uniq file name for each event, you can use ``{id}``, ``{start}``, ``{end}`` and ``{duration}`` placeholders. Example: @@ -153,9 +151,9 @@ auditok --save-detections-as "{id}_{start}_{end}.wav" -When using ``{start}``, ``{end}`` and ``{duration}`` placeholders, it's -recommended that the number of decimals of the corresponding values be limited -to 3. You can use something like: +When using ``{start}``, ``{end}``, and ``{duration}`` placeholders, it is +recommended to limit the number of decimal places for these values to 3. You +can do this with a format like: .. code:: bash @@ -165,22 +163,39 @@ Save whole audio stream ----------------------- -When reading audio data from the microphone, you most certainly want to save it -to disk. For this you can use the ``-O`` or ``--save-stream`` option. +When reading audio data from the microphone, you may want to save it to disk. +To do this, use the ``-O`` or ``--save-stream`` option: .. code:: bash - auditok --save-stream "stream.wav" + auditok --save-stream output.wav -Note this will work even if you read data from another file on disk. +Note that this will work even if you read data from a file on disk. +Join detected audio events with a silence of a given duration +------------------------------------------------------------- + +Sometimes, you may want to detect audio events while also +creating a file that contains the same events with modified +pause durations. + +To do this, use the ``-j`` or ``--join-detections`` option together +with the ``-O`` / ``--save-stream`` option. In the example below, we +read data from `input.wav` and save audio events to `output.wav`, adding +1-second pauses between them: + + +.. code:: bash + + auditok input.wav --join-detections 1 -O output.wav + Plot detections --------------- Audio signal and detections can be plotted using the ``-p`` or ``--plot`` option. You can also save plot to disk using ``--save-image``. The following example -does both: +demonstrates both: .. code:: bash
--- a/doc/conf.py Thu Mar 30 10:17:57 2023 +0100 +++ b/doc/conf.py Wed Oct 30 17:17:59 2024 +0000 @@ -12,11 +12,11 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys +import ast import os import re -import ast import shlex +import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -63,9 +63,9 @@ master_doc = "index" # General information about the project. -project = u"auditok" -copyright = u"2015-2021, Amine Sehili" -author = u"Amine Sehili" +project = "auditok" +copyright = "2015-2021, Amine Sehili" +author = "Amine Sehili" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -253,8 +253,8 @@ ( master_doc, "auditok.tex", - u"auditok Documentation", - u"Amine Sehili", + "auditok Documentation", + "Amine Sehili", "manual", ), ] @@ -284,7 +284,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "auditok", u"auditok Documentation", [author], 1)] +man_pages = [(master_doc, "auditok", "auditok Documentation", [author], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -299,7 +299,7 @@ ( master_doc, "auditok", - u"auditok Documentation", + "auditok Documentation", author, "auditok", "Audio Activity Detection tool.",
--- a/doc/examples.rst Thu Mar 30 10:17:57 2023 +0100 +++ b/doc/examples.rst Wed Oct 30 17:17:59 2024 +0000 @@ -1,38 +1,54 @@ Load audio data --------------- -Audio data is loaded with the :func:`load` function which can read from audio -files, the microphone or use raw audio data. +Audio data is loaded using the :func:`load` function, which can read from +audio files, capture from the microphone, or accept raw audio data +(as a ``bytes`` object). From a file =========== -If the first argument of :func:`load` is a string, it should be a path to an -audio file. +If the first argument of :func:`load` is a string or a `Path`, it should +refer to an existing audio file. .. code:: python import auditok region = auditok.load("audio.ogg") -If input file contains raw (headerless) audio data, passing `audio_format="raw"` -and other audio parameters (`sampling_rate`, `sample_width` and `channels`) is -mandatory. In the following example we pass audio parameters with their short -names: +If the input file contains raw (headerless) audio data, specifying audio +parameters (``sampling_rate``, ``sample_width``, and ``channels``) is required. +Additionally, if the file name does not end with 'raw', you should explicitly +pass `audio_format="raw"` to the function. + +In the example below, we provide audio parameters using their abbreviated names: .. code:: python region = auditok.load("audio.dat", audio_format="raw", sr=44100, # alias for `sampling_rate` - sw=2 # alias for `sample_width` + sw=2, # alias for `sample_width` ch=1 # alias for `channels` ) +Alternatively you can user :class:`AudioRegion` to load audio data: + +.. code:: python + + from auditok import AudioRegion + region = AudioRegion.load("audio.dat", + audio_format="raw", + sr=44100, # alias for `sampling_rate` + sw=2, # alias for `sample_width` + ch=1 # alias for `channels` + ) + + From a `bytes` object ===================== -If the type of the first argument `bytes`, it's interpreted as raw audio data: +If the first argument is of type `bytes`, it is interpreted as raw audio data: .. code:: python @@ -43,7 +59,7 @@ region = auditok.load(data, sr=sr, sw=sw, ch=ch) print(region) # alternatively you can use - #region = auditok.AudioRegion(data, sr, sw, ch) + region = auditok.AudioRegion(data, sr, sw, ch) output: @@ -54,9 +70,9 @@ From the microphone =================== -If the first argument is `None`, :func:`load` will try to read data from the -microphone. Audio parameters, as well as the `max_read` parameter are mandatory: - +If the first argument is `None`, :func:`load` will attempt to read data from the +microphone. In this case, audio parameters, along with the `max_read` parameter, +are required. .. code:: python @@ -76,8 +92,8 @@ Skip part of audio data ======================= -If the `skip` parameter is > 0, :func:`load` will skip that amount in seconds -of leading audio data: +If the ``skip`` parameter is greater than 0, :func:`load` will skip that specified +amount of leading audio data, measured in seconds: .. code:: python @@ -90,7 +106,7 @@ Limit the amount of read audio ============================== -If the `max_read` parameter is > 0, :func:`load` will read at most that amount +If the ``max_read`` parameter is > 0, :func:`load` will read at most that amount in seconds of audio data: .. code:: python @@ -99,67 +115,64 @@ region = auditok.load("audio.ogg", max_read=5) assert region.duration <= 5 -This argument is mandatory when reading data from the microphone. +This argument is required when reading data from the microphone. Basic split example ------------------- -In the following we'll use the :func:`split` function to tokenize an audio file, -requiring that valid audio events be at least 0.2 second long, at most 4 seconds -long and contain a maximum of 0.3 second of continuous silence. Limiting the size -of detected events to 4 seconds means that an event of, say, 9.5 seconds will -be returned as two 4-second events plus a third 1.5-second event. Moreover, a -valid event might contain many *silences* as far as none of them exceeds 0.3 -second. +In the following example, we'll use the :func:`split` function to tokenize an +audio file.We’ll specify that valid audio events must be at least 0.2 seconds +long, no longer than 4 seconds, and contain no more than 0.3 seconds of continuous +silence. By setting a 4-second limit, an event lasting 9.5 seconds, for instance, +will be returned as two 4-second events plus a final 1.5-second event. Additionally, +a valid event may contain multiple silences, as long as none exceed 0.3 seconds. -:func:`split` returns a generator of :class:`AudioRegion`. An :class:`AudioRegion` -can be played, saved, repeated (i.e., multiplied by an integer) and concatenated -with another region (see examples below). Notice that :class:`AudioRegion` objects -returned by :func:`split` have a ``start`` a ``stop`` information stored in -their meta data that can be accessed like `object.meta.start`. +:func:`split` returns a generator of :class:`AudioRegion` objects. Each +:class:`AudioRegion` can be played, saved, repeated (multiplied by an integer), +and concatenated with another region (see examples below). Note that +:class:`AudioRegion` objects returned by :func:`split` include `start` and `stop` +attributes, which mark the beginning and end of the audio event relative to the +input audio stream. .. code:: python import auditok - # split returns a generator of AudioRegion objects - audio_regions = auditok.split( + # `split` returns a generator of AudioRegion objects + audio_events = auditok.split( "audio.wav", - min_dur=0.2, # minimum duration of a valid audio event in seconds - max_dur=4, # maximum duration of an event - max_silence=0.3, # maximum duration of tolerated continuous silence within an event - energy_threshold=55 # threshold of detection + min_dur=0.2, # Minimum duration of a valid audio event in seconds + max_dur=4, # Maximum duration of an event + max_silence=0.3, # Maximum tolerated silence duration within an event + energy_threshold=55 # Detection threshold ) - for i, r in enumerate(audio_regions): + for i, r in enumerate(audio_events): + # AudioRegions returned by `split` have defined 'start' and 'end' attributes + print(f"Event {i}: {r.start:.3f}s -- {r.end:.3f}") - # Regions returned by `split` have 'start' and 'end' metadata fields - print("Region {i}: {r.meta.start:.3f}s -- {r.meta.end:.3f}s".format(i=i, r=r)) + # Play the audio event + r.play(progress_bar=True) - # play detection - # r.play(progress_bar=True) + # Save the event with start and end times in the filename + filename = r.save("event_{start:.3f}-{end:.3f}.wav") + print(f"Event saved as: {filename}") - # region's metadata can also be used with the `save` method - # (no need to explicitly specify region's object and `format` arguments) - filename = r.save("region_{meta.start:.3f}-{meta.end:.3f}.wav") - print("region saved as: {}".format(filename)) - -output example: +Example output: .. code:: bash - Region 0: 0.700s -- 1.400s - region saved as: region_0.700-1.400.wav - Region 1: 3.800s -- 4.500s - region saved as: region_3.800-4.500.wav - Region 2: 8.750s -- 9.950s - region saved as: region_8.750-9.950.wav - Region 3: 11.700s -- 12.400s - region saved as: region_11.700-12.400.wav - Region 4: 15.050s -- 15.850s - region saved as: region_15.050-15.850.wav - + Event 0: 0.700s -- 1.400s + Event saved as: event_0.700-1.400.wav + Event 1: 3.800s -- 4.500s + Event saved as: event_3.800-4.500.wav + Event 2: 8.750s -- 9.950s + Event saved as: event_8.750-9.950.wav + Event 3: 11.700s -- 12.400s + Event saved as: event_11.700-12.400.wav + Event 4: 15.050s -- 15.850s + Event saved as: event_15.050-15.850.wav Split and plot -------------- @@ -176,11 +189,36 @@ .. image:: figures/example_1.png +Split an audio stream and re-join (glue) audio events with silence +------------------------------------------------------------------ + +The following code detects audio events within an audio stream, then insert +1 second of silence between them to create an audio with pauses: + +.. code:: python + + # Create a 1-second silent audio region + # Audio parameters must match the original stream + from auditok import split, make_silence + silence = make_silence(duration=1, + sampling_rate=16000, + sample_width=2, + channels=1) + events = split("audio.wav") + audio_with_pauses = silence.join(events) + +Alternatively, use ``split_and_join_with_silence``: + +.. code:: python + + from auditok import split_and_join_with_silence + audio_with_pauses = split_and_join_with_silence(silence_duration=1, input="audio.wav") + Read and split data from the microphone --------------------------------------- -If the first argument of :func:`split` is None, audio data is read from the +If the first argument of :func:`split` is ``None``, audio data is read from the microphone (requires `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_): .. code:: python @@ -200,15 +238,16 @@ pass -:func:`split` will continue reading audio data until you press ``Ctrl-C``. If -you want to read a specific amount of audio data, pass the desired number of -seconds with the `max_read` argument. +:func:`split` will continue reading audio data until you press ``Ctrl-C``. To read +a specific amount of audio data, pass the desired number of seconds using the +`max_read` argument. Access recorded data after split -------------------------------- -Using a :class:`Recorder` object you can get hold of acquired audio data: +Using a :class:`Recorder` object you can access to audio data read from a file +of from the mirophone. With the following code press ``Ctrl-C`` to stop recording: .. code:: python @@ -221,11 +260,13 @@ eth = 55 # alias for energy_threshold, default value is 50 rec = auditok.Recorder(input=None, sr=sr, sw=sw, ch=ch) + events = [] try: for region in auditok.split(rec, sr=sr, sw=sw, ch=ch, eth=eth): print(region) - region.play(progress_bar=True) # progress bar requires `tqdm` + region.play(progress_bar=True) + events.append(region) except KeyboardInterrupt: pass @@ -233,6 +274,7 @@ full_audio = load(rec.data, sr=sr, sw=sw, ch=ch) # alternatively you can use full_audio = auditok.AudioRegion(rec.data, sr, sw, ch) + full_audio.play(progress_bar=True) :class:`Recorder` also accepts a `max_read` argument. @@ -240,9 +282,8 @@ Working with AudioRegions ------------------------- -The following are a couple of interesting operations you can do with -:class:`AudioRegion` objects. - +In the following sections, we will review several operations +that can be performed with :class:AudioRegion objects. Basic region information ======================== @@ -257,6 +298,9 @@ region.sample_width # alias `sw` region.channels # alias `ch` +When an audio region is returned by the :func:`split` function, it includes defined +``start`` and ``end`` attributes that refer to the beginning and end of the audio +event relative to the input audio stream. Concatenate regions =================== @@ -268,7 +312,8 @@ region_2 = auditok.load("audio_2.wav") region_3 = region_1 + region_2 -Particularly useful if you want to join regions returned by :func:`split`: +This is particularly useful when you want to join regions returned by the +:func:`split` function: .. code:: python @@ -290,8 +335,7 @@ Split one region into N regions of equal size ============================================= -Divide by a positive integer (this has nothing to do with silence-based -tokenization): +Divide by a positive integer (this is unrelated to silence-based tokenization!): .. code:: python @@ -300,21 +344,21 @@ regions = regions / 5 assert sum(regions) == region -Note that if no perfect division is possible, the last region might be a bit -shorter than the previous N-1 regions. +Note that if an exact split is not possible, the last region may be shorter +than the preceding N-1 regions. Slice a region by samples, seconds or milliseconds ================================================== -Slicing an :class:`AudioRegion` can be interesting in many situations. You can for -example remove a fixed-size portion of audio data from the beginning or from the -end of a region or crop a region by an arbitrary amount as a data augmentation -strategy. +Slicing an :class:`AudioRegion` can be useful in various situations. +For example, you can remove a fixed-length portion of audio data from +the beginning or end of a region, or crop a region by an arbitrary amount +as a data augmentation strategy. -The most accurate way to slice an `AudioRegion` is to use indices that -directly refer to raw audio samples. In the following example, assuming that the -sampling rate of audio data is 16000, you can extract a 5-second region from -main region, starting from the 20th second as follows: +The most accurate way to slice an `AudioRegion` is by using indices that +directly refer to raw audio samples. In the following example, assuming +the audio data has a sampling rate of 16000, you can extract a 5-second +segment from the main region, starting at the 20th second, as follows: .. code:: python @@ -324,9 +368,9 @@ stop = 25 * 16000 five_second_region = region[start:stop] -This allows you to practically start and stop at any audio sample within the region. -Just as with a `list` you can omit one of `start` and `stop`, or both. You can -also use negative indices: +This allows you to start and stop at any audio sample within the region. Similar +to a ``list``, you can omit either ``start`` or ``stop``, or both. Negative +indices are also supported: .. code:: python @@ -335,9 +379,9 @@ start = -3 * region.sr # `sr` is an alias of `sampling_rate` three_last_seconds = region[start:] -While slicing by raw samples is flexible, slicing with temporal indices is more -intuitive. You can do so by accessing the ``millis`` or ``seconds`` views of an -`AudioRegion` (or their shortcut alias `ms` and `sec` or `s`). +While slicing by raw samples offers flexibility, using temporal indices is +often more intuitive. You can achieve this by accessing the ``millis`` or ``seconds`` +*views* of an :class:`AudioRegion` (or using their shortcut aliases ``ms``, ``sec``, or ``s``). With the ``millis`` view: @@ -346,6 +390,8 @@ import auditok region = auditok.load("audio.wav") five_second_region = region.millis[5000:10000] + # or + five_second_region = region.ms[5000:10000] or with the ``seconds`` view: @@ -354,6 +400,10 @@ import auditok region = auditok.load("audio.wav") five_second_region = region.seconds[5:10] + # or + five_second_region = region.sec[5:10] + # or + five_second_region = region.s[5:10] ``seconds`` indices can also be floats: @@ -363,27 +413,13 @@ region = auditok.load("audio.wav") five_second_region = region.seconds[2.5:7.5] -Get arrays of audio samples -=========================== - -If `numpy` is not installed, the `samples` attributes is a list of audio samples -arrays (standard `array.array` objects), one per channels. If numpy is installed, -`samples` is a 2-D `numpy.ndarray` where the fist dimension is the channel -and the second is the the sample. +Export an ``AudioRegion`` as a ``numpy`` array +============================================== .. code:: python - import auditok - region = auditok.load("audio.wav") - samples = region.samples - assert len(samples) == region.channels - - -If `numpy` is installed you can use: - -.. code:: python - - import numpy as np - region = auditok.load("audio.wav") - samples = np.asarray(region) - assert len(samples.shape) == 2 + from auditok import load, AudioRegion + audio = load("audio.wav") # or use `AudioRegion.load("audio.wav")` + x = audio.numpy() + assert x.shape[0] == audio.channels + assert x.shape[1] == len(audio)
--- a/doc/index.rst Thu Mar 30 10:17:57 2023 +0100 +++ b/doc/index.rst Wed Oct 30 17:17:59 2024 +0000 @@ -1,8 +1,8 @@ auditok, an AUDIo TOKenization tool =================================== -.. image:: https://travis-ci.org/amsehili/auditok.svg?branch=master - :target: https://travis-ci.org/amsehili/auditok +.. image:: https://github.com/amsehili/auditok/actions/workflows/ci.yml/badge.svg + :target: https://github.com/amsehili/auditok/actions/workflows/ci.yml/ :alt: Build Status .. image:: https://readthedocs.org/projects/auditok/badge/?version=latest @@ -11,9 +11,10 @@ -``auditok`` is an **Audio Activity Detection** tool that can process online data -(read from an audio device or from standard input) as well as audio files. It -can be used as a command line program or by calling its API. +```auditok`` is an **Audio Activity Detection** tool that processes online data +(from an audio device or standard input) and audio files. It can be used via the command line or through its API. + +Full documentation is available on `Read the Docs <https://auditok.readthedocs.io/en/latest/>`_. .. toctree:: @@ -39,8 +40,8 @@ util io signal - dataset License ------- + MIT.
--- a/doc/installation.rst Thu Mar 30 10:17:57 2023 +0100 +++ b/doc/installation.rst Wed Oct 30 17:17:59 2024 +0000 @@ -1,31 +1,31 @@ Installation ------------ -A basic version of ``auditok`` will run with standard Python (>=3.4). However, -without installing additional dependencies, ``auditok`` can only deal with audio -files in *wav* or *raw* formats. if you want more features, the following -packages are needed: +**Dependencies** -- `pydub <https://github.com/jiaaro/pydub>`_ : read audio files in popular audio formats (ogg, mp3, etc.) or extract audio from a video file. -- `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_ : read audio data from the microphone and play audio back. -- `tqdm <https://github.com/tqdm/tqdm>`_ : show progress bar while playing audio clips. -- `matplotlib <https://matplotlib.org/stable/index.html>`_ : plot audio signal and detections. -- `numpy <https://numpy.org/>`_ : required by matplotlib. Also used for some math operations instead of standard python if available. +The following dependencies are required by ``auditok`` and will be installed automatically: +- `numpy <https://numpy.org/>`_: Used for signal processing. +- `pydub <https://github.com/jiaaro/pydub>`_: to read audio files in popular formats (e.g., ogg, mp3) or extract audio from video files. +- `pyaudio <https://people.csail.mit.edu/hubert/pyaudio>`_: to read audio data from the microphone and play audio back. +- `tqdm <https://github.com/tqdm/tqdm>`_: to display a progress bar while playing audio clips. +- `matplotlib <https://matplotlib.org/stable/index.html>`_: to plot audio signal and detections. -Install the latest stable version with pip: +``auditok`` requires Python 3.7 or higher. + +To install the latest stable version, use pip: .. code:: bash sudo pip install auditok -Install with the latest development version from github: +To install the latest development version from GitHub: .. code:: bash pip install git+https://github.com/amsehili/auditok -or +Alternatively, clone the repository and install it manually: .. code:: bash
--- a/pyproject.toml Thu Mar 30 10:17:57 2023 +0100 +++ b/pyproject.toml Wed Oct 30 17:17:59 2024 +0000 @@ -14,3 +14,6 @@ | dist )/ ''' + +[tool.isort] +profile = "black"
--- a/setup.py Thu Mar 30 10:17:57 2023 +0100 +++ b/setup.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,9 +1,9 @@ +import ast +import re import sys -import re -import ast + from setuptools import setup - _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("auditok/__init__.py", "rt") as f: @@ -26,6 +26,13 @@ zip_safe=False, platforms="ANY", provides=["auditok"], + install_requires=[ + "numpy", + "matplotlib", + "pydub", + "pyaudio", + "tqdm", + ], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -36,12 +43,13 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Multimedia :: Sound/Audio :: Analysis", "Topic :: Scientific/Engineering :: Information Analysis", ],
--- a/tests/test_AudioReader.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_AudioReader.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,1005 +1,17 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 - -""" - -import unittest -from functools import partial import sys import wave -from genty import genty, genty_dataset +from functools import partial + +import pytest + from auditok import ( + AudioReader, + BufferAudioSource, + Recorder, + WaveAudioSource, dataset, - ADSFactory, - AudioDataSource, - AudioReader, - Recorder, - BufferAudioSource, - WaveAudioSource, - DuplicateArgument, ) - - -class TestADSFactoryFileAudioSource(unittest.TestCase): - def setUp(self): - self.audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - - def test_ADS_type(self): - - ads = ADSFactory.ads(audio_source=self.audio_source) - - err_msg = "wrong type for ads object, expected: 'AudioDataSource', " - err_msg += "found: {0}" - self.assertIsInstance( - ads, AudioDataSource, err_msg.format(type(ads)), - ) - - def test_default_block_size(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong default block_size, expected: 160, found: {0}".format(size), - ) - - def test_block_size(self): - ads = ADSFactory.ads(audio_source=self.audio_source, block_size=512) - size = ads.block_size - self.assertEqual( - size, - 512, - "Wrong block_size, expected: 512, found: {0}".format(size), - ) - - # with alias keyword - ads = ADSFactory.ads(audio_source=self.audio_source, bs=160) - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) - - def test_block_duration(self): - - ads = ADSFactory.ads( - audio_source=self.audio_source, block_dur=0.01 - ) # 10 ms - size = ads.block_size - self.assertEqual( - size, - 160, - "Wrong block_size, expected: 160, found: {0}".format(size), - ) - - # with alias keyword - ads = ADSFactory.ads(audio_source=self.audio_source, bd=0.025) # 25 ms - size = ads.block_size - self.assertEqual( - size, - 400, - "Wrong block_size, expected: 400, found: {0}".format(size), - ) - - def test_hop_duration(self): - - ads = ADSFactory.ads( - audio_source=self.audio_source, block_dur=0.02, hop_dur=0.01 - ) # 10 ms - size = ads.hop_size - self.assertEqual( - size, 160, "Wrong hop_size, expected: 160, found: {0}".format(size) - ) - - # with alias keyword - ads = ADSFactory.ads( - audio_source=self.audio_source, bd=0.025, hop_dur=0.015 - ) # 15 ms - size = ads.hop_size - self.assertEqual( - size, - 240, - "Wrong block_size, expected: 240, found: {0}".format(size), - ) - - def test_sampling_rate(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - srate = ads.sampling_rate - self.assertEqual( - srate, - 16000, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_sample_width(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_channels(self): - ads = ADSFactory.ads(audio_source=self.audio_source) - - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_read(self): - ads = ADSFactory.ads(audio_source=self.audio_source, block_size=256) - - ads.open() - ads_data = ads.read() - ads.close() - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(256) - audio_source.close() - - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from ads" - ) - - def test_Limiter_Deco_read(self): - # read a maximum of 0.75 seconds from audio source - ads = ADSFactory.ads(audio_source=self.audio_source, max_time=0.75) - - ads_data = [] - ads.open() - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(int(16000 * 0.75)) - audio_source.close() - - self.assertEqual( - ads_data, audio_source_data, "Unexpected data read from LimiterADS" - ) - - def test_Limiter_Deco_read_limit(self): - # read a maximum of 1.191 seconds from audio source - ads = ADSFactory.ads(audio_source=self.audio_source, max_time=1.191) - total_samples = round(ads.sampling_rate * 1.191) - nb_full_blocks, last_block_size = divmod(total_samples, ads.block_size) - total_samples_with_overlap = ( - nb_full_blocks * ads.block_size + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - total_read = 0 - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - - ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), - ) - - def test_Recorder_Deco_read(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, record=True, block_size=500 - ) - - ads_data = [] - ads.open() - for i in range(10): - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(500 * 10) - audio_source.close() - - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) - - def test_Recorder_Deco_is_rewindable(self): - ads = ADSFactory.ads(audio_source=self.audio_source, record=True) - - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) - - def test_Recorder_Deco_rewind_and_read(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, record=True, block_size=320 - ) - - ads.open() - for i in range(10): - ads.read() - - ads.rewind() - - # read all available data after rewind - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - ads_data = b"".join(ads_data) - - audio_source = WaveAudioSource( - filename=dataset.one_to_six_arabic_16000_mono_bc_noise - ) - audio_source.open() - audio_source_data = audio_source.read(320 * 10) - audio_source.close() - - self.assertEqual( - ads_data, - audio_source_data, - "Unexpected data read from RecorderADS", - ) - - def test_Overlap_Deco_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1714 - hop_size = 313 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=block_size, - hop_size=hop_size, - ) - - # Read all available data overlapping blocks - ads.open() - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for i, block in enumerate(ads_data): - - tmp = audio_source.read(block_size) - - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - - audio_source.position = (i + 1) * hop_size - - audio_source.close() - - def test_Limiter_Overlap_Deco_read(self): - - block_size = 256 - hop_size = 200 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=0.50, - block_size=block_size, - hop_size=hop_size, - ) - - # Read all available data overlapping blocks - ads.open() - ads_data = [] - while True: - block = ads.read() - if block is None: - break - ads_data.append(block) - ads.close() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for i, block in enumerate(ads_data): - tmp = audio_source.read(len(block) // (ads.sw * ads.ch)) - self.assertEqual( - len(block), - len(tmp), - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (i + 1) * hop_size - - audio_source.close() - - def test_Limiter_Overlap_Deco_read_limit(self): - - block_size = 313 - hop_size = 207 - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.932, - block_size=block_size, - hop_size=hop_size, - ) - - total_samples = round(ads.sampling_rate * 1.932) - first_read_size = block_size - next_read_size = block_size - hop_size - nb_next_blocks, last_block_size = divmod( - (total_samples - first_read_size), next_read_size - ) - total_samples_with_overlap = ( - first_read_size + next_read_size * nb_next_blocks + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - cache_size = (block_size - hop_size) * ads.sample_width * ads.channels - total_read = cache_size - - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - cache_size - - ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), - ) - - def test_Recorder_Overlap_Deco_is_rewindable(self): - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=320, - hop_size=160, - record=True, - ) - self.assertTrue( - ads.rewindable, "RecorderADS.is_rewindable should return True" - ) - - def test_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1600 - hop_size = 400 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1600 - hop_size = 400 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.50, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Read all data from file and build a BufferAudioSource - fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") - wave_data = fp.readframes(fp.getnframes()) - fp.close() - audio_source = BufferAudioSource( - wave_data, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - self.assertEqual( - ads.read(), - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_limit(self): - - # Use arbitrary valid block_size and hop_size - block_size = 1000 - hop_size = 200 - - ads = ADSFactory.ads( - audio_source=self.audio_source, - max_time=1.317, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - total_samples = round(ads.sampling_rate * 1.317) - first_read_size = block_size - next_read_size = block_size - hop_size - nb_next_blocks, last_block_size = divmod( - (total_samples - first_read_size), next_read_size - ) - total_samples_with_overlap = ( - first_read_size + next_read_size * nb_next_blocks + last_block_size - ) - expected_read_bytes = ( - total_samples_with_overlap * ads.sw * ads.channels - ) - - cache_size = (block_size - hop_size) * ads.sample_width * ads.channels - total_read = cache_size - - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - total_read += len(block) - cache_size - - ads.close() - err_msg = "Wrong data length read from LimiterADS, expected: {0}, " - err_msg += "found: {1}" - self.assertEqual( - total_read, - expected_read_bytes, - err_msg.format(expected_read_bytes, total_read), - ) - - -class TestADSFactoryBufferAudioSource(unittest.TestCase): - def setUp(self): - self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" - self.ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - block_size=4, - ) - - def test_ADS_BAS_sampling_rate(self): - srate = self.ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_ADS_BAS_sample_width(self): - swidth = self.ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_ADS_BAS_channels(self): - channels = self.ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read(self): - - # Use arbitrary valid block_size and hop_size - block_size = 5 - hop_size = 4 - - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - max_time=0.80, - block_size=block_size, - hop_size=hop_size, - record=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Build a BufferAudioSource - audio_source = BufferAudioSource( - self.signal, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from OverlapADS to those read - # from an audio source with a manual position setting - for j in range(i): - - tmp = audio_source.read(block_size) - - block = ads.read() - - self.assertEqual( - block, - tmp, - "Unexpected block '{}' (N={}) read from OverlapADS".format( - block, i - ), - ) - audio_source.position = (j + 1) * hop_size - - ads.close() - audio_source.close() - - -class TestADSFactoryAlias(unittest.TestCase): - def setUp(self): - self.signal = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" - - def test_sampling_rate_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sr=16, - sample_width=2, - channels=1, - block_dur=0.5, - ) - srate = ads.sampling_rate - self.assertEqual( - srate, - 16, - "Wrong sampling rate, expected: 16000, found: {0}".format(srate), - ) - - def test_sampling_rate_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sr=16, - sampling_rate=16, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_sample_width_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sw=2, - channels=1, - block_dur=0.5, - ) - swidth = ads.sample_width - self.assertEqual( - swidth, - 2, - "Wrong sample width, expected: 2, found: {0}".format(swidth), - ) - - def test_sample_width_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sw=2, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_channels_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - ch=1, - block_dur=4, - ) - channels = ads.channels - self.assertEqual( - channels, - 1, - "Wrong number of channels, expected: 1, found: {0}".format( - channels - ), - ) - - def test_channels_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - ch=1, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_size_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=8, - ) - size = ads.block_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), - ) - - def test_block_size_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=4, - block_size=4, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_duration_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - ) - # 0.75 ms = 0.75 * 16 = 12 - size = ads.block_size - err_msg = "Wrong block_size set with a block_dur alias 'bd', " - err_msg += "expected: 8, found: {0}" - self.assertEqual( - size, 12, err_msg.format(size), - ) - - def test_block_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=4, - block_dur=4, - ) - self.assertRaises(DuplicateArgument, func) - - def test_block_size_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=4, - bs=12, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_duration_alias(self): - - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - hd=0.5, - ) - size = ads.hop_size - self.assertEqual( - size, - 8, - "Wrong block_size using bs alias, expected: 8, found: {0}".format( - size - ), - ) - - def test_hop_duration_duplicate(self): - - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bd=0.75, - hd=0.5, - hop_dur=0.5, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_size_duration_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=8, - hs=4, - hd=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_hop_size_greater_than_block_size(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - bs=4, - hs=8, - ) - self.assertRaises(ValueError, func) - - def test_filename_duplicate(self): - - func = partial( - ADSFactory.ads, - fn=dataset.one_to_six_arabic_16000_mono_bc_noise, - filename=dataset.one_to_six_arabic_16000_mono_bc_noise, - ) - self.assertRaises(DuplicateArgument, func) - - def test_data_buffer_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - db=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - ) - self.assertRaises(DuplicateArgument, func) - - def test_max_time_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - mt=10, - block_dur=0.5, - ) - self.assertEqual( - ads.max_read, - 10, - "Wrong AudioDataSource.max_read, expected: 10, found: {}".format( - ads.max_read - ), - ) - - def test_max_time_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - mt=True, - max_time=True, - ) - - self.assertRaises(DuplicateArgument, func) - - def test_record_alias(self): - ads = ADSFactory.ads( - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - rec=True, - block_dur=0.5, - ) - self.assertTrue( - ads.rewindable, "AudioDataSource.rewindable expected to be True" - ) - - def test_record_duplicate(self): - func = partial( - ADSFactory.ads, - data_buffer=self.signal, - sampling_rate=16, - sample_width=2, - channels=1, - rec=True, - record=True, - ) - self.assertRaises(DuplicateArgument, func) - - def test_Limiter_Recorder_Overlap_Deco_rewind_and_read_alias(self): - - # Use arbitrary valid block_size and hop_size - block_size = 5 - hop_size = 4 - - ads = ADSFactory.ads( - db=self.signal, - sr=16, - sw=2, - ch=1, - mt=0.80, - bs=block_size, - hs=hop_size, - rec=True, - ) - - # Read all available data overlapping blocks - ads.open() - i = 0 - while True: - block = ads.read() - if block is None: - break - i += 1 - - ads.rewind() - - # Build a BufferAudioSource - audio_source = BufferAudioSource( - self.signal, ads.sampling_rate, ads.sample_width, ads.channels - ) - audio_source.open() - - # Compare all blocks read from AudioDataSource to those read - # from an audio source with manual position definition - for j in range(i): - tmp = audio_source.read(block_size) - block = ads.read() - self.assertEqual( - block, - tmp, - "Unexpected block (N={0}) read from OverlapADS".format(i), - ) - audio_source.position = (j + 1) * hop_size - ads.close() - audio_source.close() +from auditok.util import _Limiter, _OverlapAudioReader def _read_all_data(reader): @@ -1012,68 +24,748 @@ return b"".join(blocks) -@genty -class TestAudioReader(unittest.TestCase): +class TestAudioReaderWithFileAudioSource: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + self.audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + self.audio_source.open() + yield + self.audio_source.close() - # TODO move all tests here when backward compatibility - # with ADSFactory is dropped + def test_AudioReader_type(self): + reader = AudioReader(input=self.audio_source) + err_msg = "wrong object type, expected: 'AudioReader', found: {0}" + assert isinstance(reader, AudioReader), err_msg.format(type(reader)) - @genty_dataset( - mono=("mono_400", 0.5, 16000), - multichannel=("3channel_400-800-1600", 0.5, 16000 * 3), + def _test_default_block_size(self): + reader = AudioReader(input=self.audio_source) + data = reader.read() + size = len(data) + assert ( + size == 160 + ), "Wrong default block_size, expected: 160, found: {0}".format(size) + + @pytest.mark.parametrize( + "block_dur, expected_nb_samples", + [ + (None, 160), # default: 10 ms + (0.025, 400), # 25 ms + ], + ids=["default", "_25ms"], ) - def test_Limiter(self, file_id, max_read, size): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read(size) + def test_block_duration(self, block_dur, expected_nb_samples): + """Test the number of samples read for a given block duration.""" + if block_dur is not None: + reader = AudioReader(input=self.audio_source, block_dur=block_dur) + else: + reader = AudioReader(input=self.audio_source) + data = reader.read() + nb_samples = len(data) // reader.sample_width + assert ( + nb_samples == expected_nb_samples + ), f"Wrong block_size, expected: {expected_nb_samples}, found: {nb_samples}" - reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) - reader.open() + @pytest.mark.parametrize( + "block_dur, hop_dur, expected_nb_blocks, expected_last_block_nb_samples", + [ + (None, None, 1879, 126), # default: 10 ms + (0.01, None, 1879, 126), # block_dur_10ms_hop_dur_None + (0.01, 0.01, 1879, 126), # block_dur_10ms_hop_dur_10ms + (0.02, None, 940, 126), # block_dur_20ms_hop_dur_None + (0.025, None, 752, 206), # block_dur_25ms_hop_dur_None + (0.02, 0.01, 1878, 286), # block_dur_20ms_hop_dur_10ms + (0.025, 0.005, 3754, 366), # block_dur_25ms_hop_dur_5ms + ], + ids=[ + "default", + "block_dur_10ms_hop_dur_None", + "block_dur_10ms_hop_dur_10ms", + "block_dur_20ms_hop_dur_None", + "block_dur_25ms_hop_dur_None", + "block_dur_20ms_hop_dur_10ms", + "block_dur_25ms_hop_dur_5ms", + ], + ) + def test_hop_duration( + self, + block_dur, + hop_dur, + expected_nb_blocks, + expected_last_block_nb_samples, + ): + """Test the number of read blocks and the duration of last block for + different 'block_dur' and 'hop_dur' values. + + Args: + block_dur (float or None): block duration in seconds. + hop_dur (float or None): hop duration in seconds. + expected_nb_blocks (int): expected number of read block. + expected_last_block_nb_samples (int): expected number of sample + in the last block. + """ + if block_dur is not None: + reader = AudioReader( + input=self.audio_source, block_dur=block_dur, hop_dur=hop_dur + ) + else: + reader = AudioReader(input=self.audio_source, hop_dur=hop_dur) + + nb_blocks = 0 + last_block_nb_samples = None + while True: + data = reader.read() + if data is not None: + nb_blocks += 1 + last_block_nb_samples = len(data) // reader.sample_width + else: + break + err_msg = "Wrong number of blocks read from source, expected: " + err_msg += f"{expected_nb_blocks}, found: {nb_blocks}" + assert nb_blocks == expected_nb_blocks, err_msg + + err_msg = ( + "Wrong number of samples in last block read from source, expected: " + ) + err_msg += ( + f"{expected_last_block_nb_samples}, found: {last_block_nb_samples}" + ) + + assert last_block_nb_samples == expected_last_block_nb_samples, err_msg + + def test_hop_duration_exception(self): + """Test passing hop_dur > block_dur raises ValueError""" + with pytest.raises(ValueError): + AudioReader(self.audio_source, block_dur=0.01, hop_dur=0.015) + + @pytest.mark.parametrize( + "block_dur, hop_dur", + [ + (None, None), # default + (0.01, None), # block_dur_10ms_hop_dur_None + (None, 0.01), # block_dur_None__hop_dur_10ms + (0.05, 0.05), # block_dur_50ms_hop_dur_50ms + ], + ids=[ + "default", + "block_dur_10ms_hop_dur_None", + "block_dur_None__hop_dur_10ms", + "block_dur_50ms_hop_dur_50ms", + ], + ) + def test_reader_class_block_dur_equals_hop_dur(self, block_dur, hop_dur): + """Test passing hop_dur == block_dur does not create an instance of + '_OverlapAudioReader'. + """ + if block_dur is not None: + reader = AudioReader( + input=self.audio_source, block_dur=block_dur, hop_dur=hop_dur + ) + else: + reader = AudioReader(input=self.audio_source, hop_dur=hop_dur) + assert not isinstance(reader, _OverlapAudioReader) + + def test_sampling_rate(self): + reader = AudioReader(input=self.audio_source) + sampling_rate = reader.sampling_rate + assert ( + sampling_rate == 16000 + ), f"Wrong sampling rate, expected: 16000, found: {sampling_rate}" + + def test_sample_width(self): + reader = AudioReader(input=self.audio_source) + sample_width = reader.sample_width + assert ( + sample_width == 2 + ), f"Wrong sample width, expected: 2, found: {sample_width}" + + def test_channels(self): + reader = AudioReader(input=self.audio_source) + channels = reader.channels + assert ( + channels == 1 + ), f"Wrong number of channels, expected: 1, found: {channels}" + + def test_read(self): + reader = AudioReader(input=self.audio_source, block_dur=0.02) + reader_data = reader.read() + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(320) + audio_source.close() + assert ( + reader_data == audio_source_data + ), "Unexpected data read from AudioReader" + + def test_read_with_overlap(self): + reader = AudioReader( + input=self.audio_source, block_dur=0.02, hop_dur=0.01 + ) + _ = reader.read() # first block + reader_data = reader.read() # second block with 0.01 S overlap + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + _ = audio_source.read(160) + audio_source_data = audio_source.read(320) + audio_source.close() + assert ( + reader_data == audio_source_data + ), "Unexpected data read from AudioReader" + + def test_read_from_AudioReader_with_max_read(self): + # read a maximum of 0.75 seconds from audio source + reader = AudioReader(input=self.audio_source, max_read=0.75) + assert isinstance(reader._audio_source._audio_source, _Limiter) + reader_data = _read_all_data(reader) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(int(16000 * 0.75)) + audio_source.close() + + assert ( + reader_data == audio_source_data + ), f"Unexpected data read from AudioReader with 'max_read = {0.75}'" + + def test_read_data_size_from_AudioReader_with_max_read(self): + # read a maximum of 1.191 seconds from audio source + reader = AudioReader(input=self.audio_source, max_read=1.191) + assert isinstance(reader._audio_source._audio_source, _Limiter) + total_samples = round(reader.sampling_rate * 1.191) + block_size = int(reader.block_dur * reader.sampling_rate) + nb_full_blocks, last_block_size = divmod(total_samples, block_size) + total_samples_with_overlap = ( + nb_full_blocks * block_size + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * reader.sample_width * reader.channels + ) + + reader_data = _read_all_data(reader) + total_read = len(reader_data) + err_msg = f"Wrong data length read from LimiterADS, expected: {expected_read_bytes}, found: {total_read}" + assert total_read == expected_read_bytes, err_msg + + def test_read_from_Recorder(self): + reader = Recorder(input=self.audio_source, block_dur=0.025) + reader_data = [] + for _ in range(10): + block = reader.read() + if block is None: + break + reader_data.append(block) + reader_data = b"".join(reader_data) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(400 * 10) + audio_source.close() + + assert ( + reader_data == audio_source_data + ), "Unexpected data read from Recorder" + + def test_AudioReader_rewindable(self): + reader = AudioReader(input=self.audio_source, record=True) + assert ( + reader.rewindable + ), "AudioReader with record=True should be rewindable" + + def test_AudioReader_record_and_rewind(self): + reader = AudioReader( + input=self.audio_source, record=True, block_dur=0.02 + ) + # read 0.02 * 10 = 0.2 sec. of data + for i in range(10): + reader.read() + reader.rewind() + + # read all available data after rewind + reader_data = _read_all_data(reader) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(320 * 10) # read 0.2 sec. of data + audio_source.close() + + assert ( + reader_data == audio_source_data + ), "Unexpected data read from AudioReader with record = True" + + def test_Recorder_record_and_rewind(self): + recorder = Recorder(input=self.audio_source, block_dur=0.02) + # read 0.02 * 10 = 0.2 sec. of data + for i in range(10): + recorder.read() + + recorder.rewind() + + # read all available data after rewind + recorder_data = [] + recorder_data = _read_all_data(recorder) + + audio_source = WaveAudioSource( + filename=dataset.one_to_six_arabic_16000_mono_bc_noise + ) + audio_source.open() + audio_source_data = audio_source.read(320 * 10) # read 0.2 sec. of data + audio_source.close() + + assert ( + recorder_data == audio_source_data + ), "Unexpected data read from Recorder" + + def test_read_overlapping_blocks(self): + # Use arbitrary valid block_size and hop_size + block_size = 1714 + hop_size = 313 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + block_dur=block_dur, + hop_dur=hop_dur, + ) + + # Read all available overlapping blocks of data + reader_data = [] + while True: + block = reader.read() + if block is None: + break + reader_data.append(block) + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, + reader.sampling_rate, + reader.sample_width, + reader.channels, + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read from an + # audio source with a manual position setting + for i, block in enumerate(reader_data): + tmp = audio_source.read(block_size) + assert ( + block == tmp + ), f"Unexpected data (block {i}) from reader with overlapping blocks" + audio_source.position = (i + 1) * hop_size + + audio_source.close() + + def test_read_overlapping_blocks_with_max_read(self): + block_size = 256 + hop_size = 200 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + block_dur=block_dur, + hop_dur=hop_dur, + max_read=0.5, + ) + + # Read all available overlapping blocks of data + reader_data = [] + while True: + block = reader.read() + if block is None: + break + reader_data.append(block) + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, + reader.sampling_rate, + reader.sample_width, + reader.channels, + ) + audio_source.open() + + # Compare all blocks read from OverlapADS to those read from an + # audio source with a manual position setting + for i, block in enumerate(reader_data): + tmp = audio_source.read(len(block) // (reader.sw * reader.ch)) + assert ( + block == tmp + ), f"Unexpected data (block {i}) from reader with overlapping blocks and max_read" + audio_source.position = (i + 1) * hop_size + + audio_source.close() + + def test_length_read_overlapping_blocks_with_max_read(self): + block_size = 313 + hop_size = 207 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + max_read=1.932, + block_dur=block_dur, + hop_dur=hop_dur, + ) + + total_samples = round(reader.sampling_rate * 1.932) + first_read_size = block_size + next_read_size = block_size - hop_size + nb_next_blocks, last_block_size = divmod( + (total_samples - first_read_size), next_read_size + ) + total_samples_with_overlap = ( + first_read_size + next_read_size * nb_next_blocks + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * reader.sw * reader.channels + ) + + cache_size = ( + (block_size - hop_size) * reader.sample_width * reader.channels + ) + total_read = cache_size + + i = 0 + while True: + block = reader.read() + if block is None: + break + i += 1 + total_read += len(block) - cache_size + + err_msg = ( + "Wrong data length read from LimiterADS, expected: {0}, found: {1}" + ) + assert total_read == expected_read_bytes, err_msg.format( + expected_read_bytes, total_read + ) + + def test_reader_with_overlapping_blocks__rewindable(self): + reader = AudioReader( + input=self.audio_source, + block_dur=320, + hop_dur=160, + record=True, + ) + assert ( + reader.rewindable + ), "AudioReader with record=True should be rewindable" + + def test_overlapping_blocks_with_max_read_rewind_and_read(self): + # Use arbitrary valid block_size and hop_size + block_size = 1600 + hop_size = 400 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + block_dur=block_dur, + hop_dur=hop_dur, + record=True, + ) + + # Read all available data overlapping blocks + i = 0 + while True: + block = reader.read() + if block is None: + break + i += 1 + + reader.rewind() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, + reader.sampling_rate, + reader.sample_width, + reader.channels, + ) + audio_source.open() + + # Compare blocks read from AudioReader to those read from an BufferAudioSource with manual position setting + for j in range(i): + tmp = audio_source.read(block_size) + assert ( + reader.read() == tmp + ), f"Unexpected data (block {i}) from reader with overlapping blocks and record = True" + audio_source.position = (j + 1) * hop_size + + audio_source.close() + + def test_overlapping_blocks_with_record_and_max_read_rewind_and_read(self): + # Use arbitrary valid block_size and hop_size + block_size = 1600 + hop_size = 400 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + max_time=1.50, + block_dur=block_dur, + hop_dur=hop_dur, + record=True, + ) + + # Read all available data overlapping blocks + i = 0 + while True: + block = reader.read() + if block is None: + break + i += 1 + + reader.rewind() + + # Read all data from file and build a BufferAudioSource + fp = wave.open(dataset.one_to_six_arabic_16000_mono_bc_noise, "r") + wave_data = fp.readframes(fp.getnframes()) + fp.close() + audio_source = BufferAudioSource( + wave_data, + reader.sampling_rate, + reader.sample_width, + reader.channels, + ) + audio_source.open() + + # Compare all blocks read from AudioReader to those read from BufferAudioSource with a manual position setting + for j in range(i): + tmp = audio_source.read(block_size) + assert ( + reader.read() == tmp + ), "Unexpected block (N={0}) read from OverlapADS".format(i) + audio_source.position = (j + 1) * hop_size + + audio_source.close() + + def test_length_read_overlapping_blocks_with_record_and_max_read(self): + # Use arbitrary valid block_size and hop_size + block_size = 1000 + hop_size = 200 + block_dur = block_size / self.audio_source.sampling_rate + hop_dur = hop_size / self.audio_source.sampling_rate + + reader = AudioReader( + input=self.audio_source, + block_dur=block_dur, + hop_dur=hop_dur, + record=True, + max_read=1.317, + ) + total_samples = round(reader.sampling_rate * 1.317) + first_read_size = block_size + next_read_size = block_size - hop_size + nb_next_blocks, last_block_size = divmod( + (total_samples - first_read_size), next_read_size + ) + total_samples_with_overlap = ( + first_read_size + next_read_size * nb_next_blocks + last_block_size + ) + expected_read_bytes = ( + total_samples_with_overlap * reader.sample_width * reader.channels + ) + + cache_size = ( + (block_size - hop_size) * reader.sample_width * reader.channels + ) + total_read = cache_size + + i = 0 + while True: + block = reader.read() + if block is None: + break + i += 1 + total_read += len(block) - cache_size + + err_msg = f"Wrong data length read from AudioReader, expected: {expected_read_bytes}, found: {total_read}" + assert total_read == expected_read_bytes, err_msg + + +def test_AudioReader_raw_data(): + + data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" + block_size = 5 + hop_size = 4 + reader = AudioReader( + input=data, + sampling_rate=16, + sample_width=2, + channels=1, + block_dur=block_size / 16, + hop_dur=hop_size / 16, + max_read=0.80, + record=True, + ) + reader.open() + + assert ( + reader.sampling_rate == 16 + ), f"Wrong sampling rate, expected: 16, found: {reader.sampling_rate }" + + assert ( + reader.sample_width == 2 + ), f"Wrong sample width, expected: 2, found: {reader.sample_width}" + + # Read all available data overlapping blocks + i = 0 + while True: + block = reader.read() + if block is None: + break + i += 1 + + reader.rewind() + + # Build a BufferAudioSource + audio_source = BufferAudioSource( + data, reader.sampling_rate, reader.sample_width, reader.channels + ) + audio_source.open() + + # Compare all blocks read from AudioReader to those read from an audio + # source with a manual position setting + for j in range(i): + tmp = audio_source.read(block_size) + block = reader.read() + assert ( + block == tmp + ), f"Unexpected block '{block}' (N={i}) read from OverlapADS" + audio_source.position = (j + 1) * hop_size + audio_source.close() + reader.close() + + +def test_AudioReader_alias_params(): + reader = AudioReader( + input=b"0" * 1600, + sr=16000, + sw=2, + channels=1, + ) + assert reader.sampling_rate == 16000, ( + "Unexpected sampling rate: reader.sampling_rate = " + + f"{reader.sampling_rate} instead of 16000" + ) + assert reader.sr == 16000, ( + "Unexpected sampling rate: reader.sr = " + + f"{reader.sr} instead of 16000" + ) + assert reader.sample_width == 2, ( + "Unexpected sample width: reader.sample_width = " + + f"{reader.sample_width} instead of 2" + ) + assert reader.sw == 2, ( + "Unexpected sample width: reader.sw = " + f"{reader.sw} instead of 2" + ) + assert reader.channels == 1, ( + "Unexpected number of channels: reader.channels = " + + f"{reader.channels} instead of 1" + ) + assert reader.ch == 1, ( + "Unexpected number of channels: reader.ch = " + + f"{reader.ch} instead of 1" + ) + + +@pytest.mark.parametrize( + "file_id, max_read, size", + [ + ("mono_400", 0.5, 16000), # mono + ("3channel_400-800-1600", 0.5, 16000 * 3), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Limiter(file_id, max_read, size): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read(size) + + reader = AudioReader(input_wav, block_dur=0.1, max_read=max_read) + reader.open() + data = _read_all_data(reader) + reader.close() + assert data == expected + + +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = AudioReader(input_wav, block_dur=0.1, record=True) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - reader.close() - self.assertEqual(data, expected) + assert data == expected + assert data == reader.data + reader.close() - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - reader = AudioReader(input_wav, block_dur=0.1, record=True) - reader.open() +@pytest.mark.parametrize( + "file_id", + [ + "mono_400", # mono + "3channel_400-800-1600", # multichannel + ], + ids=["mono", "multichannel"], +) +def test_Recorder_alias(file_id): + input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + with open(input_raw, "rb") as fp: + expected = fp.read() + + reader = Recorder(input_wav, block_dur=0.1) + reader.open() + data = _read_all_data(reader) + assert data == expected + + # rewind many times + for _ in range(3): + reader.rewind() data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - @genty_dataset(mono=("mono_400",), multichannel=("3channel_400-800-1600",)) - def test_Recorder_alias(self, file_id): - input_wav = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - input_raw = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - with open(input_raw, "rb") as fp: - expected = fp.read() - - reader = Recorder(input_wav, block_dur=0.1) - reader.open() - data = _read_all_data(reader) - self.assertEqual(data, expected) - - # rewind many times - for _ in range(3): - reader.rewind() - data = _read_all_data(reader) - self.assertEqual(data, expected) - self.assertEqual(data, reader.data) - reader.close() - - -if __name__ == "__main__": - unittest.main() + assert data == expected + assert data == reader.data + reader.close()
--- a/tests/test_AudioSource.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_AudioSource.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,17 +1,80 @@ """ @author: Amine Sehili <amine.sehili@gmail.com> """ + from array import array -import unittest -from genty import genty, genty_dataset + +import numpy as np +import pytest + from auditok.io import ( + AudioIOError, AudioParameterError, BufferAudioSource, RawAudioSource, WaveAudioSource, ) -from auditok.signal import FORMAT -from test_util import PURE_TONE_DICT, _sample_generator +from auditok.signal import SAMPLE_WIDTH_TO_DTYPE + + +def _sample_generator(*data_buffers): + """ + Takes a list of many mono audio data buffers and makes a sample generator + of interleaved audio samples, one sample from each channel. The resulting + generator can be used to build a multichannel audio buffer. + >>> gen = _sample_generator("abcd", "ABCD") + >>> list(gen) + ["a", "A", "b", "B", "c", "C", "d", "D"] + """ + frame_gen = zip(*data_buffers) + return (sample for frame in frame_gen for sample in frame) + + +def _generate_pure_tone( + frequency, duration_sec=1, sampling_rate=16000, sample_width=2, volume=1e4 +): + """ + Generates a pure tone with the given frequency. + """ + assert frequency <= sampling_rate / 2 + max_value = (2 ** (sample_width * 8) // 2) - 1 + if volume > max_value: + volume = max_value + dtype = SAMPLE_WIDTH_TO_DTYPE[sample_width] + total_samples = int(sampling_rate * duration_sec) + step = frequency / sampling_rate + two_pi_step = 2 * np.pi * step + data = np.array( + [int(np.sin(two_pi_step * i) * volume) for i in range(total_samples)] + ).astype(dtype) + return data + + +@pytest.fixture +def pure_tone_data(freq): + + PURE_TONE_DICT = { + freq: _generate_pure_tone(freq, 1, 16000, 2) + for freq in (400, 800, 1600) + } + PURE_TONE_DICT.update( + { + freq: _generate_pure_tone(freq, 0.1, 16000, 2) + for freq in (600, 1150, 2400, 7220) + } + ) + return PURE_TONE_DICT[freq] + + +PURE_TONE_DICT = { + freq: _generate_pure_tone(freq, 1, 16000, 2) for freq in (400, 800, 1600) +} +PURE_TONE_DICT.update( + { + freq: _generate_pure_tone(freq, 0.1, 16000, 2) + for freq in (600, 1150, 2400, 7220) + } +) def audio_source_read_all_gen(audio_source, size=None): @@ -24,202 +87,166 @@ yield data -@genty -class TestAudioSource(unittest.TestCase): +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_BufferAudioSource_read_all(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) + with open(file, "rb") as fp: + expected = fp.read() + channels = len(frequencies) + audio_source = BufferAudioSource(expected, 16000, 2, channels) + audio_source.open() + data = audio_source.read(None) + assert data == expected + audio_source.rewind() + data = audio_source.read(-10) + assert data == expected + audio_source.close() - # TODO when use_channel is None, return samples from all channels - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_BufferAudioSource_read_all(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) - with open(file, "rb") as fp: - expected = fp.read() - channels = len(frequencies) - audio_source = BufferAudioSource(expected, 16000, 2, channels) - audio_source.open() - data = audio_source.read(None) - self.assertEqual(data, expected) - audio_source.rewind() - data = audio_source.read(-10) - self.assertEqual(data, expected) - audio_source.close() +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_RawAudioSource(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) + channels = len(frequencies) + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = b"".join(audio_source_read_all_gen(audio_source)) + audio_source.close() + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + dtype = SAMPLE_WIDTH_TO_DTYPE[audio_source.sample_width] + expected = np.fromiter(_sample_generator(*mono_channels), dtype).tobytes() - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_RawAudioSource(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.raw".format(file_suffix) - channels = len(frequencies) - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = b"".join(audio_source_read_all_gen(audio_source)) - audio_source.close() - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() + assert data_read_all == expected - self.assertEqual(data_read_all, expected) + # assert read all data with None + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + assert data_read_all == expected - # assert read all data with None - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = audio_source.read(None) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with a negative size + audio_source = RawAudioSource(file, 16000, 2, channels) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + assert data_read_all == expected - # assert read all data with a negative size - audio_source = RawAudioSource(file, 16000, 2, channels) - audio_source.open() - data_read_all = audio_source.read(-10) - audio_source.close() - self.assertEqual(data_read_all, expected) - @genty_dataset( - mono=("mono_400Hz", (400,)), - multichannel=("3channel_400-800-1600Hz", (400, 800, 1600)), - ) - def test_WaveAudioSource(self, file_suffix, frequencies): - file = "tests/data/test_16KHZ_{}.wav".format(file_suffix) - audio_source = WaveAudioSource(file) - audio_source.open() - data = b"".join(audio_source_read_all_gen(audio_source)) - audio_source.close() - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() +@pytest.mark.parametrize( + "file_suffix, frequencies", + [ + ("mono_400Hz", (400,)), # mono + ("3channel_400-800-1600Hz", (400, 800, 1600)), # multichannel + ], + ids=["mono", "multichannel"], +) +def test_WaveAudioSource(file_suffix, frequencies): + file = "tests/data/test_16KHZ_{}.wav".format(file_suffix) + audio_source = WaveAudioSource(file) + audio_source.open() + data = b"".join(audio_source_read_all_gen(audio_source)) + audio_source.close() + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + dtype = SAMPLE_WIDTH_TO_DTYPE[audio_source.sample_width] + expected = np.fromiter(_sample_generator(*mono_channels), dtype).tobytes() - self.assertEqual(data, expected) + assert data == expected - # assert read all data with None - audio_source = WaveAudioSource(file) - audio_source.open() - data_read_all = audio_source.read(None) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with None + audio_source = WaveAudioSource(file) + audio_source.open() + data_read_all = audio_source.read(None) + audio_source.close() + assert data_read_all == expected - # assert read all data with a negative size - audio_source = WaveAudioSource(file) - audio_source.open() - data_read_all = audio_source.read(-10) - audio_source.close() - self.assertEqual(data_read_all, expected) + # assert read all data with a negative size + audio_source = WaveAudioSource(file) + audio_source.open() + data_read_all = audio_source.read(-10) + audio_source.close() + assert data_read_all == expected -@genty -class TestBufferAudioSource_SR10_SW1_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR10_SW1_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=10, sample_width=1, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr10_sw1_ch1_read_1(self): block = self.audio_source.read(1) exp = b"A" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEF" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"A" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"BCDEFG" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(13) exp = b"HIJKLMNOPQRST" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"UVWXYZ012345" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr10_sw1_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {}, found: {} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {}, found: {} ".format(None, block), - ) + assert block is None def test_sr10_sw1_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 10, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 10 def test_sr10_sw1_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 1, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 1 def test_sr10_sw1_ch1_channels(self): channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) + assert channels == 1 - @genty_dataset( - empty=([], 0, 0, 0), - zero=([0], 0, 0, 0), - five=([5], 5, 0.5, 500), - multiple=([5, 20], 25, 2.5, 2500), + @pytest.mark.parametrize( + "block_sizes, expected_sample, expected_second, expected_ms", + [ + ([], 0, 0, 0), # empty + ([0], 0, 0, 0), # zero + ([5], 5, 0.5, 500), # five + ([5, 20], 25, 2.5, 2500), # multiple + ], + ids=["empty", "zero", "five", "multiple"], ) def test_position( self, block_sizes, expected_sample, expected_second, expected_ms @@ -227,38 +254,24 @@ for block_size in block_sizes: self.audio_source.read(block_size) position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(1, 1, 0.1, 100), - ten=(10, 10, 1, 1000), - negative_1=(-1, 31, 3.1, 3100), - negative_2=(-7, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (1, 1, 0.1, 100), # one + (10, 10, 1, 1000), # ten + (-1, 31, 3.1, 3100), # negative_1 + (-7, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_setter( self, position, expected_sample, expected_second, expected_ms @@ -266,38 +279,24 @@ self.audio_source.position = position position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(0.1, 1, 0.1, 100), - ten=(1, 10, 1, 1000), - negative_1=(-0.1, 31, 3.1, 3100), - negative_2=(-0.7, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position_s, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (0.1, 1, 0.1, 100), # one + (1, 10, 1, 1000), # ten + (-0.1, 31, 3.1, 3100), # negative_1 + (-0.7, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_s_setter( self, position_s, expected_sample, expected_second, expected_ms @@ -305,38 +304,24 @@ self.audio_source.position_s = position_s position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(100, 1, 0.1, 100), - ten=(1000, 10, 1, 1000), - negative_1=(-100, 31, 3.1, 3100), - negative_2=(-700, 25, 2.5, 2500), + @pytest.mark.parametrize( + "position_ms, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (100, 1, 0.1, 100), # one + (1000, 10, 1, 1000), # ten + (-100, 31, 3.1, 3100), # negative_1 + (-700, 25, 2.5, 2500), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_ms_setter( self, position_ms, expected_sample, expected_second, expected_ms @@ -344,222 +329,157 @@ self.audio_source.position_ms = position_ms position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset(positive=((100,)), negative=(-100,)) + @pytest.mark.parametrize( + "position", + [ + 100, # positive + -100, # negative + ], + ids=["positive", "negative"], + ) def test_position_setter_out_of_range(self, position): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position = position - @genty_dataset(positive=((100,)), negative=(-100,)) + @pytest.mark.parametrize( + "position_s", + [ + 100, # positive + -100, # negative + ], + ids=["positive", "negative"], + ) def test_position_s_setter_out_of_range(self, position_s): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position_s = position_s - @genty_dataset(positive=((10000,)), negative=(-10000,)) + @pytest.mark.parametrize( + "position_ms", + [ + 10000, # positive + -10000, # negative + ], + ids=["positive", "negative"], + ) def test_position_ms_setter_out_of_range(self, position_ms): - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.audio_source.position_ms = position_ms def test_sr10_sw1_ch1_initial_position_s_0(self): tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr10_sw1_ch1_position_s_1_after_read(self): srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr10_sw1_ch1_position_s_2_5(self): # read 2.5 seconds self.audio_source.read(25) tp = self.audio_source.position_s - self.assertEqual( - tp, - 2.5, - msg="wrong time position, expected: 2.5, found: {0} ".format(tp), - ) + assert tp == 2.5 def test_sr10_sw1_ch1_position_s_0(self): self.audio_source.read(10) self.audio_source.position_s = 0 tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr10_sw1_ch1_position_s_1(self): self.audio_source.position_s = 1 tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr10_sw1_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 def test_sr10_sw1_ch1_read_closed(self): self.audio_source.close() - with self.assertRaises(Exception): + with pytest.raises(AudioIOError): self.audio_source.read(1) -@genty -class TestBufferAudioSource_SR16_SW2_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR16_SW2_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=16, sample_width=2, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr16_sw2_ch1_read_1(self): block = self.audio_source.read(1) exp = b"AB" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEFGHIJKL" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"AB" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"CDEFGHIJKLMN" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(5) exp = b"OPQRSTUVWX" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"YZ012345" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr16_sw2_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {0}, found: {1} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {0}, found: {1} ".format(None, block), - ) + assert block is None def test_sr16_sw2_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 16, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 16 def test_sr16_sw2_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 2, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 2 def test_sr16_sw2_ch1_channels(self): + channels = self.audio_source.channels + assert channels == 1 - channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) - - @genty_dataset( - empty=([], 0, 0, 0), - zero=([0], 0, 0, 0), - two=([2], 2, 2 / 16, int(2000 / 16)), - eleven=([11], 11, 11 / 16, int(11 * 1000 / 16)), - multiple=([4, 8], 12, 0.75, 750), + @pytest.mark.parametrize( + "block_sizes, expected_sample, expected_second, expected_ms", + [ + ([], 0, 0, 0), # empty + ([0], 0, 0, 0), # zero + ([2], 2, 2 / 16, int(2000 / 16)), # two + ([11], 11, 11 / 16, int(11 * 1000 / 16)), # eleven + ([4, 8], 12, 0.75, 750), # multiple + ], + ids=["empty", "zero", "two", "eleven", "multiple"], ) def test_position( self, block_sizes, expected_sample, expected_second, expected_ms @@ -567,46 +487,30 @@ for block_size in block_sizes: self.audio_source.read(block_size) position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms def test_sr16_sw2_ch1_read_position_0(self): self.audio_source.read(10) self.audio_source.position = 0 pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 - @genty_dataset( - zero=(0, 0, 0, 0), - one=(1, 1, 1 / 16, int(1000 / 16)), - ten=(10, 10, 10 / 16, int(10000 / 16)), - negative_1=(-1, 15, 15 / 16, int(15000 / 16)), - negative_2=(-7, 9, 9 / 16, int(9000 / 16)), + @pytest.mark.parametrize( + "position, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (1, 1, 1 / 16, int(1000 / 16)), # one + (10, 10, 10 / 16, int(10000 / 16)), # ten + (-1, 15, 15 / 16, int(15000 / 16)), # negative_1 + (-7, 9, 9 / 16, int(9000 / 16)), # negative_2 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2"], ) def test_position_setter( self, position, expected_sample, expected_second, expected_ms @@ -614,39 +518,25 @@ self.audio_source.position = position position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(0.1, 1, 1 / 16, int(1000 / 16)), - two=(1 / 8, 2, 1 / 8, int(1 / 8 * 1000)), - twelve=(0.75, 12, 0.75, 750), - negative_1=(-0.1, 15, 15 / 16, int(15000 / 16)), - negative_2=(-0.7, 5, 5 / 16, int(5000 / 16)), + @pytest.mark.parametrize( + "position_s, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (0.1, 1, 1 / 16, int(1000 / 16)), # one + (1 / 8, 2, 1 / 8, int(1 / 8 * 1000)), # two + (0.75, 12, 0.75, 750), # twelve + (-0.1, 15, 15 / 16, int(15000 / 16)), # negative_1 + (-0.7, 5, 5 / 16, int(5000 / 16)), # negative_2 + ], + ids=["zero", "one", "two", "twelve", "negative_1", "negative_2"], ) def test_position_s_setter( self, position_s, expected_sample, expected_second, expected_ms @@ -654,39 +544,25 @@ self.audio_source.position_s = position_s position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms - @genty_dataset( - zero=(0, 0, 0, 0), - one=(100, 1, 1 / 16, int(1000 / 16)), - ten=(1000, 16, 1, 1000), - negative_1=(-100, 15, 15 / 16, int(15 * 1000 / 16)), - negative_2=(-500, 8, 0.5, 500), - negative_3=(-700, 5, 5 / 16, int(5 * 1000 / 16)), + @pytest.mark.parametrize( + "position_ms, expected_sample, expected_second, expected_ms", + [ + (0, 0, 0, 0), # zero + (100, 1, 1 / 16, int(1000 / 16)), # one + (1000, 16, 1, 1000), # ten + (-100, 15, 15 / 16, int(15 * 1000 / 16)), # negative_1 + (-500, 8, 0.5, 500), # negative_2 + (-700, 5, 5 / 16, int(5 * 1000 / 16)), # negative_3 + ], + ids=["zero", "one", "ten", "negative_1", "negative_2", "negative_3"], ) def test_position_ms_setter( self, position_ms, expected_sample, expected_second, expected_ms @@ -694,266 +570,161 @@ self.audio_source.position_ms = position_ms position = self.audio_source.position - self.assertEqual( - position, - expected_sample, - msg="wrong stream position, expected: {}, found: {}".format( - expected_sample, position - ), - ) + assert position == expected_sample position_s = self.audio_source.position_s - self.assertEqual( - position_s, - expected_second, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_second, position_s - ), - ) + assert position_s == expected_second position_ms = self.audio_source.position_ms - self.assertEqual( - position_ms, - expected_ms, - msg="wrong stream position_s, expected: {}, found: {}".format( - expected_ms, position_ms - ), - ) + assert position_ms == expected_ms def test_sr16_sw2_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 -class TestBufferAudioSource_SR11_SW4_CH1(unittest.TestCase): - def setUp(self): +class TestBufferAudioSource_SR11_SW4_CH1: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): self.data = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefgh" self.audio_source = BufferAudioSource( data=self.data, sampling_rate=11, sample_width=4, channels=1 ) self.audio_source.open() - - def tearDown(self): + yield self.audio_source.close() def test_sr11_sw4_ch1_read_1(self): block = self.audio_source.read(1) exp = b"ABCD" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_6(self): block = self.audio_source.read(6) exp = b"ABCDEFGHIJKLMNOPQRSTUVWX" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_multiple(self): block = self.audio_source.read(1) exp = b"ABCD" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(6) exp = b"EFGHIJKLMNOPQRSTUVWXYZ01" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(3) exp = b"23456789abcd" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp block = self.audio_source.read(9999) exp = b"efgh" - self.assertEqual( - block, - exp, - msg="wrong block, expected: {}, found: {} ".format(exp, block), - ) + assert block == exp def test_sr11_sw4_ch1_read_all(self): block = self.audio_source.read(9999) - self.assertEqual( - block, - self.data, - msg="wrong block, expected: {0}, found: {1} ".format( - self.data, block - ), - ) + assert block == self.data block = self.audio_source.read(1) - self.assertEqual( - block, - None, - msg="wrong block, expected: {0}, found: {1} ".format(None, block), - ) + assert block is None def test_sr11_sw4_ch1_sampling_rate(self): srate = self.audio_source.sampling_rate - self.assertEqual( - srate, - 11, - msg="wrong sampling rate, expected: 10, found: {0} ".format(srate), - ) + assert srate == 11 def test_sr11_sw4_ch1_sample_width(self): swidth = self.audio_source.sample_width - self.assertEqual( - swidth, - 4, - msg="wrong sample width, expected: 1, found: {0} ".format(swidth), - ) + assert swidth == 4 def test_sr11_sw4_ch1_channels(self): channels = self.audio_source.channels - self.assertEqual( - channels, - 1, - msg="wrong number of channels, expected: 1, found: {0} ".format( - channels - ), - ) + assert channels == 1 def test_sr11_sw4_ch1_intial_position_0(self): pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 def test_sr11_sw4_ch1_position_5(self): self.audio_source.read(5) pos = self.audio_source.position - self.assertEqual( - pos, 5, msg="wrong position, expected: 5, found: {0} ".format(pos) - ) + assert pos == 5 def test_sr11_sw4_ch1_position_9(self): self.audio_source.read(5) self.audio_source.read(4) pos = self.audio_source.position - self.assertEqual( - pos, 9, msg="wrong position, expected: 5, found: {0} ".format(pos) - ) + assert pos == 9 def test_sr11_sw4_ch1_position_0(self): self.audio_source.read(10) self.audio_source.position = 0 pos = self.audio_source.position - self.assertEqual( - pos, 0, msg="wrong position, expected: 0, found: {0} ".format(pos) - ) + assert pos == 0 def test_sr11_sw4_ch1_position_10(self): self.audio_source.position = 10 pos = self.audio_source.position - self.assertEqual( - pos, - 10, - msg="wrong position, expected: 10, found: {0} ".format(pos), - ) + assert pos == 10 def test_sr11_sw4_ch1_initial_position_s_0(self): tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr11_sw4_ch1_position_s_1_after_read(self): srate = self.audio_source.sampling_rate # read one second self.audio_source.read(srate) tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr11_sw4_ch1_position_s_0_63(self): # read 2.5 seconds self.audio_source.read(7) tp = self.audio_source.position_s - self.assertAlmostEqual( - tp, - 0.636363636364, - msg="wrong time position, expected: 0.636363636364, " - "found: {0} ".format(tp), - ) + assert tp, pytest.approx(0.636363636364) def test_sr11_sw4_ch1_position_s_0(self): self.audio_source.read(10) self.audio_source.position_s = 0 tp = self.audio_source.position_s - self.assertEqual( - tp, - 0.0, - msg="wrong time position, expected: 0.0, found: {0} ".format(tp), - ) + assert tp == 0.0 def test_sr11_sw4_ch1_position_s_1(self): self.audio_source.position_s = 1 tp = self.audio_source.position_s - self.assertEqual( - tp, - 1.0, - msg="wrong time position, expected: 1.0, found: {0} ".format(tp), - ) + assert tp == 1.0 def test_sr11_sw4_ch1_rewind(self): self.audio_source.read(10) self.audio_source.rewind() tp = self.audio_source.position - self.assertEqual( - tp, 0, msg="wrong position, expected: 0.0, found: {0} ".format(tp) - ) + assert tp == 0 -class TestBufferAudioSourceCreationException(unittest.TestCase): +class TestBufferAudioSourceCreationException: def test_wrong_sample_width_value(self): - with self.assertRaises(AudioParameterError) as audio_param_err: + with pytest.raises(AudioParameterError) as audio_param_err: _ = BufferAudioSource( data=b"ABCDEFGHI", sampling_rate=9, sample_width=3, channels=1 ) - self.assertEqual( - "Sample width must be one of: 1, 2 or 4 (bytes)", - str(audio_param_err.exception), + assert ( + str(audio_param_err.value) + == "Sample width must be one of: 1, 2 or 4 (bytes)" ) def test_wrong_data_buffer_size(self): - with self.assertRaises(AudioParameterError) as audio_param_err: + with pytest.raises(AudioParameterError) as audio_param_err: _ = BufferAudioSource( data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 ) - self.assertEqual( - "The length of audio data must be an integer " - "multiple of `sample_width * channels`", - str(audio_param_err.exception), - ) + msg = "The length of audio data must be an integer multiple of " + msg += "`sample_width * channels`" + assert str(audio_param_err.value) == msg -class TestAudioSourceProperties(unittest.TestCase): +class TestAudioSourceProperties: def test_read_properties(self): data = b"" sampling_rate = 8000 @@ -963,9 +734,9 @@ data, sampling_rate, sample_width, channels ) - self.assertEqual(a_source.sampling_rate, sampling_rate) - self.assertEqual(a_source.sample_width, sample_width) - self.assertEqual(a_source.channels, channels) + assert a_source.sampling_rate == sampling_rate + assert a_source.sample_width == sample_width + assert a_source.channels == channels def test_set_readonly_properties_exception(self): data = b"" @@ -976,13 +747,17 @@ data, sampling_rate, sample_width, channels ) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): a_source.sampling_rate = 16000 + + with pytest.raises(AttributeError): a_source.sample_width = 1 + + with pytest.raises(AttributeError): a_source.channels = 2 -class TestAudioSourceShortProperties(unittest.TestCase): +class TestAudioSourceShortProperties: def test_read_short_properties(self): data = b"" sampling_rate = 8000 @@ -992,9 +767,9 @@ data, sampling_rate, sample_width, channels ) - self.assertEqual(a_source.sr, sampling_rate) - self.assertEqual(a_source.sw, sample_width) - self.assertEqual(a_source.ch, channels) + assert a_source.sr == sampling_rate + assert a_source.sw == sample_width + assert a_source.ch == channels def test_set_readonly_short_properties_exception(self): data = b"" @@ -1005,11 +780,11 @@ data, sampling_rate, sample_width, channels ) - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): a_source.sr = 16000 + + with pytest.raises(AttributeError): a_source.sw = 1 + + with pytest.raises(AttributeError): a_source.ch = 2 - - -if __name__ == "__main__": - unittest.main()
--- a/tests/test_StreamTokenizer.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_StreamTokenizer.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,11 +1,8 @@ -""" -@author: Amine Sehili <amine.sehili@gmail.com> -September 2015 +import os -""" +import pytest -import unittest -from auditok import StreamTokenizer, StringDataSource, DataValidator +from auditok import DataValidator, StreamTokenizer, StringDataSource class AValidator(DataValidator): @@ -13,1017 +10,672 @@ return frame == "A" -class TestStreamTokenizerInitParams(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() +@pytest.fixture +def validator(): + return AValidator() - # Completely deactivate init_min and init_max_silence - # The tokenizer will only rely on the other parameters - # Note that if init_min = 0, the value of init_max_silence - # will have no effect - def test_init_min_0_init_max_silence_0(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=0, - init_max_silence=0, - mode=0, - ) +def test_init_min_0_init_max_silence_0(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=0, + init_max_silence=0, + mode=0, + ) - data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaaAAAAAAAA") - # ^ ^ ^ ^ - # 2 16 20 27 - tokens = tokenizer.tokenize(data_source) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaaAAAAAAAA") + # ^ ^ ^ ^ + # 2 16 20 27 + tokens = tokenizer.tokenize(data_source) - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + assert ( + len(tokens) == 2 + ), "wrong number of tokens, expected: 2, found: {}".format(len(tokens)) + tok1, tok2 = tokens[0], tokens[1] - # tok1[0]: data - # tok1[1]: start frame (included) - # tok1[2]: end frame (included) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaaaa" + ), "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', found: {}".format( + data + ) + assert ( + start == 1 + ), "wrong start frame for token 1, expected: 1, found: {}".format(start) + assert ( + end == 16 + ), "wrong end frame for token 1, expected: 16, found: {}".format(end) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', " - "found: {0} " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAA" + ), "wrong data for token 2, expected: 'AAAAAAAA', found: {}".format(data) + assert ( + start == 20 + ), "wrong start frame for token 2, expected: 20, found: {}".format(start) + assert ( + end == 27 + ), "wrong end frame for token 2, expected: 27, found: {}".format(end) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: {0} " - ).format(data), - ) - self.assertEqual( - start, - 20, - msg=( - "wrong start frame for token 2, expected: 20, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 27, - msg=( - "wrong end frame for token 2, expected: 27, found: {0} " - ).format(end), - ) - # A valid token is considered as so iff the tokenizer encounters - # at least valid frames (init_min = 3) between witch there - # are at most 0 consecutive non valid frames (init_max_silence = 0) - # The tokenizer will only rely on the other parameters - # In other words, a valid token must start with 3 valid frames - def test_init_min_3_init_max_silence_0(self): +def test_init_min_3_init_max_silence_0(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=0, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=0, - mode=0, - ) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaaAAAAA") + # ^ ^ ^ ^ + # 18 30 33 37 - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaaAAAAA" - ) - # ^ ^ ^ ^ - # 18 30 33 37 + tokens = tokenizer.tokenize(data_source) - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 2 + ), "wrong number of tokens, expected: 2, found: {}".format(len(tokens)) + tok1, tok2 = tokens[0], tokens[1] - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAAAaaaa" + ), "wrong data for token 1, expected: 'AAAAAAAAAaaaa', found: {}".format( + data + ) + assert ( + start == 18 + ), "wrong start frame for token 1, expected: 18, found: {}".format(start) + assert ( + end == 30 + ), "wrong end frame for token 1, expected: 30, found: {}".format(end) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAAAaaaa", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 1, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 30, - msg=( - "wrong end frame for token 1, expected: 30, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAA" + ), "wrong data for token 2, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 33 + ), "wrong start frame for token 2, expected: 33, found: {}".format(start) + assert ( + end == 37 + ), "wrong end frame for token 2, expected: 37, found: {}".format(end) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 33, - msg=( - "wrong start frame for token 2, expected: 33, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 37, - msg=( - "wrong end frame for token 2, expected: 37, found: {0} " - ).format(end), - ) - # A valid token is considered iff the tokenizer encounters - # at least valid frames (init_min = 3) between witch there - # are at most 2 consecutive non valid frames (init_max_silence = 2) - def test_init_min_3_init_max_silence_2(self): +def test_init_min_3_init_max_silence_2(validator): + tokenizer = StreamTokenizer( + validator, + min_length=5, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=2, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=2, - mode=0, - ) + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaaAAAAAAAAAaaaaaaaAAAAA") + # ^ ^ ^ ^ ^ ^ + # 5 16 19 31 35 39 + tokens = tokenizer.tokenize(data_source) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaaAAAAAAAAAaaaaaaaAAAAA" - ) - # ^ ^ ^ ^ ^ ^ - # 5 16 19 31 35 39 - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 3 + ), "wrong number of tokens, expected: 3, found: {}".format(len(tokens)) + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaAaaAaAaaaa" + ), "wrong data for token 1, expected: 'AaAaaAaA', found: {}".format(data) + assert ( + start == 5 + ), "wrong start frame for token 1, expected: 5, found: {}".format(start) + assert ( + end == 16 + ), "wrong end frame for token 1, expected: 16, found: {}".format(end) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaAaaAaA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 5, - msg=( - "wrong start frame for token 1, expected: 5, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAAAaaaa" + ), "wrong data for token 2, expected: 'AAAAAAAAAaaaa', found: {}".format( + data + ) + assert ( + start == 19 + ), "wrong start frame for token 2, expected: 19, found: {}".format(start) + assert ( + end == 31 + ), "wrong end frame for token 2, expected: 31, found: {}".format(end) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAAAaaaa", - msg=( - "wrong data for token 2, expected: 'AAAAAAAAAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 19, - msg=( - "wrong start frame for token 2, expected: 19, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 31, - msg=( - "wrong end frame for token 2, expected: 31, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAA" + ), "wrong data for token 3, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 35 + ), "wrong start frame for token 3, expected: 35, found: {}".format(start) + assert ( + end == 39 + ), "wrong end frame for token 3, expected: 39, found: {}".format(end) - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 3, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 35, - msg=( - "wrong start frame for token 2, expected: 35, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 39, - msg=( - "wrong end frame for token 2, expected: 39, found: {0} " - ).format(end), - ) +@pytest.fixture +def tokenizer_min_max_length(validator): + return StreamTokenizer( + validator, + min_length=6, + max_length=20, + max_continuous_silence=2, + init_min=3, + init_max_silence=3, + mode=0, + ) -class TestStreamTokenizerMinMaxLength(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_min_length_6_init_max_length_20(self): +def test_min_length_6_init_max_length_20(tokenizer_min_max_length): + data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") + # ^ ^ ^ ^ + # 1 14 18 28 - tokenizer = StreamTokenizer( - self.A_validator, - min_length=6, - max_length=20, - max_continuous_silence=2, - init_min=3, - init_max_silence=3, - mode=0, - ) + tokens = tokenizer_min_max_length.tokenize(data_source) - data_source = StringDataSource("aAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") - # ^ ^ ^ ^ - # 1 14 18 28 + assert ( + len(tokens) == 2 + ), "wrong number of tokens, expected: 2, found: {}".format(len(tokens)) + tok1, tok2 = tokens[0], tokens[1] - tokens = tokenizer.tokenize(data_source) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaa" + ), "wrong data for token 1, expected: 'AaaaAaAaaAaAaa', found: {}".format( + data + ) + assert ( + start == 1 + ), "wrong start frame for token 1, expected: 1, found: {}".format(start) + assert ( + end == 14 + ), "wrong end frame for token 1, expected: 14, found: {}".format(end) - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAAAAAaa" + ), "wrong data for token 2, expected: 'AAAAAAAAAaa', found: {}".format(data) + assert ( + start == 18 + ), "wrong start frame for token 2, expected: 18, found: {}".format(start) + assert ( + end == 28 + ), "wrong end frame for token 2, expected: 28, found: {}".format(end) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 14, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAAAAAaa", - msg=( - "wrong data for token 2, expected: 'AAAAAAAAAaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 2, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 28, - msg=( - "wrong end frame for token 2, expected: 28, found: {0} " - ).format(end), - ) +@pytest.fixture +def tokenizer_min_max_length_1_1(validator): + return StreamTokenizer( + validator, + min_length=1, + max_length=1, + max_continuous_silence=0, + init_min=0, + init_max_silence=0, + mode=0, + ) - def test_min_length_1_init_max_length_1(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=1, - max_length=1, - max_continuous_silence=0, - init_min=0, - init_max_silence=0, - mode=0, - ) +def test_min_length_1_init_max_length_1(tokenizer_min_max_length_1_1): + data_source = StringDataSource("AAaaaAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA") - data_source = StringDataSource( - "AAaaaAaaaAaAaaAaAaaaaaAAAAAAAAAaaaaaAAAAA" - ) + tokens = tokenizer_min_max_length_1_1.tokenize(data_source) - tokens = tokenizer.tokenize(data_source) + assert ( + len(tokens) == 21 + ), "wrong number of tokens, expected: 21, found: {}".format(len(tokens)) - self.assertEqual( - len(tokens), - 21, - msg="wrong number of tokens, expected: 21, found: {0} ".format( - len(tokens) - ), - ) - def test_min_length_10_init_max_length_20(self): +@pytest.fixture +def tokenizer_min_max_length_10_20(validator): + return StreamTokenizer( + validator, + min_length=10, + max_length=20, + max_continuous_silence=4, + init_min=3, + init_max_silence=3, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=10, - max_length=20, - max_continuous_silence=4, - init_min=3, - init_max_silence=3, - mode=0, - ) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaaAAAAAaaaaaaAAAAAaaAAaaAAA" - ) - # ^ ^ ^ ^ - # 1 16 30 45 +def test_min_length_10_init_max_length_20(tokenizer_min_max_length_10_20): + data_source = StringDataSource( + "aAaaaAaAaaAaAaaaaaaAAAAAaaaaaaAAAAAaaAAaaAAA" + ) + # ^ ^ ^ ^ + # 1 16 30 45 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_min_max_length_10_20.tokenize(data_source) - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 2, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2 = tokens[0], tokens[1] + assert ( + len(tokens) == 2 + ), "wrong number of tokens, expected: 2, found: {}".format(len(tokens)) + tok1, tok2 = tokens[0], tokens[1] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AaaaAaAaaAaAaaaa", - msg=( - "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 1, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 16, - msg=( - "wrong end frame for token 1, expected: 16, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AaaaAaAaaAaAaaaa" + ), "wrong data for token 1, expected: 'AaaaAaAaaAaAaaaa', found: {}".format( + data + ) + assert ( + start == 1 + ), "wrong start frame for token 1, expected: 1, found: {}".format(start) + assert ( + end == 16 + ), "wrong end frame for token 1, expected: 16, found: {}".format(end) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAaaAAaaAAA", - msg=( - "wrong data for token 2, expected: 'AAAAAaaAAaaAAA', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 30, - msg=( - "wrong start frame for token 2, expected: 30, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 43, - msg=( - "wrong end frame for token 2, expected: 43, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAaaAAaaAAA" + ), "wrong data for token 2, expected: 'AAAAAaaAAaaAAA', found: {}".format( + data + ) + assert ( + start == 30 + ), "wrong start frame for token 2, expected: 30, found: {}".format(start) + assert ( + end == 43 + ), "wrong end frame for token 2, expected: 43, found: {}".format(end) - def test_min_length_4_init_max_length_5(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=4, - max_length=5, - max_continuous_silence=4, - init_min=3, - init_max_silence=3, - mode=0, - ) +@pytest.fixture +def tokenizer_min_max_length_4_5(validator): + return StreamTokenizer( + validator, + min_length=4, + max_length=5, + max_continuous_silence=4, + init_min=3, + init_max_silence=3, + mode=0, + ) - data_source = StringDataSource( - "aAaaaAaAaaAaAaaaaaAAAAAAAAaaaaaaAAAAAaaaaaAAaaAaa" - ) - # ^ ^^ ^ ^ ^ ^ ^ - # 18 2223 27 32 36 42 46 - tokens = tokenizer.tokenize(data_source) +def test_min_length_4_init_max_length_5(tokenizer_min_max_length_4_5): + data_source = StringDataSource( + "aAaaaAaAaaAaAaaaaaAAAAAAAAaaaaaaAAAAAaaaaaAAaaAaa" + ) + # ^ ^^ ^ ^ ^ ^ ^ + # 18 2223 27 32 36 42 46 - self.assertEqual( - len(tokens), - 4, - msg="wrong number of tokens, expected: 4, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3, tok4 = tokens[0], tokens[1], tokens[2], tokens[3] + tokens = tokenizer_min_max_length_4_5.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 18, - msg=( - "wrong start frame for token 1, expected: 18, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 22, - msg=( - "wrong end frame for token 1, expected: 22, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 4 + ), "wrong number of tokens, expected: 4, found: {}".format(len(tokens)) + tok1, tok2, tok3, tok4 = tokens[0], tokens[1], tokens[2], tokens[3] - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAaa", - msg=( - "wrong data for token 1, expected: 'AAAaa', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 23, - msg=( - "wrong start frame for token 1, expected: 23, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 27, - msg=( - "wrong end frame for token 1, expected: 27, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), "wrong data for token 1, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 18 + ), "wrong start frame for token 1, expected: 18, found: {}".format(start) + assert ( + end == 22 + ), "wrong end frame for token 1, expected: 22, found: {}".format(end) - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 32, - msg=( - "wrong start frame for token 1, expected: 1, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 36, - msg=( - "wrong end frame for token 1, expected: 7, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAaa" + ), "wrong data for token 2, expected: 'AAAaa', found: {}".format(data) + assert ( + start == 23 + ), "wrong start frame for token 2, expected: 23, found: {}".format(start) + assert ( + end == 27 + ), "wrong end frame for token 2, expected: 27, found: {}".format(end) - data = "".join(tok4[0]) - start = tok4[1] - end = tok4[2] - self.assertEqual( - data, - "AAaaA", - msg=( - "wrong data for token 2, expected: 'AAaaA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 42, - msg=( - "wrong start frame for token 2, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 46, - msg=( - "wrong end frame for token 2, expected: 22, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAA" + ), "wrong data for token 3, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 32 + ), "wrong start frame for token 3, expected: 32, found: {}".format(start) + assert ( + end == 36 + ), "wrong end frame for token 3, expected: 36, found: {}".format(end) + data = "".join(tok4[0]) + start = tok4[1] + end = tok4[2] + assert ( + data == "AAaaA" + ), "wrong data for token 4, expected: 'AAaaA', found: {}".format(data) + assert ( + start == 42 + ), "wrong start frame for token 4, expected: 42, found: {}".format(start) + assert ( + end == 46 + ), "wrong end frame for token 4, expected: 46, found: {}".format(end) -class TestStreamTokenizerMaxContinuousSilence(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_min_5_max_10_max_continuous_silence_0(self): +@pytest.fixture +def tokenizer_max_continuous_silence_0(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=0, + init_min=3, + init_max_silence=3, + mode=0, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=0, - init_min=3, - init_max_silence=3, - mode=0, - ) - data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") - # ^ ^ ^ ^ ^ ^ - # 3 7 9 14 17 25 +def test_min_5_max_10_max_continuous_silence_0( + tokenizer_max_continuous_silence_0, +): + data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") + # ^ ^ ^ ^ ^ ^ + # 3 7 9 14 17 25 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_max_continuous_silence_0.tokenize(data_source) - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + assert ( + len(tokens) == 3 + ), "wrong number of tokens, expected: 3, found: {}".format(len(tokens)) + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 3, - msg=( - "wrong start frame for token 1, expected: 3, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 7, - msg=( - "wrong end frame for token 1, expected: 7, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), "wrong data for token 1, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 3 + ), "wrong start frame for token 1, expected: 3, found: {}".format(start) + assert ( + end == 7 + ), "wrong end frame for token 1, expected: 7, found: {}".format(end) - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 9, - msg=( - "wrong start frame for token 1, expected: 9, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 14, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAAAAA" + ), "wrong data for token 2, expected: 'AAAAAA', found: {}".format(data) + assert ( + start == 9 + ), "wrong start frame for token 2, expected: 9, found: {}".format(start) + assert ( + end == 14 + ), "wrong end frame for token 2, expected: 14, found: {}".format(end) - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 17, - msg=( - "wrong start frame for token 1, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 25, - msg=( - "wrong end frame for token 1, expected: 25, found: {0} " - ).format(end), - ) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAAAAAA" + ), "wrong data for token 3, expected: 'AAAAAAAAA', found: {}".format(data) + assert ( + start == 17 + ), "wrong start frame for token 3, expected: 17, found: {}".format(start) + assert ( + end == 25 + ), "wrong end frame for token 3, expected: 25, found: {}".format(end) - def test_min_5_max_10_max_continuous_silence_1(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=1, - init_min=3, - init_max_silence=3, - mode=0, - ) +@pytest.fixture +def tokenizer_max_continuous_silence_1(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=1, + init_min=3, + init_max_silence=3, + mode=0, + ) - data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") - # ^ ^^ ^ ^ ^ - # 3 12131517 26 - # (12 13 15 17) - tokens = tokenizer.tokenize(data_source) +def test_min_5_max_10_max_continuous_silence_1( + tokenizer_max_continuous_silence_1, +): + data_source = StringDataSource("aaaAAAAAaAAAAAAaaAAAAAAAAAa") + # ^ ^^ ^ ^ ^ + # 3 12131517 26 + # (12 13 15 17) - self.assertEqual( - len(tokens), - 3, - msg="wrong number of tokens, expected: 3, found: {0} ".format( - len(tokens) - ), - ) - tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] + tokens = tokenizer_max_continuous_silence_1.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAaAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAaAAAA', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 3, - msg=( - "wrong start frame for token 1, expected: 3, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 12, - msg=( - "wrong end frame for token 1, expected: 10, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 3 + ), "wrong number of tokens, expected: 3, found: {}".format(len(tokens)) + tok1, tok2, tok3 = tokens[0], tokens[1], tokens[2] - data = "".join(tok2[0]) - start = tok2[1] - end = tok2[2] - self.assertEqual( - data, - "AAa", - msg=( - "wrong data for token 1, expected: 'AAa', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 13, - msg=( - "wrong start frame for token 1, expected: 9, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 15, - msg=( - "wrong end frame for token 1, expected: 14, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAaAAAA" + ), "wrong data for token 1, expected: 'AAAAAaAAAA', found: {}".format(data) + assert ( + start == 3 + ), "wrong start frame for token 1, expected: 3, found: {}".format(start) + assert ( + end == 12 + ), "wrong end frame for token 1, expected: 12, found: {}".format(end) - data = "".join(tok3[0]) - start = tok3[1] - end = tok3[2] - self.assertEqual( - data, - "AAAAAAAAAa", - msg=( - "wrong data for token 1, expected: 'AAAAAAAAAa', " - "found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 17, - msg=( - "wrong start frame for token 1, expected: 17, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 26, - msg=( - "wrong end frame for token 1, expected: 26, found: {0} " - ).format(end), - ) + data = "".join(tok2[0]) + start = tok2[1] + end = tok2[2] + assert ( + data == "AAa" + ), "wrong data for token 2, expected: 'AAa', found: {}".format(data) + assert ( + start == 13 + ), "wrong start frame for token 2, expected: 13, found: {}".format(start) + assert ( + end == 15 + ), "wrong end frame for token 2, expected: 15, found: {}".format(end) + data = "".join(tok3[0]) + start = tok3[1] + end = tok3[2] + assert ( + data == "AAAAAAAAAa" + ), "wrong data for token 3, expected: 'AAAAAAAAAa', found: {}".format(data) + assert ( + start == 17 + ), "wrong start frame for token 3, expected: 17, found: {}".format(start) + assert ( + end == 26 + ), "wrong end frame for token 3, expected: 26, found: {}".format(end) -class TestStreamTokenizerModes(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() - def test_STRICT_MIN_LENGTH(self): +@pytest.fixture +def tokenizer_strict_min_length(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.STRICT_MIN_LENGTH, + ) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.STRICT_MIN_LENGTH, - ) - data_source = StringDataSource("aaAAAAAAAAAAAA") - # ^ ^ - # 2 9 +def test_STRICT_MIN_LENGTH(tokenizer_strict_min_length): + data_source = StringDataSource("aaAAAAAAAAAAAA") + # ^ ^ + # 2 9 - tokens = tokenizer.tokenize(data_source) + tokens = tokenizer_strict_min_length.tokenize(data_source) - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] + assert ( + len(tokens) == 1 + ), "wrong number of tokens, expected: 1, found: {}".format(len(tokens)) + tok1 = tokens[0] - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 9, - msg=( - "wrong end frame for token 1, expected: 9, found: {0} " - ).format(end), - ) + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAA" + ), "wrong data for token 1, expected: 'AAAAAAAA', found: {}".format(data) + assert ( + start == 2 + ), "wrong start frame for token 1, expected: 2, found: {}".format(start) + assert ( + end == 9 + ), "wrong end frame for token 1, expected: 9, found: {}".format(end) - def test_DROP_TAILING_SILENCE(self): - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=10, - max_continuous_silence=2, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.DROP_TRAILING_SILENCE, - ) +@pytest.fixture +def tokenizer_drop_trailing_silence(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=10, + max_continuous_silence=2, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.DROP_TRAILING_SILENCE, + ) - data_source = StringDataSource("aaAAAAAaaaaa") - # ^ ^ - # 2 6 - tokens = tokenizer.tokenize(data_source) +def test_DROP_TAILING_SILENCE(tokenizer_drop_trailing_silence): + data_source = StringDataSource("aaAAAAAaaaaa") + # ^ ^ + # 2 6 - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] + tokens = tokenizer_drop_trailing_silence.tokenize(data_source) - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 6, - msg=( - "wrong end frame for token 1, expected: 6, found: {0} " - ).format(end), - ) + assert ( + len(tokens) == 1 + ), "wrong number of tokens, expected: 1, found: {}".format(len(tokens)) + tok1 = tokens[0] - def test_STRICT_MIN_LENGTH_and_DROP_TAILING_SILENCE(self): + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAA" + ), "wrong data for token 1, expected: 'AAAAA', found: {}".format(data) + assert ( + start == 2 + ), "wrong start frame for token 1, expected: 2, found: {}".format(start) + assert ( + end == 6 + ), "wrong end frame for token 1, expected: 6, found: {}".format(end) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=StreamTokenizer.STRICT_MIN_LENGTH - | StreamTokenizer.DROP_TRAILING_SILENCE, - ) - data_source = StringDataSource("aaAAAAAAAAAAAAaa") - # ^ ^ - # 2 8 +@pytest.fixture +def tokenizer_strict_min_and_drop_trailing_silence(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=StreamTokenizer.STRICT_MIN_LENGTH + | StreamTokenizer.DROP_TRAILING_SILENCE, + ) - tokens = tokenizer.tokenize(data_source) - self.assertEqual( - len(tokens), - 1, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) - tok1 = tokens[0] +def test_STRICT_MIN_LENGTH_and_DROP_TAILING_SILENCE( + tokenizer_strict_min_and_drop_trailing_silence, +): + data_source = StringDataSource("aaAAAAAAAAAAAAaa") + # ^ ^ + # 2 8 - data = "".join(tok1[0]) - start = tok1[1] - end = tok1[2] - self.assertEqual( - data, - "AAAAAAAA", - msg=( - "wrong data for token 1, expected: 'AAAAAAAA', found: '{0}' " - ).format(data), - ) - self.assertEqual( - start, - 2, - msg=( - "wrong start frame for token 1, expected: 2, found: {0} " - ).format(start), - ) - self.assertEqual( - end, - 9, - msg=( - "wrong end frame for token 1, expected: 9, found: {0} " - ).format(end), - ) + tokens = tokenizer_strict_min_and_drop_trailing_silence.tokenize( + data_source + ) + assert ( + len(tokens) == 1 + ), "wrong number of tokens, expected: 1, found: {}".format(len(tokens)) + tok1 = tokens[0] -class TestStreamTokenizerCallback(unittest.TestCase): - def setUp(self): - self.A_validator = AValidator() + data = "".join(tok1[0]) + start = tok1[1] + end = tok1[2] + assert ( + data == "AAAAAAAA" + ), "wrong data for token 1, expected: 'AAAAAAAA', found: {}".format(data) + assert ( + start == 2 + ), "wrong start frame for token 1, expected: 2, found: {}".format(start) + assert ( + end == 9 + ), "wrong end frame for token 1, expected: 9, found: {}".format(end) - def test_callback(self): - tokens = [] +@pytest.fixture +def tokenizer_callback(validator): + return StreamTokenizer( + validator, + min_length=5, + max_length=8, + max_continuous_silence=3, + init_min=3, + init_max_silence=3, + mode=0, + ) - def callback(data, start, end): - tokens.append((data, start, end)) - tokenizer = StreamTokenizer( - self.A_validator, - min_length=5, - max_length=8, - max_continuous_silence=3, - init_min=3, - init_max_silence=3, - mode=0, - ) +def test_callback(tokenizer_callback): + tokens = [] - data_source = StringDataSource("aaAAAAAAAAAAAAa") - # ^ ^^ ^ - # 2 910 14 + def callback(data, start, end): + tokens.append((data, start, end)) - tokenizer.tokenize(data_source, callback=callback) + data_source = StringDataSource("aaAAAAAAAAAAAAa") + # ^ ^^ ^ + # 2 910 14 - self.assertEqual( - len(tokens), - 2, - msg="wrong number of tokens, expected: 1, found: {0} ".format( - len(tokens) - ), - ) + tokenizer_callback.tokenize(data_source, callback=callback) - -if __name__ == "__main__": - unittest.main() + assert ( + len(tokens) == 2 + ), "wrong number of tokens, expected: 2, found: {}".format(len(tokens))
--- a/tests/test_cmdline_util.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_cmdline_util.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,28 +1,29 @@ import os -import unittest -from unittest import TestCase +from collections import namedtuple +from tempfile import TemporaryDirectory from unittest.mock import patch -from tempfile import TemporaryDirectory -from collections import namedtuple -from genty import genty, genty_dataset + +import pytest from auditok.cmdline_util import ( _AUDITOK_LOGGER, + KeywordArguments, + initialize_workers, make_kwargs, make_logger, - initialize_workers, - KeywordArguments, ) +from auditok.exceptions import ArgumentError from auditok.workers import ( + AudioEventsJoinerWorker, + CommandLineWorker, + PlayerWorker, + PrintWorker, + RegionSaverWorker, StreamSaverWorker, - RegionSaverWorker, - PlayerWorker, - CommandLineWorker, - PrintWorker, ) -_ArgsNamespece = namedtuple( - "_ArgsNamespece", +_ArgsNamespace = namedtuple( + "_ArgsNamespace", [ "input", "max_read", @@ -38,6 +39,7 @@ "input_device_index", "save_stream", "save_detections_as", + "join_detections", "plot", "save_image", "min_duration", @@ -57,285 +59,376 @@ ) -@genty -class TestCmdLineUtil(TestCase): - @genty_dataset( - no_record=("stream.ogg", None, False, None, "mix", "mix", False), - no_record_plot=("stream.ogg", None, True, None, None, None, False), - no_record_save_image=( - "stream.ogg", - None, - True, - "image.png", - None, - None, - False, - ), - record_plot=(None, None, True, None, None, None, True), - record_save_image=(None, None, False, "image.png", None, None, True), - int_use_channel=("stream.ogg", None, False, None, "1", 1, False), - save_detections_as=( - "stream.ogg", - "{id}.wav", - False, - None, - None, - None, - False, - ), - ) - def test_make_kwargs( - self, +@pytest.mark.parametrize( + "save_stream, save_detections_as, join_detections, plot, save_image, use_channel, exp_use_channel, exp_record", # noqa: B950 + [ + # no_record_no_join + ("stream.ogg", None, None, False, None, "mix", "mix", False), + # no_record_plot_join + ("stream.ogg", None, 1.0, True, None, None, None, False), + # no_record_save_image + ("stream.ogg", None, None, True, "image.png", None, None, False), + # record_plot + (None, None, None, True, None, None, None, True), + # record_save_image + (None, None, None, False, "image.png", None, None, True), + # int_use_channel + ("stream.ogg", None, None, False, None, "1", 1, False), + # save_detections_as + ("stream.ogg", "{id}.wav", None, False, None, None, None, False), + ], + ids=[ + "no_record_no_join", + "no_record_plot", + "no_record_save_image", + "record_plot", + "record_save_image", + "int_use_channel", + "save_detections_as", + ], +) +def test_make_kwargs( + save_stream, + save_detections_as, + join_detections, + plot, + save_image, + use_channel, + exp_use_channel, + exp_record, +): + args = ( + "file", + 30, + 0.01, + 16000, + 2, + 2, + use_channel, + "raw", + "ogg", + True, + None, + 1, save_stream, save_detections_as, + join_detections, plot, save_image, - use_channel, - exp_use_channel, - exp_record, - ): + 0.2, + 10, + 0.3, + False, + False, + 55, + ) + misc = ( + False, + False, + None, + True, + None, + "TIME_FORMAT", + "TIMESTAMP_FORMAT", + ) + args_ns = _ArgsNamespace(*(args + misc)) - args = ( - "file", - 30, - 0.01, - 16000, - 2, - 2, - use_channel, - "raw", - "ogg", - True, - None, - 1, - save_stream, - save_detections_as, - plot, - save_image, - 0.2, - 10, - 0.3, - False, - False, - 55, - ) - misc = ( - False, - False, - None, - True, - None, - "TIME_FORMAT", - "TIMESTAMP_FORMAT", - ) - args_ns = _ArgsNamespece(*(args + misc)) + io_kwargs = { + "input": "file", + "max_read": 30, + "block_dur": 0.01, + "sampling_rate": 16000, + "sample_width": 2, + "channels": 2, + "use_channel": exp_use_channel, + "save_stream": save_stream, + "save_detections_as": save_detections_as, + "join_detections": join_detections, + "audio_format": "raw", + "export_format": "ogg", + "large_file": True, + "frames_per_buffer": None, + "input_device_index": 1, + "record": exp_record, + } - io_kwargs = { - "input": "file", - "max_read": 30, - "block_dur": 0.01, - "sampling_rate": 16000, - "sample_width": 2, - "channels": 2, - "use_channel": exp_use_channel, - "save_stream": save_stream, - "save_detections_as": save_detections_as, - "audio_format": "raw", - "export_format": "ogg", - "large_file": True, - "frames_per_buffer": None, - "input_device_index": 1, - "record": exp_record, - } + split_kwargs = { + "min_dur": 0.2, + "max_dur": 10, + "max_silence": 0.3, + "drop_trailing_silence": False, + "strict_min_dur": False, + "energy_threshold": 55, + } - split_kwargs = { - "min_dur": 0.2, - "max_dur": 10, - "max_silence": 0.3, - "drop_trailing_silence": False, - "strict_min_dur": False, - "energy_threshold": 55, - } + miscellaneous = { + "echo": False, + "command": None, + "progress_bar": False, + "quiet": True, + "printf": None, + "time_format": "TIME_FORMAT", + "timestamp_format": "TIMESTAMP_FORMAT", + } - miscellaneous = { - "echo": False, - "command": None, - "progress_bar": False, - "quiet": True, - "printf": None, - "time_format": "TIME_FORMAT", - "timestamp_format": "TIMESTAMP_FORMAT", - } + expected = KeywordArguments(io_kwargs, split_kwargs, miscellaneous) + kwargs = make_kwargs(args_ns) + assert kwargs == expected - expected = KeywordArguments(io_kwargs, split_kwargs, miscellaneous) - kwargs = make_kwargs(args_ns) - self.assertEqual(kwargs, expected) - def test_make_logger_stderr_and_file(self): +def test_make_kwargs_error(): + + args = ( + "file", + 30, + 0.01, + 16000, + 2, + 2, + 1, + "raw", + "ogg", + True, + None, + 1, + None, # save_stream + None, + 1.0, # join_detections + None, + None, + 0.2, + 10, + 0.3, + False, + False, + 55, + False, + False, + None, + True, + None, + "TIME_FORMAT", + "TIMESTAMP_FORMAT", + ) + + args_ns = _ArgsNamespace(*args) + expected_err_msg = "using --join-detections/-j requires " + expected_err_msg += "--save-stream/-O to be specified." + with pytest.raises(ArgumentError) as arg_err: + make_kwargs(args_ns) + assert str(arg_err.value) == expected_err_msg + + +def test_make_logger_stderr_and_file(capsys): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(stderr=True, file=file) + assert logger.name == _AUDITOK_LOGGER + assert len(logger.handlers) == 2 + assert logger.handlers[1].stream.name == file + logger.info("This is a debug message") + captured = capsys.readouterr() + assert "This is a debug message" in captured.err + + +def test_make_logger_None(): + logger = make_logger(stderr=False, file=None) + assert logger is None + + +def test_initialize_workers_all_plus_full_stream_saver(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(stderr=True, file=file) - self.assertEqual(logger.name, _AUDITOK_LOGGER) - self.assertEqual(len(logger.handlers), 2) - self.assertEqual(logger.handlers[0].stream.name, "<stderr>") - self.assertEqual(logger.handlers[1].stream.name, file) + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + join_detections=None, + echo=True, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [ + RegionSaverWorker, + PlayerWorker, + CommandLineWorker, + PrintWorker, + ], + ): + assert isinstance(obs, cls) - def test_make_logger_None(self): - logger = make_logger(stderr=False, file=None) - self.assertIsNone(logger) - def test_initialize_workers_all(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [ - RegionSaverWorker, - PlayerWorker, - CommandLineWorker, - PrintWorker, - ], - ): - self.assertIsInstance(obs, cls) +def test_initialize_workers_all_plus_audio_event_joiner(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + join_detections=1, + echo=True, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + assert patched_player_for.called + assert not isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [ + AudioEventsJoinerWorker, + RegionSaverWorker, + PlayerWorker, + CommandLineWorker, + PrintWorker, + ], + ): + assert isinstance(obs, cls) - def test_initialize_workers_no_RegionSaverWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as=None, - echo=True, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, [PlayerWorker, CommandLineWorker, PrintWorker] - ): - self.assertIsInstance(obs, cls) - def test_initialize_workers_no_PlayerWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=False, - progress_bar=False, - command="some command", - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertFalse(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [RegionSaverWorker, CommandLineWorker, PrintWorker], - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_CommandLineWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command=None, - quiet=False, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, [RegionSaverWorker, PlayerWorker, PrintWorker] - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_PrintWorker(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - with TemporaryDirectory() as tmpdir: - export_filename = os.path.join(tmpdir, "output.wav") - reader, observers = initialize_workers( - input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=export_filename, - export_format="wave", - save_detections_as="{id}.wav", - echo=True, - progress_bar=False, - command="some command", - quiet=True, - printf="abcd", - time_format="%S", - timestamp_format="%h:%M:%S", - ) - reader.stop() - self.assertTrue(patched_player_for.called) - self.assertIsInstance(reader, StreamSaverWorker) - for obs, cls in zip( - observers, - [RegionSaverWorker, PlayerWorker, CommandLineWorker], - ): - self.assertIsInstance(obs, cls) - - def test_initialize_workers_no_observers(self): - with patch("auditok.cmdline_util.player_for") as patched_player_for: - reader, observers = initialize_workers( +def test_initialize_workers_no_RegionSaverWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( input="tests/data/test_16KHZ_mono_400Hz.wav", - save_stream=None, + save_stream=export_filename, export_format="wave", save_detections_as=None, + join_detections=None, + echo=True, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [PlayerWorker, CommandLineWorker, PrintWorker], + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_PlayerWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + join_detections=None, + echo=False, + progress_bar=False, + command="some command", + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert not patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [RegionSaverWorker, CommandLineWorker, PrintWorker], + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_CommandLineWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + join_detections=None, echo=True, progress_bar=False, command=None, + quiet=False, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [RegionSaverWorker, PlayerWorker, PrintWorker], + ): + assert isinstance(obs, cls) + + +def test_initialize_workers_no_PrintWorker(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + with TemporaryDirectory() as tmpdir: + export_filename = os.path.join(tmpdir, "output.wav") + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=export_filename, + export_format="wave", + save_detections_as="{id}.wav", + join_detections=None, + echo=True, + progress_bar=False, + command="some command", quiet=True, printf="abcd", time_format="%S", timestamp_format="%h:%M:%S", ) - self.assertTrue(patched_player_for.called) - self.assertFalse(isinstance(reader, StreamSaverWorker)) - self.assertTrue(len(observers), 0) + reader.stop() + assert patched_player_for.called + assert isinstance(reader, StreamSaverWorker) + for obs, cls in zip( + tokenizer_worker._observers, + [RegionSaverWorker, PlayerWorker, CommandLineWorker], + ): + assert isinstance(obs, cls) -if __name__ == "__main__": - unittest.main() +def test_initialize_workers_no_observers(): + with patch("auditok.cmdline_util.player_for") as patched_player_for: + reader, tokenizer_worker = initialize_workers( + input="tests/data/test_16KHZ_mono_400Hz.wav", + save_stream=None, + export_format="wave", + save_detections_as=None, + echo=True, + progress_bar=False, + command=None, + quiet=True, + printf="abcd", + time_format="%S", + timestamp_format="%h:%M:%S", + ) + assert patched_player_for.called + assert not isinstance(reader, StreamSaverWorker) + assert len(tokenizer_worker._observers) == 1
--- a/tests/test_core.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_core.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,21 +1,31 @@ +import math import os -import math +from pathlib import Path from random import random from tempfile import TemporaryDirectory -from array import array as array_ -import unittest -from unittest import TestCase, mock +from unittest import mock from unittest.mock import patch -from genty import genty, genty_dataset -from auditok import load, split, AudioRegion, AudioParameterError + +import numpy as np +import pytest + +from auditok import ( + AudioParameterError, + AudioRegion, + load, + make_silence, + split, + split_and_join_with_silence, +) from auditok.core import ( _duration_to_nb_windows, _make_audio_region, _read_chunks_online, _read_offline, ) -from auditok.util import AudioDataSource from auditok.io import get_audio_source +from auditok.signal import to_array +from auditok.util import AudioReader mock._magics.add("__round__") @@ -32,109 +42,232 @@ return regions -@genty -class TestFunctions(TestCase): - @genty_dataset( - no_skip_read_all=(0, -1), - no_skip_read_all_stereo=(0, -1, 2), - skip_2_read_all=(2, -1), - skip_2_read_all_None=(2, None), - skip_2_read_3=(2, 3), - skip_2_read_3_5_stereo=(2, 3.5, 2), - skip_2_4_read_3_5_stereo=(2.4, 3.5, 2), +@pytest.mark.parametrize( + "skip, max_read, channels", + [ + (0, -1, 1), # no_skip_read_all + (0, -1, 2), # no_skip_read_all_stereo + (2, -1, 1), # skip_2_read_all + (2, None, 1), # skip_2_read_all_None + (2, 3, 1), # skip_2_read_3 + (2, 3.5, 2), # skip_2_read_3_5_stereo + (2.4, 3.5, 2), # skip_2_4_read_3_5_stereo + ], + ids=[ + "no_skip_read_all", + "no_skip_read_all_stereo", + "skip_2_read_all", + "skip_2_read_all_None", + "skip_2_read_3", + "skip_2_read_3_5_stereo", + "skip_2_4_read_3_5_stereo", + ], +) +def test_load(skip, max_read, channels): + sampling_rate = 10 + sample_width = 2 + filename = "tests/data/test_split_10HZ_{}.raw" + filename = filename.format("mono" if channels == 1 else "stereo") + region = load( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_load(self, skip, max_read, channels=1): - sampling_rate = 10 - sample_width = 2 - filename = "tests/data/test_split_10HZ_{}.raw" - filename = filename.format("mono" if channels == 1 else "stereo") - region = load( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, + with open(filename, "rb") as fp: + fp.read(round(skip * sampling_rate * sample_width * channels)) + if max_read is None or max_read < 0: + to_read = -1 + else: + to_read = round(max_read * sampling_rate * sample_width * channels) + expected = fp.read(to_read) + assert bytes(region) == expected + + +@pytest.mark.parametrize( + "duration, sampling_rate, sample_width, channels", + [ + (1.05, 16000, 1, 1), # mono_16K_1byte + (1.5, 16000, 2, 1), # mono_16K_2byte + (1.0001, 44100, 2, 2), # stereo_44100_2byte + (1.000005, 48000, 2, 3), # 3channel_48K_2byte + (1.0001, 48000, 4, 4), # 4channel_48K_4byte + (0, 48000, 4, 4), # 4channel_4K_4byte_0sec + ], + ids=[ + "mono_16K_1byte", + "mono_16K_2byte", + "stereo_44100_2byte", + "3channel_48000_2byte", + "4channel_48K_4byte", + "4channel_4K_4byte_0sec", + ], +) +def test_make_silence(duration, sampling_rate, sample_width, channels): + silence = make_silence(duration, sampling_rate, sample_width, channels) + size = round(duration * sampling_rate) * sample_width * channels + expected_data = b"\0" * size + expected_duration = size / (sampling_rate * sample_width * channels) + assert silence.duration == expected_duration + assert silence.data == expected_data + + +@pytest.mark.parametrize( + "duration", + [ + (0,), # zero_second + (1,), # one_second + (1.0001,), # 1.0001_second + ], + ids=[ + "zero_second", + "one_second", + "1.0001_second", + ], +) +def test_split_and_join_with_silence(duration): + duration = 1.0 + sampling_rate = 10 + sample_width = 2 + channels = 1 + + regions = split( + input="tests/data/test_split_10HZ_mono.raw", + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + sr=sampling_rate, + sw=sample_width, + ch=channels, + eth=50, + ) + + size = round(duration * sampling_rate) * sample_width * channels + join_data = b"\0" * size + expected_data = join_data.join(region.data for region in regions) + expected_region = AudioRegion( + expected_data, sampling_rate, sample_width, channels + ) + + region_with_silence = split_and_join_with_silence( + input="tests/data/test_split_10HZ_mono.raw", + silence_duration=duration, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + sr=sampling_rate, + sw=sample_width, + ch=channels, + eth=50, + ) + assert region_with_silence == expected_region + + +@pytest.mark.parametrize( + "duration, analysis_window, round_fn, expected, kwargs", + [ + (0, 1, None, 0, None), # zero_duration + (0.3, 0.1, round, 3, None), # multiple + (0.35, 0.1, math.ceil, 4, None), # not_multiple_ceil + (0.35, 0.1, math.floor, 3, None), # not_multiple_floor + (0.05, 0.1, round, 0, None), # small_duration + (0.05, 0.1, math.ceil, 1, None), # small_duration_ceil + (0.3, 0.1, math.floor, 3, {"epsilon": 1e-6}), # with_round_error + (-0.5, 0.1, math.ceil, ValueError, None), # negative_duration + (0.5, -0.1, math.ceil, ValueError, None), # negative_analysis_window + ], + ids=[ + "zero_duration", + "multiple", + "not_multiple_ceil", + "not_multiple_floor", + "small_duration", + "small_duration_ceil", + "with_round_error", + "negative_duration", + "negative_analysis_window", + ], +) +def test_duration_to_nb_windows( + duration, analysis_window, round_fn, expected, kwargs +): + if expected == ValueError: + with pytest.raises(ValueError): + _duration_to_nb_windows(duration, analysis_window, round_fn) + else: + if kwargs is None: + kwargs = {} + result = _duration_to_nb_windows( + duration, analysis_window, round_fn, **kwargs ) - with open(filename, "rb") as fp: - fp.read(round(skip * sampling_rate * sample_width * channels)) - if max_read is None or max_read < 0: - to_read = -1 - else: - to_read = round( - max_read * sampling_rate * sample_width * channels - ) - expected = fp.read(to_read) - self.assertEqual(bytes(region), expected) + assert result == expected - @genty_dataset( - zero_duration=(0, 1, None, 0), - multiple=(0.3, 0.1, round, 3), - not_multiple_ceil=(0.35, 0.1, math.ceil, 4), - not_multiple_floor=(0.35, 0.1, math.floor, 3), - small_duration=(0.05, 0.1, round, 0), - small_duration_ceil=(0.05, 0.1, math.ceil, 1), - with_round_error=(0.3, 0.1, math.floor, 3, {"epsilon": 1e-6}), - negative_duration=(-0.5, 0.1, math.ceil, ValueError), - negative_analysis_window=(0.5, -0.1, math.ceil, ValueError), + +@pytest.mark.parametrize( + "channels, skip, max_read", + [ + (1, 0, None), # mono_skip_0_max_read_None + (1, 3, None), # mono_skip_3_max_read_None + (1, 2, -1), # mono_skip_2_max_read_negative + (1, 2, 3), # mono_skip_2_max_read_3 + (2, 0, None), # stereo_skip_0_max_read_None + (2, 3, None), # stereo_skip_3_max_read_None + (2, 2, -1), # stereo_skip_2_max_read_negative + (2, 2, 3), # stereo_skip_2_max_read_3 + ], + ids=[ + "mono_skip_0_max_read_None", + "mono_skip_3_max_read_None", + "mono_skip_2_max_read_negative", + "mono_skip_2_max_read_3", + "stereo_skip_0_max_read_None", + "stereo_skip_3_max_read_None", + "stereo_skip_2_max_read_negative", + "stereo_skip_2_max_read_3", + ], +) +def test_read_offline(channels, skip, max_read): + sampling_rate = 10 + sample_width = 2 + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + onset = round(skip * sampling_rate * sample_width * channels) + if max_read in (-1, None): + offset = len(data) + 1 + else: + offset = onset + round( + max_read * sampling_rate * sample_width * channels + ) + expected_data = data[onset:offset] + read_data, *audio_params = _read_offline( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_duration_to_nb_windows( - self, duration, analysis_window, round_fn, expected, kwargs=None - ): - if expected == ValueError: - with self.assertRaises(expected): - _duration_to_nb_windows(duration, analysis_window, round_fn) - else: - if kwargs is None: - kwargs = {} - result = _duration_to_nb_windows( - duration, analysis_window, round_fn, **kwargs - ) - self.assertEqual(result, expected) + assert read_data == expected_data + assert tuple(audio_params) == (sampling_rate, sample_width, channels) - @genty_dataset( - mono_skip_0_max_read_None=(1, 0, None), - mono_skip_3_max_read_None=(1, 3, None), - mono_skip_2_max_read_negative=(1, 2, -1), - mono_skip_2_max_read_3=(1, 2, 3), - stereo_skip_0_max_read_None=(2, 0, None), - stereo_skip_3_max_read_None=(2, 3, None), - stereo_skip_2_max_read_negative=(2, 2, -1), - stereo_skip_2_max_read_3=(2, 2, 3), - ) - def test_read_offline(self, channels, skip, max_read=None): - sampling_rate = 10 - sample_width = 2 - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() - onset = round(skip * sampling_rate * sample_width * channels) - if max_read in (-1, None): - offset = len(data) + 1 - else: - offset = onset + round( - max_read * sampling_rate * sample_width * channels - ) - expected_data = data[onset:offset] - read_data, *audio_params = _read_offline( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, - ) - self.assertEqual(read_data, expected_data) - self.assertEqual( - tuple(audio_params), (sampling_rate, sample_width, channels) - ) - -@genty -class TestSplit(TestCase): - @genty_dataset( - simple=( +@pytest.mark.parametrize( + ( + "min_dur, max_dur, max_silence, drop_trailing_silence, " + + "strict_min_dur, kwargs, expected" + ), + [ + ( 0.2, 5, 0.2, @@ -142,8 +275,8 @@ False, {"eth": 50}, [(2, 16), (17, 31), (34, 76)], - ), - short_max_dur=( + ), # simple + ( 0.3, 2, 0.2, @@ -151,10 +284,10 @@ False, {"eth": 50}, [(2, 16), (17, 31), (34, 54), (54, 74), (74, 76)], - ), - long_min_dur=(3, 5, 0.2, False, False, {"eth": 50}, [(34, 76)]), - long_max_silence=(0.2, 80, 10, False, False, {"eth": 50}, [(2, 76)]), - zero_max_silence=( + ), # short_max_dur + (3, 5, 0.2, False, False, {"eth": 50}, [(34, 76)]), # long_min_dur + (0.2, 80, 10, False, False, {"eth": 50}, [(2, 76)]), # long_max_silence + ( 0.2, 5, 0.0, @@ -162,8 +295,8 @@ False, {"eth": 50}, [(2, 14), (17, 24), (26, 29), (34, 76)], - ), - low_energy_threshold=( + ), # zero_max_silence + ( 0.2, 5, 0.2, @@ -171,8 +304,8 @@ False, {"energy_threshold": 40}, [(0, 50), (50, 76)], - ), - high_energy_threshold=( + ), # low_energy_threshold + ( 0.2, 5, 0.2, @@ -180,17 +313,17 @@ False, {"energy_threshold": 60}, [], - ), - trim_leading_and_trailing_silence=( + ), # high_energy_threshold + ( 0.2, - 10, # use long max_dur - 0.5, # and a max_silence longer than any inter-region silence + 10, + 0.5, True, False, {"eth": 50}, [(2, 76)], - ), - drop_trailing_silence=( + ), # trim_leading_and_trailing_silence + ( 0.2, 5, 0.2, @@ -198,8 +331,8 @@ False, {"eth": 50}, [(2, 14), (17, 29), (34, 76)], - ), - drop_trailing_silence_2=( + ), # drop_trailing_silence + ( 1.5, 5, 0.2, @@ -207,8 +340,8 @@ False, {"eth": 50}, [(34, 76)], - ), - strict_min_dur=( + ), # drop_trailing_silence_2 + ( 0.3, 2, 0.2, @@ -216,120 +349,916 @@ True, {"eth": 50}, [(2, 16), (17, 31), (34, 54), (54, 74)], - ), - ) - def test_split_params( - self, + ), # strict_min_dur + ], + ids=[ + "simple", + "short_max_dur", + "long_min_dur", + "long_max_silence", + "zero_max_silence", + "low_energy_threshold", + "high_energy_threshold", + "trim_leading_and_trailing_silence", + "drop_trailing_silence", + "drop_trailing_silence_2", + "strict_min_dur", + ], +) +def test_split_params( + min_dur, + max_dur, + max_silence, + drop_trailing_silence, + strict_min_dur, + kwargs, + expected, +): + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + + regions = split( + data, min_dur, max_dur, max_silence, drop_trailing_silence, strict_min_dur, - kwargs, - expected, - ): - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() + analysis_window=0.1, + sr=10, + sw=2, + ch=1, + **kwargs + ) - regions = split( - data, - min_dur, - max_dur, - max_silence, - drop_trailing_silence, - strict_min_dur, - analysis_window=0.1, - sr=10, - sw=2, - ch=1, - **kwargs - ) + region = AudioRegion(data, 10, 2, 1) + regions_ar = region.split( + min_dur, + max_dur, + max_silence, + drop_trailing_silence, + strict_min_dur, + analysis_window=0.1, + **kwargs + ) - region = AudioRegion(data, 10, 2, 1) - regions_ar = region.split( - min_dur, - max_dur, - max_silence, - drop_trailing_silence, - strict_min_dur, - analysis_window=0.1, - **kwargs - ) + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) + sample_width = 2 + for reg, reg_ar, exp in zip(regions, regions_ar, expected): + onset, offset = exp + exp_data = data[onset * sample_width : offset * sample_width] + assert bytes(reg) == exp_data + assert reg == reg_ar - sample_width = 2 - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[onset * sample_width : offset * sample_width] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) - @genty_dataset( - stereo_all_default=(2, {}, [(2, 32), (34, 76)]), - mono_max_read=(1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), - mono_max_read_short_name=(1, {"mr": 5}, [(2, 16), (17, 31), (34, 50)]), - mono_use_channel_1=( +@pytest.mark.parametrize( + "channels, kwargs, expected", + [ + (2, {}, [(2, 32), (34, 76)]), # stereo_all_default + (1, {"max_read": 5}, [(2, 16), (17, 31), (34, 50)]), # mono_max_read + ( + 1, + {"mr": 5}, + [(2, 16), (17, 31), (34, 50)], + ), # mono_max_read_short_name + ( 1, {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)], - ), - mono_uc_1=(1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]), - mono_use_channel_None=( + ), # mono_use_channel_1 + (1, {"eth": 50, "uc": 1}, [(2, 16), (17, 31), (34, 76)]), # mono_uc_1 + ( 1, {"eth": 50, "use_channel": None}, [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_1=( + ), # mono_use_channel_None + ( 2, {"eth": 50, "use_channel": 0}, [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_no_use_channel_given=( + ), # stereo_use_channel_1 + ( 2, {"eth": 50}, [(2, 32), (34, 76)], - ), - stereo_use_channel_minus_2=( + ), # stereo_use_channel_no_use_channel_given + ( 2, {"eth": 50, "use_channel": -2}, [(2, 16), (17, 31), (34, 76)], - ), - stereo_uc_2=(2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), - stereo_uc_minus_1=(2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), - mono_uc_mix=( + ), # stereo_use_channel_minus_2 + (2, {"eth": 50, "uc": 1}, [(10, 32), (36, 76)]), # stereo_uc_2 + (2, {"eth": 50, "uc": -1}, [(10, 32), (36, 76)]), # stereo_uc_minus_1 + ( 1, {"eth": 50, "uc": "mix"}, [(2, 16), (17, 31), (34, 76)], - ), - stereo_use_channel_mix=( + ), # mono_uc_mix + ( 2, {"energy_threshold": 53.5, "use_channel": "mix"}, [(54, 76)], - ), - stereo_uc_mix=(2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]), - stereo_uc_mix_default_eth=( + ), # stereo_use_channel_mix + (2, {"eth": 52, "uc": "mix"}, [(17, 26), (54, 76)]), # stereo_uc_mix + ( 2, {"uc": "mix"}, [(10, 16), (17, 31), (36, 76)], - ), + ), # stereo_uc_mix_default_eth + ], + ids=[ + "stereo_all_default", + "mono_max_read", + "mono_max_read_short_name", + "mono_use_channel_1", + "mono_uc_1", + "mono_use_channel_None", + "stereo_use_channel_1", + "stereo_use_channel_no_use_channel_given", + "stereo_use_channel_minus_2", + "stereo_uc_2", + "stereo_uc_minus_1", + "mono_uc_mix", + "stereo_use_channel_mix", + "stereo_uc_mix", + "stereo_uc_mix_default_eth", + ], +) +def test_split_kwargs(channels, kwargs, expected): + + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + sr=10, + sw=2, + ch=channels, + **kwargs ) - def test_split_kwargs(self, channels, kwargs, expected): - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() + region = AudioRegion(data, 10, 2, channels) + max_read = kwargs.get("max_read", kwargs.get("mr")) + if max_read is not None: + region = region.sec[:max_read] + kwargs.pop("max_read", None) + kwargs.pop("mr", None) - regions = split( - data, + regions_ar = region.split( + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + **kwargs + ) + + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg + + sample_width = 2 + sample_size_bytes = sample_width * channels + for reg, reg_ar, exp in zip( + regions, + regions_ar, + expected, + ): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert len(bytes(reg)) == len(exp_data) + assert reg == reg_ar + + +@pytest.mark.parametrize( + "min_dur, max_dur, max_silence, channels, kwargs, expected", + [ + ( + 0.2, + 5, + 0.2, + 1, + {"aw": 0.2}, + [(2, 30), (34, 76)], + ), # mono_aw_0_2_max_silence_0_2 + ( + 0.2, + 5, + 0.3, + 1, + {"aw": 0.2}, + [(2, 30), (34, 76)], + ), # mono_aw_0_2_max_silence_0_3 + ( + 0.2, + 5, + 0.4, + 1, + {"aw": 0.2}, + [(2, 32), (34, 76)], + ), # mono_aw_0_2_max_silence_0_4 + ( + 0.2, + 5, + 0, + 1, + {"aw": 0.2}, + [(2, 14), (16, 24), (26, 28), (34, 76)], + ), # mono_aw_0_2_max_silence_0 + (0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), # mono_aw_0_2 + ( + 0.3, + 5, + 0, + 1, + {"aw": 0.3}, + [(3, 12), (15, 24), (36, 76)], + ), # mono_aw_0_3_max_silence_0 + ( + 0.3, + 5, + 0.3, + 1, + {"aw": 0.3}, + [(3, 27), (36, 76)], + ), # mono_aw_0_3_max_silence_0_3 + ( + 0.3, + 5, + 0.5, + 1, + {"aw": 0.3}, + [(3, 27), (36, 76)], + ), # mono_aw_0_3_max_silence_0_5 + ( + 0.3, + 5, + 0.6, + 1, + {"aw": 0.3}, + [(3, 30), (36, 76)], + ), # mono_aw_0_3_max_silence_0_6 + ( + 0.2, + 5, + 0, + 1, + {"aw": 0.4}, + [(4, 12), (16, 24), (36, 76)], + ), # mono_aw_0_4_max_silence_0 + ( + 0.2, + 5, + 0.3, + 1, + {"aw": 0.4}, + [(4, 12), (16, 24), (36, 76)], + ), # mono_aw_0_4_max_silence_0_3 + ( + 0.2, + 5, + 0.4, + 1, + {"aw": 0.4}, + [(4, 28), (36, 76)], + ), # mono_aw_0_4_max_silence_0_4 + ( + 0.2, + 5, + 0.2, + 2, + {"analysis_window": 0.2}, + [(2, 32), (34, 76)], + ), # stereo_uc_None_analysis_window_0_2 + ( + 0.2, + 5, + 0.2, + 2, + {"uc": None, "analysis_window": 0.2}, + [(2, 32), (34, 76)], + ), # stereo_uc_any_analysis_window_0_2 + ( + 0.2, + 5, + 0.2, + 2, + {"use_channel": None, "analysis_window": 0.3}, + [(3, 30), (36, 76)], + ), # stereo_use_channel_None_aw_0_3_max_silence_0_2 + ( + 0.2, + 5, + 0.3, + 2, + {"use_channel": "any", "analysis_window": 0.3}, + [(3, 33), (36, 76)], + ), # stereo_use_channel_any_aw_0_3_max_silence_0_3 + ( + 0.2, + 5, + 0.2, + 2, + {"use_channel": None, "analysis_window": 0.4}, + [(4, 28), (36, 76)], + ), # stereo_use_channel_None_aw_0_4_max_silence_0_2 + ( + 0.2, + 5, + 0.4, + 2, + {"use_channel": "any", "analysis_window": 0.4}, + [(4, 32), (36, 76)], + ), # stereo_use_channel_any_aw_0_3_max_silence_0_4 + ( + 0.2, + 5, + 0.2, + 2, + {"uc": 0, "analysis_window": 0.2}, + [(2, 30), (34, 76)], + ), # stereo_uc_0_analysis_window_0_2 + ( + 0.2, + 5, + 0.2, + 2, + {"uc": 1, "analysis_window": 0.2}, + [(10, 32), (36, 76)], + ), # stereo_uc_1_analysis_window_0_2 + ( + 0.2, + 5, + 0, + 2, + {"uc": "mix", "analysis_window": 0.1}, + [(10, 14), (17, 24), (26, 29), (36, 76)], + ), # stereo_uc_mix_aw_0_1_max_silence_0 + ( + 0.2, + 5, + 0.1, + 2, + {"uc": "mix", "analysis_window": 0.1}, + [(10, 15), (17, 25), (26, 30), (36, 76)], + ), # stereo_uc_mix_aw_0_1_max_silence_0_1 + ( + 0.2, + 5, + 0.2, + 2, + {"uc": "mix", "analysis_window": 0.1}, + [(10, 16), (17, 31), (36, 76)], + ), # stereo_uc_mix_aw_0_1_max_silence_0_2 + ( + 0.2, + 5, + 0.3, + 2, + {"uc": "mix", "analysis_window": 0.1}, + [(10, 32), (36, 76)], + ), # stereo_uc_mix_aw_0_1_max_silence_0_3 + ( + 0.3, + 5, + 0, + 2, + {"uc": "avg", "analysis_window": 0.2}, + [(10, 14), (16, 24), (36, 76)], + ), # stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3 + ( + 0.41, + 5, + 0, + 2, + {"uc": "average", "analysis_window": 0.2}, + [(16, 24), (36, 76)], + ), # stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41 + ( + 0.2, + 5, + 0.1, + 2, + {"uc": "mix", "analysis_window": 0.2}, + [(10, 14), (16, 24), (26, 28), (36, 76)], + ), # stereo_uc_mix_aw_0_2_max_silence_0_1 + ( + 0.2, + 5, + 0.2, + 2, + {"uc": "mix", "analysis_window": 0.2}, + [(10, 30), (36, 76)], + ), # stereo_uc_mix_aw_0_2_max_silence_0_2 + ( + 0.2, + 5, + 0.4, + 2, + {"uc": "mix", "analysis_window": 0.2}, + [(10, 32), (36, 76)], + ), # stereo_uc_mix_aw_0_2_max_silence_0_4 + ( + 0.2, + 5, + 0.5, + 2, + {"uc": "mix", "analysis_window": 0.2}, + [(10, 32), (36, 76)], + ), # stereo_uc_mix_aw_0_2_max_silence_0_5 + ( + 0.2, + 5, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.2}, + [(10, 34), (36, 76)], + ), # stereo_uc_mix_aw_0_2_max_silence_0_6 + ( + 0.2, + 5, + 0, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 24), (27, 30), (36, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0 + ( + 0.4, + 5, + 0, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 24), (36, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_min_dur_0_3 + ( + 0.2, + 5, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 57), (57, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_6 + ( + 0.2, + 5.1, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 60), (60, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_1 + ( + 0.2, + 5.2, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 60), (60, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_2 + ( + 0.2, + 5.3, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 60), (60, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_3 + ( + 0.2, + 5.4, + 0.6, + 2, + {"uc": "mix", "analysis_window": 0.3}, + [(9, 63), (63, 76)], + ), # stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_4 + ( + 0.2, + 5, + 0, + 2, + {"uc": "mix", "analysis_window": 0.4}, + [(16, 24), (36, 76)], + ), # stereo_uc_mix_aw_0_4_max_silence_0 + ( + 0.2, + 5, + 0.3, + 2, + {"uc": "mix", "analysis_window": 0.4}, + [(16, 24), (36, 76)], + ), # stereo_uc_mix_aw_0_4_max_silence_0_3 + ( + 0.2, + 5, + 0.4, + 2, + {"uc": "mix", "analysis_window": 0.4}, + [(16, 28), (36, 76)], + ), # stereo_uc_mix_aw_0_4_max_silence_0_4 + ], + ids=[ + "mono_aw_0_2_max_silence_0_2", + "mono_aw_0_2_max_silence_0_3", + "mono_aw_0_2_max_silence_0_4", + "mono_aw_0_2_max_silence_0", + "mono_aw_0_2", + "mono_aw_0_3_max_silence_0", + "mono_aw_0_3_max_silence_0_3", + "mono_aw_0_3_max_silence_0_5", + "mono_aw_0_3_max_silence_0_6", + "mono_aw_0_4_max_silence_0", + "mono_aw_0_4_max_silence_0_3", + "mono_aw_0_4_max_silence_0_4", + "stereo_uc_None_analysis_window_0_2", + "stereo_uc_any_analysis_window_0_2", + "stereo_use_channel_None_aw_0_3_max_silence_0_2", + "stereo_use_channel_any_aw_0_3_max_silence_0_3", + "stereo_use_channel_None_aw_0_4_max_silence_0_2", + "stereo_use_channel_any_aw_0_3_max_silence_0_4", + "stereo_uc_0_analysis_window_0_2", + "stereo_uc_1_analysis_window_0_2", + "stereo_uc_mix_aw_0_1_max_silence_0", + "stereo_uc_mix_aw_0_1_max_silence_0_1", + "stereo_uc_mix_aw_0_1_max_silence_0_2", + "stereo_uc_mix_aw_0_1_max_silence_0_3", + "stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3", + "stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41", + "stereo_uc_mix_aw_0_2_max_silence_0_1", + "stereo_uc_mix_aw_0_2_max_silence_0_2", + "stereo_uc_mix_aw_0_2_max_silence_0_4", + "stereo_uc_mix_aw_0_2_max_silence_0_5", + "stereo_uc_mix_aw_0_2_max_silence_0_6", + "stereo_uc_mix_aw_0_3_max_silence_0", + "stereo_uc_mix_aw_0_3_max_silence_0_min_dur_0_3", + "stereo_uc_mix_aw_0_3_max_silence_0_6", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_1", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_2", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_3", + "stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_4", + "stereo_uc_mix_aw_0_4_max_silence_0", + "stereo_uc_mix_aw_0_4_max_silence_0_3", + "stereo_uc_mix_aw_0_4_max_silence_0_4", + ], +) +def test_split_analysis_window( + min_dur, max_dur, max_silence, channels, kwargs, expected +): + + mono_or_stereo = "mono" if channels == 1 else "stereo" + filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + drop_trailing_silence=False, + strict_min_dur=False, + sr=10, + sw=2, + ch=channels, + eth=49.99, + **kwargs + ) + + region = AudioRegion(data, 10, 2, channels) + regions_ar = region.split( + min_dur=min_dur, + max_dur=max_dur, + max_silence=max_silence, + drop_trailing_silence=False, + strict_min_dur=False, + eth=49.99, + **kwargs + ) + + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg + + sample_width = 2 + sample_size_bytes = sample_width * channels + for reg, reg_ar, exp in zip( + regions, + regions_ar, + expected, + ): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert bytes(reg) == exp_data + assert reg == reg_ar + + +def test_split_custom_validator(): + filename = "tests/data/test_split_10HZ_mono.raw" + with open(filename, "rb") as fp: + data = fp.read() + + regions = split( + data, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + sr=10, + sw=2, + ch=1, + analysis_window=0.1, + validator=lambda x: to_array(x, sample_width=2, channels=1)[0] >= 320, + ) + + region = AudioRegion(data, 10, 2, 1) + regions_ar = region.split( + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + validator=lambda x: to_array(x, sample_width=2, channels=1)[0] >= 320, + ) + + expected = [(2, 16), (17, 31), (34, 76)] + regions = list(regions) + regions_ar = list(regions_ar) + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions)) + assert len(regions) == len(expected), err_msg + err_msg = "Wrong number of regions after AudioRegion.split, expected: " + err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) + assert len(regions_ar) == len(expected), err_msg + + sample_size_bytes = 2 + for reg, reg_ar, exp in zip( + regions, + regions_ar, + expected, + ): + onset, offset = exp + exp_data = data[onset * sample_size_bytes : offset * sample_size_bytes] + assert bytes(reg) == exp_data + assert reg == reg_ar + + +@pytest.mark.parametrize( + "input, kwargs", + [ + ( + "tests/data/test_split_10HZ_stereo.raw", + {"audio_format": "raw", "sr": 10, "sw": 2, "ch": 2}, + ), # filename_audio_format + ( + "tests/data/test_split_10HZ_stereo.raw", + {"fmt": "raw", "sr": 10, "sw": 2, "ch": 2}, + ), # filename_audio_format_short_name + ( + "tests/data/test_split_10HZ_stereo.raw", + {"sr": 10, "sw": 2, "ch": 2}, + ), # filename_no_audio_format + ( + "tests/data/test_split_10HZ_stereo.raw", + {"sampling_rate": 10, "sample_width": 2, "channels": 2}, + ), # filename_no_long_audio_params + ( + open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), + {"sr": 10, "sw": 2, "ch": 2}, + ), # bytes_ + ( + AudioReader( + "tests/data/test_split_10HZ_stereo.raw", + sr=10, + sw=2, + ch=2, + block_dur=0.1, + ), + {}, + ), # audio_reader + ( + AudioRegion( + open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), + 10, + 2, + 2, + ), + {}, + ), # audio_region + ( + get_audio_source( + "tests/data/test_split_10HZ_stereo.raw", sr=10, sw=2, ch=2 + ), + {}, + ), # audio_source + ], + ids=[ + "filename_audio_format", + "filename_audio_format_short_name", + "filename_no_audio_format", + "filename_no_long_audio_params", + "bytes_", + "audio_reader", + "audio_region", + "audio_source", + ], +) +def test_split_input_type(input, kwargs): + + with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp: + data = fp.read() + + regions = split( + input, + min_dur=0.2, + max_dur=5, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + analysis_window=0.1, + **kwargs + ) + regions = list(regions) + expected = [(2, 32), (34, 76)] + sample_width = 2 + err_msg = "Wrong number of regions after split, expected: " + err_msg += "{}, found: {}".format(expected, regions) + assert len(regions) == len(expected), err_msg + for reg, exp in zip( + regions, + expected, + ): + onset, offset = exp + exp_data = data[onset * sample_width * 2 : offset * sample_width * 2] + assert bytes(reg) == exp_data + + +@pytest.mark.parametrize( + "min_dur, max_dur, analysis_window", + [ + (0.5, 0.4, 0.1), + (0.44, 0.49, 0.1), + ], + ids=[ + "min_dur_greater_than_max_dur", + "durations_OK_but_wrong_number_of_analysis_windows", + ], +) +def test_split_wrong_min_max_dur(min_dur, max_dur, analysis_window): + + with pytest.raises(ValueError) as val_err: + split( + b"0" * 16, + min_dur=min_dur, + max_dur=max_dur, + max_silence=0.2, + sr=16000, + sw=1, + ch=1, + analysis_window=analysis_window, + ) + + err_msg = "'min_dur' ({0} sec.) results in {1} analysis " + err_msg += "window(s) ({1} == ceil({0} / {2})) which is " + err_msg += "higher than the number of analysis window(s) for " + err_msg += "'max_dur' ({3} == floor({4} / {2}))" + + err_msg = err_msg.format( + min_dur, + math.ceil(min_dur / analysis_window), + analysis_window, + math.floor(max_dur / analysis_window), + max_dur, + ) + assert err_msg == str(val_err.value) + + +@pytest.mark.parametrize( + "max_silence, max_dur, analysis_window", + [ + (0.5, 0.5, 0.1), # max_silence_equals_max_dur + (0.5, 0.4, 0.1), # max_silence_greater_than_max_dur + (0.44, 0.49, 0.1), # durations_OK_but_wrong_number_of_analysis_windows + ], + ids=[ + "max_silence_equals_max_dur", + "max_silence_greater_than_max_dur", + "durations_OK_but_wrong_number_of_analysis_windows", + ], +) +def test_split_wrong_max_silence_max_dur(max_silence, max_dur, analysis_window): + + with pytest.raises(ValueError) as val_err: + split( + b"0" * 16, + min_dur=0.2, + max_dur=max_dur, + max_silence=max_silence, + sr=16000, + sw=1, + ch=1, + analysis_window=analysis_window, + ) + + err_msg = "'max_silence' ({0} sec.) results in {1} analysis " + err_msg += "window(s) ({1} == floor({0} / {2})) which is " + err_msg += "higher or equal to the number of analysis window(s) for " + err_msg += "'max_dur' ({3} == floor({4} / {2}))" + + err_msg = err_msg.format( + max_silence, + math.floor(max_silence / analysis_window), + analysis_window, + math.floor(max_dur / analysis_window), + max_dur, + ) + assert err_msg == str(val_err.value) + + +@pytest.mark.parametrize( + "wrong_param", + [ + {"min_dur": -1}, # negative_min_dur + {"min_dur": 0}, # zero_min_dur + {"max_dur": -1}, # negative_max_dur + {"max_dur": 0}, # zero_max_dur + {"max_silence": -1}, # negative_max_silence + {"analysis_window": 0}, # zero_analysis_window + {"analysis_window": -1}, # negative_analysis_window + ], + ids=[ + "negative_min_dur", + "zero_min_dur", + "negative_max_dur", + "zero_max_dur", + "negative_max_silence", + "zero_analysis_window", + "negative_analysis_window", + ], +) +def test_split_negative_temporal_params(wrong_param): + + params = { + "min_dur": 0.2, + "max_dur": 0.5, + "max_silence": 0.1, + "analysis_window": 0.1, + } + params.update(wrong_param) + with pytest.raises(ValueError) as val_err: + split(None, **params) + + name = set(wrong_param).pop() + value = wrong_param[name] + err_msg = "'{}' ({}) must be >{} 0".format( + name, value, "=" if name == "max_silence" else "" + ) + assert err_msg == str(val_err.value) + + +def test_split_too_small_analysis_window(): + with pytest.raises(ValueError) as val_err: + split(b"", sr=10, sw=1, ch=1, analysis_window=0.09) + err_msg = "Too small 'analysis_window' (0.09) for sampling rate (10)." + err_msg += " Analysis window should at least be 1/10 to cover one " + err_msg += "data sample" + assert err_msg == str(val_err.value) + + +def test_split_and_plot(): + + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + + region = AudioRegion(data, 10, 2, 1) + with patch("auditok.core.plot") as patch_fn: + regions = region.split_and_plot( min_dur=0.2, max_dur=5, max_silence=0.2, @@ -338,693 +1267,38 @@ analysis_window=0.1, sr=10, sw=2, - ch=channels, - **kwargs + ch=1, + eth=50, ) + assert patch_fn.called + expected = [(2, 16), (17, 31), (34, 76)] + sample_width = 2 + expected_regions = [] + for onset, offset in expected: + onset *= sample_width + offset *= sample_width + expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1)) + assert regions == expected_regions - region = AudioRegion(data, 10, 2, channels) - max_read = kwargs.get("max_read", kwargs.get("mr")) - if max_read is not None: - region = region.sec[:max_read] - kwargs.pop("max_read", None) - kwargs.pop("mr", None) - regions_ar = region.split( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - **kwargs - ) +def test_split_exception(): + with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: + data = fp.read() + region = AudioRegion(data, 10, 2, 1) - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) + with pytest.raises(RuntimeWarning): + # max_read is not accepted when calling AudioRegion.split + region.split(max_read=2) - sample_width = 2 - sample_size_bytes = sample_width * channels - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(len(bytes(reg)), len(exp_data)) - self.assertEqual(reg, reg_ar) - @genty_dataset( - mono_aw_0_2_max_silence_0_2=( - 0.2, - 5, - 0.2, - 1, - {"aw": 0.2}, - [(2, 30), (34, 76)], - ), - mono_aw_0_2_max_silence_0_3=( - 0.2, - 5, - 0.3, - 1, - {"aw": 0.2}, - [(2, 30), (34, 76)], - ), - mono_aw_0_2_max_silence_0_4=( - 0.2, - 5, - 0.4, - 1, - {"aw": 0.2}, - [(2, 32), (34, 76)], - ), - mono_aw_0_2_max_silence_0=( - 0.2, - 5, - 0, - 1, - {"aw": 0.2}, - [(2, 14), (16, 24), (26, 28), (34, 76)], - ), - mono_aw_0_2=(0.2, 5, 0.2, 1, {"aw": 0.2}, [(2, 30), (34, 76)]), - mono_aw_0_3_max_silence_0=( - 0.3, - 5, - 0, - 1, - {"aw": 0.3}, - [(3, 12), (15, 24), (36, 76)], - ), - mono_aw_0_3_max_silence_0_3=( - 0.3, - 5, - 0.3, - 1, - {"aw": 0.3}, - [(3, 27), (36, 76)], - ), - mono_aw_0_3_max_silence_0_5=( - 0.3, - 5, - 0.5, - 1, - {"aw": 0.3}, - [(3, 27), (36, 76)], - ), - mono_aw_0_3_max_silence_0_6=( - 0.3, - 5, - 0.6, - 1, - {"aw": 0.3}, - [(3, 30), (36, 76)], - ), - mono_aw_0_4_max_silence_0=( - 0.2, - 5, - 0, - 1, - {"aw": 0.4}, - [(4, 12), (16, 24), (36, 76)], - ), - mono_aw_0_4_max_silence_0_3=( - 0.2, - 5, - 0.3, - 1, - {"aw": 0.4}, - [(4, 12), (16, 24), (36, 76)], - ), - mono_aw_0_4_max_silence_0_4=( - 0.2, - 5, - 0.4, - 1, - {"aw": 0.4}, - [(4, 28), (36, 76)], - ), - stereo_uc_None_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"analysis_window": 0.2}, - [(2, 32), (34, 76)], - ), - stereo_uc_any_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": None, "analysis_window": 0.2}, - [(2, 32), (34, 76)], - ), - stereo_use_channel_None_aw_0_3_max_silence_0_2=( - 0.2, - 5, - 0.2, - 2, - {"use_channel": None, "analysis_window": 0.3}, - [(3, 30), (36, 76)], - ), - stereo_use_channel_any_aw_0_3_max_silence_0_3=( - 0.2, - 5, - 0.3, - 2, - {"use_channel": "any", "analysis_window": 0.3}, - [(3, 33), (36, 76)], - ), - stereo_use_channel_None_aw_0_4_max_silence_0_2=( - 0.2, - 5, - 0.2, - 2, - {"use_channel": None, "analysis_window": 0.4}, - [(4, 28), (36, 76)], - ), - stereo_use_channel_any_aw_0_3_max_silence_0_4=( - 0.2, - 5, - 0.4, - 2, - {"use_channel": "any", "analysis_window": 0.4}, - [(4, 32), (36, 76)], - ), - stereo_uc_0_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": 0, "analysis_window": 0.2}, - [(2, 30), (34, 76)], - ), - stereo_uc_1_analysis_window_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": 1, "analysis_window": 0.2}, - [(10, 32), (36, 76)], - ), - stereo_uc_mix_aw_0_1_max_silence_0=( - 0.2, - 5, - 0, - 2, - {"uc": "mix", "analysis_window": 0.1}, - [(10, 14), (17, 24), (26, 29), (36, 76)], - ), - stereo_uc_mix_aw_0_1_max_silence_0_1=( - 0.2, - 5, - 0.1, - 2, - {"uc": "mix", "analysis_window": 0.1}, - [(10, 15), (17, 25), (26, 30), (36, 76)], - ), - stereo_uc_mix_aw_0_1_max_silence_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": "mix", "analysis_window": 0.1}, - [(10, 16), (17, 31), (36, 76)], - ), - stereo_uc_mix_aw_0_1_max_silence_0_3=( - 0.2, - 5, - 0.3, - 2, - {"uc": "mix", "analysis_window": 0.1}, - [(10, 32), (36, 76)], - ), - stereo_uc_avg_aw_0_2_max_silence_0_min_dur_0_3=( - 0.3, - 5, - 0, - 2, - {"uc": "avg", "analysis_window": 0.2}, - [(10, 14), (16, 24), (36, 76)], - ), - stereo_uc_average_aw_0_2_max_silence_0_min_dur_0_41=( - 0.41, - 5, - 0, - 2, - {"uc": "average", "analysis_window": 0.2}, - [(16, 24), (36, 76)], - ), - stereo_uc_mix_aw_0_2_max_silence_0_1=( - 0.2, - 5, - 0.1, - 2, - {"uc": "mix", "analysis_window": 0.2}, - [(10, 14), (16, 24), (26, 28), (36, 76)], - ), - stereo_uc_mix_aw_0_2_max_silence_0_2=( - 0.2, - 5, - 0.2, - 2, - {"uc": "mix", "analysis_window": 0.2}, - [(10, 30), (36, 76)], - ), - stereo_uc_mix_aw_0_2_max_silence_0_4=( - 0.2, - 5, - 0.4, - 2, - {"uc": "mix", "analysis_window": 0.2}, - [(10, 32), (36, 76)], - ), - stereo_uc_mix_aw_0_2_max_silence_0_5=( - 0.2, - 5, - 0.5, - 2, - {"uc": "mix", "analysis_window": 0.2}, - [(10, 32), (36, 76)], - ), - stereo_uc_mix_aw_0_2_max_silence_0_6=( - 0.2, - 5, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.2}, - [(10, 34), (36, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0=( - 0.2, - 5, - 0, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 24), (27, 30), (36, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_min_dur_0_3=( - 0.4, - 5, - 0, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 24), (36, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_6=( - 0.2, - 5, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 57), (57, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_1=( - 0.2, - 5.1, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 60), (60, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_2=( - 0.2, - 5.2, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 60), (60, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_3=( - 0.2, - 5.3, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 60), (60, 76)], - ), - stereo_uc_mix_aw_0_3_max_silence_0_6_max_dur_5_4=( - 0.2, - 5.4, - 0.6, - 2, - {"uc": "mix", "analysis_window": 0.3}, - [(9, 63), (63, 76)], - ), - stereo_uc_mix_aw_0_4_max_silence_0=( - 0.2, - 5, - 0, - 2, - {"uc": "mix", "analysis_window": 0.4}, - [(16, 24), (36, 76)], - ), - stereo_uc_mix_aw_0_4_max_silence_0_3=( - 0.2, - 5, - 0.3, - 2, - {"uc": "mix", "analysis_window": 0.4}, - [(16, 24), (36, 76)], - ), - stereo_uc_mix_aw_0_4_max_silence_0_4=( - 0.2, - 5, - 0.4, - 2, - {"uc": "mix", "analysis_window": 0.4}, - [(16, 28), (36, 76)], - ), - ) - def test_split_analysis_window( - self, min_dur, max_dur, max_silence, channels, kwargs, expected - ): - - mono_or_stereo = "mono" if channels == 1 else "stereo" - filename = "tests/data/test_split_10HZ_{}.raw".format(mono_or_stereo) - with open(filename, "rb") as fp: - data = fp.read() - - regions = split( - data, - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - drop_trailing_silence=False, - strict_min_dur=False, - sr=10, - sw=2, - ch=channels, - eth=49.99, - **kwargs - ) - - region = AudioRegion(data, 10, 2, channels) - regions_ar = region.split( - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - drop_trailing_silence=False, - strict_min_dur=False, - eth=49.99, - **kwargs - ) - - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) - - sample_width = 2 - sample_size_bytes = sample_width * channels - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) - - def test_split_custom_validator(self): - filename = "tests/data/test_split_10HZ_mono.raw" - with open(filename, "rb") as fp: - data = fp.read() - - regions = split( - data, - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - sr=10, - sw=2, - ch=1, - analysis_window=0.1, - validator=lambda x: array_("h", x)[0] >= 320, - ) - - region = AudioRegion(data, 10, 2, 1) - regions_ar = region.split( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - validator=lambda x: array_("h", x)[0] >= 320, - ) - - expected = [(2, 16), (17, 31), (34, 76)] - regions = list(regions) - regions_ar = list(regions_ar) - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions)) - self.assertEqual(len(regions), len(expected), err_msg) - err_msg = "Wrong number of regions after AudioRegion.split, expected: " - err_msg += "{}, found: {}".format(len(expected), len(regions_ar)) - self.assertEqual(len(regions_ar), len(expected), err_msg) - - sample_size_bytes = 2 - for reg, reg_ar, exp in zip(regions, regions_ar, expected): - onset, offset = exp - exp_data = data[ - onset * sample_size_bytes : offset * sample_size_bytes - ] - self.assertEqual(bytes(reg), exp_data) - self.assertEqual(reg, reg_ar) - - @genty_dataset( - filename_audio_format=( - "tests/data/test_split_10HZ_stereo.raw", - {"audio_format": "raw", "sr": 10, "sw": 2, "ch": 2}, - ), - filename_audio_format_short_name=( - "tests/data/test_split_10HZ_stereo.raw", - {"fmt": "raw", "sr": 10, "sw": 2, "ch": 2}, - ), - filename_no_audio_format=( - "tests/data/test_split_10HZ_stereo.raw", - {"sr": 10, "sw": 2, "ch": 2}, - ), - filename_no_long_audio_params=( - "tests/data/test_split_10HZ_stereo.raw", - {"sampling_rate": 10, "sample_width": 2, "channels": 2}, - ), - bytes_=( - open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), - {"sr": 10, "sw": 2, "ch": 2}, - ), - audio_reader=( - AudioDataSource( - "tests/data/test_split_10HZ_stereo.raw", - sr=10, - sw=2, - ch=2, - block_dur=0.1, - ), - {}, - ), - audio_region=( - AudioRegion( - open("tests/data/test_split_10HZ_stereo.raw", "rb").read(), - 10, - 2, - 2, - ), - {}, - ), - audio_source=( - get_audio_source( - "tests/data/test_split_10HZ_stereo.raw", sr=10, sw=2, ch=2 - ), - {}, - ), - ) - def test_split_input_type(self, input, kwargs): - - with open("tests/data/test_split_10HZ_stereo.raw", "rb") as fp: - data = fp.read() - - regions = split( - input, - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - **kwargs - ) - regions = list(regions) - expected = [(2, 32), (34, 76)] - sample_width = 2 - err_msg = "Wrong number of regions after split, expected: " - err_msg += "{}, found: {}".format(expected, regions) - self.assertEqual(len(regions), len(expected), err_msg) - for reg, exp in zip(regions, expected): - onset, offset = exp - exp_data = data[ - onset * sample_width * 2 : offset * sample_width * 2 - ] - self.assertEqual(bytes(reg), exp_data) - - @genty_dataset( - min_dur_greater_than_max_dur=(0.5, 0.4, 0.1), - durations_OK_but_wrong_number_of_analysis_windows=(0.44, 0.49, 0.1), - ) - def test_split_wrong_min_max_dur(self, min_dur, max_dur, analysis_window): - - with self.assertRaises(ValueError) as val_err: - split( - b"0" * 16, - min_dur=min_dur, - max_dur=max_dur, - max_silence=0.2, - sr=16000, - sw=1, - ch=1, - analysis_window=analysis_window, - ) - - err_msg = "'min_dur' ({0} sec.) results in {1} analysis " - err_msg += "window(s) ({1} == ceil({0} / {2})) which is " - err_msg += "higher than the number of analysis window(s) for " - err_msg += "'max_dur' ({3} == floor({4} / {2}))" - - err_msg = err_msg.format( - min_dur, - math.ceil(min_dur / analysis_window), - analysis_window, - math.floor(max_dur / analysis_window), - max_dur, - ) - self.assertEqual(err_msg, str(val_err.exception)) - - @genty_dataset( - max_silence_equals_max_dur=(0.5, 0.5, 0.1), - max_silence_greater_than_max_dur=(0.5, 0.4, 0.1), - durations_OK_but_wrong_number_of_analysis_windows=(0.44, 0.49, 0.1), - ) - def test_split_wrong_max_silence_max_dur( - self, max_silence, max_dur, analysis_window - ): - - with self.assertRaises(ValueError) as val_err: - split( - b"0" * 16, - min_dur=0.2, - max_dur=max_dur, - max_silence=max_silence, - sr=16000, - sw=1, - ch=1, - analysis_window=analysis_window, - ) - - err_msg = "'max_silence' ({0} sec.) results in {1} analysis " - err_msg += "window(s) ({1} == floor({0} / {2})) which is " - err_msg += "higher or equal to the number of analysis window(s) for " - err_msg += "'max_dur' ({3} == floor({4} / {2}))" - - err_msg = err_msg.format( - max_silence, - math.floor(max_silence / analysis_window), - analysis_window, - math.floor(max_dur / analysis_window), - max_dur, - ) - self.assertEqual(err_msg, str(val_err.exception)) - - @genty_dataset( - negative_min_dur=({"min_dur": -1},), - zero_min_dur=({"min_dur": 0},), - negative_max_dur=({"max_dur": -1},), - zero_max_dur=({"max_dur": 0},), - negative_max_silence=({"max_silence": -1},), - zero_analysis_window=({"analysis_window": 0},), - negative_analysis_window=({"analysis_window": -1},), - ) - def test_split_negative_temporal_params(self, wrong_param): - - params = { - "min_dur": 0.2, - "max_dur": 0.5, - "max_silence": 0.1, - "analysis_window": 0.1, - } - params.update(wrong_param) - with self.assertRaises(ValueError) as val_err: - split(None, **params) - - name = set(wrong_param).pop() - value = wrong_param[name] - err_msg = "'{}' ({}) must be >{} 0".format( - name, value, "=" if name == "max_silence" else "" - ) - self.assertEqual(err_msg, str(val_err.exception)) - - def test_split_too_small_analysis_window(self): - with self.assertRaises(ValueError) as val_err: - split(b"", sr=10, sw=1, ch=1, analysis_window=0.09) - err_msg = "Too small 'analysis_windows' (0.09) for sampling rate (10)." - err_msg += " Analysis windows should at least be 1/10 to cover one " - err_msg += "single data sample" - self.assertEqual(err_msg, str(val_err.exception)) - - def test_split_and_plot(self): - - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() - - region = AudioRegion(data, 10, 2, 1) - with patch("auditok.plotting.plot") as patch_fn: - regions = region.split_and_plot( - min_dur=0.2, - max_dur=5, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - analysis_window=0.1, - sr=10, - sw=2, - ch=1, - eth=50, - ) - self.assertTrue(patch_fn.called) - expected = [(2, 16), (17, 31), (34, 76)] - sample_width = 2 - expected_regions = [] - for (onset, offset) in expected: - onset *= sample_width - offset *= sample_width - expected_regions.append(AudioRegion(data[onset:offset], 10, 2, 1)) - self.assertEqual(regions, expected_regions) - - def test_split_exception(self): - with open("tests/data/test_split_10HZ_mono.raw", "rb") as fp: - data = fp.read() - region = AudioRegion(data, 10, 2, 1) - - with self.assertRaises(RuntimeWarning): - # max_read is not accepted when calling AudioRegion.split - region.split(max_read=2) - - -@genty -class TestAudioRegion(TestCase): - @genty_dataset( - simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000), - one_ms_less_than_1_sec=( +@pytest.mark.parametrize( + ( + "data, start, sampling_rate, sample_width, channels, expected_end, " + + "expected_duration_s, expected_duration_ms" + ), + [ + (b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000), # simple + ( b"\0" * 7992, 0, 8000, @@ -1033,8 +1307,8 @@ 0.999, 0.999, 999, - ), - tree_quarter_ms_less_than_1_sec=( + ), # one_ms_less_than_1_sec + ( b"\0" * 7994, 0, 8000, @@ -1043,8 +1317,8 @@ 0.99925, 0.99925, 999, - ), - half_ms_less_than_1_sec=( + ), # tree_quarter_ms_less_than_1_sec + ( b"\0" * 7996, 0, 8000, @@ -1053,8 +1327,8 @@ 0.9995, 0.9995, 1000, - ), - quarter_ms_less_than_1_sec=( + ), # half_ms_less_than_1_sec + ( b"\0" * 7998, 0, 8000, @@ -1063,11 +1337,11 @@ 0.99975, 0.99975, 1000, - ), - simple_sample_width_2=(b"\0" * 8000 * 2, 0, 8000, 2, 1, 1, 1, 1000), - simple_stereo=(b"\0" * 8000 * 2, 0, 8000, 1, 2, 1, 1, 1000), - simple_multichannel=(b"\0" * 8000 * 5, 0, 8000, 1, 5, 1, 1, 1000), - simple_sample_width_2_multichannel=( + ), # quarter_ms_less_than_1_sec + (b"\0" * 8000 * 2, 0, 8000, 2, 1, 1, 1, 1000), # simple_sample_width_2 + (b"\0" * 8000 * 2, 0, 8000, 1, 2, 1, 1, 1000), # simple_stereo + (b"\0" * 8000 * 5, 0, 8000, 1, 5, 1, 1, 1000), # simple_multichannel + ( b"\0" * 8000 * 2 * 5, 0, 8000, @@ -1076,8 +1350,8 @@ 1, 1, 1000, - ), - one_ms_less_than_1s_sw_2_multichannel=( + ), # simple_sample_width_2_multichannel + ( b"\0" * 7992 * 2 * 5, 0, 8000, @@ -1086,8 +1360,8 @@ 0.999, 0.999, 999, - ), - tree_qrt_ms_lt_1_s_sw_2_multichannel=( + ), # one_ms_less_than_1s_sw_2_multichannel + ( b"\0" * 7994 * 2 * 5, 0, 8000, @@ -1096,8 +1370,8 @@ 0.99925, 0.99925, 999, - ), - half_ms_lt_1s_sw_2_multichannel=( + ), # tree_qrt_ms_lt_1_s_sw_2_multichannel + ( b"\0" * 7996 * 2 * 5, 0, 8000, @@ -1106,8 +1380,8 @@ 0.9995, 0.9995, 1000, - ), - quarter_ms_lt_1s_sw_2_multichannel=( + ), # half_ms_lt_1s_sw_2_multichannel + ( b"\0" * 7998 * 2 * 5, 0, 8000, @@ -1116,8 +1390,8 @@ 0.99975, 0.99975, 1000, - ), - arbitrary_length_1=( + ), # quarter_ms_lt_1s_sw_2_multichannel + ( b"\0" * int(8000 * 1.33), 2.7, 8000, @@ -1126,8 +1400,8 @@ 4.03, 1.33, 1330, - ), - arbitrary_length_2=( + ), # arbitrary_length_1 + ( b"\0" * int(8000 * 0.476), 11.568, 8000, @@ -1136,8 +1410,8 @@ 12.044, 0.476, 476, - ), - arbitrary_length_sw_2_multichannel=( + ), # arbitrary_length_2 + ( b"\0" * int(8000 * 1.711) * 2 * 3, 9.415, 8000, @@ -1146,8 +1420,8 @@ 11.126, 1.711, 1711, - ), - arbitrary_samplig_rate=( + ), # arbitrary_length_sw_2_multichannel + ( b"\0" * int(3172 * 1.318), 17.236, 3172, @@ -1156,8 +1430,8 @@ 17.236 + int(3172 * 1.318) / 3172, int(3172 * 1.318) / 3172, 1318, - ), - arbitrary_sr_sw_2_multichannel=( + ), # arbitrary_sampling_rate + ( b"\0" * int(11317 * 0.716) * 2 * 3, 18.811, 11317, @@ -1166,534 +1440,767 @@ 18.811 + int(11317 * 0.716) / 11317, int(11317 * 0.716) / 11317, 716, - ), + ), # arbitrary_sr_sw_2_multichannel + ], + ids=[ + "simple", + "one_ms_less_than_1_sec", + "tree_quarter_ms_less_than_1_sec", + "half_ms_less_than_1_sec", + "quarter_ms_less_than_1_sec", + "simple_sample_width_2", + "simple_stereo", + "simple_multichannel", + "simple_sample_width_2_multichannel", + "one_ms_less_than_1s_sw_2_multichannel", + "tree_qrt_ms_lt_1_s_sw_2_multichannel", + "half_ms_lt_1s_sw_2_multichannel", + "quarter_ms_lt_1s_sw_2_multichannel", + "arbitrary_length_1", + "arbitrary_length_2", + "arbitrary_length_sw_2_multichannel", + "arbitrary_sampling_rate", + "arbitrary_sr_sw_2_multichannel", + ], +) +def test_creation( + data, + start, + sampling_rate, + sample_width, + channels, + expected_end, + expected_duration_s, + expected_duration_ms, +): + region = AudioRegion(data, sampling_rate, sample_width, channels, start) + assert region.sampling_rate == sampling_rate + assert region.sr == sampling_rate + assert region.sample_width == sample_width + assert region.sw == sample_width + assert region.channels == channels + assert region.ch == channels + assert region.meta.start == start + assert region.meta.end == expected_end + assert region.duration == expected_duration_s + assert len(region.ms) == expected_duration_ms + assert bytes(region) == data + + +def test_creation_invalid_data_exception(): + with pytest.raises(AudioParameterError) as audio_param_err: + _ = AudioRegion( + data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 + ) + assert str(audio_param_err.value) == ( + "The length of audio data must be an integer " + "multiple of `sample_width * channels`" ) - def test_creation( - self, - data, - start, - sampling_rate, - sample_width, - channels, - expected_end, - expected_duration_s, - expected_duration_ms, - ): - meta = {"start": start, "end": expected_end} - region = AudioRegion(data, sampling_rate, sample_width, channels, meta) - self.assertEqual(region.sampling_rate, sampling_rate) - self.assertEqual(region.sr, sampling_rate) - self.assertEqual(region.sample_width, sample_width) - self.assertEqual(region.sw, sample_width) - self.assertEqual(region.channels, channels) - self.assertEqual(region.ch, channels) - self.assertEqual(region.meta.start, start) - self.assertEqual(region.meta.end, expected_end) - self.assertEqual(region.duration, expected_duration_s) - self.assertEqual(len(region.ms), expected_duration_ms) - self.assertEqual(bytes(region), data) - def test_creation_invalid_data_exception(self): - with self.assertRaises(AudioParameterError) as audio_param_err: - _ = AudioRegion( - data=b"ABCDEFGHI", sampling_rate=8, sample_width=2, channels=1 - ) - self.assertEqual( - "The length of audio data must be an integer " - "multiple of `sample_width * channels`", - str(audio_param_err.exception), - ) - @genty_dataset( - no_skip_read_all=(0, -1), - no_skip_read_all_stereo=(0, -1, 2), - skip_2_read_all=(2, -1), - skip_2_read_all_None=(2, None), - skip_2_read_3=(2, 3), - skip_2_read_3_5_stereo=(2, 3.5, 2), - skip_2_4_read_3_5_stereo=(2.4, 3.5, 2), +@pytest.mark.parametrize( + "skip, max_read, channels", + [ + (0, -1, 1), # no_skip_read_all + (0, -1, 2), # no_skip_read_all_stereo + (2, -1, 1), # skip_2_read_all + (2, None, 1), # skip_2_read_all_None + (2, 3, 1), # skip_2_read_3 + (2, 3.5, 2), # skip_2_read_3_5_stereo + (2.4, 3.5, 2), # skip_2_4_read_3_5_stereo + ], + ids=[ + "no_skip_read_all", + "no_skip_read_all_stereo", + "skip_2_read_all", + "skip_2_read_all_None", + "skip_2_read_3", + "skip_2_read_3_5_stereo", + "skip_2_4_read_3_5_stereo", + ], +) +def test_load_AudioRegion(skip, max_read, channels): + sampling_rate = 10 + sample_width = 2 + filename = "tests/data/test_split_10HZ_{}.raw" + filename = filename.format("mono" if channels == 1 else "stereo") + region = AudioRegion.load( + filename, + skip=skip, + max_read=max_read, + sr=sampling_rate, + sw=sample_width, + ch=channels, ) - def test_load(self, skip, max_read, channels=1): - sampling_rate = 10 - sample_width = 2 - filename = "tests/data/test_split_10HZ_{}.raw" - filename = filename.format("mono" if channels == 1 else "stereo") - region = AudioRegion.load( - filename, - skip=skip, - max_read=max_read, - sr=sampling_rate, - sw=sample_width, - ch=channels, - ) - with open(filename, "rb") as fp: - fp.read(round(skip * sampling_rate * sample_width * channels)) - if max_read is None or max_read < 0: - to_read = -1 - else: - to_read = round( - max_read * sampling_rate * sample_width * channels - ) - expected = fp.read(to_read) - self.assertEqual(bytes(region), expected) + with open(filename, "rb") as fp: + fp.read(round(skip * sampling_rate * sample_width * channels)) + if max_read is None or max_read < 0: + to_read = -1 + else: + to_read = round(max_read * sampling_rate * sample_width * channels) + expected = fp.read(to_read) + assert bytes(region) == expected - def test_load_from_microphone(self): - with patch("auditok.io.PyAudioSource") as patch_pyaudio_source: - with patch("auditok.core.AudioReader.read") as patch_reader: - patch_reader.return_value = None - with patch( - "auditok.core.AudioRegion.__init__" - ) as patch_AudioRegion: - patch_AudioRegion.return_value = None - AudioRegion.load( - None, skip=0, max_read=5, sr=16000, sw=2, ch=1 - ) - self.assertTrue(patch_pyaudio_source.called) - self.assertTrue(patch_reader.called) - self.assertTrue(patch_AudioRegion.called) - @genty_dataset(none=(None,), negative=(-1,)) - def test_load_from_microphone_without_max_read_exception(self, max_read): - with self.assertRaises(ValueError) as val_err: - AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1) - self.assertEqual( - "'max_read' should not be None when reading from microphone", - str(val_err.exception), - ) +def test_load_from_microphone(): + with patch("auditok.io.PyAudioSource") as patch_pyaudio_source: + with patch("auditok.core.AudioReader.read") as patch_reader: + patch_reader.return_value = None + with patch( + "auditok.core.AudioRegion.__init__" + ) as patch_AudioRegion: + patch_AudioRegion.return_value = None + AudioRegion.load(None, skip=0, max_read=5, sr=16000, sw=2, ch=1) + assert patch_pyaudio_source.called + assert patch_reader.called + assert patch_AudioRegion.called - def test_load_from_microphone_with_nonzero_skip_exception(self): - with self.assertRaises(ValueError) as val_err: - AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1) - self.assertEqual( - "'skip' should be 0 when reading from microphone", - str(val_err.exception), - ) - @genty_dataset( - simple=("output.wav", 1.230, "output.wav"), - start=("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"), - start_2=("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"), - start_3=("output_{meta.start:.2f}.wav", 1.2300001, "output_1.23.wav"), - start_4=("output_{meta.start:.3f}.wav", 1.233712, "output_1.234.wav"), - start_5=( +@pytest.mark.parametrize( + "max_read", + [ + None, # None + -1, # negative + ], + ids=[ + "None", + "negative", + ], +) +def test_load_from_microphone_without_max_read_exception(max_read): + with pytest.raises(ValueError) as val_err: + AudioRegion.load(None, max_read=max_read, sr=16000, sw=2, ch=1) + assert str(val_err.value) == ( + "'max_read' should not be None when reading from microphone" + ) + + +def test_load_from_microphone_with_nonzero_skip_exception(): + with pytest.raises(ValueError) as val_err: + AudioRegion.load(None, skip=1, max_read=5, sr=16000, sw=2, ch=1) + assert str(val_err.value) == ( + "'skip' should be 0 when reading from microphone" + ) + + +@pytest.mark.parametrize( + "format, start, expected", + [ + ("output.wav", 1.230, "output.wav"), # simple + ("output_{meta.start:g}.wav", 1.230, "output_1.23.wav"), # start + ("output_{meta.start}.wav", 1.233712, "output_1.233712.wav"), # start_2 + ( + "output_{meta.start:.2f}.wav", + 1.2300001, + "output_1.23.wav", + ), # start_3 + ( + "output_{meta.start:.3f}.wav", + 1.233712, + "output_1.234.wav", + ), # start_4 + ( "output_{meta.start:.8f}.wav", 1.233712, "output_1.23371200.wav", - ), - start_end_duration=( + ), # start_5 + ( "output_{meta.start}_{meta.end}_{duration}.wav", 1.455, "output_1.455_2.455_1.0.wav", - ), - start_end_duration_2=( + ), # start_end_duration + ( "output_{meta.start}_{meta.end}_{duration}.wav", 1.455321, "output_1.455321_2.455321_1.0.wav", - ), + ), # start_end_duration_2 + ], + ids=[ + "simple", + "start", + "start_2", + "start_3", + "start_4", + "start_5", + "start_end_duration", + "start_end_duration_2", + ], +) +def test_save(format, start, expected): + with TemporaryDirectory() as tmpdir: + region = AudioRegion(b"0" * 160, 160, 1, 1, start) + format = os.path.join(tmpdir, format) + filename = region.save(format)[len(tmpdir) + 1 :] + assert filename == expected + + +def test_save_file_exists_exception(): + with TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, "output.wav") + open(filename, "w").close() + region = AudioRegion(b"0" * 160, 160, 1, 1) + with pytest.raises(FileExistsError): + region.save(filename, exists_ok=False) + + with pytest.raises(FileExistsError): + region.save(Path(filename), exists_ok=False) + + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (16000, 1, 1), # mono_16K_1byte + (16000, 2, 1), # mono_16K_2byte + (44100, 2, 2), # stereo_44100_2byte + (44100, 2, 3), # 3channel_44100_2byte + ], + ids=[ + "mono_16K_1byte", + "mono_16K_2byte", + "stereo_44100_2byte", + "3channel_44100_2byte", + ], +) +def test_join(sampling_rate, sample_width, channels): + duration = 1 + size = int(duration * sampling_rate * sample_width * channels) + glue_data = b"\0" * size + regions_data = [ + b"\1" * int(size * 1.5), + b"\2" * int(size * 0.5), + b"\3" * int(size * 0.75), + ] + + glue_region = AudioRegion(glue_data, sampling_rate, sample_width, channels) + regions = [ + AudioRegion(data, sampling_rate, sample_width, channels) + for data in regions_data + ] + joined = glue_region.join(regions) + assert joined.data == glue_data.join(regions_data) + assert joined.duration == duration * 2 + 1.5 + 0.5 + 0.75 + + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (32000, 1, 1), # different_sampling_rate + (16000, 2, 1), # different_sample_width + (16000, 1, 2), # different_channels + ], + ids=[ + "different_sampling_rate", + "different_sample_width", + "different_channels", + ], +) +def test_join_exception(sampling_rate, sample_width, channels): + + glue_sampling_rate = 16000 + glue_sample_width = 1 + glue_channels = 1 + + duration = 1 + size = int( + duration * glue_sampling_rate * glue_sample_width * glue_channels ) - def test_save(self, format, start, expected): - with TemporaryDirectory() as tmpdir: - region = AudioRegion(b"0" * 160, 160, 1, 1) - meta = {"start": start, "end": start + region.duration} - region.meta = meta - format = os.path.join(tmpdir, format) - filename = region.save(format)[len(tmpdir) + 1 :] - self.assertEqual(filename, expected) + glue_data = b"\0" * size + glue_region = AudioRegion( + glue_data, glue_sampling_rate, glue_sample_width, glue_channels + ) - def test_save_file_exists_exception(self): - with TemporaryDirectory() as tmpdir: - filename = os.path.join(tmpdir, "output.wav") - open(filename, "w").close() - region = AudioRegion(b"0" * 160, 160, 1, 1) - with self.assertRaises(FileExistsError): - region.save(filename, exists_ok=False) + size = int(duration * sampling_rate * sample_width * channels) + regions_data = [ + b"\1" * int(size * 1.5), + b"\2" * int(size * 0.5), + b"\3" * int(size * 0.75), + ] + regions = [ + AudioRegion(data, sampling_rate, sample_width, channels) + for data in regions_data + ] - @genty_dataset( - first_half=( + with pytest.raises(AudioParameterError): + glue_region.join(regions) + + +@pytest.mark.parametrize( + "region, slice_, expected_data", + [ + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 500), - b"a" * 80, + b"a" * 80, # first_half ), - second_half=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(500, None), - b"b" * 80, + b"b" * 80, # second_half ), - second_half_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-500, None), - b"b" * 80, + b"b" * 80, # second_half_negative ), - middle=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(200, 750), - b"a" * 48 + b"b" * 40, + b"a" * 48 + b"b" * 40, # middle ), - middle_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-800, -250), - b"a" * 48 + b"b" * 40, + b"a" * 48 + b"b" * 40, # middle_negative ), - middle_sw2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 2, 1), slice(200, 750), - b"a" * 96 + b"b" * 80, + b"a" * 96 + b"b" * 80, # middle_sw2 ), - middle_ch2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 1, 2), slice(200, 750), - b"a" * 96 + b"b" * 80, + b"a" * 96 + b"b" * 80, # middle_ch2 ), - middle_sw2_ch2=( + ( AudioRegion(b"a" * 320 + b"b" * 320, 160, 2, 2), slice(200, 750), - b"a" * 192 + b"b" * 160, + b"a" * 192 + b"b" * 160, # middle_sw2_ch2 ), - but_first_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(1, None), - b"a" * (4000 - 8) + b"b" * 4000, + b"a" * (4000 - 8) + b"b" * 4000, # but_first_sample ), - but_first_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(-999, None), - b"a" * (4000 - 8) + b"b" * 4000, + b"a" * (4000 - 8) + b"b" * 4000, # but_first_sample_negative ), - but_last_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, 999), - b"a" * 4000 + b"b" * (4000 - 8), + b"a" * 4000 + b"b" * (4000 - 8), # but_last_sample ), - but_last_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, -1), - b"a" * 4000 + b"b" * (4000 - 8), + b"a" * 4000 + b"b" * (4000 - 8), # but_last_sample_negative ), - big_negative_start=( + ( AudioRegion(b"a" * 160, 160, 1, 1), slice(-5000, None), - b"a" * 160, + b"a" * 160, # big_negative_start ), - big_negative_stop=( + ( AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1500), - b"", + b"", # big_negative_stop ), - empty=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), - b"", + b"", # empty ), - empty_start_stop_reversed=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(200, 100), - b"", + b"", # empty_start_stop_reversed ), - empty_big_positive_start=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(2000, 3000), - b"", + b"", # empty_big_positive_start ), - empty_negative_reversed=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-100, -200), - b"", + b"", # empty_negative_reversed ), - empty_big_negative_stop=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, -2000), - b"", + b"", # empty_big_negative_stop ), - arbitrary_sampling_rate=( + ( AudioRegion(b"a" * 124 + b"b" * 376, 1234, 1, 1), slice(100, 200), - b"a" + b"b" * 123, + b"a" + b"b" * 123, # arbitrary_sampling_rate ), - ) - def test_region_temporal_slicing(self, region, slice_, expected_data): - sub_region = region.millis[slice_] - self.assertEqual(bytes(sub_region), expected_data) - start_sec = slice_.start / 1000 if slice_.start is not None else None - stop_sec = slice_.stop / 1000 if slice_.stop is not None else None - sub_region = region.sec[start_sec:stop_sec] - self.assertEqual(bytes(sub_region), expected_data) + ], + ids=[ + "first_half", + "second_half", + "second_half_negative", + "middle", + "middle_negative", + "middle_sw2", + "middle_ch2", + "middle_sw2_ch2", + "but_first_sample", + "but_first_sample_negative", + "but_last_sample", + "but_last_sample_negative", + "big_negative_start", + "big_negative_stop", + "empty", + "empty_start_stop_reversed", + "empty_big_positive_start", + "empty_negative_reversed", + "empty_big_negative_stop", + "arbitrary_sampling_rate", + ], +) +def test_region_temporal_slicing(region, slice_, expected_data): + sub_region = region.millis[slice_] + assert bytes(sub_region) == expected_data + start_sec = slice_.start / 1000 if slice_.start is not None else None + stop_sec = slice_.stop / 1000 if slice_.stop is not None else None + sub_region = region.sec[start_sec:stop_sec] + assert bytes(sub_region) == expected_data - @genty_dataset( - first_half=( + +@pytest.mark.parametrize( + "region, slice_, time_shift, expected_data", + [ + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 80), 0, - b"a" * 80, + b"a" * 80, # first_half ), - second_half=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(80, None), 0.5, - b"b" * 80, + b"b" * 80, # second_half ), - second_half_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-80, None), 0.5, - b"b" * 80, + b"b" * 80, # second_half_negative ), - middle=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(160 // 5, 160 // 4 * 3), 0.2, - b"a" * 48 + b"b" * 40, + b"a" * 48 + b"b" * 40, # middle ), - middle_negative=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-160 // 5 * 4, -160 // 4), 0.2, - b"a" * 48 + b"b" * 40, + b"a" * 48 + b"b" * 40, # middle_negative ), - middle_sw2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 2, 1), slice(160 // 5, 160 // 4 * 3), 0.2, - b"a" * 96 + b"b" * 80, + b"a" * 96 + b"b" * 80, # middle_sw2 ), - middle_ch2=( + ( AudioRegion(b"a" * 160 + b"b" * 160, 160, 1, 2), slice(160 // 5, 160 // 4 * 3), 0.2, - b"a" * 96 + b"b" * 80, + b"a" * 96 + b"b" * 80, # middle_ch2 ), - middle_sw2_ch2=( + ( AudioRegion(b"a" * 320 + b"b" * 320, 160, 2, 2), slice(160 // 5, 160 // 4 * 3), 0.2, - b"a" * 192 + b"b" * 160, + b"a" * 192 + b"b" * 160, # middle_sw2_ch2 ), - but_first_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(1, None), 1 / 8000, - b"a" * (4000 - 1) + b"b" * 4000, + b"a" * (4000 - 1) + b"b" * 4000, # but_first_sample ), - but_first_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(-7999, None), 1 / 8000, - b"a" * (4000 - 1) + b"b" * 4000, + b"a" * (4000 - 1) + b"b" * 4000, # but_first_sample_negative ), - but_last_sample=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, 7999), 0, - b"a" * 4000 + b"b" * (4000 - 1), + b"a" * 4000 + b"b" * (4000 - 1), # but_last_sample ), - but_last_sample_negative=( + ( AudioRegion(b"a" * 4000 + b"b" * 4000, 8000, 1, 1), slice(0, -1), 0, - b"a" * 4000 + b"b" * (4000 - 1), + b"a" * 4000 + b"b" * (4000 - 1), # but_last_sample_negative ), - big_negative_start=( + ( AudioRegion(b"a" * 160, 160, 1, 1), slice(-1600, None), 0, - b"a" * 160, + b"a" * 160, # big_negative_start ), - big_negative_stop=( + ( AudioRegion(b"a" * 160, 160, 1, 1), slice(None, -1600), 0, - b"", + b"", # big_negative_stop ), - empty=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, 0), 0, - b"", + b"", # empty ), - empty_start_stop_reversed=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(80, 40), 0.5, - b"", + b"", # empty_start_stop_reversed ), - empty_big_positive_start=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(1600, 3000), 10, - b"", + b"", # empty_big_positive_start ), - empty_negative_reversed=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(-16, -32), 0.9, - b"", + b"", # empty_negative_reversed ), - empty_big_negative_stop=( + ( AudioRegion(b"a" * 80 + b"b" * 80, 160, 1, 1), slice(0, -2000), 0, - b"", + b"", # empty_big_negative_stop ), - arbitrary_sampling_rate=( + ( AudioRegion(b"a" * 124 + b"b" * 376, 1235, 1, 1), slice(100, 200), 100 / 1235, - b"a" * 24 + b"b" * 76, + b"a" * 24 + b"b" * 76, # arbitrary_sampling_rate ), - arbitrary_sampling_rate_middle_sw2_ch2=( + ( AudioRegion(b"a" * 124 + b"b" * 376, 1235, 2, 2), slice(25, 50), 25 / 1235, - b"a" * 24 + b"b" * 76, + b"a" * 24 + b"b" * 76, # arbitrary_sampling_rate_middle_sw2_ch2 ), + ], + ids=[ + "first_half", + "second_half", + "second_half_negative", + "middle", + "middle_negative", + "middle_sw2", + "middle_ch2", + "middle_sw2_ch2", + "but_first_sample", + "but_first_sample_negative", + "but_last_sample", + "but_last_sample_negative", + "big_negative_start", + "big_negative_stop", + "empty", + "empty_start_stop_reversed", + "empty_big_positive_start", + "empty_negative_reversed", + "empty_big_negative_stop", + "arbitrary_sampling_rate", + "arbitrary_sampling_rate_middle_sw2_ch2", + ], +) +def test_region_sample_slicing(region, slice_, time_shift, expected_data): + sub_region = region[slice_] + assert bytes(sub_region) == expected_data + + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (8000, 1, 1), # simple + (8000, 2, 2), # stereo_sw_2 + (5413, 2, 3), # arbitrary_sr_multichannel + ], + ids=[ + "simple", + "stereo_sw_2", + "arbitrary_sr_multichannel", + ], +) +def test_concatenation(sampling_rate, sample_width, channels): + + region_1, region_2 = _make_random_length_regions( + [b"a", b"b"], sampling_rate, sample_width, channels ) - def test_region_sample_slicing( - self, region, slice_, time_shift, expected_data + expected_duration = region_1.duration + region_2.duration + expected_data = bytes(region_1) + bytes(region_2) + concat_region = region_1 + region_2 + assert concat_region.duration == pytest.approx(expected_duration, abs=1e-6) + assert bytes(concat_region) == expected_data + + +@pytest.mark.parametrize( + "sampling_rate, sample_width, channels", + [ + (8000, 1, 1), # simple + (8000, 2, 2), # stereo_sw_2 + (5413, 2, 3), # arbitrary_sr_multichannel + ], + ids=[ + "simple", + "stereo_sw_2", + "arbitrary_sr_multichannel", + ], +) +def test_concatenation_many(sampling_rate, sample_width, channels): + + regions = _make_random_length_regions( + [b"a", b"b", b"c"], sampling_rate, sample_width, channels + ) + expected_duration = sum(r.duration for r in regions) + expected_data = b"".join(bytes(r) for r in regions) + concat_region = sum(regions) + + assert concat_region.duration == pytest.approx(expected_duration, abs=1e-6) + assert bytes(concat_region) == expected_data + + +def test_concatenation_different_sampling_rate_error(): + region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) + region_2 = AudioRegion(b"b" * 100, 3000, 1, 1) + + with pytest.raises(AudioParameterError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same " + "sampling rate (8000 != 3000)" # different_sampling_rate + ) + + +def test_concatenation_different_sample_width_error(): + region_1 = AudioRegion(b"a" * 100, 8000, 2, 1) + region_2 = AudioRegion(b"b" * 100, 8000, 4, 1) + + with pytest.raises(AudioParameterError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same sample width (2 != 4)" + ) + + +def test_concatenation_different_number_of_channels_error(): + region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) + region_2 = AudioRegion(b"b" * 100, 8000, 1, 2) + + with pytest.raises(AudioParameterError) as val_err: + region_1 + region_2 + assert str(val_err.value) == ( + "Can only concatenate AudioRegions of the same " + "number of channels (1 != 2)" # different_number_of_channels + ) + + +@pytest.mark.parametrize( + "duration, expected_duration, expected_len, expected_len_ms", + [ + (0.01, 0.03, 240, 30), # simple + (0.00575, 0.01725, 138, 17), # rounded_len_floor + (0.00625, 0.01875, 150, 19), # rounded_len_ceil + ], + ids=[ + "simple", + "rounded_len_floor", + "rounded_len_ceil", + ], +) +def test_multiplication( + duration, expected_duration, expected_len, expected_len_ms +): + sw = 2 + data = b"0" * int(duration * 8000 * sw) + region = AudioRegion(data, 8000, sw, 1) + m_region = 1 * region * 3 + assert bytes(m_region) == data * 3 + assert m_region.sr == 8000 + assert m_region.sw == 2 + assert m_region.ch == 1 + assert m_region.duration == expected_duration + assert len(m_region) == expected_len + assert m_region.len == expected_len + assert m_region.s.len == expected_duration + assert len(m_region.ms) == expected_len_ms + assert m_region.ms.len == expected_len_ms + + +@pytest.mark.parametrize( + "factor, _type", + [ + ("x", str), # string + (1.4, float), # float + ], + ids=[ + "string", + "float", + ], +) +def test_multiplication_non_int(factor, _type): + with pytest.raises(TypeError) as type_err: + AudioRegion(b"0" * 80, 8000, 1, 1) * factor + err_msg = "Can't multiply AudioRegion by a non-int of type '{}'" + assert err_msg.format(_type) == str(type_err.value) + + +@pytest.mark.parametrize( + "data", + [ + [b"a" * 80, b"b" * 80], # simple + [b"a" * 31, b"b" * 31, b"c" * 30], # extra_samples_1 + [b"a" * 31, b"b" * 30, b"c" * 30], # extra_samples_2 + [b"a" * 11, b"b" * 11, b"c" * 10, b"c" * 10], # extra_samples_3 + ], + ids=[ + "simple", + "extra_samples_1", + "extra_samples_2", + "extra_samples_3", + ], +) +def test_truediv(data): + + region = AudioRegion(b"".join(data), 80, 1, 1) + + sub_regions = region / len(data) + for data_i, region in zip( + data, + sub_regions, ): - sub_region = region[slice_] - self.assertEqual(bytes(sub_region), expected_data) + assert len(data_i) == len(bytes(region)) - @genty_dataset( - simple=(8000, 1, 1), - stereo_sw_2=(8000, 2, 2), - arbitrary_sr_multichannel=(5413, 2, 3), - ) - def test_concatenation(self, sampling_rate, sample_width, channels): - region_1, region_2 = _make_random_length_regions( - [b"a", b"b"], sampling_rate, sample_width, channels - ) - expected_duration = region_1.duration + region_2.duration - expected_data = bytes(region_1) + bytes(region_2) - concat_region = region_1 + region_2 - self.assertAlmostEqual( - concat_region.duration, expected_duration, places=6 - ) - self.assertEqual(bytes(concat_region), expected_data) +@pytest.mark.parametrize( + "data, sample_width, channels, expected", + [ + (b"a" * 10, 1, 1, [97] * 10), # mono_sw_1 + (b"a" * 10, 2, 1, [24929] * 5), # mono_sw_2 + (b"a" * 8, 4, 1, [1633771873] * 2), # mono_sw_4 + (b"ab" * 5, 1, 2, [[97] * 5, [98] * 5]), # stereo_sw_1 + ], + ids=[ + "mono_sw_1", + "mono_sw_2", + "mono_sw_4", + "stereo_sw_1", + ], +) +def test_samples(data, sample_width, channels, expected): - @genty_dataset( - simple=(8000, 1, 1), - stereo_sw_2=(8000, 2, 2), - arbitrary_sr_multichannel=(5413, 2, 3), - ) - def test_concatenation_many(self, sampling_rate, sample_width, channels): - - regions = _make_random_length_regions( - [b"a", b"b", b"c"], sampling_rate, sample_width, channels - ) - expected_duration = sum(r.duration for r in regions) - expected_data = b"".join(bytes(r) for r in regions) - concat_region = sum(regions) - - self.assertAlmostEqual( - concat_region.duration, expected_duration, places=6 - ) - self.assertEqual(bytes(concat_region), expected_data) - - def test_concatenation_different_sampling_rate_error(self): - - region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) - region_2 = AudioRegion(b"b" * 100, 3000, 1, 1) - - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "sampling rate (8000 != 3000)", - str(val_err.exception), - ) - - def test_concatenation_different_sample_width_error(self): - - region_1 = AudioRegion(b"a" * 100, 8000, 2, 1) - region_2 = AudioRegion(b"b" * 100, 8000, 4, 1) - - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "sample width (2 != 4)", - str(val_err.exception), - ) - - def test_concatenation_different_number_of_channels_error(self): - - region_1 = AudioRegion(b"a" * 100, 8000, 1, 1) - region_2 = AudioRegion(b"b" * 100, 8000, 1, 2) - - with self.assertRaises(ValueError) as val_err: - region_1 + region_2 - self.assertEqual( - "Can only concatenate AudioRegions of the same " - "number of channels (1 != 2)", - str(val_err.exception), - ) - - @genty_dataset( - simple=(0.01, 0.03, 240, 30), - rounded_len_floor=(0.00575, 0.01725, 138, 17), - rounded_len_ceil=(0.00625, 0.01875, 150, 19), - ) - def test_multiplication( - self, duration, expected_duration, expected_len, expected_len_ms - ): - sw = 2 - data = b"0" * int(duration * 8000 * sw) - region = AudioRegion(data, 8000, sw, 1) - m_region = 1 * region * 3 - self.assertEqual(bytes(m_region), data * 3) - self.assertEqual(m_region.sr, 8000) - self.assertEqual(m_region.sw, 2) - self.assertEqual(m_region.ch, 1) - self.assertEqual(m_region.duration, expected_duration) - self.assertEqual(len(m_region), expected_len) - self.assertEqual(m_region.len, expected_len) - self.assertEqual(m_region.s.len, expected_duration) - self.assertEqual(len(m_region.ms), expected_len_ms) - self.assertEqual(m_region.ms.len, expected_len_ms) - - @genty_dataset(_str=("x", "str"), _float=(1.4, "float")) - def test_multiplication_non_int(self, factor, _type): - with self.assertRaises(TypeError) as type_err: - AudioRegion(b"0" * 80, 8000, 1, 1) * factor - err_msg = "Can't multiply AudioRegion by a non-int of type '{}'" - self.assertEqual(err_msg.format(_type), str(type_err.exception)) - - @genty_dataset( - simple=([b"a" * 80, b"b" * 80],), - extra_samples_1=([b"a" * 31, b"b" * 31, b"c" * 30],), - extra_samples_2=([b"a" * 31, b"b" * 30, b"c" * 30],), - extra_samples_3=([b"a" * 11, b"b" * 11, b"c" * 10, b"c" * 10],), - ) - def test_truediv(self, data): - - region = AudioRegion(b"".join(data), 80, 1, 1) - - sub_regions = region / len(data) - for data_i, region in zip(data, sub_regions): - self.assertEqual(len(data_i), len(bytes(region))) - - @genty_dataset( - mono_sw_1=(b"a" * 10, 1, 1, "b", [97] * 10), - mono_sw_2=(b"a" * 10, 2, 1, "h", [24929] * 5), - mono_sw_4=(b"a" * 8, 4, 1, "i", [1633771873] * 2), - stereo_sw_1=(b"ab" * 5, 1, 2, "b", [[97] * 5, [98] * 5]), - ) - def test_samples(self, data, sample_width, channels, fmt, expected): - - region = AudioRegion(data, 10, sample_width, channels) - if isinstance(expected[0], list): - expected = [array_(fmt, exp) for exp in expected] - else: - expected = array_(fmt, expected) - samples = region.samples - equal = samples == expected - try: - # for numpy - equal = equal.all() - except AttributeError: - pass - self.assertTrue(equal) - - -if __name__ == "__main__": - unittest.main() + region = AudioRegion(data, 10, sample_width, channels) + expected = np.array(expected) + assert (region.samples == expected).all() + assert (region.numpy() == expected).all() + assert (np.array(region) == expected).all()
--- a/tests/test_io.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_io.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,446 +1,831 @@ +import filecmp import os -import sys -import math -from array import array +import wave +from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -import filecmp -import unittest -from unittest import TestCase -from unittest.mock import patch, Mock -from genty import genty, genty_dataset -from test_util import _sample_generator, _generate_pure_tone, PURE_TONE_DICT -from auditok.signal import FORMAT +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from test_AudioSource import PURE_TONE_DICT, _sample_generator + +import auditok from auditok.io import ( AudioIOError, AudioParameterError, BufferAudioSource, RawAudioSource, + StdinAudioSource, WaveAudioSource, - StdinAudioSource, - check_audio_data, + _get_audio_parameters, _guess_audio_format, - _get_audio_parameters, _load_raw, _load_wave, _load_with_pydub, - get_audio_source, - from_file, _save_raw, _save_wave, _save_with_pydub, + check_audio_data, + from_file, + get_audio_source, to_file, ) +from auditok.signal import SAMPLE_WIDTH_TO_DTYPE +AUDIO_PARAMS = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} AUDIO_PARAMS_SHORT = {"sr": 16000, "sw": 2, "ch": 1} -@genty -class TestIO(TestCase): - @genty_dataset( - valid_mono=(b"\0" * 113, 1, 1), - valid_stereo=(b"\0" * 160, 1, 2), - invalid_mono_sw_2=(b"\0" * 113, 2, 1, False), - invalid_stereo_sw_1=(b"\0" * 113, 1, 2, False), - invalid_stereo_sw_2=(b"\0" * 158, 2, 2, False), +@pytest.mark.parametrize( + "data, sample_width, channels, valid", + [ + (b"\0" * 113, 1, 1, True), # valid_mono + (b"\0" * 160, 1, 2, True), # valid_stereo + (b"\0" * 113, 2, 1, False), # invalid_mono_sw_2 + (b"\0" * 113, 1, 2, False), # invalid_stereo_sw_1 + (b"\0" * 158, 2, 2, False), # invalid_stereo_sw_2 + ], + ids=[ + "valid_mono", + "valid_stereo", + "invalid_mono_sw_2", + "invalid_stereo_sw_1", + "invalid_stereo_sw_2", + ], +) +def test_check_audio_data(data, sample_width, channels, valid): + if not valid: + with pytest.raises(AudioParameterError): + check_audio_data(data, sample_width, channels) + else: + assert check_audio_data(data, sample_width, channels) is None + + +@pytest.mark.parametrize( + "filename, audio_format, expected", + [ + ("filename.wav", "wav", "wav"), # extension_and_format_same + ("filename.mp3", "wav", "wav"), # extension_and_format_different + ("filename.wav", None, "wav"), # extension_no_format + ("filename", "wav", "wav"), # format_no_extension + ("filename", None, None), # no_format_no_extension + ("filename", "wave", "wav"), # wave_as_wav + ("filename.wave", None, "wav"), # wave_as_wav_extension + ], + ids=[ + "extension_and_format_same", + "extension_and_format_different", + "extension_no_format", + "format_no_extension", + "no_format_no_extension", + "wave_as_wav", + "wave_as_wav_extension", + ], +) +def test_guess_audio_format(filename, audio_format, expected): + result = _guess_audio_format(filename, audio_format) + assert result == expected + + result = _guess_audio_format(Path(filename), audio_format) + assert result == expected + + +def test_get_audio_parameters_short_params(): + expected = (8000, 2, 1) + params = dict(zip(("sr", "sw", "ch"), expected)) + result = _get_audio_parameters(params) + assert result == expected + + +def test_get_audio_parameters_long_params(): + expected = (8000, 2, 1) + params = dict(zip(("sampling_rate", "sample_width", "channels"), expected)) + result = _get_audio_parameters(params) + assert result == expected + + +def test_get_audio_parameters_long_params_shadow_short_ones(): + expected = (8000, 2, 1) + params = dict( + zip( + ("sampling_rate", "sample_width", "channels"), + expected, + ) ) - def test_check_audio_data(self, data, sample_width, channels, valid=True): - - if not valid: - with self.assertRaises(AudioParameterError): - check_audio_data(data, sample_width, channels) - else: - self.assertIsNone(check_audio_data(data, sample_width, channels)) - - @genty_dataset( - extention_and_format_same=("wav", "filename.wav", "wav"), - extention_and_format_different=("wav", "filename.mp3", "wav"), - extention_no_format=(None, "filename.wav", "wav"), - format_no_extension=("wav", "filename", "wav"), - no_format_no_extension=(None, "filename", None), - wave_as_wav=("wave", "filename", "wav"), - wave_as_wav_extension=(None, "filename.wave", "wav"), - ) - def test_guess_audio_format(self, fmt, filename, expected): - result = _guess_audio_format(fmt, filename) - self.assertEqual(result, expected) - - def test_get_audio_parameters_short_params(self): - expected = (8000, 2, 1) - params = dict(zip(("sr", "sw", "ch"), expected)) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) - - def test_get_audio_parameters_long_params(self): - expected = (8000, 2, 1) - params = dict( + params.update( + dict( zip( - ("sampling_rate", "sample_width", "channels", "use_channel"), - expected, + ("sr", "sw", "ch"), + "xxx", ) ) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) + ) + result = _get_audio_parameters(params) + assert result == expected - def test_get_audio_parameters_long_params_shadow_short_ones(self): - expected = (8000, 2, 1) - params = dict( - zip(("sampling_rate", "sample_width", "channels"), expected) + +@pytest.mark.parametrize( + "missing_param", + [ + "sampling_rate", # missing_sampling_rate + "sample_width", # missing_sample_width + "channels", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_get_audio_parameters_missing_parameter(missing_param): + params = AUDIO_PARAMS.copy() + del params[missing_param] + with pytest.raises(AudioParameterError): + _get_audio_parameters(params) + + +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_get_audio_parameters_missing_parameter_short(missing_param): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + with pytest.raises(AudioParameterError): + _get_audio_parameters(params) + + +@pytest.mark.parametrize( + "values", + [ + ("x", 2, 1), # str_sampling_rate + (-8000, 2, 1), # negative_sampling_rate + (8000, "x", 1), # str_sample_width + (8000, -2, 1), # negative_sample_width + (8000, 2, "x"), # str_channels + (8000, 2, -1), # negative_channels + ], + ids=[ + "str_sampling_rate", + "negative_sampling_rate", + "str_sample_width", + "negative_sample_width", + "str_channels", + "negative_channels", + ], +) +def test_get_audio_parameters_invalid(values): + params = dict( + zip( + ("sampling_rate", "sample_width", "channels"), + values, ) - params.update(dict(zip(("sr", "sw", "ch"), "xxx"))) - result = _get_audio_parameters(params) - self.assertEqual(result, expected) + ) + with pytest.raises(AudioParameterError): + _get_audio_parameters(params) - @genty_dataset( - str_sampling_rate=(("x", 2, 1),), - negative_sampling_rate=((-8000, 2, 1),), - str_sample_width=((8000, "x", 1),), - negative_sample_width=((8000, -2, 1),), - str_channels=((8000, 2, "x"),), - negative_channels=((8000, 2, -1),), - ) - def test_get_audio_parameters_invalid(self, values): - params = dict( - zip(("sampling_rate", "sample_width", "channels"), values) - ) - with self.assertRaises(AudioParameterError): - _get_audio_parameters(params) - @genty_dataset( - raw_with_audio_format=( +@pytest.mark.parametrize( + "filename, audio_format, funtion_name, kwargs", + [ + ( "audio", "raw", "_load_raw", AUDIO_PARAMS_SHORT, - ), - raw_with_extension=( + ), # raw_with_audio_format + ( "audio.raw", None, "_load_raw", AUDIO_PARAMS_SHORT, - ), - wave_with_audio_format=("audio", "wave", "_load_wave"), - wav_with_audio_format=("audio", "wave", "_load_wave"), - wav_with_extension=("audio.wav", None, "_load_wave"), - format_and_extension_both_given=("audio.dat", "wav", "_load_wave"), - format_and_extension_both_given_b=("audio.raw", "wave", "_load_wave"), - no_format_nor_extension=("audio", None, "_load_with_pydub"), - other_formats_ogg=("audio.ogg", None, "_load_with_pydub"), - other_formats_webm=("audio", "webm", "_load_with_pydub"), + ), # raw_with_extension + ("audio", "wave", "_load_wave", None), # wave_with_audio_format + ("audio", "wave", "_load_wave", None), # wav_with_audio_format + ("audio.wav", None, "_load_wave", None), # wav_with_extension + ( + "audio.dat", + "wav", + "_load_wave", + None, + ), # format_and_extension_both_given_a + ( + "audio.raw", + "wave", + "_load_wave", + None, + ), # format_and_extension_both_given_b + ("audio", None, "_load_with_pydub", None), # no_format_nor_extension + ("audio.ogg", None, "_load_with_pydub", None), # other_formats_ogg + ("audio", "webm", "_load_with_pydub", None), # other_formats_webm + ], + ids=[ + "raw_with_audio_format", + "raw_with_extension", + "wave_with_audio_format", + "wav_with_audio_format", + "wav_with_extension", + "format_and_extension_both_given_a", + "format_and_extension_both_given_b", + "no_format_nor_extension", + "other_formats_ogg", + "other_formats_webm", + ], +) +def test_from_file(filename, audio_format, funtion_name, kwargs): + funtion_name = "auditok.io." + funtion_name + if kwargs is None: + kwargs = {} + with patch(funtion_name) as patch_function: + from_file(filename, audio_format, **kwargs) + assert patch_function.called + + +@pytest.mark.parametrize( + "large_file, cls, size, use_pathlib", + [ + (False, BufferAudioSource, -1, False), # large_file_false_negative_size + (False, BufferAudioSource, None, False), # large_file_false_None_size + ( + False, + BufferAudioSource, + None, + True, + ), # large_file_false_None_size_Path + (True, RawAudioSource, -1, False), # large_file_true_negative_size + (True, RawAudioSource, None, False), # large_file_true_None_size + (True, RawAudioSource, -1, True), # large_file_true_negative_size_Path + ], + ids=[ + "large_file_false_negative_size", + "large_file_false_None_size", + "large_file_false_None_size_Path", + "large_file_true_negative_size", + "large_file_true_None_size", + "large_file_true_negative_size_Path", + ], +) +def test_from_file_raw_read_all(large_file, cls, size, use_pathlib): + filename = Path("tests/data/test_16KHZ_mono_400Hz.raw") + if use_pathlib: + filename = Path(filename) + audio_source = from_file( + filename, + large_file=large_file, + sampling_rate=16000, + sample_width=2, + channels=1, ) - def test_from_file( - self, filename, audio_format, funtion_name, kwargs=None - ): - funtion_name = "auditok.io." + funtion_name - if kwargs is None: - kwargs = {} - with patch(funtion_name) as patch_function: - from_file(filename, audio_format, **kwargs) - self.assertTrue(patch_function.called) + assert isinstance(audio_source, cls) - def test_from_file_large_file_raw(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.raw" - audio_source = from_file( - filename, - large_file=True, - sampling_rate=16000, - sample_width=2, - channels=1, - ) - self.assertIsInstance(audio_source, RawAudioSource) + with open(filename, "rb") as fp: + expected = fp.read() + audio_source.open() + data = audio_source.read(size) + audio_source.close() + assert data == expected - def test_from_file_large_file_wave(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.wav" - audio_source = from_file(filename, large_file=True) - self.assertIsInstance(audio_source, WaveAudioSource) - def test_from_file_large_file_compressed(self,): - filename = "tests/data/test_16KHZ_mono_400Hz.ogg" - with self.assertRaises(AudioIOError): - from_file(filename, large_file=True) +@pytest.mark.parametrize( + "large_file, cls, size, use_pathlib", + [ + (False, BufferAudioSource, -1, False), # large_file_false_negative_size + (False, BufferAudioSource, None, False), # large_file_false_None_size + ( + False, + BufferAudioSource, + None, + True, + ), # large_file_false_None_size_Path + (True, WaveAudioSource, -1, False), # large_file_true_negative_size + (True, WaveAudioSource, None, False), # large_file_true_None_size + (True, WaveAudioSource, -1, True), # large_file_true_negative_size_Path + ], + ids=[ + "large_file_false_negative_size", + "large_file_false_None_size", + "large_file_false_None_size_Path", + "large_file_true_negative_size", + "large_file_true_None_size", + "large_file_true_negative_size_Path", + ], +) +def test_from_file_wave_read_all(large_file, cls, size, use_pathlib): + filename = "tests/data/test_16KHZ_mono_400Hz.wav" + if use_pathlib: + filename = Path(filename) + audio_source = from_file( + filename, + large_file=large_file, + sampling_rate=16000, + sample_width=2, + channels=1, + ) + assert isinstance(audio_source, cls) - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_from_file_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - from_file("audio", audio_format="raw", **params) + with wave.open(str(filename)) as fp: + expected = fp.readframes(-1) + audio_source.open() + data = audio_source.read(size) + audio_source.close() + assert data == expected - def test_from_file_no_pydub(self): - with patch("auditok.io._WITH_PYDUB", False): - with self.assertRaises(AudioIOError): - from_file("audio", "mp3") - @patch("auditok.io._WITH_PYDUB", True) - @patch("auditok.io.BufferAudioSource") - @genty_dataset( - ogg_first_channel=("ogg", "from_ogg"), - ogg_second_channel=("ogg", "from_ogg"), - ogg_mix=("ogg", "from_ogg"), - ogg_default=("ogg", "from_ogg"), - mp3_left_channel=("mp3", "from_mp3"), - mp3_right_channel=("mp3", "from_mp3"), - flac_first_channel=("flac", "from_file"), - flac_second_channel=("flac", "from_file"), - flv_left_channel=("flv", "from_flv"), - webm_right_channel=("webm", "from_file"), - ) - def test_from_file_multichannel_audio_compressed( - self, audio_format, function, *mocks - ): - filename = "audio.{}".format(audio_format) - segment_mock = Mock() - segment_mock.sample_width = 2 - segment_mock.channels = 2 - segment_mock._data = b"abcd" - with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: - open_func.return_value = segment_mock - from_file(filename) - self.assertTrue(open_func.called) +def test_from_file_large_file_compressed(): + filename = "tests/data/test_16KHZ_mono_400Hz.ogg" + with pytest.raises(AudioIOError): + from_file(filename, large_file=True) - @genty_dataset( - mono=("mono_400", (400,)), - three_channel=("3channel_400-800-1600", (400, 800, 1600)), - mono_large_file=("mono_400", (400,), True), - three_channel_large_file=( + +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_from_file_missing_audio_param(missing_param): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + with pytest.raises(AudioParameterError): + from_file("audio", audio_format="raw", **params) + + +def test_from_file_no_pydub(): + with patch("auditok.io._WITH_PYDUB", False): + with pytest.raises(AudioIOError): + from_file("audio", "mp3") + + +@pytest.mark.parametrize( + "audio_format, function", + [ + ("ogg", "from_ogg"), # ogg_first_channel + ("ogg", "from_ogg"), # ogg_second_channel + ("ogg", "from_ogg"), # ogg_mix + ("ogg", "from_ogg"), # ogg_default + ("mp3", "from_mp3"), # mp3_left_channel + ("mp3", "from_mp3"), # mp3_right_channel + ("flac", "from_file"), # flac_first_channel + ("flac", "from_file"), # flac_second_channel + ("flv", "from_flv"), # flv_left_channel + ("webm", "from_file"), # webm_right_channel + ], + ids=[ + "ogg_first_channel", + "ogg_second_channel", + "ogg_mix", + "ogg_default", + "mp3_left_channel", + "mp3_right_channel", + "flac_first_channel", + "flac_second_channel", + "flv_left_channel", + "webm_right_channel", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +@patch("auditok.io.BufferAudioSource") +def test_from_file_multichannel_audio_compressed( + mock_buffer_audio_source, audio_format, function +): + filename = "audio.{}".format(audio_format) + segment_mock = Mock() + segment_mock.sample_width = 2 + segment_mock.channels = 2 + segment_mock._data = b"abcd" + with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: + open_func.return_value = segment_mock + from_file(filename) + assert open_func.called + + +@pytest.mark.parametrize( + "file_id, frequencies, large_file", + [ + ("mono_400", (400,), False), # mono + ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel + ("mono_400", (400,), True), # mono_large_file + ( "3channel_400-800-1600", (400, 800, 1600), True, - ), + ), # three_channel_large_file + ], + ids=[ + "mono", + "three_channel", + "mono_large_file", + "three_channel_large_file", + ], +) +def test_load_raw(file_id, frequencies, large_file): + filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) + audio_source = _load_raw( + filename, 16000, 2, len(frequencies), large_file=large_file ) - def test_load_raw(self, file_id, frequencies, large_file=False): - filename = "tests/data/test_16KHZ_{}Hz.raw".format(file_id) - audio_source = _load_raw( - filename, 16000, 2, len(frequencies), large_file=large_file - ) - audio_source.open() - data = audio_source.read(-1) - audio_source.close() - expected_class = RawAudioSource if large_file else BufferAudioSource - self.assertIsInstance(audio_source, expected_class) - self.assertEqual(audio_source.sampling_rate, 16000) - self.assertEqual(audio_source.sample_width, 2) - self.assertEqual(audio_source.channels, len(frequencies)) - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - self.assertEqual(data, expected) + audio_source.open() + data = audio_source.read(-1) + audio_source.close() + expected_class = RawAudioSource if large_file else BufferAudioSource + assert isinstance(audio_source, expected_class) + assert audio_source.sampling_rate == 16000 + assert audio_source.sample_width == 2 + assert audio_source.channels == len(frequencies) + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + dtype = SAMPLE_WIDTH_TO_DTYPE[audio_source.sample_width] + expected = np.fromiter( + _sample_generator(*mono_channels), dtype=dtype + ).tobytes() + assert data == expected - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_load_raw_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - srate, swidth, channels, _ = _get_audio_parameters(params) - _load_raw("audio", srate, swidth, channels) - @genty_dataset( - mono=("mono_400", (400,)), - three_channel=("3channel_400-800-1600", (400, 800, 1600)), - mono_large_file=("mono_400", (400,), True), - three_channel_large_file=( +def test_load_raw_missing_audio_param(): + with pytest.raises(AudioParameterError): + _load_raw("audio", sampling_rate=None, sample_width=1, channels=1) + + with pytest.raises(AudioParameterError): + _load_raw("audio", sampling_rate=16000, sample_width=None, channels=1) + + with pytest.raises(AudioParameterError): + _load_raw("audio", sampling_rate=16000, sample_width=1, channels=None) + + +@pytest.mark.parametrize( + "file_id, frequencies, large_file", + [ + ("mono_400", (400,), False), # mono + ("3channel_400-800-1600", (400, 800, 1600), False), # three_channel + ("mono_400", (400,), True), # mono_large_file + ( "3channel_400-800-1600", (400, 800, 1600), True, - ), + ), # three_channel_large_file + ], + ids=[ + "mono", + "three_channel", + "mono_large_file", + "three_channel_large_file", + ], +) +def test_load_wave(file_id, frequencies, large_file): + filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) + audio_source = _load_wave(filename, large_file=large_file) + audio_source.open() + data = audio_source.read(-1) + audio_source.close() + expected_class = WaveAudioSource if large_file else BufferAudioSource + assert isinstance(audio_source, expected_class) + assert audio_source.sampling_rate == 16000 + assert audio_source.sample_width == 2 + assert audio_source.channels == len(frequencies) + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + dtype = SAMPLE_WIDTH_TO_DTYPE[audio_source.sample_width] + expected = np.fromiter( + _sample_generator(*mono_channels), dtype=dtype + ).tobytes() + assert data == expected + + +@pytest.mark.parametrize( + "audio_format, channels, function", + [ + ("ogg", 2, "from_ogg"), # ogg_default_first_channel + ("ogg", 1, "from_ogg"), # ogg_first_channel + ("ogg", 2, "from_ogg"), # ogg_second_channel + ("ogg", 3, "from_ogg"), # ogg_mix_channels + ("mp3", 1, "from_mp3"), # mp3_left_channel + ("mp3", 2, "from_mp3"), # mp3_right_channel + ("mp3", 3, "from_mp3"), # mp3_mix_channels + ("flac", 2, "from_file"), # flac_first_channel + ("flac", 2, "from_file"), # flac_second_channel + ("flv", 1, "from_flv"), # flv_left_channel + ("webm", 2, "from_file"), # webm_right_channel + ("webm", 4, "from_file"), # webm_mix_channels + ], + ids=[ + "ogg_default_first_channel", + "ogg_first_channel", + "ogg_second_channel", + "ogg_mix_channels", + "mp3_left_channel", + "mp3_right_channel", + "mp3_mix_channels", + "flac_first_channel", + "flac_second_channel", + "flv_left_channel", + "webm_right_channel", + "webm_mix_channels", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +@patch("auditok.io.BufferAudioSource") +def test_load_with_pydub( + mock_buffer_audio_source, audio_format, channels, function +): + filename = "audio.{}".format(audio_format) + segment_mock = Mock() + segment_mock.sample_width = 2 + segment_mock.channels = channels + segment_mock._data = b"abcdefgh" + with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: + open_func.return_value = segment_mock + _load_with_pydub(filename, audio_format) + assert open_func.called + + +@pytest.mark.parametrize( + "filename, frequencies, use_pathlib", + [ + ("mono_400Hz.raw", (400,), False), # mono + ("mono_400Hz.raw", (400,), True), # mono_pathlib + ( + "3channel_400-800-1600Hz.raw", + (400, 800, 1600), + False, + ), # three_channel + ], + ids=["mono", "three_channel", "use_pathlib"], +) +def test_save_raw(filename, frequencies, use_pathlib): + filename = "tests/data/test_16KHZ_{}".format(filename) + if use_pathlib: + filename = Path(filename) + sample_width = 2 + dtype = SAMPLE_WIDTH_TO_DTYPE[sample_width] + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + data = np.fromiter(_sample_generator(*mono_channels), dtype=dtype).tobytes() + tmpfile = NamedTemporaryFile() + _save_raw(data, tmpfile.name) + assert filecmp.cmp(tmpfile.name, filename, shallow=False) + + +@pytest.mark.parametrize( + "filename, frequencies, use_pathlib", + [ + ("mono_400Hz.wav", (400,), False), # mono + ("mono_400Hz.wav", (400,), True), # mono_pathlib + ( + "3channel_400-800-1600Hz.wav", + (400, 800, 1600), + False, + ), # three_channel + ], + ids=["mono", "mono_pathlib", "three_channel"], +) +def test_save_wave(filename, frequencies, use_pathlib): + filename = "tests/data/test_16KHZ_{}".format(filename) + if use_pathlib: + filename = str(filename) + sampling_rate = 16000 + sample_width = 2 + channels = len(frequencies) + mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] + dtype = SAMPLE_WIDTH_TO_DTYPE[sample_width] + data = np.fromiter(_sample_generator(*mono_channels), dtype=dtype).tobytes() + tmpfile = NamedTemporaryFile() + _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels) + assert filecmp.cmp(tmpfile.name, filename, shallow=False) + + +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_save_wave_missing_audio_param(missing_param): + with pytest.raises(AudioParameterError): + _save_wave( + b"\0\0", "audio", sampling_rate=None, sample_width=1, channels=1 + ) + + with pytest.raises(AudioParameterError): + _save_wave( + b"\0\0", "audio", sampling_rate=16000, sample_width=None, channels=1 + ) + + with pytest.raises(AudioParameterError): + _save_wave( + b"\0\0", "audio", sampling_rate=16000, sample_width=1, channels=None + ) + + +def test_save_with_pydub(): + with patch("auditok.io.AudioSegment.export") as export: + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, "audio.ogg") + _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1) + assert export.called + tmpdir.cleanup() + + +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio", "raw"), # raw_with_audio_format + ("audio.raw", None), # raw_with_extension + ("audio.mp3", "raw"), # raw_with_audio_format_and_extension + ("audio", None), # raw_no_audio_format_nor_extension + ], + ids=[ + "raw_with_audio_format", + "raw_with_extension", + "raw_with_audio_format_and_extension", + "raw_no_audio_format_nor_extension", + ], +) +def test_to_file_raw(filename, audio_format): + exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw" + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, filename) + data = PURE_TONE_DICT[400].tobytes() + to_file(data, filename, audio_format=audio_format) + assert filecmp.cmp(filename, exp_filename, shallow=False) + tmpdir.cleanup() + + +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio", "wav"), # wav_with_audio_format + ("audio.wav", None), # wav_with_extension + ("audio.mp3", "wav"), # wav_with_audio_format_and_extension + ("audio", "wave"), # wave_with_audio_format + ("audio.wave", None), # wave_with_extension + ("audio.mp3", "wave"), # wave_with_audio_format_and_extension + ], + ids=[ + "wav_with_audio_format", + "wav_with_extension", + "wav_with_audio_format_and_extension", + "wave_with_audio_format", + "wave_with_extension", + "wave_with_audio_format_and_extension", + ], +) +def test_to_file_wave(filename, audio_format): + exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav" + tmpdir = TemporaryDirectory() + filename = os.path.join(tmpdir.name, filename) + data = PURE_TONE_DICT[400].tobytes() + to_file( + data, + filename, + audio_format=audio_format, + sampling_rate=16000, + sample_width=2, + channels=1, ) - def test_load_wave(self, file_id, frequencies, large_file=False): - filename = "tests/data/test_16KHZ_{}Hz.wav".format(file_id) - audio_source = _load_wave(filename, large_file=large_file) - audio_source.open() - data = audio_source.read(-1) - audio_source.close() - expected_class = WaveAudioSource if large_file else BufferAudioSource - self.assertIsInstance(audio_source, expected_class) - self.assertEqual(audio_source.sampling_rate, 16000) - self.assertEqual(audio_source.sample_width, 2) - self.assertEqual(audio_source.channels, len(frequencies)) - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - fmt = FORMAT[audio_source.sample_width] - expected = array(fmt, _sample_generator(*mono_channels)).tobytes() - self.assertEqual(data, expected) + assert filecmp.cmp(filename, exp_filename, shallow=False) + tmpdir.cleanup() - @patch("auditok.io._WITH_PYDUB", True) - @patch("auditok.io.BufferAudioSource") - @genty_dataset( - ogg_default_first_channel=("ogg", 2, "from_ogg"), - ogg_first_channel=("ogg", 1, "from_ogg"), - ogg_second_channel=("ogg", 2, "from_ogg"), - ogg_mix_channels=("ogg", 3, "from_ogg"), - mp3_left_channel=("mp3", 1, "from_mp3"), - mp3_right_channel=("mp3", 2, "from_mp3"), - mp3_mix_channels=("mp3", 3, "from_mp3"), - flac_first_channel=("flac", 2, "from_file"), - flac_second_channel=("flac", 2, "from_file"), - flv_left_channel=("flv", 1, "from_flv"), - webm_right_channel=("webm", 2, "from_file"), - webm_mix_channels=("webm", 4, "from_file"), - ) - def test_load_with_pydub(self, audio_format, channels, function, *mocks): - filename = "audio.{}".format(audio_format) - segment_mock = Mock() - segment_mock.sample_width = 2 - segment_mock.channels = channels - segment_mock._data = b"abcdefgh" - with patch("auditok.io.AudioSegment.{}".format(function)) as open_func: - open_func.return_value = segment_mock - _load_with_pydub(filename, audio_format) - self.assertTrue(open_func.called) - @genty_dataset( - mono=("mono_400Hz.raw", (400,)), - three_channel=("3channel_400-800-1600Hz.raw", (400, 800, 1600)), - ) - def test_save_raw(self, filename, frequencies): - filename = "tests/data/test_16KHZ_{}".format(filename) - sample_width = 2 - fmt = FORMAT[sample_width] - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - data = array(fmt, _sample_generator(*mono_channels)).tobytes() - tmpfile = NamedTemporaryFile() - _save_raw(data, tmpfile.name) - self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False)) +@pytest.mark.parametrize( + "missing_param", + [ + "sr", # missing_sampling_rate + "sw", # missing_sample_width + "ch", # missing_channels + ], + ids=["missing_sampling_rate", "missing_sample_width", "missing_channels"], +) +def test_to_file_missing_audio_param(missing_param): + params = AUDIO_PARAMS_SHORT.copy() + del params[missing_param] + with pytest.raises(AudioParameterError): + to_file(b"\0\0", "audio", audio_format="wav", **params) + with pytest.raises(AudioParameterError): + to_file(b"\0\0", "audio", audio_format="mp3", **params) - @genty_dataset( - mono=("mono_400Hz.wav", (400,)), - three_channel=("3channel_400-800-1600Hz.wav", (400, 800, 1600)), - ) - def test_save_wave(self, filename, frequencies): - filename = "tests/data/test_16KHZ_{}".format(filename) - sampling_rate = 16000 - sample_width = 2 - channels = len(frequencies) - fmt = FORMAT[sample_width] - mono_channels = [PURE_TONE_DICT[freq] for freq in frequencies] - data = array(fmt, _sample_generator(*mono_channels)).tobytes() - tmpfile = NamedTemporaryFile() - _save_wave(data, tmpfile.name, sampling_rate, sample_width, channels) - self.assertTrue(filecmp.cmp(tmpfile.name, filename, shallow=False)) - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_save_wave_missing_audio_param(self, missing_param): - with self.assertRaises(AudioParameterError): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - srate, swidth, channels, _ = _get_audio_parameters(params) - _save_wave(b"\0\0", "audio", srate, swidth, channels) +def test_to_file_no_pydub(): + with patch("auditok.io._WITH_PYDUB", False): + with pytest.raises(AudioIOError): + to_file("audio", b"", "mp3") - def test_save_with_pydub(self): - with patch("auditok.io.AudioSegment.export") as export: - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, "audio.ogg") - _save_with_pydub(b"\0\0", filename, "ogg", 16000, 2, 1) - self.assertTrue(export.called) - tmpdir.cleanup() - @genty_dataset( - raw_with_audio_format=("audio", "raw"), - raw_with_extension=("audio.raw", None), - raw_with_audio_format_and_extension=("audio.mp3", "raw"), - raw_no_audio_format_nor_extension=("audio", None), - ) - def test_to_file_raw(self, filename, audio_format): - exp_filename = "tests/data/test_16KHZ_mono_400Hz.raw" +@pytest.mark.parametrize( + "filename, audio_format", + [ + ("audio.ogg", None), # ogg_with_extension + ("audio", "ogg"), # ogg_with_audio_format + ("audio.wav", "ogg"), # ogg_format_with_wrong_extension + ], + ids=[ + "ogg_with_extension", + "ogg_with_audio_format", + "ogg_format_with_wrong_extension", + ], +) +@patch("auditok.io._WITH_PYDUB", True) +def test_to_file_compressed(filename, audio_format): + with patch("auditok.io.AudioSegment.export") as export: tmpdir = TemporaryDirectory() filename = os.path.join(tmpdir.name, filename) - data = PURE_TONE_DICT[400].tobytes() - to_file(data, filename, audio_format=audio_format) - self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False)) + to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT) + assert export.called tmpdir.cleanup() - @genty_dataset( - wav_with_audio_format=("audio", "wav"), - wav_with_extension=("audio.wav", None), - wav_with_audio_format_and_extension=("audio.mp3", "wav"), - wave_with_audio_format=("audio", "wave"), - wave_with_extension=("audio.wave", None), - wave_with_audio_format_and_extension=("audio.mp3", "wave"), - ) - def test_to_file_wave(self, filename, audio_format): - exp_filename = "tests/data/test_16KHZ_mono_400Hz.wav" - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, filename) - data = PURE_TONE_DICT[400].tobytes() - to_file( - data, - filename, - audio_format=audio_format, - sampling_rate=16000, - sample_width=2, - channels=1, - ) - self.assertTrue(filecmp.cmp(filename, exp_filename, shallow=False)) - tmpdir.cleanup() - @genty_dataset( - missing_sampling_rate=("sr",), - missing_sample_width=("sw",), - missing_channels=("ch",), - ) - def test_to_file_missing_audio_param(self, missing_param): - params = AUDIO_PARAMS_SHORT.copy() - del params[missing_param] - with self.assertRaises(AudioParameterError): - to_file(b"\0\0", "audio", audio_format="wav", **params) - with self.assertRaises(AudioParameterError): - to_file(b"\0\0", "audio", audio_format="mp3", **params) - - def test_to_file_no_pydub(self): - with patch("auditok.io._WITH_PYDUB", False): - with self.assertRaises(AudioIOError): - to_file("audio", b"", "mp3") - - @patch("auditok.io._WITH_PYDUB", True) - @genty_dataset( - ogg_with_extension=("audio.ogg", None), - ogg_with_audio_format=("audio", "ogg"), - ogg_format_with_wrong_extension=("audio.wav", "ogg"), - ) - def test_to_file_compressed(self, filename, audio_format, *mocks): - with patch("auditok.io.AudioSegment.export") as export: - tmpdir = TemporaryDirectory() - filename = os.path.join(tmpdir.name, filename) - to_file(b"\0\0", filename, audio_format, **AUDIO_PARAMS_SHORT) - self.assertTrue(export.called) - tmpdir.cleanup() - - @genty_dataset( - string_wave=( +@pytest.mark.parametrize( + "input, expected_type, extra_args", + [ + ( "tests/data/test_16KHZ_mono_400Hz.wav", BufferAudioSource, - ), - string_wave_large_file=( + None, + ), # string_wave + ( "tests/data/test_16KHZ_mono_400Hz.wav", WaveAudioSource, {"large_file": True}, - ), - stdin=("-", StdinAudioSource), - string_raw=("tests/data/test_16KHZ_mono_400Hz.raw", BufferAudioSource), - string_raw_large_file=( + ), # string_wave_large_file + ("-", StdinAudioSource, None), # stdin + ( + "tests/data/test_16KHZ_mono_400Hz.raw", + BufferAudioSource, + None, + ), # string_raw + ( "tests/data/test_16KHZ_mono_400Hz.raw", RawAudioSource, {"large_file": True}, - ), - bytes_=(b"0" * 8000, BufferAudioSource), + ), # string_raw_large_file + (b"0" * 8000, BufferAudioSource, None), # bytes_ + ], + ids=[ + "string_wave", + "string_wave_large_file", + "stdin", + "string_raw", + "string_raw_large_file", + "bytes_", + ], +) +def test_get_audio_source(input, expected_type, extra_args): + kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} + if extra_args is not None: + kwargs.update(extra_args) + audio_source = get_audio_source(input, **kwargs) + assert isinstance(audio_source, expected_type) + assert audio_source.sampling_rate == 16000, ( + "Unexpected sampling rate: audio_source.sampling_rate = " + + f"{audio_source.sampling_rate} instead of 16000" ) - def test_get_audio_source(self, input, expected_type, extra_args=None): - kwargs = {"sampling_rate": 16000, "sample_width": 2, "channels": 1} - if extra_args is not None: - kwargs.update(extra_args) - audio_source = get_audio_source(input, **kwargs) - self.assertIsInstance(audio_source, expected_type) + assert audio_source.sr == 16000, ( + "Unexpected sampling rate: audio_source.sr = " + + f"{audio_source.sr} instead of 16000" + ) + assert audio_source.sample_width == 2, ( + "Unexpected sample width: audio_source.sample_width = " + + f"{audio_source.sample_width} instead of 2" + ) + assert audio_source.sw == 2, ( + "Unexpected sample width: audio_source.sw = " + + f"{audio_source.sw} instead of 2" + ) + assert audio_source.channels == 1, ( + "Unexpected number of channels: audio_source.channels = " + + f"{audio_source.channels} instead of 1" + ) + assert audio_source.ch == 1, ( + "Unexpected number of channels: audio_source.ch = " + + f"{audio_source.ch} instead of 1" + ) -if __name__ == "__main__": - unittest.main() +def test_get_audio_source_alias_prams(): + audio_source = get_audio_source(b"0" * 1600, sr=16000, sw=2, ch=1) + assert audio_source.sampling_rate == 16000, ( + "Unexpected sampling rate: audio_source.sampling_rate = " + + f"{audio_source.sampling_rate} instead of 16000" + ) + assert audio_source.sr == 16000, ( + "Unexpected sampling rate: audio_source.sr = " + + f"{audio_source.sr} instead of 16000" + ) + assert audio_source.sample_width == 2, ( + "Unexpected sample width: audio_source.sample_width = " + + f"{audio_source.sample_width} instead of 2" + ) + assert audio_source.sw == 2, ( + "Unexpected sample width: audio_source.sw = " + + f"{audio_source.sw} instead of 2" + ) + assert audio_source.channels == 1, ( + "Unexpected number of channels: audio_source.channels = " + + f"{audio_source.channels} instead of 1" + ) + assert audio_source.ch == 1, ( + "Unexpected number of channels: audio_source.ch = " + + f"{audio_source.ch} instead of 1" + )
--- a/tests/test_plotting.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_plotting.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,70 +1,73 @@ import os import sys -import unittest -from unittest import TestCase from tempfile import TemporaryDirectory -from genty import genty, genty_dataset + import matplotlib +import pytest -matplotlib.use("AGG") # noqa E402 -import matplotlib.pyplot as plt -from auditok.core import AudioRegion +matplotlib.use("AGG") +import matplotlib.pyplot as plt # noqa E402 -if sys.version_info.minor <= 5: - PREFIX = "py34_py35/" -else: - PREFIX = "" +from auditok.core import AudioRegion # noqa E402 + +SAVE_NEW_IMAGES = False +if SAVE_NEW_IMAGES: + import shutil # noqa E402 matplotlib.rcParams["figure.figsize"] = (10, 4) -@genty -class TestPlotting(TestCase): - @genty_dataset(mono=(1,), stereo=(2,)) - def test_region_plot(self, channels): - type_ = "mono" if channels == 1 else "stereo" - audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) - image_filename = "tests/images/{}plot_{}_region.png".format( - PREFIX, type_ +@pytest.mark.parametrize("channels", [1, 2], ids=["mono", "stereo"]) +def test_region_plot(channels): + type_ = "mono" if channels == 1 else "stereo" + audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) + image_filename = "tests/images/plot_{}_region.png".format(type_) + expected_image = plt.imread(image_filename) + with TemporaryDirectory() as tmpdir: + output_image_filename = os.path.join(tmpdir, "image.png") + region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) + region.plot(show=False, save_as=output_image_filename) + output_image = plt.imread(output_image_filename) + + if SAVE_NEW_IMAGES: + shutil.copy(output_image_filename, image_filename) + assert (output_image == expected_image).all() # mono, stereo + + +@pytest.mark.parametrize( + "channels, use_channel", + [ + (1, None), # mono + (2, "any"), # stereo_any + (2, 0), # stereo_uc_0 + (2, 1), # stereo_uc_1 + (2, "mix"), # stereo_uc_mix + ], + ids=["mono", "stereo_any", "stereo_uc_0", "stereo_uc_1", "stereo_uc_mix"], +) +def test_region_split_and_plot(channels, use_channel): + type_ = "mono" if channels == 1 else "stereo" + audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) + if type_ == "mono": + image_filename = "tests/images/split_and_plot_mono_region.png" + else: + image_filename = ( + f"tests/images/split_and_plot_uc_{use_channel}_stereo_region.png" ) - expected_image = plt.imread(image_filename) - with TemporaryDirectory() as tmpdir: - output_image_filename = os.path.join(tmpdir, "image.png") - region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) - region.plot(show=False, save_as=output_image_filename) - output_image = plt.imread(output_image_filename) - self.assertTrue((output_image == expected_image).all()) - @genty_dataset( - mono=(1,), - stereo_any=(2, "any"), - stereo_uc_0=(2, 0), - stereo_uc_1=(2, 1), - stereo_uc_mix=(2, "mix"), - ) - def test_region_split_and_plot(self, channels, use_channel=None): - type_ = "mono" if channels == 1 else "stereo" - audio_filename = "tests/data/test_split_10HZ_{}.raw".format(type_) - if type_ == "mono": - fmt = "tests/images/{}split_and_plot_mono_region.png" - else: - fmt = "tests/images/{}split_and_plot_uc_{}_stereo_region.png" - image_filename = fmt.format(PREFIX, use_channel) + expected_image = plt.imread(image_filename) + with TemporaryDirectory() as tmpdir: + output_image_filename = os.path.join(tmpdir, "image.png") + region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) + region.split_and_plot( + aw=0.1, + uc=use_channel, + max_silence=0, + show=False, + save_as=output_image_filename, + ) + output_image = plt.imread(output_image_filename) - expected_image = plt.imread(image_filename) - with TemporaryDirectory() as tmpdir: - output_image_filename = os.path.join(tmpdir, "image.png") - region = AudioRegion.load(audio_filename, sr=10, sw=2, ch=channels) - region.split_and_plot( - aw=0.1, - uc=use_channel, - max_silence=0, - show=False, - save_as=output_image_filename, - ) - output_image = plt.imread(output_image_filename) - self.assertTrue((output_image == expected_image).all()) - - -if __name__ == "__main__": - unittest.main() + if SAVE_NEW_IMAGES: + shutil.copy(output_image_filename, image_filename) + assert (output_image == expected_image).all()
--- a/tests/test_signal.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_signal.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,188 +1,94 @@ -import unittest -from unittest import TestCase from array import array as array_ -from genty import genty, genty_dataset + import numpy as np -from auditok import signal as signal_ -from auditok import signal_numpy +import pytest +from auditok import signal -@genty -class TestSignal(TestCase): - def setUp(self): - self.data = b"012345679ABC" - self.numpy_fmt = {"b": np.int8, "h": np.int16, "i": np.int32} +# from auditok import signal as signal_ +# from auditok import signal - @genty_dataset( - int8_mono=(1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]]), - int16_mono=(2, [[12592, 13106, 13620, 14134, 16697, 17218]]), - int32_mono=(4, [[858927408, 926299444, 1128415545]]), - int8_stereo=(1, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]]), - int16_stereo=(2, [[12592, 13620, 16697], [13106, 14134, 17218]]), - int32_3channel=(4, [[858927408], [926299444], [1128415545]]), - ) - def test_to_array(self, sample_width, expected): - channels = len(expected) - expected = [ - array_(signal_.FORMAT[sample_width], xi) for xi in expected - ] - result = signal_.to_array(self.data, sample_width, channels) - result_numpy = signal_numpy.to_array(self.data, sample_width, channels) - self.assertEqual(result, expected) - self.assertTrue((result_numpy == np.asarray(expected)).all()) - self.assertEqual(result_numpy.dtype, np.float64) - @genty_dataset( - int8_1channel_select_0=( - "b", - 1, - 0, - [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], - ), - int8_2channel_select_0=("b", 2, 0, [48, 50, 52, 54, 57, 66]), - int8_3channel_select_0=("b", 3, 0, [48, 51, 54, 65]), - int8_3channel_select_1=("b", 3, 1, [49, 52, 55, 66]), - int8_3channel_select_2=("b", 3, 2, [50, 53, 57, 67]), - int8_4channel_select_0=("b", 4, 0, [48, 52, 57]), - int16_1channel_select_0=( - "h", - 1, - 0, - [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel_select_0=("h", 2, 0, [12592, 13620, 16697]), - int16_2channel_select_1=("h", 2, 1, [13106, 14134, 17218]), - int16_3channel_select_0=("h", 3, 0, [12592, 14134]), - int16_3channel_select_1=("h", 3, 1, [13106, 16697]), - int16_3channel_select_2=("h", 3, 2, [13620, 17218]), - int32_1channel_select_0=( - "i", - 1, - 0, - [858927408, 926299444, 1128415545], - ), - int32_3channel_select_0=("i", 3, 0, [858927408]), - int32_3channel_select_1=("i", 3, 1, [926299444]), - int32_3channel_select_2=("i", 3, 2, [1128415545]), - ) - def test_extract_single_channel(self, fmt, channels, selected, expected): - result = signal_.extract_single_channel( - self.data, fmt, channels, selected - ) - expected = array_(fmt, expected) - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) - result_numpy = signal_numpy.extract_single_channel( - self.data, self.numpy_fmt[fmt], channels, selected - ) - self.assertTrue(all(result_numpy == expected)) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) +@pytest.fixture +def setup_data(): + return b"012345679ABC" - @genty_dataset( - int8_2channel=("b", 2, [48, 50, 52, 54, 61, 66]), - int8_4channel=("b", 4, [50, 54, 64]), - int16_1channel=("h", 1, [12592, 13106, 13620, 14134, 16697, 17218]), - int16_2channel=("h", 2, [12849, 13877, 16958]), - int32_3channel=("i", 3, [971214132]), - ) - def test_compute_average_channel(self, fmt, channels, expected): - result = signal_.compute_average_channel(self.data, fmt, channels) - expected = array_(fmt, expected) - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) - result_numpy = signal_numpy.compute_average_channel( - self.data, self.numpy_fmt[fmt], channels - ) - self.assertTrue(all(result_numpy == expected)) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) - @genty_dataset( - int8_2channel=(1, [48, 50, 52, 54, 61, 66]), - int16_2channel=(2, [12849, 13877, 16957]), - ) - def test_compute_average_channel_stereo(self, sample_width, expected): - result = signal_.compute_average_channel_stereo( - self.data, sample_width - ) - fmt = signal_.FORMAT[sample_width] - expected = array_(fmt, expected) - self.assertEqual(result, expected) - - @genty_dataset( - int8_1channel=( - "b", +@pytest.mark.parametrize( + "sample_width, expected", + [ + ( 1, [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], - ), - int8_2channel=( - "b", - 2, + ), # int8_1channel + ( + 1, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], - ), - int8_4channel=( - "b", - 4, + ), # int8_2channel + ( + 1, [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], - ), - int16_2channel=( - "h", - 2, - [[12592, 13620, 16697], [13106, 14134, 17218]], - ), - int32_3channel=("i", 3, [[858927408], [926299444], [1128415545]]), - ) - def test_separate_channels(self, fmt, channels, expected): - result = signal_.separate_channels(self.data, fmt, channels) - expected = [array_(fmt, exp) for exp in expected] - expected_numpy_fmt = self.numpy_fmt[fmt] - self.assertEqual(result, expected) + ), # int8_4channel + (2, [[12592, 13106, 13620, 14134, 16697, 17218]]), # int16_1channel + (2, [[12592, 13620, 16697], [13106, 14134, 17218]]), # int16_2channel + (4, [[858927408, 926299444, 1128415545]]), # int32_1channel + (4, [[858927408], [926299444], [1128415545]]), # int32_3channel + ], + ids=[ + "int8_1channel", + "int8_2channel", + "int8_4channel", + "int16_1channel", + "int16_2channel", + "int32_1channel", + "int32_3channel", + ], +) +def test_to_array(setup_data, sample_width, expected): + data = setup_data + channels = len(expected) + expected = np.array(expected) + result = signal.to_array(data, sample_width, channels) + assert (result == expected).all() + assert result.dtype == np.float64 + assert result.shape == expected.shape - result_numpy = signal_numpy.separate_channels( - self.data, self.numpy_fmt[fmt], channels - ) - self.assertTrue((result_numpy == expected).all()) - self.assertEqual(result_numpy.dtype, expected_numpy_fmt) - @genty_dataset( - simple=([300, 320, 400, 600], 2, 52.50624901923348), - zero=([0], 2, -200), - zeros=([0, 0, 0], 2, -200), - ) - def test_calculate_energy_single_channel(self, x, sample_width, expected): - x = array_(signal_.FORMAT[sample_width], x) - energy = signal_.calculate_energy_single_channel(x, sample_width) - self.assertEqual(energy, expected) - energy = signal_numpy.calculate_energy_single_channel(x, sample_width) - self.assertEqual(energy, expected) - - @genty_dataset( - min_=( +@pytest.mark.parametrize( + "x, aggregation_fn, expected", + [ + ([300, 320, 400, 600], None, 52.506639194632434), # mono_simple + ([0, 0, 0], None, -200), # mono_zeros + ( [[300, 320, 400, 600], [150, 160, 200, 300]], - 2, + None, + [52.506639194632434, 46.48603928135281], + ), # stereo_no_agg + ( + [[300, 320, 400, 600], [150, 160, 200, 300]], + np.mean, + 49.49633923799262, + ), # stereo_mean_agg + ( + [[300, 320, 400, 600], [150, 160, 200, 300]], min, - 46.485649105953854, - ), - max_=( + 46.48603928135281, + ), # stereo_min_agg + ( [[300, 320, 400, 600], [150, 160, 200, 300]], - 2, max, - 52.50624901923348, - ), - ) - def test_calculate_energy_multichannel( - self, x, sample_width, aggregation_fn, expected - ): - x = [array_(signal_.FORMAT[sample_width], xi) for xi in x] - energy = signal_.calculate_energy_multichannel( - x, sample_width, aggregation_fn - ) - self.assertEqual(energy, expected) - - energy = signal_numpy.calculate_energy_multichannel( - x, sample_width, aggregation_fn - ) - self.assertEqual(energy, expected) - - -if __name__ == "__main__": - unittest.main() + 52.506639194632434, + ), # stereo_max_agg + ], + ids=[ + "mono_simple", + "mono_zeros", + "stereo_no_agg", + "mean_agg", + "stereo_min_agg", + "stereo_max_agg", + ], +) +def test_calculate_energy(x, aggregation_fn, expected): + energy = signal.calculate_energy(x, aggregation_fn) + assert (energy == expected).all()
--- a/tests/test_util.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_util.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,277 +1,294 @@ -import unittest -from unittest import TestCase -from unittest.mock import patch import math from array import array as array_ -from genty import genty, genty_dataset +from unittest.mock import patch + +import numpy as np +import pytest + +from auditok import signal +from auditok.exceptions import TimeFormatError from auditok.util import ( AudioEnergyValidator, + make_channel_selector, make_duration_formatter, - make_channel_selector, -) -from auditok import signal as signal_ -from auditok import signal_numpy - -from auditok.exceptions import TimeFormatError - - -def _sample_generator(*data_buffers): - """ - Takes a list of many mono audio data buffers and makes a sample generator - of interleaved audio samples, one sample from each channel. The resulting - generator can be used to build a multichannel audio buffer. - >>> gen = _sample_generator("abcd", "ABCD") - >>> list(gen) - ["a", "A", 1, 1, "c", "C", "d", "D"] - """ - frame_gen = zip(*data_buffers) - return (sample for frame in frame_gen for sample in frame) - - -def _generate_pure_tone( - frequency, duration_sec=1, sampling_rate=16000, sample_width=2, volume=1e4 -): - """ - Generates a pure tone with the given frequency. - """ - assert frequency <= sampling_rate / 2 - max_value = (2 ** (sample_width * 8) // 2) - 1 - if volume > max_value: - volume = max_value - fmt = signal_.FORMAT[sample_width] - total_samples = int(sampling_rate * duration_sec) - step = frequency / sampling_rate - two_pi_step = 2 * math.pi * step - data = array_( - fmt, - ( - int(math.sin(two_pi_step * i) * volume) - for i in range(total_samples) - ), - ) - return data - - -PURE_TONE_DICT = { - freq: _generate_pure_tone(freq, 1, 16000, 2) for freq in (400, 800, 1600) -} -PURE_TONE_DICT.update( - { - freq: _generate_pure_tone(freq, 0.1, 16000, 2) - for freq in (600, 1150, 2400, 7220) - } ) -@genty -class TestFunctions(TestCase): - def setUp(self): - self.data = b"012345679ABC" +@pytest.fixture +def setup_data(): + return b"012345679ABC" - @genty_dataset( - only_seconds=("%S", 5400, "5400.000"), - only_millis=("%I", 5400, "5400000"), - full=("%h:%m:%s.%i", 3725.365, "01:02:05.365"), - full_zero_hours=("%h:%m:%s.%i", 1925.075, "00:32:05.075"), - full_zero_minutes=("%h:%m:%s.%i", 3659.075, "01:00:59.075"), - full_zero_seconds=("%h:%m:%s.%i", 3720.075, "01:02:00.075"), - full_zero_millis=("%h:%m:%s.%i", 3725, "01:02:05.000"), - duplicate_directive=( + +@pytest.mark.parametrize( + "fmt, duration, expected", + [ + ("%S", 5400, "5400.000"), # only_seconds + ("%I", 5400, "5400000"), # only_millis + ("%h:%m:%s.%i", 3725.365, "01:02:05.365"), # full + ("%h:%m:%s.%i", 1925.075, "00:32:05.075"), # full_zero_hours + ("%h:%m:%s.%i", 3659.075, "01:00:59.075"), # full_zero_minutes + ("%h:%m:%s.%i", 3720.075, "01:02:00.075"), # full_zero_seconds + ("%h:%m:%s.%i", 3725, "01:02:05.000"), # full_zero_millis + ( "%h %h:%m:%s.%i %s", 3725.365, "01 01:02:05.365 05", - ), - no_millis=("%h:%m:%s", 3725, "01:02:05"), - no_seconds=("%h:%m", 3725, "01:02"), - no_minutes=("%h", 3725, "01"), - no_hours=("%m:%s.%i", 3725, "02:05.000"), - ) - def test_make_duration_formatter(self, fmt, duration, expected): - formatter = make_duration_formatter(fmt) - result = formatter(duration) - self.assertEqual(result, expected) + ), # duplicate_directive + ("%h:%m:%s", 3725, "01:02:05"), # no_millis + ("%h:%m", 3725, "01:02"), # no_seconds + ("%h", 3725, "01"), # no_minutes + ("%m:%s.%i", 3725, "02:05.000"), # no_hours + ], + ids=[ + "only_seconds", + "only_millis", + "full", + "full_zero_hours", + "full_zero_minutes", + "full_zero_seconds", + "full_zero_millis", + "duplicate_directive", + "no_millis", + "no_seconds", + "no_minutes", + "no_hours", + ], +) +def test_make_duration_formatter(fmt, duration, expected): + formatter = make_duration_formatter(fmt) + result = formatter(duration) + assert result == expected - @genty_dataset( - duplicate_only_seconds=("%S %S",), - duplicate_only_millis=("%I %I",), - unknown_directive=("%x",), - ) - def test_make_duration_formatter_error(self, fmt): - with self.assertRaises(TimeFormatError): - make_duration_formatter(fmt) - @genty_dataset( - int8_1channel_select_0=( +@pytest.mark.parametrize( + "fmt", + [ + "%S %S", # duplicate_only_seconds + "%I %I", # duplicate_only_millis + "%x", # unknown_directive + ], + ids=[ + "duplicate_only_seconds", + "duplicate_only_millis", + "unknown_directive", + ], +) +def test_make_duration_formatter_error(fmt): + with pytest.raises(TimeFormatError): + make_duration_formatter(fmt) + + +@pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + ( 1, 1, 0, [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], - ), - int8_2channel_select_0=(1, 2, 0, [48, 50, 52, 54, 57, 66]), - int8_3channel_select_0=(1, 3, 0, [48, 51, 54, 65]), - int8_3channel_select_1=(1, 3, 1, [49, 52, 55, 66]), - int8_3channel_select_2=(1, 3, 2, [50, 53, 57, 67]), - int8_4channel_select_0=(1, 4, 0, [48, 52, 57]), - int16_1channel_select_0=( + ), # int8_1channel_select_0 + (1, 2, 0, [48, 50, 52, 54, 57, 66]), # int8_2channel_select_0 + (1, 3, 0, [48, 51, 54, 65]), # int8_3channel_select_0 + (1, 3, 1, [49, 52, 55, 66]), # int8_3channel_select_1 + (1, 3, 2, [50, 53, 57, 67]), # int8_3channel_select_2 + (1, 4, 0, [48, 52, 57]), # int8_4channel_select_0 + ( 2, 1, 0, [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel_select_0=(2, 2, 0, [12592, 13620, 16697]), - int16_2channel_select_1=(2, 2, 1, [13106, 14134, 17218]), - int16_3channel_select_0=(2, 3, 0, [12592, 14134]), - int16_3channel_select_1=(2, 3, 1, [13106, 16697]), - int16_3channel_select_2=(2, 3, 2, [13620, 17218]), - int32_1channel_select_0=(4, 1, 0, [858927408, 926299444, 1128415545],), - int32_3channel_select_0=(4, 3, 0, [858927408]), - int32_3channel_select_1=(4, 3, 1, [926299444]), - int32_3channel_select_2=(4, 3, 2, [1128415545]), - ) - def test_make_channel_selector_one_channel( - self, sample_width, channels, selected, expected - ): + ), # int16_1channel_select_0 + (2, 2, 0, [12592, 13620, 16697]), # int16_2channel_select_0 + (2, 2, 1, [13106, 14134, 17218]), # int16_2channel_select_1 + (2, 3, 0, [12592, 14134]), # int16_3channel_select_0 + (2, 3, 1, [13106, 16697]), # int16_3channel_select_1 + (2, 3, 2, [13620, 17218]), # int16_3channel_select_2 + ( + 4, + 1, + 0, + [858927408, 926299444, 1128415545], + ), # int32_1channel_select_0 + (4, 3, 0, [858927408]), # int32_3channel_select_0 + (4, 3, 1, [926299444]), # int32_3channel_select_1 + (4, 3, 2, [1128415545]), # int32_3channel_select_2 + ], + ids=[ + "int8_1channel_select_0", + "int8_2channel_select_0", + "int8_3channel_select_0", + "int8_3channel_select_1", + "int8_3channel_select_2", + "int8_4channel_select_0", + "int16_1channel_select_0", + "int16_2channel_select_0", + "int16_2channel_select_1", + "int16_3channel_select_0", + "int16_3channel_select_1", + "int16_3channel_select_2", + "int32_1channel_select_0", + "int32_3channel_select_0", + "int32_3channel_select_1", + "int32_3channel_select_2", + ], +) +def test_make_channel_selector_one_channel( + setup_data, sample_width, channels, selected, expected +): - # force using signal functions with standard python implementation - with patch("auditok.util.signal", signal_): - selector = make_channel_selector(sample_width, channels, selected) - result = selector(self.data) + selector = make_channel_selector(sample_width, channels, selected) + result = selector(setup_data) - fmt = signal_.FORMAT[sample_width] - expected = array_(fmt, expected) - if channels == 1: - expected = bytes(expected) - self.assertEqual(result, expected) + dtype = signal.SAMPLE_WIDTH_TO_DTYPE[sample_width] + expected = np.array(expected).astype(dtype) + assert (result == expected).all() - # Use signal functions with numpy implementation - with patch("auditok.util.signal", signal_numpy): - selector = make_channel_selector(sample_width, channels, selected) - result_numpy = selector(self.data) - expected = array_(fmt, expected) - if channels == 1: - expected = bytes(expected) - self.assertEqual(result_numpy, expected) - else: - self.assertTrue(all(result_numpy == expected)) - - @genty_dataset( - int8_2channel=(1, 2, "avg", [48, 50, 52, 54, 61, 66]), - int8_4channel=(1, 4, "average", [50, 54, 64]), - int16_1channel=( +@pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + ( + 1, + 1, + "avg", + [48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67], + ), # int8_1channel + (1, 2, "mix", [48.5, 50.5, 52.5, 54.5, 61, 66.5]), # int8_2channel + (1, 4, "average", [49.5, 53.5, 63.75]), # int8_4channel + ( 2, 1, "mix", [12592, 13106, 13620, 14134, 16697, 17218], - ), - int16_2channel=(2, 2, "avg", [12849, 13877, 16957]), - int32_3channel=(4, 3, "average", [971214132]), - ) - def test_make_channel_selector_average( - self, sample_width, channels, selected, expected - ): - # force using signal functions with standard python implementation - with patch("auditok.util.signal", signal_): - selector = make_channel_selector(sample_width, channels, selected) - result = selector(self.data) + ), # int16_1channel + (2, 2, "avg", [12849, 13877, 16957.5]), # int16_2channel + (4, 3, "average", [971214132.33]), # int32_3channel + ], + ids=[ + "int8_1channel", + "int8_2channel", + "int8_4channel", + "int16_1channel", + "int16_2channel", + "int32_3channel", + ], +) +def test_make_channel_selector_average( + setup_data, sample_width, channels, selected, expected +): - fmt = signal_.FORMAT[sample_width] - expected = array_(fmt, expected) - if channels == 1: - expected = bytes(expected) - self.assertEqual(result, expected) + selector = make_channel_selector(sample_width, channels, selected) + result = selector(setup_data).round(2) + assert (result == expected).all() - # Use signal functions with numpy implementation - with patch("auditok.util.signal", signal_numpy): - selector = make_channel_selector(sample_width, channels, selected) - result_numpy = selector(self.data) - if channels in (1, 2): - self.assertEqual(result_numpy, expected) - else: - self.assertTrue(all(result_numpy == expected)) - - @genty_dataset( - int8_1channel=( +@pytest.mark.parametrize( + "sample_width, channels, selected, expected", + [ + ( 1, 1, "any", [[48, 49, 50, 51, 52, 53, 54, 55, 57, 65, 66, 67]], - ), - int8_2channel=( + ), # int8_1channel + ( 1, 2, None, [[48, 50, 52, 54, 57, 66], [49, 51, 53, 55, 65, 67]], - ), - int8_4channel=( + ), # int8_2channel + ( 1, 4, "any", [[48, 52, 57], [49, 53, 65], [50, 54, 66], [51, 55, 67]], - ), - int16_2channel=( + ), # int8_4channel + ( 2, 2, None, [[12592, 13620, 16697], [13106, 14134, 17218]], - ), - int32_3channel=(4, 3, "any", [[858927408], [926299444], [1128415545]]), - ) - def test_make_channel_selector_any( - self, sample_width, channels, selected, expected - ): + ), # int16_2channel + ( + 4, + 3, + "any", + [[858927408], [926299444], [1128415545]], + ), # int32_3channel + ], + ids=[ + "int8_1channel", + "int8_2channel", + "int8_4channel", + "int16_2channel", + "int32_3channel", + ], +) +def test_make_channel_selector_any( + setup_data, sample_width, channels, selected, expected +): - # force using signal functions with standard python implementation - with patch("auditok.util.signal", signal_): - selector = make_channel_selector(sample_width, channels, selected) - result = selector(self.data) + # Use signal functions with numpy implementation + selector = make_channel_selector(sample_width, channels, selected) + result = selector(setup_data) + assert (result == expected).all() - fmt = signal_.FORMAT[sample_width] - expected = [array_(fmt, exp) for exp in expected] - if channels == 1: - expected = bytes(expected[0]) - self.assertEqual(result, expected) - # Use signal functions with numpy implementation - with patch("auditok.util.signal", signal_numpy): - selector = make_channel_selector(sample_width, channels, selected) - result_numpy = selector(self.data) - - if channels == 1: - self.assertEqual(result_numpy, expected) - else: - self.assertTrue((result_numpy == expected).all()) - - -@genty -class TestAudioEnergyValidator(TestCase): - @genty_dataset( - mono_valid_uc_None=([350, 400], 1, None, True), - mono_valid_uc_any=([350, 400], 1, "any", True), - mono_valid_uc_0=([350, 400], 1, 0, True), - mono_valid_uc_mix=([350, 400], 1, "mix", True), - # previous cases are all the same since we have mono audio - mono_invalid_uc_None=([300, 300], 1, None, False), - stereo_valid_uc_None=([300, 400, 350, 300], 2, None, True), - stereo_valid_uc_any=([300, 400, 350, 300], 2, "any", True), - stereo_valid_uc_mix=([300, 400, 350, 300], 2, "mix", True), - stereo_valid_uc_avg=([300, 400, 350, 300], 2, "avg", True), - stereo_valid_uc_average=([300, 400, 300, 300], 2, "average", True), - stereo_valid_uc_mix_with_null_channel=( - [634, 0, 634, 0], - 2, - "mix", - True, - ), - stereo_valid_uc_0=([320, 100, 320, 100], 2, 0, True), - stereo_valid_uc_1=([100, 320, 100, 320], 2, 1, True), - stereo_invalid_uc_None=([280, 100, 280, 100], 2, None, False), - stereo_invalid_uc_any=([280, 100, 280, 100], 2, "any", False), - stereo_invalid_uc_mix=([400, 200, 400, 200], 2, "mix", False), - stereo_invalid_uc_0=([300, 400, 300, 400], 2, 0, False), - stereo_invalid_uc_1=([400, 300, 400, 300], 2, 1, False), - zeros=([0, 0, 0, 0], 2, None, False), +class TestAudioEnergyValidator: + @pytest.mark.parametrize( + "data, channels, use_channel, expected", + [ + ([350, 400], 1, None, True), # mono_valid_uc_None + ([350, 400], 1, "any", True), # mono_valid_uc_any + ([350, 400], 1, 0, True), # mono_valid_uc_0 + ([350, 400], 1, "mix", True), # mono_valid_uc_mix + ([300, 300], 1, None, False), # mono_invalid_uc_None + ([300, 400, 350, 300], 2, None, True), # stereo_valid_uc_None + ([300, 400, 350, 300], 2, "any", True), # stereo_valid_uc_any + ([300, 400, 350, 300], 2, "mix", True), # stereo_valid_uc_mix + ([300, 400, 350, 300], 2, "avg", True), # stereo_valid_uc_avg + ( + [300, 400, 300, 300], + 2, + "average", + True, + ), # stereo_valid_uc_average + ( + [634, 0, 634, 0], + 2, + "mix", + True, + ), # stereo_valid_uc_mix_with_null_channel + ([320, 100, 320, 100], 2, 0, True), # stereo_valid_uc_0 + ([100, 320, 100, 320], 2, 1, True), # stereo_valid_uc_1 + ([280, 100, 280, 100], 2, None, False), # stereo_invalid_uc_None + ([280, 100, 280, 100], 2, "any", False), # stereo_invalid_uc_any + ([400, 200, 400, 200], 2, "mix", False), # stereo_invalid_uc_mix + ([300, 400, 300, 400], 2, 0, False), # stereo_invalid_uc_0 + ([400, 300, 400, 300], 2, 1, False), # stereo_invalid_uc_1 + ([0, 0, 0, 0], 2, None, False), # zeros + ], + ids=[ + "mono_valid_uc_None", + "mono_valid_uc_any", + "mono_valid_uc_0", + "mono_valid_uc_mix", + "mono_invalid_uc_None", + "stereo_valid_uc_None", + "stereo_valid_uc_any", + "stereo_valid_uc_mix", + "stereo_valid_uc_avg", + "stereo_valid_uc_average", + "stereo_valid_uc_mix_with_null_channel", + "stereo_valid_uc_0", + "stereo_valid_uc_1", + "stereo_invalid_uc_None", + "stereo_invalid_uc_any", + "stereo_invalid_uc_mix", + "stereo_invalid_uc_0", + "stereo_invalid_uc_1", + "zeros", + ], ) def test_audio_energy_validator( self, data, channels, use_channel, expected @@ -285,10 +302,6 @@ ) if expected: - self.assertTrue(validator.is_valid(data)) + assert validator.is_valid(data) else: - self.assertFalse(validator.is_valid(data)) - - -if __name__ == "__main__": - unittest.main() + assert not validator.is_valid(data)
--- a/tests/test_workers.py Thu Mar 30 10:17:57 2023 +0100 +++ b/tests/test_workers.py Wed Oct 30 17:17:59 2024 +0000 @@ -1,231 +1,94 @@ import os -import unittest -from unittest import TestCase -from unittest.mock import patch, call, Mock from tempfile import TemporaryDirectory -from genty import genty, genty_dataset -from auditok import AudioRegion, AudioDataSource -from auditok.exceptions import AudioEncodingWarning +from unittest.mock import Mock, call, patch + +import pytest + +import auditok.workers +from auditok import AudioReader, AudioRegion, split, split_and_join_with_silence from auditok.cmdline_util import make_logger from auditok.workers import ( + AudioEventsJoinerWorker, + CommandLineWorker, + PlayerWorker, + PrintWorker, + RegionSaverWorker, + StreamSaverWorker, TokenizerWorker, - StreamSaverWorker, - RegionSaverWorker, - PlayerWorker, - CommandLineWorker, - PrintWorker, ) -@genty -class TestWorkers(TestCase): - def setUp(self): +@pytest.fixture +def audio_data_source(): + reader = AudioReader( + input="tests/data/test_split_10HZ_mono.raw", + block_dur=0.1, + sr=10, + sw=2, + ch=1, + ) + yield reader + reader.close() - self.reader = AudioDataSource( - input="tests/data/test_split_10HZ_mono.raw", - block_dur=0.1, - sr=10, - sw=2, - ch=1, + +@pytest.fixture +def expected_detections(): + return [ + (0.2, 1.6), + (1.7, 3.1), + (3.4, 5.4), + (5.4, 7.4), + (7.4, 7.6), + ] + + +def test_TokenizerWorker(audio_data_source, expected_detections): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_TokenizerWorker") + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, ) - self.expected = [ - (0.2, 1.6), - (1.7, 3.1), - (3.4, 5.4), - (5.4, 7.4), - (7.4, 7.6), - ] + tokenizer.start_all() + tokenizer.join() + with open(file) as fp: + log_lines = fp.readlines() - def tearDown(self): - self.reader.close() + log_fmt = ( + "[DET]: Detection {} (start: {:.3f}, end: {:.3f}, duration: {:.3f})" + ) + assert len(tokenizer.detections) == len(expected_detections) + for i, (det, exp, log_line) in enumerate( + zip( + tokenizer.detections, + expected_detections, + log_lines, + ), + 1, + ): + start, end = exp + exp_log_line = log_fmt.format(i, start, end, end - start) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line - def test_TokenizerWorker(self): - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_TokenizerWorker") - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - tokenizer.start_all() - tokenizer.join() - # Get logged text - with open(file) as fp: - log_lines = fp.readlines() - log_fmt = "[DET]: Detection {} (start: {:.3f}, " - log_fmt += "end: {:.3f}, duration: {:.3f})" - self.assertEqual(len(tokenizer.detections), len(self.expected)) - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(i, start, end, end - start) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_PlayerWorker(self): - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_RegionSaverWorker") - player_mock = Mock() - observers = [PlayerWorker(player_mock, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[PLAY]") - ] - self.assertTrue(player_mock.play.called) - - self.assertEqual(len(tokenizer.detections), len(self.expected)) - log_fmt = "[PLAY]: Detection {id} played" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(id=i) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_RegionSaverWorker(self): - filename_format = ( - "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" - ) - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_RegionSaverWorker") - observers = [RegionSaverWorker(filename_format, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - with patch("auditok.core.AudioRegion.save") as patched_save: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[SAVE]") - ] - - # Assert RegionSaverWorker ran as expected - expected_save_calls = [ - call( - filename_format.format( - id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] - ), - None, - ) - for i, exp in enumerate(self.expected, 1) - ] - - # Get calls to 'AudioRegion.save' - mock_calls = [ - c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 - ] - self.assertEqual(mock_calls, expected_save_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - - log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - expected_filename = filename_format.format( - id=i, start=start, end=end, duration=end - start - ) - exp_log_line = log_fmt.format(i, expected_filename) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_CommandLineWorker(self): - command_format = "do nothing with" - with TemporaryDirectory() as tmpdir: - file = os.path.join(tmpdir, "file.log") - logger = make_logger(file=file, name="test_CommandLineWorker") - observers = [CommandLineWorker(command_format, logger=logger)] - tokenizer = TokenizerWorker( - self.reader, - logger=logger, - observers=observers, - min_dur=0.3, - max_dur=2, - max_silence=0.2, - drop_trailing_silence=False, - strict_min_dur=False, - eth=50, - ) - with patch("auditok.workers.os.system") as patched_os_system: - tokenizer.start_all() - tokenizer.join() - tokenizer._observers[0].join() - # Get logged text - with open(file) as fp: - log_lines = [ - line - for line in fp.readlines() - if line.startswith("[COMMAND]") - ] - - # Assert CommandLineWorker ran as expected - expected_save_calls = [call(command_format) for _ in self.expected] - self.assertEqual(patched_os_system.mock_calls, expected_save_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - log_fmt = "[COMMAND]: Detection {id} command '{command}'" - for i, (det, exp, log_line) in enumerate( - zip(tokenizer.detections, self.expected, log_lines), 1 - ): - start, end = exp - exp_log_line = log_fmt.format(i, command_format) - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) - # Remove timestamp part and strip new line - self.assertEqual(log_line[28:].strip(), exp_log_line) - - def test_PrintWorker(self): - observers = [ - PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") - ] +def test_PlayerWorker(audio_data_source, expected_detections): + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_RegionSaverWorker") + player_mock = Mock() + observers = [PlayerWorker(player_mock, logger=logger)] tokenizer = TokenizerWorker( - self.reader, + audio_data_source, + logger=logger, observers=observers, min_dur=0.3, max_dur=2, @@ -234,121 +97,320 @@ strict_min_dur=False, eth=50, ) - with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[PLAY]") + ] + + assert player_mock.play.called + assert len(tokenizer.detections) == len(expected_detections) + log_fmt = "[PLAY]: Detection {id} played" + for i, (det, exp, log_line) in enumerate( + zip( + tokenizer.detections, + expected_detections, + log_lines, + ), + 1, + ): + start, end = exp + exp_log_line = log_fmt.format(id=i) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_RegionSaverWorker(audio_data_source, expected_detections): + filename_format = "Region_{id}_{start:.6f}-{end:.3f}_{duration:.3f}.wav" + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_RegionSaverWorker") + observers = [RegionSaverWorker(filename_format, logger=logger)] + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("auditok.core.AudioRegion.save") as patched_save: tokenizer.start_all() tokenizer.join() tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[SAVE]") + ] - # Assert PrintWorker ran as expected - expected_print_calls = [ - call( - "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( - i, exp[0], exp[1], exp[1] - exp[0] - ) + expected_save_calls = [ + call( + filename_format.format( + id=i, start=exp[0], end=exp[1], duration=exp[1] - exp[0] + ), + None, + ) + for i, exp in enumerate(expected_detections, 1) + ] + + mock_calls = [ + c for i, c in enumerate(patched_save.mock_calls) if i % 2 == 0 + ] + assert mock_calls == expected_save_calls + assert len(tokenizer.detections) == len(expected_detections) + + log_fmt = "[SAVE]: Detection {id} saved as '{filename}'" + for i, (det, exp, log_line) in enumerate( + zip( + tokenizer.detections, + expected_detections, + log_lines, + ), + 1, + ): + start, end = exp + expected_filename = filename_format.format( + id=i, start=start, end=end, duration=end - start + ) + exp_log_line = log_fmt.format(id=i, filename=expected_filename) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_CommandLineWorker(audio_data_source, expected_detections): + command_format = "do nothing with" + with TemporaryDirectory() as tmpdir: + file = os.path.join(tmpdir, "file.log") + logger = make_logger(file=file, name="test_CommandLineWorker") + observers = [CommandLineWorker(command_format, logger=logger)] + tokenizer = TokenizerWorker( + audio_data_source, + logger=logger, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("auditok.workers.os.system") as patched_os_system: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + with open(file) as fp: + log_lines = [ + line for line in fp.readlines() if line.startswith("[COMMAND]") + ] + + expected_save_calls = [call(command_format) for _ in expected_detections] + assert patched_os_system.mock_calls == expected_save_calls + assert len(tokenizer.detections) == len(expected_detections) + log_fmt = "[COMMAND]: Detection {id} command '{command}'" + for i, (det, exp, log_line) in enumerate( + zip( + tokenizer.detections, + expected_detections, + log_lines, + ), + 1, + ): + start, end = exp + exp_log_line = log_fmt.format(id=i, command=command_format) + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end + assert log_line[28:].strip() == exp_log_line + + +def test_PrintWorker(audio_data_source, expected_detections): + observers = [ + PrintWorker(print_format="[{id}] {start} {end}, dur: {duration}") + ] + tokenizer = TokenizerWorker( + audio_data_source, + observers=observers, + min_dur=0.3, + max_dur=2, + max_silence=0.2, + drop_trailing_silence=False, + strict_min_dur=False, + eth=50, + ) + with patch("builtins.print") as patched_print: + tokenizer.start_all() + tokenizer.join() + tokenizer._observers[0].join() + + expected_print_calls = [ + call( + "[{}] {:.3f} {:.3f}, dur: {:.3f}".format( + i, exp[0], exp[1], exp[1] - exp[0] ) - for i, exp in enumerate(self.expected, 1) - ] - self.assertEqual(patched_print.mock_calls, expected_print_calls) - self.assertEqual(len(tokenizer.detections), len(self.expected)) - for det, exp in zip(tokenizer.detections, self.expected): - start, end = exp - self.assertAlmostEqual(det.start, start) - self.assertAlmostEqual(det.end, end) + ) + for i, exp in enumerate(expected_detections, 1) + ] + assert patched_print.mock_calls == expected_print_calls + assert len(tokenizer.detections) == len(expected_detections) + for det, exp in zip( + tokenizer.detections, + expected_detections, + ): + start, end = exp + assert pytest.approx(det.start) == start + assert pytest.approx(det.end) == end - def test_StreamSaverWorker_wav(self): - with TemporaryDirectory() as tmpdir: - expected_filename = os.path.join(tmpdir, "output.wav") - saver = StreamSaverWorker(self.reader, expected_filename) + +def test_StreamSaverWorker_wav(audio_data_source): + with TemporaryDirectory() as tmpdir: + expected_filename = os.path.join(tmpdir, "output.wav") + saver = StreamSaverWorker(audio_data_source, expected_filename) + saver.start() + + tokenizer = TokenizerWorker(saver) + tokenizer.start_all() + tokenizer.join() + saver.join() + + output_filename = saver.export_audio() + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + + expected_region = AudioRegion.load(output_filename) + assert output_filename == expected_filename + assert region == expected_region + assert saver.data == bytes(expected_region) + + +@pytest.mark.parametrize( + "export_format", + [ + "raw", # raw + "wav", # wav + ], + ids=[ + "raw", + "raw", + ], +) +def test_StreamSaverWorker(audio_data_source, export_format): + with TemporaryDirectory() as tmpdir: + expected_filename = os.path.join(tmpdir, f"output.{export_format}") + saver = StreamSaverWorker( + audio_data_source, expected_filename, export_format=export_format + ) + saver.start() + tokenizer = TokenizerWorker(saver) + tokenizer.start_all() + tokenizer.join() + saver.join() + output_filename = saver.export_audio() + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + expected_region = AudioRegion.load( + output_filename, sr=10, sw=2, ch=1, audio_format=export_format + ) + assert output_filename == expected_filename + assert region == expected_region + assert saver.data == bytes(expected_region) + + +def test_StreamSaverWorker_encode_audio(audio_data_source): + with TemporaryDirectory() as tmpdir: + with patch("auditok.workers._run_subprocess") as patch_rsp: + patch_rsp.return_value = (1, None, None) + expected_filename = os.path.join(tmpdir, "output.ogg") + tmp_expected_filename = expected_filename + ".wav" + saver = StreamSaverWorker(audio_data_source, expected_filename) saver.start() - tokenizer = TokenizerWorker(saver) tokenizer.start_all() tokenizer.join() saver.join() - output_filename = saver.save_stream() - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 - ) + with pytest.raises(auditok.workers.AudioEncodingError) as ae_error: + saver._encode_export_audio() - expected_region = AudioRegion.load(output_filename) - self.assertEqual(output_filename, expected_filename) - self.assertEqual(region, expected_region) - self.assertEqual(saver.data, bytes(expected_region)) + warn_msg = "Couldn't save audio data in the desired format " + warn_msg += "'ogg'.\nEither none of 'ffmpeg', 'avconv' or 'sox' " + warn_msg += "is installed or this format is not recognized.\n" + warn_msg += "Audio file was saved as '{}'" + assert warn_msg.format(tmp_expected_filename) == str(ae_error.value) + ffmpef_avconv = [ + "-y", + "-f", + "wav", + "-i", + tmp_expected_filename, + "-f", + "ogg", + expected_filename, + ] + expected_calls = [ + call(["ffmpeg"] + ffmpef_avconv), + call(["avconv"] + ffmpef_avconv), + call( + [ + "sox", + "-t", + "wav", + tmp_expected_filename, + expected_filename, + ] + ), + ] + assert patch_rsp.mock_calls == expected_calls + region = AudioRegion.load( + "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 + ) + assert not saver._exported + assert saver.data == bytes(region) - def test_StreamSaverWorker_raw(self): - with TemporaryDirectory() as tmpdir: - expected_filename = os.path.join(tmpdir, "output") - saver = StreamSaverWorker( - self.reader, expected_filename, export_format="raw" - ) - saver.start() - tokenizer = TokenizerWorker(saver) - tokenizer.start_all() - tokenizer.join() - saver.join() - output_filename = saver.save_stream() - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 - ) - expected_region = AudioRegion.load( - output_filename, sr=10, sw=2, ch=1, audio_format="raw" - ) - self.assertEqual(output_filename, expected_filename) - self.assertEqual(region, expected_region) - self.assertEqual(saver.data, bytes(expected_region)) - def test_StreamSaverWorker_encode_audio(self): - with TemporaryDirectory() as tmpdir: - with patch("auditok.workers._run_subprocess") as patch_rsp: - patch_rsp.return_value = (1, None, None) - expected_filename = os.path.join(tmpdir, "output.ogg") - tmp_expected_filename = expected_filename + ".wav" - saver = StreamSaverWorker(self.reader, expected_filename) - saver.start() - tokenizer = TokenizerWorker(saver) - tokenizer.start_all() - tokenizer.join() - saver.join() - with self.assertRaises(AudioEncodingWarning) as rt_warn: - saver.save_stream() - warn_msg = "Couldn't save audio data in the desired format " - warn_msg += "'ogg'. Either none of 'ffmpeg', 'avconv' or 'sox' " - warn_msg += "is installed or this format is not recognized.\n" - warn_msg += "Audio file was saved as '{}'" - self.assertEqual( - warn_msg.format(tmp_expected_filename), str(rt_warn.exception) - ) - ffmpef_avconv = [ - "-y", - "-f", - "wav", - "-i", - tmp_expected_filename, - "-f", - "ogg", - expected_filename, - ] - expected_calls = [ - call(["ffmpeg"] + ffmpef_avconv), - call(["avconv"] + ffmpef_avconv), - call( - [ - "sox", - "-t", - "wav", - tmp_expected_filename, - expected_filename, - ] - ), - ] - self.assertEqual(patch_rsp.mock_calls, expected_calls) - region = AudioRegion.load( - "tests/data/test_split_10HZ_mono.raw", sr=10, sw=2, ch=1 - ) - self.assertTrue(saver._exported) - self.assertEqual(saver.data, bytes(region)) +@pytest.mark.parametrize( + "export_format", + [ + "raw", # raw + "wav", # wav + ], + ids=[ + "raw", + "raw", + ], +) +def test_AudioEventsJoinerWorker(audio_data_source, export_format): + with TemporaryDirectory() as tmpdir: + expected_filename = os.path.join(tmpdir, f"output.{export_format}") + joiner = AudioEventsJoinerWorker( + silence_duration=1.0, + filename=expected_filename, + export_format=export_format, + sampling_rate=audio_data_source.sampling_rate, + sample_width=audio_data_source.sample_width, + channels=audio_data_source.channels, + ) + tokenizer = TokenizerWorker(audio_data_source, observers=[joiner]) + tokenizer.start_all() + tokenizer.join() + joiner.join() -if __name__ == "__main__": - unittest.main() + output_filename = joiner.export_audio() + expected_region = split_and_join_with_silence( + "tests/data/test_split_10HZ_mono.raw", + silence_duration=1.0, + sr=10, + sw=2, + ch=1, + aw=0.1, + ) + assert output_filename == expected_filename + assert joiner.data == bytes(expected_region)