"""Per-run telemetry: steps, LLM calls, models, tokens, durations.

Usage from the orchestrator:

    from shared.run_log import RunLog, set_active

    log = RunLog(run_id="20260430_120000")
    set_active(log)

    with log.step("step-01-input-ingestion"):
        ...

    log.finalize(output_dir)   # writes run_log.json + run_log.md

Any LLM call instrumented through `record_llm_call` (the ir_pipeline
timed_llm_call hook + the backend gen client + the agent) is automatically
attached to the currently-active step.
"""
from __future__ import annotations

import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from time import perf_counter
from typing import Iterator


# ─────────────────────────────────────────────────────────────────────────────
# Data model
# ─────────────────────────────────────────────────────────────────────────────

@dataclass
class LLMCall:
    step: str
    model: str
    label: str | None
    duration_ms: float
    input_tokens: int
    output_tokens: int

    @property
    def total_tokens(self) -> int:
        return self.input_tokens + self.output_tokens

    def to_dict(self) -> dict:
        return {
            "step":          self.step,
            "model":         self.model,
            "label":         self.label,
            "duration_ms":   round(self.duration_ms, 1),
            "input_tokens":  self.input_tokens,
            "output_tokens": self.output_tokens,
            "total_tokens":  self.total_tokens,
        }


@dataclass
class StepEvent:
    name: str
    started_at: str
    duration_ms: float | None = None
    status: str = "running"          # running | ok | failed
    error: str | None = None
    llm_calls: list[LLMCall] = field(default_factory=list)
    notes: dict = field(default_factory=dict)

    @property
    def total_input_tokens(self) -> int:
        return sum(c.input_tokens for c in self.llm_calls)

    @property
    def total_output_tokens(self) -> int:
        return sum(c.output_tokens for c in self.llm_calls)

    @property
    def total_tokens(self) -> int:
        return self.total_input_tokens + self.total_output_tokens

    @property
    def models_used(self) -> list[str]:
        seen, ordered = set(), []
        for c in self.llm_calls:
            if c.model not in seen:
                seen.add(c.model)
                ordered.append(c.model)
        return ordered

    def to_dict(self) -> dict:
        return {
            "name":          self.name,
            "started_at":    self.started_at,
            "duration_ms":   round(self.duration_ms, 1) if self.duration_ms is not None else None,
            "status":        self.status,
            "error":         self.error,
            "models_used":   self.models_used,
            "llm_call_count": len(self.llm_calls),
            "input_tokens":  self.total_input_tokens,
            "output_tokens": self.total_output_tokens,
            "total_tokens":  self.total_tokens,
            "notes":         self.notes,
            "llm_calls":     [c.to_dict() for c in self.llm_calls],
        }


# ─────────────────────────────────────────────────────────────────────────────
# RunLog
# ─────────────────────────────────────────────────────────────────────────────

class RunLog:
    def __init__(self, run_id: str) -> None:
        self.run_id = run_id
        self.started_at = datetime.now(timezone.utc).isoformat(timespec="seconds")
        self._t0 = perf_counter()
        self.duration_ms: float | None = None
        self.steps: list[StepEvent] = []
        self._stack: list[StepEvent] = []
        self.summary: dict = {}

    @contextmanager
    def step(self, name: str, **notes) -> Iterator[StepEvent]:
        ev = StepEvent(
            name=name,
            started_at=datetime.now(timezone.utc).isoformat(timespec="seconds"),
            notes=dict(notes),
        )
        self.steps.append(ev)
        self._stack.append(ev)
        t0 = perf_counter()
        try:
            yield ev
        except Exception as exc:
            ev.duration_ms = (perf_counter() - t0) * 1000
            ev.status = "failed"
            ev.error = repr(exc)
            raise
        else:
            ev.duration_ms = (perf_counter() - t0) * 1000
            ev.status = "ok"
        finally:
            self._stack.pop()

    def record_llm_call(
        self,
        *,
        model: str,
        duration_ms: float,
        input_tokens: int,
        output_tokens: int,
        label: str | None = None,
    ) -> None:
        """Attach an LLM call to the currently-active step (no-op if none)."""
        if not self._stack:
            return
        ev = self._stack[-1]
        ev.llm_calls.append(LLMCall(
            step=ev.name,
            model=model,
            label=label,
            duration_ms=duration_ms,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
        ))

    # ── Summary ─────────────────────────────────────────────────────────────
    @property
    def total_input_tokens(self) -> int:
        return sum(s.total_input_tokens for s in self.steps)

    @property
    def total_output_tokens(self) -> int:
        return sum(s.total_output_tokens for s in self.steps)

    @property
    def total_tokens(self) -> int:
        return self.total_input_tokens + self.total_output_tokens

    @property
    def total_llm_calls(self) -> int:
        return sum(len(s.llm_calls) for s in self.steps)

    def to_dict(self) -> dict:
        return {
            "run_id":         self.run_id,
            "started_at":     self.started_at,
            "duration_ms":    round(self.duration_ms, 1) if self.duration_ms is not None else None,
            "duration_s":     round(self.duration_ms / 1000, 2) if self.duration_ms is not None else None,
            "step_count":     len(self.steps),
            "llm_call_count": self.total_llm_calls,
            "input_tokens":   self.total_input_tokens,
            "output_tokens":  self.total_output_tokens,
            "total_tokens":   self.total_tokens,
            "summary":        self.summary,
            "steps":          [s.to_dict() for s in self.steps],
        }

    # ── Markdown report ─────────────────────────────────────────────────────
    def to_markdown(self) -> str:
        lines: list[str] = []
        lines.append(f"# Run log — {self.run_id}\n")
        lines.append(f"- Started: `{self.started_at}`")
        if self.duration_ms is not None:
            lines.append(f"- Total duration: **{self.duration_ms / 1000:,.2f} s**")
        lines.append(f"- LLM calls: {self.total_llm_calls}")
        lines.append(
            f"- Tokens — in: {self.total_input_tokens:,} · out: "
            f"{self.total_output_tokens:,} · total: **{self.total_tokens:,}**"
        )
        if self.summary:
            lines.append("")
            lines.append("## Project")
            for k, v in self.summary.items():
                lines.append(f"- **{k}**: {v}")
        lines.append("")

        # Per-step table
        lines.append("## Steps")
        lines.append("")
        lines.append("| # | Step | Status | Duration | LLM calls | In tok | Out tok | Total tok | Models |")
        lines.append("|---|------|--------|----------|-----------|--------|---------|-----------|--------|")
        for i, s in enumerate(self.steps, 1):
            dur = f"{s.duration_ms / 1000:,.2f} s" if s.duration_ms is not None else "—"
            models = ", ".join(_short(m) for m in s.models_used) or "—"
            lines.append(
                f"| {i} | `{s.name}` | {s.status} | {dur} | {len(s.llm_calls)} | "
                f"{s.total_input_tokens:,} | {s.total_output_tokens:,} | "
                f"{s.total_tokens:,} | {models} |"
            )
        lines.append("")

        # LLM call detail per step
        for s in self.steps:
            if not s.llm_calls:
                continue
            lines.append(f"### `{s.name}` — LLM calls")
            lines.append("")
            lines.append("| Model | Label | Duration | In tok | Out tok | Total |")
            lines.append("|-------|-------|----------|--------|---------|-------|")
            for c in s.llm_calls:
                lines.append(
                    f"| {_short(c.model)} | {c.label or '—'} | "
                    f"{c.duration_ms / 1000:,.2f} s | "
                    f"{c.input_tokens:,} | {c.output_tokens:,} | "
                    f"{c.total_tokens:,} |"
                )
            lines.append("")

        return "\n".join(lines) + "\n"

    # ── Finalize ────────────────────────────────────────────────────────────
    def finalize(self, output_dir: Path) -> None:
        self.duration_ms = (perf_counter() - self._t0) * 1000
        output_dir.mkdir(parents=True, exist_ok=True)
        (output_dir / "run_log.json").write_text(
            json.dumps(self.to_dict(), indent=2) + "\n", encoding="utf-8"
        )
        (output_dir / "run_log.md").write_text(self.to_markdown(), encoding="utf-8")


# ─────────────────────────────────────────────────────────────────────────────
# Active-run accessor (module-level singleton — simple and explicit)
# ─────────────────────────────────────────────────────────────────────────────

_active: RunLog | None = None


def set_active(log: RunLog | None) -> None:
    global _active
    _active = log


def get_active() -> RunLog | None:
    return _active


def _short(model_id: str) -> str:
    """Short label for tables — strips bedrock prefixes + version tail."""
    import re
    name = model_id.split("anthropic.")[-1] if "anthropic." in model_id else model_id
    return re.sub(r"-\d{8}.*", "", name)[:40]
