"""
Step 06 — Save DDL to Database

Reads a Postgres DDL (schema.sql) produced by Step 5, creates a dedicated
schema in the target database, and executes the DDL inside that schema.
Connects via a two-hop SSH tunnel: local → bastion → DB server → postgres.

Usage:
    python main.py                          # auto-detect latest run's schema.sql
"""

from __future__ import annotations

import logging
import os
import select
import socket
import sys
import threading
from contextlib import contextmanager
from pathlib import Path

import paramiko
from dotenv import load_dotenv
from sqlalchemy import create_engine, text
load_dotenv()
# ── Config ───────────────────────────────────────────────────────────────────

BASE_DIR = Path(__file__).resolve().parent.parent
load_dotenv(BASE_DIR / ".env")

DB_HOST     = os.getenv("DB_HOST")
DB_PORT     = int(os.getenv("DB_PORT", "5432"))
DB_NAME     = os.getenv("DB_NAME")
DB_USER     = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

# Hop 1 — bastion / jump server (publicly reachable)
BASTION_HOST     = os.getenv("BASTION_HOST")
BASTION_PORT     = int(os.getenv("BASTION_PORT", "22"))
BASTION_USER     = os.getenv("BASTION_USER")
BASTION_KEY_PATH = os.getenv("BASTION_KEY_PATH")

# Hop 2 — DB server (reachable only from inside the bastion via SSH)
DB_SSH_HOST     = os.getenv("DB_SSH_HOST", DB_HOST)
DB_SSH_PORT     = int(os.getenv("DB_SSH_PORT", "22"))
DB_SSH_USER     = os.getenv("DB_SSH_USER")
DB_SSH_KEY_PATH = os.getenv("DB_SSH_KEY_PATH")

RUNS_ROOT = BASE_DIR / "step-01-input-ingestion" / "output" / "runs"

log = logging.getLogger(__name__)


# ── Helpers ───────────────────────────────────────────────────────────────────
def _latest_schema_sql() -> Path | None:
    candidates = list(RUNS_ROOT.glob("*/schema/schema.sql"))
    if not candidates:
        return None
    return max(candidates, key=lambda p: p.stat().st_mtime)


def _free_port() -> int:
    with socket.socket() as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


def _pipe(src: socket.socket | paramiko.Channel,
          dst: socket.socket | paramiko.Channel,
          stop: threading.Event) -> None:
    """Bidirectionally forward bytes between src and dst until one closes."""
    while not stop.is_set():
        try:
            r, _, _ = select.select([src, dst], [], [], 1.0)
        except Exception:
            break
        for readable in r:
            other = dst if readable is src else src
            try:
                data = readable.recv(4096)
            except Exception:
                data = b""
            if not data:
                stop.set()
                return
            try:
                other.sendall(data)
            except Exception:
                stop.set()
                return


@contextmanager
def _db_engine(local_host: str, local_port: int):
    url = (
        f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}"
        f"@{local_host}:{local_port}/{DB_NAME}"
    )
    engine = create_engine(
        url,
        future=True,
        connect_args={"connect_timeout": 10},
    )
    try:
        yield engine
    finally:
        engine.dispose()


@contextmanager
def _tunnel():
    """Two-hop SSH tunnel using pure paramiko (compatible with paramiko 4+).

    Hop 1 — paramiko SSH → BASTION_HOST
    Hop 2 — direct-tcpip channel bastion → DB_SSH_HOST:DB_SSH_PORT,
             second paramiko client SSHes into DB server over that channel
    Forward — direct-tcpip channel DB server → DB_HOST:DB_PORT,
              a local TCP accept loop bridges it to a free local port.
    """
    if not BASTION_HOST:
        log.info("BASTION_HOST not set — connecting directly to %s:%s", DB_HOST, DB_PORT)
        yield DB_HOST, DB_PORT
        return

    if not BASTION_KEY_PATH or not DB_SSH_KEY_PATH:
        raise ValueError("Set BASTION_KEY_PATH and DB_SSH_KEY_PATH in .env.")

    # ── Hop 1: SSH into bastion ───────────────────────────────────────────────
    log.info("Hop 1 — SSH  local → %s@%s:%s", BASTION_USER, BASTION_HOST, BASTION_PORT)
    bastion = paramiko.SSHClient()
    bastion.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    bastion.connect(
        BASTION_HOST,
        port=BASTION_PORT,
        username=BASTION_USER,
        key_filename=os.path.expanduser(BASTION_KEY_PATH),
        timeout=10,
    )
    log.info("Hop 1 — bastion connected")

    try:
        # ── Hop 2: SSH into DB server through the bastion ─────────────────────
        log.info("Hop 2 — SSH  bastion → %s@%s:%s", DB_SSH_USER, DB_SSH_HOST, DB_SSH_PORT)
        proxy = bastion.get_transport().open_channel(
            "direct-tcpip",
            (DB_SSH_HOST, DB_SSH_PORT),
            ("127.0.0.1", 0),
        )
        db_server = paramiko.SSHClient()
        db_server.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        db_server.connect(
            DB_SSH_HOST,
            port=DB_SSH_PORT,
            username=DB_SSH_USER,
            key_filename=os.path.expanduser(DB_SSH_KEY_PATH),
            sock=proxy,
            timeout=10,
        )
        log.info("Hop 2 — DB server connected")

        try:
            # ── Forward: local TCP → DB server channel → DB_HOST:DB_PORT ─────
            local_port = _free_port()
            transport  = db_server.get_transport()
            stop_event = threading.Event()

            server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            server_sock.bind(("127.0.0.1", local_port))
            server_sock.listen(5)
            server_sock.settimeout(1.0)

            def _accept_loop() -> None:
                while not stop_event.is_set():
                    try:
                        client_sock, _ = server_sock.accept()
                    except socket.timeout:
                        continue
                    except Exception:
                        break
                    try:
                        channel = transport.open_channel(
                            "direct-tcpip",
                            (DB_HOST, DB_PORT),
                            ("127.0.0.1", local_port),
                        )
                    except Exception as e:
                        log.warning("Forward channel failed: %s", e)
                        client_sock.close()
                        continue
                    per_conn_stop = threading.Event()
                    threading.Thread(
                        target=_pipe,
                        args=(client_sock, channel, per_conn_stop),
                        daemon=True,
                    ).start()

            threading.Thread(target=_accept_loop, daemon=True).start()

            log.info(
                "Tunnel ready — 127.0.0.1:%s → bastion → DB server → %s:%s",
                local_port, DB_HOST, DB_PORT,
            )
            yield "127.0.0.1", local_port

        finally:
            stop_event.set()
            server_sock.close()
            db_server.close()
    finally:
        bastion.close()


# ── Main ──────────────────────────────────────────────────────────────────────

def main() -> int:
    logging.basicConfig(
        level=os.getenv("STEP06_LOG_LEVEL", "INFO"),
        format="%(asctime)s %(levelname)-7s %(message)s",
    )

    # sql_path = _latest_schema_sql()
    sql_path = Path("/home/ubuntu/dpg/pipeline/step-02-prd-generation/output/20260507_102457/schema.sql")

    if sql_path is None or not sql_path.exists():
        log.error("No schema.sql found under %s. Run Step 5 first.", RUNS_ROOT)
        return 1

    log.info("DDL file: %s", sql_path)

    ddl = sql_path.read_text(encoding="utf-8")

    try:
        with _tunnel() as (local_host, local_port):
            with _db_engine(local_host, local_port) as engine:
                with engine.begin() as conn:
                    conn.execute(text(ddl))

        log.info("Done — schema created successfully.")

    except Exception as exc:
        log.error("Failed: %s", exc)
        return 1

    return 0


if __name__ == "__main__":
    sys.exit(main())
