# HG changeset patch # User Amine Sehili # Date 1570652488 -3600 # Node ID 1732213b290a9a51aefcdedbb9bbb70b69115bd9 # Parent 73989d247f4ea8306ed14f7119af84a2f3491e03 Refactor plotting code - Update theme - Add the possibility to use a use-supplied theme - Plot multichannel audio - Optionally scale the signal before plotting - Add arguments for figure size and resolution diff -r 73989d247f4e -r 1732213b290a auditok/plotting.py --- a/auditok/plotting.py Tue Oct 08 20:16:11 2019 +0100 +++ b/auditok/plotting.py Wed Oct 09 21:21:28 2019 +0100 @@ -1,28 +1,135 @@ import matplotlib.pyplot as plt import numpy as np +AUDITOK_PLOT_THEME = { + "figure": {"facecolor": "#482a36", "alpha": 0.2}, + "plot": {"facecolor": "#282a36"}, + "energy_threshold": { + "color": "#e31f8f", + "linestyle": "--", + "linewidth": 1, + }, + "signal": {"color": "#40d970", "linestyle": "-", "linewidth": 1}, + "detections": { + "facecolor": "#777777", + "edgecolor": "#ff8c1a", + "linewidth": 1, + "alpha": 0.75, + }, +} -def plot(data, sampling_rate, show=True): - y = np.asarray(data) - ymax = np.abs(y).max() - nb_samples = y.shape[-1] + +def _make_time_axis(nb_samples, sampling_rate): sample_duration = 1 / sampling_rate x = np.linspace(0, sample_duration * (nb_samples - 1), nb_samples) - plt.plot(x, y / ymax, c="#024959") - plt.ylim(-3, 3) - if show: - plt.show() + return x -def plot_detections(data, sampling_rate, detections, show=True, save_as=None): +def _plot_line(x, y, theme, xlabel=None, ylabel=None, **kwargs): + color = theme.get("color", theme.get("c")) + ls = theme.get("linestyle", theme.get("ls")) + lw = theme.get("linewidth", theme.get("lw")) + plt.plot(x, y, c=color, ls=ls, lw=lw, **kwargs) + plt.xlabel(xlabel) + plt.ylabel(ylabel) - plot(data, sampling_rate, show=False) - if detections is not None: - for (start, end) in detections: - plt.axvspan(start, end, facecolor="g", ec="r", lw=2, alpha=0.4) + +def _plot_detections(subplot, detections, theme): + fc = theme.get("facecolor", theme.get("fc")) + ec = theme.get("edgecolor", theme.get("ec")) + ls = theme.get("linestyle", theme.get("ls")) + lw = theme.get("linewidth", theme.get("lw")) + alpha = theme.get("alpha") + for (start, end) in detections: + subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha) + + +def plot( + audio_region, + scale_signal=True, + detections=None, + energy_threshold=None, + show=True, + figsize=None, + save_as=None, + dpi=120, + theme="auditok", +): + y = np.asarray(audio_region) + if len(y.shape) == 1: + y = y.reshape(1, -1) + nb_subplots, nb_samples = y.shape + time_axis = _make_time_axis(nb_samples, audio_region.sampling_rate) + if energy_threshold is not None: + eth_log10 = energy_threshold * np.log(10) / 10 + amplitude_threshold = np.sqrt(np.exp(eth_log10)) + else: + amplitude_threshold = None + if detections is None: + detections = [] + else: + detections = list(detections) + if theme == "auditok": + theme = AUDITOK_PLOT_THEME + + fig = plt.figure(figsize=figsize, dpi=dpi) + fig_theme = theme.get("figure", theme.get("fig", {})) + fig_fc = fig_theme.get("facecolor", fig_theme.get("ffc")) + fig_alpha = fig_theme.get("alpha", 1) + fig.patch.set_facecolor(fig_fc) + fig.patch.set_alpha(fig_alpha) + + plot_theme = theme.get("plot", {}) + plot_fc = plot_theme.get("facecolor", plot_theme.get("pfc")) + for sid, samples in enumerate(y, 1): + + ax = fig.add_subplot(nb_subplots, 1, sid) + ax.set_facecolor(plot_fc) + if scale_signal: + mean = samples.mean() + std = samples.std() + samples = (samples - mean) / std + max_ = samples.max() + plt.ylim(-1.5 * max_, 1.5 * max_) + + if amplitude_threshold is not None: + if scale_signal: + amp_th = (amplitude_threshold - mean) / std + else: + amp_th = amplitude_threshold + eth_theme = theme.get("energy_threshold", theme.get("eth", {})) + _plot_line( + [time_axis[0], time_axis[-1]], + [amp_th] * 2, + eth_theme, + label="Detection threshold", + ) + if sid == 1: + legend = plt.legend( + ["Detection threshold"], + facecolor=fig_fc, + framealpha=0.1, + bbox_to_anchor=(0.0, 1.15, 1.0, 0.102), + loc=2, + ) + legend = plt.gca().add_artist(legend) + + signal_theme = theme.get("signal", {}) + _plot_line( + time_axis, + samples, + signal_theme, + xlabel="Time (seconds)", + ylabel="Signal{}".format(" (scaled)" if scale_signal else ""), + ) + detections_theme = theme.get("detections", {}) + _plot_detections(ax, detections, detections_theme) + plt.title("Channel {}".format(sid)) + + plt.tight_layout() if save_as is not None: - plt.savefig(save_as, dpi=120) + plt.savefig(save_as, dpi=dpi) if show: plt.show()