changeset 192:8e63fc5d1af6

Add tests for AudioRegion.save
author Amine Sehili <amine.sehili@gmail.com>
date Sun, 21 Apr 2019 19:10:07 +0100
parents 94d2f7560a8e
children b274a22c9685
files tests/test_core.py
diffstat 1 files changed, 33 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/tests/test_core.py	Thu Apr 18 20:30:49 2019 +0100
+++ b/tests/test_core.py	Sun Apr 21 19:10:07 2019 +0100
@@ -1,5 +1,7 @@
+import os
 import unittest
 from random import random
+from tempfile import TemporaryDirectory
 from genty import genty, genty_dataset
 from auditok import AudioRegion, AudioParameterError
 
@@ -209,6 +211,35 @@
         )
 
     @genty_dataset(
+        simple=("output.wav", 1.230, "output.wav"),
+        start=("output_{start}.wav", 1.230, "output_1.23.wav"),
+        start_2=("output_{start}.wav", 1.233712, "output_1.233712.wav"),
+        start_3=("output_{start}.wav", 1.2300001, "output_1.23.wav"),
+        start_4=("output_{start:.3f}.wav", 1.233712, "output_1.234.wav"),
+        start_5=(
+            "output_{start:.8f}.wav",
+            1.233712345,
+            "output_1.23371200.wav",
+        ),
+        start_end_duration=(
+            "output_{start}_{end}_{duration}.wav",
+            1.455,
+            "output_1.455_2.455_1.0.wav",
+        ),
+        start_end_duration_2=(
+            "output_{start}_{end}_{duration}.wav",
+            1.455321,
+            "output_1.455321_2.455321_1.0.wav",
+        ),
+    )
+    def test_save(self, format, start, expected):
+        with TemporaryDirectory() as tmpdir:
+            region = AudioRegion(b"0" * 160, start, 160, 1, 1)
+            format = os.path.join(tmpdir, format)
+            filename = region.save(format)[len(tmpdir) + 1 :]
+            self.assertEqual(filename, expected)
+
+    @genty_dataset(
         simple=(8000, 1, 1),
         stereo_sw_2=(8000, 2, 2),
         arbitray_sr_multichannel=(5413, 2, 3),
@@ -254,6 +285,8 @@
             concat_region.duration, expected_duration, places=6
         )
         self.assertEqual(bytes(concat_region), expected_data)
+        # see test_concatenation
+        self.assertEqual(len(concat_region), round(expected_duration * 1000))
 
     def test_concatenation_different_sampling_rate_error(self):