"""Shared PRD/image loader used by both backend_generation_system and ui_generation.

Returns text + base64-encoded embedded images so the same PRD can drive backend
generation, page detection, and IR generation without re-parsing.
"""
from __future__ import annotations

import base64
from dataclasses import dataclass, field
from pathlib import Path

_ALLOWED_PRD_EXTENSIONS = {".txt", ".md", ".pdf", ".pptx", ".docx"}
_ALLOWED_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp"}

_IMAGE_MEDIA_TYPES = {
    ".png": "image/png",
    ".jpg": "image/jpeg",
    ".jpeg": "image/jpeg",
    ".gif": "image/gif",
    ".webp": "image/webp",
}


@dataclass
class PRDImage:
    """A single image extracted from (or supplied alongside) a PRD."""

    media_type: str
    data: str  # base64-encoded

    def to_dict(self) -> dict[str, str]:
        return {"media_type": self.media_type, "data": self.data}


@dataclass
class LoadedPRD:
    """Result of loading a PRD file."""

    path: Path
    text: str
    images: list[PRDImage] = field(default_factory=list)
    ddl_path: Path | None = None  # set when step-02 generates a DDL from images

    @property
    def images_as_dicts(self) -> list[dict[str, str]]:
        return [img.to_dict() for img in self.images]


def _extract_from_pdf(path: Path) -> tuple[str, list[PRDImage]]:
    try:
        import fitz  # pymupdf
    except ImportError as e:
        raise ImportError(
            "pymupdf is required to read PDF files. Install with: pip install pymupdf"
        ) from e

    doc = fitz.open(str(path))
    text_parts: list[str] = []
    images: list[PRDImage] = []
    seen_xrefs: set[int] = set()

    for page in doc:
        page_text = page.get_text().strip()
        if page_text:
            text_parts.append(page_text)

        for img_info in page.get_images(full=True):
            xref = img_info[0]
            if xref in seen_xrefs:
                continue
            seen_xrefs.add(xref)
            base_image = doc.extract_image(xref)
            ext = base_image["ext"].lower()
            media_type = "image/jpeg" if ext in {"jpg", "jpeg"} else f"image/{ext}"
            data = base64.standard_b64encode(base_image["image"]).decode("utf-8")
            images.append(PRDImage(media_type=media_type, data=data))

    doc.close()
    return "\n\n".join(text_parts), images


def _extract_from_pptx(path: Path) -> tuple[str, list[PRDImage]]:
    try:
        from pptx import Presentation
        from pptx.enum.shapes import MSO_SHAPE_TYPE
    except ImportError as e:
        raise ImportError(
            "python-pptx is required to read PPTX files. Install with: pip install python-pptx"
        ) from e

    prs = Presentation(str(path))
    text_parts: list[str] = []
    images: list[PRDImage] = []
    seen_parts: set[int] = set()

    for slide_num, slide in enumerate(prs.slides, 1):
        slide_texts: list[str] = []
        for shape in slide.shapes:
            if shape.has_text_frame:
                for para in shape.text_frame.paragraphs:
                    if para.text.strip():
                        slide_texts.append(para.text.strip())

            if shape.shape_type == MSO_SHAPE_TYPE.PICTURE:
                img = shape.image
                part_id = id(img)
                if part_id in seen_parts:
                    continue
                seen_parts.add(part_id)
                data = base64.standard_b64encode(img.blob).decode("utf-8")
                images.append(PRDImage(media_type=img.content_type, data=data))

        if slide_texts:
            text_parts.append(f"--- Slide {slide_num} ---\n" + "\n".join(slide_texts))

    return "\n\n".join(text_parts), images


def _extract_from_docx(path: Path) -> tuple[str, list[PRDImage]]:
    try:
        from docx import Document
    except ImportError as e:
        raise ImportError(
            "python-docx is required to read DOCX files. Install with: pip install python-docx"
        ) from e

    doc = Document(str(path))
    text_parts: list[str] = []
    images: list[PRDImage] = []
    seen_parts: set[int] = set()

    for para in doc.paragraphs:
        if para.text.strip():
            text_parts.append(para.text.strip())

    for table in doc.tables:
        for row in table.rows:
            row_cells = [cell.text.strip() for cell in row.cells if cell.text.strip()]
            if row_cells:
                text_parts.append(" | ".join(row_cells))

    for rel in doc.part.rels.values():
        if rel.is_external:
            continue
        if "image" not in rel.reltype:
            continue
        image_part = rel.target_part
        part_id = id(image_part)
        if part_id in seen_parts:
            continue
        seen_parts.add(part_id)
        data = base64.standard_b64encode(image_part.blob).decode("utf-8")
        images.append(PRDImage(media_type=image_part.content_type, data=data))

    return "\n".join(text_parts), images


def validate_and_load_prd(prd_file: str | Path) -> LoadedPRD:
    """Load a PRD file (.md/.txt/.pdf/.pptx/.docx).

    Returns a LoadedPRD with extracted text and any embedded images.
    Raises FileNotFoundError if the file is missing, ValueError if the
    extension is unsupported or the file is empty.
    """
    path = Path(prd_file)
    if not path.exists():
        raise FileNotFoundError(f"PRD file not found: {prd_file}")

    ext = path.suffix.lower()
    if ext not in _ALLOWED_PRD_EXTENSIONS:
        raise ValueError(
            f"PRD file must be one of {sorted(_ALLOWED_PRD_EXTENSIONS)}, got: '{path.suffix}'."
        )

    if ext == ".pdf":
        text, images = _extract_from_pdf(path)
    elif ext == ".pptx":
        text, images = _extract_from_pptx(path)
    elif ext == ".docx":
        text, images = _extract_from_docx(path)
    else:
        text = path.read_text(encoding="utf-8").strip()
        images = []

    if not text.strip() and not images:
        raise ValueError(f"PRD file is empty: {prd_file}")

    return LoadedPRD(path=path, text=text, images=images)


def validate_and_load_images(image_files: list[str] | list[Path]) -> list[PRDImage]:
    """Load standalone image files into the same PRDImage shape."""
    loaded: list[PRDImage] = []
    for raw in image_files:
        s = str(raw).strip()
        if not s:
            continue

        path = Path(s)
        if not path.exists():
            raise FileNotFoundError(f"Image file not found: {s}")

        ext = path.suffix.lower()
        if ext not in _ALLOWED_IMAGE_EXTENSIONS:
            raise ValueError(
                f"Unsupported image format '{path.suffix}' for {s}. "
                f"Supported: {sorted(_ALLOWED_IMAGE_EXTENSIONS)}"
            )

        data = base64.standard_b64encode(path.read_bytes()).decode("utf-8")
        loaded.append(PRDImage(media_type=_IMAGE_MEDIA_TYPES[ext], data=data))

    return loaded
