Files
gravity_control/tests/test_hub_unit.py

622 lines
19 KiB
Python

"""Unit tests for WebSocket Hub — connection lifecycle, routing, and edge cases.
Uses unittest.mock for WebSocket mocking, no live server needed.
Run: python -m pytest tests/test_hub_unit.py -v
"""
import asyncio
import json
import time
import pytest
import pytest_asyncio
from unittest.mock import AsyncMock, MagicMock, patch
# Add project root to sys.path
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Configure pytest-asyncio
pytestmark = pytest.mark.anyio
from auth import TokenManager
from hub import WSHub, WSConnection, MsgType, PER_CONN_RATE_LIMIT
# ─── Fixtures ───
def make_token_manager(secret="test-secret-64chars-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
reg_code="test-reg-code"):
return TokenManager(secret=secret, registration_code=reg_code)
def make_mock_ws(closed=False):
"""Create a mock WebSocket with async send_json."""
ws = AsyncMock()
ws.closed = closed
ws.send_json = AsyncMock()
ws.close = AsyncMock()
return ws
def make_connection(conn_id, project="test_project", pc_name="test_pc",
instance_number=1, authenticated=True, ws=None):
"""Create a WSConnection with a mock WebSocket."""
if ws is None:
ws = make_mock_ws()
conn = WSConnection(
conn_id=conn_id,
ws=ws,
project=project,
pc_name=pc_name,
instance_number=instance_number,
authenticated=authenticated,
)
return conn
def make_hub():
"""Create a Hub with a TokenManager (no auth enforcement)."""
tm = make_token_manager()
return WSHub(tm)
def register_conn(hub, conn):
"""Manually register a connection in the hub (skip auth flow)."""
hub.connections[conn.conn_id] = conn
if conn.project not in hub.project_connections:
hub.project_connections[conn.project] = set()
hub.project_connections[conn.project].add(conn.conn_id)
# ─── Connection Tracking ───
class TestConnectionTracking:
@pytest.mark.asyncio
async def test_register_connection_adds_to_tracking(self):
hub = make_hub()
conn = make_connection("c1")
hub._register_connection(conn)
assert "c1" in hub.connections
assert "c1" in hub.project_connections["test_project"]
@pytest.mark.asyncio
async def test_register_multiple_connections_same_project(self):
hub = make_hub()
c1 = make_connection("c1")
c2 = make_connection("c2")
hub._register_connection(c1)
hub._register_connection(c2)
assert hub.get_active_count("test_project") == 2
@pytest.mark.asyncio
async def test_register_connections_different_projects(self):
hub = make_hub()
c1 = make_connection("c1", project="proj_a")
c2 = make_connection("c2", project="proj_b")
hub._register_connection(c1)
hub._register_connection(c2)
assert hub.get_active_count("proj_a") == 1
assert hub.get_active_count("proj_b") == 1
assert len(hub.connections) == 2
@pytest.mark.asyncio
async def test_disconnect_removes_from_tracking(self):
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
await hub._disconnect(conn)
assert "c1" not in hub.connections
assert hub.get_active_count("test_project") == 0
@pytest.mark.asyncio
async def test_disconnect_removes_empty_project(self):
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
await hub._disconnect(conn)
assert "test_project" not in hub.project_connections
@pytest.mark.asyncio
async def test_disconnect_keeps_other_connections(self):
hub = make_hub()
c1 = make_connection("c1")
c2 = make_connection("c2")
register_conn(hub, c1)
register_conn(hub, c2)
await hub._disconnect(c1)
assert "c1" not in hub.connections
assert "c2" in hub.connections
assert hub.get_active_count("test_project") == 1
# ─── Instance Number Assignment ───
class TestInstanceNumbers:
def test_first_instance_gets_1(self):
hub = make_hub()
num = hub._assign_instance_number("proj", "c1")
assert num == 1
def test_second_instance_gets_2(self):
hub = make_hub()
c1 = make_connection("c1", instance_number=1)
register_conn(hub, c1)
num = hub._assign_instance_number("proj", "c2")
# project "proj" has no connections (c1 is in "test_project"), so it gets 1
assert num == 1
def test_gap_filling(self):
"""If instance #1 disconnects, next connection gets #1 again."""
hub = make_hub()
c2 = make_connection("c2", instance_number=2)
register_conn(hub, c2)
num = hub._assign_instance_number("test_project", "c3")
assert num == 1 # #1 is available
def test_sequential_assignment(self):
hub = make_hub()
c1 = make_connection("c1", instance_number=1)
c2 = make_connection("c2", instance_number=2)
register_conn(hub, c1)
register_conn(hub, c2)
num = hub._assign_instance_number("test_project", "c3")
assert num == 3
# ─── Pending Owners Lifecycle ───
class TestPendingOwners:
@pytest.mark.asyncio
async def test_pending_tracks_owner(self):
"""Processing a pending message tracks the owner connection."""
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
await hub._handle_message(conn, MsgType.PENDING, {
"data": {"request_id": "req-001", "command": "ls"}
})
assert hub.pending_owners["req-001"] == "c1"
@pytest.mark.asyncio
async def test_auto_resolve_clears_owner(self):
"""Auto-resolve removes the pending owner tracking."""
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
hub.pending_owners["req-001"] = "c1"
await hub._handle_message(conn, MsgType.AUTO_RESOLVE, {
"data": {"request_id": "req-001"}
})
assert "req-001" not in hub.pending_owners
@pytest.mark.asyncio
async def test_disconnect_reassigns_pending_to_same_project(self):
"""When owner disconnects, pending is reassigned to another connection
in the same project (not deleted)."""
hub = make_hub()
c1 = make_connection("c1")
c2 = make_connection("c2")
register_conn(hub, c1)
register_conn(hub, c2)
hub.pending_owners["req-001"] = "c1"
await hub._disconnect(c1)
# req-001 should now be owned by c2
assert hub.pending_owners.get("req-001") == "c2"
@pytest.mark.asyncio
async def test_disconnect_deletes_pending_when_no_remaining(self):
"""When owner disconnects and no other connections exist, pending is deleted."""
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
hub.pending_owners["req-001"] = "c1"
await hub._disconnect(conn)
assert "req-001" not in hub.pending_owners
@pytest.mark.asyncio
async def test_register_reassigns_orphaned_pending(self):
"""When a new connection registers, orphaned pending_owners
(pointing to dead conn_ids) are reassigned to it."""
hub = make_hub()
# Simulate orphaned pending (old conn_id "dead" not in connections)
hub.pending_owners["req-001"] = "dead-conn"
hub.pending_owners["req-002"] = "dead-conn"
new_conn = make_connection("c-new")
hub._register_connection(new_conn)
assert hub.pending_owners["req-001"] == "c-new"
assert hub.pending_owners["req-002"] == "c-new"
@pytest.mark.asyncio
async def test_disconnect_reassigns_only_to_authenticated(self):
"""Pending should only be reassigned to authenticated connections."""
hub = make_hub()
c1 = make_connection("c1")
c2 = make_connection("c2", authenticated=False) # Not authenticated
register_conn(hub, c1)
register_conn(hub, c2)
hub.pending_owners["req-001"] = "c1"
await hub._disconnect(c1)
# c2 is not authenticated, so pending should be deleted
assert "req-001" not in hub.pending_owners
# ─── Response Routing ───
class TestResponseRouting:
@pytest.mark.asyncio
async def test_send_response_to_owner(self):
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
hub.pending_owners["req-001"] = "c1"
result = await hub.send_response_to_pending_owner(
"req-001", {"type": "response", "approved": True}
)
assert result is True
# Owner should be cleaned up after successful delivery
assert "req-001" not in hub.pending_owners
@pytest.mark.asyncio
async def test_send_response_owner_dead_fallback(self):
"""When original owner is dead, response routes to another connection
in the same project."""
hub = make_hub()
dead_conn = make_connection("c-dead", ws=make_mock_ws(closed=True))
alive_conn = make_connection("c-alive")
register_conn(hub, dead_conn)
register_conn(hub, alive_conn)
hub.pending_owners["req-001"] = "c-dead"
result = await hub.send_response_to_pending_owner(
"req-001", {"type": "response", "approved": True}
)
assert result is True # Rerouted successfully
assert "req-001" not in hub.pending_owners
@pytest.mark.asyncio
async def test_send_response_no_owner(self):
"""When no owner exists for a request_id, return False."""
hub = make_hub()
result = await hub.send_response_to_pending_owner(
"req-unknown", {"type": "response"}
)
assert result is False
@pytest.mark.asyncio
async def test_send_response_all_dead(self):
"""When owner and all project connections are dead, return False."""
hub = make_hub()
dead = make_connection("c-dead", ws=make_mock_ws(closed=True))
register_conn(hub, dead)
hub.pending_owners["req-001"] = "c-dead"
result = await hub.send_response_to_pending_owner(
"req-001", {"type": "response"}
)
assert result is False
assert "req-001" not in hub.pending_owners # Cleaned up
# ─── Message Routing ───
class TestMessageRouting:
@pytest.mark.asyncio
async def test_broadcast_to_project(self):
hub = make_hub()
c1 = make_connection("c1")
c2 = make_connection("c2")
c3 = make_connection("c3", project="other_project")
register_conn(hub, c1)
register_conn(hub, c2)
register_conn(hub, c3)
msg = {"type": "test", "data": "hello"}
await hub.broadcast_to_project("test_project", msg)
# c1 and c2 should have received the message (via queue)
assert c1.send_queue.qsize() == 1
assert c2.send_queue.qsize() == 1
assert c3.send_queue.qsize() == 0 # Different project
@pytest.mark.asyncio
async def test_send_to_instance(self):
hub = make_hub()
c1 = make_connection("c1", instance_number=1)
c2 = make_connection("c2", instance_number=2)
register_conn(hub, c1)
register_conn(hub, c2)
result = await hub.send_to_instance("test_project", 2, {"type": "cmd"})
assert result is True
assert c1.send_queue.qsize() == 0
assert c2.send_queue.qsize() == 1
@pytest.mark.asyncio
async def test_send_to_instance_not_found(self):
hub = make_hub()
c1 = make_connection("c1", instance_number=1)
register_conn(hub, c1)
result = await hub.send_to_instance("test_project", 99, {"type": "cmd"})
assert result is False
@pytest.mark.asyncio
async def test_broadcast_skips_unauthenticated(self):
hub = make_hub()
c1 = make_connection("c1", authenticated=True)
c2 = make_connection("c2", authenticated=False)
register_conn(hub, c1)
register_conn(hub, c2)
await hub.broadcast_to_project("test_project", {"type": "msg"})
assert c1.send_queue.qsize() == 1
assert c2.send_queue.qsize() == 0
# ─── Message Handling (Dispatch) ───
class TestMessageHandling:
@pytest.mark.asyncio
async def test_pending_sets_metadata(self):
"""Pending messages should have source metadata injected."""
hub = make_hub()
received = []
hub.set_bot_handlers(on_pending=AsyncMock(side_effect=lambda p, d: received.append(d)))
conn = make_connection("c1", pc_name="my_pc", instance_number=3)
register_conn(hub, conn)
await hub._handle_message(conn, MsgType.PENDING, {
"data": {"request_id": "req-x", "command": "npm install"}
})
assert len(received) == 1
payload = received[0]
assert payload["_conn_id"] == "c1"
assert payload["_instance_number"] == 3
assert payload["_pc_name"] == "my_pc"
assert payload["project_name"] == "test_project"
@pytest.mark.asyncio
async def test_chat_sets_project_name(self):
hub = make_hub()
received = []
hub.set_bot_handlers(on_chat=AsyncMock(side_effect=lambda p, d: received.append(d)))
conn = make_connection("c1", project="my_proj")
register_conn(hub, conn)
await hub._handle_message(conn, MsgType.CHAT, {
"data": {"content": "hello"}
})
assert received[0]["project_name"] == "my_proj"
@pytest.mark.asyncio
async def test_no_handler_doesnt_crash(self):
"""If no bot handlers are set, messages should be silently dropped."""
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
# Should not raise
await hub._handle_message(conn, MsgType.PENDING, {"data": {"request_id": "r1"}})
await hub._handle_message(conn, MsgType.CHAT, {"data": {"content": "hi"}})
await hub._handle_message(conn, MsgType.REGISTER, {"data": {}})
await hub._handle_message(conn, MsgType.AUTO_RESOLVE, {"data": {}})
@pytest.mark.asyncio
async def test_heartbeat_updates_timestamp(self):
hub = make_hub()
conn = make_connection("c1")
register_conn(hub, conn)
old_ts = conn.last_heartbeat
# Small delay to ensure timestamp difference
await asyncio.sleep(0.01)
await hub._handle_message(conn, MsgType.HEARTBEAT, {})
assert conn.last_heartbeat > old_ts
# ─── Rate Limiting ───
class TestRateLimiting:
def test_under_limit_allows(self):
hub = make_hub()
conn = make_connection("c1")
for _ in range(PER_CONN_RATE_LIMIT - 1):
assert hub._check_rate_limit(conn) is True
def test_at_limit_blocks(self):
hub = make_hub()
conn = make_connection("c1")
# Fill up to limit
for _ in range(PER_CONN_RATE_LIMIT):
hub._check_rate_limit(conn)
# Next should be blocked
assert hub._check_rate_limit(conn) is False
def test_old_timestamps_expire(self):
hub = make_hub()
conn = make_connection("c1")
# Add old timestamps (beyond RATE_WINDOW)
old_time = time.time() - 20.0
conn._msg_timestamps = [old_time] * PER_CONN_RATE_LIMIT
# Should be allowed (old timestamps are pruned)
assert hub._check_rate_limit(conn) is True
# ─── Deduplication ───
class TestDeduplication:
def test_first_message_not_duplicate(self):
hub = make_hub()
assert hub._is_duplicate("msg-001") is False
def test_same_id_is_duplicate(self):
hub = make_hub()
hub._is_duplicate("msg-001")
assert hub._is_duplicate("msg-001") is True
def test_different_id_not_duplicate(self):
hub = make_hub()
hub._is_duplicate("msg-001")
assert hub._is_duplicate("msg-002") is False
def test_cleanup_on_overflow(self):
hub = make_hub()
# Add many entries to trigger cleanup
for i in range(10001):
hub._recent_msg_ids[f"msg-{i}"] = time.time()
# Should trigger cleanup and still work
hub._is_duplicate("new-msg")
assert len(hub._recent_msg_ids) <= 10002 # +1 for new-msg
# ─── Queue Backpressure ───
class TestQueueBackpressure:
@pytest.mark.asyncio
async def test_queue_send_normal(self):
hub = make_hub()
conn = make_connection("c1")
await hub._queue_send(conn, {"type": "test"})
assert conn.send_queue.qsize() == 1
@pytest.mark.asyncio
async def test_queue_full_drops_oldest(self):
hub = make_hub()
conn = make_connection("c1")
# Fill the queue
for i in range(100):
await hub._queue_send(conn, {"type": "test", "i": i})
# Queue should be full (100 items)
assert conn.send_queue.qsize() == 100
# Adding one more should drop the oldest
await hub._queue_send(conn, {"type": "test", "i": "new"})
assert conn.send_queue.qsize() == 100 # Still 100
# ─── Diagnostics ───
class TestDiagnostics:
def test_get_status_empty(self):
hub = make_hub()
status = hub.get_status()
assert status["total_connections"] == 0
assert status["projects"] == {}
assert status["pending_owners"] == 0
def test_get_status_with_connections(self):
hub = make_hub()
c1 = make_connection("c1", project="proj_a", instance_number=1)
c2 = make_connection("c2", project="proj_a", instance_number=2)
register_conn(hub, c1)
register_conn(hub, c2)
hub.pending_owners["req-001"] = "c1"
status = hub.get_status()
assert status["total_connections"] == 2
assert status["projects"]["proj_a"]["count"] == 2
assert status["pending_owners"] == 1
assert len(status["projects"]["proj_a"]["instances"]) == 2
# ─── Auth (TokenManager) ───
class TestTokenManager:
def test_create_and_verify_token(self):
tm = make_token_manager()
token = tm.create_token("my_project", "my_pc")
payload = tm.verify_token(token)
assert payload is not None
assert payload["project"] == "my_project"
assert payload["pc"] == "my_pc"
def test_expired_token_rejected(self):
tm = make_token_manager()
token = tm.create_token("proj", "pc", ttl=-1) # Already expired
assert tm.verify_token(token) is None
def test_tampered_token_rejected(self):
tm = make_token_manager()
token = tm.create_token("proj", "pc")
tampered = token[:-1] + ("A" if token[-1] != "A" else "B")
assert tm.verify_token(tampered) is None
def test_wrong_secret_rejected(self):
tm1 = make_token_manager(secret="secret-1" + "a" * 56)
tm2 = make_token_manager(secret="secret-2" + "b" * 56)
token = tm1.create_token("proj", "pc")
assert tm2.verify_token(token) is None
def test_registration_code_valid(self):
tm = make_token_manager(reg_code="my-code")
assert tm.validate_registration_code("my-code") is True
assert tm.validate_registration_code("wrong") is False
def test_registration_code_empty_allows_all(self):
tm = make_token_manager(reg_code="")
assert tm.validate_registration_code("anything") is True