import json
import unittest
from unittest.mock import AsyncMock, patch

from app.services.tts.minimax import MiniMaxTTSService


class FakeWebSocket:
    def __init__(self, responses: list[dict]):
        self.responses = [json.dumps(response) for response in responses]
        self.sent_messages: list[dict] = []
        self.close_calls = 0

    async def send(self, payload: str):
        self.sent_messages.append(json.loads(payload))

    async def recv(self):
        if not self.responses:
            raise AssertionError("No queued websocket response")
        return self.responses.pop(0)

    async def close(self):
        self.close_calls += 1


class MiniMaxTTSServiceTests(unittest.IsolatedAsyncioTestCase):
    async def test_synthesize_stream_uses_ws_url_and_reuses_socket(self):
        fake_ws = FakeWebSocket(
            [
                {"event": "connected_success"},
                {"event": "task_started"},
                {"data": {"audio": "0100"}},
                {"data": {"audio": "0200"}, "is_final": True},
                {"event": "task_finished"},
                {"event": "task_started"},
                {"data": {"audio": "0300"}, "is_final": True},
            ]
        )

        with patch(
            "app.services.tts.minimax.websockets.connect",
            new=AsyncMock(return_value=fake_ws),
        ) as connect_mock:
            service = MiniMaxTTSService(
                api_key="test-api-key",
                base_url="https://api.minimax.io/",
                model_id="speech-2.8-hd",
                voice_id="voice-123",
                sample_rate=24000,
            )

            first_audio = [chunk async for chunk in service.synthesize_stream(["Hello.", " Next sentence?"])]
            second_audio = [chunk async for chunk in service.synthesize_stream(["Another turn."])]
            await service.close()

        self.assertEqual(first_audio, [bytes.fromhex("0100"), bytes.fromhex("0200")])
        self.assertEqual(second_audio, [bytes.fromhex("0300")])
        connect_mock.assert_awaited_once()
        self.assertEqual(connect_mock.await_args.args[0], "wss://api.minimax.io/ws/v1/t2a_v2")
        self.assertEqual(
            connect_mock.await_args.kwargs["additional_headers"]["Authorization"],
            "Bearer test-api-key",
        )
        self.assertEqual(
            [message["event"] for message in fake_ws.sent_messages],
            [
                "task_start",
                "task_continue",
                "task_continue",
                "task_finish",
                "task_start",
                "task_continue",
                "task_finish",
            ],
        )
        self.assertEqual(
            [message["text"] for message in fake_ws.sent_messages if message["event"] == "task_continue"],
            ["Hello.", " Next sentence?", "Another turn."],
        )
        self.assertEqual(
            fake_ws.sent_messages[0]["audio_setting"],
            {"sample_rate": 24000, "format": "pcm", "channel": 1},
        )
        self.assertEqual(fake_ws.close_calls, 1)

    async def test_synthesize_stream_raises_on_api_error_and_resets_socket(self):
        fake_ws = FakeWebSocket(
            [
                {"event": "connected_success"},
                {"event": "task_started"},
                {
                    "base_resp": {
                        "status_code": 2049,
                        "status_msg": "invalid api key",
                    }
                },
            ]
        )

        with patch(
            "app.services.tts.minimax.websockets.connect",
            new=AsyncMock(return_value=fake_ws),
        ):
            service = MiniMaxTTSService(
                api_key="test-api-key",
                base_url="https://api.minimax.io",
            )

            with self.assertRaises(RuntimeError) as context:
                [chunk async for chunk in service.synthesize_stream(["hello"])]

        self.assertIn("invalid api key", str(context.exception))
        self.assertIsNone(service._ws)
        self.assertEqual(fake_ws.close_calls, 1)
