"""Tests for the arq worker's extract_memory job (LLM + I/O mocked)."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

from app.services.memory.memory_schema import (
    MemoryExtraction,
    MemoryFact,
    UserProfile,
)
from app.workers.jobs.memory_jobs import extract_memory

DIALOGUE = [
    {"role": "user", "content": "Mình tên An, có con mèo tên Miu"},
    {"role": "assistant", "content": "Dễ thương quá!"},
]


def _llm_ctx(extraction: MemoryExtraction | Exception):
    """Build a ctx whose llm.with_structured_output(...).ainvoke returns/raises."""
    runnable = MagicMock()
    if isinstance(extraction, Exception):
        runnable.ainvoke = AsyncMock(side_effect=extraction)
    else:
        runnable.ainvoke = AsyncMock(return_value=extraction)
    llm = MagicMock()
    llm.with_structured_output.return_value = runnable
    return {"llm": llm}


async def test_extract_memory_writes_merged_struct_and_caches():
    extraction = MemoryExtraction(
        profile=UserProfile(name="An"),
        facts=[
            MemoryFact(content="An có mèo Miu", category="relationship", salience=0.9)
        ],
    )
    repo = AsyncMock()
    repo.get_memory_struct.return_value = ({}, [])
    cache = AsyncMock()
    with (
        patch(
            "app.workers.jobs.memory_jobs.drain_buffer",
            AsyncMock(return_value=DIALOGUE),
        ),
        patch("app.workers.jobs.memory_jobs.agent_profile_repo", repo),
        patch("app.workers.jobs.memory_jobs.cache_context", cache),
    ):
        await extract_memory(_llm_ctx(extraction), "dev1")

    repo.upsert_memory.assert_awaited_once()
    kwargs = repo.upsert_memory.await_args.kwargs
    assert kwargs["profile"]["name"] == "An"
    assert any(f["content"] == "An có mèo Miu" for f in kwargs["facts"])
    cache.assert_awaited_once()
    assert "An có mèo Miu" in cache.await_args.args[1]


async def test_extract_memory_empty_buffer_is_noop():
    repo = AsyncMock()
    with (
        patch("app.workers.jobs.memory_jobs.drain_buffer", AsyncMock(return_value=[])),
        patch("app.workers.jobs.memory_jobs.agent_profile_repo", repo),
    ):
        await extract_memory(_llm_ctx(MemoryExtraction(profile=UserProfile())), "dev1")
    repo.upsert_memory.assert_not_called()


async def test_extract_memory_llm_error_skips_without_crash():
    repo = AsyncMock()
    repo.get_memory_struct.return_value = ({}, [])
    with (
        patch(
            "app.workers.jobs.memory_jobs.drain_buffer",
            AsyncMock(return_value=DIALOGUE),
        ),
        patch("app.workers.jobs.memory_jobs.agent_profile_repo", repo),
        patch("app.workers.jobs.memory_jobs.cache_context", AsyncMock()),
    ):
        await extract_memory(_llm_ctx(RuntimeError("boom")), "dev1")
    repo.upsert_memory.assert_not_called()
