"""Unit tests for OmnivoiceTTSService — event contract, tag stripping, barge-in.

No network: the AsyncOpenAI client is replaced with a fake whose streaming
response yields canned PCM chunks.
"""

from __future__ import annotations

import asyncio

from app.services.tts.omnivoice import OmnivoiceTTSService


class _FakeStreamingResponse:
    def __init__(self, chunks):
        self._chunks = chunks

    async def __aenter__(self):
        return self

    async def __aexit__(self, *exc):
        return False

    async def iter_bytes(self):
        for c in self._chunks:
            yield c


class _FakeSpeech:
    def __init__(self, chunks, calls):
        self._chunks = chunks
        self._calls = calls

        class _WSR:
            def create(_self, **kwargs):
                calls.append(kwargs)
                return _FakeStreamingResponse(chunks)

        self.with_streaming_response = _WSR()


class _FakeAudio:
    def __init__(self, chunks, calls):
        self.speech = _FakeSpeech(chunks, calls)


class _FakeClient:
    def __init__(self, chunks, calls):
        self.audio = _FakeAudio(chunks, calls)

    async def close(self):
        return None


def _make_service(chunks):
    svc = OmnivoiceTTSService(api_key="k", voice_vi="vi_voice", voice_en="en_voice")
    calls: list[dict] = []
    svc.client = _FakeClient(chunks, calls)
    return svc, calls


async def _collect(gen):
    out = []
    async for ev in gen:
        out.append(ev)
    return out


async def test_event_contract_per_sentence():
    svc, _ = _make_service([b"\x01\x02", b"\x03\x04"])
    events = await _collect(svc.stream_utterance(["Xin chào.", "Tạm biệt."]))
    # Two sentences → two start/end pairs, audio in between.
    assert events[0] == ("sentence_start", "Xin chào.")
    assert ("audio", b"\x01\x02") in events
    assert events.count(("sentence_end", "")) == 2
    assert events.count(("sentence_start", "Xin chào.")) == 1
    # Order: each sentence's start precedes its audio precedes its end.
    starts = [i for i, e in enumerate(events) if e[0] == "sentence_start"]
    ends = [i for i, e in enumerate(events) if e[0] == "sentence_end"]
    assert starts[0] < ends[0] < starts[1] < ends[1]


async def test_code_switch_uses_per_run_voice():
    svc, calls = _make_service([b"\x00"])
    events = await _collect(svc.stream_utterance(["Con thích <en>dinosaur</en> xanh."]))
    # Single sentence boundary; display text is tag-stripped + concatenated.
    assert events[0] == ("sentence_start", "Con thích dinosaur xanh.")
    assert events.count(("sentence_start", "Con thích dinosaur xanh.")) == 1
    assert events.count(("sentence_end", "")) == 1
    # Three runs (vi / en / vi) → three synth requests, each tag-free.
    inputs = [c["input"] for c in calls]
    assert inputs == ["Con thích ", "dinosaur", " xanh."]
    voices = [c["voice"] for c in calls]
    assert voices == ["vi_voice", "en_voice", "vi_voice"]
    assert all(c["response_format"] == "pcm" for c in calls)
    assert all("<en>" not in c["input"] for c in calls)


async def test_skips_non_speakable_fragments():
    svc, calls = _make_service([b"\x00"])
    events = await _collect(svc.stream_utterance(["..", "  ", "Ổn."]))
    # Only the speakable sentence triggers a synth request.
    assert len(calls) == 1
    assert calls[0]["input"] == "Ổn."
    assert ("sentence_start", "Ổn.") in events


async def test_cancel_before_start_yields_nothing():
    svc, calls = _make_service([b"\x00"])
    cancel = asyncio.Event()
    cancel.set()
    events = await _collect(svc.stream_utterance(["Một câu."], cancel_event=cancel))
    assert events == []
    assert calls == []
