"""
Step 07 — Save DDL to Database

Reads a Postgres DDL (schema.sql) produced by Step 5, creates a dedicated
schema in the target database, and executes the DDL inside that schema.
Connects via a two-hop SSH tunnel: local → bastion → DB server → postgres.

Usage:
    python main.py                                   # auto-detect latest run's schema.sql
    python main.py --schema-path /path/to/schema.sql # explicit path
"""

from __future__ import annotations

import argparse
from typing import Any
import logging
import os
import re
import select
import socket
import sys
import threading
from contextlib import contextmanager
from pathlib import Path

import paramiko
import requests
from dotenv import load_dotenv
from sqlalchemy import create_engine, text

# ── Config ───────────────────────────────────────────────────────────────────

BASE_DIR = Path(__file__).resolve().parent.parent
load_dotenv(BASE_DIR / ".env")

# Hop 1 — bastion / jump server (publicly reachable)
BASTION_HOST     = os.getenv("BASTION_HOST")
BASTION_PORT     = int(os.getenv("BASTION_PORT", "22"))
BASTION_USER     = os.getenv("BASTION_USER")
BASTION_KEY_PATH = os.getenv("BASTION_KEY_PATH")

# Hop 2 — DB server (reachable only from inside the bastion via SSH)
DB_SSH_HOST     = os.getenv("DB_SSH_HOST") or os.getenv("DB_HOST", "")
DB_SSH_PORT     = int(os.getenv("DB_SSH_PORT", "22"))
DB_SSH_USER     = os.getenv("DB_SSH_USER")
DB_SSH_KEY_PATH = os.getenv("DB_SSH_KEY_PATH")

RUNS_ROOT = BASE_DIR / "runs"

log = logging.getLogger(__name__)


# ── Helpers ───────────────────────────────────────────────────────────────────
def _latest_schema_sql() -> Path | None:
    candidates = list(RUNS_ROOT.glob("*/schema/schema.sql"))
    if not candidates:
        return None
    return max(candidates, key=lambda p: p.stat().st_mtime)


def _free_port() -> int:
    with socket.socket() as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


def _pipe(src: socket.socket | paramiko.Channel,
          dst: socket.socket | paramiko.Channel,
          stop: threading.Event) -> None:
    """Bidirectionally forward bytes between src and dst until one closes."""
    while not stop.is_set():
        try:
            r, _, _ = select.select([src, dst], [], [], 1.0)
        except Exception:
            break
        for readable in r:
            other = dst if readable is src else src
            try:
                data = readable.recv(4096)
            except Exception:
                data = b""
            if not data:
                stop.set()
                return
            try:
                other.sendall(data)
            except Exception:
                stop.set()
                return


def _resolve_creds(db_info: dict | None) -> tuple[str, int, str, str, str]:
    """Return (host, port, user, password, dbname) from db_info or .env fallback."""
    if db_info:
        return (
            db_info["host"],
            int(db_info.get("port", 5432)),
            db_info["username"],
            db_info["password"],
            db_info["database"],
        )
    return (
        os.environ["DB_HOST"],
        int(os.getenv("DB_PORT", "5432")),
        os.environ["DB_USER"],
        os.environ["DB_PASSWORD"],
        os.environ["DB_NAME"],
    )


@contextmanager
def _db_engine(local_host: str, local_port: int, *, db_user: str, db_password: str, db_name: str):
    url = (
        f"postgresql+psycopg2://{db_user}:{db_password}"
        f"@{local_host}:{local_port}/{db_name}"
    )
    engine = create_engine(url, future=True, connect_args={"connect_timeout": 10})
    try:
        yield engine
    finally:
        engine.dispose()


@contextmanager
def _tunnel(db_host: str, db_port: int):
    """Two-hop SSH tunnel using pure paramiko (compatible with paramiko 4+).

    Hop 1 — paramiko SSH → BASTION_HOST
    Hop 2 — direct-tcpip channel bastion → DB_SSH_HOST:DB_SSH_PORT,
             second paramiko client SSHes into DB server over that channel
    Forward — direct-tcpip channel DB server → db_host:db_port,
              a local TCP accept loop bridges it to a free local port.
    """
    if not BASTION_HOST:
        log.info("BASTION_HOST not set — connecting directly to %s:%s", db_host, db_port)
        yield db_host, db_port
        return

    if not BASTION_KEY_PATH or not DB_SSH_KEY_PATH:
        raise ValueError("Set BASTION_KEY_PATH and DB_SSH_KEY_PATH in .env.")

    # ── Hop 1: SSH into bastion ───────────────────────────────────────────────
    log.info("Hop 1 — SSH  local → %s@%s:%s", BASTION_USER, BASTION_HOST, BASTION_PORT)
    bastion = paramiko.SSHClient()
    bastion.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    bastion.connect(
        BASTION_HOST,
        port=BASTION_PORT,
        username=BASTION_USER,
        key_filename=os.path.expanduser(BASTION_KEY_PATH),
        timeout=10,
    )
    log.info("Hop 1 — bastion connected")

    try:
        # ── Hop 2: SSH into DB server through the bastion ─────────────────────
        log.info("Hop 2 — SSH  bastion → %s@%s:%s", DB_SSH_USER, DB_SSH_HOST, DB_SSH_PORT)
        proxy = bastion.get_transport().open_channel(
            "direct-tcpip",
            (DB_SSH_HOST, DB_SSH_PORT),
            ("127.0.0.1", 0),
        )
        db_server = paramiko.SSHClient()
        db_server.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        db_server.connect(
            DB_SSH_HOST,
            port=DB_SSH_PORT,
            username=DB_SSH_USER,
            key_filename=os.path.expanduser(DB_SSH_KEY_PATH),
            sock=proxy,
            timeout=10,
        )
        log.info("Hop 2 — DB server connected")

        try:
            # ── Forward: local TCP → DB server channel → db_host:db_port ─────
            local_port = _free_port()
            transport  = db_server.get_transport()
            stop_event = threading.Event()

            server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            server_sock.bind(("127.0.0.1", local_port))
            server_sock.listen(5)
            server_sock.settimeout(1.0)

            def _accept_loop() -> None:
                while not stop_event.is_set():
                    try:
                        client_sock, _ = server_sock.accept()
                    except socket.timeout:
                        continue
                    except Exception:
                        break
                    try:
                        channel = transport.open_channel(
                            "direct-tcpip",
                            (db_host, db_port),
                            ("127.0.0.1", local_port),
                        )
                    except Exception as e:
                        log.warning("Forward channel failed: %s", e)
                        client_sock.close()
                        continue
                    per_conn_stop = threading.Event()
                    threading.Thread(
                        target=_pipe,
                        args=(client_sock, channel, per_conn_stop),
                        daemon=True,
                    ).start()

            threading.Thread(target=_accept_loop, daemon=True).start()

            log.info(
                "Tunnel ready — 127.0.0.1:%s → bastion → DB server → %s:%s",
                local_port, db_host, db_port,
            )
            yield "127.0.0.1", local_port

        finally:
            stop_event.set()
            server_sock.close()
            db_server.close()
    finally:
        bastion.close()


# ── Public API ────────────────────────────────────────────────────────────────

def _extract_table_names(ddl: str) -> list[str]:
    """Return table names from CREATE TABLE statements in DDL order (respects FK deps)."""
    return re.findall(r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"\w+"\.)?["\']?(\w+)["\']?', ddl, re.IGNORECASE)


def _deploy_ddl_to_schema(conn: Any, ddl: str, schema_name: str, drop_first: bool = False) -> None:
    """Execute cleaned DDL in *schema_name*. Optionally drops existing tables first."""
    conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
    conn.execute(text(f'SET search_path TO "{schema_name}", public'))

    if drop_first:
        # Drop in reverse order to respect FK dependencies.
        tables = _extract_table_names(ddl)
        for table in reversed(tables):
            conn.execute(text(f'DROP TABLE IF EXISTS "{table}" CASCADE'))

    for stmt in ddl.split(";"):
        stmt = stmt.strip()
        if stmt:
            conn.execute(text(stmt))


def save_ddl_to_db(sql_path: Path, schema_name: str | None = None, db_info: dict | None = None) -> bool:
    """Execute the DDL at *sql_path* against the configured database.

    Deploys to two places:
    1. A run-specific schema (ai_sch_<run_id>) for isolation / audit trail.
    2. The public schema — with DROP+CASCADE of existing tables so the
       generated backend always finds a schema that matches the current DDL.
       This prevents stale-table mismatches when the same DB is reused across runs.

    Credentials come from *db_info* (domain-info API response) or .env fallback.
    Returns True on success, False on any error.
    """
    db_host, db_port, db_user, db_password, db_name = _resolve_creds(db_info)
    ddl = sql_path.read_text(encoding="utf-8")

    # Derive the run-specific schema name from the DDL path's parent directory.
    if schema_name is None:
        schema_name = f"ai_sch_{sql_path.parent.name}"

    # Strip comment lines and CREATE SCHEMA lines — we create the schema ourselves.
    cleaned_lines = [
        line for line in ddl.splitlines()
        if not line.strip().startswith("--")
        and not re.match(r"^\s*CREATE\s+SCHEMA\b", line, re.IGNORECASE)
    ]
    ddl = "\n".join(cleaned_lines)

    # Remove any schema prefixes the LLM may have added so statements are schema-agnostic.
    ddl = re.sub(r'\b[a-zA-Z_]\w*\.(?=[a-zA-Z_])', '', ddl)

    # Make CREATE TABLE / INDEX idempotent for the isolated schema copy.
    isolated_ddl = re.sub(r'\bCREATE TABLE\b', 'CREATE TABLE IF NOT EXISTS', ddl, flags=re.IGNORECASE)
    isolated_ddl = re.sub(r'\bCREATE UNIQUE INDEX\b', 'CREATE UNIQUE INDEX IF NOT EXISTS', isolated_ddl, flags=re.IGNORECASE)
    isolated_ddl = re.sub(r'\bCREATE INDEX\b', 'CREATE INDEX IF NOT EXISTS', isolated_ddl, flags=re.IGNORECASE)

    try:
        with _tunnel(db_host, db_port) as (local_host, local_port):
            with _db_engine(local_host, local_port, db_user=db_user, db_password=db_password, db_name=db_name) as engine:
                with engine.begin() as conn:
                    # 1. Isolated run schema — idempotent, never destructive.
                    _deploy_ddl_to_schema(conn, isolated_ddl, schema_name, drop_first=False)

                # Verify the isolated schema was actually populated with tables.
                # If the transaction rolled back silently (e.g. FK ordering issue,
                # type conflict) the schema would be empty and the backend would
                # silently fall through to public with a different schema.
                with engine.connect() as conn:
                    row = conn.execute(
                        text(
                            "SELECT COUNT(*) FROM information_schema.tables "
                            "WHERE table_schema = :s"
                        ),
                        {"s": schema_name},
                    ).fetchone()
                    table_count = row[0] if row else 0
                if table_count == 0:
                    raise RuntimeError(
                        f"Isolated schema '{schema_name}' was created but contains no tables. "
                        "The DDL transaction likely rolled back. "
                        "Check the DDL for FK ordering issues or type conflicts."
                    )
                log.info("Verified: schema '%s' has %d table(s).", schema_name, table_count)

                with engine.begin() as conn:
                    # 2. Public schema — drop stale tables so the backend always matches.
                    _deploy_ddl_to_schema(conn, ddl, "public", drop_first=True)
        log.info("Done — DDL from %s applied to schema %s and public.", sql_path, schema_name)
        return True
    except Exception as exc:
        log.error("Failed: %s", exc)
        return False


def save_dalfin_to_public(dalfin_payload: str, domain_url: str, db_info: dict | None = None) -> bool:
    """POST the compiled dalfin JSON to the system-builder import API.

    The *domain_url* is sent as the HTTP Origin header so the backend can
    route the payload to the correct tenant's public schema.
    Returns True on success, False on any error.
    """
    api_url = os.environ.get("DALFIN_IMPORT_API_URL", "").strip()
    if not api_url:
        log.warning("DALFIN_IMPORT_API_URL not set — skipping dalfin import.")
        return False

    headers = {
        "Content-Type": "application/json",
        "Origin": domain_url,
    }
    try:
        resp = requests.post(api_url, data=dalfin_payload, headers=headers, timeout=30)
        if resp.status_code in (200, 201):
            log.info("Dalfin JSON saved to public schema via %s", api_url)
            return True
        log.error("Dalfin import failed (%s): %s", resp.status_code, resp.text[:200])
        return False
    except Exception as exc:
        log.error("Dalfin import request error: %s", exc)
        return False


# ── CLI entry point ───────────────────────────────────────────────────────────

def main() -> int:
    logging.basicConfig(
        level=os.getenv("STEP07_LOG_LEVEL", "INFO"),
        format="%(asctime)s %(levelname)-7s %(message)s",
    )

    parser = argparse.ArgumentParser(description="Step 07 — Save DDL to database")
    parser.add_argument("--schema-path", type=Path, default=None,
                        help="explicit path to schema.sql (default: auto-detect latest)")
    args = parser.parse_args()
    
    if args.schema_path:
        sql_path = args.schema_path.resolve()
        if not sql_path.exists():
            log.error("schema.sql not found at %s", sql_path)
            return 1
    else:
        sql_path = _latest_schema_sql()
        if sql_path is None:
            log.error("No schema.sql found under %s. Run Step 5 first.", RUNS_ROOT)
            return 1
    log.info("DDL file: %s", sql_path)

    return 0 if save_ddl_to_db(sql_path) else 1


if __name__ == "__main__":
    sys.exit(main())
