from sqlalchemy.orm import Session, joinedload, subqueryload
from typing import Optional, List
from .models import Sector, Exchange, Stock, Stockhistory


# Sector repository functions
def get_by_id(db: Session, entity_id: str) -> Optional[Sector]:
    return db.query(Sector).filter(Sector.id == entity_id).first()


def list_all(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Sector]:
    query = db.query(Sector)
    if filters.get("search"):
        search_term = f"%{filters['search']}%"
        query = query.filter(Sector.name.ilike(search_term))
    return query.order_by(Sector.display_order, Sector.name).limit(limit).offset(offset).all()


def create(db: Session, data: dict) -> Sector:
    sector = Sector(**data)
    db.add(sector)
    db.flush()
    return sector


def update(db: Session, entity_id: str, data: dict) -> Optional[Sector]:
    sector = get_by_id(db, entity_id)
    if not sector:
        return None
    for key, value in data.items():
        setattr(sector, key, value)
    db.flush()
    return sector


def delete(db: Session, entity_id: str) -> bool:
    sector = get_by_id(db, entity_id)
    if not sector:
        return False
    db.delete(sector)
    db.flush()
    return True


def get_sector_by_name(db: Session, name: str) -> Optional[Sector]:
    return db.query(Sector).filter(Sector.name == name).first()


# Exchange repository functions
def get_exchange_by_id(db: Session, entity_id: str) -> Optional[Exchange]:
    return db.query(Exchange).filter(Exchange.id == entity_id).first()


def list_exchanges(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Exchange]:
    query = db.query(Exchange)
    if filters.get("search"):
        search_term = f"%{filters['search']}%"
        query = query.filter(
            (Exchange.name.ilike(search_term)) | (Exchange.code.ilike(search_term))
        )
    if filters.get("country"):
        query = query.filter(Exchange.country == filters["country"])
    return query.order_by(Exchange.name).limit(limit).offset(offset).all()


def create_exchange(db: Session, data: dict) -> Exchange:
    exchange = Exchange(**data)
    db.add(exchange)
    db.flush()
    return exchange


def update_exchange(db: Session, entity_id: str, data: dict) -> Optional[Exchange]:
    exchange = get_exchange_by_id(db, entity_id)
    if not exchange:
        return None
    for key, value in data.items():
        setattr(exchange, key, value)
    db.flush()
    return exchange


def delete_exchange(db: Session, entity_id: str) -> bool:
    exchange = get_exchange_by_id(db, entity_id)
    if not exchange:
        return False
    db.delete(exchange)
    db.flush()
    return True


def get_exchange_by_code(db: Session, code: str) -> Optional[Exchange]:
    return db.query(Exchange).filter(Exchange.code == code).first()


# Stock repository functions
def get_stock_by_id(db: Session, entity_id: str) -> Optional[Stock]:
    return db.query(Stock).filter(Stock.id == entity_id).first()


def list_stocks(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Stock]:
    query = db.query(Stock)
    if filters.get("search"):
        search_term = f"%{filters['search']}%"
        query = query.filter(
            (Stock.ticker_symbol.ilike(search_term)) | (Stock.company_name.ilike(search_term))
        )
    if filters.get("sector_id"):
        query = query.filter(Stock.sector_id == filters["sector_id"])
    if filters.get("exchange_id"):
        query = query.filter(Stock.exchange_id == filters["exchange_id"])
    if filters.get("status"):
        query = query.filter(Stock.status == filters["status"])
    if filters.get("min_price") is not None:
        query = query.filter(Stock.current_price >= filters["min_price"])
    if filters.get("max_price") is not None:
        query = query.filter(Stock.current_price <= filters["max_price"])
    return query.order_by(Stock.ticker_symbol).limit(limit).offset(offset).all()


def create_stock(db: Session, data: dict) -> Stock:
    stock = Stock(**data)
    db.add(stock)
    db.flush()
    return stock


def update_stock(db: Session, entity_id: str, data: dict) -> Optional[Stock]:
    stock = get_stock_by_id(db, entity_id)
    if not stock:
        return None
    for key, value in data.items():
        setattr(stock, key, value)
    db.flush()
    return stock


def delete_stock(db: Session, entity_id: str) -> bool:
    stock = get_stock_by_id(db, entity_id)
    if not stock:
        return False
    db.delete(stock)
    db.flush()
    return True


def get_stock_by_ticker(db: Session, ticker_symbol: str) -> Optional[Stock]:
    return db.query(Stock).filter(Stock.ticker_symbol == ticker_symbol).first()


def get_stock_by_company_name(db: Session, company_name: str) -> Optional[Stock]:
    return db.query(Stock).filter(Stock.company_name == company_name).first()


def get_stock_with_details(db: Session, stock_id: str) -> Optional[Stock]:
    return (
        db.query(Stock)
        .options(
            joinedload(Stock.sector),
            joinedload(Stock.exchange),
        )
        .filter(Stock.id == stock_id)
        .first()
    )


def get_stock_with_history(db: Session, stock_id: str) -> Optional[Stock]:
    return (
        db.query(Stock)
        .options(
            subqueryload(Stock.stock_history),
        )
        .filter(Stock.id == stock_id)
        .first()
    )


def search_stocks_autocomplete(db: Session, search_term: str, limit: int = 10) -> List[Stock]:
    search_pattern = f"%{search_term}%"
    return (
        db.query(Stock)
        .filter(
            (Stock.ticker_symbol.ilike(search_pattern)) | (Stock.company_name.ilike(search_pattern))
        )
        .filter(Stock.status == "ACTIVE")
        .order_by(Stock.ticker_symbol)
        .limit(limit)
        .all()
    )


# StockHistory repository functions
def get_stockhistory_by_id(db: Session, entity_id: str) -> Optional[Stockhistory]:
    return db.query(Stockhistory).filter(Stockhistory.id == entity_id).first()


def list_stockhistory(db: Session, limit: int = 20, offset: int = 0, **filters) -> List[Stockhistory]:
    query = db.query(Stockhistory)
    if filters.get("stock_id"):
        query = query.filter(Stockhistory.stock_id == filters["stock_id"])
    if filters.get("start_date"):
        query = query.filter(Stockhistory.date >= filters["start_date"])
    if filters.get("end_date"):
        query = query.filter(Stockhistory.date <= filters["end_date"])
    return query.order_by(Stockhistory.date.desc()).limit(limit).offset(offset).all()


def create_stockhistory(db: Session, data: dict) -> Stockhistory:
    stockhistory = Stockhistory(**data)
    db.add(stockhistory)
    db.flush()
    return stockhistory


def update_stockhistory(db: Session, entity_id: str, data: dict) -> Optional[Stockhistory]:
    stockhistory = get_stockhistory_by_id(db, entity_id)
    if not stockhistory:
        return None
    for key, value in data.items():
        setattr(stockhistory, key, value)
    db.flush()
    return stockhistory


def delete_stockhistory(db: Session, entity_id: str) -> bool:
    stockhistory = get_stockhistory_by_id(db, entity_id)
    if not stockhistory:
        return False
    db.delete(stockhistory)
    db.flush()
    return True


def get_stockhistory_by_stock_and_date(db: Session, stock_id: str, date) -> Optional[Stockhistory]:
    return (
        db.query(Stockhistory)
        .filter(Stockhistory.stock_id == stock_id, Stockhistory.date == date)
        .first()
    )