"""Test the middleware's per-turn token summation from provider usage."""

from __future__ import annotations

from app.services.agent.memory_middleware import _sum_turn_tokens


class _Msg:
    def __init__(self, usage_metadata=None, response_metadata=None):
        self.usage_metadata = usage_metadata
        self.response_metadata = response_metadata


def test_sum_from_usage_metadata_across_messages():
    msgs = [
        _Msg(usage_metadata={"total_tokens": 1200}),
        _Msg(usage_metadata={"total_tokens": 800}),  # tool round
    ]
    assert _sum_turn_tokens(msgs) == 2000


def test_falls_back_to_response_metadata():
    msgs = [_Msg(response_metadata={"token_usage": {"total_tokens": 500}})]
    assert _sum_turn_tokens(msgs) == 500


def test_missing_usage_returns_zero():
    assert _sum_turn_tokens([_Msg(), _Msg()]) == 0
