from sqlalchemy.orm import Session, joinedload
from typing import Optional, List
from .models import Payment

def get_by_id(db: Session, payment_id: str) -> Optional[Payment]:
    return db.query(Payment).filter(Payment.id == payment_id).first()

def get_by_gateway_reference(db: Session, gateway_reference: str) -> Optional[Payment]:
    return db.query(Payment).filter(Payment.payment_gateway_reference == gateway_reference).first()

def list_all(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Payment]:
    query = db.query(Payment)
    
    if "user_id" in filters and filters["user_id"]:
        query = query.filter(Payment.user_id == filters["user_id"])
    
    if "booking_id" in filters and filters["booking_id"]:
        query = query.filter(Payment.booking_id == filters["booking_id"])
    
    if "status" in filters and filters["status"]:
        query = query.filter(Payment.status == filters["status"])
    
    if "payment_method" in filters and filters["payment_method"]:
        query = query.filter(Payment.payment_method == filters["payment_method"])
    
    return query.order_by(Payment.created_at.desc()).limit(limit).offset(offset).all()

def create(db: Session, data: dict) -> Payment:
    payment = Payment(**data)
    db.add(payment)
    db.flush()
    return payment

def update(db: Session, payment_id: str, data: dict) -> Optional[Payment]:
    payment = get_by_id(db, payment_id)
    if not payment:
        return None
    
    for key, value in data.items():
        setattr(payment, key, value)
    
    db.flush()
    return payment

def delete(db: Session, payment_id: str) -> bool:
    payment = get_by_id(db, payment_id)
    if not payment:
        return False
    
    db.delete(payment)
    db.flush()
    return True

def get_with_details(db: Session, payment_id: str) -> Optional[Payment]:
    return (
        db.query(Payment)
        .options(
            joinedload(Payment.booking),
            joinedload(Payment.user)
        )
        .filter(Payment.id == payment_id)
        .first()
    )

def list_by_booking_id(db: Session, booking_id: str) -> List[Payment]:
    return db.query(Payment).filter(Payment.booking_id == booking_id).order_by(Payment.created_at.desc()).all()

def list_by_user_id(db: Session, user_id: str, limit: int = 20, offset: int = 0) -> List[Payment]:
    return (
        db.query(Payment)
        .filter(Payment.user_id == user_id)
        .order_by(Payment.created_at.desc())
        .limit(limit)
        .offset(offset)
        .all()
    )