"""Unit tests for WebSocket Hub — connection lifecycle, routing, and edge cases. Uses unittest.mock for WebSocket mocking, no live server needed. Run: python -m pytest tests/test_hub_unit.py -v """ import asyncio import json import time import pytest import pytest_asyncio from unittest.mock import AsyncMock, MagicMock, patch # Add project root to sys.path import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Configure pytest-asyncio pytestmark = pytest.mark.anyio from auth import TokenManager from hub import WSHub, WSConnection, MsgType, PER_CONN_RATE_LIMIT # ─── Fixtures ─── def make_token_manager(secret="test-secret-64chars-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", reg_code="test-reg-code"): return TokenManager(secret=secret, registration_code=reg_code) def make_mock_ws(closed=False): """Create a mock WebSocket with async send_json.""" ws = AsyncMock() ws.closed = closed ws.send_json = AsyncMock() ws.close = AsyncMock() return ws def make_connection(conn_id, project="test_project", pc_name="test_pc", instance_number=1, authenticated=True, ws=None): """Create a WSConnection with a mock WebSocket.""" if ws is None: ws = make_mock_ws() conn = WSConnection( conn_id=conn_id, ws=ws, project=project, pc_name=pc_name, instance_number=instance_number, authenticated=authenticated, ) return conn def make_hub(): """Create a Hub with a TokenManager (no auth enforcement).""" tm = make_token_manager() return WSHub(tm) def register_conn(hub, conn): """Manually register a connection in the hub (skip auth flow).""" hub.connections[conn.conn_id] = conn if conn.project not in hub.project_connections: hub.project_connections[conn.project] = set() hub.project_connections[conn.project].add(conn.conn_id) # ─── Connection Tracking ─── class TestConnectionTracking: @pytest.mark.asyncio async def test_register_connection_adds_to_tracking(self): hub = make_hub() conn = make_connection("c1") hub._register_connection(conn) assert "c1" in hub.connections assert "c1" in hub.project_connections["test_project"] @pytest.mark.asyncio async def test_register_multiple_connections_same_project(self): hub = make_hub() c1 = make_connection("c1") c2 = make_connection("c2") hub._register_connection(c1) hub._register_connection(c2) assert hub.get_active_count("test_project") == 2 @pytest.mark.asyncio async def test_register_connections_different_projects(self): hub = make_hub() c1 = make_connection("c1", project="proj_a") c2 = make_connection("c2", project="proj_b") hub._register_connection(c1) hub._register_connection(c2) assert hub.get_active_count("proj_a") == 1 assert hub.get_active_count("proj_b") == 1 assert len(hub.connections) == 2 @pytest.mark.asyncio async def test_disconnect_removes_from_tracking(self): hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) await hub._disconnect(conn) assert "c1" not in hub.connections assert hub.get_active_count("test_project") == 0 @pytest.mark.asyncio async def test_disconnect_removes_empty_project(self): hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) await hub._disconnect(conn) assert "test_project" not in hub.project_connections @pytest.mark.asyncio async def test_disconnect_keeps_other_connections(self): hub = make_hub() c1 = make_connection("c1") c2 = make_connection("c2") register_conn(hub, c1) register_conn(hub, c2) await hub._disconnect(c1) assert "c1" not in hub.connections assert "c2" in hub.connections assert hub.get_active_count("test_project") == 1 # ─── Instance Number Assignment ─── class TestInstanceNumbers: def test_first_instance_gets_1(self): hub = make_hub() num = hub._assign_instance_number("proj", "c1") assert num == 1 def test_second_instance_gets_2(self): hub = make_hub() c1 = make_connection("c1", instance_number=1) register_conn(hub, c1) num = hub._assign_instance_number("proj", "c2") # project "proj" has no connections (c1 is in "test_project"), so it gets 1 assert num == 1 def test_gap_filling(self): """If instance #1 disconnects, next connection gets #1 again.""" hub = make_hub() c2 = make_connection("c2", instance_number=2) register_conn(hub, c2) num = hub._assign_instance_number("test_project", "c3") assert num == 1 # #1 is available def test_sequential_assignment(self): hub = make_hub() c1 = make_connection("c1", instance_number=1) c2 = make_connection("c2", instance_number=2) register_conn(hub, c1) register_conn(hub, c2) num = hub._assign_instance_number("test_project", "c3") assert num == 3 # ─── Pending Owners Lifecycle ─── class TestPendingOwners: @pytest.mark.asyncio async def test_pending_tracks_owner(self): """Processing a pending message tracks the owner connection.""" hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) await hub._handle_message(conn, MsgType.PENDING, { "data": {"request_id": "req-001", "command": "ls"} }) assert hub.pending_owners["req-001"] == "c1" @pytest.mark.asyncio async def test_auto_resolve_clears_owner(self): """Auto-resolve removes the pending owner tracking.""" hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) hub.pending_owners["req-001"] = "c1" await hub._handle_message(conn, MsgType.AUTO_RESOLVE, { "data": {"request_id": "req-001"} }) assert "req-001" not in hub.pending_owners @pytest.mark.asyncio async def test_disconnect_reassigns_pending_to_same_project(self): """When owner disconnects, pending is reassigned to another connection in the same project (not deleted).""" hub = make_hub() c1 = make_connection("c1") c2 = make_connection("c2") register_conn(hub, c1) register_conn(hub, c2) hub.pending_owners["req-001"] = "c1" await hub._disconnect(c1) # req-001 should now be owned by c2 assert hub.pending_owners.get("req-001") == "c2" @pytest.mark.asyncio async def test_disconnect_deletes_pending_when_no_remaining(self): """When owner disconnects and no other connections exist, pending is deleted.""" hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) hub.pending_owners["req-001"] = "c1" await hub._disconnect(conn) assert "req-001" not in hub.pending_owners @pytest.mark.asyncio async def test_register_reassigns_orphaned_pending(self): """When a new connection registers, orphaned pending_owners (pointing to dead conn_ids) are reassigned to it.""" hub = make_hub() # Simulate orphaned pending (old conn_id "dead" not in connections) hub.pending_owners["req-001"] = "dead-conn" hub.pending_owners["req-002"] = "dead-conn" new_conn = make_connection("c-new") hub._register_connection(new_conn) assert hub.pending_owners["req-001"] == "c-new" assert hub.pending_owners["req-002"] == "c-new" @pytest.mark.asyncio async def test_disconnect_reassigns_only_to_authenticated(self): """Pending should only be reassigned to authenticated connections.""" hub = make_hub() c1 = make_connection("c1") c2 = make_connection("c2", authenticated=False) # Not authenticated register_conn(hub, c1) register_conn(hub, c2) hub.pending_owners["req-001"] = "c1" await hub._disconnect(c1) # c2 is not authenticated, so pending should be deleted assert "req-001" not in hub.pending_owners # ─── Response Routing ─── class TestResponseRouting: @pytest.mark.asyncio async def test_send_response_to_owner(self): hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) hub.pending_owners["req-001"] = "c1" result = await hub.send_response_to_pending_owner( "req-001", {"type": "response", "approved": True} ) assert result is True # Owner should be cleaned up after successful delivery assert "req-001" not in hub.pending_owners @pytest.mark.asyncio async def test_send_response_owner_dead_fallback(self): """When original owner is dead, response routes to another connection in the same project.""" hub = make_hub() dead_conn = make_connection("c-dead", ws=make_mock_ws(closed=True)) alive_conn = make_connection("c-alive") register_conn(hub, dead_conn) register_conn(hub, alive_conn) hub.pending_owners["req-001"] = "c-dead" result = await hub.send_response_to_pending_owner( "req-001", {"type": "response", "approved": True} ) assert result is True # Rerouted successfully assert "req-001" not in hub.pending_owners @pytest.mark.asyncio async def test_send_response_no_owner(self): """When no owner exists for a request_id, return False.""" hub = make_hub() result = await hub.send_response_to_pending_owner( "req-unknown", {"type": "response"} ) assert result is False @pytest.mark.asyncio async def test_send_response_all_dead(self): """When owner and all project connections are dead, return False.""" hub = make_hub() dead = make_connection("c-dead", ws=make_mock_ws(closed=True)) register_conn(hub, dead) hub.pending_owners["req-001"] = "c-dead" result = await hub.send_response_to_pending_owner( "req-001", {"type": "response"} ) assert result is False assert "req-001" not in hub.pending_owners # Cleaned up # ─── Message Routing ─── class TestMessageRouting: @pytest.mark.asyncio async def test_broadcast_to_project(self): hub = make_hub() c1 = make_connection("c1") c2 = make_connection("c2") c3 = make_connection("c3", project="other_project") register_conn(hub, c1) register_conn(hub, c2) register_conn(hub, c3) msg = {"type": "test", "data": "hello"} await hub.broadcast_to_project("test_project", msg) # c1 and c2 should have received the message (via queue) assert c1.send_queue.qsize() == 1 assert c2.send_queue.qsize() == 1 assert c3.send_queue.qsize() == 0 # Different project @pytest.mark.asyncio async def test_send_to_instance(self): hub = make_hub() c1 = make_connection("c1", instance_number=1) c2 = make_connection("c2", instance_number=2) register_conn(hub, c1) register_conn(hub, c2) result = await hub.send_to_instance("test_project", 2, {"type": "cmd"}) assert result is True assert c1.send_queue.qsize() == 0 assert c2.send_queue.qsize() == 1 @pytest.mark.asyncio async def test_send_to_instance_not_found(self): hub = make_hub() c1 = make_connection("c1", instance_number=1) register_conn(hub, c1) result = await hub.send_to_instance("test_project", 99, {"type": "cmd"}) assert result is False @pytest.mark.asyncio async def test_broadcast_skips_unauthenticated(self): hub = make_hub() c1 = make_connection("c1", authenticated=True) c2 = make_connection("c2", authenticated=False) register_conn(hub, c1) register_conn(hub, c2) await hub.broadcast_to_project("test_project", {"type": "msg"}) assert c1.send_queue.qsize() == 1 assert c2.send_queue.qsize() == 0 # ─── Message Handling (Dispatch) ─── class TestMessageHandling: @pytest.mark.asyncio async def test_pending_sets_metadata(self): """Pending messages should have source metadata injected.""" hub = make_hub() received = [] hub.set_bot_handlers(on_pending=AsyncMock(side_effect=lambda p, d: received.append(d))) conn = make_connection("c1", pc_name="my_pc", instance_number=3) register_conn(hub, conn) await hub._handle_message(conn, MsgType.PENDING, { "data": {"request_id": "req-x", "command": "npm install"} }) assert len(received) == 1 payload = received[0] assert payload["_conn_id"] == "c1" assert payload["_instance_number"] == 3 assert payload["_pc_name"] == "my_pc" assert payload["project_name"] == "test_project" @pytest.mark.asyncio async def test_chat_sets_project_name(self): hub = make_hub() received = [] hub.set_bot_handlers(on_chat=AsyncMock(side_effect=lambda p, d: received.append(d))) conn = make_connection("c1", project="my_proj") register_conn(hub, conn) await hub._handle_message(conn, MsgType.CHAT, { "data": {"content": "hello"} }) assert received[0]["project_name"] == "my_proj" @pytest.mark.asyncio async def test_no_handler_doesnt_crash(self): """If no bot handlers are set, messages should be silently dropped.""" hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) # Should not raise await hub._handle_message(conn, MsgType.PENDING, {"data": {"request_id": "r1"}}) await hub._handle_message(conn, MsgType.CHAT, {"data": {"content": "hi"}}) await hub._handle_message(conn, MsgType.REGISTER, {"data": {}}) await hub._handle_message(conn, MsgType.AUTO_RESOLVE, {"data": {}}) @pytest.mark.asyncio async def test_heartbeat_updates_timestamp(self): hub = make_hub() conn = make_connection("c1") register_conn(hub, conn) old_ts = conn.last_heartbeat # Small delay to ensure timestamp difference await asyncio.sleep(0.01) await hub._handle_message(conn, MsgType.HEARTBEAT, {}) assert conn.last_heartbeat > old_ts # ─── Rate Limiting ─── class TestRateLimiting: def test_under_limit_allows(self): hub = make_hub() conn = make_connection("c1") for _ in range(PER_CONN_RATE_LIMIT - 1): assert hub._check_rate_limit(conn) is True def test_at_limit_blocks(self): hub = make_hub() conn = make_connection("c1") # Fill up to limit for _ in range(PER_CONN_RATE_LIMIT): hub._check_rate_limit(conn) # Next should be blocked assert hub._check_rate_limit(conn) is False def test_old_timestamps_expire(self): hub = make_hub() conn = make_connection("c1") # Add old timestamps (beyond RATE_WINDOW) old_time = time.time() - 20.0 conn._msg_timestamps = [old_time] * PER_CONN_RATE_LIMIT # Should be allowed (old timestamps are pruned) assert hub._check_rate_limit(conn) is True # ─── Deduplication ─── class TestDeduplication: def test_first_message_not_duplicate(self): hub = make_hub() assert hub._is_duplicate("msg-001") is False def test_same_id_is_duplicate(self): hub = make_hub() hub._is_duplicate("msg-001") assert hub._is_duplicate("msg-001") is True def test_different_id_not_duplicate(self): hub = make_hub() hub._is_duplicate("msg-001") assert hub._is_duplicate("msg-002") is False def test_cleanup_on_overflow(self): hub = make_hub() # Add many entries to trigger cleanup for i in range(10001): hub._recent_msg_ids[f"msg-{i}"] = time.time() # Should trigger cleanup and still work hub._is_duplicate("new-msg") assert len(hub._recent_msg_ids) <= 10002 # +1 for new-msg # ─── Queue Backpressure ─── class TestQueueBackpressure: @pytest.mark.asyncio async def test_queue_send_normal(self): hub = make_hub() conn = make_connection("c1") await hub._queue_send(conn, {"type": "test"}) assert conn.send_queue.qsize() == 1 @pytest.mark.asyncio async def test_queue_full_drops_oldest(self): hub = make_hub() conn = make_connection("c1") # Fill the queue for i in range(100): await hub._queue_send(conn, {"type": "test", "i": i}) # Queue should be full (100 items) assert conn.send_queue.qsize() == 100 # Adding one more should drop the oldest await hub._queue_send(conn, {"type": "test", "i": "new"}) assert conn.send_queue.qsize() == 100 # Still 100 # ─── Diagnostics ─── class TestDiagnostics: def test_get_status_empty(self): hub = make_hub() status = hub.get_status() assert status["total_connections"] == 0 assert status["projects"] == {} assert status["pending_owners"] == 0 def test_get_status_with_connections(self): hub = make_hub() c1 = make_connection("c1", project="proj_a", instance_number=1) c2 = make_connection("c2", project="proj_a", instance_number=2) register_conn(hub, c1) register_conn(hub, c2) hub.pending_owners["req-001"] = "c1" status = hub.get_status() assert status["total_connections"] == 2 assert status["projects"]["proj_a"]["count"] == 2 assert status["pending_owners"] == 1 assert len(status["projects"]["proj_a"]["instances"]) == 2 # ─── Auth (TokenManager) ─── class TestTokenManager: def test_create_and_verify_token(self): tm = make_token_manager() token = tm.create_token("my_project", "my_pc") payload = tm.verify_token(token) assert payload is not None assert payload["project"] == "my_project" assert payload["pc"] == "my_pc" def test_expired_token_rejected(self): tm = make_token_manager() token = tm.create_token("proj", "pc", ttl=-1) # Already expired assert tm.verify_token(token) is None def test_tampered_token_rejected(self): tm = make_token_manager() token = tm.create_token("proj", "pc") tampered = token[:-1] + ("A" if token[-1] != "A" else "B") assert tm.verify_token(tampered) is None def test_wrong_secret_rejected(self): tm1 = make_token_manager(secret="secret-1" + "a" * 56) tm2 = make_token_manager(secret="secret-2" + "b" * 56) token = tm1.create_token("proj", "pc") assert tm2.verify_token(token) is None def test_registration_code_valid(self): tm = make_token_manager(reg_code="my-code") assert tm.validate_registration_code("my-code") is True assert tm.validate_registration_code("wrong") is False def test_registration_code_empty_allows_all(self): tm = make_token_manager(reg_code="") assert tm.validate_registration_code("anything") is True