changeset 186:00790b3d5aa2

Refactor StreamSaverWorker - Add functions to export stream with ffmpeg/avconv or sox - Fall back to wave if none of these tools is available (avoid losing data if exporting using a compressed format fails) - Add `data` property to align with AudioDataSource API
author Amine Sehili <amine.sehili@gmail.com>
date Mon, 01 Apr 2019 21:12:52 +0100
parents da8d454ece74
children 949678a8cf25
files auditok/cmdline_util.py auditok/io.py auditok/workers.py
diffstat 3 files changed, 174 insertions(+), 37 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/cmdline_util.py	Sat Mar 30 14:41:42 2019 +0100
+++ b/auditok/cmdline_util.py	Mon Apr 01 21:12:52 2019 +0100
@@ -15,13 +15,18 @@
         record = args_ns.plot or (args_ns.save_image is not None)
     else:
         record = False
+    try:
+        use_channel = int(args_ns.use_channel)
+    except ValueError:
+        use_channel = args_ns.use_channel
+
     io_kwargs = {
-        "max_read_time": args_ns.max_time,
+        "max_read": args_ns.max_time,
         "block_dur": args_ns.analysis_window,
         "sampling_rate": args_ns.sampling_rate,
         "sample_width": args_ns.sample_width,
         "channels": args_ns.channels,
-        "use_channel": args_ns.use_channel,
+        "use_channel": use_channel,
         "input_type": args_ns.input_type,
         "output_type": args_ns.output_type,
         "large_file": args_ns.large_file,
@@ -35,7 +40,7 @@
         "max_dur": args_ns.max_duration,
         "max_silence": args_ns.max_silence,
         "drop_trailing_silence": args_ns.drop_trailing_silence,
-        "strict_min_length": args_ns.strict_min_length,
+        "strict_min_dur": args_ns.strict_min_duration,
         "energy_threshold": args_ns.energy_threshold,
     }
     return KeywordArguments(io_kwargs, split_kwargs)
@@ -45,8 +50,6 @@
     """
     Accepted format directives: %i %s %m %h
     """
-    # check directives are correct
-
     if fmt == "%S":
 
         def fromatter(seconds):
--- a/auditok/io.py	Sat Mar 30 14:41:42 2019 +0100
+++ b/auditok/io.py	Mon Apr 01 21:12:52 2019 +0100
@@ -102,9 +102,9 @@
     try:
         return ["left", "right"].index(use_channel)
     except ValueError:
-        err_message = "'use_channel' parameter must be an integer "
-        "or one of ('left', 'right', 'mix'), found: '{}'".format(use_channel)
-        raise AudioParameterError(err_message)
+        err_message = "'use_channel' parameter must be an integer or one of "
+        err_message += "('left', 'right', 'mix'), found: '{}'"
+        raise AudioParameterError(err_message.format(use_channel))
 
 
 def _get_audio_parameters(param_dict):
@@ -796,7 +796,6 @@
     if input is not None:
         return from_file(filename=input,
                          audio_format=kwargs.get('audio_format'),
-                         large_file=kwargs.get('large_file', False),
                          **kwargs)
 
     # read data from microphone via pyaudio
--- a/auditok/workers.py	Sat Mar 30 14:41:42 2019 +0100
+++ b/auditok/workers.py	Mon Apr 01 21:12:52 2019 +0100
@@ -1,10 +1,14 @@
 from __future__ import print_function
+import os
 import sys
 from abc import ABCMeta, abstractmethod
 from threading import Thread
 from datetime import datetime, timedelta
 from collections import namedtuple
-import tempfile
+import wave
+import subprocess
+from .io import _guess_audio_format
+from .util import AudioDataSource
 
 try:
     import future
@@ -26,6 +30,18 @@
     pass
 
 
+class AudioEncodingError(Exception):
+    pass
+
+
+def _run_subprocess(command):
+    with subprocess.Popen(
+            command, stdin=open(os.devnull, "rb"), stdout=subprocess.PIPE
+        ) as proc:
+        stdout, stderr = proc.communicate()
+        return proc.returncode, stdout, stderr
+
+
 class Worker(Thread):
     def __init__(self, timeout=0.5, logger=None):
         self._timeout = timeout
@@ -76,15 +92,15 @@
             return None
 
 
-class TokenizerWorker(Worker):
+class TokenizerWorker(Worker, AudioDataSource):
     def __init__(self, reader, observers=None, logger=None, **kwargs):
         self._observers = observers if observers is not None else []
         self._reader = reader
         self._audio_region_gen = split(self, **kwargs)
         self._audio_regions = []
         self._log_format = "[DET]: Detection {id} (start: {start:.3f}, "
-        self._log_format += "end: {end:.3f}, duration: {duration:.3f})"
-        Worker.__init__(self, logger=logger)
+        self._log_format = "end: {end:.3f}, duration: {duration:.3f})"
+        Worker.__init__(self, timeout=0.5, logger=logger)
 
     @property
     def audio_regions(self):
@@ -105,7 +121,10 @@
         self._init_start_processing_timestamp()
         for _id, audio_region in enumerate(self._audio_region_gen, start=1):
             ar_meta = _AudioRegionMeta(
-                _id, audio_region.start, audio_region.end, audio_region.duration
+                _id,
+                audio_region.start,
+                audio_region.end,
+                audio_region.duration,
             )
             self._audio_regions.append(ar_meta)
             if self._logger is not None:
@@ -115,7 +134,9 @@
                     end=audio_region.end,
                     duration=audio_region.duration,
                 )
-                self._log(message + " " + str(self._start_processing_timestamp))
+                self._log(
+                    message + " " + str(self._start_processing_timestamp)
+                )
             self._notify_observers((_id, audio_region))
         self._notify_observers(_STOP_PROCESSING)
         self._reader.close()
@@ -136,7 +157,7 @@
         self.stop()
         for observer in self._observers:
             observer.stop()
-        self._reader.stop()
+        self._reader.close()
 
     def read(self):
         if self._stop_requested():
@@ -144,41 +165,52 @@
         else:
             return self._reader.read()
 
+    def __getattr__(self, name):
+        return getattr(self._reader, name)
 
-class StreamSaverWorker(Worker):
-    def __init__(self, audio_data_source, filename=None, cache_size=16000, timeout=0.5):
+
+class StreamSaverWorker(Worker, AudioDataSource):
+    def __init__(
+        self,
+        audio_data_source,
+        filename,
+        format=None,
+        cache_size=16000,
+        timeout=0.5,
+    ):
 
         self._audio_data_source = audio_data_source
         self._cache_size = cache_size
-        if filename is not None:
-            self._filename = filename
-            self._fp = open(self._filename, "wb")
+        self._output_filename = filename
+        self._export_format = _guess_audio_format(format, filename)
+        if self._export_format != "raw":
+            self._tmp_output_filename = self._output_filename + ".raw"
         else:
-            self._fp = tempfile.NamedTemporaryFile()
-            self._filename = self._fp.name
-
+            self._tmp_output_filename = self._output_filename
+        self._fp = open(self._tmp_output_filename, "wb")
+        self._exported = False
         self._cache = []
         self._total_cached = 0
         Worker.__init__(self, timeout=timeout)
 
     @property
-    def sampling_rate(self):
+    def sr(self):
         return self._audio_data_source.sampling_rate
 
     @property
-    def sample_width(self):
-        return self._audio_data_source.get_sample_width
+    def sw(self):
+        return self._audio_data_source.sample_width
 
     @property
-    def channels(self):
-        return self._audio_data_source.get_channels
+    def ch(self):
+        return self._audio_data_source.channels
 
     def __del__(self):
-        self._fp.close()
-
-    @property
-    def filename(self):
-        return self._filename
+        self._post_process()
+        if (
+            self._tmp_output_filename != self._output_filename
+        ) and self._exported:
+            os.remove(self._tmp_output_filename)
 
     def _process_message(self, data):
         self._cache.append(data)
@@ -196,6 +228,7 @@
             except Empty:
                 break
         self._write_cached_data()
+        self._fp.close()
 
     def _write_cached_data(self):
         if self._cache:
@@ -203,12 +236,105 @@
             self._fp.write(data)
             self._cache = []
             self._total_cached = 0
+            self._fp.flush()
 
     def open(self):
         self._audio_data_source.open()
 
     def close(self):
         self._audio_data_source.close()
+        self.stop()
+
+    def rewind(self):
+        # ensure compatibility with AudioDataSource with record=True
+        pass
+
+    @property
+    def data(self):
+        with open(self._tmp_output_filename, "rb") as fp:
+            return fp.read()
+
+    def save_stream(self):
+        if self._export_format == "raw":
+            return
+        if self._export_format == "wav":
+            self._export_wave()
+            self._exported = True
+            return
+        try:
+            self._export_with_ffmpeg_or_avconv()
+        except AudioEncodingError:
+            try:
+                self._save_with_sox()
+            except AudioEncodingError:
+                warn_msg = "Couldn't save data in the required format '{}'"
+                print(warn_msg.format(self._export_format), file=sys.stderr)
+                print("Saving stream as a wave file...", file=sys.stderr)
+                self._output_filename += ".wav"
+                self._export_wave()
+                print("Audio data saved to '{}'".format(self._output_filename))
+        finally:
+            self._exported = True
+        return self._output_filename
+
+    def _export_wave(self):
+        with open(self._tmp_output_filename, "rb") as fp:
+            with wave.open(self._output_filename, "wb") as wfp:
+                wfp.setframerate(self.sr)
+                wfp.setsampwidth(self.sw)
+                wfp.setnchannels(self.ch)
+                # read blocks of 4 seconds
+                block_size = self.sr * self.sw * self.ch * 4
+                while True:
+                    block = fp.read(block_size)
+                    if not block:
+                        return
+                    wfp.writeframes(block)
+
+    def _export_with_ffmpeg_or_avconv(self):
+        pcm_fmt = {1: "s8", 2: "s16le", 4: "s32le"}[self.sw]
+        command = [
+            "-y",
+            "-f",
+            pcm_fmt,
+            "-ar",
+            str(self.sr),
+            "-ac",
+            str(self.ch),
+            "-i",
+            self._tmp_output_filename,
+            "-f",
+            self._export_format,
+            self._output_filename,
+        ]
+        returncode, stdout, stderr = _run_subprocess(["ffmpeg"] + command)
+        if  returncode != 0:
+            returncode, stdout, stderr = _run_subprocess(["avconv"] + command)
+            if returncode != 0:
+                raise AudioEncodingError(stderr)
+        return stdout
+
+    def _export_with_sox(self):
+        command = [
+            "sox",
+            "-t",
+            "raw",
+            "-r",
+            str(self.sr),
+            "-c",
+            str(self.ch),
+            "-b",
+            str(self.sw * 8),
+            "-e",
+            "signed",
+            self._tmp_output_filename,
+            self._output_filename,
+        ]
+        print(" ".join(command))
+        returncode, stdout, stderr = _run_subprocess(command)
+        if returncode != 0:
+            raise AudioEncodingError(stderr)
+        return stdout
 
     def close_output(self):
         self._fp.close()
@@ -221,11 +347,16 @@
             self.send(_STOP_PROCESSING)
         return data
 
+    def __getattr__(self, name):
+        if name == "data":
+            return self.data
+        return getattr(self._audio_data_source, name)
+
 
 class PlayerWorker(Worker):
     def __init__(self, progress_bar=False, timeout=0.5, logger=None):
         self._progress_bar = progress_bar
-        self._log_format = "[PLAY]: Detection {id} played (start:{start:.3f}, "
+        self._log_format = "[PLAY]: Detection {id} played (start:{start:.3f},"
         self._log_format += "end:{end:.3f}, dur:{duration:.3f})"
         Worker.__init__(self, timeout=timeout, logger=logger)
 
@@ -243,7 +374,9 @@
 
 
 class RegionSaverWorker(Worker):
-    def __init__(self, name_format, filetype=None, timeout=0.2, logger=None, **kwargs):
+    def __init__(
+        self, name_format, filetype=None, timeout=0.2, logger=None, **kwargs
+    ):
         self._name_format = name_format
         self._filetype = filetype
         self._audio_kwargs = kwargs
@@ -253,7 +386,9 @@
     def _process_message(self, message):
         _id, audio_region = message
         filename = self._name_format.replace("{id}", str(_id))
-        filename = audio_region.save(filename, self._filetype, **self._audio_kwargs)
+        filename = audio_region.save(
+            filename, self._filetype, **self._audio_kwargs
+        )
         if self._logger:
             message = self._debug_format.format(id=_id, filename=filename)
             self._log(message)