from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, status
from typing import Optional
from decimal import Decimal, ROUND_HALF_UP
from datetime import date, datetime
from dateutil.relativedelta import relativedelta
from . import repository
from .schema import CalculationCreate, CalculationUpdate
from .models import AmortizationScheduleEntry


def calculate_monthly_payment(principal: Decimal, annual_rate: Decimal, months: int) -> Decimal:
    if annual_rate == 0:
        return (principal / Decimal(months)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
    monthly_rate = annual_rate / Decimal("1200")
    numerator = principal * monthly_rate * ((Decimal("1") + monthly_rate) ** months)
    denominator = ((Decimal("1") + monthly_rate) ** months) - Decimal("1")
    return (numerator / denominator).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)


def generate_amortization_schedule(
    principal: Decimal,
    annual_rate: Decimal,
    months: int,
    monthly_payment: Decimal,
    start_date: date
) -> list:
    schedule = []
    balance = principal
    cumulative_interest = Decimal("0.00")
    monthly_rate = annual_rate / Decimal("1200")
    
    for payment_num in range(1, months + 1):
        payment_date = start_date + relativedelta(months=payment_num - 1)
        beginning_balance = balance
        
        if annual_rate == 0:
            interest_payment = Decimal("0.00")
        else:
            interest_payment = (balance * monthly_rate).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        
        if payment_num == months:
            principal_payment = balance
            payment_amount = principal_payment + interest_payment
        else:
            principal_payment = monthly_payment - interest_payment
            payment_amount = monthly_payment
        
        balance = balance - principal_payment
        if balance < Decimal("0.01"):
            balance = Decimal("0.00")
        
        cumulative_interest = cumulative_interest + interest_payment
        
        schedule.append({
            "payment_number": payment_num,
            "payment_date": payment_date,
            "beginning_balance": beginning_balance,
            "payment_amount": payment_amount,
            "principal_payment": principal_payment,
            "interest_payment": interest_payment,
            "ending_balance": balance,
            "cumulative_interest": cumulative_interest,
        })
    
    return schedule


def create_calculation(db: Session, data: CalculationCreate) -> dict:
    principal = data.principal
    annual_rate = data.annual_interest_rate
    months = data.loan_term_months
    
    if data.start_date:
        start_date = data.start_date
    else:
        today = date.today()
        start_date = (today.replace(day=1) + relativedelta(months=1))
    
    monthly_payment = calculate_monthly_payment(principal, annual_rate, months)
    
    schedule_entries = generate_amortization_schedule(
        principal, annual_rate, months, monthly_payment, start_date
    )
    
    total_interest = sum(entry["interest_payment"] for entry in schedule_entries)
    total_amount_paid = principal + total_interest
    
    calc_data = {
        "principal": principal,
        "annual_interest_rate": annual_rate,
        "loan_term_months": months,
        "monthly_payment": monthly_payment,
        "total_interest": total_interest,
        "total_amount_paid": total_amount_paid,
        "calculation_method": data.calculation_method,
        "start_date": start_date,
    }
    
    try:
        calculation = repository.create(db, calc_data)
        
        for entry in schedule_entries:
            entry["calculation_id"] = calculation.id
            schedule_obj = AmortizationScheduleEntry(**entry)
            db.add(schedule_obj)
        
        db.commit()
        db.refresh(calculation)
    except IntegrityError as exc:
        db.rollback()
        msg = str(exc.orig).lower() if exc.orig else ""
        if "unique" in msg or "duplicate" in msg:
            raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Calculation already exists.")
        if "foreign key" in msg:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Referenced resource does not exist.")
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Data integrity error.")
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise
    
    return calculation


def get_calculation(db: Session, calculation_id: str):
    calculation = repository.get_by_id(db, calculation_id)
    if not calculation:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Calculation not found")
    return calculation


def get_calculation_details(db: Session, calculation_id: str):
    calculation = repository.get_with_details(db, calculation_id)
    if not calculation:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Calculation not found")
    return calculation


def list_calculations(db: Session, limit: int, offset: int) -> dict:
    items = repository.list_all(db, limit=limit, offset=offset)
    total = repository.count_all(db)
    return {"items": items, "total": total, "limit": limit, "offset": offset}


def delete_calculation(db: Session, calculation_id: str) -> bool:
    try:
        result = repository.delete(db, calculation_id)
        if not result:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Calculation not found")
        db.commit()
        return True
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise