from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from typing import List, Optional
from datetime import datetime, timedelta
from decimal import Decimal, ROUND_HALF_EVEN
from . import repository
from .schema import (
    CalculationsessionCreate,
    CalculationsessionUpdate,
    CalculationInputSchema,
    AmortizationentryCreate,
    AmortizationentryUpdate,
)


def _get_or_raise(db: Session, entity_id: str, repo_module):
    obj = repo_module.get_by_id(db, entity_id)
    if not obj:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Entity with id {entity_id} not found")
    return obj


def _round_currency(value: float) -> float:
    """Round currency values using banker's rounding (round half to even) to 2 decimal places."""
    return float(Decimal(str(value)).quantize(Decimal("0.01"), rounding=ROUND_HALF_EVEN))


# CalculationSession Handlers
def create_calculationsession(db: Session, data: CalculationsessionCreate):
    try:
        obj = repository.create(db, data.model_dump())
        db.commit()
        db.refresh(obj)
        return obj
    except Exception:
        db.rollback()
        raise


def list_calculationsessions(db: Session, limit: int, offset: int, session_id: Optional[str] = None):
    filters = {}
    if session_id:
        filters["session_id"] = session_id
    return repository.list_all(db, limit, offset, **filters)


def get_calculationsession(db: Session, entity_id: str):
    obj = repository.get_by_id(db, entity_id)
    if not obj:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"CalculationSession with id {entity_id} not found")
    return obj


def update_calculationsession(db: Session, entity_id: str, data: CalculationsessionUpdate):
    try:
        obj = repository.update(db, entity_id, data.model_dump(exclude_unset=True))
        if not obj:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"CalculationSession with id {entity_id} not found")
        db.commit()
        db.refresh(obj)
        return obj
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def delete_calculationsession(db: Session, entity_id: str):
    try:
        success = repository.delete(db, entity_id)
        if not success:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"CalculationSession with id {entity_id} not found")
        db.commit()
        return {"message": "CalculationSession deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# Calculation Handlers
def list_calculations(db: Session, limit: int, offset: int, session_id: Optional[str] = None):
    filters = {}
    if session_id:
        filters["session_id"] = session_id
    return repository.list_calculations(db, limit, offset, **filters)


def get_calculation(db: Session, entity_id: str):
    obj = repository.get_calculation_by_id(db, entity_id)
    if not obj:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Calculation with id {entity_id} not found")
    return obj


def get_calculation_details(db: Session, calculation_id: str):
    obj = repository.get_calculation_with_details(db, calculation_id)
    if not obj:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Calculation with id {calculation_id} not found")
    return obj


def delete_calculation(db: Session, entity_id: str):
    try:
        success = repository.delete_calculation(db, entity_id)
        if not success:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Calculation with id {entity_id} not found")
        db.commit()
        return {"message": "Calculation deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# AmortizationEntry Handlers
def list_amortization_entries(db: Session, limit: int, offset: int, calculation_id: Optional[str] = None):
    filters = {}
    if calculation_id:
        filters["calculation_id"] = calculation_id
    return repository.list_amortization_entries(db, limit, offset, **filters)


def get_amortization_entry(db: Session, entity_id: str):
    obj = repository.get_amortization_entry_by_id(db, entity_id)
    if not obj:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"AmortizationEntry with id {entity_id} not found")
    return obj


def create_amortization_entry(db: Session, data: AmortizationentryCreate):
    calc = repository.get_calculation_by_id(db, data.calculation_id)
    if not calc:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Calculation with id {data.calculation_id} not found")
    try:
        obj = repository.create_amortization_entry(db, data.model_dump())
        db.commit()
        db.refresh(obj)
        return obj
    except Exception:
        db.rollback()
        raise


def update_amortization_entry(db: Session, entity_id: str, data: AmortizationentryUpdate):
    try:
        obj = repository.update_amortization_entry(db, entity_id, data.model_dump(exclude_unset=True))
        if not obj:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"AmortizationEntry with id {entity_id} not found")
        db.commit()
        db.refresh(obj)
        return obj
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def delete_amortization_entry(db: Session, entity_id: str):
    try:
        success = repository.delete_amortization_entry(db, entity_id)
        if not success:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"AmortizationEntry with id {entity_id} not found")
        db.commit()
        return {"message": "AmortizationEntry deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# Workflow: Perform Loan Calculation
def perform_loan_calculation(db: Session, data: CalculationInputSchema):
    """
    Perform loan calculation and generate complete amortization schedule.
    Validates inputs, calculates monthly payment, generates all amortization entries.
    """
    # Validate session exists
    session = repository.get_session_by_session_id(db, data.session_id)
    if not session:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session with session_id {data.session_id} not found")

    principal = data.principal
    annual_rate = data.annual_interest_rate
    term_months = data.loan_term_months

    # Calculate monthly payment using standard amortization formula
    if annual_rate == 0.0:
        # If interest rate is 0%, monthly payment = principal / number of months
        monthly_payment = _round_currency(principal / term_months)
        total_interest = 0.0
    else:
        # M = P × [i(1 + i)^n] / [(1 + i)^n - 1]
        monthly_rate = annual_rate / 12 / 100
        numerator = monthly_rate * ((1 + monthly_rate) ** term_months)
        denominator = ((1 + monthly_rate) ** term_months) - 1
        monthly_payment = _round_currency(principal * (numerator / denominator))
        total_interest = _round_currency((monthly_payment * term_months) - principal)

    total_amount = _round_currency(principal + total_interest)

    try:
        # Create calculation record
        calc_data = {
            "session_id": session.id,
            "principal": principal,
            "annual_interest_rate": annual_rate,
            "loan_term_months": term_months,
            "monthly_payment": monthly_payment,
            "total_interest": total_interest,
            "total_amount": total_amount,
        }
        calculation = repository.create_calculation(db, calc_data)
        db.flush()

        # Generate amortization schedule
        entries_data = []
        balance = principal
        cumulative_interest = 0.0
        cumulative_principal = 0.0
        first_payment_date = datetime.utcnow().date() + timedelta(days=30)

        for payment_num in range(1, term_months + 1):
            payment_date = first_payment_date + timedelta(days=30 * (payment_num - 1))
            beginning_balance = balance

            if annual_rate == 0.0:
                interest_portion = 0.0
            else:
                monthly_rate = annual_rate / 12 / 100
                interest_portion = _round_currency(balance * monthly_rate)

            # For final payment, adjust to ensure balance reaches exactly $0.00
            if payment_num == term_months:
                principal_portion = balance
                payment_amount = _round_currency(principal_portion + interest_portion)
                ending_balance = 0.0
            else:
                principal_portion = _round_currency(monthly_payment - interest_portion)
                payment_amount = monthly_payment
                ending_balance = _round_currency(balance - principal_portion)

            cumulative_interest = _round_currency(cumulative_interest + interest_portion)
            cumulative_principal = _round_currency(cumulative_principal + principal_portion)

            entry_data = {
                "calculation_id": calculation.id,
                "payment_number": payment_num,
                "payment_date": payment_date,
                "beginning_balance": beginning_balance,
                "payment_amount": payment_amount,
                "principal_portion": principal_portion,
                "interest_portion": interest_portion,
                "ending_balance": ending_balance,
                "cumulative_interest": cumulative_interest,
                "cumulative_principal": cumulative_principal,
            }
            entries_data.append(entry_data)
            balance = ending_balance

        # Bulk create amortization entries
        repository.create_amortization_entries_bulk(db, entries_data)

        # Update session calculation count and last access time
        session.calculation_count += 1
        session.last_access_at = datetime.utcnow()

        db.commit()
        db.refresh(calculation)

        # Return calculation with amortization entries
        return repository.get_calculation_with_details(db, calculation.id)

    except Exception:
        db.rollback()
        raise


# Workflow: Create or Retrieve Session
def create_or_retrieve_session(db: Session, session_id: str, user_agent: Optional[str] = None, ip_address: Optional[str] = None):
    """
    Create a new session or retrieve existing session by session_id.
    Updates last_access_at if session exists.
    """
    session = repository.get_session_by_session_id(db, session_id)
    if session:
        try:
            session.last_access_at = datetime.utcnow()
            db.commit()
            db.refresh(session)
            return session
        except Exception:
            db.rollback()
            raise
    else:
        try:
            now = datetime.utcnow()
            session_data = {
                "session_id": session_id,
                "first_access_at": now,
                "last_access_at": now,
                "calculation_count": 0,
                "user_agent": user_agent,
                "ip_address": ip_address,
            }
            session = repository.create(db, session_data)
            db.commit()
            db.refresh(session)
            return session
        except Exception:
            db.rollback()
            raise


# Workflow: Get Session Calculations
def get_session_calculations(db: Session, session_id: str, limit: int, offset: int):
    """
    List all calculations for a given session_id.
    """
    session = repository.get_session_by_session_id(db, session_id)
    if not session:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session with session_id {session_id} not found")
    return repository.list_calculations(db, limit, offset, session_id=session.id)