import asyncio
import importlib
import json
import os
import sys
import types
import unittest
from unittest.mock import AsyncMock, patch

from fastapi import WebSocketDisconnect

from app.router.voice_handler import receive_client_hello


class FakeHandshakeWebSocket:
    def __init__(self, messages=None, wait_forever: bool = False):
        self.messages = list(messages or [])
        self.wait_forever = wait_forever
        self.close_calls: list[dict] = []

    async def receive(self):
        if self.messages:
            return self.messages.pop(0)
        if self.wait_forever:
            await asyncio.Event().wait()
        raise AssertionError("No queued websocket message")

    async def close(self, code: int, reason: str = ""):
        self.close_calls.append({"code": code, "reason": reason})


class FakeEndpointWebSocket:
    def __init__(self):
        self.accepted = False
        self.close_calls: list[dict] = []
        self.text_messages: list[dict] = []
        self.client = type("Client", (), {"host": "127.0.0.1"})()
        self.app = type("App", (), {"state": type("State", (), {})()})()

    async def accept(self):
        self.accepted = True

    async def close(self, code: int, reason: str = ""):
        self.close_calls.append({"code": code, "reason": reason})

    async def send_text(self, payload: str):
        self.text_messages.append(json.loads(payload))


def import_voice_router_with_dependency_stubs():
    stubbed_modules = {
        "soniox": types.ModuleType("soniox"),
        "soniox.types": types.ModuleType("soniox.types"),
        "soniox.realtime": types.ModuleType("soniox.realtime"),
        "soniox.realtime.async_stt": types.ModuleType("soniox.realtime.async_stt"),
        "cartesia": types.ModuleType("cartesia"),
        "cartesia.types": types.ModuleType("cartesia.types"),
        "cartesia.types.websocket_response": types.ModuleType(
            "cartesia.types.websocket_response"
        ),
    }
    stubbed_modules["soniox"].AsyncSonioxClient = type("AsyncSonioxClient", (), {})
    stubbed_modules["soniox.types"].RealtimeSTTConfig = type(
        "RealtimeSTTConfig",
        (),
        {},
    )
    stubbed_modules["soniox.types"].RealtimeEvent = type("RealtimeEvent", (), {})
    stubbed_modules["soniox.realtime.async_stt"].AsyncRealtimeSTTSession = type(
        "AsyncRealtimeSTTSession",
        (),
        {},
    )
    stubbed_modules["cartesia"].AsyncCartesia = type("AsyncCartesia", (), {})
    for name in ("Chunk", "Done", "FlushDone", "Error"):
        setattr(
            stubbed_modules["cartesia.types.websocket_response"],
            name,
            type(name, (), {}),
        )

    sys.modules.pop("app.router.voice", None)
    with patch.dict(sys.modules, stubbed_modules):
        return importlib.import_module("app.router.voice")


class VoiceHandshakeTests(unittest.IsolatedAsyncioTestCase):
    async def test_receive_client_hello_ignores_noise_until_hello(self):
        websocket = FakeHandshakeWebSocket(
            messages=[
                {"type": "websocket.receive", "bytes": b"\x00\x01"},
                {"type": "websocket.receive", "text": "not-json"},
                {
                    "type": "websocket.receive",
                    "text": json.dumps({"type": "ping"}),
                },
                {
                    "type": "websocket.receive",
                    "text": json.dumps({"type": "hello", "version": 1}),
                },
            ]
        )

        hello_data = await receive_client_hello(websocket, timeout=0.1)

        self.assertEqual(hello_data["type"], "hello")
        self.assertEqual(hello_data["version"], 1)
        self.assertEqual(websocket.close_calls, [])

    async def test_receive_client_hello_closes_with_4001_on_timeout(self):
        websocket = FakeHandshakeWebSocket(wait_forever=True)

        with self.assertRaises(TimeoutError):
            await receive_client_hello(websocket, timeout=0.01)

        self.assertEqual(
            websocket.close_calls,
            [{"code": 4001, "reason": "Hello timeout"}],
        )

    async def test_receive_client_hello_preserves_disconnect_details(self):
        websocket = FakeHandshakeWebSocket(
            messages=[
                {
                    "type": "websocket.disconnect",
                    "code": 1001,
                    "reason": "client left",
                }
            ]
        )

        with self.assertRaises(WebSocketDisconnect) as context:
            await receive_client_hello(websocket, timeout=0.1)

        self.assertEqual(context.exception.code, 1001)
        self.assertEqual(context.exception.reason, "client left")
        self.assertEqual(websocket.close_calls, [])

    async def test_websocket_endpoint_keeps_timeout_out_of_generic_1011_path(self):
        websocket = FakeEndpointWebSocket()
        voice_router = import_voice_router_with_dependency_stubs()

        with patch(
            "app.router.voice._vh.voice_state",
            new=object(),
        ), patch(
            "app.router.voice.authenticate_voice_client",
            new=AsyncMock(return_value=True),
        ), patch(
            "app.router.voice.receive_client_hello",
            new=AsyncMock(side_effect=asyncio.TimeoutError()),
        ):
            await voice_router.websocket_endpoint(
                websocket,
                device_id_header="device-1",
                client_id_header="client-1",
            )

        self.assertTrue(websocket.accepted)
        self.assertEqual(websocket.close_calls, [])

    async def test_websocket_endpoint_connects_minimax_with_session_scoped_soniox(self):
        websocket = FakeEndpointWebSocket()
        voice_router = import_voice_router_with_dependency_stubs()
        voice_state = types.SimpleNamespace(llm_service=object())
        agent = types.SimpleNamespace(
            tts_provider="minimax",
            voice_id="voice-123",
            tts_model_id="speech-2.8-hd",
        )
        profile = types.SimpleNamespace()
        soniox_client = types.SimpleNamespace(
            connect=AsyncMock(),
            close=AsyncMock(),
        )
        tts_service = types.SimpleNamespace(
            connect=AsyncMock(),
            close=AsyncMock(),
        )
        summary_repo = types.SimpleNamespace(get_by_session_id=AsyncMock(return_value=None))
        memory = types.SimpleNamespace(init_memory=lambda *args, **kwargs: None)
        dialogue = types.SimpleNamespace(update_system_message=lambda *args, **kwargs: None)
        output_encoder = object()
        run_receive_loop = AsyncMock(return_value=None)

        with patch("app.router.voice._vh.voice_state", new=voice_state), patch(
            "app.router.voice.authenticate_voice_client",
            new=AsyncMock(return_value=True),
        ), patch(
            "app.router.voice.receive_client_hello",
            new=AsyncMock(return_value={"type": "hello", "version": 1}),
        ), patch(
            "app.router.voice.build_welcome_message",
            return_value={"type": "hello_ack"},
        ), patch(
            "app.router.voice.resolve_agent_profile",
            new=AsyncMock(return_value=(profile, agent)),
        ), patch(
            "app.router.voice.build_system_prompt",
            new=AsyncMock(return_value="system prompt"),
        ), patch(
            "app.router.voice.ConversationSummaryRepository",
            return_value=summary_repo,
        ), patch(
            "app.router.voice.MemoryProvider",
            return_value=memory,
        ), patch(
            "app.router.voice.Dialogue",
            return_value=dialogue,
        ), patch(
            "app.router.voice.RbChatHistoryRepository",
            return_value=object(),
        ), patch(
            "app.router.voice.create_output_encoder",
            return_value=output_encoder,
        ), patch(
            "app.router.voice.dispose_output_encoder",
        ) as dispose_output_encoder, patch(
            "app.router.voice.SonioxRTClient",
            return_value=soniox_client,
        ), patch(
            "app.router.voice.MiniMaxTTSService",
            return_value=tts_service,
        ), patch(
            "app.router.voice.run_receive_loop",
            new=run_receive_loop,
        ):
            await voice_router.websocket_endpoint(
                websocket,
                device_id_header="device-1",
                client_id_header="client-1",
            )

        self.assertTrue(websocket.accepted)
        self.assertEqual(websocket.text_messages[0], {"type": "hello_ack"})
        soniox_client.connect.assert_awaited_once()
        tts_service.connect.assert_awaited_once()
        run_receive_loop.assert_awaited_once()
        soniox_client.close.assert_awaited_once()
        tts_service.close.assert_awaited_once()
        dispose_output_encoder.assert_called_once_with(output_encoder)
        self.assertEqual(websocket.close_calls, [])

    async def test_websocket_endpoint_closes_soniox_if_minimax_connect_fails(self):
        websocket = FakeEndpointWebSocket()
        voice_router = import_voice_router_with_dependency_stubs()
        voice_state = types.SimpleNamespace(llm_service=object())
        agent = types.SimpleNamespace(
            tts_provider="minimax",
            voice_id="voice-123",
            tts_model_id="speech-2.8-hd",
        )
        profile = types.SimpleNamespace()
        soniox_client = types.SimpleNamespace(
            connect=AsyncMock(),
            close=AsyncMock(),
        )
        tts_service = types.SimpleNamespace(
            connect=AsyncMock(side_effect=RuntimeError("minimax down")),
            close=AsyncMock(),
        )
        summary_repo = types.SimpleNamespace(get_by_session_id=AsyncMock(return_value=None))
        memory = types.SimpleNamespace(init_memory=lambda *args, **kwargs: None)
        dialogue = types.SimpleNamespace(update_system_message=lambda *args, **kwargs: None)
        run_receive_loop = AsyncMock(return_value=None)

        with patch("app.router.voice._vh.voice_state", new=voice_state), patch(
            "app.router.voice.authenticate_voice_client",
            new=AsyncMock(return_value=True),
        ), patch(
            "app.router.voice.receive_client_hello",
            new=AsyncMock(return_value={"type": "hello", "version": 1}),
        ), patch(
            "app.router.voice.build_welcome_message",
            return_value={"type": "hello_ack"},
        ), patch(
            "app.router.voice.resolve_agent_profile",
            new=AsyncMock(return_value=(profile, agent)),
        ), patch(
            "app.router.voice.build_system_prompt",
            new=AsyncMock(return_value="system prompt"),
        ), patch(
            "app.router.voice.ConversationSummaryRepository",
            return_value=summary_repo,
        ), patch(
            "app.router.voice.MemoryProvider",
            return_value=memory,
        ), patch(
            "app.router.voice.Dialogue",
            return_value=dialogue,
        ), patch(
            "app.router.voice.RbChatHistoryRepository",
            return_value=object(),
        ), patch(
            "app.router.voice.create_output_encoder",
            return_value=object(),
        ), patch(
            "app.router.voice.dispose_output_encoder",
        ) as dispose_output_encoder, patch(
            "app.router.voice.SonioxRTClient",
            return_value=soniox_client,
        ), patch(
            "app.router.voice.MiniMaxTTSService",
            return_value=tts_service,
        ), patch(
            "app.router.voice.run_receive_loop",
            new=run_receive_loop,
        ):
            await voice_router.websocket_endpoint(
                websocket,
                device_id_header="device-1",
                client_id_header="client-1",
            )

        soniox_client.connect.assert_awaited_once()
        soniox_client.close.assert_awaited_once()
        tts_service.connect.assert_awaited_once()
        tts_service.close.assert_awaited_once()
        run_receive_loop.assert_not_awaited()
        dispose_output_encoder.assert_called_once()
        self.assertEqual(
            websocket.text_messages[-1],
            {"type": "error", "message": "Failed to connect to TTS service"},
        )
        self.assertEqual(websocket.close_calls, [{"code": 1011, "reason": ""}])


async def smoke_ws():
    import websockets

    uri = os.getenv("VOICE_WS_URL", "ws://localhost:8003/xiaozhi/v1/")
    headers = {"device-id": "test-device-123", "client-id": "test-client-123"}

    print(f"Connecting to {uri}...")
    try:
        async with websockets.connect(uri, additional_headers=headers) as websocket:
            print("Connected successfully!")

            hello_msg = {"type": "hello", "version": 1}
            print(f"Sending hello: {hello_msg}")
            await websocket.send(json.dumps(hello_msg))

            response = await websocket.recv()
            print(f"Server response: {response}")

            dummy_opus = b"\x00\x01\x02\x03" * 10
            print(f"Sending {len(dummy_opus)} bytes of dummy audio...")
            await websocket.send(dummy_opus)

            await asyncio.sleep(1)
            print("Smoke test completed successfully.")
    except Exception as exc:
        print(f"Connection failed: {exc}")


if __name__ == "__main__":
    asyncio.run(smoke_ws())
