Files
gravity_control/hub.py

654 lines
25 KiB
Python

"""WebSocket Hub — real-time bidirectional communication between Extensions and Bot.
Replaces file-based IPC (bridge/) and HTTP polling (Collector/Gateway) with
persistent WebSocket connections.
Architecture:
Extension ↔ WSS ↔ Hub ↔ Bot (in-process) ↔ Discord
Each Extension connects via WebSocket, authenticates, and receives a unique
instance number within its project. Messages are routed by project.
"""
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine
from aiohttp import web, WSMsgType
from auth import TokenManager
logger = logging.getLogger(__name__)
# ─── Constants ───
HEARTBEAT_INTERVAL = 30.0 # seconds between server→client pings
HEARTBEAT_TIMEOUT = 10.0 # seconds to wait for pong before disconnect
CLIENT_QUEUE_SIZE = 100 # max queued messages per client (backpressure)
MAX_MSG_SIZE = 1024 * 1024 # 1MB max WebSocket message size
PER_CONN_RATE_LIMIT = 60 # max messages per 10s window per connection
RATE_WINDOW = 10.0 # seconds
class MsgType(str, Enum):
"""WebSocket protocol message types."""
# Extension → Hub (upstream)
AUTH = "auth"
PENDING = "pending"
CHAT = "chat"
REGISTER = "register"
HEARTBEAT = "heartbeat"
AUTO_RESOLVE = "auto_resolve"
BRAIN_EVENT = "brain_event"
# Hub → Extension (downstream)
AUTH_OK = "auth_ok"
AUTH_FAIL = "auth_fail"
RESPONSE = "response"
COMMAND = "command"
INSTANCE_UPDATE = "instance_update"
ERROR = "error"
@dataclass
class WSConnection:
"""Represents a connected Extension instance."""
conn_id: str
ws: web.WebSocketResponse
project: str = ""
pc_name: str = ""
instance_number: int = 0
connected_at: float = field(default_factory=time.time)
last_heartbeat: float = field(default_factory=time.time)
authenticated: bool = False
send_queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue(maxsize=CLIENT_QUEUE_SIZE))
_sender_task: asyncio.Task | None = field(default=None, repr=False)
# Rate limiting
_msg_timestamps: list[float] = field(default_factory=list, repr=False)
class WSHub:
"""WebSocket Hub for routing messages between Extensions and Bot.
Responsibilities:
- Connection lifecycle (auth, heartbeat, disconnect)
- Message routing (project-scoped, instance-targeted)
- Instance number management (auto-assign, reassign on disconnect)
- Rate limiting per connection
- Backpressure via per-client queues
"""
def __init__(self, token_manager: TokenManager):
self.token_manager = token_manager
self.connections: dict[str, WSConnection] = {} # conn_id → connection
self.project_connections: dict[str, set[str]] = {} # project → {conn_ids}
self.pending_owners: dict[str, str] = {} # request_id → conn_id
self._recent_msg_ids: dict[str, float] = {} # msg_id → timestamp (dedup)
self._msg_id_ttl = 300.0 # 5 min dedup window
# Bot callbacks — set by bot.py during initialization
self._on_pending: Callable[..., Coroutine] | None = None
self._on_chat: Callable[..., Coroutine] | None = None
self._on_register: Callable[..., Coroutine] | None = None
self._on_auto_resolve: Callable[..., Coroutine] | None = None
self._on_brain_event: Callable[..., Coroutine] | None = None
# ─── Bot Integration ───
def set_bot_handlers(
self,
on_pending: Callable | None = None,
on_chat: Callable | None = None,
on_register: Callable | None = None,
on_auto_resolve: Callable | None = None,
on_brain_event: Callable | None = None,
):
"""Register bot callback functions for incoming Extension messages."""
self._on_pending = on_pending
self._on_chat = on_chat
self._on_register = on_register
self._on_auto_resolve = on_auto_resolve
self._on_brain_event = on_brain_event
# ─── Connection Management ───
async def handle_ws(self, request: web.Request) -> web.WebSocketResponse:
"""Handle a new WebSocket connection.
Protocol:
1. Client connects
2. Client sends {type:"auth", token:"...", project:"...", pc:"..."}
OR {type:"auth", registration_code:"...", project:"...", pc:"..."}
3. Server responds {type:"auth_ok", conn_id, instance_number}
OR {type:"auth_fail", reason:"..."}
4. Bidirectional message exchange
"""
ws = web.WebSocketResponse(
max_msg_size=MAX_MSG_SIZE,
heartbeat=HEARTBEAT_INTERVAL,
)
await ws.prepare(request)
conn_id = uuid.uuid4().hex[:12]
conn = WSConnection(conn_id=conn_id, ws=ws)
remote = request.remote or "unknown"
logger.info(f"[HUB] New WS connection: {conn_id} from {remote}")
try:
# Wait for auth message (first message must be auth)
auth_ok = await self._handle_auth(conn, timeout=10.0)
if not auth_ok:
return ws
# Register connection
self._register_connection(conn)
# Start sender task (per-client queue → ws.send)
conn._sender_task = asyncio.create_task(
self._sender_loop(conn), name=f"sender-{conn_id}"
)
# Message loop
await self._message_loop(conn)
except Exception as e:
logger.error(f"[HUB] Connection {conn_id} error: {e}")
finally:
await self._disconnect(conn)
return ws
async def _handle_auth(self, conn: WSConnection, timeout: float = 10.0) -> bool:
"""Wait for and process the auth message."""
try:
msg = await asyncio.wait_for(conn.ws.receive(), timeout=timeout)
except asyncio.TimeoutError:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Auth timeout"
})
await conn.ws.close()
return False
if msg.type != WSMsgType.TEXT:
await conn.ws.close()
return False
try:
data = json.loads(msg.data)
except json.JSONDecodeError:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Invalid JSON"
})
await conn.ws.close()
return False
if data.get("type") != MsgType.AUTH:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "First message must be auth"
})
await conn.ws.close()
return False
project = data.get("project", "")
pc_name = data.get("pc", "unknown")
if not project:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Project name required"
})
await conn.ws.close()
return False
# Try token auth first, then registration code
token = data.get("token", "")
reg_code = data.get("registration_code", "")
if token:
payload = self.token_manager.verify_token(token)
if not payload:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Invalid or expired token"
})
await conn.ws.close()
return False
# Token is valid — use project from token (overrides client)
project = payload.get("project", project)
pc_name = payload.get("pc", pc_name)
elif reg_code:
if not self.token_manager.validate_registration_code(reg_code):
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Invalid registration code"
})
await conn.ws.close()
return False
else:
# No auth provided — check if registration code is configured
if self.token_manager.registration_code:
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_FAIL, "reason": "Auth required"
})
await conn.ws.close()
return False
# No registration code configured → allow (dev mode)
# Auth successful
conn.project = project
conn.pc_name = pc_name
conn.authenticated = True
# Issue a session token (for reconnection)
session_token = self.token_manager.create_token(project, pc_name)
# Assign instance number
instance_number = self._assign_instance_number(project, conn.conn_id)
conn.instance_number = instance_number
await self._send_direct(conn.ws, {
"type": MsgType.AUTH_OK,
"conn_id": conn.conn_id,
"instance_number": instance_number,
"session_token": session_token,
"active_count": self.get_active_count(project) + 1, # +1 for self (not yet registered)
})
logger.info(
f"[HUB] Auth OK: {conn.conn_id} project={project} pc={pc_name} "
f"instance=#{instance_number}"
)
return True
def _register_connection(self, conn: WSConnection):
"""Add connection to tracking structures."""
self.connections[conn.conn_id] = conn
if conn.project not in self.project_connections:
self.project_connections[conn.project] = set()
self.project_connections[conn.project].add(conn.conn_id)
# FIX: Reassign orphaned pending_owners (from dead conn_ids or orphan markers)
# to this new connection. When Extension reconnects, old entries become stale.
reassigned = 0
for rid, cid in list(self.pending_owners.items()):
if cid.startswith("orphan:"):
# Only reassign orphans from the same project
if cid != f"orphan:{conn.project}":
continue
self.pending_owners[rid] = conn.conn_id
reassigned += 1
elif cid not in self.connections:
self.pending_owners[rid] = conn.conn_id
reassigned += 1
if reassigned:
logger.info(
f"[HUB] Reassigned {reassigned} orphaned pending_owners "
f"to new conn {conn.conn_id} (project={conn.project})"
)
# Broadcast instance update to all project connections
asyncio.create_task(self._broadcast_instance_update(conn.project))
async def _disconnect(self, conn: WSConnection):
"""Clean up after a connection closes."""
conn_id = conn.conn_id
project = conn.project
# Cancel sender task
if conn._sender_task and not conn._sender_task.done():
conn._sender_task.cancel()
try:
await conn._sender_task
except asyncio.CancelledError:
pass
# Remove from tracking
self.connections.pop(conn_id, None)
if project in self.project_connections:
self.project_connections[project].discard(conn_id)
if not self.project_connections[project]:
del self.project_connections[project]
# Reassign pending ownership to another connection in same project
# (instead of deleting — prevents approval responses from being lost)
stale = [rid for rid, cid in self.pending_owners.items() if cid == conn_id]
if stale:
remaining = self.project_connections.get(project, set()) - {conn_id}
new_owner = None
for cid in remaining:
c = self.connections.get(cid)
if c and c.authenticated:
new_owner = cid
break
for rid in stale:
if new_owner:
self.pending_owners[rid] = new_owner
logger.info(
f"[HUB] Reassigned pending {rid[:12]}{new_owner} "
f"(disconnected {conn_id})"
)
else:
# Preserve as orphan marker instead of deleting —
# will be reassigned when Extension reconnects
self.pending_owners[rid] = f"orphan:{project}"
logger.info(
f"[HUB] Orphaned pending {rid[:12]} "
f"(no active connections in {project})"
)
# Close WebSocket if still open
if not conn.ws.closed:
await conn.ws.close()
logger.info(f"[HUB] Disconnected: {conn_id} project={project}")
# Broadcast instance update (remaining connections get notified)
if project:
await self._broadcast_instance_update(project)
# ─── Instance Number Management ───
def _assign_instance_number(self, project: str, conn_id: str) -> int:
"""Assign the next available instance number for a project."""
used = set()
for cid in self.project_connections.get(project, set()):
c = self.connections.get(cid)
if c:
used.add(c.instance_number)
# Find lowest available number starting from 1
num = 1
while num in used:
num += 1
return num
def get_active_count(self, project: str) -> int:
"""Get the number of active connections for a project."""
return len(self.project_connections.get(project, set()))
async def _broadcast_instance_update(self, project: str):
"""Notify all project connections of instance count change."""
count = self.get_active_count(project)
msg = {
"type": MsgType.INSTANCE_UPDATE,
"active_count": count,
"instances": [],
}
# Include instance list for UI
for cid in self.project_connections.get(project, set()):
c = self.connections.get(cid)
if c:
msg["instances"].append({
"instance_number": c.instance_number,
"pc": c.pc_name,
})
await self.broadcast_to_project(project, msg)
# ─── Message Routing ───
async def broadcast_to_project(self, project: str, message: dict):
"""Send a message to all connections in a project."""
conn_ids = self.project_connections.get(project, set()).copy()
for cid in conn_ids:
conn = self.connections.get(cid)
if conn and conn.authenticated:
await self._queue_send(conn, message)
async def send_to_connection(self, conn_id: str, message: dict):
"""Send a message to a specific connection."""
conn = self.connections.get(conn_id)
if conn and conn.authenticated:
await self._queue_send(conn, message)
async def send_to_instance(self, project: str, instance_number: int, message: dict):
"""Send a message to a specific instance number within a project."""
for cid in self.project_connections.get(project, set()):
conn = self.connections.get(cid)
if conn and conn.instance_number == instance_number:
await self._queue_send(conn, message)
return True
logger.warning(
f"[HUB] Instance #{instance_number} not found for project={project}"
)
return False
async def send_response_to_pending_owner(self, request_id: str, message: dict):
"""Route a response to the Extension that created the pending request.
Falls back to any active connection in the same project if the
original owner disconnected (e.g. Extension WS reconnected with
a new conn_id).
"""
conn_id = self.pending_owners.get(request_id)
if conn_id:
conn = self.connections.get(conn_id)
if conn and conn.authenticated and not conn.ws.closed:
await self.send_to_connection(conn_id, message)
self.pending_owners.pop(request_id, None)
return True
# Original owner dead — try to find any active connection in same project
project = conn.project if conn else None
if project:
for cid in self.project_connections.get(project, set()):
c = self.connections.get(cid)
if c and c.authenticated and not c.ws.closed:
await self.send_to_connection(cid, message)
self.pending_owners.pop(request_id, None)
logger.info(
f"[HUB] Rerouted response {request_id[:12]}{cid} "
f"(original {conn_id} dead)"
)
return True
# No active connection found — clean up
self.pending_owners.pop(request_id, None)
logger.warning(
f"[HUB] Response {request_id[:12]} lost: owner {conn_id} dead, "
f"no active connections in project"
)
return False
logger.warning(f"[HUB] No owner for pending {request_id[:12]}")
return False
# ─── Per-Client Send Queue (backpressure) ───
async def _queue_send(self, conn: WSConnection, message: dict):
"""Queue a message for sending via the per-client sender loop."""
try:
conn.send_queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning(
f"[HUB] Queue full for {conn.conn_id} — dropping oldest message"
)
# Drop oldest to make room
try:
conn.send_queue.get_nowait()
except asyncio.QueueEmpty:
pass
try:
conn.send_queue.put_nowait(message)
except asyncio.QueueFull:
pass
async def _sender_loop(self, conn: WSConnection):
"""Dedicated sender coroutine for a single connection.
Reads from the per-client queue and sends via WebSocket.
This prevents slow clients from blocking message routing.
"""
try:
while not conn.ws.closed:
try:
msg = await asyncio.wait_for(
conn.send_queue.get(), timeout=HEARTBEAT_INTERVAL
)
await conn.ws.send_json(msg)
except asyncio.TimeoutError:
continue # No message to send — loop continues
except ConnectionResetError:
break
except Exception as e:
logger.error(f"[HUB] Send error {conn.conn_id}: {e}")
break
except asyncio.CancelledError:
pass
async def _send_direct(self, ws: web.WebSocketResponse, message: dict):
"""Send directly (bypasses queue — for auth messages only)."""
try:
await ws.send_json(message)
except Exception as e:
logger.error(f"[HUB] Direct send error: {e}")
# ─── Message Loop ───
async def _message_loop(self, conn: WSConnection):
"""Main message receive loop for an authenticated connection."""
async for msg in conn.ws:
if msg.type == WSMsgType.TEXT:
# Rate limit check
if not self._check_rate_limit(conn):
await self._queue_send(conn, {
"type": MsgType.ERROR,
"error": "Rate limited — too many messages",
})
continue
try:
data = json.loads(msg.data)
except json.JSONDecodeError:
continue
msg_type = data.get("type", "")
# Dedup check
msg_id = data.get("msg_id")
if msg_id and self._is_duplicate(msg_id):
continue
await self._handle_message(conn, msg_type, data)
elif msg.type == WSMsgType.ERROR:
logger.error(
f"[HUB] WS error {conn.conn_id}: {conn.ws.exception()}"
)
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
break
async def _handle_message(self, conn: WSConnection, msg_type: str, data: dict):
"""Route an incoming message to the appropriate handler."""
conn.last_heartbeat = time.time() # Any message counts as heartbeat
if msg_type == MsgType.PENDING:
payload = data.get("data", {})
request_id = payload.get("request_id", "")
if request_id:
# Prevent slow memory leak for stranded requests
if len(self.pending_owners) > 10000:
oldest = next(iter(self.pending_owners))
self.pending_owners.pop(oldest, None)
# Track ownership for response routing
self.pending_owners[request_id] = conn.conn_id
# Add source metadata
payload["_conn_id"] = conn.conn_id
payload["_instance_number"] = conn.instance_number
payload["_pc_name"] = conn.pc_name
payload.setdefault("project_name", conn.project)
if self._on_pending:
await self._on_pending(conn.project, payload)
elif msg_type == MsgType.CHAT:
payload = data.get("data", {})
payload.setdefault("project_name", conn.project)
payload["_instance_number"] = conn.instance_number
payload["_pc_name"] = conn.pc_name
if self._on_chat:
await self._on_chat(conn.project, payload)
elif msg_type == MsgType.REGISTER:
payload = data.get("data", {})
payload.setdefault("project_name", conn.project)
if self._on_register:
await self._on_register(payload)
elif msg_type == MsgType.AUTO_RESOLVE:
payload = data.get("data", {})
request_id = payload.get("request_id", "")
if request_id:
self.pending_owners.pop(request_id, None)
if self._on_auto_resolve:
await self._on_auto_resolve(conn.project, payload)
elif msg_type == MsgType.BRAIN_EVENT:
payload = data.get("data", {})
payload.setdefault("project_name", conn.project)
if self._on_brain_event:
await self._on_brain_event(conn.project, payload)
elif msg_type == MsgType.HEARTBEAT:
# Echo back a "pong" so clients without native ping/pong can update their timers
await conn.ws.send_json({"type": "pong"})
else:
logger.warning(f"[HUB] Unknown message type: {msg_type} from {conn.conn_id}")
# ─── Rate Limiting ───
def _check_rate_limit(self, conn: WSConnection) -> bool:
"""Per-connection rate limiting. Returns True if allowed."""
now = time.time()
conn._msg_timestamps = [
t for t in conn._msg_timestamps if now - t < RATE_WINDOW
]
if len(conn._msg_timestamps) >= PER_CONN_RATE_LIMIT:
logger.warning(f"[HUB] Rate limited: {conn.conn_id}")
return False
conn._msg_timestamps.append(now)
return True
# ─── Deduplication ───
def _is_duplicate(self, msg_id: str) -> bool:
"""Check if a message ID was recently seen."""
now = time.time()
# Cleanup old entries
if len(self._recent_msg_ids) > 10000:
cutoff = now - self._msg_id_ttl
self._recent_msg_ids = {
k: v for k, v in self._recent_msg_ids.items() if v > cutoff
}
if msg_id in self._recent_msg_ids:
return True
self._recent_msg_ids[msg_id] = now
return False
# ─── Diagnostics ───
def get_status(self) -> dict:
"""Return hub status for health/diagnostics."""
projects = {}
for project, conn_ids in self.project_connections.items():
instances = []
for cid in conn_ids:
c = self.connections.get(cid)
if c:
instances.append({
"conn_id": c.conn_id,
"instance": c.instance_number,
"pc": c.pc_name,
"connected": int(time.time() - c.connected_at),
})
projects[project] = {
"count": len(conn_ids),
"instances": instances,
}
return {
"total_connections": len(self.connections),
"projects": projects,
"pending_owners": len(self.pending_owners),
}