"""Tests for audit column enforcer."""
from __future__ import annotations

import copy

from pipeline.audit_column_enforcer import enforce_audit_columns


def _schema(tables):
    return {
        "schema_name": "demo",
        "menus": [{
            "menu_id": 1, "menu_name": "Ops", "sequence_number": 1,
            "description": "ops",
            "submenus": [{
                "submenu_id": 101, "submenu_name": "Things",
                "sequence_number": 1, "description": "x",
                "tables": tables,
            }],
        }],
    }


def test_adds_three_audit_columns():
    s = _schema([{
        "table_name": "orders",
        "comment": "test",
        "columns": [{"name": "id", "type": "SERIAL", "constraints": "PRIMARY KEY"}],
        "foreign_keys": [],
    }])
    enforce_audit_columns(s)
    cols = s["menus"][0]["submenus"][0]["tables"][0]["columns"]
    names = {c["name"] for c in cols}
    assert "created_at" in names
    assert "updated_at" in names
    assert "is_active" in names


def test_does_not_duplicate_existing():
    s = _schema([{
        "table_name": "orders",
        "comment": "test",
        "columns": [
            {"name": "id", "type": "SERIAL"},
            {"name": "created_at", "type": "TIMESTAMP"},
        ],
        "foreign_keys": [],
    }])
    enforce_audit_columns(s)
    cols = s["menus"][0]["submenus"][0]["tables"][0]["columns"]
    created_count = sum(1 for c in cols if c["name"] == "created_at")
    assert created_count == 1


def test_skips_cfg_tables():
    s = _schema([{
        "table_name": "cfg_status",
        "comment": "lookup",
        "columns": [{"name": "id", "type": "SERIAL"}],
        "foreign_keys": [],
    }])
    enforce_audit_columns(s)
    cols = s["menus"][0]["submenus"][0]["tables"][0]["columns"]
    assert {c["name"] for c in cols} == {"id"}, "cfg_* should not get audit cols injected here"


def test_skips_menu_submenu():
    s = _schema([
        {"table_name": "menu", "comment": "x", "columns": [{"name": "id", "type": "SERIAL"}], "foreign_keys": []},
        {"table_name": "submenu", "comment": "x", "columns": [{"name": "id", "type": "SERIAL"}], "foreign_keys": []},
    ])
    enforce_audit_columns(s)
    for tbl in s["menus"][0]["submenus"][0]["tables"]:
        names = {c["name"] for c in tbl["columns"]}
        assert "created_at" not in names


def test_idempotent():
    s = _schema([{
        "table_name": "orders",
        "comment": "test",
        "columns": [{"name": "id", "type": "SERIAL"}],
        "foreign_keys": [],
    }])
    enforce_audit_columns(s)
    snap = copy.deepcopy(s)
    enforce_audit_columns(s)
    assert s["menus"][0]["submenus"][0]["tables"][0]["columns"] == \
           snap["menus"][0]["submenus"][0]["tables"][0]["columns"]


def test_stats_recorded():
    s = _schema([
        {"table_name": "orders", "comment": "x", "columns": [{"name": "id", "type": "SERIAL"}], "foreign_keys": []},
        {"table_name": "items", "comment": "x", "columns": [{"name": "id", "type": "SERIAL"}, {"name": "is_active", "type": "BOOLEAN"}], "foreign_keys": []},
    ])
    enforce_audit_columns(s)
    stats = s["_audit_column_enforcer"]
    assert stats["tables_processed"] == 2
    assert stats["created_at_added"] == 2
    assert stats["updated_at_added"] == 2
    assert stats["is_active_added"] == 1  # items already had it
