"""Deterministic SQL renderer: converts a validated MantaraSchema to PostgreSQL DDL.

This is the core architectural win — the LLM generates structured JSON, and
Python renders SQL deterministically. No LLM involvement in SQL generation.
"""

import re
from models import MantaraSchema, Table


def _esc(s: str) -> str:
    """Escape single quotes for SQL string literals."""
    return s.replace("'", "''")


def _section(title: str) -> str:
    """Render a section separator comment."""
    bar = "-- " + "=" * 76
    return f"{bar}\n-- {title}\n{bar}"


def _table_title(table_name: str) -> str:
    """Convert snake_case table name to a readable title."""
    return table_name.replace("_", " ").title()


def _get_all_table_names(schema: MantaraSchema) -> set[str]:
    """Collect all business table names defined in the schema."""
    names = set()
    for menu in schema.menus:
        for submenu in menu.submenus:
            if submenu.tables:
                for table in submenu.tables:
                    names.add(table.table_name)
    return names


# Column name patterns that should have CHECK >= 0
_NON_NEGATIVE_PATTERNS = re.compile(
    r"^(price|cost|amount|total_amount|total_cost|total_price|budget|"
    r"salary|fee|charge|revenue|balance|discount|tax_amount|"
    r"quantity|qty|credit_limit|unit_cost|unit_price|line_total|"
    r"rental_cost|hourly_rate|daily_cost|allocated_amount|spent_amount|"
    r"invoice_amount|order_total|price_per_night|price_per_unit|"
    r"total_budget|net_amount|gross_amount|"
    r"hours_allocated|hours_worked|actual_hours|"
    r"planned_quantity|used_quantity|reserved_quantity|"
    r"reorder_level|room_count|number_of_rooms|guest_count|"
    r"max_occupancy|weight|volume|rate)$",
    re.IGNORECASE,
)

# Column name patterns for start/end date pairs
_START_DATE_NAMES = {"start_date", "check_in_date", "usage_start_date", "planned_start_date", "hire_date"}
_END_DATE_NAMES = {"end_date", "check_out_date", "usage_end_date", "planned_end_date", "end_date"}
_DATE_PAIRS = [
    ("start_date", "end_date"),
    ("check_in_date", "check_out_date"),
    ("usage_start_date", "usage_end_date"),
    ("planned_start_date", "planned_end_date"),
    ("planned_date", "actual_date"),
]


def _needs_non_negative_check(col_name: str, constraints: str) -> bool:
    """Check if column needs a CHECK >= 0 and doesn't already have one."""
    if not _NON_NEGATIVE_PATTERNS.match(col_name):
        return False
    # Already has a CHECK constraint
    if re.search(r"\bCHECK\s*\(", constraints, re.IGNORECASE):
        return False
    return True


def _build_table_check_constraints(table: Table) -> list[str]:
    """Build table-level CHECK constraints for date range validation."""
    col_names = {col.name for col in table.columns}
    checks = []
    for start_col, end_col in _DATE_PAIRS:
        if start_col in col_names and end_col in col_names:
            # Check if this constraint already exists in any column
            already_has = False
            for col in table.columns:
                if col.constraints and re.search(
                    rf"\b{end_col}\s*>=\s*{start_col}\b|\b{end_col}\s*>\s*{start_col}\b",
                    col.constraints, re.IGNORECASE
                ):
                    already_has = True
                    break
            if not already_has:
                checks.append(f"    CHECK ({end_col} >= {start_col})")
    return checks


def _sanitize_constraints(constraint_str: str) -> str:
    """Strip MySQL-specific syntax that is invalid in PostgreSQL."""
    # ON UPDATE CURRENT_TIMESTAMP is MySQL — not valid in PostgreSQL
    return re.sub(
        r"\s*ON\s+UPDATE\s+CURRENT_TIMESTAMP\b", "", constraint_str, flags=re.IGNORECASE
    ).strip()


def _extract_ref_table(constraint_str: str) -> str | None:
    """Extract the table name from a REFERENCES clause, if present."""
    m = re.search(r"REFERENCES\s+\w+\.(\w+)\(", constraint_str, re.IGNORECASE)
    return m.group(1) if m else None


def _strip_references(constraint_str: str) -> str:
    """Remove the REFERENCES clause from a constraint string."""
    return re.sub(
        r"\s*REFERENCES\s+\w+\.\w+\(\w+\)", "", constraint_str, flags=re.IGNORECASE
    ).strip()


def _extract_full_ref(constraint_str: str) -> str | None:
    """Extract the full 'schema.table(col)' from a REFERENCES clause."""
    m = re.search(r"REFERENCES\s+(\w+\.\w+\(\w+\))", constraint_str, re.IGNORECASE)
    return m.group(1) if m else None


# ---------------------------------------------------------------------------
# Section renderers
# ---------------------------------------------------------------------------

def _render_header(schema: MantaraSchema) -> str:
    abbrev = schema.schema_name.upper()
    return (
        f"{_section(f'MANTARA SCHEMA: {schema.system_name} ({abbrev})')}\n"
        f"-- Generated following mantara.schema.v1 specification\n"
    )


def _render_create_schema(schema_name: str) -> str:
    return f"-- Schema creation\nCREATE SCHEMA {schema_name};\n"


def _render_menu_table(schema: MantaraSchema) -> str:
    sn = schema.schema_name
    lines = [
        f"CREATE TABLE {sn}.menu (",
        "    menu_id INT PRIMARY KEY,",
        "    menu_name VARCHAR(255) NOT NULL,",
        "    sequence_number INT NOT NULL,",
        "    description TEXT",
        ");",
        f"COMMENT ON TABLE {sn}.menu IS "
        f"'Stores the top-level navigation menus for the {_esc(schema.system_name)}.';",
        f"INSERT INTO {sn}.menu (menu_id, menu_name, sequence_number, description)",
        "VALUES",
    ]

    sorted_menus = sorted(schema.menus, key=lambda m: m.menu_id)
    for i, menu in enumerate(sorted_menus):
        comma = "," if i < len(sorted_menus) - 1 else ""
        lines.append(
            f"    ({menu.menu_id}, '{_esc(menu.menu_name)}', "
            f"{menu.sequence_number}, '{_esc(menu.description)}'){comma}"
        )
    lines.append("ON CONFLICT (menu_id) DO NOTHING;")

    return "\n".join(lines)


def _render_submenu_table(schema: MantaraSchema) -> str:
    sn = schema.schema_name
    lines = [
        f"CREATE TABLE {sn}.submenu (",
        "    submenu_id INT PRIMARY KEY,",
        "    submenu_name VARCHAR(255) NOT NULL,",
        f"    menu_id INT NOT NULL REFERENCES {sn}.menu(menu_id),",
        "    sequence_number INT NOT NULL,",
        "    description TEXT",
        ");",
        f"INSERT INTO {sn}.submenu (submenu_id, submenu_name, menu_id, "
        f"sequence_number, description)",
        "VALUES",
    ]

    sorted_menus = sorted(schema.menus, key=lambda m: m.menu_id)
    all_entries: list[tuple[int, int, str, int, str]] = []  # (submenu_id, menu_id, name, seq, desc)
    menu_boundaries: list[tuple[str, int, int]] = []  # (menu_name, menu_id, start_idx)

    for menu in sorted_menus:
        sorted_subs = sorted(menu.submenus, key=lambda s: s.submenu_id)
        menu_boundaries.append((menu.menu_name, menu.menu_id, len(all_entries)))
        for sub in sorted_subs:
            all_entries.append(
                (sub.submenu_id, menu.menu_id, sub.submenu_name,
                 sub.sequence_number, sub.description)
            )

    for i, (sid, mid, name, seq, desc) in enumerate(all_entries):
        comma = "," if i < len(all_entries) - 1 else ""
        lines.append(
            f"    ({sid}, '{_esc(name)}', {mid}, {seq}, '{_esc(desc)}'){comma}"
        )
    lines.append("ON CONFLICT (submenu_id) DO NOTHING;")

    return "\n".join(lines)


def _render_enums(schema: MantaraSchema) -> str:
    if not schema.enum_types:
        return ""

    lines = []

    for enum in schema.enum_types:
        values = ", ".join(f"'{v}'" for v in enum.values)
        lines.append(f"CREATE TYPE {enum.type_name} AS ENUM ({values});")

    return "\n".join(lines)


def _render_business_table(
    table: Table,
    schema_name: str,
    created_tables: set[str],
    all_table_names: set[str],
) -> tuple[str, list[str]]:
    """Render a single business table CREATE TABLE + COMMENT.

    Returns (sql_block, list_of_deferred_alter_statements).
    """
    sn = schema_name
    deferred: list[str] = []
    cols_with_inline_fk: set[str] = set()

    # Detect which standard columns are present for named constraint generation
    col_names_set = {col.name for col in table.columns}

    # --- Build column entries: (sql_definition, optional_comment) ---
    col_entries: list[tuple[str, str | None]] = []
    for col in table.columns:
        # Standard `id` column — always serial4 NOT NULL; PK is a named constraint
        if col.name == "id":
            col_entries.append(("    id serial4 NOT NULL", None))
            continue

        line = f"    {col.name} {col.type}"
        constraints = _sanitize_constraints(col.constraints or "")

        # user_id / created_by / updated_by — strip any inline REFERENCES;
        # named FK constraints to public."user"(id) are emitted at table level.
        if col.name in {"user_id", "created_by", "updated_by"}:
            clean = re.sub(
                r'\s*REFERENCES\s+\S+\s*\([^)]+\)', '', constraints, flags=re.IGNORECASE
            ).strip()
            if clean:
                line += f" {clean}"
            col_entries.append((line, col.comment if col.comment else None))
            continue

        if constraints:
            ref_table = _extract_ref_table(constraints)

            if ref_table is not None:
                cols_with_inline_fk.add(col.name)

                if ref_table not in created_tables and ref_table in all_table_names:
                    # Forward reference — defer FK, keep other constraints
                    full_ref = _extract_full_ref(constraints)
                    clean = _strip_references(constraints)
                    if clean:
                        line += f" {clean}"

                    deferred.append(
                        f"ALTER TABLE {sn}.{table.table_name}\n"
                        f"    ADD CONSTRAINT fk_{table.table_name}_{col.name} "
                        f"FOREIGN KEY ({col.name}) REFERENCES {full_ref};"
                    )
                else:
                    line += f" {constraints}"
            else:
                line += f" {constraints}"

        # Auto-inject CHECK constraint if column needs one and doesn't have it
        if _needs_non_negative_check(col.name, constraints):
            line += f" CHECK ({col.name} >= 0)"

        col_entries.append((line, col.comment if col.comment else None))

    # --- Table-level constraints from foreign_keys list ---
    extra_constraints: list[str] = []
    if table.foreign_keys:
        for fk in table.foreign_keys:
            if fk.column in cols_with_inline_fk:
                continue  # Already handled via column constraints

            ref_match = re.match(r"(\w+)\.(\w+)\((\w+)\)", fk.references)
            if not ref_match:
                continue

            ref_table = ref_match.group(2)
            fk_clause = f"FOREIGN KEY ({fk.column}) REFERENCES {fk.references}"
            if fk.on_delete:
                fk_clause += f" ON DELETE {fk.on_delete}"
            if fk.on_update:
                fk_clause += f" ON UPDATE {fk.on_update}"

            if ref_table not in created_tables and ref_table in all_table_names:
                deferred.append(
                    f"ALTER TABLE {sn}.{table.table_name}\n"
                    f"    ADD CONSTRAINT fk_{table.table_name}_{fk.column} "
                    f"{fk_clause};"
                )
            else:
                extra_constraints.append(
                    f"    CONSTRAINT fk_{table.table_name}_{fk.column} {fk_clause}"
                )

    # --- Auto-inject date range CHECK constraints ---
    date_checks = _build_table_check_constraints(table)
    extra_constraints.extend(date_checks)

    # --- Standard named constraints (PK first, then user FKs last) ---
    std_pkey: list[str] = []
    std_user_fks: list[str] = []
    if "id" in col_names_set:
        std_pkey.append(f"    CONSTRAINT {table.table_name}_pkey PRIMARY KEY (id)")
    if "created_by" in col_names_set:
        std_user_fks.append(
            f'    CONSTRAINT {table.table_name}_created_by_fkey '
            f'FOREIGN KEY (created_by) REFERENCES public."user"(id)'
        )
    if "updated_by" in col_names_set:
        std_user_fks.append(
            f'    CONSTRAINT {table.table_name}_updated_by_fkey '
            f'FOREIGN KEY (updated_by) REFERENCES public."user"(id)'
        )
    if "user_id" in col_names_set:
        std_user_fks.append(
            f'    CONSTRAINT {table.table_name}_user_id_fkey '
            f'FOREIGN KEY (user_id) REFERENCES public."user"(id)'
        )

    # Order: named PK, then other FK/CHECK constraints, then user FKs
    all_constraints = std_pkey + extra_constraints + std_user_fks

    # --- Assemble CREATE TABLE ---
    all_entries: list[tuple[str, str | None]] = col_entries + [
        (c, None) for c in all_constraints
    ]

    # Build body with comma BEFORE inline comment (not after)
    # This prevents -- comments from eating the comma
    body_lines: list[str] = []
    for i, (sql_def, _comment) in enumerate(all_entries):
        is_last = (i == len(all_entries) - 1)
        body_lines.append(sql_def if is_last else f"{sql_def},")
    body = "\n".join(body_lines)

    sql = (
        f"CREATE TABLE {sn}.{table.table_name} (\n"
        f"{body}\n"
        f");"
    )

    return sql, deferred


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def render_sql(schema: MantaraSchema) -> str:
    """Render a complete PostgreSQL DDL script from a MantaraSchema object.

    Order of emission (Mantara v8 — guarantees inline FKs to cfg_*):
        1. Header + assumptions + open questions
        2. CREATE SCHEMA
        3. menu + submenu core tables
        4. ENUM types (legacy — cfg_enforcer normally clears these)
        5. CONFIGURATION LOOKUP TABLES — every cfg_* table FIRST so that
           business-table columns referencing cfg_* can keep their inline
           REFERENCES clause (no forward-ref deferral)
        6. Business tables grouped by menu, cfg_* skipped (already emitted)
        7. ADDITIONAL FOREIGN KEY CONSTRAINTS — only the leftover forward
           refs between business tables (much smaller now)
    """
    sn = schema.schema_name
    parts: list[str] = []
    created_tables: set[str] = {"menu", "submenu"}
    all_deferred: list[str] = []
    all_table_names = _get_all_table_names(schema)

    parts.append(_render_create_schema(sn))

    # Menu + Submenu tables
    parts.append(_render_menu_table(schema))
    parts.append(_render_submenu_table(schema))

    # ENUM types
    enum_sql = _render_enums(schema)
    if enum_sql:
        parts.append(enum_sql)

    # Collect (table, parent_menu, parent_submenu) tuples in render order.
    sorted_menus = sorted(schema.menus, key=lambda m: m.menu_id)
    cfg_records: list[tuple[Table, object, object]] = []
    biz_by_menu: list[tuple[object, list[tuple[object, list[Table]]]]] = []
    for menu in sorted_menus:
        sorted_subs = sorted(menu.submenus, key=lambda s: s.submenu_id)
        biz_subs: list[tuple[object, list[Table]]] = []
        for submenu in sorted_subs:
            biz_tables: list[Table] = []
            for table in (submenu.tables or []):
                if table.table_name.startswith("cfg_"):
                    cfg_records.append((table, menu, submenu))
                else:
                    biz_tables.append(table)
            biz_subs.append((submenu, biz_tables))
        biz_by_menu.append((menu, biz_subs))

    # Phase 5: emit cfg_* tables FIRST so business→cfg refs are backward.
    if cfg_records:
        for table, _menu, _submenu in cfg_records:
            sql, deferred = _render_business_table(
                table, sn, created_tables, all_table_names
            )
            parts.append(sql)
            all_deferred.extend(deferred)
            created_tables.add(table.table_name)
            parts.append("")

    # Phase 6: business tables, grouped by menu, cfg_* skipped.
    for menu, biz_subs in biz_by_menu:
        # Skip menus where every submenu was cfg-only and got rendered already
        if not any(tbls for _sm, tbls in biz_subs) and not any(
            (sm.tables is None or len(sm.tables) == 0)
            for sm, _tbls in biz_subs
        ):
            continue

        for submenu, biz_tables in biz_subs:
            if not biz_tables:
                continue

            for table in biz_tables:
                sql, deferred = _render_business_table(
                    table, sn, created_tables, all_table_names
                )
                parts.append(sql)
                all_deferred.extend(deferred)
                created_tables.add(table.table_name)

            parts.append("")

    # Deferred foreign keys
    if all_deferred:
        parts.append("")
        for stmt in all_deferred:
            parts.append(stmt)
            parts.append("")

    # Footer
    parts.append(_section("END OF SCHEMA"))
    parts.append("")

    return "\n".join(parts)
