import json
import math
import struct
import unittest

import opuslib_next

from app.services.audio_pipeline import (
    CHANNELS,
    FRAME_SIZE,
    SAMPLE_RATE,
    apply_pcm_gain_s16le,
    build_tts_response_from_streamed_tokens,
    create_output_encoder,
    dispose_output_encoder,
    prepare_tts_chunks,
    split_text_for_tts,
    synthesize_and_send_audio,
)


def build_tone_frames(frequency_hz: int, frame_count: int = 6) -> list[bytes]:
    frames = []
    for frame_index in range(frame_count):
        samples = []
        for sample_index in range(FRAME_SIZE):
            time_offset = ((frame_index * FRAME_SIZE) + sample_index) / SAMPLE_RATE
            sample_value = int(12000 * math.sin(2 * math.pi * frequency_hz * time_offset))
            samples.append(sample_value)
        frames.append(struct.pack(f"<{FRAME_SIZE}h", *samples))
    return frames


def decode_packets(packets: list[bytes]) -> list[int]:
    decoder = opuslib_next.Decoder(SAMPLE_RATE, CHANNELS)
    samples: list[int] = []
    for packet in packets:
        pcm_bytes = decoder.decode(packet, FRAME_SIZE)
        samples.extend(struct.unpack(f"<{len(pcm_bytes) // 2}h", pcm_bytes))
    return samples


def mean_absolute_error(left: list[int], right: list[int]) -> float:
    sample_count = min(len(left), len(right))
    return sum(abs(left[index] - right[index]) for index in range(sample_count)) / sample_count


class FakeTTSService:
    def __init__(self, pcm_chunks: list[bytes], output_gain: float = 1.0):
        self.pcm_chunks = pcm_chunks
        self.output_gain = output_gain

    async def synthesize_stream(self, text_chunks: list[str]):
        for chunk in self.pcm_chunks:
            yield chunk


class FakeWebSocket:
    def __init__(self):
        self.text_messages: list[dict] = []
        self.binary_messages: list[bytes] = []

    async def send_text(self, payload: str):
        self.text_messages.append(json.loads(payload))

    async def send_bytes(self, payload: bytes):
        self.binary_messages.append(payload)


class TrackingEncoder:
    def __init__(self):
        self.calls: list[tuple[int, int]] = []

    def encode(self, frame_bytes: bytes, frame_size: int) -> bytes:
        self.calls.append((len(frame_bytes), frame_size))
        return f"frame-{len(self.calls)}".encode("utf-8")


class RecordingEncoder:
    def __init__(self):
        self.frames: list[bytes] = []
        self.frame_sizes: list[int] = []

    def encode(self, frame_bytes: bytes, frame_size: int) -> bytes:
        self.frames.append(frame_bytes)
        self.frame_sizes.append(frame_size)
        return b"frame"


class AudioPipelineTests(unittest.IsolatedAsyncioTestCase):
    def test_apply_pcm_gain_s16le_amplifies_with_clipping_protection(self):
        pcm = struct.pack("<4h", 1000, -1000, 25000, -25000)

        boosted = apply_pcm_gain_s16le(pcm, 2.0)

        self.assertEqual(struct.unpack("<4h", boosted), (2000, -2000, 32767, -32768))

    def test_split_text_for_tts_keeps_sentence_endings_with_chunks(self):
        chunks = split_text_for_tts("Hello there. How are you?")

        self.assertEqual(chunks, ["Hello there.", " How are you?"])

    def test_split_text_for_tts_keeps_text_without_punctuation_in_one_chunk(self):
        chunks = split_text_for_tts("Just keep speaking without punctuation")

        self.assertEqual(chunks, ["Just keep speaking without punctuation"])

    def test_split_text_for_tts_keeps_final_tail_without_punctuation(self):
        chunks = split_text_for_tts("Sentence one. Final tail")

        self.assertEqual(chunks, ["Sentence one.", " Final tail"])

    def test_prepare_tts_chunks_logs_and_preserves_text_when_joined(self):
        text = "First sentence. Second one without final punctuation"

        chunks = prepare_tts_chunks(text, session_id="session-1")

        self.assertEqual("".join(chunks), text)
        self.assertEqual(chunks[0], "First sentence.")
        self.assertTrue(chunks[-1].endswith("punctuation"))

    def test_build_tts_response_from_streamed_tokens_uses_full_response_for_chunking(self):
        tokens = [
            "This",
            " is",
            " one",
            " sentence.",
            " Final",
            " tail",
        ]

        full_response, text_chunks = build_tts_response_from_streamed_tokens(
            tokens,
            session_id="session-1",
        )

        self.assertEqual(full_response, "This is one sentence. Final tail")
        self.assertEqual("".join(text_chunks), full_response)
        self.assertEqual(len(text_chunks), 2)
        self.assertEqual(text_chunks[0], "This is one sentence.")
        self.assertTrue(text_chunks[-1].endswith("Final tail"))

    def test_create_output_encoder_returns_distinct_instances(self):
        encoder_a = create_output_encoder()
        encoder_b = create_output_encoder()

        try:
            self.assertIsNot(encoder_a, encoder_b)
            self.assertTrue(hasattr(encoder_a, "encode"))
            self.assertTrue(hasattr(encoder_b, "encode"))
        finally:
            dispose_output_encoder(encoder_a)
            dispose_output_encoder(encoder_b)

    def test_dedicated_output_encoders_keep_stream_state_isolated(self):
        stream_a = build_tone_frames(440)
        stream_b = build_tone_frames(880)

        encoder_a = create_output_encoder()
        encoder_b = create_output_encoder()
        shared_encoder = create_output_encoder()

        try:
            dedicated_packets_a = [encoder_a.encode(frame, FRAME_SIZE) for frame in stream_a]
            dedicated_packets_b = [encoder_b.encode(frame, FRAME_SIZE) for frame in stream_b]

            shared_packets_a = []
            shared_packets_b = []
            for frame_a, frame_b in zip(stream_a, stream_b):
                shared_packets_a.append(shared_encoder.encode(frame_a, FRAME_SIZE))
                shared_packets_b.append(shared_encoder.encode(frame_b, FRAME_SIZE))

            dedicated_a = decode_packets(dedicated_packets_a)
            dedicated_b = decode_packets(dedicated_packets_b)
            shared_a = decode_packets(shared_packets_a)
            shared_b = decode_packets(shared_packets_b)

            self.assertGreater(mean_absolute_error(shared_a, dedicated_a), 100)
            self.assertGreater(mean_absolute_error(shared_b, dedicated_b), 100)
        finally:
            dispose_output_encoder(encoder_a)
            dispose_output_encoder(encoder_b)
            dispose_output_encoder(shared_encoder)

    async def test_synthesize_and_send_audio_uses_the_passed_encoder(self):
        full_frame = b"\x01\x00" * FRAME_SIZE
        partial_frame = b"\x02\x00" * 20
        tts_service = FakeTTSService([full_frame, partial_frame])
        websocket = FakeWebSocket()
        encoder = TrackingEncoder()

        await synthesize_and_send_audio(
            ["hello"],
            tts_service,
            websocket,
            "session-1",
            output_encoder=encoder,
        )

        # 1 full frame + 1 padded full-size final frame for the short tail.
        self.assertEqual(encoder.calls[0], (FRAME_SIZE * 2, FRAME_SIZE))
        self.assertEqual(encoder.calls[1], (FRAME_SIZE * 2, FRAME_SIZE))
        self.assertEqual(len(encoder.calls), 2)
        self.assertEqual(len(websocket.binary_messages), 2)
        self.assertEqual(websocket.text_messages[0]["state"], "start")
        self.assertEqual(websocket.text_messages[-1]["state"], "stop")

    async def test_synthesize_and_send_audio_applies_tts_output_gain(self):
        sample_value = 1000
        full_frame = struct.pack(f"<{FRAME_SIZE}h", *([sample_value] * FRAME_SIZE))
        tts_service = FakeTTSService([full_frame], output_gain=2.0)
        websocket = FakeWebSocket()
        encoder = RecordingEncoder()

        await synthesize_and_send_audio(
            ["hello"],
            tts_service,
            websocket,
            "session-1",
            output_encoder=encoder,
        )

        first_frame = struct.unpack(f"<{FRAME_SIZE}h", encoder.frames[0])
        self.assertTrue(all(sample == sample_value * 2 for sample in first_frame))

    async def test_synthesize_and_send_audio_pads_partial_tail_to_one_full_frame(self):
        full_frame = b"\x01\x00" * FRAME_SIZE
        residual_frame = b"\x02\x00" * 120
        tts_service = FakeTTSService([full_frame, residual_frame])
        websocket = FakeWebSocket()
        encoder = RecordingEncoder()

        await synthesize_and_send_audio(
            ["hello"],
            tts_service,
            websocket,
            "session-1",
            output_encoder=encoder,
        )

        self.assertEqual(encoder.frame_sizes, [FRAME_SIZE, FRAME_SIZE])
        padded_samples = struct.unpack(f"<{FRAME_SIZE}h", encoder.frames[1])
        self.assertEqual(padded_samples[:120], (2,) * 120)
        self.assertEqual(padded_samples[120:], (0,) * (FRAME_SIZE - 120))
        self.assertEqual(len(websocket.binary_messages), 2)

    async def test_synthesize_and_send_audio_pads_short_tail_to_one_full_frame(self):
        partial_frame = b"\x03\x00" * 20
        tts_service = FakeTTSService([partial_frame])
        websocket = FakeWebSocket()
        encoder = RecordingEncoder()

        await synthesize_and_send_audio(
            ["hello"],
            tts_service,
            websocket,
            "session-1",
            output_encoder=encoder,
        )

        self.assertEqual(encoder.frame_sizes, [FRAME_SIZE])
        padded_samples = struct.unpack(f"<{FRAME_SIZE}h", encoder.frames[0])
        self.assertEqual(padded_samples[:20], (3,) * 20)
        self.assertEqual(padded_samples[20:], (0,) * (FRAME_SIZE - 20))
        self.assertEqual(len(websocket.binary_messages), 1)


if __name__ == "__main__":
    unittest.main()
