Files
gravity_control/bridge.py

459 lines
17 KiB
Python

"""Bridge protocol — communication between Discord bot and Antigravity.
Bridge directory: ~/.gemini/antigravity/bridge/
Structure:
bridge/
pending/ ← Bot writes approval requests for Discord
response/ ← Bot writes user responses from Discord
commands/ ← Bot writes user text input from Discord
Protocol:
1. VS Code Extension detects pending approval → writes JSON to pending/
2. Bot reads pending/ → sends Discord message with ✅/❌ buttons
3. User clicks button → Bot writes JSON to response/
4. VS Code Extension reads response/ → executes action
Transport layer:
LocalTransport — file-based (default, single-PC)
RemoteTransport — HTTP-based (future: multi-PC collector mode)
"""
import json
import time
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from dataclasses import dataclass, asdict
from enum import Enum
from config import Config
logger = logging.getLogger(__name__)
class ApprovalStatus(Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
TIMEOUT = "timeout"
@dataclass
class ApprovalRequest:
"""An approval request from Antigravity."""
request_id: str
conversation_id: str
command: str # The command/action needing approval
description: str # Human-readable description
timestamp: float
status: str = "pending"
discord_message_id: int = 0
project_name: str = "" # Project routing key
step_type: str = "" # e.g. 'diff_review', passed through to response
@dataclass
class UserResponse:
"""A user response from Discord."""
request_id: str
approved: bool
user_input: str = ""
timestamp: float = 0
button_index: int = -1 # -1 = legacy (approve/reject), 0+ = specific button index
step_type: str = "" # pass through from pending for extension routing
project_name: str = "" # for multi-project: extension uses this when pending file is missing
# ─── Transport Abstraction ───
class BridgeTransport(ABC):
"""Abstract transport for bridge I/O.
Implementations handle reading/writing JSON files for the bridge protocol,
regardless of whether the storage is local filesystem or remote HTTP.
"""
@abstractmethod
def list_json_files(self, subdir: str) -> list[str]:
"""List JSON filenames in a subdirectory (e.g. 'pending', 'response')."""
...
@abstractmethod
def read_json(self, subdir: str, filename: str) -> dict | None:
"""Read and parse a JSON file. Returns None if not found or corrupt."""
...
@abstractmethod
def write_json(self, subdir: str, filename: str, data: dict) -> None:
"""Write data as JSON to a file in the given subdirectory."""
...
@abstractmethod
def delete_file(self, subdir: str, filename: str) -> bool:
"""Delete a file. Returns True if deleted, False if not found."""
...
@abstractmethod
def ensure_dirs(self) -> None:
"""Ensure all required subdirectories exist."""
...
class LocalTransport(BridgeTransport):
"""File-system based transport (default, single-PC mode).
Reads/writes directly to the bridge directory on local disk.
This is the existing behavior, extracted into a transport class.
"""
def __init__(self, bridge_dir: Path):
self.bridge_dir = bridge_dir
def list_json_files(self, subdir: str) -> list[str]:
d = self.bridge_dir / subdir
if not d.exists():
return []
return [f.name for f in d.glob("*.json")]
def read_json(self, subdir: str, filename: str) -> dict | None:
fp = self.bridge_dir / subdir / filename
if not fp.exists():
return None
try:
return json.loads(fp.read_text(encoding="utf-8-sig"))
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"LocalTransport: bad file {subdir}/{filename}: {e}")
return None
def write_json(self, subdir: str, filename: str, data: dict) -> None:
d = self.bridge_dir / subdir
d.mkdir(parents=True, exist_ok=True)
fp = d / filename
fp.write_text(
json.dumps(data, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def delete_file(self, subdir: str, filename: str) -> bool:
fp = self.bridge_dir / subdir / filename
if fp.exists():
try:
fp.unlink()
return True
except OSError:
return False
return False
def ensure_dirs(self) -> None:
for sub in ("pending", "response", "commands"):
(self.bridge_dir / sub).mkdir(parents=True, exist_ok=True)
class RemoteTransport(BridgeTransport):
"""HTTP-based transport for Collector → Gateway communication.
Maps BridgeTransport methods to Gateway API endpoints:
list_json_files("pending") → GET /api/pending (returns list)
write_json("pending", ...) → POST /api/pending
read_json("response", ...) → GET /api/response/{rid}
write_json("commands", ...) → (not used by Collector, Gateway pushes commands)
etc.
"""
def __init__(self, base_url: str, api_key: str = ""):
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self._headers = {"Content-Type": "application/json"}
if api_key:
self._headers["Authorization"] = f"Bearer {api_key}"
self._session = None # aiohttp.ClientSession — lazy created
# Connection health
self.connected = False
self._consecutive_failures = 0
self._max_failures_before_warning = 3
# Rate limit backoff
self._rate_limited_until = 0.0 # timestamp until which we should not send requests
self._backoff_seconds = 0.0 # current backoff duration (exponential)
self._BACKOFF_BASE = 1.0
self._BACKOFF_MAX = 60.0
# Retry queue: list of (method, path, data) tuples
self._retry_queue: list[tuple[str, str, dict | None]] = []
self._retry_queue_max = 100
logger.info(f"RemoteTransport: {self.base_url} (auth={'yes' if api_key else 'no'})")
async def _get_session(self):
"""Lazy-create aiohttp session."""
if self._session is None or self._session.closed:
import aiohttp
timeout = aiohttp.ClientTimeout(total=10)
self._session = aiohttp.ClientSession(
headers=self._headers, timeout=timeout
)
return self._session
async def close(self):
"""Close the HTTP session."""
if self._session and not self._session.closed:
await self._session.close()
@property
def is_rate_limited(self) -> bool:
"""Check if we are currently in a rate-limit backoff period."""
return time.time() < self._rate_limited_until
def _apply_backoff(self, retry_after: float = 0):
"""Apply exponential backoff for rate limiting."""
if retry_after > 0:
self._backoff_seconds = min(retry_after, self._BACKOFF_MAX)
else:
if self._backoff_seconds == 0:
self._backoff_seconds = self._BACKOFF_BASE
else:
self._backoff_seconds = min(self._backoff_seconds * 2, self._BACKOFF_MAX)
self._rate_limited_until = time.time() + self._backoff_seconds
logger.warning(f"RemoteTransport: backing off {self._backoff_seconds:.0f}s (until +{self._backoff_seconds:.0f}s)")
def _reset_backoff(self):
"""Reset backoff after a successful request."""
if self._backoff_seconds > 0:
self._backoff_seconds = 0
self._rate_limited_until = 0
async def _arequest(self, method: str, path: str, data: dict | None = None) -> dict | None:
"""Async non-blocking HTTP request to Gateway API."""
# Skip if in backoff period (except health checks)
if self.is_rate_limited and path != "/health":
return None
session = await self._get_session()
url = f"{self.base_url}{path}"
try:
kwargs = {}
if data is not None:
kwargs["json"] = data
async with session.request(method, url, **kwargs) as resp:
if resp.status >= 400:
if resp.status == 401:
logger.error("RemoteTransport: 401 Unauthorized — check GATEWAY_API_KEY")
elif resp.status == 429:
retry_after = float(resp.headers.get("Retry-After", 0))
self._apply_backoff(retry_after)
else:
logger.warning(f"RemoteTransport: {method} {path}{resp.status}")
return None
result = await resp.json()
if not self.connected:
logger.info("RemoteTransport: ✅ Gateway connected")
self.connected = True
self._consecutive_failures = 0
self._reset_backoff()
return result
except Exception as e:
self._consecutive_failures += 1
if self._consecutive_failures == self._max_failures_before_warning:
logger.error(f"RemoteTransport: ❌ Gateway unreachable ({self._consecutive_failures} failures): {e}")
elif self._consecutive_failures < self._max_failures_before_warning:
logger.warning(f"RemoteTransport: {method} {path}{e}")
self.connected = False
# Apply backoff on connection failures too
if self._consecutive_failures >= self._max_failures_before_warning:
self._apply_backoff()
return None
async def _arequest_retry(self, method: str, path: str, data: dict | None = None) -> dict | None:
"""Request with retry queue — failed POSTs are queued for later."""
result = await self._arequest(method, path, data)
if result is None and method == "POST" and data is not None:
if len(self._retry_queue) < self._retry_queue_max:
self._retry_queue.append((method, path, data))
return result
async def flush_retry_queue(self):
"""Retry queued failed requests."""
if not self._retry_queue or not self.connected:
return
queue = self._retry_queue[:]
self._retry_queue.clear()
succeeded = 0
for method, path, data in queue:
result = await self._arequest(method, path, data)
if result is None:
if len(self._retry_queue) < self._retry_queue_max:
self._retry_queue.append((method, path, data))
break
succeeded += 1
if succeeded:
logger.info(f"[RETRY] flushed {succeeded}/{len(queue)} queued requests")
async def health_check(self) -> bool:
"""Check if Gateway is reachable."""
result = await self._arequest("GET", "/health")
return result is not None and result.get("status") == "ok"
# ─── Async methods (used by Collector) ───
async def awrite_json(self, subdir: str, filename: str, data: dict) -> None:
if subdir == "pending":
await self._arequest_retry("POST", "/api/pending", data)
elif subdir == "response":
rid = data.get("request_id", filename.replace(".json", ""))
await self._arequest_retry("POST", f"/api/response/{rid}", data)
async def aread_json(self, subdir: str, filename: str) -> dict | None:
rid = filename.replace(".json", "")
if subdir == "response":
return await self._arequest("GET", f"/api/response/{rid}")
return None
async def apoll_commands(self, project: str) -> list[dict]:
result = await self._arequest("GET", f"/api/commands/{project}")
if result and isinstance(result, dict):
return result.get("commands", [])
return []
async def aregister_session(self, conv_id: str, project: str) -> None:
await self._arequest_retry("POST", "/api/register", {
"conversation_id": conv_id, "project_name": project,
})
async def asend_chat(self, project: str, content: str) -> None:
await self._arequest_retry("POST", "/api/chat", {
"project_name": project, "content": content,
})
async def asend_event(self, event_data: dict) -> None:
await self._arequest_retry("POST", "/api/event", event_data)
# ─── Sync stubs (ABC compliance, not used in Collector) ───
def list_json_files(self, subdir: str) -> list[str]:
return []
def read_json(self, subdir: str, filename: str) -> dict | None:
return None
def write_json(self, subdir: str, filename: str, data: dict) -> None:
pass
def delete_file(self, subdir: str, filename: str) -> bool:
return True
def ensure_dirs(self) -> None:
pass
# ─── Bridge Protocol (uses Transport) ───
class BridgeProtocol:
"""Manages the bridge protocol via a pluggable transport."""
def __init__(self, transport: BridgeTransport | None = None):
if transport is None:
bridge_dir = Config.BRAIN_PATH.parent / "bridge"
transport = LocalTransport(bridge_dir)
self.transport = transport
# Legacy attributes for backward compatibility
# (bot.py uses self.bridge.pending_dir etc. in some places)
if isinstance(transport, LocalTransport):
self.bridge_dir = transport.bridge_dir
self.pending_dir = transport.bridge_dir / "pending"
self.response_dir = transport.bridge_dir / "response"
self.commands_dir = transport.bridge_dir / "commands"
# Ensure directories exist
self.transport.ensure_dirs()
# Startup cleanup: purge stale pending files (> 5 min old)
self._cleanup_stale_pending()
logger.info(f"Bridge protocol initialized: transport={type(transport).__name__}")
def _cleanup_stale_pending(self, max_age_seconds: int = 300):
"""Remove pending files older than max_age_seconds on startup."""
now = time.time()
cleaned = 0
for fname in self.transport.list_json_files("pending"):
data = self.transport.read_json("pending", fname)
if data is None:
self.transport.delete_file("pending", fname)
cleaned += 1
continue
ts = data.get("timestamp", 0)
if now - ts > max_age_seconds:
self.transport.delete_file("pending", fname)
cleaned += 1
if cleaned:
logger.info(f"Startup cleanup: removed {cleaned} stale pending files")
def get_pending_requests(self) -> list[ApprovalRequest]:
"""Read all pending approval requests. Skips files older than 30 minutes."""
requests = []
fields = {f.name for f in ApprovalRequest.__dataclass_fields__.values()}
now = time.time()
MAX_AGE = 1800 # 30 minutes (matches Discord button timeout)
for fname in self.transport.list_json_files("pending"):
data = self.transport.read_json("pending", fname)
if data is None:
continue
ts = data.get("timestamp", 0)
if now - ts > MAX_AGE:
# Too old — mark expired and skip
data["status"] = "expired"
self.transport.write_json("pending", fname, data)
continue
if data.get("status") == "pending":
# Filter to known fields only
filtered = {k: v for k, v in data.items() if k in fields}
try:
requests.append(ApprovalRequest(**filtered))
except TypeError as e:
logger.warning(f"Bad pending request {fname}: {e}")
return requests
def read_pending_request(self, request_id: str) -> ApprovalRequest | None:
"""Re-read a specific pending request (to get merged data)."""
fname = f"{request_id}.json"
data = self.transport.read_json("pending", fname)
if data is None:
return None
fields = {fn.name for fn in ApprovalRequest.__dataclass_fields__.values()}
filtered = {k: v for k, v in data.items() if k in fields}
try:
return ApprovalRequest(**filtered)
except TypeError:
return None
def write_response(self, response: UserResponse):
"""Write a user response to the response directory."""
response.timestamp = time.time()
fname = f"{response.request_id}.json"
self.transport.write_json("response", fname, asdict(response))
logger.info(f"Response written: {fname} (approved={response.approved})")
# Delete pending file after processing (prevents re-processing and accumulation)
self.transport.delete_file("pending", fname)
def write_command(self, conversation_id: str, text: str, *, project_name: str = ""):
"""Write a user text command for Antigravity to consume."""
cmd_id = f"{int(time.time() * 1000)}"
fname = f"{cmd_id}.json"
data = {
"id": cmd_id,
"conversation_id": conversation_id,
"project_name": project_name,
"text": text,
"timestamp": time.time(),
"consumed": False,
}
self.transport.write_json("commands", fname, data)
logger.info(f"Command written: {cmd_id} → project={project_name}")
return cmd_id