from sqlalchemy.orm import Session, joinedload, subqueryload
from typing import Optional, List
from .models import Watchlist, Watchlistitem
from stock_management.models import Stock


def get_by_id(db: Session, entity_id: str) -> Optional[Watchlist]:
    return db.query(Watchlist).filter(Watchlist.id == entity_id).first()


def list_all(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Watchlist]:
    query = db.query(Watchlist)
    
    if "user_id" in filters and filters["user_id"]:
        query = query.filter(Watchlist.user_id == filters["user_id"])
    
    if "is_default" in filters and filters["is_default"] is not None:
        query = query.filter(Watchlist.is_default == filters["is_default"])
    
    return query.order_by(Watchlist.sort_order.asc()).limit(limit).offset(offset).all()


def create(db: Session, data: dict) -> Watchlist:
    watchlist = Watchlist(**data)
    db.add(watchlist)
    db.flush()
    return watchlist


def update(db: Session, entity_id: str, data: dict) -> Optional[Watchlist]:
    watchlist = get_by_id(db, entity_id)
    if not watchlist:
        return None
    for key, value in data.items():
        setattr(watchlist, key, value)
    db.flush()
    return watchlist


def delete(db: Session, entity_id: str) -> bool:
    watchlist = get_by_id(db, entity_id)
    if not watchlist:
        return False
    db.delete(watchlist)
    db.flush()
    return True


def get_with_details(db: Session, watchlist_id: str) -> Optional[Watchlist]:
    return (
        db.query(Watchlist)
        .options(
            joinedload(Watchlist.watchlist_items)
                .joinedload(Watchlistitem.stock)
                .joinedload(Stock.sector),
            joinedload(Watchlist.watchlist_items)
                .joinedload(Watchlistitem.stock)
                .joinedload(Stock.exchange),
        )
        .filter(Watchlist.id == watchlist_id)
        .first()
    )


def get_by_user_and_name(db: Session, user_id: str, name: str) -> Optional[Watchlist]:
    return (
        db.query(Watchlist)
        .filter(Watchlist.user_id == user_id, Watchlist.name == name)
        .first()
    )


def count_by_user(db: Session, user_id: str) -> int:
    return db.query(Watchlist).filter(Watchlist.user_id == user_id).count()


def get_watchlistitem_by_id(db: Session, item_id: str) -> Optional[Watchlistitem]:
    return db.query(Watchlistitem).filter(Watchlistitem.id == item_id).first()


def list_watchlistitems(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Watchlistitem]:
    query = db.query(Watchlistitem)
    
    if "watchlist_id" in filters and filters["watchlist_id"]:
        query = query.filter(Watchlistitem.watchlist_id == filters["watchlist_id"])
    
    if "stock_id" in filters and filters["stock_id"]:
        query = query.filter(Watchlistitem.stock_id == filters["stock_id"])
    
    return query.order_by(Watchlistitem.sort_order.asc()).limit(limit).offset(offset).all()


def create_watchlistitem(db: Session, data: dict) -> Watchlistitem:
    item = Watchlistitem(**data)
    db.add(item)
    db.flush()
    return item


def update_watchlistitem(db: Session, item_id: str, data: dict) -> Optional[Watchlistitem]:
    item = get_watchlistitem_by_id(db, item_id)
    if not item:
        return None
    for key, value in data.items():
        setattr(item, key, value)
    db.flush()
    return item


def delete_watchlistitem(db: Session, item_id: str) -> bool:
    item = get_watchlistitem_by_id(db, item_id)
    if not item:
        return False
    db.delete(item)
    db.flush()
    return True


def get_watchlistitem_by_watchlist_and_stock(
    db: Session, watchlist_id: str, stock_id: str
) -> Optional[Watchlistitem]:
    return (
        db.query(Watchlistitem)
        .filter(
            Watchlistitem.watchlist_id == watchlist_id,
            Watchlistitem.stock_id == stock_id
        )
        .first()
    )


def count_items_in_watchlist(db: Session, watchlist_id: str) -> int:
    return db.query(Watchlistitem).filter(Watchlistitem.watchlist_id == watchlist_id).count()