"""Mantara v2 — Multi-step pipeline: Analyze → Plan → Validate Plan → Generate → Validate → Render.

Approach 2 from plan.md: 3 LLM calls with focused prompts + repair loop.
"""

import json
import re
import sys
import time
from dataclasses import dataclass, field

from openai import OpenAI
import os
from config import (
    OPENAI_API_KEY, MODEL, MAX_TOKENS, TEMPERATURE,
    TIMEOUT_SECONDS, SYSTEM_PROMPT_PATH, BACKEND,
    OLLAMA_BASE_URL, OLLAMA_MODEL,
)
from models import MantaraSchema
from business_validator import validate_all
from renderer import render_sql
from fsd_analyzer import FSDAnalysis


def _current_temperature() -> float:
    """Read MANTARA_TEMPERATURE fresh from env on every call. Enables
    Best-of-N to vary temperatures across parallel candidates without
    needing to reload the entire Mantara module."""
    return float(os.getenv("MANTARA_TEMPERATURE", str(TEMPERATURE)))


def _log(step: str, msg: str):
    print(f"  [v2:{step}] {msg}", file=sys.stderr)


def _is_local() -> bool:
    return BACKEND in ("ollama", "llamacpp")


def _chat_complete(client, model: str, messages: list, max_tokens: int = 4000,
                   temperature: float = 0.2, json_mode: bool = False):
    """Unified chat completion that works with OpenAI, Ollama, and llama.cpp clients."""
    if BACKEND == "llamacpp":
        output = client.chat(model=model, messages=messages,
                             max_tokens=max_tokens, temperature=temperature)
        # Normalize to OpenAI-like response
        class _Resp:
            class choices_item:
                class message:
                    content = output["choices"][0]["message"]["content"]
            choices = [choices_item()]
            class usage:
                total_tokens = output.get("usage", {}).get("total_tokens", 0)
        return _Resp()
    else:
        kwargs = {}
        if json_mode and _is_local():
            kwargs["response_format"] = {"type": "json_object"}
        return client.chat.completions.create(
            model=model, messages=messages,
            max_tokens=max_tokens, temperature=temperature, **kwargs
        )


def _make_client():
    """Create the appropriate client based on backend config."""
    if BACKEND == "llamacpp":
        from backends.llamacpp_backend import LlamaCppBackend
        return LlamaCppBackend()
    if BACKEND == "ollama":
        return OpenAI(base_url=OLLAMA_BASE_URL, api_key="ollama", timeout=TIMEOUT_SECONDS)
    return OpenAI(api_key=OPENAI_API_KEY, timeout=TIMEOUT_SECONDS)


def _resolve_model(model: str) -> str:
    """Map OpenAI model names to local equivalents when using local backend."""
    if not _is_local():
        return model
    return OLLAMA_MODEL


def _parse_structured(client, system_prompt: str, user_content: str,
                      model: str) -> MantaraSchema:
    """Generate a MantaraSchema — uses Structured Outputs for OpenAI, JSON mode for local."""
    if BACKEND == "llamacpp":
        # client is a LlamaCppBackend instance
        schema = client.generate(system_prompt, user_content, model=model)
        # Return a fake completion-like object for token counting
        class _FakeCompletion:
            class usage:
                total_tokens = 0
        return schema, _FakeCompletion()

    if BACKEND == "ollama":
        from backends.ollama_backend import _load_json_schema, _SCHEMA_INSTRUCTION, _extract_json
        full_prompt = system_prompt + _SCHEMA_INSTRUCTION.format(json_schema=_load_json_schema())

        completion = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": full_prompt},
                {"role": "user", "content": user_content},
            ],
            max_tokens=MAX_TOKENS,
            temperature=_current_temperature(),
            response_format={"type": "json_object"},
        )
        raw = completion.choices[0].message.content
        if not raw:
            raise ValueError("Model returned empty response")
        json_str = _extract_json(raw)
        data = json.loads(json_str)
        data.pop("$schema", None)
        schema = MantaraSchema.model_validate(data)
        return schema, completion
    else:
        completion = client.beta.chat.completions.parse(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
            ],
            response_format=MantaraSchema,
            max_tokens=MAX_TOKENS,
            temperature=_current_temperature(),
        )
        message = completion.choices[0].message
        if message.refusal:
            raise ValueError(f"Model refused: {message.refusal}")
        if message.parsed is None:
            raise ValueError("Model returned no structured output")
        return message.parsed, completion


@dataclass
class StepResult:
    name: str
    output: str
    latency: float
    tokens: int = 0


@dataclass
class V2Result:
    schema: MantaraSchema
    json_str: str
    sql_str: str
    validation: dict
    elapsed_seconds: float
    steps: list = field(default_factory=list)
    repair_attempts: int = 0


# ---------------------------------------------------------------------------
# Step 1: Input Analyzer (cheap model)
# ---------------------------------------------------------------------------

_ANALYZER_PROMPT = """You are an input analyzer for a database schema generator.

Given a user's description of a business system, extract ALL of the following fields:

1. **system_type**: "functional" (user describes what the system does) or "structured" (user provides menu/submenu/field layout)
2. **system_name**: A clear name for the system (e.g., "Hospital Management System")
3. **schema_abbreviation**: A short snake_case abbreviation (e.g., "hms", "sos", "rms")

4. **explicit_entities**: List of entities DIRECTLY mentioned in the input (e.g., for "hospital" -> ["patients", "doctors"])

5. **implicit_entities**: List of entities NOT mentioned but REQUIRED by the domain. Use these expansion rules:
   - Every core entity needs: a master table, a history/audit table, a settings/preferences table
   - "hospital" implies: patients, doctors, nurses, departments, appointments, appointment_history, prescriptions, prescription_items, medical_records, billing, billing_items, payments, rooms, bed_assignments, wards, pharmacy, lab_tests, lab_results, vital_signs, users, audit_log
   - "restaurant" implies: menu_items, menu_categories, tables, reservations, orders, order_items, customers, staff, shifts, inventory, suppliers, supplier_items, bills, payments, loyalty_points, loyalty_rewards
   - "e-commerce" implies: products, categories, customers, carts, cart_items, orders, order_items, payments, shipping, addresses, reviews, wishlists, coupons, returns
   - "school/university" implies: students, faculty, courses, enrollments, grades, departments, classrooms, schedules, attendance, fees, fee_payments, transcripts
   - For ANY system, always add: users, roles, permissions, audit_log, system_settings, notifications
   - Junction tables for every M:N relationship (e.g., doctor_specialties, staff_services)
   - Line-item tables for every header entity (e.g., order_items, invoice_lines, prescription_items)
   - Schedule/availability tables (e.g., doctor_schedule, room_availability, staff_shifts)
   - The implicit list should be AT LEAST 2x the explicit list

6. **relationships**: List of relationships with explicit cardinality. Each entry must follow this format:
   {"from": "entity_a", "to": "entity_b", "type": "1:N" or "M:N" or "1:1", "description": "...", "junction_table": "name_if_MN_else_null"}
   Include at least 8-10 relationships. Every M:N MUST name its junction table.

7. **enum_candidates**: List of columns that should be ENUMs -- be generous, at least 6-8. Each entry:
   {"column_name": "...", "suggested_values": ["val_a", "val_b", "val_c"], "used_in_tables": ["table1", "table2"]}

8. **complexity_score**: An object with:
   {"level": "simple"|"moderate"|"complex", "min_menus": N, "min_submenus": N, "min_tables": N}
   Rules:
   - "simple" (e.g., todo app): min_menus=3, min_submenus=7, min_tables=10
   - "moderate" (e.g., restaurant, clinic): min_menus=5, min_submenus=12, min_tables=18
   - "complex" (e.g., hospital, ERP, e-commerce): min_menus=7, min_submenus=18, min_tables=25

9. **suggested_menu_groups**: Group ALL entities (explicit + implicit) into logical menu groups.
   Each entry: {"menu_name": "...", "entities": ["entity_a", "entity_b", ...]}
   Every entity must appear in exactly one group. Minimum 3 groups, each with 2+ entities.

CRITICAL: Do NOT under-count entities. A "hospital management system" needs at minimum: patients, doctors, departments, appointments, appointment_history, prescriptions, prescription_items, billing, billing_items, payments, rooms, bed_assignments, medical_records, users, audit_log. That's 15 entities, not 4.

CRITICAL: A "todo app" is NOT just 2 tables. It needs: tasks, categories, tags, task_tags, task_comments, task_attachments, users, user_preferences, notifications, audit_log. That's 10 entities minimum.

Respond in JSON format only. No markdown, no code fences."""

_ANALYZER_FSD_ADDENDUM = """

ADDITIONAL CONTEXT: The input is a Functional Specification Document (FSD).
The FSD has been pre-analyzed and key sections have been marked with:
- "=== FSD ANALYSIS ===" header with pre-extracted modules, features, entities, relationships, and ENUMs.

When processing an FSD:
1. Trust the pre-extracted entity list but ADD any implied entities the FSD missed (junction tables, audit tables, line-item tables).
2. The DETECTED MODULES should map directly to your suggested_menu_groups.
3. The DETECTED FEATURES should map to submenus within those menu groups.
4. The DETECTED ENUM CANDIDATES should all appear in your enum_candidates (plus any additional ones you identify).
5. The DETECTED RELATIONSHIPS should all appear in your relationships list.
6. Set system_type to "structured" since the FSD provides explicit structure.
7. Pay attention to priority indicators (P0, P1, P2) -- P0 entities are core and must not be missed."""


def _is_fsd_input(user_input: str) -> bool:
    """Check if the user input contains FSD analysis context."""
    return "=== FSD ANALYSIS" in user_input


def _step_analyze(client: OpenAI, user_input: str) -> StepResult:
    """Step 1: Analyze and classify the input -- extracts explicit/implicit entities, relationships, complexity."""
    start = time.time()

    # Use FSD-enhanced prompt when FSD context is detected
    prompt = _ANALYZER_PROMPT
    if _is_fsd_input(user_input):
        prompt = _ANALYZER_PROMPT + _ANALYZER_FSD_ADDENDUM

    use_model = _resolve_model("gpt-4o-mini")

    resp = _chat_complete(client, model=use_model, messages=[
        {"role": "system", "content": prompt},
        {"role": "user", "content": user_input},
    ], max_tokens=4000, temperature=0.1, json_mode=True)

    elapsed = round(time.time() - start, 1)
    output = resp.choices[0].message.content
    tokens = resp.usage.total_tokens if resp.usage else 0

    _log("analyze", f"model={use_model}  tokens={tokens}  latency={elapsed}s")

    return StepResult(name="Input Analysis", output=output, latency=elapsed, tokens=tokens)


# ---------------------------------------------------------------------------
# Step 2: Schema Planner (strong model)
# ---------------------------------------------------------------------------

_PLANNER_PROMPT = """You are a database schema planner for the Mantara schema standard.

Given the analysis below (which includes explicit_entities, implicit_entities, relationships, and complexity_score), design a complete schema plan.

HARD RULES (violations cause rejection and re-planning):
HR-1: Every menu MUST have AT LEAST 2 submenus. A menu with 1 submenu is ALWAYS wrong.
      You MUST decompose each menu into at least 2 meaningful submenus.
      If you cannot think of 2 submenus, merge the menu into another or decompose further.
HR-2: Every entity from explicit_entities AND implicit_entities MUST appear as a table.
      You MUST NOT silently drop any entity. After planning, list any uncovered entity -- this list MUST be empty.
HR-3: Total menus MUST be >= complexity_score.min_menus from the analysis.
HR-4: Total submenus MUST be >= complexity_score.min_submenus from the analysis.
HR-5: Total tables MUST be >= complexity_score.min_tables from the analysis.
HR-6: Each submenu should have 1-3 tables. A submenu with 0 tables is pointless.
HR-7: No generic submenu names: "General", "Misc", "Other", "Main" are BANNED.

OUTPUT FORMAT -- produce a JSON object with these fields:

1. **entity_coverage_checklist**: Before planning, list EVERY entity from the analysis (explicit + implicit) and which menu/submenu you will assign it to. Format: [{"entity": "...", "menu": "...", "submenu": "..."}]

2. **menus**: List of top-level menus with menu_id (starting at 1), name, and description. Each menu object must include its submenus array inline.

3. For each menu, include **submenus** with submenu_id (convention: menu_id * 100 + seq), name, description, and tables list.

4. For each submenu, include **tables** with:
   - table_name (snake_case)
   - columns with types, constraints, and inline comments for domain context
   - foreign key relationships to other tables

5. **enum_types**: Full list with type_name (schema.name_enum), values (snake_case only), and description

6. **uncovered_entities**: Entities from analysis NOT assigned to any table. THIS MUST BE EMPTY.

7. **validation_summary**: {"total_menus": N, "total_submenus": N, "total_tables": N, "min_submenus_per_menu": N, "all_entities_covered": true/false}

8. **assumptions**: Domain assumptions you made
9. **open_questions**: Questions that a real implementation would need answered

DECOMPOSITION TEMPLATES -- use these patterns for common menu types:

When you see "X Management", decompose into:
  -> "X Registration/Setup" (master data entry)
  -> "X Directory/Listing" (browse/search)
  -> "X History/Activity Log" (audit trail)

When you see "X Processing" or "X Operations", decompose into:
  -> "X Input/Intake" (create new records)
  -> "X Workflow/Processing" (status changes, approvals)
  -> "X Completion/Output" (finalization, results)

When you see "Orders" or "Transactions", decompose into:
  -> "Create Order/Transaction"
  -> "Order Tracking/Status"
  -> "Returns/Refunds/Adjustments"

When you see "Billing" or "Finance", decompose into:
  -> "Invoice Generation"
  -> "Payment Processing"
  -> "Payment History/Reports"

"Reports & Analytics" should ALWAYS be its own menu with submenus per domain:
  -> "Operational Reports"
  -> "Financial Reports"
  -> "Performance Analytics"

"Settings & Administration" should ALWAYS be its own menu:
  -> "User Management" (users, roles, permissions)
  -> "System Configuration" (settings, preferences)
  -> "Audit & Logs" (audit_log, system_log)

ADDITIONAL RULES:
- Every table must have: id SERIAL PRIMARY KEY as first column, submenu_id INT as second column. NO EXCEPTIONS -- line-item tables and junction tables MUST also have submenu_id.
- Every table should have created_at and updated_at TIMESTAMP columns
- All names must be snake_case
- ENUM values must be snake_case (e.g., a_positive not A+, grade_a not A)
- Foreign keys reference schema.table(column) format
- FK COLUMN NAME RULE: Every business table's PK column is named exactly `id`. When writing a REFERENCES constraint, the referenced column MUST be `id`. WRONG: `REFERENCES schema.bmiresult(bmi_result_id)`. RIGHT: `REFERENCES schema.bmiresult(id)`. Using any name other than `id` will produce "column does not exist" errors.
- FK TYPE RULE: Every FK column must be INT (matching the SERIAL/integer PK it references). NEVER use UUID, BIGINT, or VARCHAR for FK columns. All PKs are SERIAL (integer), so all FK columns referencing them must be INT.
- submenu_id convention: menu_id * 100 + sequence (Menu 1 -> 101, 102; Menu 2 -> 201, 202)
- Orders/invoices MUST have a line-items table. M:N relationships MUST have junction tables.
- NORMALIZATION: Separate master data from assignments. Create master tables and assignment tables.
- Use NUMERIC(12,2) for monetary values, NUMERIC(10,2) for quantities.
- Add CHECK constraints on money/quantity columns.
- Add indexes on FK columns. Add UNIQUE constraints where business rules demand it.

THINK STEP BY STEP:
  a) Read complexity_score from the analysis -- note minimums for menus, submenus, tables
  b) Read ALL entities from explicit_entities + implicit_entities -- write entity_coverage_checklist
  c) Group entities into domain modules (menus) -- MUST meet min_menus target
  d) For EACH menu, apply decomposition templates to create 2-4 submenus -- MUST meet min_submenus target
  e) Assign tables to submenus (1-3 per submenu), adding junction/line-item/audit tables -- MUST meet min_tables target
  f) VERIFY: count submenus per menu -- if ANY menu has only 1 submenu, STOP and split/merge
  g) VERIFY: check entity_coverage_checklist -- if ANY entity is unassigned, add it
  h) VERIFY: check for banned names ("General", "Misc", "Other", "Main") -- rename if found
  i) Fill in validation_summary and confirm all checks pass

Respond in detailed JSON format. No markdown, no code fences."""

_PLANNER_FSD_ADDENDUM = """

ADDITIONAL FSD CONTEXT:
The input is a Functional Specification Document with pre-extracted structure.
- Map DETECTED MODULES directly to menus (use the module names as menu names).
- Map DETECTED FEATURES directly to submenus under their parent module/menu.
- Every DETECTED ENTITY must have a corresponding table -- do not drop any.
- Every DETECTED ENUM CANDIDATE must appear as an ENUM type in the schema.
- Every DETECTED RELATIONSHIP must be implemented as a foreign key.
- DETECTED BUSINESS RULES should inform CHECK constraints and validation logic.
- The FSD structure takes precedence over your default decomposition heuristics.
  Only add additional menus/submenus if the FSD modules need splitting or merging
  to meet the 2-4 submenus per menu requirement."""


def _step_plan(client: OpenAI, analysis: str, user_input: str) -> StepResult:
    """Step 2: Create a detailed schema plan with entity coverage and decomposition."""
    start = time.time()

    prompt = _PLANNER_PROMPT
    if _is_fsd_input(user_input):
        prompt = _PLANNER_PROMPT + _PLANNER_FSD_ADDENDUM

    use_model = _resolve_model("gpt-4o")

    resp = _chat_complete(client, model=use_model, messages=[
        {"role": "system", "content": prompt},
        {"role": "user", "content": f"Original request:\n{user_input}\n\nAnalysis:\n{analysis}"},
    ], max_tokens=12000, temperature=0.2, json_mode=True)

    elapsed = round(time.time() - start, 1)
    output = resp.choices[0].message.content
    tokens = resp.usage.total_tokens if resp.usage else 0

    _log("plan", f"model={use_model}  tokens={tokens}  latency={elapsed}s")

    return StepResult(name="Schema Planning", output=output, latency=elapsed, tokens=tokens)


# ---------------------------------------------------------------------------
# Step 2.5: Pre-Generation Plan Validation
# ---------------------------------------------------------------------------

_BANNED_SUBMENU_NAMES = {"general", "misc", "other", "main", "miscellaneous", "default"}


def _try_parse_json(text: str | None) -> dict | None:
    """Try to parse JSON from text, handling extra text around the JSON object."""
    if not text:
        return None
    try:
        return json.loads(text)
    except (json.JSONDecodeError, TypeError):
        json_match = re.search(r'\{[\s\S]*\}', text)
        if json_match:
            try:
                return json.loads(json_match.group())
            except (json.JSONDecodeError, TypeError):
                return None
        return None


def _validate_plan(plan_json: str, analysis_json: str) -> list[str]:
    """Validate the plan before sending to the generator. Returns list of issues found."""
    issues = []

    # Parse plan and analysis JSON (best-effort)
    plan = _try_parse_json(plan_json)
    if plan is None:
        _log("validate_plan", "Could not parse plan JSON -- skipping validation")
        return []

    analysis = _try_parse_json(analysis_json) or {}

    # --- Check 1: Every menu has >= 2 submenus ---
    menus = plan.get("menus", [])
    for menu in menus:
        menu_name = menu.get("name", menu.get("menu_name", "Unknown"))
        submenus = menu.get("submenus", [])
        if len(submenus) < 2:
            issues.append(
                f"PLAN_VIOLATION: Menu '{menu_name}' has only {len(submenus)} submenu(s) -- minimum is 2. "
                f"Decompose it using templates: Setup/Registration, Operations, History/Tracking."
            )

    # --- Check 2: All analyzer entities are covered ---
    explicit = analysis.get("explicit_entities", analysis.get("entities", []))
    implicit = analysis.get("implicit_entities", [])
    all_entities = set()
    for e in explicit:
        if isinstance(e, str):
            all_entities.add(e.lower().strip())
    for e in implicit:
        if isinstance(e, str):
            all_entities.add(e.lower().strip())

    if all_entities:
        # Collect all table names from the plan
        plan_tables = set()
        for menu in menus:
            for submenu in menu.get("submenus", []):
                for table in submenu.get("tables", []):
                    tname = table.get("table_name", table.get("name", ""))
                    if isinstance(tname, str):
                        plan_tables.add(tname.lower().strip())

        # Check coverage -- fuzzy match (entity name should appear in at least one table name)
        uncovered = []
        for entity in all_entities:
            entity_clean = entity.replace("_", "")
            covered = any(
                entity in tname or entity_clean in tname.replace("_", "")
                for tname in plan_tables
            )
            if not covered:
                uncovered.append(entity)

        if uncovered:
            issues.append(
                f"PLAN_VIOLATION: {len(uncovered)} entities from analysis are NOT covered by any table: "
                f"{', '.join(sorted(uncovered))}. Each entity must map to at least one table."
            )

    # --- Check 3: No banned submenu names ---
    for menu in menus:
        for submenu in menu.get("submenus", []):
            sm_name = submenu.get("name", submenu.get("submenu_name", ""))
            if isinstance(sm_name, str) and sm_name.lower().strip() in _BANNED_SUBMENU_NAMES:
                issues.append(
                    f"PLAN_VIOLATION: Submenu '{sm_name}' in menu '{menu.get('name', '?')}' uses a banned generic name. "
                    f"Rename to a specific business function (e.g., 'Task Categories' instead of 'General')."
                )

    # --- Check 4: Minimum counts from complexity score ---
    complexity = analysis.get("complexity_score", {})
    if isinstance(complexity, dict):
        min_menus = complexity.get("min_menus", 3)
        min_submenus = complexity.get("min_submenus", 6)
        min_tables = complexity.get("min_tables", 8)

        total_menus = len(menus)
        total_submenus = sum(len(m.get("submenus", [])) for m in menus)
        total_tables = sum(
            len(sm.get("tables", []))
            for m in menus
            for sm in m.get("submenus", [])
        )

        if total_menus < min_menus:
            issues.append(
                f"PLAN_VIOLATION: Plan has {total_menus} menus but complexity_score requires >= {min_menus}."
            )
        if total_submenus < min_submenus:
            issues.append(
                f"PLAN_VIOLATION: Plan has {total_submenus} submenus but complexity_score requires >= {min_submenus}."
            )
        if total_tables < min_tables:
            issues.append(
                f"PLAN_VIOLATION: Plan has {total_tables} tables but complexity_score requires >= {min_tables}."
            )

    return issues


def _step_replan(client: OpenAI, analysis: str, user_input: str,
                 previous_plan: str, issues: list[str]) -> StepResult:
    """Re-run the planner with specific fix instructions for plan violations."""
    start = time.time()

    prompt = _PLANNER_PROMPT
    if _is_fsd_input(user_input):
        prompt = _PLANNER_PROMPT + _PLANNER_FSD_ADDENDUM

    use_model = _resolve_model("gpt-4o")

    fix_instructions = "\n".join(f"- {issue}" for issue in issues)
    replan_msg = (
        f"Original request:\n{user_input}\n\n"
        f"Analysis:\n{analysis}\n\n"
        f"Your PREVIOUS plan had these violations:\n{fix_instructions}\n\n"
        f"PREVIOUS PLAN (for reference -- fix the violations above):\n{previous_plan}\n\n"
        f"Generate a CORRECTED plan that fixes ALL violations. "
        f"Remember: every menu MUST have >= 2 submenus, every entity must be covered, "
        f"no generic names like 'General' or 'Misc'."
    )

    resp = _chat_complete(client, model=use_model, messages=[
        {"role": "system", "content": prompt},
        {"role": "user", "content": replan_msg},
    ], max_tokens=12000, temperature=0.2, json_mode=True)

    elapsed = round(time.time() - start, 1)
    output = resp.choices[0].message.content
    tokens = resp.usage.total_tokens if resp.usage else 0

    _log("replan", f"model={use_model}  tokens={tokens}  latency={elapsed}s  fixed={len(issues)} issues")

    return StepResult(name="Plan Repair", output=output, latency=elapsed, tokens=tokens)


# ---------------------------------------------------------------------------
# Step 3: JSON Generator (structured outputs)
# ---------------------------------------------------------------------------

def _load_system_prompt() -> str:
    with open(SYSTEM_PROMPT_PATH) as f:
        return f.read()


_GENERATOR_ADDENDUM = """

---

IMPORTANT: Generate ONLY the structured JSON (Part A). Do NOT generate SQL.
You have been given a detailed schema plan. Follow it closely -- use the exact
menu structure, submenu IDs, table names, columns, and ENUM types from the plan.
All ENUM values MUST be snake_case (no symbols like + or -).
Include column-level comments for domain context (units, examples, security notes).
Include assumptions and open_questions arrays in the root object.
Focus on producing accurate mantara.schema.v1 JSON.

CRITICAL VALIDATION RULES (your output will be rejected if violated):
1. Every menu MUST have at least 2 submenus -- NEVER a 1:1 menu-to-submenu mapping.
2. Every business table's SECOND column must be submenu_id (including line-item and junction tables).
3. Use ENUM types for all status/type/role/category/method columns -- never bare VARCHAR.
4. Total tables should be at least (number_of_menus + 5).
5. Include junction tables for M:N relationships and line-item tables for header entities.
6. MINIMUM SCHEMA SIZE: at least 3 menus, 6 submenus, 8 tables.
7. Any column name containing status/type/role/category/priority/severity/level/mode/method/gender/payment MUST use an ENUM type, not VARCHAR/TEXT.
8. Every table must have at least 5 business columns (excluding id, submenu_id, created_at, updated_at). Thin tables will be rejected.
9. FK TYPE CONSISTENCY: All PKs are SERIAL (integer). Every FK column that references another table's id MUST be type INT -- never UUID, BIGINT, or TEXT. Mismatched types (e.g. uuid FK → integer PK) will cause a Postgres error and will be rejected.
10. FK REFERENCED COLUMN: The column inside REFERENCES schema.table(...) MUST be `id` for every business table. NEVER write `tablename_id`, `table_id`, or any other variant -- those columns do not exist. Example: `REFERENCES schema.bmiresult(id)`, NEVER `REFERENCES schema.bmiresult(bmi_result_id)`. This causes a hard Postgres error and will be rejected."""

_GENERATOR_FSD_ADDENDUM = """

FSD-SPECIFIC INSTRUCTIONS:
The input is a Functional Specification Document. The pre-extracted FSD analysis
(marked with === FSD ANALYSIS ===) provides authoritative module, entity, and ENUM
information. Ensure:
- Every module from the FSD maps to a menu.
- Every feature from the FSD maps to a submenu.
- Every entity from the FSD has a corresponding table with ALL listed attributes as columns.
- Every ENUM candidate from the FSD is created as a proper ENUM type.
- Business rules from the FSD are implemented as CHECK constraints where possible."""


def _build_pregeneration_rules(plan: str) -> str:
    """Scan the plan to build dynamic validation rules for common failure modes."""
    rules = []

    # Count entities mentioned in the plan to set expectations
    plan_lower = plan.lower()

    # Check for enum-candidate keywords in the plan
    enum_keywords = ["status", "type", "role", "category", "priority", "severity",
                     "level", "mode", "method", "gender", "payment"]
    found_keywords = [kw for kw in enum_keywords if kw in plan_lower]
    if found_keywords:
        rules.append(
            f"PRE-CHECK: The plan mentions these enum-like concepts: {', '.join(found_keywords)}. "
            f"Make sure EVERY column with these words in its name uses an ENUM type, not VARCHAR/TEXT. "
            f"Define the enum in enum_types with snake_case values."
        )

    # Check for potential thin entities
    if "junction" in plan_lower or "bridge" in plan_lower:
        rules.append(
            "PRE-CHECK: Junction/bridge tables still need at least 5 business columns. "
            "Add fields like quantity, notes, effective_date, status, is_active beyond just the two FKs."
        )

    # Check for menu count
    menu_count = plan_lower.count('"menu"') + plan_lower.count("'menu'") + plan_lower.count("menu_id")
    if menu_count < 3:
        rules.append(
            "PRE-CHECK: The plan may have too few menus. Ensure at least 3 menus with 2+ submenus each."
        )

    if rules:
        return "\n\nDYNAMIC PRE-GENERATION CHECKS:\n" + "\n".join(f"- {r}" for r in rules)
    return ""


def _step_generate(client: OpenAI, plan: str, user_input: str, model: str) -> tuple[MantaraSchema, StepResult]:
    """Step 3: Generate MantaraSchema JSON using Structured Outputs (OpenAI) or JSON mode (Ollama)."""
    start = time.time()
    system_prompt = _load_system_prompt() + _GENERATOR_ADDENDUM
    if _is_fsd_input(user_input):
        system_prompt += _GENERATOR_FSD_ADDENDUM

    # Add dynamic pre-generation rules based on plan analysis
    pregeneration_rules = _build_pregeneration_rules(plan)
    generation_prompt = f"Original request:\n{user_input}\n\nSchema plan to follow:\n{plan}{pregeneration_rules}"

    schema, completion = _parse_structured(client, system_prompt, generation_prompt, model)

    elapsed = round(time.time() - start, 1)
    tokens = completion.usage.total_tokens if completion.usage else 0

    _log("generate", f"model={model}  tokens={tokens}  latency={elapsed}s")

    return schema, StepResult(name="JSON Generation", output="(structured output)", latency=elapsed, tokens=tokens)


# ---------------------------------------------------------------------------
# Repair loop
# ---------------------------------------------------------------------------

def _build_repair_instructions(errors: list[str]) -> str:
    """Build specific, actionable repair instructions from categorized errors."""
    sections = []

    # Categorize errors for targeted instructions
    submenu_errors = [e for e in errors if ("submenu" in e.lower() and ("only 1 submenu" in e.lower() or "0 submenu" in e.lower() or "at least 2" in e.lower()))]
    enum_errors = [e for e in errors if "ENUM NEEDED" in e or "should be an ENUM" in e.lower() or "enum type" in e.lower()]
    thin_errors = [e for e in errors if "THIN ENTITY" in e]
    size_errors = [e for e in errors if "MINIMUM SIZE" in e]
    categorized = set(submenu_errors + enum_errors + thin_errors + size_errors)
    other_errors = [e for e in errors if e not in categorized]

    if submenu_errors:
        menu_names = []
        for e in submenu_errors:
            m = re.search(r"Menu '([^']+)'", e)
            if m:
                menu_names.append(m.group(1))
        sections.append(
            "SUBMENU GRANULARITY FIXES (CRITICAL):\n"
            + "\n".join(f"  - {e}" for e in submenu_errors)
            + "\n\n"
            + "  ACTION REQUIRED: Each of these menus MUST have at least 2 submenus.\n"
            + "  For each menu listed above, decompose it into 2-4 submenus representing "
            + "distinct user-facing screens/functions.\n"
            + "  Examples of good decomposition:\n"
            + "    - 'Customer Management' -> 'Customer Registration', 'Customer Directory', 'Customer Groups'\n"
            + "    - 'Orders' -> 'Create Order', 'Order History', 'Returns & Refunds'\n"
            + "    - 'Settings' -> 'User Management', 'System Config', 'Roles & Permissions'\n"
        )

    if enum_errors:
        col_list = []
        for e in enum_errors:
            m = re.search(r"Table '([^']+)'\.(\w+)", e)
            if m:
                col_list.append(f"    - {m.group(1)}.{m.group(2)}")
        sections.append(
            "ENUM TYPE FIXES:\n"
            + "\n".join(f"  - {e}" for e in enum_errors)
            + "\n\n"
            + "  ACTION REQUIRED: Change these columns from VARCHAR/TEXT to ENUM types.\n"
            + "  For each column above:\n"
            + "    1. Create a new enum type in enum_types (e.g., 'schema.column_name_enum')\n"
            + "    2. Add appropriate snake_case values to the enum\n"
            + "    3. Change the column type to reference the new enum\n"
            + "  Columns that need ENUMs:\n"
            + "\n".join(col_list) + "\n"
        )

    if thin_errors:
        table_list = []
        for e in thin_errors:
            m = re.search(r"Table '([^']+)' has only (\d+) business column", e)
            if m:
                table_list.append(f"    - {m.group(1)} (currently {m.group(2)} business columns)")
        sections.append(
            "THIN ENTITY FIXES:\n"
            + "\n".join(f"  - {e}" for e in thin_errors)
            + "\n\n"
            + "  ACTION REQUIRED: Add more domain-relevant columns to these tables.\n"
            + "  Each table should have at least 5 business columns (excluding id, submenu_id, created_at, updated_at).\n"
            + "  Think about what attributes a real business would need to track for each entity.\n"
            + "  Tables needing more columns:\n"
            + "\n".join(table_list) + "\n"
        )

    if size_errors:
        sections.append(
            "MINIMUM SCHEMA SIZE FIXES:\n"
            + "\n".join(f"  - {e}" for e in size_errors)
            + "\n\n"
            + "  ACTION REQUIRED: The schema is too small for a real application.\n"
            + "  Minimum requirements: 3 menus, 6 submenus, 8 tables.\n"
            + "  Add domain-appropriate menus (e.g., Settings, Reports, Administration)\n"
            + "  and ensure each menu has 2+ submenus with relevant tables.\n"
        )

    if other_errors:
        sections.append(
            "OTHER FIXES:\n"
            + "\n".join(f"  - {e}" for e in other_errors)
        )

    return "\n\n".join(sections)


def _normalise_in_loop(schema: MantaraSchema) -> tuple[MantaraSchema, dict]:
    """Apply normaliser to a MantaraSchema. Returns (new_schema, stats).

    Round-trip: model → dict → normalise (mutates dict, adds new tables, etc.)
    → strip private `_*` keys → MantaraSchema. Pydantic ignores unknown keys
    by default so the round-trip is safe. Idempotent — running twice = no-op.
    """
    try:
        from normaliser import normalise  # local import to avoid circular at module load
        # `to_json_dict` includes the `$schema` marker — strip it before re-construction
        d = schema.to_json_dict()
        d.pop("$schema", None)
        # In-loop: keep schema_name stable so the LLM sees a consistent
        # identifier across repair attempts. The final unique sch_<date>_*
        # name is assigned after the loop by step-05's normalise() call.
        normalise(d, cir=None, regenerate_name=False)
        # Clean private keys before reconstructing the typed model
        clean = {k: v for k, v in d.items() if not k.startswith("_")}
        # Strip per-table private fields too (they'd be ignored, but explicit is better)
        for menu in clean.get("menus", []) or []:
            for sm in menu.get("submenus") or []:
                for tbl in sm.get("tables") or []:
                    for k in list(tbl.keys()):
                        if k.startswith("_"):
                            del tbl[k]
        new_schema = MantaraSchema(**clean)
        stats = d.get("_normaliser") or {}
        # Trim noisy fields, keep only the action counters worth logging
        return new_schema, {
            k: v for k, v in stats.items()
            if k in {"lifted_enums", "auto_created_cfg", "normalised_repairs",
                      "rewritten_columns", "stripped_dangling_fks"}
        }
    except Exception as exc:  # noqa: BLE001
        _log("normalise", f"skipped (non-fatal): {exc}")
        return schema, {}


def _step_repair(client: OpenAI, schema: MantaraSchema, errors: list[str],
                 plan: str, user_input: str, model: str, attempt: int) -> tuple[MantaraSchema, StepResult]:
    """Re-run JSON generation with specific, categorized error feedback."""
    start = time.time()
    system_prompt = _load_system_prompt() + _GENERATOR_ADDENDUM

    repair_instructions = _build_repair_instructions(errors)
    error_feedback = "\n".join(f"- {e}" for e in errors)
    repair_msg = (
        f"Original request:\n{user_input}\n\n"
        f"Schema plan:\n{plan}\n\n"
        f"REPAIR ATTEMPT {attempt}: The previous output FAILED validation with {len(errors)} error(s).\n\n"
        f"=== VALIDATION ERRORS ===\n{error_feedback}\n\n"
        f"=== SPECIFIC FIX INSTRUCTIONS ===\n{repair_instructions}\n\n"
        f"Fix ALL of these errors in the new output. Do NOT introduce new errors while fixing these."
    )

    schema, completion = _parse_structured(client, system_prompt, repair_msg, model)

    elapsed = round(time.time() - start, 1)
    tokens = completion.usage.total_tokens if completion.usage else 0

    _log("repair", f"attempt={attempt}  model={model}  tokens={tokens}  latency={elapsed}s")

    return schema, StepResult(name=f"Repair (attempt {attempt})", output="(structured output)", latency=elapsed, tokens=tokens)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def generate_v2(user_input: str, model: str | None = None,
                on_step=None) -> V2Result:
    """Full v2 pipeline: Analyze -> Plan -> Validate Plan -> Generate -> Validate/Repair -> Render.

    Args:
        user_input: Natural language description.
        model: Model for JSON generation (default from config). Analysis always uses gpt-4o-mini.
        on_step: Optional callback(step_name, step_result) for progress updates.

    Returns:
        V2Result with schema, JSON, SQL, validation, steps, and timing.
    """
    use_model = _resolve_model(model or MODEL)
    client = _make_client()
    overall_start = time.time()
    steps = []

    def _notify(step_result):
        steps.append(step_result)
        if on_step:
            on_step(step_result.name, step_result)

    # Step 1: Analyze
    analysis_result = _step_analyze(client, user_input)
    _notify(analysis_result)

    # Step 2: Plan
    plan_result = _step_plan(client, analysis_result.output, user_input)
    _notify(plan_result)

    # Step 2.5: Validate the plan before generation (max 1 replan attempt)
    plan_issues = _validate_plan(plan_result.output, analysis_result.output)
    if plan_issues:
        _log("validate_plan", f"FAIL -- {len(plan_issues)} issues found, re-planning")
        for issue in plan_issues:
            _log("validate_plan", f"  {issue}")
        plan_result = _step_replan(
            client, analysis_result.output, user_input,
            plan_result.output, plan_issues
        )
        _notify(plan_result)

        # Check again (but don't loop endlessly -- one replan attempt is enough)
        plan_issues_2 = _validate_plan(plan_result.output, analysis_result.output)
        if plan_issues_2:
            _log("validate_plan", f"WARN -- {len(plan_issues_2)} issues remain after replan (proceeding anyway)")
        else:
            _log("validate_plan", "PASS after replan")
    else:
        _log("validate_plan", "PASS")

    # Step 3: Generate
    schema, gen_result = _step_generate(client, plan_result.output, user_input, use_model)
    _notify(gen_result)

    # Step 3.5: NORMALISE before validation
    # Deterministically repairs cfg shape mismatches, auto-creates missing FK
    # target tables, strips dangling FKs, and rewrites business→cfg refs to the
    # canonical <cfg>(<cfg>_id) form. Runs BEFORE the legacy validator so the
    # repair loop doesn't waste 3 LLM calls (~$0.50, ~3 min) fixing structural
    # bugs the normaliser handles in milliseconds.
    schema, norm_stats = _normalise_in_loop(schema)
    if norm_stats:
        _log("normalise", " ".join(f"{k}={v}" for k, v in norm_stats.items() if v))

    # Step 4: Validate + Repair loop (max 3 retries)
    validation = validate_all(schema)
    repair_attempts = 0

    while not validation["is_valid"] and repair_attempts < 3:
        repair_attempts += 1
        _log("validate", f"FAIL -- {len(validation['errors'])} errors, repair attempt {repair_attempts}")
        schema, repair_result = _step_repair(
            client, schema, validation["errors"],
            plan_result.output, user_input, use_model, repair_attempts
        )
        _notify(repair_result)
        # Normalise again after each LLM repair — same reason as above.
        schema, _ = _normalise_in_loop(schema)
        validation = validate_all(schema)

    if validation["is_valid"]:
        _log("validate", "PASS")
    else:
        _log("validate", f"FAIL after {repair_attempts} repairs -- {len(validation['errors'])} errors remain")

    # Step 5: Render SQL
    sql_str = render_sql(schema)
    json_str = json.dumps(schema.to_json_dict(), indent=2)

    elapsed = round(time.time() - overall_start, 1)

    return V2Result(
        schema=schema,
        json_str=json_str,
        sql_str=sql_str,
        validation=validation,
        elapsed_seconds=elapsed,
        steps=steps,
        repair_attempts=repair_attempts,
    )
