diff --git a/tests/test_hub_unit.py b/tests/test_hub_unit.py new file mode 100644 index 0000000..115c400 --- /dev/null +++ b/tests/test_hub_unit.py @@ -0,0 +1,621 @@ +"""Unit tests for WebSocket Hub — connection lifecycle, routing, and edge cases. + +Uses unittest.mock for WebSocket mocking, no live server needed. +Run: python -m pytest tests/test_hub_unit.py -v +""" + +import asyncio +import json +import time +import pytest +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +# Add project root to sys.path +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Configure pytest-asyncio +pytestmark = pytest.mark.anyio + +from auth import TokenManager +from hub import WSHub, WSConnection, MsgType, PER_CONN_RATE_LIMIT + + +# ─── Fixtures ─── + +def make_token_manager(secret="test-secret-64chars-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + reg_code="test-reg-code"): + return TokenManager(secret=secret, registration_code=reg_code) + + +def make_mock_ws(closed=False): + """Create a mock WebSocket with async send_json.""" + ws = AsyncMock() + ws.closed = closed + ws.send_json = AsyncMock() + ws.close = AsyncMock() + return ws + + +def make_connection(conn_id, project="test_project", pc_name="test_pc", + instance_number=1, authenticated=True, ws=None): + """Create a WSConnection with a mock WebSocket.""" + if ws is None: + ws = make_mock_ws() + conn = WSConnection( + conn_id=conn_id, + ws=ws, + project=project, + pc_name=pc_name, + instance_number=instance_number, + authenticated=authenticated, + ) + return conn + + +def make_hub(): + """Create a Hub with a TokenManager (no auth enforcement).""" + tm = make_token_manager() + return WSHub(tm) + + +def register_conn(hub, conn): + """Manually register a connection in the hub (skip auth flow).""" + hub.connections[conn.conn_id] = conn + if conn.project not in hub.project_connections: + hub.project_connections[conn.project] = set() + hub.project_connections[conn.project].add(conn.conn_id) + + +# ─── Connection Tracking ─── + +class TestConnectionTracking: + + @pytest.mark.asyncio + async def test_register_connection_adds_to_tracking(self): + hub = make_hub() + conn = make_connection("c1") + hub._register_connection(conn) + + assert "c1" in hub.connections + assert "c1" in hub.project_connections["test_project"] + + @pytest.mark.asyncio + async def test_register_multiple_connections_same_project(self): + hub = make_hub() + c1 = make_connection("c1") + c2 = make_connection("c2") + hub._register_connection(c1) + hub._register_connection(c2) + + assert hub.get_active_count("test_project") == 2 + + @pytest.mark.asyncio + async def test_register_connections_different_projects(self): + hub = make_hub() + c1 = make_connection("c1", project="proj_a") + c2 = make_connection("c2", project="proj_b") + hub._register_connection(c1) + hub._register_connection(c2) + + assert hub.get_active_count("proj_a") == 1 + assert hub.get_active_count("proj_b") == 1 + assert len(hub.connections) == 2 + + @pytest.mark.asyncio + async def test_disconnect_removes_from_tracking(self): + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + + await hub._disconnect(conn) + + assert "c1" not in hub.connections + assert hub.get_active_count("test_project") == 0 + + @pytest.mark.asyncio + async def test_disconnect_removes_empty_project(self): + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + + await hub._disconnect(conn) + + assert "test_project" not in hub.project_connections + + @pytest.mark.asyncio + async def test_disconnect_keeps_other_connections(self): + hub = make_hub() + c1 = make_connection("c1") + c2 = make_connection("c2") + register_conn(hub, c1) + register_conn(hub, c2) + + await hub._disconnect(c1) + + assert "c1" not in hub.connections + assert "c2" in hub.connections + assert hub.get_active_count("test_project") == 1 + + +# ─── Instance Number Assignment ─── + +class TestInstanceNumbers: + + def test_first_instance_gets_1(self): + hub = make_hub() + num = hub._assign_instance_number("proj", "c1") + assert num == 1 + + def test_second_instance_gets_2(self): + hub = make_hub() + c1 = make_connection("c1", instance_number=1) + register_conn(hub, c1) + + num = hub._assign_instance_number("proj", "c2") + # project "proj" has no connections (c1 is in "test_project"), so it gets 1 + assert num == 1 + + def test_gap_filling(self): + """If instance #1 disconnects, next connection gets #1 again.""" + hub = make_hub() + c2 = make_connection("c2", instance_number=2) + register_conn(hub, c2) + + num = hub._assign_instance_number("test_project", "c3") + assert num == 1 # #1 is available + + def test_sequential_assignment(self): + hub = make_hub() + c1 = make_connection("c1", instance_number=1) + c2 = make_connection("c2", instance_number=2) + register_conn(hub, c1) + register_conn(hub, c2) + + num = hub._assign_instance_number("test_project", "c3") + assert num == 3 + + +# ─── Pending Owners Lifecycle ─── + +class TestPendingOwners: + + @pytest.mark.asyncio + async def test_pending_tracks_owner(self): + """Processing a pending message tracks the owner connection.""" + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + + await hub._handle_message(conn, MsgType.PENDING, { + "data": {"request_id": "req-001", "command": "ls"} + }) + + assert hub.pending_owners["req-001"] == "c1" + + @pytest.mark.asyncio + async def test_auto_resolve_clears_owner(self): + """Auto-resolve removes the pending owner tracking.""" + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + hub.pending_owners["req-001"] = "c1" + + await hub._handle_message(conn, MsgType.AUTO_RESOLVE, { + "data": {"request_id": "req-001"} + }) + + assert "req-001" not in hub.pending_owners + + @pytest.mark.asyncio + async def test_disconnect_reassigns_pending_to_same_project(self): + """When owner disconnects, pending is reassigned to another connection + in the same project (not deleted).""" + hub = make_hub() + c1 = make_connection("c1") + c2 = make_connection("c2") + register_conn(hub, c1) + register_conn(hub, c2) + hub.pending_owners["req-001"] = "c1" + + await hub._disconnect(c1) + + # req-001 should now be owned by c2 + assert hub.pending_owners.get("req-001") == "c2" + + @pytest.mark.asyncio + async def test_disconnect_deletes_pending_when_no_remaining(self): + """When owner disconnects and no other connections exist, pending is deleted.""" + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + hub.pending_owners["req-001"] = "c1" + + await hub._disconnect(conn) + + assert "req-001" not in hub.pending_owners + + @pytest.mark.asyncio + async def test_register_reassigns_orphaned_pending(self): + """When a new connection registers, orphaned pending_owners + (pointing to dead conn_ids) are reassigned to it.""" + hub = make_hub() + # Simulate orphaned pending (old conn_id "dead" not in connections) + hub.pending_owners["req-001"] = "dead-conn" + hub.pending_owners["req-002"] = "dead-conn" + + new_conn = make_connection("c-new") + hub._register_connection(new_conn) + + assert hub.pending_owners["req-001"] == "c-new" + assert hub.pending_owners["req-002"] == "c-new" + + @pytest.mark.asyncio + async def test_disconnect_reassigns_only_to_authenticated(self): + """Pending should only be reassigned to authenticated connections.""" + hub = make_hub() + c1 = make_connection("c1") + c2 = make_connection("c2", authenticated=False) # Not authenticated + register_conn(hub, c1) + register_conn(hub, c2) + hub.pending_owners["req-001"] = "c1" + + await hub._disconnect(c1) + + # c2 is not authenticated, so pending should be deleted + assert "req-001" not in hub.pending_owners + + +# ─── Response Routing ─── + +class TestResponseRouting: + + @pytest.mark.asyncio + async def test_send_response_to_owner(self): + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + hub.pending_owners["req-001"] = "c1" + + result = await hub.send_response_to_pending_owner( + "req-001", {"type": "response", "approved": True} + ) + + assert result is True + # Owner should be cleaned up after successful delivery + assert "req-001" not in hub.pending_owners + + @pytest.mark.asyncio + async def test_send_response_owner_dead_fallback(self): + """When original owner is dead, response routes to another connection + in the same project.""" + hub = make_hub() + dead_conn = make_connection("c-dead", ws=make_mock_ws(closed=True)) + alive_conn = make_connection("c-alive") + register_conn(hub, dead_conn) + register_conn(hub, alive_conn) + hub.pending_owners["req-001"] = "c-dead" + + result = await hub.send_response_to_pending_owner( + "req-001", {"type": "response", "approved": True} + ) + + assert result is True # Rerouted successfully + assert "req-001" not in hub.pending_owners + + @pytest.mark.asyncio + async def test_send_response_no_owner(self): + """When no owner exists for a request_id, return False.""" + hub = make_hub() + + result = await hub.send_response_to_pending_owner( + "req-unknown", {"type": "response"} + ) + + assert result is False + + @pytest.mark.asyncio + async def test_send_response_all_dead(self): + """When owner and all project connections are dead, return False.""" + hub = make_hub() + dead = make_connection("c-dead", ws=make_mock_ws(closed=True)) + register_conn(hub, dead) + hub.pending_owners["req-001"] = "c-dead" + + result = await hub.send_response_to_pending_owner( + "req-001", {"type": "response"} + ) + + assert result is False + assert "req-001" not in hub.pending_owners # Cleaned up + + +# ─── Message Routing ─── + +class TestMessageRouting: + + @pytest.mark.asyncio + async def test_broadcast_to_project(self): + hub = make_hub() + c1 = make_connection("c1") + c2 = make_connection("c2") + c3 = make_connection("c3", project="other_project") + register_conn(hub, c1) + register_conn(hub, c2) + register_conn(hub, c3) + + msg = {"type": "test", "data": "hello"} + await hub.broadcast_to_project("test_project", msg) + + # c1 and c2 should have received the message (via queue) + assert c1.send_queue.qsize() == 1 + assert c2.send_queue.qsize() == 1 + assert c3.send_queue.qsize() == 0 # Different project + + @pytest.mark.asyncio + async def test_send_to_instance(self): + hub = make_hub() + c1 = make_connection("c1", instance_number=1) + c2 = make_connection("c2", instance_number=2) + register_conn(hub, c1) + register_conn(hub, c2) + + result = await hub.send_to_instance("test_project", 2, {"type": "cmd"}) + + assert result is True + assert c1.send_queue.qsize() == 0 + assert c2.send_queue.qsize() == 1 + + @pytest.mark.asyncio + async def test_send_to_instance_not_found(self): + hub = make_hub() + c1 = make_connection("c1", instance_number=1) + register_conn(hub, c1) + + result = await hub.send_to_instance("test_project", 99, {"type": "cmd"}) + + assert result is False + + @pytest.mark.asyncio + async def test_broadcast_skips_unauthenticated(self): + hub = make_hub() + c1 = make_connection("c1", authenticated=True) + c2 = make_connection("c2", authenticated=False) + register_conn(hub, c1) + register_conn(hub, c2) + + await hub.broadcast_to_project("test_project", {"type": "msg"}) + + assert c1.send_queue.qsize() == 1 + assert c2.send_queue.qsize() == 0 + + +# ─── Message Handling (Dispatch) ─── + +class TestMessageHandling: + + @pytest.mark.asyncio + async def test_pending_sets_metadata(self): + """Pending messages should have source metadata injected.""" + hub = make_hub() + received = [] + hub.set_bot_handlers(on_pending=AsyncMock(side_effect=lambda p, d: received.append(d))) + + conn = make_connection("c1", pc_name="my_pc", instance_number=3) + register_conn(hub, conn) + + await hub._handle_message(conn, MsgType.PENDING, { + "data": {"request_id": "req-x", "command": "npm install"} + }) + + assert len(received) == 1 + payload = received[0] + assert payload["_conn_id"] == "c1" + assert payload["_instance_number"] == 3 + assert payload["_pc_name"] == "my_pc" + assert payload["project_name"] == "test_project" + + @pytest.mark.asyncio + async def test_chat_sets_project_name(self): + hub = make_hub() + received = [] + hub.set_bot_handlers(on_chat=AsyncMock(side_effect=lambda p, d: received.append(d))) + + conn = make_connection("c1", project="my_proj") + register_conn(hub, conn) + + await hub._handle_message(conn, MsgType.CHAT, { + "data": {"content": "hello"} + }) + + assert received[0]["project_name"] == "my_proj" + + @pytest.mark.asyncio + async def test_no_handler_doesnt_crash(self): + """If no bot handlers are set, messages should be silently dropped.""" + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + + # Should not raise + await hub._handle_message(conn, MsgType.PENDING, {"data": {"request_id": "r1"}}) + await hub._handle_message(conn, MsgType.CHAT, {"data": {"content": "hi"}}) + await hub._handle_message(conn, MsgType.REGISTER, {"data": {}}) + await hub._handle_message(conn, MsgType.AUTO_RESOLVE, {"data": {}}) + + @pytest.mark.asyncio + async def test_heartbeat_updates_timestamp(self): + hub = make_hub() + conn = make_connection("c1") + register_conn(hub, conn) + old_ts = conn.last_heartbeat + + # Small delay to ensure timestamp difference + await asyncio.sleep(0.01) + await hub._handle_message(conn, MsgType.HEARTBEAT, {}) + + assert conn.last_heartbeat > old_ts + + +# ─── Rate Limiting ─── + +class TestRateLimiting: + + def test_under_limit_allows(self): + hub = make_hub() + conn = make_connection("c1") + + for _ in range(PER_CONN_RATE_LIMIT - 1): + assert hub._check_rate_limit(conn) is True + + def test_at_limit_blocks(self): + hub = make_hub() + conn = make_connection("c1") + + # Fill up to limit + for _ in range(PER_CONN_RATE_LIMIT): + hub._check_rate_limit(conn) + + # Next should be blocked + assert hub._check_rate_limit(conn) is False + + def test_old_timestamps_expire(self): + hub = make_hub() + conn = make_connection("c1") + + # Add old timestamps (beyond RATE_WINDOW) + old_time = time.time() - 20.0 + conn._msg_timestamps = [old_time] * PER_CONN_RATE_LIMIT + + # Should be allowed (old timestamps are pruned) + assert hub._check_rate_limit(conn) is True + + +# ─── Deduplication ─── + +class TestDeduplication: + + def test_first_message_not_duplicate(self): + hub = make_hub() + assert hub._is_duplicate("msg-001") is False + + def test_same_id_is_duplicate(self): + hub = make_hub() + hub._is_duplicate("msg-001") + assert hub._is_duplicate("msg-001") is True + + def test_different_id_not_duplicate(self): + hub = make_hub() + hub._is_duplicate("msg-001") + assert hub._is_duplicate("msg-002") is False + + def test_cleanup_on_overflow(self): + hub = make_hub() + # Add many entries to trigger cleanup + for i in range(10001): + hub._recent_msg_ids[f"msg-{i}"] = time.time() + + # Should trigger cleanup and still work + hub._is_duplicate("new-msg") + assert len(hub._recent_msg_ids) <= 10002 # +1 for new-msg + + +# ─── Queue Backpressure ─── + +class TestQueueBackpressure: + + @pytest.mark.asyncio + async def test_queue_send_normal(self): + hub = make_hub() + conn = make_connection("c1") + + await hub._queue_send(conn, {"type": "test"}) + + assert conn.send_queue.qsize() == 1 + + @pytest.mark.asyncio + async def test_queue_full_drops_oldest(self): + hub = make_hub() + conn = make_connection("c1") + + # Fill the queue + for i in range(100): + await hub._queue_send(conn, {"type": "test", "i": i}) + + # Queue should be full (100 items) + assert conn.send_queue.qsize() == 100 + + # Adding one more should drop the oldest + await hub._queue_send(conn, {"type": "test", "i": "new"}) + + assert conn.send_queue.qsize() == 100 # Still 100 + + +# ─── Diagnostics ─── + +class TestDiagnostics: + + def test_get_status_empty(self): + hub = make_hub() + status = hub.get_status() + + assert status["total_connections"] == 0 + assert status["projects"] == {} + assert status["pending_owners"] == 0 + + def test_get_status_with_connections(self): + hub = make_hub() + c1 = make_connection("c1", project="proj_a", instance_number=1) + c2 = make_connection("c2", project="proj_a", instance_number=2) + register_conn(hub, c1) + register_conn(hub, c2) + hub.pending_owners["req-001"] = "c1" + + status = hub.get_status() + + assert status["total_connections"] == 2 + assert status["projects"]["proj_a"]["count"] == 2 + assert status["pending_owners"] == 1 + assert len(status["projects"]["proj_a"]["instances"]) == 2 + + +# ─── Auth (TokenManager) ─── + +class TestTokenManager: + + def test_create_and_verify_token(self): + tm = make_token_manager() + token = tm.create_token("my_project", "my_pc") + payload = tm.verify_token(token) + + assert payload is not None + assert payload["project"] == "my_project" + assert payload["pc"] == "my_pc" + + def test_expired_token_rejected(self): + tm = make_token_manager() + token = tm.create_token("proj", "pc", ttl=-1) # Already expired + assert tm.verify_token(token) is None + + def test_tampered_token_rejected(self): + tm = make_token_manager() + token = tm.create_token("proj", "pc") + tampered = token[:-1] + ("A" if token[-1] != "A" else "B") + assert tm.verify_token(tampered) is None + + def test_wrong_secret_rejected(self): + tm1 = make_token_manager(secret="secret-1" + "a" * 56) + tm2 = make_token_manager(secret="secret-2" + "b" * 56) + token = tm1.create_token("proj", "pc") + assert tm2.verify_token(token) is None + + def test_registration_code_valid(self): + tm = make_token_manager(reg_code="my-code") + assert tm.validate_registration_code("my-code") is True + assert tm.validate_registration_code("wrong") is False + + def test_registration_code_empty_allows_all(self): + tm = make_token_manager(reg_code="") + assert tm.validate_registration_code("anything") is True