from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from fastapi import HTTPException, status
from typing import Optional, List
from decimal import Decimal
from datetime import datetime
from . import repository
from .models import CalculationResult, ChartDataPoint
from .schema import (
    CalculationinputCreate,
    CalculationinputUpdate,
    CalculationresultResponse,
    ChartdatapointResponse,
    CalculationWithResultAndChartResponse,
)


def calculate_sip_maturity(
    monthly_investment: Decimal, annual_return_rate: Decimal, time_period: int
) -> tuple[Decimal, Decimal, Decimal]:
    """
    Calculate SIP maturity using the formula:
    M = P × ({[1 + i]^n – 1} / i) × (1 + i)
    where P = monthly investment, i = monthly rate, n = total months
    Returns (total_investment, estimated_returns, maturity_value)
    """
    months = time_period * 12
    monthly_rate = annual_return_rate / Decimal("12") / Decimal("100")
    
    total_investment = monthly_investment * Decimal(months)
    
    if monthly_rate == 0:
        maturity_value = total_investment
    else:
        power_term = (Decimal("1") + monthly_rate) ** months
        maturity_value = monthly_investment * ((power_term - Decimal("1")) / monthly_rate) * (Decimal("1") + monthly_rate)
    
    maturity_value = maturity_value.quantize(Decimal("0.01"))
    estimated_returns = maturity_value - total_investment
    
    return total_investment, estimated_returns, maturity_value


def generate_chart_data_points(
    calculation_input_id: str,
    monthly_investment: Decimal,
    annual_return_rate: Decimal,
    time_period: int,
) -> List[dict]:
    """
    Generate year-by-year chart data points.
    For periods <= 10 years: one point per year.
    For longer periods: points every 2-3 years (configurable).
    """
    monthly_rate = annual_return_rate / Decimal("12") / Decimal("100")
    points = []
    
    if time_period <= 10:
        step = 1
    else:
        step = 2
    
    for year in range(1, time_period + 1, step):
        months = year * 12
        invested_amount = monthly_investment * Decimal(months)
        
        if monthly_rate == 0:
            projected_value = invested_amount
        else:
            power_term = (Decimal("1") + monthly_rate) ** months
            projected_value = monthly_investment * ((power_term - Decimal("1")) / monthly_rate) * (Decimal("1") + monthly_rate)
        
        projected_value = projected_value.quantize(Decimal("0.01"))
        invested_amount = invested_amount.quantize(Decimal("0.01"))
        
        points.append({
            "calculation_input_id": calculation_input_id,
            "year": year,
            "invested_amount": invested_amount,
            "projected_value": projected_value,
        })
    
    if time_period > 10 and points[-1]["year"] != time_period:
        months = time_period * 12
        invested_amount = monthly_investment * Decimal(months)
        if monthly_rate == 0:
            projected_value = invested_amount
        else:
            power_term = (Decimal("1") + monthly_rate) ** months
            projected_value = monthly_investment * ((power_term - Decimal("1")) / monthly_rate) * (Decimal("1") + monthly_rate)
        projected_value = projected_value.quantize(Decimal("0.01"))
        invested_amount = invested_amount.quantize(Decimal("0.01"))
        points.append({
            "calculation_input_id": calculation_input_id,
            "year": time_period,
            "invested_amount": invested_amount,
            "projected_value": projected_value,
        })
    
    return points


def create_calculation_with_chart_data(
    db: Session, data: CalculationinputCreate
) -> dict:
    """
    Creates a calculation input, computes the result using SIP formula,
    and generates year-by-year chart data points for visualization.
    """
    try:
        input_data = data.model_dump()
        calc_input = repository.create(db, input_data)
        db.flush()
        
        total_investment, estimated_returns, maturity_value = calculate_sip_maturity(
            calc_input.monthly_investment,
            calc_input.expected_return_rate,
            calc_input.time_period,
        )
        
        result_data = {
            "calculation_input_id": calc_input.id,
            "total_investment": total_investment,
            "estimated_returns": estimated_returns,
            "maturity_value": maturity_value,
        }
        calc_result = CalculationResult(**result_data)
        db.add(calc_result)
        db.flush()
        
        chart_points_data = generate_chart_data_points(
            calc_input.id,
            calc_input.monthly_investment,
            calc_input.expected_return_rate,
            calc_input.time_period,
        )
        
        chart_points = []
        for point_data in chart_points_data:
            chart_point = ChartDataPoint(**point_data)
            db.add(chart_point)
            chart_points.append(chart_point)
        
        db.commit()
        db.refresh(calc_input)
        db.refresh(calc_result)
        for cp in chart_points:
            db.refresh(cp)
        
        return {
            "calculation_input": calc_input,
            "calculation_result": calc_result,
            "chart_data_points": chart_points,
        }
    
    except IntegrityError as exc:
        db.rollback()
        msg = str(exc.orig).lower() if exc.orig else ""
        if "unique" in msg or "duplicate" in msg:
            raise HTTPException(
                status_code=status.HTTP_409_CONFLICT,
                detail="Resource already exists.",
            )
        if "foreign key" in msg:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Referenced resource does not exist.",
            )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Data integrity error.",
        )
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise


def get_calculation_by_id(db: Session, entity_id: str):
    calc_input = repository.get_by_id(db, entity_id)
    if calc_input is None:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Calculation not found.",
        )
    return calc_input


def get_calculation_details(db: Session, entity_id: str):
    calc_input = repository.get_with_details(db, entity_id)
    if calc_input is None:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Calculation not found.",
        )
    return calc_input


def list_calculations(
    db: Session,
    limit: int,
    offset: int,
    timestamp_start: Optional[datetime] = None,
    timestamp_end: Optional[datetime] = None,
) -> dict:
    filters = {}
    if timestamp_start is not None:
        filters["timestamp_start"] = timestamp_start
    if timestamp_end is not None:
        filters["timestamp_end"] = timestamp_end
    
    items = repository.list_all(db, limit=limit, offset=offset, **filters)
    total = repository.count_all(db, **filters)
    return {"items": items, "total": total, "limit": limit, "offset": offset}


def delete_calculation(db: Session, entity_id: str) -> bool:
    calc_input = repository.get_by_id(db, entity_id)
    if calc_input is None:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Calculation not found.",
        )
    try:
        success = repository.delete(db, entity_id)
        db.commit()
        return success
    except IntegrityError as exc:
        db.rollback()
        msg = str(exc.orig).lower() if exc.orig else ""
        if "foreign key" in msg:
            raise HTTPException(
                status_code=status.HTTP_409_CONFLICT,
                detail="Cannot delete calculation while referenced by other records.",
            )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Data integrity error.",
        )
    except HTTPException:
        db.rollback()
        raise
    except Exception:
        db.rollback()
        raise