import asyncio
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Set

from fastapi import WebSocket
from redis import asyncio as aioredis
from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.protocols.utils import ClientDisconnected

# Configure structured logging for production
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


class ConnectionManager:
    def __init__(self):
        self.active_connections: Dict[str, List[WebSocket]] = {}
        self.socket_user_map: Dict[WebSocket, Dict] = {}
        self.user_sockets: Dict[int, Set[WebSocket]] = {}
        self.redis_client = None
        self.pubsub_client = None  # Dedicated connection for pubsub
        # Batch notification settings
        self._pending_broadcasts: Dict[str, List[str]] = defaultdict(list)
        self._broadcast_lock = asyncio.Lock()
        self._flush_task: Optional[asyncio.Task] = None
        self._flush_interval = 0.5  # 500ms batching window (increased for stability)
        # Shutdown flag
        self._shutdown = False

    async def connect(
        self,
        websocket: WebSocket,
        group_name: str,
        user_id: Optional[int] = None,
        user_name: Optional[str] = None,
    ):
        """
        Connect a websocket to a group.
        user_id and user_name are optional for backward compatibility.
        """
        await websocket.accept()
        if group_name not in self.active_connections:
            self.active_connections[group_name] = []
        self.active_connections[group_name].append(websocket)

        # Track user info for this socket (only if provided)
        if user_id is not None:
            self.socket_user_map[websocket] = {
                "user_id": user_id,
                "user_name": user_name,
                "group_name": group_name,
            }

            if user_id not in self.user_sockets:
                self.user_sockets[user_id] = set()
            self.user_sockets[user_id].add(websocket)

    async def disconnect(self, websocket: WebSocket, group_name: str):
        """Disconnect a websocket from a group"""
        try:
            await websocket.close()
        except Exception:
            pass
        await self.remove_client(websocket=websocket, group_name=group_name)

    async def remove_client(
        self, websocket: WebSocket, group_name: str
    ) -> Optional[Dict]:
        """
        Remove a client from tracking.
        Returns user_info if it was tracked (for slot locking cleanup).
        """
        # Remove from active connections
        if (
            group_name in self.active_connections
            and websocket in self.active_connections[group_name]
        ):
            self.active_connections[group_name].remove(websocket)

        # Get and remove user tracking (if exists)
        user_info = self.socket_user_map.pop(websocket, None)

        if user_info:
            user_id = user_info.get("user_id")
            if user_id and user_id in self.user_sockets:
                self.user_sockets[user_id].discard(websocket)
                if not self.user_sockets[user_id]:
                    del self.user_sockets[user_id]

        return user_info

    def get_user_info(self, websocket: WebSocket) -> Optional[Dict]:
        """Get user info associated with a websocket"""
        return self.socket_user_map.get(websocket)

    def is_user_connected(self, user_id: int) -> bool:
        """Check if user has any active connections"""
        return user_id in self.user_sockets and len(self.user_sockets[user_id]) > 0

    def register_socket_user(
        self,
        websocket: WebSocket,
        user_id: int,
        user_name: str,
        group_name: str,
        **extra_data,
    ):
        """
        Register user info for an already-connected socket.
        Used when auth happens after connection (e.g., via subscribe message).
        """
        self.socket_user_map[websocket] = {
            "user_id": user_id,
            "user_name": user_name,
            "group_name": group_name,
            **extra_data,
        }

        if user_id not in self.user_sockets:
            self.user_sockets[user_id] = set()
        self.user_sockets[user_id].add(websocket)

    async def broadcast(self, data: str, group_name: str):
        """Broadcast message to all connections in a group with timeout protection"""
        if self._shutdown:
            return

        connections = self.active_connections.get(group_name, [])

        if not connections:
            logger.info(f"No connections in group {group_name}")
            return

        dead_connections = []
        success_count = 0

        # Create a copy to avoid modification during iteration
        connections_copy = list(connections)

        for connection in connections_copy:
            if self._shutdown:
                break

            try:
                if connection.client_state != WebSocketState.CONNECTED:
                    dead_connections.append(connection)
                    continue

                # Add timeout to prevent blocking on slow clients
                await asyncio.wait_for(connection.send_text(data), timeout=5.0)
                success_count += 1

            except asyncio.TimeoutError:
                logger.warning(f"Timeout sending to client in {group_name}")
                dead_connections.append(connection)

            except (WebSocketDisconnect, ClientDisconnected):
                dead_connections.append(connection)

            except RuntimeError as e:
                if "close" in str(e).lower() or "write" in str(e).lower():
                    dead_connections.append(connection)
                else:
                    logger.error(f"RuntimeError during broadcast: {e}")
                    dead_connections.append(connection)

            except Exception as e:
                logger.error(
                    f"Unexpected error during broadcast: {type(e).__name__}: {e}"
                )
                dead_connections.append(connection)

        # Clean up dead connections
        for conn in dead_connections:
            try:
                await self.disconnect(websocket=conn, group_name=group_name)
                # if group_name in self.active_connections:
                #     if conn in self.active_connections[group_name]:
                #         self.active_connections[group_name].remove(conn)
                # self.socket_user_map.pop(conn, None)
            except Exception:
                pass

        fail_count = len(dead_connections)
        if fail_count > 0:
            logger.warning(
                f"Broadcast to {group_name}: {success_count} success, {fail_count} removed"
            )
        elif success_count > 0:
            logger.info(f"Broadcast to {group_name}: {success_count} clients")

    async def send_to_socket(self, websocket: WebSocket, data: str):
        """Send message to a specific socket"""
        try:
            if websocket.client_state == WebSocketState.CONNECTED:
                await websocket.send_text(data)
                return True
        except Exception as e:
            logger.error(f"Error sending to socket: {e}")
        return False

    async def publish_data(self, data: str, group_name: str):
        """
        Publish to Redis for cross-worker broadcast.
        Uses timeout to prevent blocking.
        """
        if self._shutdown:
            logger.info("[PUBLISH] Skipped - shutdown in progress")
            return

        try:
            if not self.redis_client:
                await self.connect_to_redis()

            group_name = str(group_name)
            logger.info(f"[PUBLISH] Channel: {group_name}")

            # Publish to Redis with timeout to prevent blocking
            result = await asyncio.wait_for(
                self.redis_client.publish(group_name, data), timeout=5.0
            )
            logger.info(f"[PUBLISH] Subscribers notified: {result}")

        except asyncio.TimeoutError:
            logger.warning(f"[PUBLISH] Timeout publishing to {group_name}")
        except asyncio.CancelledError:
            logger.info("[PUBLISH] Cancelled")
            raise
        except Exception as e:
            logger.error(f"[PUBLISH ERROR] {type(e).__name__}: {e}", exc_info=True)

    async def connect_to_redis(self):
        """Connect to Redis with optimized settings for pub/sub"""
        if self.redis_client:
            return

        try:
            # Main client for general operations
            self.redis_client = aioredis.Redis.from_url(
                "redis://localhost:6379",
                decode_responses=True,
                socket_connect_timeout=5,
                socket_keepalive=True,
                health_check_interval=30,
                socket_timeout=5,
            )
            # Test connection
            await self.redis_client.ping()
            logger.info("[REDIS] Main client connected successfully")

            # Dedicated client for pubsub (separate connection)
            self.pubsub_client = aioredis.Redis.from_url(
                "redis://localhost:6379",
                decode_responses=True,
                socket_connect_timeout=5,
                socket_keepalive=True,
                health_check_interval=0,  # Disable for pubsub
                socket_timeout=None,  # Pubsub needs to wait indefinitely
            )
            await self.pubsub_client.ping()
            logger.info("[REDIS] Pubsub client connected successfully")

            self._shutdown = False
        except Exception as e:
            logger.error(f"[REDIS] Connection failed: {e}")
            raise

    async def close_redis(self):
        """Close Redis connections gracefully"""
        self._shutdown = True
        logger.info("[REDIS] Initiating shutdown...")

        # Cancel flush task if running
        if self._flush_task and not self._flush_task.done():
            self._flush_task.cancel()
            try:
                await self._flush_task
            except asyncio.CancelledError:
                pass

        # Close pubsub client first
        if self.pubsub_client:
            try:
                await self.pubsub_client.aclose()
            except Exception as e:
                logger.warning(f"[REDIS] Error closing pubsub client: {e}")
            self.pubsub_client = None
            logger.info("[REDIS] Pubsub client closed")

        # Close main client
        if self.redis_client:
            try:
                await self.redis_client.aclose()
            except Exception as e:
                logger.warning(f"[REDIS] Error closing main client: {e}")
            self.redis_client = None
            logger.info("[REDIS] Main client closed")

    def is_shutting_down(self) -> bool:
        """Check if the manager is shutting down"""
        return self._shutdown

    async def queue_broadcast(self, data: str, group_name: str):
        """
        Queue a broadcast message for batched delivery.
        Prevents freezing by batching multiple messages together.
        Uses timeout on lock to prevent deadlocks.
        """
        if self._shutdown:
            return

        try:
            # Use wait_for to prevent blocking indefinitely on lock
            async with asyncio.timeout(1.0):
                async with self._broadcast_lock:
                    self._pending_broadcasts[group_name].append(data)

                    # Start flush task if not running
                    if self._flush_task is None or self._flush_task.done():
                        self._flush_task = asyncio.create_task(self._flush_broadcasts())
        except asyncio.TimeoutError:
            # If we can't acquire lock, broadcast directly (fallback)
            logger.warning(
                f"[QUEUE] Lock timeout, broadcasting directly to {group_name}"
            )
            await self.broadcast(data=data, group_name=group_name)
        except Exception as e:
            logger.error(f"[QUEUE ERROR] {type(e).__name__}: {e}")

    async def _flush_broadcasts(self):
        """Flush all pending broadcasts after a short delay (batching window)"""
        if self._shutdown:
            return

        try:
            await asyncio.sleep(self._flush_interval)

            # Get pending messages with lock
            async with asyncio.timeout(1.0):
                async with self._broadcast_lock:
                    pending = dict(self._pending_broadcasts)
                    self._pending_broadcasts.clear()

            # Process each group's messages without holding the lock
            for group_name, messages in pending.items():
                if self._shutdown:
                    return

                if not messages:
                    continue

                # Send all messages (important for real-time updates)
                for data in messages:
                    if self._shutdown:
                        return

                    try:
                        await self.broadcast(data=data, group_name=group_name)
                    except Exception as e:
                        logger.error(f"[FLUSH ERROR] Broadcasting to {group_name}: {e}")

        except asyncio.TimeoutError:
            logger.warning("[FLUSH] Lock timeout during flush")
        except asyncio.CancelledError:
            logger.info("[FLUSH] Cancelled")
        except Exception as e:
            logger.error(f"[FLUSH ERROR] {type(e).__name__}: {e}")


redis_manager = ConnectionManager()


async def listen_to_redis():
    """
    Each worker runs this:
      - Subscribes to Redis patterns
      - For each message, queues broadcast for batched delivery

    Subscribes to both table_* and slot_locks:* patterns.
    Uses get_message with timeout to prevent freezing and allow graceful shutdown.
    """
    pubsub = None
    reconnect_delay = 1  # Start with 1 second, will increase on failures

    while not redis_manager.is_shutting_down():
        try:
            # Ensure connection exists
            if not redis_manager.pubsub_client:
                await redis_manager.connect_to_redis()

            # Create pubsub from dedicated client
            pubsub = redis_manager.pubsub_client.pubsub()

            # Subscribe to patterns
            await pubsub.psubscribe("table_*", "slot_locks:*", "test")
            logger.info(
                "[REDIS SUB] Connected and subscribed (table_* and slot_locks:*)"
            )

            # Reset reconnect delay on successful connection
            reconnect_delay = 1

            # Use get_message with timeout instead of blocking listen()
            while not redis_manager.is_shutting_down():
                try:
                    # get_message with timeout allows checking shutdown flag
                    msg = await asyncio.wait_for(
                        pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0),
                        timeout=2.0,
                    )

                    if msg is None:
                        # No message received, continue loop (allows shutdown check)
                        continue

                    if msg["type"] != "pmessage":
                        continue

                    channel = msg["channel"]
                    data = msg["data"]

                    logger.info(f"[REDIS MSG] Channel: {channel}")

                    if channel in redis_manager.active_connections and len(
                        redis_manager.active_connections.get(channel, [])
                    ):
                        # Queue broadcast for batched delivery (prevents freezing)
                        await redis_manager.queue_broadcast(
                            data=data, group_name=channel
                        )

                except asyncio.TimeoutError:
                    # Timeout is expected, just continue to check shutdown flag
                    continue
                except asyncio.CancelledError:
                    logger.info("[REDIS SUB] Listener cancelled")
                    raise

        except asyncio.CancelledError:
            logger.info("[REDIS SUB] Cancelled, shutting down...")
            break

        except Exception as e:
            if redis_manager.is_shutting_down():
                break
            logger.error(f"[REDIS SUB ERROR] {type(e).__name__}: {e}", exc_info=True)
            # Exponential backoff with max of 30 seconds
            reconnect_delay = min(reconnect_delay * 2, 30)

        finally:
            if pubsub is not None:
                try:
                    await pubsub.punsubscribe("table_*", "slot_locks:*")
                    await pubsub.close()
                except Exception as e:
                    logger.info(
                        f"[REDIS SUB] Cleanup error (expected during shutdown): {e}"
                    )
                pubsub = None

        if not redis_manager.is_shutting_down():
            logger.warning(f"[REDIS SUB] Reconnecting in {reconnect_delay} seconds...")
            try:
                await asyncio.sleep(reconnect_delay)
            except asyncio.CancelledError:
                break

    logger.info("[REDIS SUB] Listener stopped")
