"""Mantara — Web UI for schema generation and comparison.

Run: streamlit run app.py
"""

import streamlit as st
import json
import time
import pandas as pd

from generator import generate
from generator_v2 import generate_v2
from comparator import compare_schemas, assess_schema
from config import MODEL, BACKEND, OLLAMA_MODEL
from fsd_analyzer import analyze_fsd, enrich_user_input, FSDAnalysis
from business_validator import (
    validate_naming,
    validate_submenu_ids,
    validate_table_structure,
    validate_ref_consistency,
    validate_referenced_tables,
    validate_postgresql_syntax,
    validate_numeric_types,
    validate_check_constraints,
    validate_schema_compliance,
    validate_submenu_granularity,
    validate_entity_coverage,
    validate_generic_names,
    validate_enum_usage,
)

st.set_page_config(
    page_title="Mantara — Schema Generator",
    page_icon="🏗️",
    layout="wide",
)

st.title("Mantara")

# Backend indicator
_backend_model = OLLAMA_MODEL if BACKEND in ("ollama", "llamacpp") else MODEL
_backend_label = {"openai": "OpenAI", "ollama": "Ollama", "llamacpp": "llama.cpp (Local)"}.get(BACKEND, BACKEND)
st.caption(f"Backend: **{_backend_label}** · Model: `{_backend_model}`")

st.markdown("**Generate production-ready PostgreSQL schemas from natural language in under 90 seconds.**  \nDescribe a business system → get validated JSON + executable SQL with menus, tables, ENUMs, foreign keys, and CHECK constraints.")

st.divider()

# --- Mode selector ---
mode = st.radio(
    "Mode",
    ["Generate Schema", "Compare Schemas", "Architecture"],
    horizontal=True,
)

st.divider()

# ============================================================
# GENERATE SCHEMA MODE
# ============================================================
if mode == "Generate Schema":

    # --- Settings row ---
    col_input, col_settings = st.columns([3, 1])

    with col_settings:
        st.markdown("**Settings**")

        version = st.radio(
            "Pipeline",
            ["v1 — Single Call", "v2 — Multi-Step"],
            help="v1: One LLM call. Fast, good for simple schemas.\nv2: Analyze → Plan → Generate → Repair. Slower, better for complex schemas.",
        )

        # Show all available models — both OpenAI and local Qwen
        _all_models = ["gpt-4o", "gpt-4o-mini", "qwen2.5-coder:32b-instruct-q4_0", "qwen2.5-coder:7b"]
        model = st.selectbox("Model", _all_models, index=0)
        if "v2" in version:
            if "qwen" in model:
                st.caption(f"v2 uses {model} for all steps (local Ollama inference).")
            else:
                st.caption("v2 uses gpt-4o-mini for analysis, your selected model for planning + generation.")

    with col_input:
        input_method = st.radio(
            "Input method",
            ["Type a description", "Upload FSD / document"],
            horizontal=True,
        )

        is_upload = input_method != "Type a description"

        if not is_upload:
            user_input = st.text_area(
                "Describe your system",
                placeholder="e.g. Build a salon operating system for managing appointments, clients, stylists, services, inventory, billing, and loyalty rewards",
                height=120,
            )
        else:
            uploaded_file = st.file_uploader(
                "Upload a Functional Specification Document",
                type=["txt", "md", "pdf", "docx", "csv"],
                help="Supports .txt, .md, .pdf, .docx, .csv — the full document text is sent as the prompt.",
            )
            user_input = ""
            fsd_analysis = None
            if uploaded_file is not None:
                if uploaded_file.name.endswith(".pdf"):
                    try:
                        import fitz  # PyMuPDF
                        raw_bytes = uploaded_file.read()
                        doc = fitz.open(stream=raw_bytes, filetype="pdf")
                        pages = []
                        for page in doc:
                            # Use layout-preserving extraction for better structure
                            text = page.get_text("text", sort=True)
                            if text.strip():
                                pages.append(f"--- Page {page.number + 1} ---\n{text}")
                        user_input = "\n\n".join(pages)
                        doc.close()
                    except ImportError:
                        st.error("PDF support requires PyMuPDF. Run: `pip install PyMuPDF`")
                elif uploaded_file.name.endswith(".docx"):
                    try:
                        import docx
                        from io import BytesIO
                        doc = docx.Document(BytesIO(uploaded_file.read()))
                        parts = []
                        for para in doc.paragraphs:
                            text = para.text.strip()
                            if not text:
                                continue
                            # Preserve heading structure from DOCX styles
                            style = para.style.name if para.style else ""
                            if "Heading 1" in style:
                                parts.append(f"\n# {text}")
                            elif "Heading 2" in style:
                                parts.append(f"\n## {text}")
                            elif "Heading 3" in style:
                                parts.append(f"\n### {text}")
                            elif "Heading 4" in style:
                                parts.append(f"\n#### {text}")
                            elif "List" in style:
                                parts.append(f"- {text}")
                            else:
                                parts.append(text)
                        # Also extract tables from DOCX
                        for table in doc.tables:
                            header = [cell.text.strip() for cell in table.rows[0].cells]
                            parts.append(f"\n| {' | '.join(header)} |")
                            parts.append(f"| {' | '.join(['---'] * len(header))} |")
                            for row in table.rows[1:]:
                                cells = [cell.text.strip() for cell in row.cells]
                                parts.append(f"| {' | '.join(cells)} |")
                        user_input = "\n".join(parts)
                    except ImportError:
                        st.error("DOCX support requires python-docx. Run: `pip install python-docx`")
                elif uploaded_file.name.endswith(".csv"):
                    import csv
                    from io import StringIO
                    raw = uploaded_file.read().decode("utf-8")
                    reader = csv.DictReader(StringIO(raw))
                    rows = list(reader)
                    # Convert CSV rows to a structured FSD prompt
                    lines = ["Generate a Mantara schema from this Functional Specification Document:\n"]
                    current_module = ""
                    for row in rows:
                        module = row.get("Module", "").strip()
                        if module and module != current_module:
                            lines.append(f"\n## Module: {module}")
                            current_module = module
                        feature = row.get("Feature", "").strip()
                        desc = row.get("Description", "").strip()
                        entities = row.get("Entities", "").strip()
                        relationships = row.get("Relationships", "").strip()
                        enums = row.get("ENUM Candidates", "").strip()
                        rules = row.get("Business Rules", row.get("Rules", "")).strip() if "Business Rules" in row or "Rules" in row else ""
                        priority = row.get("Priority", "").strip()
                        if feature:
                            lines.append(f"\n### {feature}")
                        if priority:
                            lines.append(f"Priority: {priority}")
                        if desc:
                            lines.append(f"{desc}")
                        if entities:
                            lines.append(f"Entities: {entities}")
                        if relationships:
                            lines.append(f"Relationships: {relationships}")
                        if enums:
                            lines.append(f"ENUMs: {enums}")
                        if rules:
                            lines.append(f"Business Rules: {rules}")
                    user_input = "\n".join(lines)
                else:
                    user_input = uploaded_file.read().decode("utf-8")

                if user_input:
                    # --- FSD Analysis ---
                    fsd_analysis = analyze_fsd(user_input)
                    st.session_state["fsd_analysis"] = fsd_analysis

                    if fsd_analysis.is_fsd:
                        st.success(
                            f"FSD detected (confidence: {fsd_analysis.confidence:.0%}) -- "
                            f"{fsd_analysis.summary_text}"
                        )

                        # Summary display: show what was extracted
                        with st.expander("FSD Analysis Summary", expanded=True):
                            # Metrics row
                            a_cols = st.columns(5)
                            a_cols[0].metric("Modules", fsd_analysis.module_count)
                            a_cols[1].metric("Features", len(fsd_analysis.features))
                            a_cols[2].metric("Entities", fsd_analysis.entity_count)
                            a_cols[3].metric("ENUMs", len(fsd_analysis.enum_candidates))
                            a_cols[4].metric("Rules", len(fsd_analysis.business_rules))

                            # Modules and features
                            if fsd_analysis.modules:
                                st.markdown("**Modules** (will become menus):")
                                for mod in fsd_analysis.modules:
                                    st.markdown(f"- {mod}")

                            if fsd_analysis.features:
                                st.markdown("**Features** (will become submenus):")
                                feat_cols = st.columns(3)
                                for i, feat in enumerate(fsd_analysis.features):
                                    feat_cols[i % 3].markdown(f"- {feat}")

                            # Entities with attributes
                            if fsd_analysis.entities:
                                st.markdown("**Entities** (will become tables):")
                                for ent in fsd_analysis.entities:
                                    if ent.attributes:
                                        st.markdown(f"- **{ent.name}**: {', '.join(ent.attributes[:8])}")
                                    else:
                                        st.markdown(f"- **{ent.name}**")

                            # Enum candidates
                            if fsd_analysis.enum_candidates:
                                st.markdown("**ENUM Candidates:**")
                                for enum in fsd_analysis.enum_candidates:
                                    st.markdown(f"- `{enum.name}`: {' | '.join(enum.values[:8])}")

                            # Relationships
                            if fsd_analysis.relationships:
                                st.markdown("**Relationships:**")
                                for rel in fsd_analysis.relationships[:15]:
                                    card = f" ({rel.cardinality})" if rel.cardinality else ""
                                    st.markdown(f"- {rel.from_entity} -> {rel.to_entity}{card}")

                            # Business rules
                            if fsd_analysis.business_rules:
                                st.markdown(f"**Business Rules** ({len(fsd_analysis.business_rules)} detected):")
                                for rule in fsd_analysis.business_rules[:10]:
                                    st.markdown(f"- {rule.rule[:120]}")
                                if len(fsd_analysis.business_rules) > 10:
                                    st.caption(f"... and {len(fsd_analysis.business_rules) - 10} more rules")

                    with st.expander(f"Raw Document Preview ({len(user_input):,} chars)", expanded=False):
                        st.text(user_input[:3000] + ("..." if len(user_input) > 3000 else ""))

    generate_btn = st.button("Generate Schema", type="primary", use_container_width=True)

    # --- Generate ---
    if generate_btn and user_input.strip():

        use_v2 = "v2" in version

        # Enrich input with FSD analysis if available
        fsd_analysis = st.session_state.get("fsd_analysis")
        if fsd_analysis and fsd_analysis.is_fsd:
            enriched_input = enrich_user_input(user_input.strip(), fsd_analysis)
        else:
            # Run analysis on typed input too (may be a pasted FSD)
            if len(user_input.strip()) > 800:
                fsd_analysis = analyze_fsd(user_input.strip())
                if fsd_analysis.is_fsd:
                    enriched_input = enrich_user_input(user_input.strip(), fsd_analysis)
                else:
                    enriched_input = user_input.strip()
            else:
                enriched_input = user_input.strip()

        # Auto-switch to v2 for large inputs with warning
        if not use_v2 and len(user_input) > 2000:
            st.warning(
                f"Input is large ({len(user_input):,} chars). Switching to **v2 (Multi-Step)** — "
                f"v1 may timeout or produce incomplete results for complex inputs."
            )
            use_v2 = True

        if use_v2:
            # ---- V2: Multi-step with live progress ----
            status = st.status("Running v2 pipeline...", expanded=True)

            with status:
                step_container = st.empty()
                timer_display = st.empty()
                pipeline_start = time.time()

                step_logs = []

                def on_step(name, step_result):
                    step_logs.append(step_result)
                    elapsed_so_far = round(time.time() - pipeline_start, 1)
                    lines = []
                    for s in step_logs:
                        lines.append(f"✅ **{s.name}** — {s.latency}s ({s.tokens:,} tokens)")
                    remaining = ["Schema Planning", "JSON Generation"]
                    if len(step_logs) == 1:
                        lines.append("⏳ **Schema Planning** — running...")
                    elif len(step_logs) == 2:
                        lines.append("⏳ **JSON Generation** — running...")
                    step_container.markdown("\n\n".join(lines))
                    timer_display.metric("Elapsed", f"{elapsed_so_far}s")

                # Show initial state
                step_container.markdown("⏳ **Input Analysis** — running...")
                timer_display.metric("Elapsed", "0.0s")

                try:
                    result = generate_v2(enriched_input, model=model, on_step=on_step)

                    # Final step display
                    lines = []
                    for s in result.steps:
                        lines.append(f"✅ **{s.name}** — {s.latency}s ({s.tokens:,} tokens)")
                    lines.append(f"✅ **SQL Rendering** — deterministic (instant)")
                    step_container.markdown("\n\n".join(lines))
                    timer_display.metric("Elapsed", f"{result.elapsed_seconds}s")

                    if result.validation["is_valid"]:
                        status.update(label=f"Pipeline complete — {result.elapsed_seconds}s ✅", state="complete")
                    else:
                        status.update(label=f"Pipeline complete with warnings — {result.elapsed_seconds}s", state="complete")

                    schema = result.schema
                    validation = result.validation

                except Exception as e:
                    status.update(label="Error", state="error")
                    st.error(f"Error: {type(e).__name__}: {e}")
                    st.stop()

        else:
            # ---- V1: Single call with timer ----
            status = st.status("Running v1 pipeline...", expanded=True)

            with status:
                step_container = st.empty()
                timer_display = st.empty()

                step_container.markdown("⏳ **LLM Call** (Structured Outputs) — running...")
                timer_display.metric("Elapsed", "0.0s")

                try:
                    start = time.time()
                    result_dict = result = generate(enriched_input, model=model)
                    elapsed = round(time.time() - start, 1)

                    step_container.markdown(
                        f"✅ **LLM Call** — {elapsed}s\n\n"
                        f"✅ **Business Validation** — instant\n\n"
                        f"✅ **SQL Rendering** — deterministic (instant)"
                    )
                    timer_display.metric("Elapsed", f"{elapsed}s")

                    schema = result_dict["schema"]
                    validation = result_dict["validation"]
                    result = type("V1Result", (), {
                        "schema": schema,
                        "json_str": result_dict["json_str"],
                        "sql_str": result_dict["sql_str"],
                        "validation": validation,
                        "elapsed_seconds": elapsed,
                        "steps": [],
                        "repair_attempts": 0,
                    })()

                    if validation["is_valid"]:
                        status.update(label=f"Pipeline complete — {elapsed}s ✅", state="complete")
                    else:
                        status.update(label=f"Pipeline complete with warnings — {elapsed}s", state="complete")

                except Exception as e:
                    status.update(label="Error", state="error")
                    st.error(f"Error: {type(e).__name__}: {e}")
                    st.stop()

        # ---- Common output section ----

        table_count = sum(
            len(sub.tables or [])
            for menu in schema.menus
            for sub in menu.submenus
        )
        enum_count = len(schema.enum_types or [])
        comment_count = sum(
            1 for menu in schema.menus
            for sub in menu.submenus
            for table in (sub.tables or [])
            for col in table.columns
            if col.comment
        )
        total_tokens = sum(s.tokens for s in result.steps) if hasattr(result, "steps") and result.steps else 0

        # Metrics row
        st.divider()
        n_cols = 6 if use_v2 else 5
        cols = st.columns(n_cols)
        cols[0].metric("Tables", table_count)
        cols[1].metric("ENUMs", enum_count)
        cols[2].metric("Menus", len(schema.menus))
        cols[3].metric("Col Comments", comment_count)
        cols[4].metric("Total Time", f"{result.elapsed_seconds}s")
        if use_v2:
            cols[5].metric("LLM Calls", len(result.steps))

        # ---- Validation Checklist ----
        st.divider()
        st.subheader("Validation Checklist")

        checks = [
            ("Snake_case naming (schema, tables, columns)", validate_naming),
            ("Submenu ID convention (menu_id × 100 + seq)", validate_submenu_ids),
            ("Table structure (id first, submenu_id second)", validate_table_structure),
            ("REFERENCES schema consistency", validate_ref_consistency),
            ("Referenced tables exist", validate_referenced_tables),
            ("PostgreSQL syntax (no MySQL-isms)", validate_postgresql_syntax),
            ("NUMERIC for money/quantities (no FLOAT)", validate_numeric_types),
            ("CHECK constraints on amounts/quantities", validate_check_constraints),
            ("ENUM usage for status/type/role columns", validate_enum_usage),
            ("JSON Schema compliance (mantara.schema.v1)", validate_schema_compliance),
        ]

        all_check_errors = []
        for check_name, check_fn in checks:
            errors = check_fn(schema)
            if errors:
                st.markdown(f"❌ **{check_name}** — {len(errors)} issue(s)")
                for e in errors:
                    st.caption(f"    → {e}")
                all_check_errors.extend(errors)
            else:
                st.markdown(f"✅ **{check_name}**")

        # --- Decomposition Quality Warnings ---
        decomp_checks = [
            ("Submenu granularity (2-6 per menu, tables present)", validate_submenu_granularity),
            ("Entity coverage (no duplicate tables across submenus)", validate_entity_coverage),
            ("Naming quality (no generic menu/submenu names)", validate_generic_names),
        ]

        decomp_warnings = []
        for check_name, check_fn in decomp_checks:
            warnings = check_fn(schema)
            if warnings:
                st.markdown(f"⚠️ **{check_name}** — {len(warnings)} warning(s)")
                for w in warnings:
                    st.caption(f"    → {w}")
                decomp_warnings.extend(warnings)
            else:
                st.markdown(f"✅ **{check_name}**")

        # Show assumptions and open questions if present
        if schema.assumptions:
            with st.expander(f"Assumptions ({len(schema.assumptions)})", expanded=False):
                for a in schema.assumptions:
                    st.markdown(f"- {a}")

        if schema.open_questions:
            with st.expander(f"Open Questions ({len(schema.open_questions)})", expanded=False):
                for q in schema.open_questions:
                    st.markdown(f"- {q}")

        # Structural checks (not from validator functions)
        sql_text = result.sql_str if hasattr(result, "sql_str") else result_dict["sql_str"]
        st.markdown(f"✅ **CREATE SCHEMA present**" if f"CREATE SCHEMA {schema.schema_name};" in sql_text else "❌ **CREATE SCHEMA missing**")
        st.markdown(f"✅ **Menu table with INSERT data**" if "INSERT INTO" in sql_text and ".menu" in sql_text else "❌ **Menu INSERT missing**")
        st.markdown(f"✅ **COMMENT ON TABLE for every table**" if sql_text.count("COMMENT ON TABLE") >= table_count + 2 else f"⚠️ **Some tables may be missing COMMENT ON TABLE**")

        # Test it yourself
        st.divider()
        st.subheader("Test It Yourself")
        st.markdown(
            "Copy the SQL below and paste it into an online PostgreSQL executor to verify it runs clean:"
        )

        test_col1, test_col2, test_col3 = st.columns(3)
        with test_col1:
            st.link_button("Open DB Fiddle (PostgreSQL 16)", "https://www.db-fiddle.com/")
        with test_col2:
            st.link_button("Open OneCompiler", "https://onecompiler.com/postgresql")
        with test_col3:
            st.link_button("Open Programiz", "https://www.programiz.com/postgresql/online-compiler/")

        # ---- Output tabs ----
        st.divider()
        tab_sql, tab_json = st.tabs(["SQL Output", "JSON Output"])

        sql_text = result.sql_str if hasattr(result, "sql_str") else result_dict["sql_str"]
        json_text = result.json_str if hasattr(result, "json_str") else result_dict["json_str"]

        with tab_sql:
            st.code(sql_text, language="sql")
            col_dl, col_copy = st.columns(2)
            with col_dl:
                st.download_button(
                    "Download .sql",
                    sql_text,
                    file_name=f"{schema.schema_name}_schema.sql",
                    mime="text/plain",
                )

        with tab_json:
            st.code(json_text, language="json")
            col_dl2, col_copy2 = st.columns(2)
            with col_dl2:
                st.download_button(
                    "Download .json",
                    json_text,
                    file_name=f"{schema.schema_name}_schema.json",
                    mime="application/json",
                )

        # v2 step details
        if use_v2 and result.steps:
            st.divider()
            with st.expander("Pipeline Step Details", expanded=False):
                for s in result.steps:
                    st.markdown(f"**{s.name}** — {s.latency}s, {s.tokens:,} tokens")
                    if s.output and s.output != "(structured output)":
                        st.code(s.output[:3000], language="json")
                if result.repair_attempts > 0:
                    st.info(f"Repair loop ran {result.repair_attempts} time(s)")

        # ============================================================
        # SELF-ASSESSMENT SCORECARD
        # ============================================================
        st.divider()
        st.subheader("Quality Self-Assessment")
        st.caption("The generated schema is automatically rated against 10 production-quality criteria.")

        assess_status = st.status("Running self-assessment...", expanded=True)
        with assess_status:
            assess_step = st.empty()
            assess_step.markdown("⏳ **Quality Assessment** (Structured Outputs) — running...")

            try:
                assessment, assess_elapsed, assess_tokens = assess_schema(
                    sql_text,
                    user_description=user_input.strip(),
                    model=model,
                )
                assess_step.markdown(f"✅ **Quality Assessment** — {assess_elapsed}s ({assess_tokens:,} tokens)")
                assess_status.update(label=f"Self-assessment complete — {assess_elapsed}s ✅", state="complete")

                # Store in session state for regenerate
                st.session_state["last_assessment"] = assessment
                st.session_state["last_user_input"] = user_input.strip()
                st.session_state["last_model"] = model
                st.session_state["last_version"] = version

            except Exception as e:
                assess_status.update(label="Assessment error", state="error")
                st.error(f"Self-assessment failed: {type(e).__name__}: {e}")
                assessment = None

        if assessment:
            # Grade + Overall metrics
            grade_color = {
                "A+": "🟢", "A": "🟢", "B+": "🟡", "B": "🟡",
                "C+": "🟠", "C": "🟠", "D": "🔴", "F": "🔴",
            }
            grade_icon = grade_color.get(assessment.grade, "⚪")

            g_cols = st.columns(4)
            g_cols[0].metric("Overall Score", f"{assessment.overall:.1f} / 10")
            g_cols[1].metric("Grade", f"{grade_icon} {assessment.grade}")
            g_cols[2].metric("Assessment Time", f"{assess_elapsed}s")
            g_cols[3].metric("Assessment Tokens", f"{assess_tokens:,}")

            st.markdown(f"**{assessment.summary}**")

            # Scorecard table
            st.markdown("---")
            scorecard_data = []
            for c in assessment.criteria:
                scorecard_data.append({
                    "Criterion": c.criterion,
                    "Score": c.score,
                    "Notes": c.notes,
                    "How to Improve": c.improvement,
                })

            df_assess = pd.DataFrame(scorecard_data)

            def color_score(val):
                if isinstance(val, (int, float)):
                    if val >= 8.0:
                        return "background-color: #d4edda"  # green
                    elif val >= 6.0:
                        return "background-color: #fff3cd"  # yellow
                    else:
                        return "background-color: #f8d7da"  # red
                return ""

            styled_assess = df_assess.style.map(
                color_score, subset=["Score"]
            ).format({"Score": "{:.1f}"})
            st.dataframe(styled_assess, use_container_width=True, hide_index=True)

            # Strengths & Weaknesses
            str_col, weak_col = st.columns(2)
            with str_col:
                st.markdown("**Strengths**")
                for s in assessment.strengths:
                    st.markdown(f"- {s}")
            with weak_col:
                st.markdown("**Weaknesses**")
                for w in assessment.weaknesses:
                    st.markdown(f"- {w}")

            # Verdict
            st.markdown("---")
            st.markdown(f"**Verdict:** {assessment.verdict}")

            # Download assessment JSON
            assess_json = assessment.model_dump_json(indent=2)
            st.download_button(
                "Download Assessment JSON",
                assess_json,
                file_name="schema_self_assessment.json",
                mime="application/json",
                use_container_width=True,
            )

            # ============================================================
            # REGENERATE BUTTON — feeds weaknesses back to improve
            # ============================================================
            if assessment.overall < 9.5 and assessment.weaknesses:
                st.divider()
                st.subheader("Improve Schema")

                # Build improvement hints from low-scoring criteria
                low_criteria = [c for c in assessment.criteria if c.score < 8.0]
                if low_criteria:
                    st.markdown("The following criteria scored below 8.0 and can be improved:")
                    for c in low_criteria:
                        st.markdown(f"- **{c.criterion}** ({c.score:.1f}/10): {c.improvement}")

                st.markdown("Click **Regenerate** to re-generate the schema with these improvements applied.")

                if st.button("Regenerate (with improvements)", type="primary", use_container_width=True):
                    # Build an enhanced prompt with feedback
                    improvement_instructions = "\n".join(
                        f"- {c.improvement}" for c in assessment.criteria if c.score < 8.0
                    )
                    weakness_notes = "\n".join(f"- {w}" for w in assessment.weaknesses)

                    enhanced_input = (
                        f"{st.session_state['last_user_input']}\n\n"
                        f"IMPORTANT — Apply these mandatory improvements to the schema:\n"
                        f"{improvement_instructions}\n\n"
                        f"Address these weaknesses:\n"
                        f"{weakness_notes}\n\n"
                        f"Ensure the schema scores 8.0+ on ALL of these criteria: "
                        f"Domain Coverage, Data Model Clarity, Normalization, Referential Integrity, "
                        f"Data Type Precision, Constraint Rigor, ENUM Design, Scalability & Indexing, "
                        f"PostgreSQL Best Practices, Completeness & Production-Readiness."
                    )

                    regen_use_v2 = "v2" in st.session_state.get("last_version", "v1")
                    _default_model = OLLAMA_MODEL if BACKEND in ("ollama", "llamacpp") else "gpt-4o"
                    regen_model = st.session_state.get("last_model", _default_model)

                    regen_status = st.status("Regenerating with improvements...", expanded=True)
                    with regen_status:
                        regen_step = st.empty()
                        regen_step.markdown("⏳ **Regenerating schema** — running...")

                        try:
                            if regen_use_v2:
                                regen_result = generate_v2(enhanced_input, model=regen_model)
                                regen_sql = regen_result.sql_str
                                regen_json = regen_result.json_str
                            else:
                                regen_dict = generate(enhanced_input, model=regen_model)
                                regen_sql = regen_dict["sql_str"]
                                regen_json = regen_dict["json_str"]

                            regen_step.markdown("✅ **Schema regenerated**")
                            regen_status.update(label="Regeneration complete ✅", state="complete")

                        except Exception as e:
                            regen_status.update(label="Regeneration error", state="error")
                            st.error(f"Regeneration failed: {type(e).__name__}: {e}")
                            st.stop()

                    # Show regenerated output
                    st.subheader("Improved Schema")
                    regen_tab_sql, regen_tab_json = st.tabs(["SQL Output", "JSON Output"])

                    with regen_tab_sql:
                        st.code(regen_sql, language="sql")
                        st.download_button(
                            "Download Improved .sql",
                            regen_sql,
                            file_name="improved_schema.sql",
                            mime="text/plain",
                        )

                    with regen_tab_json:
                        st.code(regen_json, language="json")
                        st.download_button(
                            "Download Improved .json",
                            regen_json,
                            file_name="improved_schema.json",
                            mime="application/json",
                        )

                    # Re-assess the improved schema
                    st.markdown("---")
                    st.subheader("Improved Schema Assessment")
                    regen_assess_status = st.status("Assessing improved schema...", expanded=True)
                    with regen_assess_status:
                        regen_assess_step = st.empty()
                        regen_assess_step.markdown("⏳ **Quality Assessment** — running...")

                        try:
                            regen_assessment, regen_a_elapsed, regen_a_tokens = assess_schema(
                                regen_sql,
                                user_description=st.session_state["last_user_input"],
                                model=regen_model,
                            )
                            regen_assess_step.markdown(
                                f"✅ **Quality Assessment** — {regen_a_elapsed}s ({regen_a_tokens:,} tokens)"
                            )
                            regen_assess_status.update(
                                label=f"Assessment complete — {regen_a_elapsed}s ✅", state="complete"
                            )
                        except Exception as e:
                            regen_assess_status.update(label="Assessment error", state="error")
                            st.error(f"Re-assessment failed: {type(e).__name__}: {e}")
                            regen_assessment = None

                    if regen_assessment:
                        # Show before/after comparison
                        delta = regen_assessment.overall - assessment.overall
                        delta_str = f"+{delta:.1f}" if delta > 0 else f"{delta:.1f}"
                        regen_grade_icon = grade_color.get(regen_assessment.grade, "⚪")

                        rg_cols = st.columns(4)
                        rg_cols[0].metric("New Overall", f"{regen_assessment.overall:.1f} / 10", delta=delta_str)
                        rg_cols[1].metric("New Grade", f"{regen_grade_icon} {regen_assessment.grade}")
                        rg_cols[2].metric("Previous", f"{assessment.overall:.1f} / 10")
                        rg_cols[3].metric("Improvement", delta_str)

                        st.markdown(f"**{regen_assessment.summary}**")

                        # Side-by-side score comparison
                        compare_data = []
                        for orig_c, new_c in zip(assessment.criteria, regen_assessment.criteria):
                            compare_data.append({
                                "Criterion": orig_c.criterion,
                                "Before": orig_c.score,
                                "After": new_c.score,
                                "Change": new_c.score - orig_c.score,
                                "Notes": new_c.notes,
                            })

                        df_compare = pd.DataFrame(compare_data)

                        def color_change(val):
                            if isinstance(val, (int, float)):
                                if val > 0:
                                    return "background-color: #d4edda; color: #155724"
                                elif val < 0:
                                    return "background-color: #f8d7da; color: #721c24"
                            return ""

                        styled_compare = df_compare.style.map(
                            color_change, subset=["Change"]
                        ).map(
                            color_score, subset=["Before", "After"]
                        ).format({"Before": "{:.1f}", "After": "{:.1f}", "Change": "{:+.1f}"})
                        st.dataframe(styled_compare, use_container_width=True, hide_index=True)

                        st.markdown(f"**Verdict:** {regen_assessment.verdict}")

                        regen_assess_json = regen_assessment.model_dump_json(indent=2)
                        st.download_button(
                            "Download Improved Assessment JSON",
                            regen_assess_json,
                            file_name="improved_schema_assessment.json",
                            mime="application/json",
                            use_container_width=True,
                        )

    elif generate_btn:
        st.warning("Please enter a description first.")

# ============================================================
# COMPARE SCHEMAS MODE
# ============================================================
elif mode == "Compare Schemas":

    # --- Settings ---
    col_compare_main, col_compare_settings = st.columns([3, 1])

    with col_compare_settings:
        st.markdown("**Settings**")
        if BACKEND in ("ollama", "llamacpp"):
            compare_model = st.selectbox("Model", [OLLAMA_MODEL], index=0, key="compare_model")
        else:
            compare_model = st.selectbox("Model", ["gpt-4o", "gpt-4o-mini"], index=0, key="compare_model")

    with col_compare_main:
        st.markdown("Paste two SQL schemas side by side to get a quality scorecard with 10 criteria scored /10.")

    # --- Schema input columns ---
    col_v1, col_v2 = st.columns(2)

    with col_v1:
        label_v1 = st.text_input("Label", value="V1", key="label_v1")
        schema_v1 = st.text_area(
            "Paste SQL schema",
            placeholder="CREATE TABLE ...",
            height=300,
            key="schema_v1",
        )

    with col_v2:
        label_v2 = st.text_input("Label", value="V2", key="label_v2")
        schema_v2 = st.text_area(
            "Paste SQL schema",
            placeholder="CREATE TABLE ...",
            height=300,
            key="schema_v2",
        )

    compare_btn = st.button("Compare Schemas", type="primary", use_container_width=True)

    # --- Compare ---
    if compare_btn and schema_v1.strip() and schema_v2.strip():

        status = st.status("Comparing schemas...", expanded=True)

        with status:
            step_container = st.empty()
            step_container.markdown("⏳ **LLM Analysis** (Structured Outputs) — running...")

            try:
                comparison, elapsed, total_tokens = compare_schemas(
                    schema_v1.strip(),
                    schema_v2.strip(),
                    label_v1=label_v1,
                    label_v2=label_v2,
                    model=compare_model,
                )

                step_container.markdown(f"✅ **LLM Analysis** — {elapsed}s ({total_tokens:,} tokens)")
                status.update(label=f"Comparison complete — {elapsed}s ✅", state="complete")

            except Exception as e:
                status.update(label="Error", state="error")
                st.error(f"Error: {type(e).__name__}: {e}")
                st.stop()

        # ---- Metrics row ----
        st.divider()
        winner = label_v1 if comparison.v1_overall > comparison.v2_overall else (
            label_v2 if comparison.v2_overall > comparison.v1_overall else "Tie"
        )
        m_cols = st.columns(5)
        m_cols[0].metric(f"{label_v1} Overall", f"{comparison.v1_overall:.1f} / 10")
        m_cols[1].metric(f"{label_v2} Overall", f"{comparison.v2_overall:.1f} / 10")
        m_cols[2].metric("Winner", winner)
        m_cols[3].metric("Time", f"{elapsed}s")
        m_cols[4].metric("Tokens", f"{total_tokens:,}")

        # ---- Summary ----
        st.divider()
        st.markdown(f"**Summary:** {comparison.summary}")

        # ---- Scorecard Table ----
        st.divider()
        st.subheader("Scorecard")

        scorecard_data = []
        for c in comparison.criteria:
            scorecard_data.append({
                "Criterion": c.criterion,
                label_v1: c.v1_score,
                label_v2: c.v2_score,
                "Notes": c.notes,
            })

        df = pd.DataFrame(scorecard_data)

        def highlight_winner(row):
            """Highlight the higher score in green for each row."""
            v1_val = row[label_v1]
            v2_val = row[label_v2]
            styles = [""] * len(row)
            v1_idx = row.index.get_loc(label_v1)
            v2_idx = row.index.get_loc(label_v2)
            if v1_val > v2_val:
                styles[v1_idx] = "background-color: #d4edda"
            elif v2_val > v1_val:
                styles[v2_idx] = "background-color: #d4edda"
            return styles

        styled_df = df.style.apply(highlight_winner, axis=1).format(
            {label_v1: "{:.1f}", label_v2: "{:.1f}"}
        )
        st.dataframe(styled_df, use_container_width=True, hide_index=True)

        # ---- Profiles ----
        st.divider()
        st.subheader("Schema Profiles")

        prof_col1, prof_col2 = st.columns(2)

        with prof_col1:
            st.markdown(f"### {comparison.v1_profile.label}")
            st.markdown(f"**Best for:** {comparison.v1_profile.best_for}")
            st.markdown("**Strengths**")
            for s in comparison.v1_profile.strengths:
                st.markdown(f"- {s}")
            st.markdown("**Weaknesses**")
            for w in comparison.v1_profile.weaknesses:
                st.markdown(f"- {w}")

        with prof_col2:
            st.markdown(f"### {comparison.v2_profile.label}")
            st.markdown(f"**Best for:** {comparison.v2_profile.best_for}")
            st.markdown("**Strengths**")
            for s in comparison.v2_profile.strengths:
                st.markdown(f"- {s}")
            st.markdown("**Weaknesses**")
            for w in comparison.v2_profile.weaknesses:
                st.markdown(f"- {w}")

        # ---- Verdict ----
        st.divider()
        st.subheader("Verdict")
        st.markdown(comparison.verdict)

        # ---- Download JSON ----
        st.divider()
        scorecard_json = comparison.model_dump_json(indent=2)
        st.download_button(
            "Download Scorecard JSON",
            scorecard_json,
            file_name="schema_comparison_scorecard.json",
            mime="application/json",
            use_container_width=True,
        )

    elif compare_btn:
        st.warning("Please paste both schemas before comparing.")

# ============================================================
# ARCHITECTURE MODE
# ============================================================
elif mode == "Architecture":
    st.subheader("Pipeline Architecture")
    st.markdown(
        "How Mantara generates production-ready PostgreSQL schemas from natural language."
    )

    st.divider()

    # --- V1 Pipeline ---
    st.markdown("### V1 Pipeline — Single Call")
    st.markdown("""
```
User Input (natural language)
        │
        ▼
┌─────────────────────────────────┐
│  System Prompt (system_prompt.md)│
│  + JSON-Only Addendum           │
│  + Chain-of-Thought guidance    │
└────────────┬────────────────────┘
             │
             ▼
┌─────────────────────────────────┐
│  OpenAI Structured Outputs      │
│  model: gpt-4o                  │
│  response_format: MantaraSchema │
│  temperature: 0.2               │
│  max_tokens: 16,000             │
│                                 │
│  The LLM generates ONLY JSON   │
│  (no SQL) — Pydantic model      │
│  guarantees valid structure     │
└────────────┬────────────────────┘
             │
             ▼
┌─────────────────────────────────┐
│  Business Validator (Python)    │
│  12 validation checks:         │
│  • Snake_case naming            │
│  • Submenu ID convention        │
│  • Table structure (id first)   │
│  • REFERENCES consistency       │
│  • Referenced tables exist      │
│  • PostgreSQL syntax            │
│  • NUMERIC for money            │
│  • CHECK constraints            │
│  • JSON Schema compliance       │
│  • Submenu granularity (NEW)    │
│  • Entity coverage (NEW)        │
│  • Generic name detection (NEW) │
└────────────┬────────────────────┘
             │
             ▼
┌─────────────────────────────────┐
│  Deterministic SQL Renderer     │
│  (renderer.py — zero LLM)      │
│                                 │
│  • Resolves table creation      │
│    order (topological sort)     │
│  • Defers forward-ref FKs      │
│    via ALTER TABLE              │
│  • Auto-injects CHECK >= 0     │
│  • Auto-injects date range     │
│    CHECK constraints            │
│  • Strips MySQL syntax          │
│  • Renders COMMENT ON TABLE     │
│  • Section separators           │
└────────────┬────────────────────┘
             │
             ▼
    JSON + SQL Output
```
    """)

    st.markdown("**Latency:** ~8-15s (single LLM call)")
    st.markdown("**Best for:** Simple to moderate systems (< 2000 char input)")

    st.divider()

    # --- V2 Pipeline ---
    st.markdown("### V2 Pipeline — Multi-Step (Tree of Thought)")
    st.markdown("""
```
User Input (natural language)
        │
        ▼
┌──────────────────────────────────┐
│  Step 1: INPUT ANALYZER          │
│  model: gpt-4o-mini (cheap)     │
│  max_tokens: 2,000              │
│  temperature: 0.1               │
│                                  │
│  Extracts:                       │
│  • system_type (F/S/Mixed)      │
│  • entities list                 │
│  • relationships                 │
│  • enum_candidates               │
│  • complexity estimate           │
│                                  │
│  Purpose: Structured analysis    │
│  before expensive generation     │
└────────────┬─────────────────────┘
             │
             ▼
┌──────────────────────────────────┐
│  Step 2: SCHEMA PLANNER          │
│  model: gpt-4o (strong)         │
│  max_tokens: 8,000              │
│  temperature: 0.2               │
│                                  │
│  Designs:                        │
│  • Menu → Submenu hierarchy     │
│  • Table assignments per submenu │
│  • Column types + constraints    │
│  • FK relationships              │
│  • ENUM types                    │
│  • Entity decomposition rules    │
│                                  │
│  Purpose: "Think before coding"  │
│  — plan the schema architecture  │
│  before committing to JSON       │
└────────────┬─────────────────────┘
             │
             ▼
┌──────────────────────────────────┐
│  Step 3: JSON GENERATOR          │
│  model: gpt-4o                   │
│  Structured Outputs              │
│  response_format: MantaraSchema  │
│  max_tokens: 16,000             │
│                                  │
│  Follows the plan from Step 2   │
│  to produce exact JSON structure │
└────────────┬─────────────────────┘
             │
             ▼
┌──────────────────────────────────┐
│  Step 4: VALIDATE + REPAIR LOOP  │
│  Python validation (12 checks)   │
│                                  │
│  If errors found:                │
│  ┌────────────────────────┐     │
│  │  Re-call LLM with      │     │
│  │  error feedback         │     │
│  │  (max 2 repair cycles)  │◄───┘
│  └────────────────────────┘     │
│                                  │
│  Errors fed back as explicit    │
│  instructions for the next try   │
└────────────┬─────────────────────┘
             │
             ▼
┌──────────────────────────────────┐
│  Step 5: SQL RENDERER            │
│  (same deterministic renderer)   │
└────────────┬─────────────────────┘
             │
             ▼
    JSON + SQL Output
    + Step-by-step metrics
```
    """)

    st.markdown("**Latency:** ~25-60s (3-5 LLM calls)")
    st.markdown("**Best for:** Complex systems, large FSDs, 2000+ char inputs")

    st.divider()

    # --- Key Design Decisions ---
    st.markdown("### Key Design Decisions")

    col_arch1, col_arch2 = st.columns(2)

    with col_arch1:
        st.markdown("**Why LLM generates JSON, not SQL?**")
        st.markdown("""
- SQL is fragile — one typo = entire script fails
- JSON is structurally validated by Pydantic before rendering
- Deterministic SQL renderer = zero syntax errors
- Easier to repair — fix JSON structure, not SQL strings
- Business validation catches semantic errors (wrong IDs, missing FKs)
        """)

        st.markdown("**Why Structured Outputs?**")
        st.markdown("""
- OpenAI guarantees response matches Pydantic model
- No parsing/regex needed — response is always valid JSON
- Eliminates "forgot to close a brace" class of errors
- `response_format=MantaraSchema` = type-safe output
        """)

    with col_arch2:
        st.markdown("**Why V2 uses 3 steps?**")
        st.markdown("""
- **Chain-of-Thought**: Planning before generating catches decomposition errors early
- **Cheap analysis first**: gpt-4o-mini for Step 1 saves cost on initial classification
- **Repair loop**: Self-healing — validation errors get fed back as explicit fix instructions
- **Separation of concerns**: Analyzer, Planner, Generator each have focused prompts
        """)

        st.markdown("**New: Decomposition Guards**")
        st.markdown("""
- **Submenu granularity**: Flags menus with too few/many submenus
- **Entity coverage**: Catches duplicate tables across submenus
- **Generic name detection**: Flags "General", "Misc", "Other" names
- **Step-by-step prompt**: LLM mentally lists entities before generating
- **Domain patterns**: System prompt includes reference decompositions for common domains
        """)

    st.divider()

    # --- File Map ---
    st.markdown("### File Map")
    st.code("""
v1/
├── app.py                 # Streamlit UI (this app)
├── generator.py           # V1 pipeline orchestrator
├── generator_v2.py        # V2 multi-step pipeline
├── llm_client.py          # OpenAI API wrapper + system prompt loading
├── business_validator.py  # 12 validation checks (Python, no LLM)
├── renderer.py            # Deterministic SQL renderer (Python, no LLM)
├── models.py              # Pydantic models (MantaraSchema)
├── comparator.py          # Schema comparison + self-assessment
├── config.py              # Environment config (.env)
├── prompts/
│   ├── system_prompt.md   # Main LLM system prompt (Mantara Standard)
│   ├── comparison_prompt.md
│   └── self_assessment_prompt.md
├── schemas/
│   └── mantara_schema_v1.json  # JSON Schema for validation
└── backends/              # Pluggable LLM backends
    """, language="text")

# --- Footer ---
st.divider()
st.caption("Mantara Schema Generator — Local Instance")

print('testing')