"""Direct boto3 Bedrock client — for steps that call the API without LangChain.

Used by step-03 backend generation.
"""
from __future__ import annotations

import os
from abc import ABC, abstractmethod


class LLMClient(ABC):
    @abstractmethod
    def generate(self, prompt: str, images: list[dict] | None = None) -> str:
        """Send a prompt (and optional images) to the model; return text response.

        Each entry in ``images`` must have keys: ``media_type`` (e.g. ``"image/png"``)
        and ``data`` (base64-encoded string).
        """
        raise NotImplementedError


class BedrockLLMClient(LLMClient):
    def __init__(self) -> None:
        import anthropic

        from shared.config import (
            BACKEND_MODEL_ID,
            BACKEND_TEMPERATURE,
            BEDROCK_REGION,
        )

        self._client = anthropic.AnthropicBedrock(
            aws_region=BEDROCK_REGION,
            aws_access_key=os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("AWS_ACCESS_KEY"),
            aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY") or os.getenv("AWS_SECRET_KEY"),
        )
        self._model_id = BACKEND_MODEL_ID
        # Claude 4+ models reject the temperature parameter entirely
        _model = BACKEND_MODEL_ID or ""
        _is_claude4 = any(p in _model for p in ("claude-opus-4", "claude-sonnet-4", "claude-haiku-4"))
        self._temperature = None if _is_claude4 else BACKEND_TEMPERATURE

    def generate(self, prompt: str, images: list[dict] | None = None) -> str:
        from time import perf_counter

        if images:
            content: list[dict] = []
            for img in images:
                content.append({
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": img["media_type"],
                        "data": img["data"],
                    },
                })
            content.append({"type": "text", "text": prompt})
        else:
            content = [{"type": "text", "text": prompt}]

        stream_kwargs: dict = {
            "model": self._model_id,
            "max_tokens": 32768,
            "messages": [{"role": "user", "content": content}],
        }
        if self._temperature is not None:
            stream_kwargs["temperature"] = self._temperature

        t0 = perf_counter()
        with self._client.messages.stream(
            **stream_kwargs,
        ) as stream:
            text = stream.get_final_text()
            final_message = stream.get_final_message()
        elapsed_ms = (perf_counter() - t0) * 1000

        try:
            from shared.run_log import get_active
            rl = get_active()
            if rl is not None:
                usage = final_message.usage
                rl.record_llm_call(
                    model=self._model_id,
                    duration_ms=elapsed_ms,
                    input_tokens=int(usage.input_tokens),
                    output_tokens=int(usage.output_tokens),
                    label="backend_gen",
                )
        except Exception:
            pass

        return text


def create_llm_client() -> LLMClient:
    return BedrockLLMClient()
