from __future__ import annotations

import base64
import mimetypes
from dataclasses import dataclass
from pathlib import Path


_SUPPORTED_IMAGE_MIME_TYPES: set[str] = {
    "image/png",
    "image/jpeg",
    "image/webp",
    "image/gif",
}


@dataclass(frozen=True)
class EncodedImage:
    path: Path
    media_type: str
    base64_data: str


def collect_image_paths(images_dir: Path, limit: int | None = None) -> list[Path]:
    if not images_dir.exists():
        raise FileNotFoundError(f"Images directory not found: {images_dir}")
    if not images_dir.is_dir():
        raise NotADirectoryError(f"Images path is not a directory: {images_dir}")
    if limit is not None and limit <= 0:
        raise ValueError("limit must be >= 1 or None")

    candidates: list[Path] = []
    for path in images_dir.iterdir():
        if not path.is_file():
            continue
        mime_type, _ = mimetypes.guess_type(str(path))
        if mime_type in _SUPPORTED_IMAGE_MIME_TYPES:
            candidates.append(path)

    candidates.sort(key=lambda p: p.name.lower())
    return candidates if limit is None else candidates[:limit]


def _detect_media_type(data: bytes) -> str | None:
    """Detect actual image format from file header bytes."""
    if data[:8] == b"\x89PNG\r\n\x1a\n":
        return "image/png"
    if data[:2] == b"\xff\xd8":
        return "image/jpeg"
    if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
        return "image/webp"
    if data[:6] in (b"GIF87a", b"GIF89a"):
        return "image/gif"
    return None


def encode_image(path: Path) -> EncodedImage:
    if not path.exists():
        raise FileNotFoundError(f"Image not found: {path}")
    if not path.is_file():
        raise FileNotFoundError(f"Image is not a file: {path}")

    raw = path.read_bytes()

    # Detect actual format from bytes — file extensions can lie
    media_type = _detect_media_type(raw)
    if not media_type:
        media_type, _ = mimetypes.guess_type(str(path))
    if not media_type or media_type not in _SUPPORTED_IMAGE_MIME_TYPES:
        raise ValueError(f"Unsupported image type for {path.name}: {media_type or 'unknown'}")

    base64_data = base64.b64encode(raw).decode("ascii")
    return EncodedImage(path=path, media_type=media_type, base64_data=base64_data)


def encode_image_from_base64(data_url: str) -> EncodedImage:
    """Create an EncodedImage from a base64 data URL or raw base64 string.

    Accepts:
        "data:image/png;base64,<data>"  — standard data URL
        "<raw base64>"                  — assumed image/png
    """
    if data_url.startswith("data:"):
        header, _, data = data_url.partition(",")
        media_type = header.split(";")[0][5:]  # strip "data:"
        if media_type not in _SUPPORTED_IMAGE_MIME_TYPES:
            raise ValueError(f"Unsupported image media type: {media_type}")
        base64_data = data
    else:
        media_type = "image/png"
        base64_data = data_url

    return EncodedImage(path=Path("inline"), media_type=media_type, base64_data=base64_data)


def encoded_image_to_bedrock_block(image: EncodedImage) -> dict:
    # Bedrock Converse multimodal block format (supported by langchain-aws ChatBedrockConverse).
    return {
        "type": "image",
        "source": {
            "type": "base64",
            "media_type": image.media_type,
            "data": image.base64_data,
        },
    }
