refactor(extension): 모듈 분리 + Hub 통합 테스트 #task-395
- extension.ts 3,446→1,289줄 (-63%) - step-probe.ts (1,435줄): setupMonitor, processResponseFile, tryApprovalStrategies - observer-script.ts (687줄): DOM observer script - ws-client.ts (390줄): WSBridgeClient - step-utils.ts (114줄): step 파싱 유틸 - auth.py (115줄): JWT + registration code - hub.py (581줄): WSHub + per-client queue - Hub WS 연동 테스트 통과 (auth, chat, register) - VSIX v0.4.0 빌드
This commit is contained in:
@@ -1,35 +1,59 @@
|
||||
# Architecture
|
||||
|
||||
> 이 프로젝트의 아키텍처를 설명하는 문서입니다.
|
||||
> AI 에이전트는 구현 전 이 문서를 반드시 확인합니다.
|
||||
|
||||
## 프로젝트 개요
|
||||
|
||||
<!-- 프로젝트의 목적과 핵심 기능을 간략히 서술 -->
|
||||
|
||||
(프로젝트 설명을 여기에 작성하세요)
|
||||
Antigravity AI 코딩 에이전트의 Discord 연동 시스템.
|
||||
- AG Extension ↔ WebSocket Hub ↔ Discord Bot (실시간)
|
||||
- AG Extension ↔ 파일 bridge ↔ Collector ↔ Gateway ↔ Discord Bot (레거시)
|
||||
|
||||
## 디렉토리 구조
|
||||
|
||||
```
|
||||
project-root/
|
||||
├── src/ # 소스 코드
|
||||
├── tests/ # 테스트
|
||||
├── docs/ # 문서
|
||||
├── .agents/ # AI 에이전트 설정
|
||||
└── ...
|
||||
gravity_control/
|
||||
├── auth.py # JWT 토큰 관리
|
||||
├── hub.py # WebSocket Hub (메시지 라우팅, 인스턴스 관리)
|
||||
├── bot.py # Discord 봇 (승인 UI, 채널 관리, Hub 핸들러)
|
||||
├── gateway.py # HTTP REST + /ws endpoint
|
||||
├── bridge.py # 파일 기반 IPC (레거시)
|
||||
├── collector.py # 원격 파일→HTTP 릴레이 (Phase 2 삭제 예정)
|
||||
├── watcher.py # Brain 디렉토리 변경 감시
|
||||
├── config.py # 환경변수 설정
|
||||
├── main.py # 진입점 (Bot + Hub + Watcher 시작)
|
||||
├── parser.py # Markdown→Discord 파서
|
||||
├── extension/src/
|
||||
│ ├── extension.ts # 메인 Extension (3,300줄, 점진적 모듈화 진행)
|
||||
│ ├── ws-client.ts # WSBridgeClient (Hub 연결, 재연결, 메시지 큐)
|
||||
│ └── sdk/ # Antigravity SDK (로컬 임베드)
|
||||
└── .agents/references/ # AI 에이전트 레퍼런스 문서
|
||||
```
|
||||
|
||||
## 핵심 모듈
|
||||
|
||||
<!-- 각 모듈의 역할과 의존 관계를 설명 -->
|
||||
|
||||
| 모듈 | 역할 | 의존성 |
|
||||
|------|------|--------|
|
||||
| (모듈명) | (역할 설명) | (의존하는 모듈) |
|
||||
| hub.py | WS 연결 관리, 메시지 라우팅, 인스턴스 번호 | auth.py |
|
||||
| auth.py | JWT 토큰 생성/검증, registration code | - |
|
||||
| bot.py | Discord UI, 승인 처리, 채팅 릴레이 | hub.py, bridge.py, parser.py |
|
||||
| gateway.py | HTTP REST API + /ws endpoint | hub.py |
|
||||
| ws-client.ts | Extension→Hub WS 클라이언트 | - |
|
||||
| extension.ts | AG SDK 연동, step 감시, DOM observer | ws-client.ts, sdk/ |
|
||||
|
||||
## 데이터 흐름
|
||||
|
||||
<!-- 주요 데이터 흐름을 Mermaid 다이어그램이나 텍스트로 설명 -->
|
||||
|
||||
(데이터 흐름을 여기에 작성하세요)
|
||||
```
|
||||
[Extension]
|
||||
│
|
||||
┌────┴────┐
|
||||
│ WS Hub │ ← 실시간 (preferred)
|
||||
│ (ws-client.ts → hub.py)
|
||||
└────┬────┘
|
||||
│ ┌─────────────┐
|
||||
├───────────────────→│ Discord Bot │→ Discord
|
||||
│ └─────────────┘
|
||||
┌────┴────┐
|
||||
│파일 bridge│ ← 레거시 fallback
|
||||
│(Collector → Gateway)
|
||||
└─────────┘
|
||||
```
|
||||
|
||||
@@ -42,5 +42,8 @@
|
||||
| DISCORD_TOKEN | Discord 봇 토큰 | (필수) |
|
||||
| DISCORD_GUILD_ID | Discord 서버 ID | (필수) |
|
||||
| BRAIN_PATH | AG 브레인 경로 | `~/.gemini/antigravity/brain` |
|
||||
| BOT_MODE | 봇 모드 (local/remote) | `local` |
|
||||
| BOT_MODE | 봇 모드 (local/remote/gateway) | `local` |
|
||||
| REMOTE_BRIDGE_URL | 원격 브릿지 URL | (remote 모드 전용) |
|
||||
| GATEWAY_API_KEY | Gateway REST API 인증 키 | (gateway 모드) |
|
||||
| GRAVITY_HUB_SECRET | WS Hub JWT 서명 시크릿 | (자동생성 가능) |
|
||||
| GRAVITY_REGISTRATION_CODE | Extension 등록 코드 | (미설정 시 인증 생략) |
|
||||
|
||||
127
auth.py
Normal file
127
auth.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Authentication module — JWT token management for WebSocket Hub.
|
||||
|
||||
Two-stage auth:
|
||||
1. Extension connects with a registration code (built into .vsix at build time)
|
||||
2. Hub validates the code and issues a short-lived JWT session token
|
||||
3. Subsequent reconnections use the JWT token directly
|
||||
|
||||
The master secret is stored server-side only (GRAVITY_HUB_SECRET env var).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults
|
||||
DEFAULT_TOKEN_TTL = 86400 # 24 hours
|
||||
DEFAULT_REGISTRATION_CODE = "" # Set via GRAVITY_REGISTRATION_CODE env var
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Manages JWT-like token creation and verification.
|
||||
|
||||
Uses HMAC-SHA256 for signing. Tokens contain:
|
||||
- project: project name scope
|
||||
- pc: PC identifier (hostname or custom name)
|
||||
- iat: issued at (unix timestamp)
|
||||
- exp: expiration (unix timestamp)
|
||||
"""
|
||||
|
||||
def __init__(self, secret: str = "", registration_code: str = ""):
|
||||
self.secret = secret or os.getenv("GRAVITY_HUB_SECRET", "")
|
||||
self.registration_code = registration_code or os.getenv(
|
||||
"GRAVITY_REGISTRATION_CODE", DEFAULT_REGISTRATION_CODE
|
||||
)
|
||||
if not self.secret:
|
||||
# Auto-generate a secret if not set (ephemeral — tokens invalid after restart)
|
||||
self.secret = hashlib.sha256(os.urandom(32)).hexdigest()
|
||||
logger.warning(
|
||||
"[AUTH] No GRAVITY_HUB_SECRET set — generated ephemeral secret. "
|
||||
"Tokens will be invalid after server restart."
|
||||
)
|
||||
|
||||
def validate_registration_code(self, code: str) -> bool:
|
||||
"""Check if the provided registration code matches."""
|
||||
if not self.registration_code:
|
||||
# No registration code configured → allow all connections
|
||||
logger.warning("[AUTH] No registration code configured — accepting all")
|
||||
return True
|
||||
return hmac.compare_digest(code, self.registration_code)
|
||||
|
||||
def create_token(
|
||||
self, project: str, pc_name: str, ttl: int = DEFAULT_TOKEN_TTL
|
||||
) -> str:
|
||||
"""Create a signed token for a specific project and PC.
|
||||
|
||||
Returns a base64-encoded string: {header}.{payload}.{signature}
|
||||
"""
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"project": project,
|
||||
"pc": pc_name,
|
||||
"iat": now,
|
||||
"exp": now + ttl,
|
||||
}
|
||||
payload_b64 = _b64_encode(json.dumps(payload))
|
||||
signature = self._sign(payload_b64)
|
||||
return f"{payload_b64}.{signature}"
|
||||
|
||||
def verify_token(self, token: str) -> dict | None:
|
||||
"""Verify and decode a token.
|
||||
|
||||
Returns the payload dict if valid, None if invalid or expired.
|
||||
"""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
payload_b64, signature = parts
|
||||
expected_sig = self._sign(payload_b64)
|
||||
|
||||
if not hmac.compare_digest(signature, expected_sig):
|
||||
logger.warning("[AUTH] Invalid token signature")
|
||||
return None
|
||||
|
||||
payload = json.loads(_b64_decode(payload_b64))
|
||||
|
||||
# Check expiration
|
||||
if payload.get("exp", 0) < time.time():
|
||||
logger.info(f"[AUTH] Token expired for {payload.get('pc', '?')}")
|
||||
return None
|
||||
|
||||
return payload
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
logger.warning(f"[AUTH] Token decode error: {e}")
|
||||
return None
|
||||
|
||||
def _sign(self, data: str) -> str:
|
||||
"""HMAC-SHA256 sign and return base64."""
|
||||
sig = hmac.new(
|
||||
self.secret.encode("utf-8"),
|
||||
data.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return _b64_encode_bytes(sig)
|
||||
|
||||
|
||||
def _b64_encode(data: str) -> str:
|
||||
"""URL-safe base64 encode a string, no padding."""
|
||||
return urlsafe_b64encode(data.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def _b64_encode_bytes(data: bytes) -> str:
|
||||
"""URL-safe base64 encode bytes, no padding."""
|
||||
return urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def _b64_decode(data: str) -> str:
|
||||
"""URL-safe base64 decode, handles missing padding."""
|
||||
padded = data + "=" * (4 - len(data) % 4)
|
||||
return urlsafe_b64decode(padded).decode("utf-8")
|
||||
322
bot.py
322
bot.py
@@ -5,9 +5,15 @@ Multi-project channel architecture:
|
||||
- Each conversation maps to a project via conv_to_project dict
|
||||
- Extension registers projects via bridge/pending/ files
|
||||
- Commands include project_name for routing to correct IDE window
|
||||
|
||||
Multi-PC UX:
|
||||
- When multiple AG instances are active, messages get instance numbers (PC #1, #2)
|
||||
- Users can target specific instances with !N <message> (e.g. !2 hello)
|
||||
- When only one instance is active, natural conversation without numbers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -184,17 +190,41 @@ class GravityBot(commands.Bot):
|
||||
self._processed_message_ids: deque[int] = deque(maxlen=200) # dedup for Gateway event replay
|
||||
self._approval_messages: dict[str, int] = {} # FIX #4: request_id → discord message_id (for auto_resolved lookup)
|
||||
self.gateway = None # Set by main.py in gateway mode
|
||||
self.hub = None # Set by main.py in gateway mode (WSHub instance)
|
||||
|
||||
def _write_command(self, project: str, text: str, **kwargs):
|
||||
"""Write command to bridge AND push to gateway (if gateway mode)."""
|
||||
self.bridge.write_command(project, text, **kwargs)
|
||||
if self.gateway:
|
||||
import time
|
||||
self.gateway.push_command(project, {
|
||||
"id": str(int(time.time() * 1000)),
|
||||
def _write_command(self, project: str, text: str, *,
|
||||
target_instance: int | None = None, **kwargs):
|
||||
"""Write command to bridge AND push to gateway/hub.
|
||||
|
||||
Args:
|
||||
target_instance: If set, send only to this instance number (via Hub).
|
||||
If None, broadcast to all instances.
|
||||
"""
|
||||
cmd_data = {
|
||||
"text": text,
|
||||
"project_name": kwargs.get('project_name', project),
|
||||
})
|
||||
}
|
||||
|
||||
# Hub route (preferred if available)
|
||||
if self.hub:
|
||||
import time as _time
|
||||
cmd_data["id"] = str(int(_time.time() * 1000))
|
||||
msg = {"type": "command", "data": cmd_data}
|
||||
if target_instance is not None:
|
||||
asyncio.create_task(
|
||||
self.hub.send_to_instance(project, target_instance, msg)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
self.hub.broadcast_to_project(project, msg)
|
||||
)
|
||||
|
||||
# Legacy routes (file bridge + gateway HTTP)
|
||||
self.bridge.write_command(project, text, **kwargs)
|
||||
if self.gateway:
|
||||
import time as _time
|
||||
cmd_data["id"] = cmd_data.get("id", str(int(_time.time() * 1000)))
|
||||
self.gateway.push_command(project, cmd_data)
|
||||
|
||||
@staticmethod
|
||||
def _make_channel_name(project_name: str) -> str:
|
||||
@@ -206,6 +236,8 @@ class GravityBot(commands.Bot):
|
||||
self.pending_approval_scanner.start()
|
||||
self.chat_snapshot_scanner.start()
|
||||
self._register_slash_commands()
|
||||
# Register Hub handlers (if Hub is available, set after setup_hook by main.py)
|
||||
asyncio.get_event_loop().call_soon(self._register_hub_handlers)
|
||||
logger.info("Bot setup complete")
|
||||
|
||||
def _register_slash_commands(self):
|
||||
@@ -826,7 +858,33 @@ class GravityBot(commands.Bot):
|
||||
logger.info(f"Sent approval request: {request.request_id[:12]}")
|
||||
self._approval_messages[request.request_id] = msg.id # FIX #4: Track msg_id for auto_resolved lookup
|
||||
|
||||
# ─── Discord → IDE Text Relay ─────────────────────────────────────
|
||||
# ─── Discord → IDE Text Relay + Multi-PC UX ───────────────────────────
|
||||
|
||||
def _get_instance_header(self, project: str, instance_number: int) -> str:
|
||||
"""Format instance header based on active count.
|
||||
|
||||
Single instance: empty string (natural conversation)
|
||||
Multiple instances: **[PC #N]** prefix
|
||||
"""
|
||||
if not self.hub:
|
||||
return ""
|
||||
active = self.hub.get_active_count(project)
|
||||
if active <= 1:
|
||||
return ""
|
||||
return f"**[PC #{instance_number}]** "
|
||||
|
||||
def _parse_instance_target(self, text: str) -> tuple[int | None, str]:
|
||||
"""Parse !N prefix from message text.
|
||||
|
||||
Returns (target_instance, remaining_text).
|
||||
'!2 hello' -> (2, 'hello')
|
||||
'hello' -> (None, 'hello')
|
||||
'!stop' -> (None, '!stop') # special commands not treated as targeting
|
||||
"""
|
||||
match = re.match(r'^!(\d+)\s+(.+)', text, re.DOTALL)
|
||||
if match:
|
||||
return int(match.group(1)), match.group(2).strip()
|
||||
return None, text
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author == self.user:
|
||||
@@ -845,19 +903,24 @@ class GravityBot(commands.Bot):
|
||||
|
||||
text = message.content.strip()
|
||||
|
||||
# Parse !N instance targeting (before special commands)
|
||||
target_instance, actual_text = self._parse_instance_target(text)
|
||||
|
||||
# Special command: !stop — cancel AI work
|
||||
if text == "!stop":
|
||||
self._write_command(project, "!stop", project_name=project)
|
||||
if actual_text == "!stop":
|
||||
self._write_command(project, "!stop", target_instance=target_instance,
|
||||
project_name=project)
|
||||
target_label = f" (PC #{target_instance})" if target_instance else ""
|
||||
embed = discord.Embed(
|
||||
title="⏹️ AI 작업 중지",
|
||||
description=f"프로젝트: **{project}**\n중지 요청을 Extension에 전달했습니다.",
|
||||
description=f"프로젝트: **{project}**{target_label}\n중지 요청을 Extension에 전달했습니다.",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
await message.channel.send(embed=embed)
|
||||
return
|
||||
|
||||
# Special command: !auto — toggle auto-approve
|
||||
if text == "!auto":
|
||||
if actual_text == "!auto":
|
||||
# Toggle per-project auto-approve
|
||||
if project in self.auto_approve_projects:
|
||||
self.auto_approve_projects.discard(project)
|
||||
@@ -865,7 +928,8 @@ class GravityBot(commands.Bot):
|
||||
else:
|
||||
self.auto_approve_projects.add(project)
|
||||
enabled = True
|
||||
self._write_command(project, f"!auto {'on' if enabled else 'off'}", project_name=project)
|
||||
self._write_command(project, f"!auto {'on' if enabled else 'off'}",
|
||||
target_instance=target_instance, project_name=project)
|
||||
emoji = "🟢" if enabled else "🔴"
|
||||
mode = "자동 승인" if enabled else "수동 승인"
|
||||
embed = discord.Embed(
|
||||
@@ -877,18 +941,240 @@ class GravityBot(commands.Bot):
|
||||
await message.channel.send(embed=embed)
|
||||
return
|
||||
|
||||
# General text relay — routed by project
|
||||
if text:
|
||||
self._write_command(project, text, project_name=project)
|
||||
# General text relay — routed by project (+ optional instance targeting)
|
||||
if actual_text:
|
||||
self._write_command(project, actual_text, target_instance=target_instance,
|
||||
project_name=project)
|
||||
await message.add_reaction("📨")
|
||||
target_label = f" PC #{target_instance}" if target_instance else ""
|
||||
embed = discord.Embed(
|
||||
description=f"📨 → **{project}** IDE에 전달됨\n`{text[:100]}`",
|
||||
description=f"📨 → **{project}**{target_label} IDE에 전달됨\n`{actual_text[:100]}`",
|
||||
color=discord.Color.blurple(),
|
||||
)
|
||||
await message.channel.send(embed=embed, delete_after=10)
|
||||
|
||||
await self.process_commands(message)
|
||||
|
||||
# ─── Hub Event Handlers ──────────────────────────────────────────
|
||||
|
||||
def _register_hub_handlers(self):
|
||||
"""Register callbacks on the Hub for Extension->Bot messages."""
|
||||
if not self.hub:
|
||||
return
|
||||
self.hub.set_bot_handlers(
|
||||
on_pending=self._hub_on_pending,
|
||||
on_chat=self._hub_on_chat,
|
||||
on_register=self._hub_on_register,
|
||||
on_auto_resolve=self._hub_on_auto_resolve,
|
||||
on_brain_event=self._hub_on_brain_event,
|
||||
)
|
||||
logger.info("[BOT] Hub handlers registered")
|
||||
|
||||
async def _hub_on_pending(self, project: str, data: dict):
|
||||
"""Handle pending approval from Hub (Extension->Hub->Bot)."""
|
||||
try:
|
||||
request_id = data.get("request_id", "")
|
||||
if not request_id:
|
||||
return
|
||||
|
||||
# Skip if already sent
|
||||
if request_id in self._sent_approval_ids:
|
||||
return
|
||||
|
||||
# Check auto_resolved status
|
||||
status = data.get("status", "pending")
|
||||
if status in ("auto_resolved", "expired"):
|
||||
await self._handle_auto_resolved(request_id, status)
|
||||
return
|
||||
|
||||
instance_number = data.get("_instance_number", 0)
|
||||
pc_name = data.get("_pc_name", "")
|
||||
header = self._get_instance_header(project, instance_number)
|
||||
|
||||
# Build approval request
|
||||
request = ApprovalRequest(
|
||||
request_id=request_id,
|
||||
command=data.get("command", ""),
|
||||
description=data.get("description", ""),
|
||||
project_name=project,
|
||||
step_type=data.get("step_type", ""),
|
||||
status=status,
|
||||
)
|
||||
|
||||
# Auto-approve check
|
||||
if project in self.auto_approve_projects:
|
||||
await self._auto_approve_via_hub(request)
|
||||
return
|
||||
|
||||
# Send to Discord
|
||||
channel = await self._get_channel(project)
|
||||
if not channel:
|
||||
logger.warning(f"[HUB-PENDING] No channel for project={project}")
|
||||
return
|
||||
|
||||
buttons = data.get("buttons", [])
|
||||
desc_parts = []
|
||||
if header:
|
||||
desc_parts.append(header)
|
||||
desc_parts.append(f"**명령:** `{request.command[:200]}`")
|
||||
if buttons:
|
||||
btn_names = [b.get("text", "?") for b in buttons]
|
||||
desc_parts.append(f"**선택지:** {' / '.join(btn_names)}")
|
||||
if request.description:
|
||||
desc_parts.append(request.description[:500])
|
||||
|
||||
embed = discord.Embed(
|
||||
title="⚠️ 승인 요청",
|
||||
description="\n".join(desc_parts),
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_footer(text=f"ID: {request_id}")
|
||||
|
||||
view = ApprovalView(self.bridge, request, buttons=buttons)
|
||||
msg = await channel.send(embed=embed, view=view)
|
||||
|
||||
self._sent_approval_ids.add(request_id)
|
||||
self._approval_messages[request_id] = msg.id
|
||||
logger.info(f"[HUB-PENDING] Sent approval: {request_id[:12]} project={project}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB-PENDING] Error: {e}")
|
||||
|
||||
async def _auto_approve_via_hub(self, request: ApprovalRequest):
|
||||
"""Auto-approve a pending request via Hub."""
|
||||
if self.hub:
|
||||
await self.hub.send_response_to_pending_owner(request.request_id, {
|
||||
"type": "response",
|
||||
"data": {
|
||||
"request_id": request.request_id,
|
||||
"approved": True,
|
||||
"button_index": 0,
|
||||
"step_type": request.step_type,
|
||||
"project_name": request.project_name,
|
||||
},
|
||||
})
|
||||
# Also write via legacy bridge
|
||||
self.bridge.write_response(UserResponse(
|
||||
request_id=request.request_id, approved=True,
|
||||
step_type=request.step_type,
|
||||
project_name=request.project_name,
|
||||
))
|
||||
logger.info(f"[HUB-AUTO] Auto-approved: {request.request_id[:12]}")
|
||||
|
||||
async def _hub_on_chat(self, project: str, data: dict):
|
||||
"""Handle chat snapshot from Hub (Extension->Hub->Bot->Discord)."""
|
||||
try:
|
||||
content = data.get("content", "")
|
||||
attached_files = data.get("attached_files", [])
|
||||
if not content and not attached_files:
|
||||
return
|
||||
|
||||
instance_number = data.get("_instance_number", 0)
|
||||
header = self._get_instance_header(project, instance_number)
|
||||
|
||||
channel = await self._get_channel(project)
|
||||
if not channel:
|
||||
return
|
||||
|
||||
import io as _io
|
||||
discord_files = []
|
||||
for af in attached_files:
|
||||
af_name = af.get("name", "document.md")
|
||||
af_content = af.get("content", "")
|
||||
if af_content:
|
||||
discord_files.append(discord.File(
|
||||
_io.BytesIO(af_content.encode("utf-8")),
|
||||
filename=af_name,
|
||||
))
|
||||
|
||||
display_content = f"{header}{content}" if header else content
|
||||
|
||||
FILE_ATTACH_THRESHOLD = 4000
|
||||
if len(display_content) > FILE_ATTACH_THRESHOLD:
|
||||
summary = display_content[:500].rsplit('\n', 1)[0]
|
||||
embed = discord.Embed(
|
||||
title="💬 AI 대화 내용",
|
||||
description=f"{summary}\n\n📎 *전체 내용은 첨부 파일 참조* ({len(content):,}자)",
|
||||
color=discord.Color.purple(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
discord_files.append(discord.File(
|
||||
_io.BytesIO(content.encode("utf-8")),
|
||||
filename="chat_message.md",
|
||||
))
|
||||
await channel.send(embed=embed, files=discord_files)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
title="💬 AI 대화 내용",
|
||||
description=display_content,
|
||||
color=discord.Color.purple(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
await channel.send(
|
||||
embed=embed,
|
||||
files=discord_files if discord_files else discord.utils.MISSING,
|
||||
)
|
||||
|
||||
logger.info(f"[HUB-CHAT] Sent to #{channel.name} ({len(content)} chars)")
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB-CHAT] Error: {e}")
|
||||
|
||||
async def _hub_on_register(self, data: dict):
|
||||
"""Handle session registration from Hub."""
|
||||
conv_id = data.get("conversation_id", "")
|
||||
project = data.get("project_name", "")
|
||||
if conv_id and project:
|
||||
self.conv_to_project[conv_id] = project
|
||||
logger.info(f"[HUB-REG] {conv_id[:8]} → {project}")
|
||||
|
||||
async def _hub_on_auto_resolve(self, project: str, data: dict):
|
||||
"""Handle auto_resolve notification from Hub."""
|
||||
request_id = data.get("request_id", "")
|
||||
if request_id:
|
||||
await self._handle_auto_resolved(request_id, "auto_resolved")
|
||||
|
||||
async def _hub_on_brain_event(self, project: str, data: dict):
|
||||
"""Handle brain event from Hub (Extension->Hub->Bot->Discord)."""
|
||||
try:
|
||||
from watcher import BrainEvent, EventType
|
||||
event = BrainEvent(
|
||||
event_type=EventType(data.get("event_type", "file_changed")),
|
||||
conversation_id=data.get("conversation_id", ""),
|
||||
file_name=data.get("file_name", ""),
|
||||
file_path=None,
|
||||
content=data.get("content", ""),
|
||||
timestamp=data.get("timestamp", time.time()),
|
||||
)
|
||||
await self.event_queue.put(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB-EVENT] Error: {e}")
|
||||
|
||||
async def _handle_auto_resolved(self, request_id: str, status: str):
|
||||
"""Edit Discord message to show auto-resolved/expired status."""
|
||||
msg_id = self._approval_messages.get(request_id)
|
||||
if not msg_id:
|
||||
return
|
||||
# Find the channel containing this message
|
||||
for channel in self.project_channels.values():
|
||||
try:
|
||||
msg = await channel.fetch_message(msg_id)
|
||||
embed = msg.embeds[0] if msg.embeds else None
|
||||
if embed:
|
||||
if status == "auto_resolved":
|
||||
embed.color = discord.Color.green()
|
||||
embed.set_footer(text="✅ 자동 해결됨")
|
||||
else:
|
||||
embed.color = discord.Color.greyple()
|
||||
embed.set_footer(text="⏰ 만료됨")
|
||||
await msg.edit(embed=embed, view=None)
|
||||
self._approval_messages.pop(request_id, None)
|
||||
break
|
||||
except (discord.NotFound, discord.Forbidden):
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
|
||||
# ─── Chat Snapshot Scanner ─────────────────────────────────────────
|
||||
|
||||
@tasks.loop(seconds=5)
|
||||
|
||||
@@ -48,6 +48,10 @@ class Config:
|
||||
REMOTE_BRIDGE_URL: str = os.getenv("REMOTE_BRIDGE_URL", "")
|
||||
GATEWAY_API_KEY: str = os.getenv("GATEWAY_API_KEY", "")
|
||||
|
||||
# WebSocket Hub
|
||||
GRAVITY_HUB_SECRET: str = os.getenv("GRAVITY_HUB_SECRET", "") # JWT signing secret
|
||||
GRAVITY_REGISTRATION_CODE: str = os.getenv("GRAVITY_REGISTRATION_CODE", "") # Extension auth
|
||||
|
||||
@classmethod
|
||||
def validate(cls) -> list[str]:
|
||||
"""Return list of configuration errors."""
|
||||
|
||||
5
docs/devlog/2026-03-17.md
Normal file
5
docs/devlog/2026-03-17.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Devlog — 2026-03-17
|
||||
|
||||
| # | 시간 | 작업 | 커밋 | 상태 |
|
||||
|---|------|------|------|------|
|
||||
| 009 | 00:00~06:38 | Extension 모듈 분리 + Hub 통합 테스트 + VSIX v0.4.0 빌드 | `TBD` | ✅ |
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -2,7 +2,7 @@
|
||||
"name": "gravity-bridge",
|
||||
"displayName": "Gravity Bridge",
|
||||
"description": "Antigravity ↔ Discord 브리지 연동 확장",
|
||||
"version": "0.3.16",
|
||||
"version": "0.4.0",
|
||||
"publisher": "variet",
|
||||
"engines": {
|
||||
"vscode": "^1.100.0"
|
||||
@@ -68,6 +68,16 @@
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "프로젝트 이름 (기본: git remote 레포명)"
|
||||
},
|
||||
"gravityBridge.hubUrl": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "WebSocket Hub URL (예: wss://your-server.com/ws)"
|
||||
},
|
||||
"gravityBridge.registrationCode": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Hub 등록 코드 (서버에서 발급)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
698
extension/src/observer-script.ts
Normal file
698
extension/src/observer-script.ts
Normal file
@@ -0,0 +1,698 @@
|
||||
/**
|
||||
* Approval Observer Script — injected into AG's renderer process.
|
||||
*
|
||||
* This is a self-contained JavaScript string template that runs in the
|
||||
* browser context (no Node.js APIs). It scans the DOM for approval buttons,
|
||||
* reports them to the HTTP bridge, and handles trigger clicks.
|
||||
*
|
||||
* Extracted from extension.ts for maintainability.
|
||||
*/
|
||||
|
||||
export function generateApprovalObserverScript(_port: number): string {
|
||||
// Port is hardcoded as fallback, but renderer also reads ag-bridge-ports.json for multi-bridge
|
||||
return `
|
||||
// ── Gravity Bridge v3: Approval Observer (deep DOM traversal — iframes, webviews, shadow DOMs) ──
|
||||
(function(){
|
||||
'use strict';
|
||||
var BASE='',_obs=false,_sent={},_ready=false;
|
||||
var _scanScheduled=false,_lastScanTs=0;
|
||||
var THROTTLE_MS=100;
|
||||
var CLEANUP_MS=300000;
|
||||
var _domDumped=false;
|
||||
|
||||
function log(m){console.log('[GB Observer] '+m);}
|
||||
log('v3 Script loaded — deep DOM traversal enabled');
|
||||
|
||||
// ── Deep DOM Traversal: find buttons across ALL boundaries ──
|
||||
// Searches: main document → iframes (contentDocument) → webview elements → shadow DOMs
|
||||
function deepFindButtons(patterns){
|
||||
var results=[];
|
||||
// 1. Main document buttons
|
||||
collectButtons(document,results,patterns,'main');
|
||||
// 2. Iframe traversal (try contentDocument — works if same-origin or webSecurity off)
|
||||
var iframes=document.querySelectorAll('iframe');
|
||||
for(var i=0;i<iframes.length;i++){
|
||||
try{
|
||||
var idoc=iframes[i].contentDocument||iframes[i].contentWindow.document;
|
||||
if(idoc){collectButtons(idoc,results,patterns,'iframe#'+i+'('+iframes[i].className.substring(0,30)+')');}
|
||||
}catch(e){
|
||||
// Cross-origin — can't access. Log only on first dom dump
|
||||
if(!_domDumped)log('iframe#'+i+' cross-origin: '+e.message.substring(0,60));
|
||||
}
|
||||
}
|
||||
// 3. Webview elements (Electron <webview> tag — has executeJavaScript)
|
||||
var webviews=document.querySelectorAll('webview');
|
||||
for(var w=0;w<webviews.length;w++){
|
||||
try{
|
||||
var wvDoc=webviews[w].contentDocument;
|
||||
if(wvDoc){collectButtons(wvDoc,results,patterns,'webview#'+w);}
|
||||
}catch(e){
|
||||
if(!_domDumped)log('webview#'+w+' access error: '+e.message.substring(0,60));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
function collectButtons(doc,results,patterns,source){
|
||||
if(!doc||!doc.querySelectorAll)return;
|
||||
var btns=doc.querySelectorAll('button');
|
||||
for(var i=0;i<btns.length;i++){
|
||||
var b=btns[i];
|
||||
if(b.disabled||b.hidden)continue;
|
||||
try{if(!b.offsetParent&&b.style.display!=='fixed')continue;}catch(e){}
|
||||
var txt=(b.textContent||'').trim();
|
||||
if(!txt)continue;
|
||||
for(var p=0;p<patterns.length;p++){
|
||||
if(patterns[p].test(txt)){
|
||||
results.push({btn:b,text:txt,source:source});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 4. Recurse into shadow DOMs
|
||||
try{
|
||||
var allEls=doc.querySelectorAll('*');
|
||||
for(var j=0;j<allEls.length;j++){
|
||||
var sr=allEls[j].shadowRoot;
|
||||
if(sr)collectButtons(sr,results,patterns,source+'>shadow');
|
||||
}
|
||||
}catch(e){}
|
||||
}
|
||||
|
||||
// ── Deep DOM Inspector (recursive, POSTs results to bridge) ──
|
||||
function runDeepInspect(){
|
||||
var result={timestamp:new Date().toISOString(),windowURL:window.location.href,windowOrigin:window.location.origin,windowProtocol:window.location.protocol,framesCount:window.frames.length,nodes:[]};
|
||||
log('DEEP-INSPECT: starting recursive DOM analysis...');
|
||||
|
||||
function inspectDoc(doc,depth,label){
|
||||
var node={label:label,depth:depth,accessible:true,url:'',buttons:[],roleBtns:[],iframes:[],webviews:[],shadowDOMs:0,totalElements:0};
|
||||
if(!doc){node.accessible=false;node.error='null document';result.nodes.push(node);return;}
|
||||
try{node.url=(doc.URL||doc.documentURI||'unknown').substring(0,200);}catch(e){node.url='blocked';}
|
||||
try{node.title=(doc.title||'').substring(0,100);}catch(e){}
|
||||
try{node.readyState=doc.readyState;}catch(e){}
|
||||
|
||||
// CSP
|
||||
try{
|
||||
var csp=doc.querySelectorAll('meta[http-equiv="Content-Security-Policy"]');
|
||||
if(csp.length>0){node.csp=[];for(var c=0;c<csp.length;c++){node.csp.push((csp[c].content||'').substring(0,200));}}
|
||||
}catch(e){}
|
||||
|
||||
try{
|
||||
var allEls=doc.querySelectorAll('*');
|
||||
node.totalElements=allEls.length;
|
||||
// Buttons
|
||||
var btns=doc.querySelectorAll('button');
|
||||
for(var i=0;i<btns.length;i++){
|
||||
var b=btns[i];
|
||||
var txt=(b.textContent||'').trim().substring(0,80);
|
||||
if(!txt)continue;
|
||||
var cls=(b.className||'').substring(0,60);
|
||||
var disabled=b.disabled;
|
||||
var hidden=b.hidden||false;
|
||||
try{if(!b.offsetParent&&b.style.display!=='fixed')hidden=true;}catch(e){}
|
||||
var aria=b.getAttribute('aria-label')||'';
|
||||
var ttl=b.getAttribute('title')||'';
|
||||
node.buttons.push({text:txt,class:cls,disabled:disabled,hidden:hidden,aria:aria,title:ttl});
|
||||
}
|
||||
// role=button
|
||||
var rbs=doc.querySelectorAll('[role="button"]');
|
||||
for(var r=0;r<rbs.length;r++){
|
||||
if(rbs[r].tagName==='BUTTON')continue;
|
||||
var rtxt=(rbs[r].textContent||'').trim().substring(0,60);
|
||||
node.roleBtns.push({tag:rbs[r].tagName.toLowerCase(),text:rtxt});
|
||||
}
|
||||
// Shadow DOMs
|
||||
for(var s=0;s<allEls.length;s++){
|
||||
var sr=allEls[s].shadowRoot;
|
||||
if(sr){node.shadowDOMs++;inspectDoc(sr,depth+1,'shadow(<'+allEls[s].tagName.toLowerCase()+' class="'+(allEls[s].className||'').substring(0,30)+'">)');}
|
||||
}
|
||||
// Iframes
|
||||
var ifs=doc.querySelectorAll('iframe');
|
||||
for(var fi=0;fi<ifs.length;fi++){
|
||||
var f=ifs[fi];
|
||||
var finfo={index:fi,class:(f.className||'').substring(0,60),src:(f.src||'').substring(0,150),id:f.id||'',sandbox:f.getAttribute('sandbox')||'',allow:f.getAttribute('allow')||'',accessible:false,cwExists:false,cwFrames:0};
|
||||
try{
|
||||
var idoc=f.contentDocument||(f.contentWindow&&f.contentWindow.document);
|
||||
if(idoc){finfo.accessible=true;inspectDoc(idoc,depth+1,'iframe#'+fi+'('+finfo.class.substring(0,30)+')');
|
||||
}else{finfo.error='contentDocument=null';}
|
||||
}catch(e){finfo.error=e.message.substring(0,80);}
|
||||
try{var cw=f.contentWindow;if(cw){finfo.cwExists=true;finfo.cwFrames=cw.length;try{finfo.cwLocation=cw.location.href;}catch(e2){finfo.cwLocation='blocked: '+e2.message.substring(0,40);}}}
|
||||
catch(e){}
|
||||
node.iframes.push(finfo);
|
||||
}
|
||||
// Webviews
|
||||
var wvs=doc.querySelectorAll('webview');
|
||||
for(var wi=0;wi<wvs.length;wi++){
|
||||
var wv=wvs[wi];
|
||||
var winfo={index:wi,src:(wv.src||'').substring(0,150),class:(wv.className||'').substring(0,60),partition:wv.getAttribute('partition')||'',preload:wv.getAttribute('preload')||'',nodeintegration:wv.getAttribute('nodeintegration')||'',webpreferences:wv.getAttribute('webpreferences')||'',hasExecJS:typeof wv.executeJavaScript==='function',contentDocAccessible:false};
|
||||
try{var wdoc=wv.contentDocument;if(wdoc){winfo.contentDocAccessible=true;inspectDoc(wdoc,depth+1,'webview#'+wi+'.contentDocument');}}catch(e){winfo.contentDocError=e.message.substring(0,60);}
|
||||
node.webviews.push(winfo);
|
||||
}
|
||||
}catch(e){node.error=e.message;}
|
||||
result.nodes.push(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
inspectDoc(document,0,'MainDocument');
|
||||
|
||||
// Webview executeJavaScript probe (async)
|
||||
var webviews=document.querySelectorAll('webview');
|
||||
var probesPending=webviews.length;
|
||||
result.webviewProbes=[];
|
||||
if(probesPending===0)postResults();
|
||||
for(var pw=0;pw<webviews.length;pw++){
|
||||
(function(wv,idx){
|
||||
if(typeof wv.executeJavaScript!=='function'){result.webviewProbes.push({index:idx,error:'executeJavaScript not available'});probesPending--;if(probesPending<=0)postResults();return;}
|
||||
try{
|
||||
wv.executeJavaScript('(function(){var btns=document.querySelectorAll("button");var allEls=document.querySelectorAll("*");var ifs=document.querySelectorAll("iframe");var wvs=document.querySelectorAll("webview");var btnArr=[];for(var i=0;i<btns.length;i++){var b=btns[i];var txt=(b.textContent||"").trim();var cls=(b.className||"").substring(0,50);var dis=b.disabled;var hid=b.hidden||!b.offsetParent;btnArr.push({text:txt.substring(0,60),class:cls,disabled:dis,hidden:hid,aria:b.getAttribute("aria-label")||"",title:b.getAttribute("title")||""});}var rbs=document.querySelectorAll("[role=button]");var rbArr=[];for(var j=0;j<rbs.length;j++){if(rbs[j].tagName!=="BUTTON")rbArr.push({tag:rbs[j].tagName.toLowerCase(),text:(rbs[j].textContent||"").trim().substring(0,40)});}var sc=0;for(var k=0;k<allEls.length;k++){if(allEls[k].shadowRoot)sc++;}return JSON.stringify({url:document.URL,title:document.title,totalElements:allEls.length,buttons:btnArr,roleBtns:rbArr,iframes:ifs.length,webviews:wvs.length,shadowDOMs:sc});})()')
|
||||
.then(function(r){
|
||||
try{var d=JSON.parse(r);result.webviewProbes.push({index:idx,success:true,data:d});log('DEEP-INSPECT: webview#'+idx+' probe OK: '+d.buttons.length+' buttons, '+d.totalElements+' elements');}catch(e){result.webviewProbes.push({index:idx,parseError:e.message,raw:r});}
|
||||
probesPending--;if(probesPending<=0)postResults();
|
||||
})
|
||||
.catch(function(e){
|
||||
result.webviewProbes.push({index:idx,execError:e.message});
|
||||
log('DEEP-INSPECT: webview#'+idx+' execJS error: '+e.message);
|
||||
probesPending--;if(probesPending<=0)postResults();
|
||||
});
|
||||
}catch(e){
|
||||
result.webviewProbes.push({index:idx,callError:e.message});
|
||||
probesPending--;if(probesPending<=0)postResults();
|
||||
}
|
||||
})(webviews[pw],pw);
|
||||
}
|
||||
|
||||
function postResults(){
|
||||
var summary='nodes='+result.nodes.length;
|
||||
var totalBtns=0;for(var n=0;n<result.nodes.length;n++)totalBtns+=result.nodes[n].buttons.length;
|
||||
summary+=' totalButtons='+totalBtns+' webviewProbes='+result.webviewProbes.length;
|
||||
log('DEEP-INSPECT complete: '+summary);
|
||||
// Also log buttons from each node
|
||||
for(var n2=0;n2<result.nodes.length;n2++){
|
||||
var nd=result.nodes[n2];
|
||||
if(nd.buttons.length>0){
|
||||
log(' '+nd.label+': '+nd.buttons.length+' buttons');
|
||||
for(var bi=0;bi<Math.min(15,nd.buttons.length);bi++){
|
||||
log(' ['+bi+'] "'+nd.buttons[bi].text+'"'+(nd.buttons[bi].disabled?' DISABLED':'')+(nd.buttons[bi].hidden?' HIDDEN':''));
|
||||
}
|
||||
}
|
||||
}
|
||||
// POST to bridge
|
||||
fetch(BASE+'/deep-inspect-result',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify(result)})
|
||||
.then(function(){log('DEEP-INSPECT results posted to bridge');})
|
||||
.catch(function(e){log('DEEP-INSPECT post error: '+e.message);});
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-dump on startup (3s delay)
|
||||
function dumpDOMStructure(){runDeepInspect();}
|
||||
|
||||
// ── Port Discovery: async fetch()-based (sync XHR blocked in Electron renderer) ──
|
||||
var HARDCODED_PORT=${_port};
|
||||
|
||||
function tryPingAsync(port){
|
||||
return fetch('http://127.0.0.1:'+port+'/ping?t='+Date.now(),{signal:AbortSignal.timeout(2000)})
|
||||
.then(function(r){return r.text();})
|
||||
.then(function(t){return t==='pong';})
|
||||
.catch(function(){return false;});
|
||||
}
|
||||
|
||||
function discoverPort(cb){
|
||||
log('Trying hardcoded port '+HARDCODED_PORT+'...');
|
||||
tryPingAsync(HARDCODED_PORT).then(function(ok){
|
||||
if(ok){log('Port discovered (hardcoded): '+HARDCODED_PORT);cb(HARDCODED_PORT);return;}
|
||||
log('Hardcoded port failed, retrying with backoff...');
|
||||
|
||||
var attempts=0;
|
||||
var timer=setInterval(function(){
|
||||
attempts++;
|
||||
if(attempts>60){clearInterval(timer);log('Port discovery timeout after 2min');return;}
|
||||
tryPingAsync(HARDCODED_PORT).then(function(ok2){
|
||||
if(ok2){clearInterval(timer);log('Port discovered (retry #'+attempts+'): '+HARDCODED_PORT);cb(HARDCODED_PORT);}
|
||||
});
|
||||
},2000);
|
||||
});
|
||||
}
|
||||
|
||||
discoverPort(function(port){
|
||||
BASE='http://127.0.0.1:'+port;
|
||||
fetch(BASE+'/ping').then(function(r){return r.text();}).then(function(t){
|
||||
if(t==='pong'){log('Bridge connected on port '+port);_ready=true;startObserver();setTimeout(dumpDOMStructure,3000);}
|
||||
else log('Bridge ping failed: '+t);
|
||||
}).catch(function(e){log('Bridge unreachable: '+e.message);});
|
||||
});
|
||||
|
||||
// ── Button patterns to detect (order matters: first match wins per scan) ──
|
||||
// ONLY positive triggers should initiate a pending request group.
|
||||
// Negative/secondary buttons (Deny, Reject, Dismiss) will be collected as siblings.
|
||||
var PATS=[
|
||||
{re:/^Run/i, type:'terminal_command'},
|
||||
{re:/^Accept all$/i, type:'diff_review'},
|
||||
{re:/^Accept$/i, type:'agent_step'},
|
||||
{re:/^Allow/i, type:'permission'},
|
||||
{re:/^Approve/i, type:'agent_step'},
|
||||
{re:/^Retry$/i, type:'error_recovery'},
|
||||
];
|
||||
|
||||
// ALL actionable button patterns (for grouping siblings in same container)
|
||||
var ALL_ACTION_RE=[/^Run/i,/^Accept/i,/^Reject/i,/^Allow/i,/^Deny/i,/^Approve/i,/^Cancel$/i,/^Retry$/i,/^Dismiss$/i,/^Stop$/i,/^Decline$/i];
|
||||
|
||||
// Reject button patterns for finding the counterpart
|
||||
var REJECT_RE=[/^reject$/i,/^reject all$/i,/^cancel$/i,/^deny$/i,/^stop$/i,/^decline$/i,/^dismiss$/i];
|
||||
|
||||
// ── Stable button fingerprint (no getBoundingClientRect — scroll-safe) ──
|
||||
function btnId(b,type){
|
||||
// Use: type + button text + parent's first 40 chars of text content
|
||||
var txt=(b.textContent||'').trim();
|
||||
var parent=b.parentElement;
|
||||
var pctx=parent?(parent.textContent||'').substring(0,40).replace(/\\s+/g,' '):'';
|
||||
// Also use DOM position: nth-child among sibling buttons
|
||||
var idx=0;
|
||||
if(parent){
|
||||
var siblings=parent.querySelectorAll('button');
|
||||
for(var i=0;i<siblings.length;i++){if(siblings[i]===b){idx=i;break;}}
|
||||
}
|
||||
return type+'|'+txt+'|'+idx+'|'+pctx.substring(0,20);
|
||||
}
|
||||
|
||||
// ── Context extraction — walk up DOM to find command/code description ──
|
||||
function extractContext(b){
|
||||
// Strategy 1: Look for code/pre/terminal blocks near the button
|
||||
var container=b.closest('[class*="step"]')
|
||||
||b.closest('[class*="action"]')
|
||||
||b.closest('[class*="tool"]')
|
||||
||b.closest('[class*="cascade"]')
|
||||
||b.closest('[class*="message"]');
|
||||
if(!container)container=b.parentElement;
|
||||
if(!container)return '';
|
||||
|
||||
// Look for code blocks
|
||||
var codeEl=container.querySelector('pre,code,[class*="command"],[class*="terminal"],[class*="code-block"]');
|
||||
if(codeEl){
|
||||
var codeText=(codeEl.textContent||'').trim();
|
||||
if(codeText.length>0)return codeText.substring(0,500);
|
||||
}
|
||||
|
||||
// Strategy 2: Get surrounding text (exclude button text itself)
|
||||
var full=(container.textContent||'');
|
||||
var btnText=(b.textContent||'');
|
||||
var desc=full.replace(btnText,'').trim();
|
||||
// Trim to reasonable length
|
||||
return desc.substring(0,500);
|
||||
}
|
||||
|
||||
// ── Find common container of related buttons ──
|
||||
function findButtonContainer(btn){
|
||||
return btn.closest('[class*="step"]')
|
||||
||btn.closest('[class*="action"]')
|
||||
||btn.closest('[class*="tool"]')
|
||||
||btn.closest('[class*="cascade"]')
|
||||
||btn.closest('[class*="message"]')
|
||||
||btn.closest('[class*="dialog"]')
|
||||
||btn.closest('[class*="notification"]')
|
||||
||btn.parentElement;
|
||||
}
|
||||
|
||||
// ── Collect all actionable sibling buttons from a container ──
|
||||
function collectSiblingButtons(container,triggerBtn){
|
||||
if(!container)return [];
|
||||
var siblings=container.querySelectorAll('button');
|
||||
var result=[];
|
||||
for(var i=0;i<siblings.length;i++){
|
||||
var sb=siblings[i];
|
||||
if(sb.disabled||sb.hidden)continue;
|
||||
try{if(!sb.offsetParent&&sb.style.display!=='fixed')continue;}catch(e){}
|
||||
var stxt=(sb.textContent||'').trim();
|
||||
stxt=stxt.replace(/(Alt|Ctrl|Shift|Meta)\+.*/,'').trim();
|
||||
if(!stxt)continue;
|
||||
// Check if this button matches any actionable pattern
|
||||
var isAction=false;
|
||||
for(var a=0;a<ALL_ACTION_RE.length;a++){
|
||||
if(ALL_ACTION_RE[a].test(stxt)){isAction=true;break;}
|
||||
}
|
||||
if(!isAction)continue;
|
||||
result.push({btn:sb,text:stxt,isPrimary:(sb===triggerBtn)});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ── Find the React app container (Antigravity's main UI root) ──
|
||||
function findPanel(){
|
||||
// Priority order of panel selectors (most specific first)
|
||||
var selectors=[
|
||||
'.antigravity-agent-side-panel',
|
||||
'#jetski-agent-panel',
|
||||
'.react-app-container',
|
||||
'[class*="agent-panel"]',
|
||||
'[class*="agentPanel"]',
|
||||
];
|
||||
for(var i=0;i<selectors.length;i++){
|
||||
var el=document.querySelector(selectors[i]);
|
||||
if(el)return el;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// ── Core scan — finds actionable buttons and reports to bridge ──
|
||||
// Groups related buttons from same container into a single pending
|
||||
function scan(){
|
||||
if(!_ready)return;
|
||||
var now=Date.now();
|
||||
|
||||
var panel=findPanel();
|
||||
// Expand search: panel-scoped first, then full body for review bars
|
||||
var searchRoots=[];
|
||||
if(panel)searchRoots.push(panel);
|
||||
// Always also scan body for diff review bar (Accept all/Reject all)
|
||||
// which lives outside the agent panel in the editor notification area
|
||||
if(document.body)searchRoots.push(document.body);
|
||||
if(!searchRoots.length)return;
|
||||
|
||||
var seen={}; // dedupe buttons across search roots
|
||||
for(var r=0;r<searchRoots.length;r++){
|
||||
var allBtns=searchRoots[r].querySelectorAll('button');
|
||||
if(!allBtns.length)continue;
|
||||
|
||||
for(var j=0;j<allBtns.length;j++){
|
||||
var b=allBtns[j];
|
||||
if(b.disabled||b.hidden)continue;
|
||||
// Check visibility (offsetParent null = hidden via CSS)
|
||||
if(!b.offsetParent&&b.style.display!=='fixed')continue;
|
||||
|
||||
var txt=(b.textContent||'').trim();
|
||||
if(!txt)continue;
|
||||
// Strip keyboard shortcut suffixes (e.g. "RunAlt+↵" → "Run")
|
||||
txt=txt.replace(/(Alt|Ctrl|Shift|Meta)\+.*/,'').trim();
|
||||
if(!txt)continue;
|
||||
|
||||
// Match against patterns
|
||||
var matchedType=null;
|
||||
for(var p=0;p<PATS.length;p++){
|
||||
if(PATS[p].re.test(txt)){matchedType=PATS[p].type;break;}
|
||||
}
|
||||
if(!matchedType)continue;
|
||||
|
||||
// Generate stable ID for the GROUP (use container-based key)
|
||||
var container=findButtonContainer(b);
|
||||
var groupKey=matchedType+'|group|'+(container?(container.textContent||'').substring(0,40).replace(/\\s+/g,' '):'none');
|
||||
if(_sent[groupKey])continue;
|
||||
|
||||
// Collect ALL related buttons from the same container
|
||||
var siblings=collectSiblingButtons(container,b);
|
||||
if(siblings.length===0)siblings=[{btn:b,text:txt,isPrimary:true}];
|
||||
|
||||
// Build buttons array for multi-choice support
|
||||
var buttonsArr=[];
|
||||
var btnRefs=[];
|
||||
var bidList=[];
|
||||
for(var si=0;si<siblings.length;si++){
|
||||
var sb=siblings[si];
|
||||
var sbid=btnId(sb.btn,matchedType);
|
||||
buttonsArr.push({text:sb.text,index:si,is_primary:sb.isPrimary});
|
||||
btnRefs.push(sb.btn);
|
||||
bidList.push(sbid);
|
||||
}
|
||||
|
||||
// Extract context from trigger button
|
||||
var desc=extractContext(b);
|
||||
var rid=now.toString()+'_'+Math.random().toString(36).substring(2,6);
|
||||
|
||||
// Mark entire group as sent
|
||||
_sent[groupKey]={rid:rid,ts:now};
|
||||
for(var mk=0;mk<bidList.length;mk++){_sent[bidList[mk]]={rid:rid,ts:now};}
|
||||
|
||||
log('DETECTED GROUP '+matchedType+': '+buttonsArr.map(function(x){return '"'+x.text+'"';}).join(', ')+' → pending to bridge');
|
||||
|
||||
// Send to bridge (closure to capture refs)
|
||||
(function(rid2,btnRefs2,bidList2,groupKey2,txt2,desc2,type2,buttonsArr2){
|
||||
var payload={
|
||||
request_id:rid2,
|
||||
command:txt2,
|
||||
description:desc2,
|
||||
step_type:type2,
|
||||
buttons:buttonsArr2
|
||||
};
|
||||
fetch(BASE+'/pending',{
|
||||
method:'POST',
|
||||
headers:{'Content-Type':'application/json'},
|
||||
body:JSON.stringify(payload)
|
||||
}).then(function(r){return r.json();}).then(function(d){
|
||||
log('Pending created: '+d.request_id+' for group ['+buttonsArr2.map(function(x){return x.text;}).join(', ')+']');
|
||||
pollResponseGroup(d.request_id,btnRefs2,bidList2,groupKey2);
|
||||
}).catch(function(e){
|
||||
log('POST error: '+e.message);
|
||||
delete _sent[groupKey2];
|
||||
for(var di=0;di<bidList2.length;di++){delete _sent[bidList2[di]];}
|
||||
});
|
||||
})(rid,btnRefs,bidList,groupKey,txt,desc,matchedType,buttonsArr);
|
||||
|
||||
// Process ONE button GROUP per scan cycle (avoid flooding)
|
||||
return;
|
||||
}
|
||||
} // end searchRoots loop
|
||||
}
|
||||
|
||||
// ── Poll for Discord response (multi-button group aware) ──
|
||||
function pollResponseGroup(rid,btnRefs,bidList,groupKey){
|
||||
var polls=0;
|
||||
var maxPolls=200; // 5 minutes at 1500ms interval
|
||||
var timer=setInterval(function(){
|
||||
polls++;
|
||||
// Check if ANY button in the group is still in DOM
|
||||
var anyAlive=false;
|
||||
for(var ai=0;ai<btnRefs.length;ai++){
|
||||
if(document.body.contains(btnRefs[ai])){anyAlive=true;break;}
|
||||
}
|
||||
if(!anyAlive){
|
||||
log('All buttons removed from DOM — stopping poll for '+rid);
|
||||
clearInterval(timer);
|
||||
delete _sent[groupKey];
|
||||
for(var ci=0;ci<bidList.length;ci++){delete _sent[bidList[ci]];}
|
||||
return;
|
||||
}
|
||||
if(polls>maxPolls){
|
||||
log('Poll timeout for '+rid);
|
||||
clearInterval(timer);
|
||||
delete _sent[groupKey];
|
||||
for(var ti=0;ti<bidList.length;ti++){delete _sent[bidList[ti]];}
|
||||
return;
|
||||
}
|
||||
fetch(BASE+'/response/'+rid).then(function(r){return r.json();}).then(function(d){
|
||||
if(d.waiting)return;
|
||||
clearInterval(timer);
|
||||
var btnIdx=(typeof d.button_index==='number'&&d.button_index>=0)?d.button_index:-1;
|
||||
if(btnIdx>=0&&btnIdx<btnRefs.length){
|
||||
// Multi-choice: click specific button by index
|
||||
var targetBtn=btnRefs[btnIdx];
|
||||
var targetTxt=(targetBtn.textContent||'').trim();
|
||||
log((d.approved?'✅':'❌')+' CHOICE '+rid+' → clicking button['+btnIdx+'] "'+targetTxt+'"');
|
||||
targetBtn.click();
|
||||
} else if(d.approved){
|
||||
// Legacy single-button: click first (primary) button
|
||||
var primaryBtn=btnRefs[0];
|
||||
log('✅ APPROVED '+rid+' → clicking "'+((primaryBtn.textContent||'').trim())+'"');
|
||||
primaryBtn.click();
|
||||
} else {
|
||||
// Legacy reject: find and click reject/deny button
|
||||
log('❌ REJECTED '+rid+' → finding reject button');
|
||||
clickRejectButton(btnRefs[0]);
|
||||
}
|
||||
delete _sent[groupKey];
|
||||
for(var ri=0;ri<bidList.length;ri++){delete _sent[bidList[ri]];}
|
||||
}).catch(function(){});
|
||||
},1500);
|
||||
}
|
||||
|
||||
// Legacy pollResponse for backward compatibility (single button)
|
||||
function pollResponse(rid,btn,bid){
|
||||
pollResponseGroup(rid,[btn],[bid],bid);
|
||||
}
|
||||
|
||||
// ── Find and click the reject/cancel counterpart button ──
|
||||
function clickRejectButton(approveBtn){
|
||||
// Walk up to find the container, then search for reject buttons
|
||||
var container=approveBtn.closest('[class*="step"]')
|
||||
||approveBtn.closest('[class*="action"]')
|
||||
||approveBtn.closest('[class*="tool"]')
|
||||
||approveBtn.parentElement;
|
||||
if(!container){log('No container for reject');return;}
|
||||
|
||||
var siblings=container.querySelectorAll('button');
|
||||
for(var i=0;i<siblings.length;i++){
|
||||
var t=(siblings[i].textContent||'').trim();
|
||||
for(var r=0;r<REJECT_RE.length;r++){
|
||||
if(REJECT_RE[r].test(t)){
|
||||
log('Clicking reject: "'+t+'"');
|
||||
siblings[i].click();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
log('No reject button found near approve button');
|
||||
}
|
||||
|
||||
// ── Throttled scan — leading-edge: fires immediately, then locks ──
|
||||
function scheduleScan(){
|
||||
if(!_ready)return;
|
||||
var now=Date.now();
|
||||
if(now-_lastScanTs>=THROTTLE_MS){
|
||||
_lastScanTs=now;
|
||||
scan();
|
||||
} else if(!_scanScheduled){
|
||||
_scanScheduled=true;
|
||||
setTimeout(function(){
|
||||
_scanScheduled=false;
|
||||
_lastScanTs=Date.now();
|
||||
scan();
|
||||
},THROTTLE_MS-(now-_lastScanTs));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Periodic cleanup of stale _sent entries ──
|
||||
setInterval(function(){
|
||||
var now=Date.now();
|
||||
var keys=Object.keys(_sent);
|
||||
for(var i=0;i<keys.length;i++){
|
||||
var entry=_sent[keys[i]];
|
||||
if(entry&&entry.ts&&(now-entry.ts)>CLEANUP_MS){
|
||||
log('Cleanup stale entry: '+keys[i]);
|
||||
delete _sent[keys[i]];
|
||||
}
|
||||
}
|
||||
},60000);
|
||||
|
||||
// ── Start observation ──
|
||||
function startObserver(){
|
||||
if(_obs)return;
|
||||
// PRIMARY: MutationObserver — reacts instantly to DOM changes
|
||||
new MutationObserver(function(mutations){
|
||||
// Only scan if mutations contain added nodes (new buttons potentially)
|
||||
for(var i=0;i<mutations.length;i++){
|
||||
if(mutations[i].addedNodes.length>0){
|
||||
scheduleScan();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}).observe(document.body,{childList:true,subtree:true});
|
||||
|
||||
// FALLBACK: periodic scan every 3s for any missed mutations
|
||||
setInterval(scheduleScan,3000);
|
||||
|
||||
// ── Adaptive idle detection for HTTP polls ──
|
||||
var _lastActivity=Date.now();
|
||||
var _idleThreshold=60000; // 60s without DOM changes → slow mode
|
||||
new MutationObserver(function(){_lastActivity=Date.now();}).observe(document.body,{childList:true,subtree:true,attributes:true});
|
||||
function getAdaptiveInterval(){return (Date.now()-_lastActivity>_idleThreshold)?10000:2000;}
|
||||
|
||||
// ── DEEP-INSPECT POLLING: curl→Bridge→Renderer→Results ──
|
||||
(function pollDeepInspect(){
|
||||
if(_ready&&BASE){
|
||||
fetch(BASE+'/deep-inspect-trigger?t='+Date.now()).then(function(r){return r.json();}).then(function(d){
|
||||
if(d.inspect){log('🔍 Deep inspect triggered via HTTP');runDeepInspect();}
|
||||
}).catch(function(){});
|
||||
}
|
||||
setTimeout(pollDeepInspect,getAdaptiveInterval());
|
||||
})();
|
||||
|
||||
// ── TRIGGER-CLICK: Extension→Renderer bridge for programmatic button clicks ──
|
||||
// Extension sets clickTrigger via tryApprovalStrategies → renderer polls and clicks
|
||||
// v3: uses deepFindButtons() to traverse iframes, webviews, shadow DOMs
|
||||
(function pollTriggerClick(){
|
||||
if(_ready&&BASE){
|
||||
fetch(BASE+'/trigger-click?t='+Date.now()).then(function(r){return r.json();}).then(function(d){
|
||||
if(!d.action)return;
|
||||
log('🔔 TRIGGER-CLICK received: action='+d.action);
|
||||
|
||||
var approveRe=[/^Run$/i,/^Run /i,/^Accept/i,/^Allow/i,/^Approve/i,/^Continue$/i,/^Proceed$/i,/^Retry$/i];
|
||||
var rejectRe=[/^Reject/i,/^Cancel$/i,/^Deny$/i,/^Stop$/i,/^Decline$/i,/^Dismiss$/i];
|
||||
var patterns=(d.action==='approve')?approveRe:rejectRe;
|
||||
var emoji=(d.action==='approve')?'✅':'❌';
|
||||
|
||||
// Phase 1: deepFindButtons in main doc + accessible iframes + shadow DOMs
|
||||
var found=deepFindButtons(patterns);
|
||||
if(found.length>0){
|
||||
log(emoji+' TRIGGER-CLICK: clicking "'+found[0].text+'" from '+found[0].source);
|
||||
found[0].btn.click();
|
||||
return;
|
||||
}
|
||||
|
||||
// Phase 2: Try <webview>.executeJavaScript for inaccessible webviews
|
||||
var webviews=document.querySelectorAll('webview');
|
||||
if(webviews.length>0){
|
||||
log('TRIGGER-CLICK: trying '+webviews.length+' webview(s) via executeJavaScript...');
|
||||
var patternsStr=patterns.map(function(re){return re.source;}).join('|');
|
||||
var clickScript='(function(){'+
|
||||
'var re=new RegExp("'+patternsStr+'","i");'+
|
||||
'var btns=document.querySelectorAll("button");'+
|
||||
'for(var i=0;i<btns.length;i++){'+
|
||||
'var b=btns[i];if(b.disabled||b.hidden)continue;'+
|
||||
'var t=(b.textContent||"").trim();'+
|
||||
'if(re.test(t)){b.click();return "CLICKED:"+t;}'+
|
||||
'}'+
|
||||
'return "NOT_FOUND:"+btns.length+"_buttons";'+
|
||||
'})()';
|
||||
for(var w=0;w<webviews.length;w++){
|
||||
(function(wv,idx){
|
||||
try{
|
||||
if(typeof wv.executeJavaScript==='function'){
|
||||
wv.executeJavaScript(clickScript).then(function(result){
|
||||
log(emoji+' TRIGGER-CLICK webview#'+idx+': '+result);
|
||||
}).catch(function(e){
|
||||
log('TRIGGER-CLICK webview#'+idx+' execJS error: '+e.message);
|
||||
});
|
||||
}
|
||||
}catch(e){
|
||||
log('TRIGGER-CLICK webview#'+idx+' error: '+e.message);
|
||||
}
|
||||
})(webviews[w],w);
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Try iframes via postMessage (cross-origin fallback)
|
||||
var iframes=document.querySelectorAll('iframe');
|
||||
if(iframes.length>0){
|
||||
log('TRIGGER-CLICK: trying '+iframes.length+' iframe(s) — checking accessibility...');
|
||||
var clickedAny=false;
|
||||
for(var fi=0;fi<iframes.length;fi++){
|
||||
try{
|
||||
var idoc=iframes[fi].contentDocument||iframes[fi].contentWindow.document;
|
||||
if(!idoc)continue;
|
||||
var ibtns=idoc.querySelectorAll('button');
|
||||
for(var bi=0;bi<ibtns.length;bi++){
|
||||
var ib=ibtns[bi];
|
||||
if(ib.disabled||ib.hidden)continue;
|
||||
var itxt=(ib.textContent||'').trim();
|
||||
for(var pi=0;pi<patterns.length;pi++){
|
||||
if(patterns[pi].test(itxt)){
|
||||
log(emoji+' TRIGGER-CLICK iframe#'+fi+': clicking "'+itxt+'"');
|
||||
ib.click();
|
||||
clickedAny=true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}catch(e){}
|
||||
}
|
||||
}
|
||||
|
||||
if(!found.length){
|
||||
// Log what we DID find for debugging
|
||||
var allBtns=document.querySelectorAll('button');
|
||||
var btnTexts=[];
|
||||
for(var di=0;di<Math.min(10,allBtns.length);di++){
|
||||
btnTexts.push('"'+((allBtns[di].textContent||'').trim()).substring(0,30)+'"');
|
||||
}
|
||||
log('⚠️ TRIGGER-CLICK: no '+d.action+' button found. Main DOM has '+allBtns.length+' btns: ['+btnTexts.join(',')+']');
|
||||
log('⚠️ iframes='+document.querySelectorAll('iframe').length+' webviews='+document.querySelectorAll('webview').length);
|
||||
}
|
||||
}).catch(function(){});
|
||||
}
|
||||
setTimeout(pollTriggerClick,getAdaptiveInterval());
|
||||
})();
|
||||
|
||||
_obs=true;
|
||||
log('v3 Observer active — deep DOM traversal + MutationObserver + trigger-click polling');
|
||||
}
|
||||
})();
|
||||
`;
|
||||
}
|
||||
|
||||
1435
extension/src/step-probe.ts
Normal file
1435
extension/src/step-probe.ts
Normal file
File diff suppressed because it is too large
Load Diff
114
extension/src/step-utils.ts
Normal file
114
extension/src/step-utils.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
/**
|
||||
* Step Utilities — pure functions for parsing step data, planner responses,
|
||||
* and tool call information. No external state dependencies.
|
||||
*
|
||||
* Extracted from extension.ts for maintainability.
|
||||
*/
|
||||
|
||||
export function extractPlannerText(step: any): string | null {
|
||||
if (!step) { return null; }
|
||||
|
||||
// Fields to SKIP — not user-facing content
|
||||
const SKIP_FIELDS = new Set([
|
||||
'thinking', 'thinkingSignature', 'stopReason', 'type', 'status', 'metadata',
|
||||
'ephemeralMessage', 'generatorModel', 'requestedModel',
|
||||
'executionId', 'sourceTrajectoryStepInfo', 'stepIndex',
|
||||
'viewableAt', 'createdAt', 'finishedGeneratingAt',
|
||||
'lastCompletedChunkAt', 'source', 'stepGenerationVersion'
|
||||
]);
|
||||
|
||||
// plannerResponse can be string or object
|
||||
const pr = step.plannerResponse;
|
||||
if (typeof pr === 'string' && pr.length > 10) {
|
||||
return filterEphemeral(pr);
|
||||
}
|
||||
if (pr && typeof pr === 'object') {
|
||||
// Try known content fields first (NOT thinking/stopReason)
|
||||
const text = pr.content || pr.text || pr.summary || pr.message || pr.response || pr.output;
|
||||
if (typeof text === 'string' && text.length > 10) {
|
||||
return filterEphemeral(text);
|
||||
}
|
||||
// Search other fields, but skip non-content ones
|
||||
for (const key of Object.keys(pr)) {
|
||||
if (SKIP_FIELDS.has(key)) continue;
|
||||
const val = pr[key];
|
||||
if (typeof val === 'string' && val.length > 50) { // Higher threshold
|
||||
const filtered = filterEphemeral(val);
|
||||
if (filtered) {
|
||||
console.log(`Gravity Bridge: [DEBUG] planner text in plannerResponse.${key} (${filtered.length} chars)`);
|
||||
return filtered;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try other step fields (skip known non-content)
|
||||
for (const key of Object.keys(step)) {
|
||||
if (SKIP_FIELDS.has(key) || key === 'plannerResponse') continue;
|
||||
const val = step[key];
|
||||
if (typeof val === 'string' && val.length > 50) {
|
||||
const filtered = filterEphemeral(val);
|
||||
if (filtered) {
|
||||
console.log(`Gravity Bridge: [DEBUG] planner text in step.${key}`);
|
||||
return filtered;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Filter out system ephemeral messages and non-content strings. */
|
||||
export function filterEphemeral(text: string): string | null {
|
||||
if (!text || text.length < 10) { return null; }
|
||||
// Skip system prompt metadata
|
||||
if (text.includes('<EPHEMERAL_MESSAGE>') || text.includes('<ephemeral_message>')) { return null; }
|
||||
if (text.includes('artifact_reminder') || text.includes('active_task_reminder')) { return null; }
|
||||
if (text.includes('no_active_task_reminder')) { return null; }
|
||||
// Skip base64/crypto strings (no spaces, mostly alphanumeric)
|
||||
if (!text.includes(' ') && /^[A-Za-z0-9+/=_\-]{50,}$/.test(text)) { return null; }
|
||||
return text;
|
||||
}
|
||||
|
||||
/** Extract human-readable command from a tool call step's data. */
|
||||
export function extractToolCommand(stepData: any): string {
|
||||
// Try common step data shapes from protobuf
|
||||
if (stepData.runCommand) {
|
||||
return stepData.runCommand.commandLine || stepData.runCommand.command || 'Run Command';
|
||||
}
|
||||
if (stepData.writeToFile) {
|
||||
const target = stepData.writeToFile.targetFile || stepData.writeToFile.filePath || 'file';
|
||||
return `Write: ${target.split(/[\\/]/).pop()}`;
|
||||
}
|
||||
if (stepData.codeAction) {
|
||||
const fp = stepData.codeAction.filePath || '';
|
||||
return `Edit: ${fp.split(/[\\/]/).pop() || 'file'}`;
|
||||
}
|
||||
if (stepData.replaceFileContent || stepData.multiReplaceFileContent) {
|
||||
const d = stepData.replaceFileContent || stepData.multiReplaceFileContent;
|
||||
const fp = d.targetFile || d.filePath || '';
|
||||
return `Edit: ${fp.split(/[\\/]/).pop() || 'file'}`;
|
||||
}
|
||||
if (stepData.sendCommandInput) {
|
||||
return `Send Input: ${(stepData.sendCommandInput.input || '').substring(0, 50)}`;
|
||||
}
|
||||
// Generic fallback: use first key name
|
||||
const keys = Object.keys(stepData).filter(k => k !== 'status' && k !== 'stepStatus');
|
||||
return keys.length > 0 ? keys[0] : 'Unknown tool call';
|
||||
}
|
||||
|
||||
/** Extract description from a tool call step for Discord display. */
|
||||
export function extractToolDescription(stepData: any, sessionTitle: string, stepIndex: number): string {
|
||||
const parts = [`Step #${stepIndex}`, `Session: "${sessionTitle}"`];
|
||||
// Try to get code/command content for context
|
||||
if (stepData.runCommand) {
|
||||
const cmd = stepData.runCommand.commandLine || stepData.runCommand.command || '';
|
||||
if (cmd) parts.push(`Command: ${cmd.substring(0, 200)}`);
|
||||
}
|
||||
if (stepData.writeToFile?.targetFile) {
|
||||
parts.push(`File: ${stepData.writeToFile.targetFile}`);
|
||||
}
|
||||
if (stepData.codeAction?.filePath) {
|
||||
parts.push(`File: ${stepData.codeAction.filePath}`);
|
||||
}
|
||||
return parts.join('\n');
|
||||
}
|
||||
505
extension/src/ws-client.ts
Normal file
505
extension/src/ws-client.ts
Normal file
@@ -0,0 +1,505 @@
|
||||
/**
|
||||
* WebSocket Bridge Client — connects Extension to the Hub server.
|
||||
*
|
||||
* Replaces file-based IPC for:
|
||||
* - Pending approvals (Extension → Hub → Bot → Discord)
|
||||
* - User responses (Discord → Bot → Hub → Extension)
|
||||
* - Chat snapshots (Extension → Hub → Bot → Discord)
|
||||
* - Commands (Discord → Bot → Hub → Extension)
|
||||
* - Session registration
|
||||
* - Auto-resolve notifications
|
||||
*
|
||||
* Features:
|
||||
* - Exponential backoff + jitter reconnection
|
||||
* - Message queue (survives reconnection)
|
||||
* - Heartbeat ping/pong
|
||||
* - First-message JWT authentication
|
||||
*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
// ─── Types ───
|
||||
|
||||
export interface WSMessage {
|
||||
type: string;
|
||||
data?: any;
|
||||
msg_id?: string;
|
||||
}
|
||||
|
||||
export interface WSAuthMessage {
|
||||
type: 'auth';
|
||||
token?: string;
|
||||
registration_code?: string;
|
||||
project: string;
|
||||
pc: string;
|
||||
}
|
||||
|
||||
export interface WSAuthOkResponse {
|
||||
type: 'auth_ok';
|
||||
conn_id: string;
|
||||
instance_number: number;
|
||||
session_token: string;
|
||||
active_count: number;
|
||||
}
|
||||
|
||||
export interface WSPendingData {
|
||||
request_id: string;
|
||||
command: string;
|
||||
description?: string;
|
||||
step_type?: string;
|
||||
status?: string;
|
||||
buttons?: Array<{ text: string; index: number }>;
|
||||
project_name?: string;
|
||||
// diff_review metadata
|
||||
edit_step_indices?: number[];
|
||||
modified_files?: string[];
|
||||
}
|
||||
|
||||
export interface WSResponseData {
|
||||
request_id: string;
|
||||
approved: boolean;
|
||||
button_index?: number;
|
||||
step_type?: string;
|
||||
project_name?: string;
|
||||
}
|
||||
|
||||
export interface WSCommandData {
|
||||
text: string;
|
||||
project_name?: string;
|
||||
action?: string;
|
||||
}
|
||||
|
||||
export interface WSChatData {
|
||||
content: string;
|
||||
attached_files?: Array<{ name: string; content: string }>;
|
||||
conversation_id?: string;
|
||||
project_name?: string;
|
||||
}
|
||||
|
||||
export interface WSRegisterData {
|
||||
conversation_id: string;
|
||||
project_name: string;
|
||||
}
|
||||
|
||||
// ─── Event Handlers ───
|
||||
|
||||
export interface WSBridgeHandlers {
|
||||
onResponse?: (data: WSResponseData) => void;
|
||||
onCommand?: (data: WSCommandData) => void;
|
||||
onInstanceUpdate?: (activeCount: number, instances: Array<{ instance_number: number; pc: string }>) => void;
|
||||
onConnected?: (connId: string, instanceNumber: number, sessionToken: string) => void;
|
||||
onDisconnected?: () => void;
|
||||
onError?: (error: string) => void;
|
||||
}
|
||||
|
||||
// ─── Constants ───
|
||||
|
||||
const INITIAL_RECONNECT_DELAY = 1000; // 1s
|
||||
const MAX_RECONNECT_DELAY = 60000; // 60s
|
||||
const RECONNECT_JITTER = 0.3; // ±30%
|
||||
const HEARTBEAT_INTERVAL = 25000; // 25s (server expects 30s)
|
||||
const MAX_QUEUE_SIZE = 200;
|
||||
const AUTH_TIMEOUT = 10000; // 10s
|
||||
|
||||
// ─── WSBridgeClient ───
|
||||
|
||||
export class WSBridgeClient {
|
||||
private ws: any = null; // WebSocket instance (Node.js ws module)
|
||||
private hubUrl: string;
|
||||
private registrationCode: string;
|
||||
private project: string;
|
||||
private pcName: string;
|
||||
private handlers: WSBridgeHandlers;
|
||||
private logFn: (msg: string) => void;
|
||||
|
||||
// Connection state
|
||||
private connected = false;
|
||||
private authenticated = false;
|
||||
private connId = '';
|
||||
private instanceNumber = 0;
|
||||
private sessionToken = '';
|
||||
private shouldReconnect = true;
|
||||
private reconnectDelay = INITIAL_RECONNECT_DELAY;
|
||||
private reconnectTimer: NodeJS.Timeout | null = null;
|
||||
private heartbeatTimer: NodeJS.Timeout | null = null;
|
||||
private authTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// Message queue (survives reconnection)
|
||||
private messageQueue: WSMessage[] = [];
|
||||
private msgIdCounter = 0;
|
||||
|
||||
constructor(
|
||||
hubUrl: string,
|
||||
registrationCode: string,
|
||||
project: string,
|
||||
pcName: string,
|
||||
handlers: WSBridgeHandlers,
|
||||
logFn: (msg: string) => void,
|
||||
) {
|
||||
this.hubUrl = hubUrl;
|
||||
this.registrationCode = registrationCode;
|
||||
this.project = project;
|
||||
this.pcName = pcName;
|
||||
this.handlers = handlers;
|
||||
this.logFn = logFn;
|
||||
}
|
||||
|
||||
// ─── Public API ───
|
||||
|
||||
/** Start the WebSocket connection. */
|
||||
async connect(): Promise<void> {
|
||||
if (!this.hubUrl) {
|
||||
this.logFn('[WS] No hub URL configured — WS disabled');
|
||||
return;
|
||||
}
|
||||
this.shouldReconnect = true;
|
||||
await this._connect();
|
||||
}
|
||||
|
||||
/** Gracefully disconnect. */
|
||||
disconnect(): void {
|
||||
this.shouldReconnect = false;
|
||||
this._cleanup();
|
||||
this.logFn('[WS] Disconnected (intentional)');
|
||||
}
|
||||
|
||||
/** Check if connected and authenticated. */
|
||||
isConnected(): boolean {
|
||||
return this.connected && this.authenticated;
|
||||
}
|
||||
|
||||
/** Get the instance number assigned by the Hub. */
|
||||
getInstanceNumber(): number {
|
||||
return this.instanceNumber;
|
||||
}
|
||||
|
||||
/** Send a pending approval to the Hub. */
|
||||
sendPending(data: WSPendingData): boolean {
|
||||
return this._send({ type: 'pending', data });
|
||||
}
|
||||
|
||||
/** Send a chat snapshot to the Hub. */
|
||||
sendChat(data: WSChatData): boolean {
|
||||
return this._send({ type: 'chat', data });
|
||||
}
|
||||
|
||||
/** Send a session registration. */
|
||||
sendRegister(data: WSRegisterData): boolean {
|
||||
return this._send({ type: 'register', data });
|
||||
}
|
||||
|
||||
/** Send an auto_resolve notification. */
|
||||
sendAutoResolve(requestId: string): boolean {
|
||||
return this._send({ type: 'auto_resolve', data: { request_id: requestId } });
|
||||
}
|
||||
|
||||
/** Send a brain event. */
|
||||
sendBrainEvent(data: any): boolean {
|
||||
return this._send({ type: 'brain_event', data });
|
||||
}
|
||||
|
||||
// ─── Internal Connection ───
|
||||
|
||||
private async _connect(): Promise<void> {
|
||||
try {
|
||||
// Dynamic import of ws module (Node.js built-in or npm package)
|
||||
const WebSocket = await this._getWebSocketClass();
|
||||
if (!WebSocket) {
|
||||
this.logFn('[WS] WebSocket module not available');
|
||||
return;
|
||||
}
|
||||
|
||||
this.logFn(`[WS] Connecting to ${this.hubUrl}...`);
|
||||
const ws = new WebSocket(this.hubUrl);
|
||||
|
||||
ws.on('open', () => {
|
||||
this.logFn('[WS] Connection opened, authenticating...');
|
||||
this.ws = ws;
|
||||
this.connected = true;
|
||||
this._authenticate();
|
||||
});
|
||||
|
||||
ws.on('message', (raw: Buffer | string) => {
|
||||
try {
|
||||
const data = JSON.parse(typeof raw === 'string' ? raw : raw.toString('utf-8'));
|
||||
this._handleMessage(data);
|
||||
} catch (e: any) {
|
||||
this.logFn(`[WS] Parse error: ${e.message}`);
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', (code: number, reason: Buffer) => {
|
||||
const reasonStr = reason ? reason.toString('utf-8') : '';
|
||||
this.logFn(`[WS] Connection closed: code=${code} reason=${reasonStr}`);
|
||||
this._onDisconnect();
|
||||
});
|
||||
|
||||
ws.on('error', (err: Error) => {
|
||||
this.logFn(`[WS] Error: ${err.message}`);
|
||||
// close event will follow
|
||||
});
|
||||
|
||||
ws.on('pong', () => {
|
||||
// Server responded to our ping — connection is alive
|
||||
});
|
||||
|
||||
} catch (e: any) {
|
||||
this.logFn(`[WS] Connect failed: ${e.message}`);
|
||||
this._scheduleReconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private async _getWebSocketClass(): Promise<any> {
|
||||
try {
|
||||
// Try Node.js built-in WebSocket (v21+)
|
||||
if (typeof globalThis.WebSocket !== 'undefined') {
|
||||
return globalThis.WebSocket;
|
||||
}
|
||||
// Try require('ws') — should be available in VS Code's Node.js
|
||||
const ws = require('ws');
|
||||
return ws;
|
||||
} catch {
|
||||
// ws module not available
|
||||
try {
|
||||
// Fallback: try the built-in undici WebSocket
|
||||
const { WebSocket } = require('undici');
|
||||
return WebSocket;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Authentication ───
|
||||
|
||||
private _authenticate(): void {
|
||||
if (!this.ws) return;
|
||||
|
||||
const authMsg: WSAuthMessage = {
|
||||
type: 'auth',
|
||||
project: this.project,
|
||||
pc: this.pcName,
|
||||
};
|
||||
|
||||
// Use session token if available (from previous connection)
|
||||
if (this.sessionToken) {
|
||||
authMsg.token = this.sessionToken;
|
||||
} else if (this.registrationCode) {
|
||||
authMsg.registration_code = this.registrationCode;
|
||||
}
|
||||
|
||||
this._sendRaw(authMsg);
|
||||
|
||||
// Timeout for auth response
|
||||
this.authTimer = setTimeout(() => {
|
||||
if (!this.authenticated) {
|
||||
this.logFn('[WS] Auth timeout — closing connection');
|
||||
this._cleanup();
|
||||
this._scheduleReconnect();
|
||||
}
|
||||
}, AUTH_TIMEOUT);
|
||||
}
|
||||
|
||||
// ─── Message Handling ───
|
||||
|
||||
private _handleMessage(msg: WSMessage): void {
|
||||
switch (msg.type) {
|
||||
case 'auth_ok': {
|
||||
const authOk = msg as unknown as WSAuthOkResponse;
|
||||
this.authenticated = true;
|
||||
this.connId = authOk.conn_id;
|
||||
this.instanceNumber = authOk.instance_number;
|
||||
this.sessionToken = authOk.session_token;
|
||||
this.reconnectDelay = INITIAL_RECONNECT_DELAY;
|
||||
|
||||
if (this.authTimer) {
|
||||
clearTimeout(this.authTimer);
|
||||
this.authTimer = null;
|
||||
}
|
||||
|
||||
this.logFn(`[WS] Authenticated: conn=${this.connId} instance=#${this.instanceNumber} active=${authOk.active_count}`);
|
||||
this._startHeartbeat();
|
||||
this._flushQueue();
|
||||
this.handlers.onConnected?.(this.connId, this.instanceNumber, this.sessionToken);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'auth_fail': {
|
||||
const reason = (msg as any).reason || 'Unknown';
|
||||
this.logFn(`[WS] Auth failed: ${reason}`);
|
||||
// Clear session token if it was rejected
|
||||
this.sessionToken = '';
|
||||
this._cleanup();
|
||||
// Don't reconnect on auth failure (needs manual fix)
|
||||
this.handlers.onError?.(`Auth failed: ${reason}`);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'response': {
|
||||
const data = msg.data as WSResponseData;
|
||||
if (data) {
|
||||
this.logFn(`[WS] Response received: ${data.request_id?.substring(0, 12)} approved=${data.approved}`);
|
||||
this.handlers.onResponse?.(data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'command': {
|
||||
const data = msg.data as WSCommandData;
|
||||
if (data) {
|
||||
this.logFn(`[WS] Command received: ${data.text?.substring(0, 50)}`);
|
||||
this.handlers.onCommand?.(data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'instance_update': {
|
||||
const activeCount = (msg as any).active_count || 0;
|
||||
const instances = (msg as any).instances || [];
|
||||
this.logFn(`[WS] Instance update: ${activeCount} active`);
|
||||
this.handlers.onInstanceUpdate?.(activeCount, instances);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'error': {
|
||||
const error = (msg as any).error || 'Unknown error';
|
||||
this.logFn(`[WS] Server error: ${error}`);
|
||||
this.handlers.onError?.(error);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
this.logFn(`[WS] Unknown message type: ${msg.type}`);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Send ───
|
||||
|
||||
private _send(msg: WSMessage): boolean {
|
||||
// Add unique message ID for dedup
|
||||
msg.msg_id = `${this.project}-${Date.now()}-${++this.msgIdCounter}`;
|
||||
|
||||
if (this.isConnected()) {
|
||||
return this._sendRaw(msg);
|
||||
}
|
||||
|
||||
// Queue for later
|
||||
if (this.messageQueue.length >= MAX_QUEUE_SIZE) {
|
||||
// Drop oldest
|
||||
this.messageQueue.shift();
|
||||
this.logFn('[WS] Queue full — dropped oldest message');
|
||||
}
|
||||
this.messageQueue.push(msg);
|
||||
this.logFn(`[WS] Queued message (type=${msg.type}, queue=${this.messageQueue.length})`);
|
||||
return false;
|
||||
}
|
||||
|
||||
private _sendRaw(msg: any): boolean {
|
||||
try {
|
||||
if (this.ws && this.connected) {
|
||||
this.ws.send(JSON.stringify(msg));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} catch (e: any) {
|
||||
this.logFn(`[WS] Send error: ${e.message}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private _flushQueue(): void {
|
||||
if (this.messageQueue.length === 0) return;
|
||||
this.logFn(`[WS] Flushing ${this.messageQueue.length} queued messages`);
|
||||
const queue = [...this.messageQueue];
|
||||
this.messageQueue = [];
|
||||
for (const msg of queue) {
|
||||
this._sendRaw(msg);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Heartbeat ───
|
||||
|
||||
private _startHeartbeat(): void {
|
||||
this._stopHeartbeat();
|
||||
this.heartbeatTimer = setInterval(() => {
|
||||
if (this.ws && this.connected) {
|
||||
try {
|
||||
this.ws.ping();
|
||||
} catch {
|
||||
// ping failure will trigger close event
|
||||
}
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
}
|
||||
|
||||
private _stopHeartbeat(): void {
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Reconnection ───
|
||||
|
||||
private _onDisconnect(): void {
|
||||
const wasAuthenticated = this.authenticated;
|
||||
this.connected = false;
|
||||
this.authenticated = false;
|
||||
this.ws = null;
|
||||
|
||||
this._stopHeartbeat();
|
||||
if (this.authTimer) {
|
||||
clearTimeout(this.authTimer);
|
||||
this.authTimer = null;
|
||||
}
|
||||
|
||||
if (wasAuthenticated) {
|
||||
this.handlers.onDisconnected?.();
|
||||
}
|
||||
|
||||
if (this.shouldReconnect) {
|
||||
this._scheduleReconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private _scheduleReconnect(): void {
|
||||
if (this.reconnectTimer) return;
|
||||
|
||||
// Exponential backoff with jitter
|
||||
const jitter = 1 + (Math.random() * 2 - 1) * RECONNECT_JITTER;
|
||||
const delay = Math.min(this.reconnectDelay * jitter, MAX_RECONNECT_DELAY);
|
||||
this.logFn(`[WS] Reconnecting in ${Math.round(delay)}ms...`);
|
||||
|
||||
this.reconnectTimer = setTimeout(() => {
|
||||
this.reconnectTimer = null;
|
||||
this.reconnectDelay = Math.min(this.reconnectDelay * 2, MAX_RECONNECT_DELAY);
|
||||
this._connect();
|
||||
}, delay);
|
||||
}
|
||||
|
||||
// ─── Cleanup ───
|
||||
|
||||
private _cleanup(): void {
|
||||
this._stopHeartbeat();
|
||||
|
||||
if (this.authTimer) {
|
||||
clearTimeout(this.authTimer);
|
||||
this.authTimer = null;
|
||||
}
|
||||
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
this.reconnectTimer = null;
|
||||
}
|
||||
|
||||
if (this.ws) {
|
||||
try {
|
||||
this.ws.close();
|
||||
} catch { }
|
||||
this.ws = null;
|
||||
}
|
||||
|
||||
this.connected = false;
|
||||
this.authenticated = false;
|
||||
}
|
||||
}
|
||||
47
gateway.py
47
gateway.py
@@ -1,10 +1,12 @@
|
||||
"""Gateway HTTP API — receives data from remote Collectors and routes to Discord bot.
|
||||
"""Gateway HTTP API + WebSocket Hub — receives data from Collectors and Extensions.
|
||||
|
||||
Runs alongside the Discord bot in the server Docker container.
|
||||
Collectors (local PCs) push pending approvals, chat snapshots, and registrations
|
||||
to this API, and poll for responses.
|
||||
Supports both:
|
||||
- REST API: for legacy Collectors (HTTP polling)
|
||||
- WebSocket: for direct Extension connections (real-time)
|
||||
|
||||
Endpoints:
|
||||
GET /ws — WebSocket endpoint (Extension direct connection)
|
||||
POST /api/pending — Collector pushes a new approval request
|
||||
GET /api/pending — List all pending requests (for diagnostics)
|
||||
POST /api/response/{rid} — Collector polls for response (or Gateway pushes)
|
||||
@@ -14,6 +16,7 @@ Endpoints:
|
||||
POST /api/command — Gateway pushes command to specific collector
|
||||
GET /api/commands/{project} — Collector polls for commands
|
||||
GET /health — Health check
|
||||
GET /hub/status — WebSocket Hub diagnostics
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -34,13 +37,15 @@ COMMAND_TTL = 1800 # 30 min — stale commands auto-deleted
|
||||
|
||||
|
||||
class GatewayAPI:
|
||||
"""HTTP API server for Collector ↔ Gateway communication."""
|
||||
"""HTTP API + WebSocket Hub server."""
|
||||
|
||||
def __init__(self, bot, host: str = "0.0.0.0", port: int = 8585, api_key: str = ""):
|
||||
def __init__(self, bot, host: str = "0.0.0.0", port: int = 8585,
|
||||
api_key: str = "", hub=None):
|
||||
self.bot = bot
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_key = api_key
|
||||
self.hub = hub # WSHub instance (None = WS disabled)
|
||||
self.app = web.Application(
|
||||
middlewares=[self._auth_middleware],
|
||||
client_max_size=1024 * 1024, # Security: 1MB max request body
|
||||
@@ -52,6 +57,10 @@ class GatewayAPI:
|
||||
self._rate_limits: dict[str, list[float]] = defaultdict(list) # IP → [timestamps]
|
||||
|
||||
def _setup_routes(self):
|
||||
# WebSocket endpoint (no auth middleware — Hub handles its own auth)
|
||||
self.app.router.add_get("/ws", self._ws_handler)
|
||||
self.app.router.add_get("/hub/status", self._hub_status)
|
||||
# Legacy REST endpoints (Collector compatibility)
|
||||
self.app.router.add_get("/health", self._health)
|
||||
self.app.router.add_post("/api/pending", self._post_pending)
|
||||
self.app.router.add_get("/api/pending", self._list_pending)
|
||||
@@ -61,13 +70,29 @@ class GatewayAPI:
|
||||
self.app.router.add_get("/api/commands/{project}", self._get_commands)
|
||||
self.app.router.add_post("/api/event", self._post_event)
|
||||
|
||||
# ─── WebSocket Handler ───
|
||||
|
||||
async def _ws_handler(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""WebSocket endpoint for direct Extension connections."""
|
||||
if not self.hub:
|
||||
return web.json_response(
|
||||
{"error": "WebSocket Hub not enabled"}, status=503
|
||||
)
|
||||
return await self.hub.handle_ws(request)
|
||||
|
||||
async def _hub_status(self, request: web.Request) -> web.Response:
|
||||
"""WebSocket Hub diagnostics."""
|
||||
if not self.hub:
|
||||
return web.json_response({"hub": "disabled"})
|
||||
return web.json_response(self.hub.get_status())
|
||||
|
||||
# ─── Auth Middleware ───
|
||||
|
||||
@web.middleware
|
||||
async def _auth_middleware(self, request: web.Request, handler):
|
||||
"""Reject requests without valid API key on /api/* routes."""
|
||||
# Health endpoint is public
|
||||
if request.path == "/health":
|
||||
# WebSocket and public endpoints skip API key auth
|
||||
if request.path in ("/health", "/ws", "/hub/status"):
|
||||
return await handler(request)
|
||||
|
||||
# All /api/* routes require auth + rate limit
|
||||
@@ -109,11 +134,15 @@ class GatewayAPI:
|
||||
# ─── Health ───
|
||||
|
||||
async def _health(self, request: web.Request) -> web.Response:
|
||||
return web.json_response({
|
||||
status = {
|
||||
"status": "ok",
|
||||
"bot_ready": self.bot.is_ready() if self.bot else False,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
"hub_enabled": self.hub is not None,
|
||||
}
|
||||
if self.hub:
|
||||
status["hub_connections"] = len(self.hub.connections)
|
||||
return web.json_response(status)
|
||||
|
||||
# ─── Pending Approvals (Collector → Gateway → Discord) ───
|
||||
|
||||
|
||||
580
hub.py
Normal file
580
hub.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""WebSocket Hub — real-time bidirectional communication between Extensions and Bot.
|
||||
|
||||
Replaces file-based IPC (bridge/) and HTTP polling (Collector/Gateway) with
|
||||
persistent WebSocket connections.
|
||||
|
||||
Architecture:
|
||||
Extension ↔ WSS ↔ Hub ↔ Bot (in-process) ↔ Discord
|
||||
|
||||
Each Extension connects via WebSocket, authenticates, and receives a unique
|
||||
instance number within its project. Messages are routed by project.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
from auth import TokenManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ─── Constants ───
|
||||
|
||||
HEARTBEAT_INTERVAL = 30.0 # seconds between server→client pings
|
||||
HEARTBEAT_TIMEOUT = 10.0 # seconds to wait for pong before disconnect
|
||||
CLIENT_QUEUE_SIZE = 100 # max queued messages per client (backpressure)
|
||||
MAX_MSG_SIZE = 1024 * 1024 # 1MB max WebSocket message size
|
||||
PER_CONN_RATE_LIMIT = 60 # max messages per 10s window per connection
|
||||
RATE_WINDOW = 10.0 # seconds
|
||||
|
||||
|
||||
class MsgType(str, Enum):
|
||||
"""WebSocket protocol message types."""
|
||||
# Extension → Hub (upstream)
|
||||
AUTH = "auth"
|
||||
PENDING = "pending"
|
||||
CHAT = "chat"
|
||||
REGISTER = "register"
|
||||
HEARTBEAT = "heartbeat"
|
||||
AUTO_RESOLVE = "auto_resolve"
|
||||
BRAIN_EVENT = "brain_event"
|
||||
|
||||
# Hub → Extension (downstream)
|
||||
AUTH_OK = "auth_ok"
|
||||
AUTH_FAIL = "auth_fail"
|
||||
RESPONSE = "response"
|
||||
COMMAND = "command"
|
||||
INSTANCE_UPDATE = "instance_update"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WSConnection:
|
||||
"""Represents a connected Extension instance."""
|
||||
conn_id: str
|
||||
ws: web.WebSocketResponse
|
||||
project: str = ""
|
||||
pc_name: str = ""
|
||||
instance_number: int = 0
|
||||
connected_at: float = field(default_factory=time.time)
|
||||
last_heartbeat: float = field(default_factory=time.time)
|
||||
authenticated: bool = False
|
||||
send_queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue(maxsize=CLIENT_QUEUE_SIZE))
|
||||
_sender_task: asyncio.Task | None = field(default=None, repr=False)
|
||||
# Rate limiting
|
||||
_msg_timestamps: list[float] = field(default_factory=list, repr=False)
|
||||
|
||||
|
||||
class WSHub:
|
||||
"""WebSocket Hub for routing messages between Extensions and Bot.
|
||||
|
||||
Responsibilities:
|
||||
- Connection lifecycle (auth, heartbeat, disconnect)
|
||||
- Message routing (project-scoped, instance-targeted)
|
||||
- Instance number management (auto-assign, reassign on disconnect)
|
||||
- Rate limiting per connection
|
||||
- Backpressure via per-client queues
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.connections: dict[str, WSConnection] = {} # conn_id → connection
|
||||
self.project_connections: dict[str, set[str]] = {} # project → {conn_ids}
|
||||
self.pending_owners: dict[str, str] = {} # request_id → conn_id
|
||||
self._recent_msg_ids: dict[str, float] = {} # msg_id → timestamp (dedup)
|
||||
self._msg_id_ttl = 300.0 # 5 min dedup window
|
||||
|
||||
# Bot callbacks — set by bot.py during initialization
|
||||
self._on_pending: Callable[..., Coroutine] | None = None
|
||||
self._on_chat: Callable[..., Coroutine] | None = None
|
||||
self._on_register: Callable[..., Coroutine] | None = None
|
||||
self._on_auto_resolve: Callable[..., Coroutine] | None = None
|
||||
self._on_brain_event: Callable[..., Coroutine] | None = None
|
||||
|
||||
# ─── Bot Integration ───
|
||||
|
||||
def set_bot_handlers(
|
||||
self,
|
||||
on_pending: Callable | None = None,
|
||||
on_chat: Callable | None = None,
|
||||
on_register: Callable | None = None,
|
||||
on_auto_resolve: Callable | None = None,
|
||||
on_brain_event: Callable | None = None,
|
||||
):
|
||||
"""Register bot callback functions for incoming Extension messages."""
|
||||
self._on_pending = on_pending
|
||||
self._on_chat = on_chat
|
||||
self._on_register = on_register
|
||||
self._on_auto_resolve = on_auto_resolve
|
||||
self._on_brain_event = on_brain_event
|
||||
|
||||
# ─── Connection Management ───
|
||||
|
||||
async def handle_ws(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle a new WebSocket connection.
|
||||
|
||||
Protocol:
|
||||
1. Client connects
|
||||
2. Client sends {type:"auth", token:"...", project:"...", pc:"..."}
|
||||
OR {type:"auth", registration_code:"...", project:"...", pc:"..."}
|
||||
3. Server responds {type:"auth_ok", conn_id, instance_number}
|
||||
OR {type:"auth_fail", reason:"..."}
|
||||
4. Bidirectional message exchange
|
||||
"""
|
||||
ws = web.WebSocketResponse(
|
||||
max_msg_size=MAX_MSG_SIZE,
|
||||
heartbeat=HEARTBEAT_INTERVAL,
|
||||
)
|
||||
await ws.prepare(request)
|
||||
|
||||
conn_id = uuid.uuid4().hex[:12]
|
||||
conn = WSConnection(conn_id=conn_id, ws=ws)
|
||||
remote = request.remote or "unknown"
|
||||
logger.info(f"[HUB] New WS connection: {conn_id} from {remote}")
|
||||
|
||||
try:
|
||||
# Wait for auth message (first message must be auth)
|
||||
auth_ok = await self._handle_auth(conn, timeout=10.0)
|
||||
if not auth_ok:
|
||||
return ws
|
||||
|
||||
# Register connection
|
||||
self._register_connection(conn)
|
||||
|
||||
# Start sender task (per-client queue → ws.send)
|
||||
conn._sender_task = asyncio.create_task(
|
||||
self._sender_loop(conn), name=f"sender-{conn_id}"
|
||||
)
|
||||
|
||||
# Message loop
|
||||
await self._message_loop(conn)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB] Connection {conn_id} error: {e}")
|
||||
finally:
|
||||
await self._disconnect(conn)
|
||||
|
||||
return ws
|
||||
|
||||
async def _handle_auth(self, conn: WSConnection, timeout: float = 10.0) -> bool:
|
||||
"""Wait for and process the auth message."""
|
||||
try:
|
||||
msg = await asyncio.wait_for(conn.ws.receive(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Auth timeout"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
await conn.ws.close()
|
||||
return False
|
||||
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Invalid JSON"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
|
||||
if data.get("type") != MsgType.AUTH:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "First message must be auth"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
|
||||
project = data.get("project", "")
|
||||
pc_name = data.get("pc", "unknown")
|
||||
|
||||
if not project:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Project name required"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
|
||||
# Try token auth first, then registration code
|
||||
token = data.get("token", "")
|
||||
reg_code = data.get("registration_code", "")
|
||||
|
||||
if token:
|
||||
payload = self.token_manager.verify_token(token)
|
||||
if not payload:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Invalid or expired token"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
# Token is valid — use project from token (overrides client)
|
||||
project = payload.get("project", project)
|
||||
pc_name = payload.get("pc", pc_name)
|
||||
elif reg_code:
|
||||
if not self.token_manager.validate_registration_code(reg_code):
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Invalid registration code"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
else:
|
||||
# No auth provided — check if registration code is configured
|
||||
if self.token_manager.registration_code:
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_FAIL, "reason": "Auth required"
|
||||
})
|
||||
await conn.ws.close()
|
||||
return False
|
||||
# No registration code configured → allow (dev mode)
|
||||
|
||||
# Auth successful
|
||||
conn.project = project
|
||||
conn.pc_name = pc_name
|
||||
conn.authenticated = True
|
||||
|
||||
# Issue a session token (for reconnection)
|
||||
session_token = self.token_manager.create_token(project, pc_name)
|
||||
|
||||
# Assign instance number
|
||||
instance_number = self._assign_instance_number(project, conn.conn_id)
|
||||
conn.instance_number = instance_number
|
||||
|
||||
await self._send_direct(conn.ws, {
|
||||
"type": MsgType.AUTH_OK,
|
||||
"conn_id": conn.conn_id,
|
||||
"instance_number": instance_number,
|
||||
"session_token": session_token,
|
||||
"active_count": self.get_active_count(project) + 1, # +1 for self (not yet registered)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"[HUB] Auth OK: {conn.conn_id} project={project} pc={pc_name} "
|
||||
f"instance=#{instance_number}"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_connection(self, conn: WSConnection):
|
||||
"""Add connection to tracking structures."""
|
||||
self.connections[conn.conn_id] = conn
|
||||
if conn.project not in self.project_connections:
|
||||
self.project_connections[conn.project] = set()
|
||||
self.project_connections[conn.project].add(conn.conn_id)
|
||||
|
||||
# Broadcast instance update to all project connections
|
||||
asyncio.create_task(self._broadcast_instance_update(conn.project))
|
||||
|
||||
async def _disconnect(self, conn: WSConnection):
|
||||
"""Clean up after a connection closes."""
|
||||
conn_id = conn.conn_id
|
||||
project = conn.project
|
||||
|
||||
# Cancel sender task
|
||||
if conn._sender_task and not conn._sender_task.done():
|
||||
conn._sender_task.cancel()
|
||||
try:
|
||||
await conn._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Remove from tracking
|
||||
self.connections.pop(conn_id, None)
|
||||
if project in self.project_connections:
|
||||
self.project_connections[project].discard(conn_id)
|
||||
if not self.project_connections[project]:
|
||||
del self.project_connections[project]
|
||||
|
||||
# Clean up pending ownership
|
||||
stale = [rid for rid, cid in self.pending_owners.items() if cid == conn_id]
|
||||
for rid in stale:
|
||||
del self.pending_owners[rid]
|
||||
|
||||
# Close WebSocket if still open
|
||||
if not conn.ws.closed:
|
||||
await conn.ws.close()
|
||||
|
||||
logger.info(f"[HUB] Disconnected: {conn_id} project={project}")
|
||||
|
||||
# Broadcast instance update (remaining connections get notified)
|
||||
if project:
|
||||
await self._broadcast_instance_update(project)
|
||||
|
||||
# ─── Instance Number Management ───
|
||||
|
||||
def _assign_instance_number(self, project: str, conn_id: str) -> int:
|
||||
"""Assign the next available instance number for a project."""
|
||||
used = set()
|
||||
for cid in self.project_connections.get(project, set()):
|
||||
c = self.connections.get(cid)
|
||||
if c:
|
||||
used.add(c.instance_number)
|
||||
|
||||
# Find lowest available number starting from 1
|
||||
num = 1
|
||||
while num in used:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def get_active_count(self, project: str) -> int:
|
||||
"""Get the number of active connections for a project."""
|
||||
return len(self.project_connections.get(project, set()))
|
||||
|
||||
async def _broadcast_instance_update(self, project: str):
|
||||
"""Notify all project connections of instance count change."""
|
||||
count = self.get_active_count(project)
|
||||
msg = {
|
||||
"type": MsgType.INSTANCE_UPDATE,
|
||||
"active_count": count,
|
||||
"instances": [],
|
||||
}
|
||||
# Include instance list for UI
|
||||
for cid in self.project_connections.get(project, set()):
|
||||
c = self.connections.get(cid)
|
||||
if c:
|
||||
msg["instances"].append({
|
||||
"instance_number": c.instance_number,
|
||||
"pc": c.pc_name,
|
||||
})
|
||||
await self.broadcast_to_project(project, msg)
|
||||
|
||||
# ─── Message Routing ───
|
||||
|
||||
async def broadcast_to_project(self, project: str, message: dict):
|
||||
"""Send a message to all connections in a project."""
|
||||
conn_ids = self.project_connections.get(project, set()).copy()
|
||||
for cid in conn_ids:
|
||||
conn = self.connections.get(cid)
|
||||
if conn and conn.authenticated:
|
||||
await self._queue_send(conn, message)
|
||||
|
||||
async def send_to_connection(self, conn_id: str, message: dict):
|
||||
"""Send a message to a specific connection."""
|
||||
conn = self.connections.get(conn_id)
|
||||
if conn and conn.authenticated:
|
||||
await self._queue_send(conn, message)
|
||||
|
||||
async def send_to_instance(self, project: str, instance_number: int, message: dict):
|
||||
"""Send a message to a specific instance number within a project."""
|
||||
for cid in self.project_connections.get(project, set()):
|
||||
conn = self.connections.get(cid)
|
||||
if conn and conn.instance_number == instance_number:
|
||||
await self._queue_send(conn, message)
|
||||
return True
|
||||
logger.warning(
|
||||
f"[HUB] Instance #{instance_number} not found for project={project}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def send_response_to_pending_owner(self, request_id: str, message: dict):
|
||||
"""Route a response to the Extension that created the pending request."""
|
||||
conn_id = self.pending_owners.get(request_id)
|
||||
if conn_id:
|
||||
await self.send_to_connection(conn_id, message)
|
||||
# Clean up after response delivered
|
||||
self.pending_owners.pop(request_id, None)
|
||||
return True
|
||||
logger.warning(f"[HUB] No owner for pending {request_id[:12]}")
|
||||
return False
|
||||
|
||||
# ─── Per-Client Send Queue (backpressure) ───
|
||||
|
||||
async def _queue_send(self, conn: WSConnection, message: dict):
|
||||
"""Queue a message for sending via the per-client sender loop."""
|
||||
try:
|
||||
conn.send_queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"[HUB] Queue full for {conn.conn_id} — dropping oldest message"
|
||||
)
|
||||
# Drop oldest to make room
|
||||
try:
|
||||
conn.send_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
try:
|
||||
conn.send_queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
async def _sender_loop(self, conn: WSConnection):
|
||||
"""Dedicated sender coroutine for a single connection.
|
||||
|
||||
Reads from the per-client queue and sends via WebSocket.
|
||||
This prevents slow clients from blocking message routing.
|
||||
"""
|
||||
try:
|
||||
while not conn.ws.closed:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
conn.send_queue.get(), timeout=HEARTBEAT_INTERVAL
|
||||
)
|
||||
await conn.ws.send_json(msg)
|
||||
except asyncio.TimeoutError:
|
||||
continue # No message to send — loop continues
|
||||
except ConnectionResetError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB] Send error {conn.conn_id}: {e}")
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _send_direct(self, ws: web.WebSocketResponse, message: dict):
|
||||
"""Send directly (bypasses queue — for auth messages only)."""
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[HUB] Direct send error: {e}")
|
||||
|
||||
# ─── Message Loop ───
|
||||
|
||||
async def _message_loop(self, conn: WSConnection):
|
||||
"""Main message receive loop for an authenticated connection."""
|
||||
async for msg in conn.ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
# Rate limit check
|
||||
if not self._check_rate_limit(conn):
|
||||
await self._queue_send(conn, {
|
||||
"type": MsgType.ERROR,
|
||||
"error": "Rate limited — too many messages",
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
msg_type = data.get("type", "")
|
||||
|
||||
# Dedup check
|
||||
msg_id = data.get("msg_id")
|
||||
if msg_id and self._is_duplicate(msg_id):
|
||||
continue
|
||||
|
||||
await self._handle_message(conn, msg_type, data)
|
||||
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.error(
|
||||
f"[HUB] WS error {conn.conn_id}: {conn.ws.exception()}"
|
||||
)
|
||||
break
|
||||
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
|
||||
break
|
||||
|
||||
async def _handle_message(self, conn: WSConnection, msg_type: str, data: dict):
|
||||
"""Route an incoming message to the appropriate handler."""
|
||||
conn.last_heartbeat = time.time() # Any message counts as heartbeat
|
||||
|
||||
if msg_type == MsgType.PENDING:
|
||||
payload = data.get("data", {})
|
||||
request_id = payload.get("request_id", "")
|
||||
if request_id:
|
||||
# Track ownership for response routing
|
||||
self.pending_owners[request_id] = conn.conn_id
|
||||
# Add source metadata
|
||||
payload["_conn_id"] = conn.conn_id
|
||||
payload["_instance_number"] = conn.instance_number
|
||||
payload["_pc_name"] = conn.pc_name
|
||||
payload.setdefault("project_name", conn.project)
|
||||
if self._on_pending:
|
||||
await self._on_pending(conn.project, payload)
|
||||
|
||||
elif msg_type == MsgType.CHAT:
|
||||
payload = data.get("data", {})
|
||||
payload.setdefault("project_name", conn.project)
|
||||
payload["_instance_number"] = conn.instance_number
|
||||
payload["_pc_name"] = conn.pc_name
|
||||
if self._on_chat:
|
||||
await self._on_chat(conn.project, payload)
|
||||
|
||||
elif msg_type == MsgType.REGISTER:
|
||||
payload = data.get("data", {})
|
||||
payload.setdefault("project_name", conn.project)
|
||||
if self._on_register:
|
||||
await self._on_register(payload)
|
||||
|
||||
elif msg_type == MsgType.AUTO_RESOLVE:
|
||||
payload = data.get("data", {})
|
||||
request_id = payload.get("request_id", "")
|
||||
if request_id:
|
||||
self.pending_owners.pop(request_id, None)
|
||||
if self._on_auto_resolve:
|
||||
await self._on_auto_resolve(conn.project, payload)
|
||||
|
||||
elif msg_type == MsgType.BRAIN_EVENT:
|
||||
payload = data.get("data", {})
|
||||
payload.setdefault("project_name", conn.project)
|
||||
if self._on_brain_event:
|
||||
await self._on_brain_event(conn.project, payload)
|
||||
|
||||
elif msg_type == MsgType.HEARTBEAT:
|
||||
pass # last_heartbeat already updated above
|
||||
|
||||
else:
|
||||
logger.warning(f"[HUB] Unknown message type: {msg_type} from {conn.conn_id}")
|
||||
|
||||
# ─── Rate Limiting ───
|
||||
|
||||
def _check_rate_limit(self, conn: WSConnection) -> bool:
|
||||
"""Per-connection rate limiting. Returns True if allowed."""
|
||||
now = time.time()
|
||||
conn._msg_timestamps = [
|
||||
t for t in conn._msg_timestamps if now - t < RATE_WINDOW
|
||||
]
|
||||
if len(conn._msg_timestamps) >= PER_CONN_RATE_LIMIT:
|
||||
logger.warning(f"[HUB] Rate limited: {conn.conn_id}")
|
||||
return False
|
||||
conn._msg_timestamps.append(now)
|
||||
return True
|
||||
|
||||
# ─── Deduplication ───
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Check if a message ID was recently seen."""
|
||||
now = time.time()
|
||||
# Cleanup old entries
|
||||
if len(self._recent_msg_ids) > 10000:
|
||||
cutoff = now - self._msg_id_ttl
|
||||
self._recent_msg_ids = {
|
||||
k: v for k, v in self._recent_msg_ids.items() if v > cutoff
|
||||
}
|
||||
if msg_id in self._recent_msg_ids:
|
||||
return True
|
||||
self._recent_msg_ids[msg_id] = now
|
||||
return False
|
||||
|
||||
# ─── Diagnostics ───
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Return hub status for health/diagnostics."""
|
||||
projects = {}
|
||||
for project, conn_ids in self.project_connections.items():
|
||||
instances = []
|
||||
for cid in conn_ids:
|
||||
c = self.connections.get(cid)
|
||||
if c:
|
||||
instances.append({
|
||||
"conn_id": c.conn_id,
|
||||
"instance": c.instance_number,
|
||||
"pc": c.pc_name,
|
||||
"connected": int(time.time() - c.connected_at),
|
||||
})
|
||||
projects[project] = {
|
||||
"count": len(conn_ids),
|
||||
"instances": instances,
|
||||
}
|
||||
return {
|
||||
"total_connections": len(self.connections),
|
||||
"projects": projects,
|
||||
"pending_owners": len(self.pending_owners),
|
||||
}
|
||||
21
main.py
21
main.py
@@ -102,14 +102,29 @@ async def main():
|
||||
else:
|
||||
logger.info("Gateway mode — watcher disabled (data via HTTP API)")
|
||||
|
||||
# Start Gateway HTTP API (gateway mode)
|
||||
# Start Gateway HTTP API + WebSocket Hub (gateway mode)
|
||||
if Config.BOT_MODE == 'gateway':
|
||||
from gateway import GatewayAPI
|
||||
from hub import WSHub
|
||||
from auth import TokenManager
|
||||
|
||||
# Initialize Hub
|
||||
token_mgr = TokenManager(
|
||||
secret=Config.GRAVITY_HUB_SECRET,
|
||||
registration_code=Config.GRAVITY_REGISTRATION_CODE,
|
||||
)
|
||||
hub = WSHub(token_mgr)
|
||||
|
||||
gateway_port = int(os.environ.get('GATEWAY_PORT', '8585'))
|
||||
gateway = GatewayAPI(bot, port=gateway_port, api_key=Config.GATEWAY_API_KEY)
|
||||
gateway = GatewayAPI(
|
||||
bot, port=gateway_port,
|
||||
api_key=Config.GATEWAY_API_KEY,
|
||||
hub=hub,
|
||||
)
|
||||
bot.gateway = gateway # Enable _write_command → gateway.push_command
|
||||
bot.hub = hub # Enable Hub-based message routing
|
||||
await gateway.start()
|
||||
logger.info(f"Gateway API running on port {gateway_port}")
|
||||
logger.info(f"Gateway API + WS Hub running on port {gateway_port}")
|
||||
|
||||
# Run Discord bot (blocks until bot disconnects)
|
||||
await bot.start(Config.DISCORD_TOKEN)
|
||||
|
||||
28
tests/test_syntax.py
Normal file
28
tests/test_syntax.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Quick syntax check for all modified/new Python files."""
|
||||
import ast
|
||||
import sys
|
||||
|
||||
files = [
|
||||
'auth.py',
|
||||
'hub.py',
|
||||
'config.py',
|
||||
'gateway.py',
|
||||
'main.py',
|
||||
'bot.py',
|
||||
]
|
||||
|
||||
errors = []
|
||||
for f in files:
|
||||
try:
|
||||
with open(f'../{f}', encoding='utf-8') as fh:
|
||||
ast.parse(fh.read())
|
||||
print(f' OK: {f}')
|
||||
except SyntaxError as e:
|
||||
print(f' FAIL: {f} - {e}')
|
||||
errors.append(f)
|
||||
|
||||
if errors:
|
||||
print(f'\nFailed: {errors}')
|
||||
sys.exit(1)
|
||||
else:
|
||||
print('\nAll files parse OK')
|
||||
74
tests/test_ws_hub.py
Normal file
74
tests/test_ws_hub.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Quick test for WebSocket Hub connection and auth protocol."""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def test_ws():
|
||||
import websockets
|
||||
|
||||
reg_code = os.getenv("GRAVITY_REGISTRATION_CODE")
|
||||
uri = "ws://localhost:8586/ws"
|
||||
|
||||
print(f"1. Connecting to {uri}...")
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Correct auth protocol: type=auth with registration_code
|
||||
auth_msg = {
|
||||
"type": "auth",
|
||||
"registration_code": reg_code,
|
||||
"project": "gravity_control",
|
||||
"pc": "test_pc",
|
||||
}
|
||||
await ws.send(json.dumps(auth_msg))
|
||||
print("2. Sent auth message")
|
||||
|
||||
resp = await asyncio.wait_for(ws.recv(), timeout=5)
|
||||
resp_data = json.loads(resp)
|
||||
print(f"3. Auth response: {json.dumps(resp_data, indent=2)}")
|
||||
|
||||
if resp_data.get("type") == "auth_ok":
|
||||
print("4. AUTH SUCCESS!")
|
||||
conn_id = resp_data.get("conn_id", "")
|
||||
instance = resp_data.get("instance_number", "")
|
||||
print(f" conn_id={conn_id}, instance=#{instance}")
|
||||
|
||||
# Send a chat message (Hub expects type='chat', not 'chat_snapshot')
|
||||
snap = {
|
||||
"type": "chat",
|
||||
"data": {
|
||||
"content": "[TEST] Hub WS test message",
|
||||
"project_name": "gravity_control",
|
||||
"conversation_id": "test-ws-hub-123",
|
||||
},
|
||||
}
|
||||
await ws.send(json.dumps(snap))
|
||||
print("5. Sent chat_snapshot via WS")
|
||||
|
||||
# Check hub status via HTTP
|
||||
r = urllib.request.urlopen("http://localhost:8586/hub/status", timeout=3)
|
||||
status = json.loads(r.read())
|
||||
print(f"6. Hub status: {json.dumps(status, indent=2)}")
|
||||
|
||||
# Send register message
|
||||
reg = {
|
||||
"type": "register",
|
||||
"conversation_id": "test-ws-hub-123",
|
||||
"project_name": "gravity_control",
|
||||
}
|
||||
await ws.send(json.dumps(reg))
|
||||
print("7. Sent register")
|
||||
|
||||
else:
|
||||
reason = resp_data.get("reason", "unknown")
|
||||
print(f"4. AUTH FAILED: {reason}")
|
||||
|
||||
await ws.close()
|
||||
print("8. Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_ws())
|
||||
Reference in New Issue
Block a user