Mercurial > hg > auditok
changeset 176:c2fa3a12058e
Move thread workers to a separate module and refactor code
author | Amine Sehili <amine.sehili@gmail.com> |
---|---|
date | Wed, 13 Mar 2019 21:08:20 +0100 |
parents | 592ec1821452 |
children | 2acbdbd18327 |
files | auditok/workers.py |
diffstat | 1 files changed, 290 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/auditok/workers.py Wed Mar 13 21:08:20 2019 +0100 @@ -0,0 +1,290 @@ +from __future__ import print_function +import sys +from abc import ABCMeta, abstractmethod +from threading import Thread +from datetime import datetime, timedelta +from collections import namedtuple +import tempfile + +try: + import future + from queue import Queue, Empty +except ImportError: + if sys.version_info >= (3, 0): + from queue import Queue, Empty + else: + from Queue import Queue, Empty + +from .core import split +from .cmdline_util import make_duration_fromatter + +_STOP_PROCESSING = "STOP_PROCESSING" +_AudioRegionMeta = namedtuple("_AudioRegionMeta", "id start end duration") + + +class EndOfProcessing(Exception): + pass + + +class Worker(Thread): + def __init__(self, timeout=0.5, logger=None): + self._timeout = timeout + self._logger = logger + self._start_processing_timestamp = None + self._inbox = Queue() + Thread.__init__(self) + + def run(self): + while True: + message = self._get_message() + if message == _STOP_PROCESSING: + break + if message is not None: + self._process_message(message) + self._post_process() + + @abstractmethod + def _process_message(self, message): + """Process incoming messages""" + + def _post_process(self): + pass + + def _log(self, message): + self._logger.warning(message) + + def _stop_requested(self): + try: + message = self._inbox.get_nowait() + if message == _STOP_PROCESSING: + return True + except Empty: + return False + + def stop(self): + self.send(_STOP_PROCESSING) + self.join() + + def send(self, message): + self._inbox.put(message) + + def _get_message(self): + try: + message = self._inbox.get(timeout=self._timeout) + return message + except Empty: + return None + + +class TokenizerWorker(Worker): + def __init__(self, reader, observers=None, logger=None, **kwargs): + self._observers = observers if observers is not None else [] + self._reader = reader + self._audio_region_gen = split(self, **kwargs) + self._audio_regions = [] + self._log_format = "[DET]: Detection {id} (start: {start:.3f}, " + self._log_format += "end: {end:.3f}, duration: {duration:.3f})" + Worker.__init__(self, logger=logger) + + @property + def audio_regions(self): + return self._audio_regions + + def _notify_observers(self, message): + for observer in self._observers: + observer.send(message) + + def _init_start_processing_timestamp(self): + timestamp = datetime.now() + self._start_processing_timestamp = timestamp + for observer in self._observers: + observer._start_processing_timestamp = timestamp + + def run(self): + self._reader.open() + self._init_start_processing_timestamp() + for _id, audio_region in enumerate(self._audio_region_gen, start=1): + ar_meta = _AudioRegionMeta( + _id, audio_region.start, audio_region.end, audio_region.duration + ) + self._audio_regions.append(ar_meta) + if self._logger is not None: + message = self._log_format.format( + id=_id, + start=audio_region.start, + end=audio_region.end, + duration=audio_region.duration, + ) + self._log(message + " " + str(self._start_processing_timestamp)) + self._notify_observers((_id, audio_region)) + self._notify_observers(_STOP_PROCESSING) + self._reader.close() + + def add_observer(self, observer): + observer.start_processing_timestamp = self._start_processing_timestamp + self._observers.append(observer) + + def remove_observer(self, observer): + self._observers.remove(observer) + + def start_all(self): + for observer in self._observers: + observer.start() + self.start() + + def stop_all(self): + self.stop() + for observer in self._observers: + observer.stop() + self._reader.stop() + + def read(self): + if self._stop_requested(): + return None + else: + return self._reader.read() + + +class StreamSaverWorker(Worker): + def __init__(self, audio_data_source, filename=None, cache_size=16000, timeout=0.5): + + self._audio_data_source = audio_data_source + self._cache_size = cache_size + if filename is not None: + self._filename = filename + self._fp = open(self._filename, "wb") + else: + self._fp = tempfile.NamedTemporaryFile() + self._filename = self._fp.name + + self._cache = [] + self._total_cached = 0 + Worker.__init__(self, timeout=timeout) + + @property + def sampling_rate(self): + return self._audio_data_source.sampling_rate + + @property + def sample_width(self): + return self._audio_data_source.get_sample_width + + @property + def channels(self): + return self._audio_data_source.get_channels + + def __del__(self): + self._fp.close() + + @property + def filename(self): + return self._filename + + def _process_message(self, data): + self._cache.append(data) + self._total_cached += len(data) + if self._total_cached >= self._cache_size: + self._write_cached_data() + + def _post_process(self): + while True: + try: + data = self._inbox.get_nowait() + if data != _STOP_PROCESSING: + self._cache.append(data) + self._total_cached += len(data) + except Empty: + break + self._write_cached_data() + + def _write_cached_data(self): + if self._cache: + data = b"".join(self._cache) + self._fp.write(data) + self._cache = [] + self._total_cached = 0 + + def open(self): + self._audio_data_source.open() + + def close(self): + self._audio_data_source.close() + + def close_output(self): + self._fp.close() + + def read(self): + data = self._audio_data_source.read() + if data is not None: + self.send(data) + else: + self.send(_STOP_PROCESSING) + return data + + +class PlayerWorker(Worker): + def __init__(self, progress_bar=False, timeout=0.5, logger=None): + self._progress_bar = progress_bar + self._log_format = "[PLAY]: Detection {id} played (start:{start:.3f}, " + self._log_format += "end:{end:.3f}, dur:{duration:.3f})" + Worker.__init__(self, timeout=timeout, logger=logger) + + def _process_message(self, message): + _id, audio_region = message + if self._logger is not None: + message = self._log_format.format( + id=_id, + start=audio_region.start, + end=audio_region.end, + duration=audio_region.duration, + ) + self._log(message) + audio_region.play(self._progress_bar) + + +class RegionSaverWorker(Worker): + def __init__(self, name_format, filetype=None, timeout=0.2, logger=None, **kwargs): + self._name_format = name_format + self._filetype = filetype + self._audio_kwargs = kwargs + self._debug_format = '[SAVE]: Detection {id} saved as "{filename}"' + Worker.__init__(self, timeout=timeout, logger=logger) + + def _process_message(self, message): + _id, audio_region = message + filename = self._name_format.replace("{id}", str(_id)) + filename = audio_region.save(filename, self._filetype, **self._audio_kwargs) + if self._logger: + message = self._debug_format.format(id=_id, filename=filename) + self._log(message) + + +class PrintWorker(Worker): + def __init__( + self, + print_format="{start} {end}", + time_format="%S", + timestamp_format="%Y/%m/%d %H:%M:%S.%f", + timeout=0.2, + ): + + self._print_format = print_format + self._format_time = make_duration_fromatter(time_format) + self._timestamp_format = timestamp_format + self.detections = [] + Worker.__init__(self, timeout=timeout) + + def _process_message(self, message): + _id, audio_region = message + timestamp = self._start_processing_timestamp + timedelta( + seconds=audio_region.start + ) + timestamp = timestamp.strftime(self._timestamp_format) + text = self._print_format.format( + id=_id, + start=self._format_time(audio_region.start), + end=self._format_time(audio_region.end), + duration=self._format_time(audio_region.duration), + timestamp=timestamp, + ) + print(text)