"""Unit tests for BargeInDetector service."""

import asyncio
import pytest
from app.services.barge_in import BargeInDetector, BargeInState


class MockToken:
    """Mimics Soniox Token for testing."""
    def __init__(self, text, is_final=False, confidence=0.8):
        self.text = text
        self.is_final = is_final
        self.confidence = confidence


@pytest.mark.asyncio
async def test_ignore_when_idle():
    """Tokens during IDLE state should not trigger interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.9)])
    assert len(fired) == 0
    assert not d.tts_cancel_event.is_set()


@pytest.mark.asyncio
async def test_keyword_interrupt():
    """Keyword with high confidence during AGENT_SPEAKING triggers interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.9)])
    assert len(fired) == 1
    assert d.tts_cancel_event.is_set()
    assert d.state == BargeInState.AGENT_SPEAKING


@pytest.mark.asyncio
async def test_vietnamese_keyword_interrupt():
    """Vietnamese keyword 'dung lai' triggers interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("dung lai", is_final=False, confidence=0.85)])
    assert len(fired) == 1


@pytest.mark.asyncio
async def test_low_confidence_ignored():
    """Low confidence tokens should not trigger interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int, confidence_threshold=0.65)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.3)])
    assert len(fired) == 0
    assert not d.tts_cancel_event.is_set()


@pytest.mark.asyncio
async def test_long_utterance_interrupt():
    """3+ word utterance with high confidence triggers interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int, min_words_for_interrupt=3)
    d.agent_start_speaking()
    tokens = [
        MockToken("i want to ", is_final=False, confidence=0.85),
        MockToken("ask something", is_final=False, confidence=0.80),
    ]
    await d.evaluate_tokens(tokens)
    assert len(fired) == 1


@pytest.mark.asyncio
async def test_short_non_keyword_ignored():
    """Short non-keyword utterance should be ignored."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int, min_words_for_interrupt=3)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("um", is_final=False, confidence=0.7)])
    assert len(fired) == 0


@pytest.mark.asyncio
async def test_double_fire_prevention():
    """Second interrupt during same TTS turn should not fire."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.9)])
    await d.evaluate_tokens([MockToken("wait", is_final=False, confidence=0.9)])
    assert len(fired) == 1


@pytest.mark.asyncio
async def test_final_tokens_ignored():
    """Final (is_final=True) tokens should not trigger interrupt."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=True, confidence=0.9)])
    assert len(fired) == 0


@pytest.mark.asyncio
async def test_no_confidence_ignored():
    """Tokens with confidence=None should be filtered out."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=None)])
    assert len(fired) == 0


@pytest.mark.asyncio
async def test_state_transitions():
    """State transitions: IDLE -> AGENT_SPEAKING -> IDLE."""
    async def on_int(t): pass
    d = BargeInDetector(on_interrupt=on_int)
    assert d.state == BargeInState.IDLE

    d.agent_start_speaking()
    assert d.state == BargeInState.AGENT_SPEAKING
    assert not d.tts_cancel_event.is_set()

    d.agent_stop_speaking()
    assert d.state == BargeInState.IDLE


@pytest.mark.asyncio
async def test_reset():
    """Reset clears all state."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.9)])
    assert d.tts_cancel_event.is_set()

    d.reset()
    assert d.state == BargeInState.IDLE
    assert not d.tts_cancel_event.is_set()


@pytest.mark.asyncio
async def test_interrupt_resets_after_new_tts_turn():
    """New agent_start_speaking() re-enables interrupt detection."""
    fired = []
    async def on_int(t): fired.append(t)
    d = BargeInDetector(on_interrupt=on_int)

    # First TTS turn
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("stop", is_final=False, confidence=0.9)])
    assert len(fired) == 1

    # New TTS turn
    d.agent_stop_speaking()
    d.agent_start_speaking()
    await d.evaluate_tokens([MockToken("wait", is_final=False, confidence=0.9)])
    assert len(fired) == 2
