"""Pydantic schema validation — pure data, no infra."""

from datetime import UTC

import pytest
from pydantic import ValidationError

from app.router.kb_admin_schemas import (
    ArticleCreate,
    ArticleUpdate,
    KBStats,
    RetrievalLogRead,
)


def test_article_create_normalises_tags():
    a = ArticleCreate(
        article_id="voi",
        title="Voi rừng",
        body_md="x" * 100,
        tags=[" Animals ", "DONG-vat", "", "  "],
    )
    assert a.tags == ["animals", "dong-vat"]
    assert a.language == "vi"
    assert a.status == "draft"
    assert a.source_type == "markdown"


def test_article_create_rejects_bad_id():
    with pytest.raises(ValidationError):
        ArticleCreate(article_id="UPPER", title="t", body_md="x" * 50)
    with pytest.raises(ValidationError):
        ArticleCreate(article_id="1-good", title="t", body_md="x" * 50)
    # leading digit is fine
    ArticleCreate(article_id="1abc", title="ttt", body_md="x" * 50)


def test_article_create_rejects_short_body():
    with pytest.raises(ValidationError):
        ArticleCreate(article_id="a", title="title-ok", body_md="too short")


def test_article_create_enforces_tag_cap():
    with pytest.raises(ValidationError):
        ArticleCreate(
            article_id="a",
            title="title-ok",
            body_md="x" * 50,
            tags=[f"t{i}" for i in range(11)],
        )


def test_article_update_all_optional():
    u = ArticleUpdate()
    assert u.model_dump(exclude_none=True) == {}


def test_article_update_tag_trim_and_lower():
    u = ArticleUpdate(tags=[" A ", "b", ""])
    assert u.tags == ["a", "b"]


def test_kbstats_round_trip():
    s = KBStats(
        total_articles=10,
        by_status={"published": 7, "draft": 3},
        by_source={"markdown": 6, "pdf": 4},
        by_language={"vi": 8, "en": 2},
        total_chunks=42,
        retrieval_7d={"hit": 90, "no_match": 5, "error": 2},
        miss_rate_7d=0.072,
    )
    j = s.model_dump()
    assert j["miss_rate_7d"] == 0.072
    assert j["by_source"]["pdf"] == 4


def test_retrieval_log_read_optional_fields():
    from datetime import datetime, timezone

    r = RetrievalLogRead(
        id="abc",
        tenant_id="global",
        user_id=None,
        device_id=None,
        session_id=None,
        query="voi",
        query_lang=None,
        status="hit",
        top_article_ids=["voi"],
        top_chunk_ids=["c1"],
        top_scores=[0.9],
        crag_grade=None,
        stages_ms={"total_ms": 120},
        error=None,
        created_at=datetime.now(UTC),
    )
    assert r.status == "hit"
    assert r.stages_ms["total_ms"] == 120
