view tests/test_core.py @ 191:94d2f7560a8e

Fix bug in test AudioRegion concatenation
author Amine Sehili <amine.sehili@gmail.com>
date Thu, 18 Apr 2019 20:30:49 +0100
parents ceebd8e331c4
children 8e63fc5d1af6
line wrap: on
line source
import unittest
from random import random
from genty import genty, genty_dataset
from auditok import AudioRegion, AudioParameterError


def _make_random_length_regions(
    byte_seq, sampling_rate, sample_width, channels
):
    regions = []
    for b in byte_seq:
        duration = round(random() * 10, 6)
        data = b * int(duration * sampling_rate) * sample_width * channels
        start = round(random() * 13, 3)
        region = AudioRegion(
            data, start, sampling_rate, sample_width, channels
        )
        regions.append(region)
    return regions


@genty
class TestAudioRegion(unittest.TestCase):
    @genty_dataset(
        simple=(b"\0" * 8000, 0, 8000, 1, 1, 1, 1, 1000),
        one_ms_less_than_1_sec=(
            b"\0" * 7992,
            0,
            8000,
            1,
            1,
            0.999,
            0.999,
            999,
        ),
        tree_quarter_ms_less_than_1_sec=(
            b"\0" * 7994,
            0,
            8000,
            1,
            1,
            0.99925,
            0.99925,
            999,
        ),
        half_ms_less_than_1_sec=(
            b"\0" * 7996,
            0,
            8000,
            1,
            1,
            0.9995,
            0.9995,
            1000,
        ),
        quarter_ms_less_than_1_sec=(
            b"\0" * 7998,
            0,
            8000,
            1,
            1,
            0.99975,
            0.99975,
            1000,
        ),
        simple_sample_width_2=(b"\0" * 8000 * 2, 0, 8000, 2, 1, 1, 1, 1000),
        simple_stereo=(b"\0" * 8000 * 2, 0, 8000, 1, 2, 1, 1, 1000),
        simple_multichannel=(b"\0" * 8000 * 5, 0, 8000, 1, 5, 1, 1, 1000),
        simple_sample_width_2_multichannel=(
            b"\0" * 8000 * 2 * 5,
            0,
            8000,
            2,
            5,
            1,
            1,
            1000,
        ),
        one_ms_less_than_1s_sw_2_multichannel=(
            b"\0" * 7992 * 2 * 5,
            0,
            8000,
            2,
            5,
            0.999,
            0.999,
            999,
        ),
        tree_qrt_ms_lt_1_s_sw_2_multichannel=(
            b"\0" * 7994 * 2 * 5,
            0,
            8000,
            2,
            5,
            0.99925,
            0.99925,
            999,
        ),
        half_ms_lt_1s_sw_2_multichannel=(
            b"\0" * 7996 * 2 * 5,
            0,
            8000,
            2,
            5,
            0.9995,
            0.9995,
            1000,
        ),
        quarter_ms_lt_1s_sw_2_multichannel=(
            b"\0" * 7998 * 2 * 5,
            0,
            8000,
            2,
            5,
            0.99975,
            0.99975,
            1000,
        ),
        arbitrary_length_1=(
            b"\0" * int(8000 * 1.33),
            2.7,
            8000,
            1,
            1,
            4.03,
            1.33,
            1330,
        ),
        arbitrary_length_2=(
            b"\0" * int(8000 * 0.476),
            11.568,
            8000,
            1,
            1,
            12.044,
            0.476,
            476,
        ),
        arbitrary_length_sw_2_multichannel=(
            b"\0" * int(8000 * 1.711) * 2 * 3,
            9.415,
            8000,
            2,
            3,
            11.126,
            1.711,
            1711,
        ),
        arbitrary_samplig_rate=(
            b"\0" * int(3172 * 1.318),
            17.236,
            3172,
            1,
            1,
            17.236 + int(3172 * 1.318) / 3172,
            int(3172 * 1.318) / 3172,
            1318,
        ),
        arbitrary_sr_sw_2_multichannel=(
            b"\0" * int(11317 * 0.716) * 2 * 3,
            18.811,
            11317,
            2,
            3,
            18.811 + int(11317 * 0.716) / 11317,
            int(11317 * 0.716) / 11317,
            716,
        ),
    )
    def test_creation(
        self,
        data,
        start,
        sampling_rate,
        sample_width,
        channels,
        expected_end,
        expected_duration_s,
        expected_duration_ms,
    ):
        region = AudioRegion(
            data, start, sampling_rate, sample_width, channels
        )
        self.assertEqual(region.sampling_rate, sampling_rate)
        self.assertEqual(region.sr, sampling_rate)
        self.assertEqual(region.sample_width, sample_width)
        self.assertEqual(region.sw, sample_width)
        self.assertEqual(region.channels, channels)
        self.assertEqual(region.ch, channels)
        self.assertEqual(region.start, start)
        self.assertEqual(region.end, expected_end)
        self.assertEqual(region.duration, expected_duration_s)
        self.assertEqual(len(region), expected_duration_ms)
        self.assertEqual(bytes(region), data)

    def test_creation_invalid_data_exception(self):
        with self.assertRaises(AudioParameterError) as audio_param_err:
            _ = AudioRegion(
                data=b"ABCDEFGHI",
                start=0,
                sampling_rate=8,
                sample_width=2,
                channels=1,
            )
        self.assertEqual(
            "The length of audio data must be an integer "
            "multiple of `sample_width * channels`",
            str(audio_param_err.exception),
        )

    @genty_dataset(
        simple=(8000, 1, 1),
        stereo_sw_2=(8000, 2, 2),
        arbitray_sr_multichannel=(5413, 2, 3),
    )
    def test_concatenation(self, sampling_rate, sample_width, channels):

        region_1, region_2 = _make_random_length_regions(
            [b"a", b"b"], sampling_rate, sample_width, channels
        )

        expected_start = region_1.start
        expected_duration = region_1.duration + region_2.duration
        expected_end = expected_start + expected_duration
        expected_data = bytes(region_1) + bytes(region_2)
        concat_region = region_1 + region_2

        self.assertEqual(concat_region.start, expected_start)
        self.assertAlmostEqual(concat_region.end, expected_end, places=6)
        self.assertAlmostEqual(
            concat_region.duration, expected_duration, places=6
        )
        self.assertEqual(bytes(concat_region), expected_data)

    @genty_dataset(
        simple=(8000, 1, 1),
        stereo_sw_2=(8000, 2, 2),
        arbitray_sr_multichannel=(5413, 2, 3),
    )
    def test_concatenation_many(self, sampling_rate, sample_width, channels):

        regions = _make_random_length_regions(
            [b"a", b"b", b"c"], sampling_rate, sample_width, channels
        )
        expected_start = regions[0].start
        expected_duration = sum(r.duration for r in regions)
        expected_end = expected_start + expected_duration
        expected_data = b"".join(bytes(r) for r in regions)
        concat_region = sum(regions)

        self.assertEqual(concat_region.start, expected_start)
        self.assertAlmostEqual(concat_region.end, expected_end, places=6)
        self.assertAlmostEqual(
            concat_region.duration, expected_duration, places=6
        )
        self.assertEqual(bytes(concat_region), expected_data)

    def test_concatenation_different_sampling_rate_error(self):

        region_1 = AudioRegion(b"a" * 100, 0, 8000, 1, 1)
        region_2 = AudioRegion(b"b" * 100, 0, 3000, 1, 1)

        with self.assertRaises(ValueError) as val_err:
            region_1 + region_2
        self.assertEqual(
            "Can only concatenate AudioRegions of the same "
            "sampling rate (8000 != 3000)",
            str(val_err.exception),
        )

    def test_concatenation_different_sample_width_error(self):

        region_1 = AudioRegion(b"a" * 100, 0, 8000, 2, 1)
        region_2 = AudioRegion(b"b" * 100, 0, 8000, 4, 1)

        with self.assertRaises(ValueError) as val_err:
            region_1 + region_2
        self.assertEqual(
            "Can only concatenate AudioRegions of the same "
            "sample width (2 != 4)",
            str(val_err.exception),
        )

    def test_concatenation_different_number_of_channels_error(self):

        region_1 = AudioRegion(b"a" * 100, 0, 8000, 1, 1)
        region_2 = AudioRegion(b"b" * 100, 0, 8000, 1, 2)

        with self.assertRaises(ValueError) as val_err:
            region_1 + region_2
        self.assertEqual(
            "Can only concatenate AudioRegions of the same "
            "number of channels (1 != 2)",
            str(val_err.exception),
        )