from sqlalchemy.orm import Session, joinedload
from typing import Optional, List
from .models import Document

def get_by_id(db: Session, entity_id: str) -> Optional[Document]:
    return db.query(Document).filter(Document.id == entity_id).first()

def list_all(
    db: Session,
    limit: int = 20,
    offset: int = 0,
    booking_id: Optional[str] = None,
    customer_id: Optional[str] = None,
    document_type: Optional[str] = None,
    uploaded_by: Optional[str] = None
) -> List[Document]:
    query = db.query(Document)
    
    if booking_id:
        query = query.filter(Document.booking_id == booking_id)
    if customer_id:
        query = query.filter(Document.customer_id == customer_id)
    if document_type:
        query = query.filter(Document.document_type == document_type)
    if uploaded_by:
        query = query.filter(Document.uploaded_by == uploaded_by)
    
    return query.order_by(Document.created_at.desc()).limit(limit).offset(offset).all()

def create(db: Session, data: dict) -> Document:
    document = Document(**data)
    db.add(document)
    db.flush()
    return document

def update(db: Session, entity_id: str, data: dict) -> Optional[Document]:
    document = get_by_id(db, entity_id)
    if not document:
        return None
    
    for key, value in data.items():
        setattr(document, key, value)
    
    db.flush()
    return document

def delete(db: Session, entity_id: str) -> bool:
    document = get_by_id(db, entity_id)
    if not document:
        return False
    
    db.delete(document)
    db.flush()
    return True

def get_with_details(db: Session, entity_id: str) -> Optional[Document]:
    return (
        db.query(Document)
        .options(
            joinedload(Document.uploader),
            joinedload(Document.booking),
            joinedload(Document.customer).joinedload("user")
        )
        .filter(Document.id == entity_id)
        .first()
    )

def count_by_booking(db: Session, booking_id: str) -> int:
    return db.query(Document).filter(Document.booking_id == booking_id).count()

def count_by_customer(db: Session, customer_id: str) -> int:
    return db.query(Document).filter(Document.customer_id == customer_id).count()