"""Mantara v8 — DomainKnowledgeBag.

Aggregates domain values from upstream artifacts (Step 2 vision, Step 3 CIR,
Step 4 PRD entities + glossary) into a single lookup-friendly bag.

The bag answers:
  - "What status values does this entity have, per the input?"
  - "What other enum-like values are described for this entity's fields?"
  - "What's the human-readable description of this entity?"

Used by the normaliser to seed cfg_* tables with REAL domain values pulled
from the input (instead of generic placeholders), and to attach glossary
definitions as COMMENT ON TABLE.

Generic across domains — relies only on the CIR/PRD data structure, not
on any domain-specific keywords.
"""
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any


def _snake(name: str) -> str:
    """`AdvancedShipmentNotice` → `advanced_shipment_notice`."""
    if not name:
        return ""
    s = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
    return re.sub(r"[^A-Za-z0-9]+", "_", s).strip("_").lower()


# Status field synonyms — the bag treats all of these as "status" for lookups
_STATUS_FIELD_NAMES = {"status", "state", "lifecycle", "stage", "phase"}


class DomainKnowledgeBag:
    """Aggregator + lookup over upstream domain artifacts."""

    def __init__(self) -> None:
        # (snake_entity, snake_field) -> list[str]   domain values for that field
        self._field_values: dict[tuple[str, str], list[str]] = {}
        # snake_entity -> list[str]                  status values (workflow states or status field values)
        self._status_values: dict[str, list[str]] = {}
        # snake_entity -> str                        glossary definition / description
        self._comments: dict[str, str] = {}
        # snake_entity -> list[str]                  workflow names bound to this entity
        self._workflows: dict[str, list[str]] = {}

    # --------------------------------------------------------------------- #
    # Construction
    # --------------------------------------------------------------------- #
    @classmethod
    def from_run_dir(cls, run_dir: Path) -> "DomainKnowledgeBag":
        """Build a bag by loading every relevant artifact from a Step-1 run dir.

        Tolerant of missing files — each loader is independent.
        """
        run_dir = Path(run_dir)
        bag = cls()
        bag._load_cir(run_dir / "cir" / "enriched_cir.json")
        bag._load_vision(run_dir / "cir" / "vision_extract.json")
        bag._load_entities(run_dir / "prd" / "entities.json")
        bag._load_glossary(run_dir / "prd" / "glossary.json")
        return bag

    @classmethod
    def from_dicts(cls, *,
                    cir: dict | None = None,
                    vision: dict | None = None,
                    entities: dict | None = None,
                    glossary: dict | list | None = None) -> "DomainKnowledgeBag":
        """Construct from already-loaded dicts (used in tests + when caller has them in memory)."""
        bag = cls()
        if cir:
            bag._ingest_cir(cir)
        if vision:
            bag._ingest_vision(vision)
        if entities:
            bag._ingest_entities(entities)
        if glossary:
            bag._ingest_glossary(glossary)
        return bag

    # --------------------------------------------------------------------- #
    # Loaders (file-aware, tolerant)
    # --------------------------------------------------------------------- #
    def _load_cir(self, path: Path) -> None:
        if path.exists():
            try:
                self._ingest_cir(json.loads(path.read_text()))
            except Exception:
                pass

    def _load_vision(self, path: Path) -> None:
        if path.exists():
            try:
                self._ingest_vision(json.loads(path.read_text()))
            except Exception:
                pass

    def _load_entities(self, path: Path) -> None:
        if path.exists():
            try:
                self._ingest_entities(json.loads(path.read_text()))
            except Exception:
                pass

    def _load_glossary(self, path: Path) -> None:
        if path.exists():
            try:
                self._ingest_glossary(json.loads(path.read_text()))
            except Exception:
                pass

    # --------------------------------------------------------------------- #
    # Ingesters (one per source)
    # --------------------------------------------------------------------- #
    def _ingest_cir(self, cir: dict) -> None:
        # Workflows → status values per bound entity
        for wf in cir.get("workflows") or []:
            if not isinstance(wf, dict):
                continue
            bound = wf.get("bound_entity") or wf.get("name") or ""
            states = [str(s).strip() for s in (wf.get("states") or []) if s]
            entity_key = _snake(bound)
            if entity_key and states:
                # Filter out compound state names like "Create" or "RCR" that
                # are workflow steps, not actual entity statuses (heuristic:
                # keep states that look like single-word status descriptors).
                self._status_values.setdefault(entity_key, []).extend(states)
                self._workflows.setdefault(entity_key, []).append(wf.get("name", ""))

        # CIR entities → field-level domain values
        for ent in cir.get("entities") or []:
            if not isinstance(ent, dict):
                continue
            ename = _snake(ent.get("name") or ent.get("entity_name") or "")
            if not ename:
                continue
            for field in ent.get("fields") or []:
                if not isinstance(field, dict):
                    continue
                fname = _snake(field.get("name") or "")
                if not fname:
                    continue
                values = field.get("values") or field.get("enum") or field.get("options") or []
                if values:
                    cleaned = [str(v).strip() for v in values if v]
                    if cleaned:
                        self._field_values.setdefault((ename, fname), []).extend(cleaned)
                        # If the field name IS a status synonym, also register as status
                        if fname in _STATUS_FIELD_NAMES:
                            self._status_values.setdefault(ename, []).extend(cleaned)

    def _ingest_vision(self, vision: dict) -> None:
        # Vision sometimes has entities[].fields[].values too; same shape as CIR
        for ent in (vision.get("entities") or []):
            if not isinstance(ent, dict):
                continue
            ename = _snake(ent.get("name") or "")
            for field in ent.get("fields") or []:
                if not isinstance(field, dict):
                    continue
                fname = _snake(field.get("name") or "")
                values = field.get("values") or field.get("enum") or field.get("options") or []
                if ename and fname and values:
                    cleaned = [str(v).strip() for v in values if v]
                    if cleaned:
                        self._field_values.setdefault((ename, fname), []).extend(cleaned)
                        if fname in _STATUS_FIELD_NAMES:
                            self._status_values.setdefault(ename, []).extend(cleaned)

    def _ingest_entities(self, entities_json: dict) -> None:
        # PRD entities.json — JSON Schema with $defs[entity].properties[field].enum
        defs = entities_json.get("$defs") or entities_json.get("definitions") or {}
        for ename, espec in defs.items():
            if not isinstance(espec, dict):
                continue
            entity_key = _snake(ename)
            for fname, fspec in (espec.get("properties") or {}).items():
                if not isinstance(fspec, dict):
                    continue
                field_key = _snake(fname)
                enum_vals = fspec.get("enum") or []
                if enum_vals:
                    cleaned = [str(v).strip() for v in enum_vals if v]
                    if cleaned:
                        self._field_values.setdefault((entity_key, field_key), []).extend(cleaned)
                        if field_key in _STATUS_FIELD_NAMES:
                            self._status_values.setdefault(entity_key, []).extend(cleaned)
            # Entity-level description from JSON Schema
            desc = espec.get("description") or espec.get("title")
            if desc:
                self._comments.setdefault(entity_key, str(desc).strip())

    def _ingest_glossary(self, glossary: Any) -> None:
        # Glossary may be a dict {term: definition} OR a list of {term, definition} OR a {"terms": [...]}
        terms_list: list = []
        if isinstance(glossary, list):
            terms_list = glossary
        elif isinstance(glossary, dict):
            if "terms" in glossary and isinstance(glossary["terms"], list):
                terms_list = glossary["terms"]
            else:
                # Treat the dict itself as {term: definition}
                terms_list = [{"term": k, "definition": v} for k, v in glossary.items()]
        for entry in terms_list:
            if isinstance(entry, dict):
                term = entry.get("term") or entry.get("name") or ""
                defn = entry.get("definition") or entry.get("description") or ""
            elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
                term, defn = str(entry[0]), str(entry[1])
            else:
                continue
            if not term or not defn:
                continue
            entity_key = _snake(term)
            if entity_key and entity_key not in self._comments:
                self._comments[entity_key] = str(defn).strip()

    # --------------------------------------------------------------------- #
    # Lookups (the public API the normaliser uses)
    # --------------------------------------------------------------------- #
    def lookup_field_values(self, entity: str, field: str) -> list[str]:
        """Return de-duplicated domain values for `entity.field`, or []."""
        key = (_snake(entity), _snake(field))
        vals = self._field_values.get(key) or []
        return _dedupe_preserve_order(vals)

    def lookup_status_values(self, entity: str) -> list[str]:
        """Return de-duplicated status values for the entity, or [].

        Pulls from workflow.states + any field whose name is in _STATUS_FIELD_NAMES.
        Falls back to fuzzy entity-name matching (handles plurals + abbreviations).
        """
        key = _snake(entity)
        vals = list(self._status_values.get(key) or [])
        if not vals:
            # Fuzzy: try singular/plural pivot
            for k, v in self._status_values.items():
                if k == key.rstrip("s") or k.rstrip("s") == key or k in key or key in k:
                    vals.extend(v)
                    break
        return _dedupe_preserve_order(vals)

    def lookup_entity_comment(self, entity: str) -> str | None:
        """Return a human-readable description for the entity, or None."""
        key = _snake(entity)
        return self._comments.get(key)

    def has_data(self) -> bool:
        return bool(self._field_values or self._status_values or self._comments)

    def summary(self) -> dict[str, int]:
        return {
            "entities_with_status_values": len(self._status_values),
            "field_values": len(self._field_values),
            "entity_comments": len(self._comments),
            "workflows_indexed": sum(len(v) for v in self._workflows.values()),
        }


def _dedupe_preserve_order(items: list[str]) -> list[str]:
    seen: set[str] = set()
    out: list[str] = []
    for item in items:
        key = str(item).lower().strip()
        if key and key not in seen:
            seen.add(key)
            out.append(item)
    return out


__all__ = ["DomainKnowledgeBag"]
