"""Structured LLM call logging — shared across all pipeline steps."""
from __future__ import annotations

import logging
import re
import sys
import threading
from contextlib import contextmanager
from time import perf_counter
from typing import Iterator


def extract_token_usage(response: object) -> dict[str, int]:
    """
    Pull input/output/total token counts from a LangChain Bedrock response.

    Tries two locations:
    - response.response_metadata['usage']  → Bedrock camelCase keys
    - response.usage_metadata              → LangChain snake_case standard
    Returns zeros when neither is present (e.g. mocked responses in tests).
    """
    meta = getattr(response, "response_metadata", None)
    if isinstance(meta, dict):
        usage = meta.get("usage") or meta.get("Usage") or {}
        if usage:
            return {
                "input_tokens": int(usage.get("inputTokens", 0)),
                "output_tokens": int(usage.get("outputTokens", 0)),
                "total_tokens": int(usage.get("totalTokens", 0)),
            }

    usage_meta = getattr(response, "usage_metadata", None)
    if isinstance(usage_meta, dict):
        return {
            "input_tokens": int(usage_meta.get("input_tokens", 0)),
            "output_tokens": int(usage_meta.get("output_tokens", 0)),
            "total_tokens": int(usage_meta.get("total_tokens", 0)),
        }

    return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}


def shorten_model_id(model_id: str) -> str:
    """
    Convert a Bedrock inference-profile ID to a short human-readable label.

    "global.anthropic.claude-sonnet-4-5-20250929-v1:0"  →  "claude-sonnet-4-5"
    """
    name = model_id.split("anthropic.")[-1] if "anthropic." in model_id else model_id
    name = re.sub(r"-\d{8}.*", "", name)
    return name[:40]


def _fmt_tokens(n: int) -> str:
    return f"{n:,}"


def log_llm_response(
    logger: logging.Logger,
    *,
    step: str,
    model: str,
    response: object,
    elapsed_ms: float,
    attempt: int = 1,
    total_attempts: int = 1,
    page_id: str | None = None,
) -> dict[str, int]:
    """Emit one structured INFO line summarising a completed LLM call."""
    usage = extract_token_usage(response)
    short_model = shorten_model_id(model)

    parts: list[str] = [
        f"step={step}",
        f"model={short_model}",
        f"attempt={attempt}/{total_attempts}",
    ]
    if page_id:
        parts.append(f"page={page_id}")
    parts += [
        f"in={_fmt_tokens(usage['input_tokens'])}",
        f"out={_fmt_tokens(usage['output_tokens'])}",
        f"total={_fmt_tokens(usage['total_tokens'])} tok",
        f"{elapsed_ms:,.0f}ms",
    ]

    logger.info("[LLM] %s", " | ".join(parts))
    return usage


@contextmanager
def timed_llm_call(
    logger: logging.Logger,
    *,
    step: str,
    model: str,
    attempt: int = 1,
    total_attempts: int = 1,
    page_id: str | None = None,
) -> Iterator[dict]:
    """
    Context manager that times a model.invoke() block and logs usage.

    Usage::

        with timed_llm_call(logger, step="Page detection", model=model_name,
                             attempt=attempt, total_attempts=max_attempts) as stats:
            stats["response"] = model.invoke(...)

        usage = stats["usage"]
    """
    stats: dict = {}
    t0 = perf_counter()

    short = shorten_model_id(model)
    label = f"{short}  {('page=' + page_id) if page_id else step}"

    print(f"     *  {label}  calling...", flush=True)

    _is_tty = sys.stdout.isatty()
    stop_spinner = threading.Event()
    _FRAMES = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"

    def _spin() -> None:
        i = 0
        next_newline = 10.0
        while not stop_spinner.wait(0.15):
            elapsed = perf_counter() - t0
            if _is_tty:
                frame = _FRAMES[i % len(_FRAMES)]
                print(f"\r     {frame}  {label}  {elapsed:.0f}s ", end="", flush=True)
            if elapsed >= next_newline:
                print(f"     ... {elapsed:.0f}s", flush=True)
                next_newline += 10.0
            i += 1

    spinner_thread = threading.Thread(target=_spin, daemon=True)
    spinner_thread.start()

    def _stop_and_clear() -> None:
        stop_spinner.set()
        spinner_thread.join(timeout=0.5)
        if _is_tty:
            print("\r" + " " * 60 + "\r", end="", flush=True)

    try:
        yield stats
    except Exception:
        _stop_and_clear()
        elapsed_ms = (perf_counter() - t0) * 1000
        logger.error(
            "[LLM] %s FAILED | model=%s | attempt=%s/%s%s | %.0fms",
            step,
            shorten_model_id(model),
            attempt,
            total_attempts,
            f" | page={page_id}" if page_id else "",
            elapsed_ms,
        )
        raise
    else:
        _stop_and_clear()
        elapsed_ms = (perf_counter() - t0) * 1000
        response = stats.get("response")
        usage = log_llm_response(
            logger,
            step=step,
            model=model,
            response=response,
            elapsed_ms=elapsed_ms,
            attempt=attempt,
            total_attempts=total_attempts,
            page_id=page_id,
        )
        stats["usage"] = usage
        stats["elapsed_ms"] = elapsed_ms

        try:
            from shared.run_log import get_active
            rl = get_active()
            if rl is not None:
                rl.record_llm_call(
                    model=model,
                    duration_ms=elapsed_ms,
                    input_tokens=int(usage.get("input_tokens", 0)),
                    output_tokens=int(usage.get("output_tokens", 0)),
                    label=f"{step}{(' · page=' + page_id) if page_id else ''}",
                )
        except Exception:
            pass
