from typing import Optional

from fastapi import Depends, HTTPException, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import ExpiredSignatureError, JWTError, jwt
from pydantic import BaseModel

from src.config.settings import settings_instance
from src.config.logger_manager import get_logger

logger = get_logger(__name__)

security = HTTPBearer()


# ---------------------- Token Data Model ----------------------
class TokenData(BaseModel):
    """Data extracted from JWT token."""

    id: Optional[int] = None
    login_id: Optional[str] = None
    client_id: Optional[str] = None
    org_id: Optional[int] = None
    org_user_id: Optional[int] = None
    role: Optional[str] = None
    domain: Optional[str] = None
    is_owner: Optional[int] = 0


# ---------------------- Token Verification ----------------------
def verify_token(token: str) -> Optional[TokenData]:
    """
    Verify and decode a JWT token.

    Args:
        token: JWT token string (with or without Bearer prefix)

    Returns:
        TokenData if valid, None if invalid
    """
    if not token:
        return None

    # Strip Bearer prefix if present
    if token.startswith("Bearer "):
        token = token[7:]

    try:
        payload = jwt.decode(
            token,
            settings_instance.SECRET_KEY,
            algorithms=[settings_instance.ALGORITHM],
        )

        login_id: str = payload.get("sub")
        if login_id is None:
            logger.warning("Token missing 'sub' claim")
            return None

        return TokenData(
            login_id=login_id,
            id=payload.get("id"),
            client_id=payload.get("client_id"),
            org_id=payload.get("org_id"),
            org_user_id=payload.get("org_user_id"),
            role=payload.get("role"),
            domain=payload.get("domain"),
            is_owner=payload.get("is_owner", 0),
        )

    except ExpiredSignatureError:
        logger.debug("Token has expired")
        return None
    except JWTError as e:
        logger.warning(f"JWT verification failed: {e}")
        return None
    except Exception as e:
        logger.error(f"Unexpected token verification error: {e}")
        return None


def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security),
) -> TokenData:
    """
    FastAPI dependency to extract and validate the current user from JWT.

    Raises:
        HTTPException: 401 if token is missing or invalid
    """
    token = credentials.credentials
    token_data = verify_token(token)

    if not token_data:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail={"message": "Invalid or expired token"},
            headers={"WWW-Authenticate": "Bearer"},
        )

    return token_data
