"""CRAG state machine — confidence bypass + retry path + termination cap."""

import pytest
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import PrivateAttr

import app.settings.setting as ss
from app.services.kb.crag_graph import build_crag_graph


class _Stub(BaseRetriever):
    _docs: list = PrivateAttr()

    def __init__(self, docs):
        super().__init__()
        self._docs = docs

    def _get_relevant_documents(self, q, *, run_manager):
        return self._docs

    async def _aget_relevant_documents(self, q, *, run_manager):
        return self._docs


def _initial(query="x"):
    return {
        "query": query,
        "docs": [],
        "relevance_ok": False,
        "attempt": 0,
        "top_score": None,
        "grade_score": None,
    }


@pytest.mark.asyncio
async def test_high_confidence_bypass_skips_llm_grade(monkeypatch):
    monkeypatch.setattr(ss.settings, "KB_ENABLE_CRAG_GRADE", True)
    docs = [
        Document(
            page_content="x", metadata={"relevance_score": 0.85, "article_id": "a"}
        )
    ]
    g = build_crag_graph(primary_retriever=_Stub(docs), retry_retriever=None)
    out = await g.ainvoke(_initial())
    assert out["relevance_ok"]
    assert out["grade_score"] == 0.85
    assert out["attempt"] == 0  # no retry


@pytest.mark.asyncio
async def test_toggle_off_short_circuits_to_relevance_ok(monkeypatch):
    monkeypatch.setattr(ss.settings, "KB_ENABLE_CRAG_GRADE", False)
    docs = [Document(page_content="x", metadata={"article_id": "a"})]
    g = build_crag_graph(primary_retriever=_Stub(docs), retry_retriever=None)
    out = await g.ainvoke(_initial())
    assert out["relevance_ok"]
    assert out["grade_score"] is None  # never called


@pytest.mark.asyncio
async def test_empty_primary_no_retry_no_results(monkeypatch):
    monkeypatch.setattr(ss.settings, "KB_ENABLE_CRAG_GRADE", True)
    # No retry retriever → empty primary stays empty after one cycle
    g = build_crag_graph(primary_retriever=_Stub([]), retry_retriever=None)
    out = await g.ainvoke(_initial())
    assert out["docs"] == []
    # Attempt should be 1 after retry-with-fallback fired once
    assert out["attempt"] == 1


@pytest.mark.asyncio
async def test_retry_path_uses_alternate_retriever(monkeypatch):
    monkeypatch.setattr(ss.settings, "KB_ENABLE_CRAG_GRADE", True)
    primary = _Stub([])
    retry = _Stub(
        [
            Document(
                page_content="hit",
                metadata={"relevance_score": 0.95, "article_id": "a"},
            )
        ]
    )
    g = build_crag_graph(primary_retriever=primary, retry_retriever=retry)
    out = await g.ainvoke(_initial())
    assert out["attempt"] == 1
    assert len(out["docs"]) == 1
    assert out["docs"][0].page_content == "hit"


@pytest.mark.asyncio
async def test_attempt_cap_terminates_at_one(monkeypatch):
    """A retriever that always returns empty must terminate after one retry."""
    monkeypatch.setattr(ss.settings, "KB_ENABLE_CRAG_GRADE", True)
    g = build_crag_graph(primary_retriever=_Stub([]), retry_retriever=_Stub([]))
    out = await g.ainvoke(_initial())
    assert out["attempt"] == 1  # exactly one retry, then END
