"""Tests for CHECK constraint enricher."""
from __future__ import annotations

import copy

from pipeline.constraint_enricher import enforce_constraints


def _table(table_name="orders", columns=None):
    return {
        "schema_name": "demo",
        "menus": [{
            "menu_id": 1, "menu_name": "Ops", "sequence_number": 1,
            "description": "ops",
            "submenus": [{
                "submenu_id": 101, "submenu_name": "Orders",
                "sequence_number": 1, "description": "x",
                "tables": [{
                    "table_name": table_name,
                    "comment": "test",
                    "columns": columns or [],
                    "foreign_keys": [],
                }],
            }],
        }],
    }


def test_adds_non_negative_check_for_amount():
    s = _table(columns=[
        {"name": "id", "type": "SERIAL", "constraints": "PRIMARY KEY"},
        {"name": "total_amount", "type": "NUMERIC(10,2)"},
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][1]
    assert "CHECK (total_amount >= 0)" in col["constraints"]


def test_adds_non_negative_check_for_qty():
    s = _table(columns=[
        {"name": "qty", "type": "INT"},
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][0]
    assert "CHECK (qty >= 0)" in col["constraints"]


def test_does_not_add_to_non_numeric():
    s = _table(columns=[
        {"name": "amount", "type": "VARCHAR(20)"},  # weird but possible
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][0]
    assert "CHECK" not in (col.get("constraints") or "")


def test_adds_percentage_check():
    s = _table(columns=[
        {"name": "discount_pct", "type": "NUMERIC(5,2)"},
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][0]
    assert "BETWEEN 0 AND 100" in col["constraints"]


def test_adds_email_check():
    s = _table(columns=[
        {"name": "email", "type": "VARCHAR(255)"},
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][0]
    assert "CHECK (email LIKE '%@%')" in col["constraints"]


def test_adds_date_pair_check():
    s = _table(columns=[
        {"name": "id", "type": "SERIAL"},
        {"name": "start_date", "type": "DATE"},
        {"name": "end_date", "type": "DATE"},
    ])
    enforce_constraints(s)
    end_col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][2]
    assert "CHECK (end_date >= start_date)" in end_col["constraints"]


def test_idempotent():
    s = _table(columns=[
        {"name": "qty", "type": "INT"},
        {"name": "discount_pct", "type": "NUMERIC(5,2)"},
    ])
    enforce_constraints(s)
    snap = copy.deepcopy(s)
    enforce_constraints(s)
    # Stats reset on second call but constraint strings stay the same
    assert s["menus"][0]["submenus"][0]["tables"][0]["columns"] == \
           snap["menus"][0]["submenus"][0]["tables"][0]["columns"]


def test_skips_existing_check():
    s = _table(columns=[
        {"name": "qty", "type": "INT", "constraints": "NOT NULL CHECK (qty > 0)"},
    ])
    enforce_constraints(s)
    col = s["menus"][0]["submenus"][0]["tables"][0]["columns"][0]
    # Should keep existing > 0 check (not duplicate to >= 0)
    assert col["constraints"].count("CHECK") == 1


def test_skips_audit_columns():
    """created_at, updated_at, is_active shouldn't get CHECK constraints."""
    s = _table(columns=[
        {"name": "created_at", "type": "TIMESTAMP"},
        {"name": "updated_at", "type": "TIMESTAMP"},
        {"name": "is_active", "type": "BOOLEAN"},
    ])
    enforce_constraints(s)
    cols = s["menus"][0]["submenus"][0]["tables"][0]["columns"]
    for c in cols:
        assert "CHECK" not in (c.get("constraints") or "")


def test_stats_recorded():
    s = _table(columns=[
        {"name": "qty", "type": "INT"},
        {"name": "amount", "type": "NUMERIC(10,2)"},
        {"name": "discount_pct", "type": "NUMERIC(5,2)"},
        {"name": "email", "type": "VARCHAR(255)"},
        {"name": "start_date", "type": "DATE"},
        {"name": "end_date", "type": "DATE"},
    ])
    enforce_constraints(s)
    stats = s["_constraint_enricher"]
    assert stats["non_negative_added"] == 2
    assert stats["percentage_added"] == 1
    assert stats["email_added"] == 1
    assert stats["date_pair_added"] == 1
