changeset 176:c2fa3a12058e

Move thread workers to a separate module and refactor code
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 13 Mar 2019 21:08:20 +0100
parents 592ec1821452
children 2acbdbd18327
files auditok/workers.py
diffstat 1 files changed, 290 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/auditok/workers.py	Wed Mar 13 21:08:20 2019 +0100
@@ -0,0 +1,290 @@
+from __future__ import print_function
+import sys
+from abc import ABCMeta, abstractmethod
+from threading import Thread
+from datetime import datetime, timedelta
+from collections import namedtuple
+import tempfile
+
+try:
+    import future
+    from queue import Queue, Empty
+except ImportError:
+    if sys.version_info >= (3, 0):
+        from queue import Queue, Empty
+    else:
+        from Queue import Queue, Empty
+
+from .core import split
+from .cmdline_util import make_duration_fromatter
+
+_STOP_PROCESSING = "STOP_PROCESSING"
+_AudioRegionMeta = namedtuple("_AudioRegionMeta", "id start end duration")
+
+
+class EndOfProcessing(Exception):
+    pass
+
+
+class Worker(Thread):
+    def __init__(self, timeout=0.5, logger=None):
+        self._timeout = timeout
+        self._logger = logger
+        self._start_processing_timestamp = None
+        self._inbox = Queue()
+        Thread.__init__(self)
+
+    def run(self):
+        while True:
+            message = self._get_message()
+            if message == _STOP_PROCESSING:
+                break
+            if message is not None:
+                self._process_message(message)
+        self._post_process()
+
+    @abstractmethod
+    def _process_message(self, message):
+        """Process incoming messages"""
+
+    def _post_process(self):
+        pass
+
+    def _log(self, message):
+        self._logger.warning(message)
+
+    def _stop_requested(self):
+        try:
+            message = self._inbox.get_nowait()
+            if message == _STOP_PROCESSING:
+                return True
+        except Empty:
+            return False
+
+    def stop(self):
+        self.send(_STOP_PROCESSING)
+        self.join()
+
+    def send(self, message):
+        self._inbox.put(message)
+
+    def _get_message(self):
+        try:
+            message = self._inbox.get(timeout=self._timeout)
+            return message
+        except Empty:
+            return None
+
+
+class TokenizerWorker(Worker):
+    def __init__(self, reader, observers=None, logger=None, **kwargs):
+        self._observers = observers if observers is not None else []
+        self._reader = reader
+        self._audio_region_gen = split(self, **kwargs)
+        self._audio_regions = []
+        self._log_format = "[DET]: Detection {id} (start: {start:.3f}, "
+        self._log_format += "end: {end:.3f}, duration: {duration:.3f})"
+        Worker.__init__(self, logger=logger)
+
+    @property
+    def audio_regions(self):
+        return self._audio_regions
+
+    def _notify_observers(self, message):
+        for observer in self._observers:
+            observer.send(message)
+
+    def _init_start_processing_timestamp(self):
+        timestamp = datetime.now()
+        self._start_processing_timestamp = timestamp
+        for observer in self._observers:
+            observer._start_processing_timestamp = timestamp
+
+    def run(self):
+        self._reader.open()
+        self._init_start_processing_timestamp()
+        for _id, audio_region in enumerate(self._audio_region_gen, start=1):
+            ar_meta = _AudioRegionMeta(
+                _id, audio_region.start, audio_region.end, audio_region.duration
+            )
+            self._audio_regions.append(ar_meta)
+            if self._logger is not None:
+                message = self._log_format.format(
+                    id=_id,
+                    start=audio_region.start,
+                    end=audio_region.end,
+                    duration=audio_region.duration,
+                )
+                self._log(message + " " + str(self._start_processing_timestamp))
+            self._notify_observers((_id, audio_region))
+        self._notify_observers(_STOP_PROCESSING)
+        self._reader.close()
+
+    def add_observer(self, observer):
+        observer.start_processing_timestamp = self._start_processing_timestamp
+        self._observers.append(observer)
+
+    def remove_observer(self, observer):
+        self._observers.remove(observer)
+
+    def start_all(self):
+        for observer in self._observers:
+            observer.start()
+        self.start()
+
+    def stop_all(self):
+        self.stop()
+        for observer in self._observers:
+            observer.stop()
+        self._reader.stop()
+
+    def read(self):
+        if self._stop_requested():
+            return None
+        else:
+            return self._reader.read()
+
+
+class StreamSaverWorker(Worker):
+    def __init__(self, audio_data_source, filename=None, cache_size=16000, timeout=0.5):
+
+        self._audio_data_source = audio_data_source
+        self._cache_size = cache_size
+        if filename is not None:
+            self._filename = filename
+            self._fp = open(self._filename, "wb")
+        else:
+            self._fp = tempfile.NamedTemporaryFile()
+            self._filename = self._fp.name
+
+        self._cache = []
+        self._total_cached = 0
+        Worker.__init__(self, timeout=timeout)
+
+    @property
+    def sampling_rate(self):
+        return self._audio_data_source.sampling_rate
+
+    @property
+    def sample_width(self):
+        return self._audio_data_source.get_sample_width
+
+    @property
+    def channels(self):
+        return self._audio_data_source.get_channels
+
+    def __del__(self):
+        self._fp.close()
+
+    @property
+    def filename(self):
+        return self._filename
+
+    def _process_message(self, data):
+        self._cache.append(data)
+        self._total_cached += len(data)
+        if self._total_cached >= self._cache_size:
+            self._write_cached_data()
+
+    def _post_process(self):
+        while True:
+            try:
+                data = self._inbox.get_nowait()
+                if data != _STOP_PROCESSING:
+                    self._cache.append(data)
+                    self._total_cached += len(data)
+            except Empty:
+                break
+        self._write_cached_data()
+
+    def _write_cached_data(self):
+        if self._cache:
+            data = b"".join(self._cache)
+            self._fp.write(data)
+            self._cache = []
+            self._total_cached = 0
+
+    def open(self):
+        self._audio_data_source.open()
+
+    def close(self):
+        self._audio_data_source.close()
+
+    def close_output(self):
+        self._fp.close()
+
+    def read(self):
+        data = self._audio_data_source.read()
+        if data is not None:
+            self.send(data)
+        else:
+            self.send(_STOP_PROCESSING)
+        return data
+
+
+class PlayerWorker(Worker):
+    def __init__(self, progress_bar=False, timeout=0.5, logger=None):
+        self._progress_bar = progress_bar
+        self._log_format = "[PLAY]: Detection {id} played (start:{start:.3f}, "
+        self._log_format += "end:{end:.3f}, dur:{duration:.3f})"
+        Worker.__init__(self, timeout=timeout, logger=logger)
+
+    def _process_message(self, message):
+        _id, audio_region = message
+        if self._logger is not None:
+            message = self._log_format.format(
+                id=_id,
+                start=audio_region.start,
+                end=audio_region.end,
+                duration=audio_region.duration,
+            )
+            self._log(message)
+        audio_region.play(self._progress_bar)
+
+
+class RegionSaverWorker(Worker):
+    def __init__(self, name_format, filetype=None, timeout=0.2, logger=None, **kwargs):
+        self._name_format = name_format
+        self._filetype = filetype
+        self._audio_kwargs = kwargs
+        self._debug_format = '[SAVE]: Detection {id} saved as "{filename}"'
+        Worker.__init__(self, timeout=timeout, logger=logger)
+
+    def _process_message(self, message):
+        _id, audio_region = message
+        filename = self._name_format.replace("{id}", str(_id))
+        filename = audio_region.save(filename, self._filetype, **self._audio_kwargs)
+        if self._logger:
+            message = self._debug_format.format(id=_id, filename=filename)
+            self._log(message)
+
+
+class PrintWorker(Worker):
+    def __init__(
+        self,
+        print_format="{start} {end}",
+        time_format="%S",
+        timestamp_format="%Y/%m/%d %H:%M:%S.%f",
+        timeout=0.2,
+    ):
+
+        self._print_format = print_format
+        self._format_time = make_duration_fromatter(time_format)
+        self._timestamp_format = timestamp_format
+        self.detections = []
+        Worker.__init__(self, timeout=timeout)
+
+    def _process_message(self, message):
+        _id, audio_region = message
+        timestamp = self._start_processing_timestamp + timedelta(
+            seconds=audio_region.start
+        )
+        timestamp = timestamp.strftime(self._timestamp_format)
+        text = self._print_format.format(
+            id=_id,
+            start=self._format_time(audio_region.start),
+            end=self._format_time(audio_region.end),
+            duration=self._format_time(audio_region.duration),
+            timestamp=timestamp,
+        )
+        print(text)