from sqlalchemy.orm import Session
from typing import Optional, List
from .models import Timelog
from datetime import datetime


def get_by_id(db: Session, entity_id: str) -> Optional[Timelog]:
    return db.query(Timelog).filter(Timelog.id == entity_id).first()


def list_all(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Timelog]:
    q = db.query(Timelog)
    
    if filters.get("job_id") is not None:
        q = q.filter(Timelog.job_id == filters["job_id"])
    
    if filters.get("user_id") is not None:
        q = q.filter(Timelog.user_id == filters["user_id"])
    
    if filters.get("start_date") is not None:
        q = q.filter(Timelog.start_time >= filters["start_date"])
    
    if filters.get("end_date") is not None:
        q = q.filter(Timelog.start_time <= filters["end_date"])
    
    return q.order_by(Timelog.created_at.desc()).limit(limit).offset(offset).all()


def count_all(db: Session, **filters) -> int:
    q = db.query(Timelog.id)
    
    if filters.get("job_id") is not None:
        q = q.filter(Timelog.job_id == filters["job_id"])
    
    if filters.get("user_id") is not None:
        q = q.filter(Timelog.user_id == filters["user_id"])
    
    if filters.get("start_date") is not None:
        q = q.filter(Timelog.start_time >= filters["start_date"])
    
    if filters.get("end_date") is not None:
        q = q.filter(Timelog.start_time <= filters["end_date"])
    
    return q.count()


def create(db: Session, data: dict) -> Timelog:
    timelog = Timelog(**data)
    db.add(timelog)
    db.flush()
    return timelog


def update(db: Session, entity_id: str, data: dict) -> Optional[Timelog]:
    timelog = get_by_id(db, entity_id)
    if not timelog:
        return None
    for key, value in data.items():
        setattr(timelog, key, value)
    db.flush()
    return timelog


def delete(db: Session, entity_id: str) -> bool:
    timelog = get_by_id(db, entity_id)
    if not timelog:
        return False
    db.delete(timelog)
    db.flush()
    return True


def get_active_timelog_for_user(db: Session, user_id: str) -> Optional[Timelog]:
    return (
        db.query(Timelog)
        .filter(Timelog.user_id == user_id)
        .filter(Timelog.end_time.is_(None))
        .first()
    )


def get_overlapping_timelogs(db: Session, user_id: str, start_time: datetime, end_time: Optional[datetime] = None) -> List[Timelog]:
    q = db.query(Timelog).filter(Timelog.user_id == user_id)
    
    if end_time:
        q = q.filter(
            (Timelog.start_time < end_time) &
            ((Timelog.end_time.is_(None)) | (Timelog.end_time > start_time))
        )
    else:
        q = q.filter(
            (Timelog.end_time.is_(None)) | (Timelog.end_time > start_time)
        )
    
    return q.all()