view auditok/workers.py @ 274:961f35fc09a8

Add test_workers.py - Add test for TokenizerWorker
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 18 Sep 2019 20:21:10 +0200
parents eb77a08a608a
children f0252da17455
line wrap: on
line source
from __future__ import print_function
import os
import sys
from abc import ABCMeta, abstractmethod
from threading import Thread
from datetime import datetime, timedelta
from collections import namedtuple
import wave
import subprocess
from .io import _guess_audio_format
from .util import AudioDataSource

try:
    import future
    from queue import Queue, Empty
except ImportError:
    if sys.version_info >= (3, 0):
        from queue import Queue, Empty
    else:
        from Queue import Queue, Empty

from .core import split
from . import cmdline_util

_STOP_PROCESSING = "STOP_PROCESSING"
_Detection = namedtuple("_Detection", "id start end duration")


class EndOfProcessing(Exception):
    pass


class AudioEncodingError(Exception):
    pass


def _run_subprocess(command):
    try:
        with subprocess.Popen(
            command, stdin=open(os.devnull, "rb"), stdout=subprocess.PIPE
        ) as proc:
            stdout, stderr = proc.communicate()
            return proc.returncode, stdout, stderr
    except:
        err_msg = "Can not export audio with command: {}".format(command)
        raise AudioEncodingError(err_msg)


class Worker(Thread):
    def __init__(self, timeout=0.5, logger=None):
        self._timeout = timeout
        self._logger = logger
        self._start_processing_timestamp = None
        self._inbox = Queue()
        Thread.__init__(self)

    def run(self):
        while True:
            message = self._get_message()
            if message == _STOP_PROCESSING:
                break
            if message is not None:
                self._process_message(message)
        self._post_process()

    @abstractmethod
    def _process_message(self, message):
        """Process incoming messages"""

    def _post_process(self):
        pass

    def _log(self, message):
        self._logger.warning(message)

    def _stop_requested(self):
        try:
            message = self._inbox.get_nowait()
            if message == _STOP_PROCESSING:
                return True
        except Empty:
            return False

    def stop(self):
        self.send(_STOP_PROCESSING)
        self.join()

    def send(self, message):
        self._inbox.put(message)

    def _get_message(self):
        try:
            message = self._inbox.get(timeout=self._timeout)
            return message
        except Empty:
            return None


class TokenizerWorker(Worker, AudioDataSource):
    def __init__(self, reader, observers=None, logger=None, **kwargs):
        self._observers = observers if observers is not None else []
        self._reader = reader
        self._audio_region_gen = split(self, **kwargs)
        self._detections = []
        self._log_format = "[DET]: Detection {0.id} (start: {0.start:.3f}, "
        self._log_format += "end: {0.end:.3f}, duration: {0.duration:.3f})"
        Worker.__init__(self, timeout=0.2, logger=logger)

    @property
    def detections(self):
        return self._detections

    def _notify_observers(self, message):
        for observer in self._observers:
            observer.send(message)

    def _init_start_processing_timestamp(self):
        timestamp = datetime.now()
        self._start_processing_timestamp = timestamp
        for observer in self._observers:
            observer._start_processing_timestamp = timestamp

    def run(self):
        self._reader.open()
        self._init_start_processing_timestamp()
        for _id, audio_region in enumerate(self._audio_region_gen, start=1):
            detection = _Detection(
                _id,
                audio_region.meta.start,
                audio_region.meta.end,
                audio_region.duration,
            )
            self._detections.append(detection)
            if self._logger is not None:
                message = self._log_format.format(detection)
                self._log(message)
            self._notify_observers((_id, audio_region))
        self._notify_observers(_STOP_PROCESSING)
        self._reader.close()

    def start_all(self):
        for observer in self._observers:
            observer.start()
        self.start()

    def stop_all(self):
        self.stop()
        for observer in self._observers:
            observer.stop()
        self._reader.close()

    def read(self):
        if self._stop_requested():
            return None
        else:
            return self._reader.read()

    def __getattr__(self, name):
        return getattr(self._reader, name)


class StreamSaverWorker(Worker, AudioDataSource):
    def __init__(
        self,
        audio_data_source,
        filename,
        export_format=None,
        cache_size=16000,
        timeout=0.5,
    ):

        self._audio_data_source = audio_data_source
        self._cache_size = cache_size
        self._output_filename = filename
        self._export_format = _guess_audio_format(export_format, filename)
        if self._export_format != "raw":
            self._tmp_output_filename = self._output_filename + ".raw"
        else:
            self._tmp_output_filename = self._output_filename
        self._fp = open(self._tmp_output_filename, "wb")
        self._exported = False
        self._cache = []
        self._total_cached = 0
        Worker.__init__(self, timeout=timeout)

    @property
    def sr(self):
        return self._audio_data_source.sampling_rate

    @property
    def sw(self):
        return self._audio_data_source.sample_width

    @property
    def ch(self):
        return self._audio_data_source.channels

    def __del__(self):
        self._post_process()
        if (
            self._tmp_output_filename != self._output_filename
        ) and self._exported:
            os.remove(self._tmp_output_filename)

    def _process_message(self, data):
        self._cache.append(data)
        self._total_cached += len(data)
        if self._total_cached >= self._cache_size:
            self._write_cached_data()

    def _post_process(self):
        while True:
            try:
                data = self._inbox.get_nowait()
                if data != _STOP_PROCESSING:
                    self._cache.append(data)
                    self._total_cached += len(data)
            except Empty:
                break
        self._write_cached_data()
        self._fp.close()

    def _write_cached_data(self):
        if self._cache:
            data = b"".join(self._cache)
            self._fp.write(data)
            self._cache = []
            self._total_cached = 0
            self._fp.flush()

    def open(self):
        self._audio_data_source.open()

    def close(self):
        self._audio_data_source.close()
        self.stop()

    def rewind(self):
        # ensure compatibility with AudioDataSource with record=True
        pass

    @property
    def data(self):
        print("reading data")
        with open(self._tmp_output_filename, "rb") as fp:
            return fp.read()

    def save_stream(self):
        if self._export_format == "raw":
            return
        if self._export_format == "wav":
            self._export_wave()
            self._exported = True
            return
        try:
            self._export_with_ffmpeg_or_avconv()
        except AudioEncodingError:
            try:
                self._export_with_sox()
            except AudioEncodingError:
                warn_msg = "Couldn't save data in the required format '{}'"
                print(warn_msg.format(self._export_format), file=sys.stderr)
                print("Saving stream as a wave file...", file=sys.stderr)
                self._output_filename += ".wav"
                self._export_wave()
                print(
                    "Audio data saved to '{}'".format(self._output_filename),
                    file=sys.stderr,
                )
        finally:
            self._exported = True
        return self._output_filename

    def _export_wave(self):
        with open(self._tmp_output_filename, "rb") as fp:
            with wave.open(self._output_filename, "wb") as wfp:
                wfp.setframerate(self.sr)
                wfp.setsampwidth(self.sw)
                wfp.setnchannels(self.ch)
                # read blocks of 4 seconds
                block_size = self.sr * self.sw * self.ch * 4
                while True:
                    block = fp.read(block_size)
                    if not block:
                        return
                    wfp.writeframes(block)

    def _export_with_ffmpeg_or_avconv(self):
        pcm_fmt = {1: "s8", 2: "s16le", 4: "s32le"}[self.sw]
        command = [
            "-y",
            "-f",
            pcm_fmt,
            "-ar",
            str(self.sr),
            "-ac",
            str(self.ch),
            "-i",
            self._tmp_output_filename,
            "-f",
            self._export_format,
            self._output_filename,
        ]
        returncode, stdout, stderr = _run_subprocess(["ffmpeg"] + command)
        if returncode != 0:
            returncode, stdout, stderr = _run_subprocess(["avconv"] + command)
            if returncode != 0:
                raise AudioEncodingError(stderr)
        return stdout

    def _export_with_sox(self):
        command = [
            "sox",
            "-t",
            "raw",
            "-r",
            str(self.sr),
            "-c",
            str(self.ch),
            "-b",
            str(self.sw * 8),
            "-e",
            "signed",
            self._tmp_output_filename,
            self._output_filename,
        ]
        returncode, stdout, stderr = _run_subprocess(command)
        if returncode != 0:
            raise AudioEncodingError(stderr)
        return stdout

    def close_output(self):
        self._fp.close()

    def read(self):
        data = self._audio_data_source.read()
        if data is not None:
            self.send(data)
        else:
            self.send(_STOP_PROCESSING)
        return data

    def __getattr__(self, name):
        if name == "data":
            return self.data
        return getattr(self._audio_data_source, name)


class PlayerWorker(Worker):
    def __init__(self, player, progress_bar=False, timeout=0.5, logger=None):
        self._player = player
        self._progress_bar = progress_bar
        self._log_format = "[PLAY]: Detection {id} played (start:{start:.3f},"
        self._log_format += "end:{end:.3f}, dur:{duration:.3f})"
        Worker.__init__(self, timeout=timeout, logger=logger)

    def _process_message(self, message):
        _id, audio_region = message
        if self._logger is not None:
            message = self._log_format.format(
                id=_id,
                start=audio_region.start,
                end=audio_region.end,
                duration=audio_region.duration,
            )
            self._log(message)
        audio_region.play(
            player=self._player, progress_bar=self._progress_bar, leave=False
        )


class RegionSaverWorker(Worker):
    def __init__(
        self, name_format, filetype=None, timeout=0.2, logger=None, **kwargs
    ):
        self._name_format = name_format
        self._filetype = filetype
        self._audio_kwargs = kwargs
        self._debug_format = '[SAVE]: Detection {id} saved as "{filename}"'
        Worker.__init__(self, timeout=timeout, logger=logger)

    def _process_message(self, message):
        _id, audio_region = message
        filename = self._name_format.format(
            id=_id,
            start=audio_region.meta.start,
            end=audio_region.meta.end,
            duration=audio_region.duration,
        )
        filename = audio_region.save(
            filename, self._filetype, **self._audio_kwargs
        )
        if self._logger:
            message = self._debug_format.format(id=_id, filename=filename)
            self._log(message)


class CommandLineWorker(Worker):
    def __init__(self, command, timeout=0.2, logger=None):
        self._command = command
        Worker.__init__(self, timeout=timeout, logger=logger)


class PrintWorker(Worker):
    def __init__(
        self,
        print_format="{start} {end}",
        time_format="%S",
        timestamp_format="%Y/%m/%d %H:%M:%S.%f",
        timeout=0.2,
    ):

        self._print_format = print_format
        self._format_time = cmdline_util.make_duration_fromatter(time_format)
        self._timestamp_format = timestamp_format
        self.detections = []
        Worker.__init__(self, timeout=timeout)

    def _process_message(self, message):
        _id, audio_region = message
        timestamp = self._start_processing_timestamp + timedelta(
            seconds=audio_region.meta.start
        )
        timestamp = timestamp.strftime(self._timestamp_format)
        text = self._print_format.format(
            id=_id,
            start=self._format_time(audio_region.meta.start),
            end=self._format_time(audio_region.meta.end),
            duration=self._format_time(audio_region.duration),
            timestamp=timestamp,
        )
        print(text)