"""Tests for the business rule validator."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from models import (
    MantaraSchema, Menu, Submenu, Table, Column, ForeignKey, EnumType,
)
from business_validator import (
    validate_all,
    validate_naming,
    validate_submenu_ids,
    validate_table_structure,
    validate_ref_consistency,
    validate_referenced_tables,
    validate_postgresql_syntax,
    validate_numeric_types,
    validate_check_constraints,
    validate_schema_compliance,
)


class TestValidSchema:
    """A known-good schema should pass all checks."""

    def test_all_pass(self, ams_schema):
        result = validate_all(ams_schema)
        assert result["is_valid"], f"Expected valid, got errors: {result['errors']}"
        assert result["errors"] == []


class TestValidateNaming:
    def test_valid_names(self, ams_schema):
        errors = validate_naming(ams_schema)
        assert errors == []

    def test_bad_schema_name(self, ams_schema):
        ams_schema.schema_name = "AMS"
        errors = validate_naming(ams_schema)
        assert any("schema_name" in e and "snake_case" in e for e in errors)

    def test_bad_table_name(self, ams_schema):
        ams_schema.menus[0].submenus[0].tables[0].table_name = "FarmOverview"
        errors = validate_naming(ams_schema)
        assert any("FarmOverview" in e for e in errors)

    def test_bad_column_name(self, ams_schema):
        ams_schema.menus[0].submenus[0].tables[0].columns[2].name = "farmId"
        errors = validate_naming(ams_schema)
        assert any("farmId" in e for e in errors)

    def test_enum_missing_suffix(self, ams_schema):
        ams_schema.enum_types[0].type_name = "ams.season"
        errors = validate_naming(ams_schema)
        assert any("season" in e and "_enum" in e for e in errors)

    def test_enum_value_not_snake_case(self, ams_schema):
        ams_schema.enum_types[0].values[0] = "Kharif"
        errors = validate_naming(ams_schema)
        assert any("Kharif" in e for e in errors)


class TestValidateSubmenuIds:
    def test_valid_ids(self, ams_schema):
        errors = validate_submenu_ids(ams_schema)
        assert errors == []

    def test_bad_submenu_id(self, ams_schema):
        # submenu_id 301 is wrong for menu_id=1 (should be 101-199)
        ams_schema.menus[0].submenus[0].submenu_id = 301
        errors = validate_submenu_ids(ams_schema)
        assert len(errors) >= 1
        assert any("301" in e for e in errors)

    def test_submenu_id_zero_offset(self, ams_schema):
        # submenu_id 100 is invalid (range is 101-199 for menu_id=1)
        ams_schema.menus[0].submenus[0].submenu_id = 100
        errors = validate_submenu_ids(ams_schema)
        assert len(errors) >= 1


class TestValidateTableStructure:
    def test_valid_structure(self, ams_schema):
        errors = validate_table_structure(ams_schema)
        assert errors == []

    def test_missing_id_column(self, ams_schema):
        # Replace 'id' with something else as first column
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.columns[0].name = "farm_overview_id"
        errors = validate_table_structure(ams_schema)
        assert any("first column must be 'id'" in e for e in errors)

    def test_missing_submenu_id_column(self, ams_schema):
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.columns[1].name = "menu_id"
        errors = validate_table_structure(ams_schema)
        assert any("second column must be 'submenu_id'" in e for e in errors)

    def test_too_few_columns(self):
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="tiny",
                                    comment="A table with only one column for testing.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_table_structure(schema)
        assert any("fewer than 2 columns" in e for e in errors)


class TestValidateRefConsistency:
    def test_valid_refs(self, ams_schema):
        errors = validate_ref_consistency(ams_schema)
        assert errors == []

    def test_wrong_schema_in_column_constraint(self, ams_schema):
        # Change a REFERENCES to use wrong schema name
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.columns[1].constraints = "DEFAULT 101 NOT NULL REFERENCES wrong_schema.submenu(submenu_id)"
        errors = validate_ref_consistency(ams_schema)
        assert any("wrong_schema" in e for e in errors)

    def test_wrong_schema_in_fk_list(self, ams_schema):
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.foreign_keys[0].references = "wrong.farms(id)"
        errors = validate_ref_consistency(ams_schema)
        assert any("wrong" in e for e in errors)


class TestValidateReferencedTables:
    def test_valid_refs(self, ams_schema):
        """All referenced tables exist — should pass."""
        errors = validate_referenced_tables(ams_schema)
        assert errors == []

    def test_missing_table_in_column_constraint(self, ams_schema):
        """Column references a table that doesn't exist in the schema."""
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.columns.append(
            Column(
                name="customer_id",
                type="INT",
                constraints="REFERENCES ams.customers(id)",
            )
        )
        errors = validate_referenced_tables(ams_schema)
        assert any("customers" in e and "not defined" in e for e in errors)

    def test_missing_table_in_fk_list(self, ams_schema):
        """FK list references a table that doesn't exist."""
        table = ams_schema.menus[0].submenus[0].tables[0]
        table.foreign_keys.append(
            ForeignKey(column="supplier_id", references="ams.suppliers(id)")
        )
        errors = validate_referenced_tables(ams_schema)
        assert any("suppliers" in e and "not defined" in e for e in errors)

    def test_menu_submenu_tables_allowed(self):
        """References to 'menu' and 'submenu' tables should not error."""
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="items",
                                    comment="Items.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL REFERENCES tst.submenu(submenu_id)"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_referenced_tables(schema)
        assert errors == []


class TestValidatePostgresqlSyntax:
    def test_valid_syntax(self, ams_schema):
        errors = validate_postgresql_syntax(ams_schema)
        assert errors == []

    def test_on_update_current_timestamp(self):
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="items",
                                    comment="Items table with MySQL syntax for testing.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="updated_at", type="TIMESTAMP",
                                               constraints="DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_postgresql_syntax(schema)
        assert len(errors) == 1
        assert "ON UPDATE CURRENT_TIMESTAMP" in errors[0]

    def test_auto_increment(self):
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="items",
                                    comment="Items table with AUTO_INCREMENT for testing.",
                                    columns=[
                                        Column(name="id", type="INT", constraints="AUTO_INCREMENT PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_postgresql_syntax(schema)
        assert len(errors) == 1
        assert "AUTO_INCREMENT" in errors[0]


class TestValidateNumericTypes:
    def test_valid_numeric(self, ams_schema):
        """AMS uses FLOAT for hectares — that's a measurement, not money, so it's ok."""
        errors = validate_numeric_types(ams_schema)
        assert errors == []

    def test_float_for_price(self):
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="products",
                                    comment="Products table with FLOAT price for testing.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="unit_price", type="FLOAT"),
                                        Column(name="total_amount", type="DOUBLE PRECISION"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_numeric_types(schema)
        assert len(errors) == 2
        assert any("unit_price" in e and "FLOAT" in e for e in errors)
        assert any("total_amount" in e and "DOUBLE PRECISION" in e for e in errors)

    def test_float_for_non_money_ok(self):
        """FLOAT for non-money columns like latitude/temperature should not error."""
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="sensors",
                                    comment="Sensor readings table for testing FLOAT allowance.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="latitude", type="FLOAT"),
                                        Column(name="temperature", type="FLOAT"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_numeric_types(schema)
        assert errors == []


class TestValidateCheckConstraints:
    def test_missing_check_on_price(self):
        """price column without CHECK should be flagged."""
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="products",
                                    comment="Products with price for testing CHECK validation.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="price", type="NUMERIC(12,2)", constraints="NOT NULL"),
                                        Column(name="quantity", type="INT", constraints="NOT NULL"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_check_constraints(schema)
        assert len(errors) == 2
        assert any("price" in e and "CHECK" in e for e in errors)
        assert any("quantity" in e and "CHECK" in e for e in errors)

    def test_check_present_passes(self):
        """Columns with CHECK constraints should pass."""
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="orders",
                                    comment="Orders with proper CHECK constraints.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="amount", type="NUMERIC(12,2)", constraints="NOT NULL CHECK (amount >= 0)"),
                                        Column(name="quantity", type="INT", constraints="NOT NULL CHECK (quantity > 0)"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_check_constraints(schema)
        assert errors == []

    def test_non_money_columns_ignored(self):
        """Columns like 'name', 'status' should not require CHECK."""
        schema = MantaraSchema(
            system_name="Test",
            schema_name="tst",
            description="Test.",
            menus=[
                Menu(
                    menu_id=1,
                    menu_name="Main",
                    sequence_number=1,
                    description="Main.",
                    submenus=[
                        Submenu(
                            submenu_id=101,
                            submenu_name="Sub",
                            sequence_number=1,
                            description="Sub.",
                            tables=[
                                Table(
                                    table_name="items",
                                    comment="Items without money columns.",
                                    columns=[
                                        Column(name="id", type="SERIAL", constraints="PRIMARY KEY"),
                                        Column(name="submenu_id", type="INT", constraints="DEFAULT 101 NOT NULL"),
                                        Column(name="name", type="VARCHAR(255)", constraints="NOT NULL"),
                                        Column(name="status", type="VARCHAR(50)"),
                                    ],
                                ),
                            ],
                        ),
                    ],
                ),
            ],
        )
        errors = validate_check_constraints(schema)
        assert errors == []

    def test_ams_schema_passes(self, ams_schema):
        """AMS schema doesn't have money columns, should pass."""
        errors = validate_check_constraints(ams_schema)
        assert errors == []


class TestValidateSchemaCompliance:
    def test_valid_compliance(self, ams_schema):
        errors = validate_schema_compliance(ams_schema)
        assert errors == []
