changeset 300:1732213b290a

Refactor plotting code - Update theme - Add the possibility to use a use-supplied theme - Plot multichannel audio - Optionally scale the signal before plotting - Add arguments for figure size and resolution
author Amine Sehili <amine.sehili@gmail.com>
date Wed, 09 Oct 2019 21:21:28 +0100
parents 73989d247f4e
children 0f6ef74c65a9
files auditok/plotting.py
diffstat 1 files changed, 121 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/auditok/plotting.py	Tue Oct 08 20:16:11 2019 +0100
+++ b/auditok/plotting.py	Wed Oct 09 21:21:28 2019 +0100
@@ -1,28 +1,135 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
+AUDITOK_PLOT_THEME = {
+    "figure": {"facecolor": "#482a36", "alpha": 0.2},
+    "plot": {"facecolor": "#282a36"},
+    "energy_threshold": {
+        "color": "#e31f8f",
+        "linestyle": "--",
+        "linewidth": 1,
+    },
+    "signal": {"color": "#40d970", "linestyle": "-", "linewidth": 1},
+    "detections": {
+        "facecolor": "#777777",
+        "edgecolor": "#ff8c1a",
+        "linewidth": 1,
+        "alpha": 0.75,
+    },
+}
 
-def plot(data, sampling_rate, show=True):
-    y = np.asarray(data)
-    ymax = np.abs(y).max()
-    nb_samples = y.shape[-1]
+
+def _make_time_axis(nb_samples, sampling_rate):
     sample_duration = 1 / sampling_rate
     x = np.linspace(0, sample_duration * (nb_samples - 1), nb_samples)
-    plt.plot(x, y / ymax, c="#024959")
-    plt.ylim(-3, 3)
-    if show:
-        plt.show()
+    return x
 
 
-def plot_detections(data, sampling_rate, detections, show=True, save_as=None):
+def _plot_line(x, y, theme, xlabel=None, ylabel=None, **kwargs):
+    color = theme.get("color", theme.get("c"))
+    ls = theme.get("linestyle", theme.get("ls"))
+    lw = theme.get("linewidth", theme.get("lw"))
+    plt.plot(x, y, c=color, ls=ls, lw=lw, **kwargs)
+    plt.xlabel(xlabel)
+    plt.ylabel(ylabel)
 
-    plot(data, sampling_rate, show=False)
-    if detections is not None:
-        for (start, end) in detections:
-            plt.axvspan(start, end, facecolor="g", ec="r", lw=2, alpha=0.4)
+
+def _plot_detections(subplot, detections, theme):
+    fc = theme.get("facecolor", theme.get("fc"))
+    ec = theme.get("edgecolor", theme.get("ec"))
+    ls = theme.get("linestyle", theme.get("ls"))
+    lw = theme.get("linewidth", theme.get("lw"))
+    alpha = theme.get("alpha")
+    for (start, end) in detections:
+        subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha)
+
+
+def plot(
+    audio_region,
+    scale_signal=True,
+    detections=None,
+    energy_threshold=None,
+    show=True,
+    figsize=None,
+    save_as=None,
+    dpi=120,
+    theme="auditok",
+):
+    y = np.asarray(audio_region)
+    if len(y.shape) == 1:
+        y = y.reshape(1, -1)
+    nb_subplots, nb_samples = y.shape
+    time_axis = _make_time_axis(nb_samples, audio_region.sampling_rate)
+    if energy_threshold is not None:
+        eth_log10 = energy_threshold * np.log(10) / 10
+        amplitude_threshold = np.sqrt(np.exp(eth_log10))
+    else:
+        amplitude_threshold = None
+    if detections is None:
+        detections = []
+    else:
+        detections = list(detections)
+    if theme == "auditok":
+        theme = AUDITOK_PLOT_THEME
+
+    fig = plt.figure(figsize=figsize, dpi=dpi)
+    fig_theme = theme.get("figure", theme.get("fig", {}))
+    fig_fc = fig_theme.get("facecolor", fig_theme.get("ffc"))
+    fig_alpha = fig_theme.get("alpha", 1)
+    fig.patch.set_facecolor(fig_fc)
+    fig.patch.set_alpha(fig_alpha)
+
+    plot_theme = theme.get("plot", {})
+    plot_fc = plot_theme.get("facecolor", plot_theme.get("pfc"))
+    for sid, samples in enumerate(y, 1):
+
+        ax = fig.add_subplot(nb_subplots, 1, sid)
+        ax.set_facecolor(plot_fc)
+        if scale_signal:
+            mean = samples.mean()
+            std = samples.std()
+            samples = (samples - mean) / std
+            max_ = samples.max()
+            plt.ylim(-1.5 * max_, 1.5 * max_)
+
+        if amplitude_threshold is not None:
+            if scale_signal:
+                amp_th = (amplitude_threshold - mean) / std
+            else:
+                amp_th = amplitude_threshold
+            eth_theme = theme.get("energy_threshold", theme.get("eth", {}))
+            _plot_line(
+                [time_axis[0], time_axis[-1]],
+                [amp_th] * 2,
+                eth_theme,
+                label="Detection threshold",
+            )
+            if sid == 1:
+                legend = plt.legend(
+                    ["Detection threshold"],
+                    facecolor=fig_fc,
+                    framealpha=0.1,
+                    bbox_to_anchor=(0.0, 1.15, 1.0, 0.102),
+                    loc=2,
+                )
+                legend = plt.gca().add_artist(legend)
+
+        signal_theme = theme.get("signal", {})
+        _plot_line(
+            time_axis,
+            samples,
+            signal_theme,
+            xlabel="Time (seconds)",
+            ylabel="Signal{}".format(" (scaled)" if scale_signal else ""),
+        )
+        detections_theme = theme.get("detections", {})
+        _plot_detections(ax, detections, detections_theme)
+        plt.title("Channel {}".format(sid))
+
+    plt.tight_layout()
 
     if save_as is not None:
-        plt.savefig(save_as, dpi=120)
+        plt.savefig(save_as, dpi=dpi)
 
     if show:
         plt.show()