#!/usr/bin/env python3
"""Mantara Eval Harness — run test prompts and score outputs.

Usage:
    python eval.py                      # Run all prompts in test_prompts/
    python eval.py --model gpt-4o-mini  # Use a different model
    python eval.py --prompt hospital    # Run a single prompt by name
    python eval.py --json               # Output results as JSON
"""

import argparse
import json
import sys
import time
from pathlib import Path

from config import MODEL, OUTPUT_DIR
from generator import generate


PROMPTS_DIR = Path(__file__).parent / "test_prompts"


def discover_prompts(filter_name: str | None = None) -> list[tuple[str, str]]:
    """Return list of (name, prompt_text) from test_prompts/ directory."""
    prompts = []
    for f in sorted(PROMPTS_DIR.glob("*.txt")):
        name = f.stem
        if filter_name and filter_name != name:
            continue
        prompts.append((name, f.read_text().strip()))
    return prompts


def count_tables(schema) -> int:
    return sum(
        len(sub.tables or [])
        for menu in schema.menus
        for sub in menu.submenus
    )


def count_enums(schema) -> int:
    return len(schema.enum_types or [])


def run_eval(prompts: list[tuple[str, str]], model: str | None = None) -> list[dict]:
    """Run each prompt through the pipeline and collect results."""
    results = []
    use_model = model or MODEL

    for name, prompt_text in prompts:
        print(f"\n{'='*60}", file=sys.stderr)
        print(f"  EVAL: {name}", file=sys.stderr)
        print(f"  Model: {use_model}", file=sys.stderr)
        print(f"{'='*60}", file=sys.stderr)

        result_entry = {
            "name": name,
            "model": use_model,
            "prompt_length": len(prompt_text),
            "status": "error",
            "error": None,
            "latency_seconds": None,
            "tables": None,
            "enums": None,
            "menus": None,
            "validation_errors": None,
            "validation_pass": None,
        }

        try:
            start = time.time()
            result = generate(prompt_text, model=model)
            elapsed = round(time.time() - start, 1)

            schema = result["schema"]
            validation = result["validation"]

            result_entry.update({
                "status": "pass" if validation["is_valid"] else "fail",
                "latency_seconds": elapsed,
                "tables": count_tables(schema),
                "enums": count_enums(schema),
                "menus": len(schema.menus),
                "validation_errors": validation["errors"],
                "validation_pass": validation["is_valid"],
            })

            # Save outputs
            OUTPUT_DIR.mkdir(exist_ok=True)
            json_path = OUTPUT_DIR / f"eval_{name}.json"
            sql_path = OUTPUT_DIR / f"eval_{name}.sql"
            json_path.write_text(result["json_str"])
            sql_path.write_text(result["sql_str"])

            print(f"  Result: {'PASS' if validation['is_valid'] else 'FAIL'}", file=sys.stderr)
            print(f"  Tables: {count_tables(schema)}, ENUMs: {count_enums(schema)}, Menus: {len(schema.menus)}", file=sys.stderr)
            print(f"  Latency: {elapsed}s", file=sys.stderr)
            if validation["errors"]:
                print(f"  Errors:", file=sys.stderr)
                for e in validation["errors"]:
                    print(f"    - {e}", file=sys.stderr)

        except Exception as e:
            result_entry["error"] = f"{type(e).__name__}: {e}"
            print(f"  ERROR: {result_entry['error']}", file=sys.stderr)

        results.append(result_entry)

    return results


def print_summary_table(results: list[dict]):
    """Print a formatted summary table to stdout."""
    print()
    print(f"{'Prompt':<15} {'Status':<8} {'Tables':>6} {'ENUMs':>6} {'Menus':>6} {'Errors':>7} {'Latency':>8}")
    print("-" * 62)

    for r in results:
        name = r["name"][:14]
        status = r["status"].upper()
        tables = str(r["tables"] or "-")
        enums = str(r["enums"] or "-")
        menus = str(r["menus"] or "-")
        n_errors = str(len(r["validation_errors"])) if r["validation_errors"] is not None else "-"
        latency = f"{r['latency_seconds']}s" if r["latency_seconds"] else "-"

        print(f"{name:<15} {status:<8} {tables:>6} {enums:>6} {menus:>6} {n_errors:>7} {latency:>8}")

    # Summary line
    total = len(results)
    passed = sum(1 for r in results if r["status"] == "pass")
    print("-" * 62)
    print(f"{'TOTAL':<15} {passed}/{total} pass")
    print()


def main():
    parser = argparse.ArgumentParser(description="Mantara Eval Harness")
    parser.add_argument("--model", "-m", type=str, default=None, help=f"Model to use (default: {MODEL})")
    parser.add_argument("--prompt", "-p", type=str, default=None, help="Run a single prompt by name")
    parser.add_argument("--json", action="store_true", help="Output results as JSON")
    args = parser.parse_args()

    prompts = discover_prompts(args.prompt)
    if not prompts:
        print(f"No prompts found in {PROMPTS_DIR}/", file=sys.stderr)
        if args.prompt:
            print(f"  (filter: '{args.prompt}')", file=sys.stderr)
        sys.exit(1)

    print(f"\nMantara Eval — {len(prompts)} prompt(s)", file=sys.stderr)

    results = run_eval(prompts, model=args.model)

    if args.json:
        print(json.dumps(results, indent=2))
    else:
        print_summary_table(results)


if __name__ == "__main__":
    main()
