from sqlalchemy.orm import Session, joinedload
from typing import Optional, List
from .models import Calculation, AmortizationScheduleEntry


def get_by_id(db: Session, entity_id: str) -> Optional[Calculation]:
    return db.query(Calculation).filter(Calculation.id == entity_id).first()


def get_with_details(db: Session, entity_id: str) -> Optional[Calculation]:
    return (
        db.query(Calculation)
        .options(joinedload(Calculation.amortization_schedule_entries))
        .filter(Calculation.id == entity_id)
        .first()
    )


def list_all(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Calculation]:
    q = db.query(Calculation)
    return q.order_by(Calculation.created_at.desc()).limit(limit).offset(offset).all()


def count_all(db: Session, **filters) -> int:
    q = db.query(Calculation.id)
    return q.count()


def create(db: Session, data: dict) -> Calculation:
    obj = Calculation(**data)
    db.add(obj)
    db.flush()
    return obj


def update(db: Session, entity_id: str, data: dict) -> Optional[Calculation]:
    obj = db.query(Calculation).filter(Calculation.id == entity_id).first()
    if obj is None:
        return None
    for key, value in data.items():
        setattr(obj, key, value)
    db.flush()
    return obj


def delete(db: Session, entity_id: str) -> bool:
    obj = db.query(Calculation).filter(Calculation.id == entity_id).first()
    if obj is None:
        return False
    db.delete(obj)
    db.flush()
    return True


def list_schedule_entries_by_calculation(db: Session, calculation_id: str) -> List[AmortizationScheduleEntry]:
    return (
        db.query(AmortizationScheduleEntry)
        .filter(AmortizationScheduleEntry.calculation_id == calculation_id)
        .order_by(AmortizationScheduleEntry.payment_number)
        .all()
    )