"""Bedrock LLM client with Anthropic prompt caching support.

Uses the Anthropic SDK's AnthropicBedrock client (Messages API) instead of
LangChain's ChatBedrockConverse (Converse API) because only the Messages API
supports cache_control markers.

Cache points:
  - system_prompt  → system[] with cache_control (static rules, shared every call)
  - cached_user_prefix → first user content block with cache_control (shared per run, e.g. manifest)
  - user_text      → dynamic per-call content, never cached
  - images         → appended after user_text, not cached

When ENABLE_PROMPT_CACHING=false the client behaves identically but omits
cache_control markers, falling back to standard uncached behaviour.
"""
from __future__ import annotations

from anthropic import AnthropicBedrock

from shared.config import BEDROCK_REGION, ENABLE_PROMPT_CACHING
from shared.logging import get_logger
from shared.media.images import EncodedImage

logger = get_logger(__name__)


class CachingBedrockClient:
    def __init__(
        self,
        model_id: str,
        temperature: float = 0.0,
        max_tokens: int = 32768,
    ) -> None:
        self._client = AnthropicBedrock(aws_region=BEDROCK_REGION)
        self._model_id = model_id
        self._temperature = temperature
        self._max_tokens = max_tokens

    def generate(
        self,
        user_text: str,
        *,
        system_prompt: str | None = None,
        cached_user_prefix: str | None = None,
        images: list[EncodedImage] | None = None,
    ) -> str:
        """Call Bedrock with optional prompt caching.

        Parameters
        ----------
        user_text:
            Dynamic per-call content placed last in the user message. Never cached.
        system_prompt:
            Static instructions placed in the system field. Marked cacheable when
            ENABLE_PROMPT_CACHING is true. Use for large rule/schema blocks that
            are identical across all calls in a run.
        cached_user_prefix:
            Optional first user content block marked cacheable. Use for content
            that is shared across all pages/modules in one run but differs between
            runs (e.g. the API manifest rendered text).
        images:
            EncodedImage objects appended as base64 image blocks after user_text.
        """
        system: list[dict] | None = None
        if system_prompt:
            block: dict = {"type": "text", "text": system_prompt}
            if ENABLE_PROMPT_CACHING:
                block["cache_control"] = {"type": "ephemeral"}
            system = [block]

        content: list[dict] = []
        if cached_user_prefix:
            prefix_block: dict = {"type": "text", "text": cached_user_prefix}
            if ENABLE_PROMPT_CACHING:
                prefix_block["cache_control"] = {"type": "ephemeral"}
            content.append(prefix_block)

        content.append({"type": "text", "text": user_text})

        if images:
            for img in images:
                content.append({
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": img.media_type,
                        "data": img.base64_data,
                    },
                })

        _is_claude4 = any(p in self._model_id for p in ("claude-opus-4", "claude-sonnet-4", "claude-haiku-4"))
        kwargs: dict = {
            "model": self._model_id,
            "max_tokens": self._max_tokens,
            "messages": [{"role": "user", "content": content}],
        }
        if not _is_claude4:
            kwargs["temperature"] = self._temperature
        if system is not None:
            kwargs["system"] = system

        with self._client.messages.stream(**kwargs) as stream:
            msg = stream.get_final_message()

        u = msg.usage
        cache_read = getattr(u, "cache_read_input_tokens", 0) or 0
        cache_write = getattr(u, "cache_creation_input_tokens", 0) or 0
        logger.info(
            "bedrock usage | model=%s input=%d output=%d "
            "cache_read=%d cache_write=%d caching_enabled=%s",
            self._model_id,
            u.input_tokens,
            u.output_tokens,
            cache_read,
            cache_write,
            ENABLE_PROMPT_CACHING,
        )
        if ENABLE_PROMPT_CACHING and cache_read == 0 and cache_write == 0:
            logger.warning("prompt caching enabled but no cache tokens reported — model may not support cache_control")

        return msg.content[0].text
