"""Tests for AgentMemoryProvider: token-threshold buffering/trigger + cached
query_memory. Redis is faked; enqueue + repo are mocked."""

from __future__ import annotations

from unittest.mock import AsyncMock, patch

import fakeredis.aioredis as fakeaio

from app.share.memory.agent_memory_provider import AgentMemoryProvider

DIALOGUE = [
    {"role": "user", "content": "Mình tên An"},
    {"role": "assistant", "content": "Chào An!"},
]


def _provider(device="dev1"):
    p = AgentMemoryProvider()
    p.init_memory(role_id=device, llm=None)
    return p


async def test_save_memory_below_threshold_does_not_enqueue():
    fake = fakeaio.FakeRedis(decode_responses=True)
    enqueue = AsyncMock()
    with (
        patch(
            "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
        ),
        patch("app.share.memory.agent_memory_provider.enqueue_extract", enqueue),
    ):
        p = _provider()
        for _ in range(3):  # 3 * 1000 = 3000 < 5000
            await p.save_memory(DIALOGUE, token_count=1000)
        enqueue.assert_not_called()
        assert int(await fake.get("mem:tok:dev1")) == 3000


async def test_save_memory_crossing_threshold_enqueues():
    fake = fakeaio.FakeRedis(decode_responses=True)
    enqueue = AsyncMock()
    with (
        patch(
            "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
        ),
        patch("app.share.memory.agent_memory_provider.enqueue_extract", enqueue),
    ):
        p = _provider()
        await p.save_memory(DIALOGUE, token_count=3000)
        await p.save_memory(DIALOGUE, token_count=3000)  # total 6000 >= 5000
        enqueue.assert_awaited_once_with("dev1")
        # Provider does NOT reset the counter — the worker does, atomically with
        # the buffer drain. Counter keeps the real total until then.
        assert int(await fake.get("mem:tok:dev1")) == 6000


async def test_drain_buffer_clears_dialogue_and_resets_counter():
    from app.services.memory.memory_buffer import drain_buffer, push_turn

    fake = fakeaio.FakeRedis(decode_responses=True)
    with patch(
        "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
    ):
        await push_turn("dev1", DIALOGUE, 3000)
        await push_turn("dev1", DIALOGUE, 3000)
        drained = await drain_buffer("dev1")
        assert len(drained) == 4  # 2 turns * 2 msgs, flattened
        assert await fake.get("mem:tok:dev1") is None  # counter reset
        assert await fake.lrange("mem:buf:dev1", 0, -1) == []  # buffer cleared


async def test_save_memory_short_dialogue_skipped():
    fake = fakeaio.FakeRedis(decode_responses=True)
    with patch(
        "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
    ):
        p = _provider()
        await p.save_memory([{"role": "user", "content": "hi"}], token_count=9999)
        assert await fake.get("mem:tok:dev1") is None  # nothing buffered


async def test_query_memory_cache_hit_skips_mongo():
    fake = fakeaio.FakeRedis(decode_responses=True)
    await fake.set("mem:ctx:dev1", "cached ctx")
    repo = AsyncMock()
    with (
        patch(
            "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
        ),
        patch("app.share.memory.agent_memory_provider.agent_profile_repo", repo),
    ):
        p = _provider()
        assert await p.query_memory("anything") == "cached ctx"
        repo.get_memory_struct.assert_not_called()


async def test_query_memory_miss_rebuilds_and_caches():
    fake = fakeaio.FakeRedis(decode_responses=True)
    repo = AsyncMock()
    repo.get_memory_struct.return_value = ({"name": "An"}, [])
    with (
        patch(
            "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
        ),
        patch("app.share.memory.agent_memory_provider.agent_profile_repo", repo),
    ):
        p = _provider()
        out = await p.query_memory("x")
        assert "name: An" in out
        assert await fake.get("mem:ctx:dev1") == out  # re-cached


async def test_query_memory_no_struct_returns_empty():
    fake = fakeaio.FakeRedis(decode_responses=True)
    repo = AsyncMock()
    repo.get_memory_struct.return_value = None
    with (
        patch(
            "app.services.memory.memory_buffer.get_redis", AsyncMock(return_value=fake)
        ),
        patch("app.share.memory.agent_memory_provider.agent_profile_repo", repo),
    ):
        p = _provider()
        assert await p.query_memory("x") == ""
