changeset 251:7652b6115c2d

Move lower level logic from cmdline.py to cmdline_util.py
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 28 Aug 2019 21:04:52 +0200
parents a9f4adfae459
children d3a815e1b001
files auditok/cmdline.py auditok/cmdline_util.py
diffstat 2 files changed, 63 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/cmdline.py	Mon Aug 26 21:58:43 2019 +0200
+++ b/auditok/cmdline.py	Wed Aug 28 21:04:52 2019 +0200
@@ -20,10 +20,10 @@
 import time
 import threading
 
-from auditok import __version__
+from auditok import __version__, AudioRegion
 from .util import AudioDataSource
 from .io import player_for
-from .cmdline_util import make_logger, make_kwargs
+from .cmdline_util import make_logger, make_kwargs, initialize_workers
 from . import workers
 
 
@@ -41,9 +41,7 @@
         parser = ArgumentParser(
             prog=program_name, description="An Audio Tokenization tool"
         )
-        parser.add_argument(
-            "--version", "-v", action="version", version=version
-        )
+        parser.add_argument("--version", "-v", action="version", version=version)
         group = parser.add_argument_group("Input-Output options")
         group.add_argument(
             "-i",
@@ -354,46 +352,7 @@
         args = parser.parse_args(argv)
         logger = make_logger(args.debug, args.debug_file)
         kwargs = make_kwargs(args)
-        observers = []
-
-        reader = AudioDataSource(args.input, **kwargs.io)
-        if args.output_main is not None:
-            reader = workers.StreamSaverWorker(reader, args.output_main)
-            reader.start()
-
-        if args.output_tokens is not None:
-            worker = workers.RegionSaverWorker(
-                args.output_tokens, args.output_type, logger=logger
-            )
-            observers.append(worker)
-
-        if args.echo:
-            player = player_for(reader)
-            progress_bar = args.progress_bar
-            worker = workers.PlayerWorker(
-                player, progress_bar=progress_bar, logger=logger
-            )
-            observers.append(worker)
-
-        if args.command is not None:
-            worker = workers.CommandLineWorker(
-                command=args.command, logger=logger
-            )
-            observers.append(worker)
-
-        if not args.quiet:
-            print_format = (
-                args.printf.replace("\\n", "\n")
-                .replace("\\t", "\t")
-                .replace("\\r", "\r")
-            )
-            time_format = args.time_format
-            timestamp_format = args.timestamp_format
-            worker = workers.PrintWorker(
-                print_format, time_format, timestamp_format
-            )
-            observers.append(worker)
-
+        reader, observers = initialize_workers(args, logger=logger, **kwargs.io)
         tokenizer_worker = workers.TokenizerWorker(
             reader, observers, logger=logger, **kwargs.split
         )
@@ -410,14 +369,22 @@
             if args.output_main is not None:
                 reader.save_stream()
             if args.plot or args.save_image is not None:
-                from plotting import plot_signal_and_detections
-                import numpy as np
+                from .plotting import plot_detections
 
-                formats = {1: np.int8, 2: np.int16, 4: np.int32}
                 reader.rewind()
-                signal = np.from_buffer(reader.data, dtype=formats[reader.sw])
-                regions = tokenizer_worker.audio_regions
-                plot_signal_and_detections(signal, regions, args.save_image)
+                main_region = AudioRegion(
+                    reader.data, sr=reader.sr, sw=reader.sw, channels=reader.ch
+                )
+                detections = (
+                    (det.start, det.end) for det in tokenizer_worker.audio_regions
+                )
+                plot_detections(
+                    main_region,
+                    reader.sr,
+                    detections,
+                    show=True,
+                    save_as=args.save_image,
+                )
         return 0
 
 
--- a/auditok/cmdline_util.py	Mon Aug 26 21:58:43 2019 +0200
+++ b/auditok/cmdline_util.py	Wed Aug 28 21:04:52 2019 +0200
@@ -1,6 +1,9 @@
 import sys
 import logging
 from collections import namedtuple
+from . import workers
+from .util import AudioDataSource
+from .io import player_for
 
 LOGGER_NAME = "AUDITOK_LOGGER"
 KeywordArguments = namedtuple("KeywordArguments", ["io", "split"])
@@ -99,3 +102,45 @@
         handler.setLevel(logging.DEBUG)
         logger.addHandler(handler)
     return logger
+
+
+def initialize_workers(args, logger=None, **io_kwargs):
+    observers = []
+
+    reader = AudioDataSource(args.input, **io_kwargs)
+    if args.output_main is not None:
+        reader = workers.StreamSaverWorker(reader, args.output_main)
+        reader.start()
+
+    if args.output_tokens is not None:
+        worker = workers.RegionSaverWorker(
+            args.output_tokens, args.output_type, logger=logger
+        )
+        observers.append(worker)
+
+    if args.echo:
+        player = player_for(reader)
+        progress_bar = args.progress_bar
+        worker = workers.PlayerWorker(
+            player, progress_bar=progress_bar, logger=logger
+        )
+        observers.append(worker)
+
+    if args.command is not None:
+        worker = workers.CommandLineWorker(command=args.command, logger=logger)
+        observers.append(worker)
+
+    if not args.quiet:
+        print_format = (
+            args.printf.replace("\\n", "\n")
+            .replace("\\t", "\t")
+            .replace("\\r", "\r")
+        )
+        time_format = args.time_format
+        timestamp_format = args.timestamp_format
+        worker = workers.PrintWorker(
+            print_format, time_format, timestamp_format
+        )
+        observers.append(worker)
+
+    return reader, observers