from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from typing import List, Optional
from datetime import datetime, timedelta
from decimal import Decimal, ROUND_HALF_EVEN
from . import repository
from .schema import (
    CalculationsessionCreate,
    CalculationsessionUpdate,
    CalculationsessionResponse,
    CalculationCreate,
    CalculationUpdate,
    CalculationResponse,
    CalculationDetailsResponse,
    AmortizationentryCreate,
    AmortizationentryUpdate,
    AmortizationentryResponse,
    LoanCalculationInput,
    SessionTrackInput,
)


def _get_or_raise(db: Session, entity_id: str, repo_module):
    obj = repo_module.get_by_id(db, entity_id)
    if not obj:
        raise HTTPException(status_code=404, detail=f"Entity with id {entity_id} not found")
    return obj


def _round_currency(value: float) -> float:
    """Round currency values using banker's rounding (round half to even) to 2 decimal places"""
    return float(Decimal(str(value)).quantize(Decimal("0.01"), rounding=ROUND_HALF_EVEN))


# CalculationSession handlers
def create_calculationsession(db: Session, data: CalculationsessionCreate) -> CalculationsessionResponse:
    existing = repository.get_session_by_session_id(db, data.session_id)
    if existing:
        raise HTTPException(status_code=409, detail=f"Session with session_id {data.session_id} already exists")
    
    try:
        session = repository.create(db, data.model_dump())
        db.commit()
        db.refresh(session)
        return CalculationsessionResponse.model_validate(session)
    except Exception:
        db.rollback()
        raise


def list_calculationsessions(
    db: Session, limit: int, offset: int, session_id: Optional[str] = None
) -> List[CalculationsessionResponse]:
    filters = {}
    if session_id:
        filters["session_id"] = session_id
    sessions = repository.list_all(db, limit, offset, **filters)
    return [CalculationsessionResponse.model_validate(s) for s in sessions]


def get_calculationsession(db: Session, entity_id: str) -> CalculationsessionResponse:
    session = _get_or_raise(db, entity_id, repository)
    return CalculationsessionResponse.model_validate(session)


def update_calculationsession(
    db: Session, entity_id: str, data: CalculationsessionUpdate
) -> CalculationsessionResponse:
    update_data = data.model_dump(exclude_unset=True)
    if not update_data:
        raise HTTPException(status_code=400, detail="No fields to update")
    
    if "session_id" in update_data:
        existing = repository.get_session_by_session_id(db, update_data["session_id"])
        if existing and existing.id != entity_id:
            raise HTTPException(status_code=409, detail=f"Session with session_id {update_data['session_id']} already exists")
    
    try:
        session = repository.update(db, entity_id, update_data)
        if not session:
            raise HTTPException(status_code=404, detail=f"Session with id {entity_id} not found")
        db.commit()
        db.refresh(session)
        return CalculationsessionResponse.model_validate(session)
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def delete_calculationsession(db: Session, entity_id: str) -> dict:
    try:
        success = repository.delete(db, entity_id)
        if not success:
            raise HTTPException(status_code=404, detail=f"Session with id {entity_id} not found")
        db.commit()
        return {"message": "Session deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# Calculation handlers
def create_calculation(db: Session, data: CalculationCreate) -> CalculationResponse:
    session = repository.get_by_id(db, data.session_id)
    if not session:
        raise HTTPException(status_code=404, detail=f"Session with id {data.session_id} not found")
    
    try:
        calculation = repository.create_calculation(db, data.model_dump())
        db.commit()
        db.refresh(calculation)
        return CalculationResponse.model_validate(calculation)
    except Exception:
        db.rollback()
        raise


def list_calculations(
    db: Session, limit: int, offset: int, session_id: Optional[str] = None
) -> List[CalculationResponse]:
    filters = {}
    if session_id:
        filters["session_id"] = session_id
    calculations = repository.list_calculations(db, limit, offset, **filters)
    return [CalculationResponse.model_validate(c) for c in calculations]


def get_calculation(db: Session, entity_id: str) -> CalculationResponse:
    calculation = repository.get_calculation_by_id(db, entity_id)
    if not calculation:
        raise HTTPException(status_code=404, detail=f"Calculation with id {entity_id} not found")
    return CalculationResponse.model_validate(calculation)


def update_calculation(db: Session, entity_id: str, data: CalculationUpdate) -> CalculationResponse:
    update_data = data.model_dump(exclude_unset=True)
    if not update_data:
        raise HTTPException(status_code=400, detail="No fields to update")
    
    if "session_id" in update_data:
        session = repository.get_by_id(db, update_data["session_id"])
        if not session:
            raise HTTPException(status_code=404, detail=f"Session with id {update_data['session_id']} not found")
    
    try:
        calculation = repository.update_calculation(db, entity_id, update_data)
        if not calculation:
            raise HTTPException(status_code=404, detail=f"Calculation with id {entity_id} not found")
        db.commit()
        db.refresh(calculation)
        return CalculationResponse.model_validate(calculation)
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def delete_calculation(db: Session, entity_id: str) -> dict:
    try:
        success = repository.delete_calculation(db, entity_id)
        if not success:
            raise HTTPException(status_code=404, detail=f"Calculation with id {entity_id} not found")
        db.commit()
        return {"message": "Calculation deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# AmortizationEntry handlers
def create_amortizationentry(db: Session, data: AmortizationentryCreate) -> AmortizationentryResponse:
    calculation = repository.get_calculation_by_id(db, data.calculation_id)
    if not calculation:
        raise HTTPException(status_code=404, detail=f"Calculation with id {data.calculation_id} not found")
    
    try:
        entry = repository.create_amortization(db, data.model_dump())
        db.commit()
        db.refresh(entry)
        return AmortizationentryResponse.model_validate(entry)
    except Exception:
        db.rollback()
        raise


def list_amortizationentries(
    db: Session, limit: int, offset: int, calculation_id: Optional[str] = None
) -> List[AmortizationentryResponse]:
    filters = {}
    if calculation_id:
        filters["calculation_id"] = calculation_id
    entries = repository.list_amortization_entries(db, limit, offset, **filters)
    return [AmortizationentryResponse.model_validate(e) for e in entries]


def get_amortizationentry(db: Session, entity_id: str) -> AmortizationentryResponse:
    entry = repository.get_amortization_by_id(db, entity_id)
    if not entry:
        raise HTTPException(status_code=404, detail=f"Amortization entry with id {entity_id} not found")
    return AmortizationentryResponse.model_validate(entry)


def update_amortizationentry(
    db: Session, entity_id: str, data: AmortizationentryUpdate
) -> AmortizationentryResponse:
    update_data = data.model_dump(exclude_unset=True)
    if not update_data:
        raise HTTPException(status_code=400, detail="No fields to update")
    
    if "calculation_id" in update_data:
        calculation = repository.get_calculation_by_id(db, update_data["calculation_id"])
        if not calculation:
            raise HTTPException(status_code=404, detail=f"Calculation with id {update_data['calculation_id']} not found")
    
    try:
        entry = repository.update_amortization(db, entity_id, update_data)
        if not entry:
            raise HTTPException(status_code=404, detail=f"Amortization entry with id {entity_id} not found")
        db.commit()
        db.refresh(entry)
        return AmortizationentryResponse.model_validate(entry)
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def delete_amortizationentry(db: Session, entity_id: str) -> dict:
    try:
        success = repository.delete_amortization(db, entity_id)
        if not success:
            raise HTTPException(status_code=404, detail=f"Amortization entry with id {entity_id} not found")
        db.commit()
        return {"message": "Amortization entry deleted successfully"}
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


# Workflow handlers
def calculate_loan(db: Session, data: LoanCalculationInput) -> CalculationDetailsResponse:
    """
    Perform loan calculation and generate complete amortization schedule.
    Validates inputs, calculates monthly payment, generates all amortization entries,
    and returns calculation with full schedule.
    """
    session = repository.get_by_id(db, data.session_id)
    if not session:
        raise HTTPException(status_code=404, detail=f"Session with id {data.session_id} not found")
    
    principal = data.principal
    annual_rate = data.annual_interest_rate
    term_months = data.loan_term_months
    
    # Calculate monthly payment using amortization formula
    if annual_rate == 0:
        monthly_payment = _round_currency(principal / term_months)
    else:
        monthly_rate = annual_rate / 100 / 12
        monthly_payment = _round_currency(
            principal * (monthly_rate * (1 + monthly_rate) ** term_months) / 
            ((1 + monthly_rate) ** term_months - 1)
        )
    
    # Calculate totals
    total_amount = _round_currency(monthly_payment * term_months)
    total_interest = _round_currency(total_amount - principal)
    
    try:
        # Create calculation record
        calculation_data = {
            "session_id": data.session_id,
            "principal": principal,
            "annual_interest_rate": annual_rate,
            "loan_term_months": term_months,
            "monthly_payment": monthly_payment,
            "total_interest": total_interest,
            "total_amount": total_amount,
        }
        calculation = repository.create_calculation(db, calculation_data)
        db.flush()
        
        # Generate amortization schedule
        balance = principal
        cumulative_interest = 0.0
        cumulative_principal = 0.0
        first_payment_date = datetime.utcnow().date() + timedelta(days=30)
        
        for payment_num in range(1, term_months + 1):
            payment_date = first_payment_date + timedelta(days=30 * (payment_num - 1))
            beginning_balance = balance
            
            if annual_rate == 0:
                interest_portion = 0.0
            else:
                monthly_rate = annual_rate / 100 / 12
                interest_portion = _round_currency(balance * monthly_rate)
            
            principal_portion = _round_currency(monthly_payment - interest_portion)
            
            # Adjust final payment to ensure balance reaches exactly 0
            if payment_num == term_months:
                principal_portion = balance
                payment_amount = _round_currency(principal_portion + interest_portion)
                ending_balance = 0.0
            else:
                payment_amount = monthly_payment
                ending_balance = _round_currency(balance - principal_portion)
            
            cumulative_interest = _round_currency(cumulative_interest + interest_portion)
            cumulative_principal = _round_currency(cumulative_principal + principal_portion)
            
            entry_data = {
                "calculation_id": calculation.id,
                "payment_number": payment_num,
                "payment_date": payment_date,
                "beginning_balance": beginning_balance,
                "payment_amount": payment_amount,
                "principal_portion": principal_portion,
                "interest_portion": interest_portion,
                "ending_balance": ending_balance,
                "cumulative_interest": cumulative_interest,
                "cumulative_principal": cumulative_principal,
            }
            repository.create_amortization(db, entry_data)
            
            balance = ending_balance
        
        # Update session calculation count
        session.calculation_count += 1
        session.last_access_at = datetime.utcnow()
        
        db.commit()
        
        # Retrieve calculation with all amortization entries
        calculation_with_details = repository.get_calculation_with_details(db, calculation.id)
        return CalculationDetailsResponse.model_validate(calculation_with_details)
    
    except Exception:
        db.rollback()
        raise


def get_calculation_details(db: Session, calculation_id: str) -> CalculationDetailsResponse:
    """
    Get calculation with full amortization schedule.
    Returns calculation record with all associated amortization entries.
    """
    calculation = repository.get_calculation_with_details(db, calculation_id)
    if not calculation:
        raise HTTPException(status_code=404, detail=f"Calculation with id {calculation_id} not found")
    return CalculationDetailsResponse.model_validate(calculation)


def get_amortization_schedule(db: Session, calculation_id: str) -> List[AmortizationentryResponse]:
    """
    Get amortization schedule for a calculation.
    Returns ordered list of all amortization entries for the specified calculation.
    """
    calculation = repository.get_calculation_by_id(db, calculation_id)
    if not calculation:
        raise HTTPException(status_code=404, detail=f"Calculation with id {calculation_id} not found")
    
    entries = repository.list_amortization_by_calculation(db, calculation_id)
    return [AmortizationentryResponse.model_validate(e) for e in entries]


def track_session(db: Session, data: SessionTrackInput) -> CalculationsessionResponse:
    """
    Create or update calculation session for tracking.
    Creates new session on first visit or updates last access time for existing session.
    """
    existing = repository.get_session_by_session_id(db, data.session_id)
    
    try:
        if existing:
            # Update existing session
            update_data = {
                "last_access_at": datetime.utcnow(),
            }
            if data.user_agent:
                update_data["user_agent"] = data.user_agent
            if data.ip_address:
                update_data["ip_address"] = data.ip_address
            
            session = repository.update(db, existing.id, update_data)
            db.commit()
            db.refresh(session)
            return CalculationsessionResponse.model_validate(session)
        else:
            # Create new session
            now = datetime.utcnow()
            session_data = {
                "session_id": data.session_id,
                "first_access_at": now,
                "last_access_at": now,
                "calculation_count": 0,
                "user_agent": data.user_agent,
                "ip_address": data.ip_address,
            }
            session = repository.create(db, session_data)
            db.commit()
            db.refresh(session)
            return CalculationsessionResponse.model_validate(session)
    except Exception:
        db.rollback()
        raise


def get_session_calculations(db: Session, session_id: str, limit: int, offset: int) -> List[CalculationResponse]:
    """
    Get all calculations for a session.
    Returns paginated list of calculations associated with the specified session.
    """
    session = repository.get_session_by_session_id(db, session_id)
    if not session:
        raise HTTPException(status_code=404, detail=f"Session with session_id {session_id} not found")
    
    filters = {"session_id": session.id}
    calculations = repository.list_calculations(db, limit, offset, **filters)
    return [CalculationResponse.model_validate(c) for c in calculations]