"""Mantara v8 — schema validator (the gate).

Runs AFTER normaliser.normalise() and BEFORE renderer.render_sql(). Catches:

    • FK references to tables that don't exist
    • FK references to columns that don't exist on the target table
    • cfg_* tables missing the canonical column set
    • orphan ENUM types (should have been lifted)
    • duplicate table_names
    • duplicate column names within a table

Optional: if `pglast` is installed, parses the rendered SQL to catch any
DDL syntax issues the structural checks missed.

Returns a list of `Violation` objects. The caller decides whether to raise
or continue with warnings.
"""
from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import Iterable


@dataclass
class Violation:
    code: str           # e.g. "FK_TARGET_MISSING"
    severity: str       # "error" | "warning"
    message: str
    location: str = ""  # "table.column"

    def __str__(self) -> str:
        loc = f" [{self.location}]" if self.location else ""
        return f"{self.severity.upper()} {self.code}{loc}: {self.message}"


@dataclass
class ValidationReport:
    violations: list[Violation] = field(default_factory=list)
    parse_ok: bool | None = None     # None = not attempted

    @property
    def errors(self) -> list[Violation]:
        return [v for v in self.violations if v.severity == "error"]

    @property
    def warnings(self) -> list[Violation]:
        return [v for v in self.violations if v.severity == "warning"]

    @property
    def is_clean(self) -> bool:
        return not self.errors

    def summary(self) -> str:
        parts = [f"{len(self.errors)} errors", f"{len(self.warnings)} warnings"]
        if self.parse_ok is True:
            parts.append("DDL parses ✓")
        elif self.parse_ok is False:
            parts.append("DDL parse FAILED")
        return ", ".join(parts)


# Required cfg_* canonical columns
_REQUIRED_CFG_COLS = {"code", "label", "description", "is_active"}


def _all_tables(schema_json: dict) -> list[dict]:
    out: list[dict] = []
    for menu in schema_json.get("menus", []):
        for sub in menu.get("submenus") or []:
            for tbl in sub.get("tables") or []:
                if isinstance(tbl, dict):
                    out.append(tbl)
    return out


def _table_index(schema_json: dict) -> dict[str, dict]:
    """Map table_name → table dict. menu/submenu added as virtual entries."""
    idx: dict[str, dict] = {
        "menu": {
            "table_name": "menu",
            "columns": [
                {"name": "menu_id"},
                {"name": "menu_name"},
                {"name": "sequence_number"},
                {"name": "description"},
            ],
        },
        "submenu": {
            "table_name": "submenu",
            "columns": [
                {"name": "submenu_id"},
                {"name": "submenu_name"},
                {"name": "menu_id"},
                {"name": "sequence_number"},
                {"name": "description"},
            ],
        },
    }
    for tbl in _all_tables(schema_json):
        name = tbl.get("table_name")
        if name:
            idx[name] = tbl
    return idx


def _parse_fk_ref(ref: str) -> tuple[str | None, str | None]:
    """`schema.table(col)` or `table(col)` → (table, col)."""
    m = re.match(r"\s*(?:\w+\.)?(\w+)\s*\(\s*(\w+)\s*\)\s*$", ref or "")
    if not m:
        return None, None
    return m.group(1), m.group(2)


def _column_names(tbl: dict) -> set[str]:
    return {c.get("name") for c in (tbl.get("columns") or []) if c.get("name")}


# --------------------------------------------------------------------------- #
# Structural checks
# --------------------------------------------------------------------------- #
def check_duplicate_tables(schema_json: dict, report: ValidationReport) -> None:
    seen: dict[str, int] = {}
    for tbl in _all_tables(schema_json):
        name = tbl.get("table_name")
        if not name:
            continue
        seen[name] = seen.get(name, 0) + 1
    for name, count in seen.items():
        if count > 1:
            report.violations.append(Violation(
                code="DUPLICATE_TABLE",
                severity="error",
                message=f"Table '{name}' is defined {count} times",
                location=name,
            ))


def check_duplicate_columns(schema_json: dict, report: ValidationReport) -> None:
    for tbl in _all_tables(schema_json):
        seen: dict[str, int] = {}
        for col in tbl.get("columns") or []:
            n = col.get("name")
            if not n:
                continue
            seen[n] = seen.get(n, 0) + 1
        for col_name, count in seen.items():
            if count > 1:
                report.violations.append(Violation(
                    code="DUPLICATE_COLUMN",
                    severity="error",
                    message=f"Column '{col_name}' defined {count} times",
                    location=f"{tbl.get('table_name')}.{col_name}",
                ))


def check_fk_targets(schema_json: dict, report: ValidationReport) -> None:
    idx = _table_index(schema_json)
    inline_pat = re.compile(
        r"REFERENCES\s+(?:\w+\.)?(\w+)\s*\(\s*(\w+)\s*\)", re.IGNORECASE
    )
    for tbl in _all_tables(schema_json):
        tname = tbl.get("table_name", "")
        # Inline column FKs
        for col in tbl.get("columns") or []:
            cons = col.get("constraints") or ""
            for m in inline_pat.finditer(cons):
                target_table, target_col = m.group(1), m.group(2)
                _check_fk_pair(report, idx, tname, col.get("name", "?"),
                               target_table, target_col)
        # Table-level foreign_keys[]
        for fk in tbl.get("foreign_keys") or []:
            target_table, target_col = _parse_fk_ref(fk.get("references", ""))
            if target_table is None:
                report.violations.append(Violation(
                    code="FK_REF_UNPARSEABLE",
                    severity="error",
                    message=f"Cannot parse FK references string: {fk.get('references')!r}",
                    location=f"{tname}.{fk.get('column', '?')}",
                ))
                continue
            _check_fk_pair(report, idx, tname, fk.get("column", "?"),
                           target_table, target_col)


def _check_fk_pair(
    report: ValidationReport,
    idx: dict[str, dict],
    src_table: str,
    src_col: str,
    target_table: str,
    target_col: str | None,
) -> None:
    if target_table not in idx:
        report.violations.append(Violation(
            code="FK_TARGET_TABLE_MISSING",
            severity="error",
            message=f"FK references table '{target_table}' which is not defined",
            location=f"{src_table}.{src_col}",
        ))
        return
    if target_col is not None:
        target_cols = _column_names(idx[target_table])
        if target_col not in target_cols:
            report.violations.append(Violation(
                code="FK_TARGET_COLUMN_MISSING",
                severity="error",
                message=(
                    f"FK references column '{target_table}.{target_col}' which "
                    f"does not exist on the target table"
                ),
                location=f"{src_table}.{src_col}",
            ))


def check_cfg_canonical_shape(schema_json: dict, report: ValidationReport) -> None:
    for tbl in _all_tables(schema_json):
        name = tbl.get("table_name", "")
        if not name.startswith("cfg_"):
            continue
        cols = _column_names(tbl)
        # PK must be `<cfg_name>_id`
        canonical_pk = f"{name}_id"
        has_pk = False
        for c in tbl.get("columns") or []:
            if c.get("name") == canonical_pk and "PRIMARY KEY" in (c.get("constraints") or "").upper():
                has_pk = True
                break
        if not has_pk:
            report.violations.append(Violation(
                code="CFG_PK_NON_CANONICAL",
                severity="error",
                message=f"cfg_* table must have PK column named '{canonical_pk}'",
                location=name,
            ))
        missing = _REQUIRED_CFG_COLS - cols
        if missing:
            report.violations.append(Violation(
                code="CFG_MISSING_CANONICAL_COLUMNS",
                severity="error",
                message=f"cfg_* table missing required columns: {sorted(missing)}",
                location=name,
            ))


def check_no_residual_enums(schema_json: dict, report: ValidationReport) -> None:
    enums = schema_json.get("enum_types") or []
    if enums:
        report.violations.append(Violation(
            code="RESIDUAL_ENUM_TYPES",
            severity="warning",
            message=(
                f"{len(enums)} enum_types[] entries remain — should have been "
                f"lifted to cfg_* tables. Did normaliser run?"
            ),
        ))


# --------------------------------------------------------------------------- #
# Optional DDL parse smoke-test (pglast)
# --------------------------------------------------------------------------- #
def parse_ddl(sql: str) -> tuple[bool, str]:
    """Best-effort SQL parse. Returns (ok, error_message).

    pglast is optional; if missing we return (True, "pglast not installed —
    parse skipped") so this never blocks the gate.
    """
    try:
        import pglast  # type: ignore
    except ImportError:
        return True, "pglast not installed — parse skipped"
    try:
        pglast.parse_sql(sql)
    except Exception as exc:  # noqa: BLE001
        return False, str(exc)
    return True, ""


# --------------------------------------------------------------------------- #
# Public entry point
# --------------------------------------------------------------------------- #
def validate(schema_json: dict, sql: str | None = None) -> ValidationReport:
    """Run all checks. Pass `sql` to also do a DDL parse smoke-test."""
    report = ValidationReport()
    check_duplicate_tables(schema_json, report)
    check_duplicate_columns(schema_json, report)
    check_fk_targets(schema_json, report)
    check_cfg_canonical_shape(schema_json, report)
    check_no_residual_enums(schema_json, report)

    if sql is not None:
        ok, err = parse_ddl(sql)
        report.parse_ok = ok
        if not ok:
            report.violations.append(Violation(
                code="DDL_PARSE_FAILED",
                severity="error",
                message=f"pglast parse error: {err}",
            ))

    return report


__all__ = ["validate", "Violation", "ValidationReport"]
