"""llama.cpp backend — local LLM inference via llama-cpp-python.

Fallback for when Ollama's Metal backend crashes on macOS Ventura.
Uses CPU inference (n_gpu_layers=0) for compatibility. Slower but reliable.

Setup:
    pip install llama-cpp-python
    # Models are reused from Ollama's blob store
"""

import json
import os
import re
import sys
import time

from llama_cpp import Llama

from config import (
    OLLAMA_MODEL, MAX_TOKENS, TEMPERATURE,
    TIMEOUT_SECONDS, MAX_RETRIES, JSON_SCHEMA_PATH,
)
from models import MantaraSchema


# Map model names to Ollama blob paths (resolved at init time)
_OLLAMA_BLOBS_DIR = os.path.expanduser("~/.ollama/models/blobs")
_OLLAMA_MANIFESTS_DIR = os.path.expanduser("~/.ollama/models/manifests/registry.ollama.ai/library")


def _log(msg: str):
    print(f"  [llama.cpp] {msg}", file=sys.stderr)


def _resolve_model_path(model_name: str) -> str:
    """Resolve an Ollama model name (e.g., 'qwen2.5-coder:7b') to its GGUF blob path."""
    parts = model_name.split(":")
    library = parts[0]
    tag = parts[1] if len(parts) > 1 else "latest"

    manifest_path = os.path.join(_OLLAMA_MANIFESTS_DIR, library, tag)
    if not os.path.exists(manifest_path):
        raise FileNotFoundError(
            f"Model '{model_name}' not found in Ollama store. "
            f"Run: ollama pull {model_name}"
        )

    with open(manifest_path) as f:
        manifest = json.load(f)

    for layer in manifest["layers"]:
        if layer["mediaType"] == "application/vnd.ollama.image.model":
            digest = layer["digest"].replace("sha256:", "")
            blob_path = os.path.join(_OLLAMA_BLOBS_DIR, f"sha256-{digest}")
            if os.path.exists(blob_path):
                return blob_path
            raise FileNotFoundError(f"Model blob not found: {blob_path}")

    raise ValueError(f"No model layer found in manifest for '{model_name}'")


def _load_json_schema() -> str:
    """Load the mantara JSON schema to inject into the prompt."""
    with open(JSON_SCHEMA_PATH) as f:
        schema = json.load(f)
    return json.dumps(schema, indent=2)


def _extract_json(text: str) -> str:
    """Extract JSON from model response, stripping markdown fences if present."""
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r'^```(?:json)?\s*\n?', '', text)
        text = re.sub(r'\n?```\s*$', '', text)
        text = text.strip()

    start = text.find('{')
    if start == -1:
        return text

    depth = 0
    for i in range(start, len(text)):
        if text[i] == '{':
            depth += 1
        elif text[i] == '}':
            depth -= 1
            if depth == 0:
                return text[start:i + 1]

    return text[start:]


_SCHEMA_INSTRUCTION = """

---

OUTPUT FORMAT: You MUST respond with ONLY a valid JSON object — no markdown, no code fences, no explanation.
The JSON must conform exactly to the following JSON Schema:

{json_schema}

CRITICAL:
- Output raw JSON only. Do NOT wrap in ```json``` or any markdown.
- Every field marked "required" MUST be present.
- All pattern constraints (snake_case, schema.table(col) format) MUST be followed.
- ENUM type_name must match pattern: schema_name.xxx_enum
- ENUM values must be lowercase snake_case (no symbols, no spaces).
"""


class LlamaCppBackend:
    """llama-cpp-python backend for local inference."""

    def __init__(self):
        self._json_schema = _load_json_schema()
        self._llm_cache: dict[str, Llama] = {}

    def _get_llm(self, model_name: str) -> Llama:
        """Load and cache a model. First call is slow (~5-13s), subsequent calls instant."""
        if model_name not in self._llm_cache:
            model_path = _resolve_model_path(model_name)
            _log(f"Loading {model_name} from {model_path}...")
            start = time.time()
            self._llm_cache[model_name] = Llama(
                model_path=model_path,
                n_ctx=32768,  # Full context — system prompt is ~15K tokens
                n_gpu_layers=0,  # CPU only — Metal crashes on Ventura
                verbose=False,
            )
            _log(f"Loaded in {time.time() - start:.1f}s")
        return self._llm_cache[model_name]

    def generate(self, system_prompt: str, user_input: str, model: str | None = None) -> MantaraSchema:
        """Generate a MantaraSchema via llama.cpp.

        Injects JSON schema into system prompt and uses JSON mode for output.
        """
        use_model = model or OLLAMA_MODEL
        last_error = None

        full_prompt = system_prompt + _SCHEMA_INSTRUCTION.format(json_schema=self._json_schema)

        for attempt in range(1 + MAX_RETRIES):
            try:
                llm = self._get_llm(use_model)
                start = time.time()

                output = llm.create_chat_completion(
                    messages=[
                        {"role": "system", "content": full_prompt},
                        {"role": "user", "content": user_input},
                    ],
                    response_format={"type": "json_object"},
                    max_tokens=MAX_TOKENS,
                    temperature=TEMPERATURE,
                )

                elapsed = round(time.time() - start, 1)
                usage = output.get("usage", {})
                _log(
                    f"model={use_model}  "
                    f"prompt_tokens={usage.get('prompt_tokens', '?')}  "
                    f"completion_tokens={usage.get('completion_tokens', '?')}  "
                    f"latency={elapsed}s"
                )

                raw = output["choices"][0]["message"]["content"]
                if not raw:
                    raise ValueError("Model returned empty response")

                json_str = _extract_json(raw)
                try:
                    data = json.loads(json_str)
                except json.JSONDecodeError as e:
                    raise ValueError(
                        f"Model returned invalid JSON: {e}\n"
                        f"Raw (first 500 chars): {raw[:500]}"
                    ) from e

                data.pop("$schema", None)

                try:
                    schema = MantaraSchema.model_validate(data)
                except Exception as e:
                    raise ValueError(
                        f"JSON parsed but failed Pydantic validation: {e}\n"
                        f"Keys: {list(data.keys())}"
                    ) from e

                return schema

            except (ConnectionError, OSError) as e:
                last_error = e
                if attempt < MAX_RETRIES:
                    wait = 2 ** (attempt + 1)
                    _log(f"Error (attempt {attempt + 1}): {e} — retrying in {wait}s")
                    time.sleep(wait)
                else:
                    _log(f"All {1 + MAX_RETRIES} attempts failed.")

        raise RuntimeError(
            f"Failed after {1 + MAX_RETRIES} attempts. "
            f"Last error: {type(last_error).__name__}: {last_error}"
        )

    def chat(self, model: str, messages: list, max_tokens: int = 4000,
             temperature: float = 0.2) -> dict:
        """Plain chat completion for V2 pipeline steps (analyze, plan)."""
        use_model = self._resolve_model(model)
        llm = self._get_llm(use_model)

        start = time.time()
        output = llm.create_chat_completion(
            messages=messages,
            response_format={"type": "json_object"},
            max_tokens=max_tokens,
            temperature=temperature,
        )
        elapsed = round(time.time() - start, 1)

        usage = output.get("usage", {})
        _log(f"chat model={use_model}  tokens={usage.get('total_tokens', '?')}  latency={elapsed}s")

        return output

    def _resolve_model(self, model: str) -> str:
        """Map OpenAI model names to local equivalents."""
        mapping = {
            "gpt-4o": OLLAMA_MODEL,
            "gpt-4o-mini": OLLAMA_MODEL,
        }
        return mapping.get(model, model)
