"""WebSocket Hub — real-time bidirectional communication between Extensions and Bot. Replaces file-based IPC (bridge/) and HTTP polling (Collector/Gateway) with persistent WebSocket connections. Architecture: Extension ↔ WSS ↔ Hub ↔ Bot (in-process) ↔ Discord Each Extension connects via WebSocket, authenticates, and receives a unique instance number within its project. Messages are routed by project. """ import asyncio import json import logging import time import uuid from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Coroutine from aiohttp import web, WSMsgType from auth import TokenManager logger = logging.getLogger(__name__) # ─── Constants ─── HEARTBEAT_INTERVAL = 30.0 # seconds between server→client pings HEARTBEAT_TIMEOUT = 10.0 # seconds to wait for pong before disconnect CLIENT_QUEUE_SIZE = 100 # max queued messages per client (backpressure) MAX_MSG_SIZE = 1024 * 1024 # 1MB max WebSocket message size PER_CONN_RATE_LIMIT = 60 # max messages per 10s window per connection RATE_WINDOW = 10.0 # seconds class MsgType(str, Enum): """WebSocket protocol message types.""" # Extension → Hub (upstream) AUTH = "auth" PENDING = "pending" CHAT = "chat" REGISTER = "register" HEARTBEAT = "heartbeat" AUTO_RESOLVE = "auto_resolve" BRAIN_EVENT = "brain_event" # Hub → Extension (downstream) AUTH_OK = "auth_ok" AUTH_FAIL = "auth_fail" RESPONSE = "response" COMMAND = "command" INSTANCE_UPDATE = "instance_update" ERROR = "error" @dataclass class WSConnection: """Represents a connected Extension instance.""" conn_id: str ws: web.WebSocketResponse project: str = "" pc_name: str = "" instance_number: int = 0 connected_at: float = field(default_factory=time.time) last_heartbeat: float = field(default_factory=time.time) authenticated: bool = False send_queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue(maxsize=CLIENT_QUEUE_SIZE)) _sender_task: asyncio.Task | None = field(default=None, repr=False) # Rate limiting _msg_timestamps: list[float] = field(default_factory=list, repr=False) class WSHub: """WebSocket Hub for routing messages between Extensions and Bot. Responsibilities: - Connection lifecycle (auth, heartbeat, disconnect) - Message routing (project-scoped, instance-targeted) - Instance number management (auto-assign, reassign on disconnect) - Rate limiting per connection - Backpressure via per-client queues """ def __init__(self, token_manager: TokenManager): self.token_manager = token_manager self.connections: dict[str, WSConnection] = {} # conn_id → connection self.project_connections: dict[str, set[str]] = {} # project → {conn_ids} self.pending_owners: dict[str, str] = {} # request_id → conn_id self._recent_msg_ids: dict[str, float] = {} # msg_id → timestamp (dedup) self._msg_id_ttl = 300.0 # 5 min dedup window # Bot callbacks — set by bot.py during initialization self._on_pending: Callable[..., Coroutine] | None = None self._on_chat: Callable[..., Coroutine] | None = None self._on_register: Callable[..., Coroutine] | None = None self._on_auto_resolve: Callable[..., Coroutine] | None = None self._on_brain_event: Callable[..., Coroutine] | None = None # ─── Bot Integration ─── def set_bot_handlers( self, on_pending: Callable | None = None, on_chat: Callable | None = None, on_register: Callable | None = None, on_auto_resolve: Callable | None = None, on_brain_event: Callable | None = None, ): """Register bot callback functions for incoming Extension messages.""" self._on_pending = on_pending self._on_chat = on_chat self._on_register = on_register self._on_auto_resolve = on_auto_resolve self._on_brain_event = on_brain_event # ─── Connection Management ─── async def handle_ws(self, request: web.Request) -> web.WebSocketResponse: """Handle a new WebSocket connection. Protocol: 1. Client connects 2. Client sends {type:"auth", token:"...", project:"...", pc:"..."} OR {type:"auth", registration_code:"...", project:"...", pc:"..."} 3. Server responds {type:"auth_ok", conn_id, instance_number} OR {type:"auth_fail", reason:"..."} 4. Bidirectional message exchange """ ws = web.WebSocketResponse( max_msg_size=MAX_MSG_SIZE, heartbeat=HEARTBEAT_INTERVAL, ) await ws.prepare(request) conn_id = uuid.uuid4().hex[:12] conn = WSConnection(conn_id=conn_id, ws=ws) remote = request.remote or "unknown" logger.info(f"[HUB] New WS connection: {conn_id} from {remote}") try: # Wait for auth message (first message must be auth) auth_ok = await self._handle_auth(conn, timeout=10.0) if not auth_ok: return ws # Register connection self._register_connection(conn) # Start sender task (per-client queue → ws.send) conn._sender_task = asyncio.create_task( self._sender_loop(conn), name=f"sender-{conn_id}" ) # Message loop await self._message_loop(conn) except Exception as e: logger.error(f"[HUB] Connection {conn_id} error: {e}") finally: await self._disconnect(conn) return ws async def _handle_auth(self, conn: WSConnection, timeout: float = 10.0) -> bool: """Wait for and process the auth message.""" try: msg = await asyncio.wait_for(conn.ws.receive(), timeout=timeout) except asyncio.TimeoutError: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Auth timeout" }) await conn.ws.close() return False if msg.type != WSMsgType.TEXT: await conn.ws.close() return False try: data = json.loads(msg.data) except json.JSONDecodeError: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Invalid JSON" }) await conn.ws.close() return False if data.get("type") != MsgType.AUTH: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "First message must be auth" }) await conn.ws.close() return False project = data.get("project", "") pc_name = data.get("pc", "unknown") if not project: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Project name required" }) await conn.ws.close() return False # Try token auth first, then registration code token = data.get("token", "") reg_code = data.get("registration_code", "") if token: payload = self.token_manager.verify_token(token) if not payload: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Invalid or expired token" }) await conn.ws.close() return False # Token is valid — use project from token (overrides client) project = payload.get("project", project) pc_name = payload.get("pc", pc_name) elif reg_code: if not self.token_manager.validate_registration_code(reg_code): await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Invalid registration code" }) await conn.ws.close() return False else: # No auth provided — check if registration code is configured if self.token_manager.registration_code: await self._send_direct(conn.ws, { "type": MsgType.AUTH_FAIL, "reason": "Auth required" }) await conn.ws.close() return False # No registration code configured → allow (dev mode) # Auth successful conn.project = project conn.pc_name = pc_name conn.authenticated = True # Issue a session token (for reconnection) session_token = self.token_manager.create_token(project, pc_name) # Assign instance number instance_number = self._assign_instance_number(project, conn.conn_id) conn.instance_number = instance_number await self._send_direct(conn.ws, { "type": MsgType.AUTH_OK, "conn_id": conn.conn_id, "instance_number": instance_number, "session_token": session_token, "active_count": self.get_active_count(project) + 1, # +1 for self (not yet registered) }) logger.info( f"[HUB] Auth OK: {conn.conn_id} project={project} pc={pc_name} " f"instance=#{instance_number}" ) return True def _register_connection(self, conn: WSConnection): """Add connection to tracking structures.""" self.connections[conn.conn_id] = conn if conn.project not in self.project_connections: self.project_connections[conn.project] = set() self.project_connections[conn.project].add(conn.conn_id) # Broadcast instance update to all project connections asyncio.create_task(self._broadcast_instance_update(conn.project)) async def _disconnect(self, conn: WSConnection): """Clean up after a connection closes.""" conn_id = conn.conn_id project = conn.project # Cancel sender task if conn._sender_task and not conn._sender_task.done(): conn._sender_task.cancel() try: await conn._sender_task except asyncio.CancelledError: pass # Remove from tracking self.connections.pop(conn_id, None) if project in self.project_connections: self.project_connections[project].discard(conn_id) if not self.project_connections[project]: del self.project_connections[project] # Clean up pending ownership stale = [rid for rid, cid in self.pending_owners.items() if cid == conn_id] for rid in stale: del self.pending_owners[rid] # Close WebSocket if still open if not conn.ws.closed: await conn.ws.close() logger.info(f"[HUB] Disconnected: {conn_id} project={project}") # Broadcast instance update (remaining connections get notified) if project: await self._broadcast_instance_update(project) # ─── Instance Number Management ─── def _assign_instance_number(self, project: str, conn_id: str) -> int: """Assign the next available instance number for a project.""" used = set() for cid in self.project_connections.get(project, set()): c = self.connections.get(cid) if c: used.add(c.instance_number) # Find lowest available number starting from 1 num = 1 while num in used: num += 1 return num def get_active_count(self, project: str) -> int: """Get the number of active connections for a project.""" return len(self.project_connections.get(project, set())) async def _broadcast_instance_update(self, project: str): """Notify all project connections of instance count change.""" count = self.get_active_count(project) msg = { "type": MsgType.INSTANCE_UPDATE, "active_count": count, "instances": [], } # Include instance list for UI for cid in self.project_connections.get(project, set()): c = self.connections.get(cid) if c: msg["instances"].append({ "instance_number": c.instance_number, "pc": c.pc_name, }) await self.broadcast_to_project(project, msg) # ─── Message Routing ─── async def broadcast_to_project(self, project: str, message: dict): """Send a message to all connections in a project.""" conn_ids = self.project_connections.get(project, set()).copy() for cid in conn_ids: conn = self.connections.get(cid) if conn and conn.authenticated: await self._queue_send(conn, message) async def send_to_connection(self, conn_id: str, message: dict): """Send a message to a specific connection.""" conn = self.connections.get(conn_id) if conn and conn.authenticated: await self._queue_send(conn, message) async def send_to_instance(self, project: str, instance_number: int, message: dict): """Send a message to a specific instance number within a project.""" for cid in self.project_connections.get(project, set()): conn = self.connections.get(cid) if conn and conn.instance_number == instance_number: await self._queue_send(conn, message) return True logger.warning( f"[HUB] Instance #{instance_number} not found for project={project}" ) return False async def send_response_to_pending_owner(self, request_id: str, message: dict): """Route a response to the Extension that created the pending request.""" conn_id = self.pending_owners.get(request_id) if conn_id: await self.send_to_connection(conn_id, message) # Clean up after response delivered self.pending_owners.pop(request_id, None) return True logger.warning(f"[HUB] No owner for pending {request_id[:12]}") return False # ─── Per-Client Send Queue (backpressure) ─── async def _queue_send(self, conn: WSConnection, message: dict): """Queue a message for sending via the per-client sender loop.""" try: conn.send_queue.put_nowait(message) except asyncio.QueueFull: logger.warning( f"[HUB] Queue full for {conn.conn_id} — dropping oldest message" ) # Drop oldest to make room try: conn.send_queue.get_nowait() except asyncio.QueueEmpty: pass try: conn.send_queue.put_nowait(message) except asyncio.QueueFull: pass async def _sender_loop(self, conn: WSConnection): """Dedicated sender coroutine for a single connection. Reads from the per-client queue and sends via WebSocket. This prevents slow clients from blocking message routing. """ try: while not conn.ws.closed: try: msg = await asyncio.wait_for( conn.send_queue.get(), timeout=HEARTBEAT_INTERVAL ) await conn.ws.send_json(msg) except asyncio.TimeoutError: continue # No message to send — loop continues except ConnectionResetError: break except Exception as e: logger.error(f"[HUB] Send error {conn.conn_id}: {e}") break except asyncio.CancelledError: pass async def _send_direct(self, ws: web.WebSocketResponse, message: dict): """Send directly (bypasses queue — for auth messages only).""" try: await ws.send_json(message) except Exception as e: logger.error(f"[HUB] Direct send error: {e}") # ─── Message Loop ─── async def _message_loop(self, conn: WSConnection): """Main message receive loop for an authenticated connection.""" async for msg in conn.ws: if msg.type == WSMsgType.TEXT: # Rate limit check if not self._check_rate_limit(conn): await self._queue_send(conn, { "type": MsgType.ERROR, "error": "Rate limited — too many messages", }) continue try: data = json.loads(msg.data) except json.JSONDecodeError: continue msg_type = data.get("type", "") # Dedup check msg_id = data.get("msg_id") if msg_id and self._is_duplicate(msg_id): continue await self._handle_message(conn, msg_type, data) elif msg.type == WSMsgType.ERROR: logger.error( f"[HUB] WS error {conn.conn_id}: {conn.ws.exception()}" ) break elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): break async def _handle_message(self, conn: WSConnection, msg_type: str, data: dict): """Route an incoming message to the appropriate handler.""" conn.last_heartbeat = time.time() # Any message counts as heartbeat if msg_type == MsgType.PENDING: payload = data.get("data", {}) request_id = payload.get("request_id", "") if request_id: # Track ownership for response routing self.pending_owners[request_id] = conn.conn_id # Add source metadata payload["_conn_id"] = conn.conn_id payload["_instance_number"] = conn.instance_number payload["_pc_name"] = conn.pc_name payload.setdefault("project_name", conn.project) if self._on_pending: await self._on_pending(conn.project, payload) elif msg_type == MsgType.CHAT: payload = data.get("data", {}) payload.setdefault("project_name", conn.project) payload["_instance_number"] = conn.instance_number payload["_pc_name"] = conn.pc_name if self._on_chat: await self._on_chat(conn.project, payload) elif msg_type == MsgType.REGISTER: payload = data.get("data", {}) payload.setdefault("project_name", conn.project) if self._on_register: await self._on_register(payload) elif msg_type == MsgType.AUTO_RESOLVE: payload = data.get("data", {}) request_id = payload.get("request_id", "") if request_id: self.pending_owners.pop(request_id, None) if self._on_auto_resolve: await self._on_auto_resolve(conn.project, payload) elif msg_type == MsgType.BRAIN_EVENT: payload = data.get("data", {}) payload.setdefault("project_name", conn.project) if self._on_brain_event: await self._on_brain_event(conn.project, payload) elif msg_type == MsgType.HEARTBEAT: pass # last_heartbeat already updated above else: logger.warning(f"[HUB] Unknown message type: {msg_type} from {conn.conn_id}") # ─── Rate Limiting ─── def _check_rate_limit(self, conn: WSConnection) -> bool: """Per-connection rate limiting. Returns True if allowed.""" now = time.time() conn._msg_timestamps = [ t for t in conn._msg_timestamps if now - t < RATE_WINDOW ] if len(conn._msg_timestamps) >= PER_CONN_RATE_LIMIT: logger.warning(f"[HUB] Rate limited: {conn.conn_id}") return False conn._msg_timestamps.append(now) return True # ─── Deduplication ─── def _is_duplicate(self, msg_id: str) -> bool: """Check if a message ID was recently seen.""" now = time.time() # Cleanup old entries if len(self._recent_msg_ids) > 10000: cutoff = now - self._msg_id_ttl self._recent_msg_ids = { k: v for k, v in self._recent_msg_ids.items() if v > cutoff } if msg_id in self._recent_msg_ids: return True self._recent_msg_ids[msg_id] = now return False # ─── Diagnostics ─── def get_status(self) -> dict: """Return hub status for health/diagnostics.""" projects = {} for project, conn_ids in self.project_connections.items(): instances = [] for cid in conn_ids: c = self.connections.get(cid) if c: instances.append({ "conn_id": c.conn_id, "instance": c.instance_number, "pc": c.pc_name, "connected": int(time.time() - c.connected_at), }) projects[project] = { "count": len(conn_ids), "instances": instances, } return { "total_connections": len(self.connections), "projects": projects, "pending_owners": len(self.pending_owners), }