"""Direct Anthropic API backend for schema generation.

Uses Claude via the Anthropic API with a standard API key.
"""

import json
import os
import time
import re
import sys

from models import MantaraSchema

try:
    import anthropic
    HAS_ANTHROPIC = True
except ImportError:
    HAS_ANTHROPIC = False


def _log(msg: str):
    print(f"  [anthropic] {msg}", file=sys.stderr)


class AnthropicBackend:
    """Schema generator using Claude via direct Anthropic API."""

    def __init__(self):
        if not HAS_ANTHROPIC:
            raise RuntimeError(
                "anthropic package not installed. Run: pip install anthropic"
            )

        self.api_key = os.getenv("ANTHROPIC_API_KEY", "")
        self.model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-6")
        self.max_tokens = int(os.getenv("MANTARA_MAX_TOKENS", "16000"))
        self.temperature = float(os.getenv("MANTARA_TEMPERATURE", "0.2"))
        self.timeout = int(os.getenv("MANTARA_TIMEOUT", "600"))

        if not self.api_key:
            raise RuntimeError("ANTHROPIC_API_KEY not set.")

    def _create_client(self):
        """Create Anthropic client."""
        return anthropic.Anthropic(
            api_key=self.api_key,
            timeout=self.timeout,
        )

    def _extract_json(self, text: str) -> dict:
        """Extract JSON from Claude's response, handling markdown fences."""
        # Strategy 1: Direct parse
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            pass

        # Strategy 2: Strip markdown fences
        cleaned = text.strip()
        if cleaned.startswith("```"):
            cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned)
            cleaned = re.sub(r"\n?```\s*$", "", cleaned)
            try:
                return json.loads(cleaned)
            except json.JSONDecodeError:
                pass

        # Strategy 3: Find outermost braces
        match = re.search(r"\{[\s\S]*\}", text)
        if match:
            try:
                return json.loads(match.group())
            except json.JSONDecodeError:
                pass

        raise ValueError(f"Could not extract valid JSON from response (length: {len(text)})")

    def generate(self, system_prompt: str, user_input: str, model: str | None = None) -> MantaraSchema:
        """Generate MantaraSchema via Anthropic API.

        Retries up to 3 times with exponential backoff.
        """
        client = self._create_client()
        use_model = model or self.model

        last_error = None
        for attempt in range(3):
            try:
                start = time.time()

                response = client.messages.create(
                    model=use_model,
                    max_tokens=self.max_tokens,
                    temperature=self.temperature,
                    system=system_prompt,
                    messages=[{"role": "user", "content": user_input}],
                )

                elapsed = round(time.time() - start, 1)

                # Log usage
                if hasattr(response, "usage") and response.usage:
                    _log(
                        f"model={use_model}  "
                        f"input_tokens={response.usage.input_tokens}  "
                        f"output_tokens={response.usage.output_tokens}  "
                        f"latency={elapsed}s"
                    )

                # Extract text content
                text_content = None
                for block in response.content:
                    if block.type == "text":
                        text_content = block.text
                        break

                if not text_content:
                    raise ValueError("No text content in Claude response")

                # Parse JSON
                data = self._extract_json(text_content)

                # Validate against MantaraSchema
                schema = MantaraSchema(**data)
                return schema

            except Exception as e:
                last_error = e
                if attempt < 2:
                    wait = (2 ** attempt) * 1  # 1s, 2s
                    _log(f"Attempt {attempt + 1} failed: {e}. Retrying in {wait}s...")
                    time.sleep(wait)

        raise RuntimeError(
            f"Anthropic backend failed after 3 attempts. Last error: {last_error}"
        )
