Files
gravity_control/bot.py

649 lines
28 KiB
Python

"""Discord bot — relays Antigravity brain events to Discord channels.
Multi-project channel architecture:
- One channel per project: AG-{project_name} (e.g. ag-gravity_control, ag-deriva)
- Each conversation maps to a project via conv_to_project dict
- Extension registers projects via bridge/pending/ files
- Commands include project_name for routing to correct IDE window
"""
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
import discord
from discord.ext import commands, tasks
from config import Config
from parser import (
parse_task_progress,
md_to_discord_text,
format_task_embed_text,
)
from watcher import BrainEvent, EventType
from bridge import BridgeProtocol, ApprovalRequest, UserResponse
logger = logging.getLogger(__name__)
# ─── Discord UI Components ──────────────────────────────────────────
class ApprovalView(discord.ui.View):
"""Discord buttons for approving/rejecting Antigravity actions."""
def __init__(self, bridge: BridgeProtocol, request: ApprovalRequest):
super().__init__(timeout=1800) # 30 minutes
self.bridge = bridge
self.request = request
self.responded = False
@discord.ui.button(label="✅ 승인", style=discord.ButtonStyle.green)
async def approve(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=True,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.green()
embed.set_footer(text=f"✅ 승인됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
@discord.ui.button(label="❌ 거부", style=discord.ButtonStyle.red)
async def reject(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=False,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.red()
embed.set_footer(text=f"❌ 거부됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
async def on_timeout(self):
if not self.responded:
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=False,
))
# ─── Bot ─────────────────────────────────────────────────────────────
class GravityBot(commands.Bot):
"""Discord bot for Antigravity session monitoring.
Multi-project architecture:
- project_channels: project_name → TextChannel (ag-gravity_control, ag-deriva, etc.)
- conv_to_project: conversation_id → project_name (learned from pending approvals)
- channel_to_project: channel_id → project_name (for Discord→IDE routing)
"""
def __init__(self, event_queue: asyncio.Queue):
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
super().__init__(command_prefix="!", intents=intents)
self.event_queue = event_queue
self.project_channels: dict[str, discord.TextChannel] = {} # project → channel
self.conv_to_project: dict[str, str] = {} # conv_id → project
self.channel_to_project: dict[int, str] = {} # channel.id → project
self.session_status_messages: dict[str, int] = {} # conv_id → msg_id
self._sent_approval_ids: set[str] = set()
self._deferred_ids: dict[str, int] = {} # request_id → defer count
self._ready_event = asyncio.Event()
self._channel_lock = asyncio.Lock()
self.bridge = BridgeProtocol()
self.session_category: discord.CategoryChannel | None = None
self.guild: discord.Guild | None = None
@staticmethod
def _make_channel_name(project_name: str) -> str:
"""ag-gravity_control, ag-deriva, etc."""
return f"{Config.CHANNEL_PREFIX}-{project_name}".lower()
async def setup_hook(self):
self.loop.create_task(self._process_events())
self.pending_approval_scanner.start()
self.chat_snapshot_scanner.start()
self._register_slash_commands()
logger.info("Bot setup complete")
def _register_slash_commands(self):
"""Register Discord slash commands."""
@self.tree.command(name="stop", description="AI 작업 중지")
async def slash_stop(interaction: discord.Interaction):
project = self.channel_to_project.get(interaction.channel_id)
if not project:
await interaction.response.send_message("⚠️ 프로젝트 채널이 아닙니다.", ephemeral=True)
return
self.bridge.write_command(project, "!stop", project_name=project)
await interaction.response.send_message(
embed=discord.Embed(
title="⏹️ AI 작업 중지",
description=f"**{project}** IDE에 중지 요청 전달됨",
color=discord.Color.orange(),
)
)
@self.tree.command(name="auto", description="자동 승인 토글")
async def slash_auto(interaction: discord.Interaction, mode: str):
project = self.channel_to_project.get(interaction.channel_id)
if not project:
await interaction.response.send_message("⚠️ 프로젝트 채널이 아닙니다.", ephemeral=True)
return
enabled = mode.lower() in ("on", "true", "1")
self.bridge.write_command(project, f"!auto {'on' if enabled else 'off'}", project_name=project)
emoji = "🟢" if enabled else "🔴"
await interaction.response.send_message(
embed=discord.Embed(
title=f"{emoji} {'자동 승인' if enabled else '수동 승인'} 모드",
description=f"프로젝트: **{project}**",
color=discord.Color.green() if enabled else discord.Color.red(),
)
)
@self.tree.command(name="send", description="IDE 채팅에 메시지 전송")
async def slash_send(interaction: discord.Interaction, message: str):
project = self.channel_to_project.get(interaction.channel_id)
if not project:
await interaction.response.send_message("⚠️ 프로젝트 채널이 아닙니다.", ephemeral=True)
return
self.bridge.write_command(project, message, project_name=project)
await interaction.response.send_message(
embed=discord.Embed(
description=f"📨 → **{project}** IDE에 전달됨\n`{message[:100]}`",
color=discord.Color.blurple(),
),
delete_after=10,
)
async def on_ready(self):
logger.info(f"Bot connected as {self.user} (ID: {self.user.id})")
self.guild = self.get_guild(Config.DISCORD_GUILD_ID)
if not self.guild:
logger.error(f"Guild {Config.DISCORD_GUILD_ID} not found!")
return
# Find or create category
category_name = "Antigravity Sessions"
self.session_category = discord.utils.get(
self.guild.categories, name=category_name
)
if not self.session_category:
try:
self.session_category = await self.guild.create_category(category_name)
logger.info(f"Created category: {category_name}")
except discord.errors.Forbidden:
logger.error("No permission to create category!")
return
# Discover existing project channels
await self._discover_channels()
# Load conversation → project registrations from Extension
self._load_registrations()
# Sync slash commands to guild
try:
self.tree.copy_global_to(guild=self.guild)
synced = await self.tree.sync(guild=self.guild)
logger.info(f"Synced {len(synced)} slash commands to guild")
except Exception as e:
logger.warning(f"Slash command sync failed: {e}")
# Open the gate
self._ready_event.set()
logger.info("Ready gate opened — event processing enabled")
# Start scanner loops
if not self.pending_approval_scanner.is_running():
self.pending_approval_scanner.start()
if not self.chat_snapshot_scanner.is_running():
self.chat_snapshot_scanner.start()
logger.info("Scanner loops started")
# ─── Channel Management ──────────────────────────────────────────
def _load_registrations(self):
"""Read bridge/register/ to learn conversation → project mappings."""
register_dir = self.bridge.bridge_dir / "register"
if not register_dir.exists():
return
count = 0
for f in register_dir.glob("*.json"):
try:
data = json.loads(f.read_text(encoding="utf-8-sig"))
conv_id = data.get("conversation_id", "")
project = data.get("project_name", "")
if conv_id and project:
self.conv_to_project[conv_id] = project
count += 1
except (json.JSONDecodeError, OSError):
pass
# Only log when count changes
prev = getattr(self, '_last_reg_count', -1)
if count != prev:
self._last_reg_count = count
if count:
logger.info(f"Loaded {count} conversation→project registrations")
# ─── Channel Management ──────────────────────────────────────────
async def _discover_channels(self):
"""Find existing project channels via Discord API (not cache)."""
all_channels = await self.guild.fetch_channels()
prefix = Config.CHANNEL_PREFIX.lower() + "-"
for ch in all_channels:
if (isinstance(ch, discord.TextChannel)
and ch.category_id == self.session_category.id
and ch.name.startswith(prefix)):
project = ch.name[len(prefix):]
self.project_channels[project] = ch
self.channel_to_project[ch.id] = project
logger.info(f"Found channel: #{ch.name} → project={project}")
logger.info(f"Discovered {len(self.project_channels)} project channels")
async def _get_channel(self, project_name: str) -> discord.TextChannel:
"""Get or create a channel for a project. Lock-protected."""
if project_name in self.project_channels:
return self.project_channels[project_name]
async with self._channel_lock:
# Double-check after lock
if project_name in self.project_channels:
return self.project_channels[project_name]
channel_name = self._make_channel_name(project_name)
# Search existing channels FIRST (prevents duplicates)
try:
all_channels = await self.guild.fetch_channels()
for ch in all_channels:
if (isinstance(ch, discord.TextChannel)
and ch.name == channel_name
and ch.category_id == self.session_category.id):
self.project_channels[project_name] = ch
self.channel_to_project[ch.id] = project_name
logger.info(f"Found existing channel: #{channel_name}")
return ch
except Exception as e:
logger.warning(f"fetch_channels failed: {e}")
# No existing channel — create new
try:
ch = await self.guild.create_text_channel(
name=channel_name,
category=self.session_category,
topic=f"Antigravity Bridge — {project_name}",
)
self.project_channels[project_name] = ch
self.channel_to_project[ch.id] = project_name
logger.info(f"Created channel: #{channel_name}")
embed = discord.Embed(
title=f"🚀 {project_name}",
description=f"Antigravity Bridge 연결됨",
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
await ch.send(embed=embed)
return ch
except discord.errors.Forbidden:
logger.error(f"No permission to create channel: {channel_name}")
return None
def _resolve_project(self, conversation_id: str) -> str:
"""Get project name for a conversation. Falls back to default."""
return self.conv_to_project.get(
conversation_id, Config.PROJECT_NAME
)
# ─── Event Processing ─────────────────────────────────────────────
async def _process_events(self):
"""Main event loop — ALL events go through here sequentially."""
await self.wait_until_ready()
await self._ready_event.wait()
logger.info("Event processor started (ready gate passed)")
while not self.is_closed():
try:
event = await asyncio.wait_for(
self.event_queue.get(), timeout=5.0
)
await self._handle_event(event)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error processing event: {e}", exc_info=True)
async def _handle_event(self, event: BrainEvent):
"""Route brain events to the correct project channel."""
project = self._resolve_project(event.conversation_id)
channel = await self._get_channel(project)
if not channel:
return
if event.event_type == EventType.SESSION_START:
return
try:
if event.file_name == "task.md":
await self._send_task_update(channel, event)
else:
await self._send_artifact_update(channel, event)
except discord.NotFound:
self.project_channels.pop(project, None)
logger.warning(f"Channel deleted for project {project}, will recreate")
# ─── Message Senders ─────────────────────────────────────────────
async def _send_task_update(
self, channel: discord.TextChannel, event: BrainEvent
):
progress = parse_task_progress(event.content)
# Full task content (truncated to embed limit)
full_content = event.content.strip()
description = format_task_embed_text(progress) + "\n\n" + full_content
if len(description) > 4000:
description = description[:4000] + "\n…(truncated)"
embed = discord.Embed(
title="📋 Task 진행 현황",
description=description,
color=discord.Color.gold() if progress.in_progress > 0
else discord.Color.green() if progress.done == progress.total
else discord.Color.greyple(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"Session: {event.conversation_id[:8]}")
msg_id = self.session_status_messages.get(event.conversation_id)
if msg_id:
try:
msg = await channel.fetch_message(msg_id)
await msg.edit(embed=embed)
return
except (discord.NotFound, discord.HTTPException):
pass
msg = await channel.send(embed=embed)
self.session_status_messages[event.conversation_id] = msg.id
async def _send_artifact_update(
self, channel: discord.TextChannel, event: BrainEvent
):
labels = {
"implementation_plan.md": "📐 구현 계획",
"walkthrough.md": "📝 작업 결과 요약",
}
label = labels.get(event.file_name, f"📄 {event.file_name}")
event_label = "생성" if event.event_type == EventType.FILE_CREATED else "업데이트"
full_content = event.content.strip()
CHUNK_SIZE = 4000 # Discord embed desc limit is 4096
# Split into chunks for long content
chunks = []
while full_content:
chunks.append(full_content[:CHUNK_SIZE])
full_content = full_content[CHUNK_SIZE:]
if not chunks:
chunks = ["(빈 파일)"]
# First chunk with title
embed = discord.Embed(
title=f"{label} ({event_label}됨)",
description=chunks[0],
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"Session: {event.conversation_id[:8]}")
await channel.send(embed=embed)
# Additional chunks if content is long
for i, chunk in enumerate(chunks[1:], 2):
embed = discord.Embed(
title=f"{label} (계속 {i}/{len(chunks)})",
description=chunk,
color=discord.Color.blue(),
)
await channel.send(embed=embed)
# ─── Approval Scanner ────────────────────────────────────────────
@tasks.loop(seconds=3)
async def pending_approval_scanner(self):
"""Scan bridge/pending/ for new approval requests + reload registrations."""
try:
# Reload conv→project registrations each cycle
self._load_registrations()
# Ensure channels exist for all registered projects
for project in set(self.conv_to_project.values()):
if project not in self.project_channels:
await self._get_channel(project)
logger.info(f"Auto-created channel for registered project: {project}")
requests = self.bridge.get_pending_requests()
for req in requests:
if req.request_id in self._sent_approval_ids:
continue
if req.discord_message_id != 0:
continue
# Defer short-command pendings (e.g. "Run") by 4 cycles (~12s)
# to give step_probe time to merge detailed command info
# (step_probe MERGE happens ~10s after pending creation)
if len(req.command) <= 15:
if req.request_id not in self._deferred_ids:
self._deferred_ids[req.request_id] = 1
continue # skip this cycle
elif self._deferred_ids[req.request_id] < 4:
self._deferred_ids[req.request_id] += 1
# Re-read from file (step_probe may have merged)
fresh = self.bridge.read_pending_request(req.request_id)
if fresh and len(fresh.command) > 15:
req = fresh # use merged version — send now!
else:
continue # wait one more cycle
# Clean up defer tracking
self._deferred_ids.pop(req.request_id, None)
# Learn project mapping from pending approval
project = req.project_name or Config.PROJECT_NAME
if req.conversation_id and req.conversation_id != '__global__':
self.conv_to_project[req.conversation_id] = project
channel = await self._get_channel(project)
if channel:
self._sent_approval_ids.add(req.request_id)
await self._send_approval_request(channel, req)
# ── Check for auto_resolved pendings (approved directly in AG) ──
for f in self.bridge.pending_dir.glob("*.json"):
try:
data = json.loads(f.read_text(encoding="utf-8-sig"))
if data.get("status") == "auto_resolved":
msg_id = data.get("discord_message_id", 0)
project = data.get("project_name", Config.PROJECT_NAME)
if msg_id:
channel = await self._get_channel(project)
if channel:
try:
msg = await channel.fetch_message(msg_id)
embed = discord.Embed(
title="✅ AG에서 직접 승인됨",
description=f"```\n{data.get('command', '')[:500]}\n```",
color=discord.Color.green(),
)
embed.set_footer(text=f"ID: {data.get('request_id', '')}")
await msg.edit(embed=embed, view=None)
except discord.NotFound:
pass
f.unlink()
self._deferred_ids.pop(data.get("request_id", ""), None)
except (json.JSONDecodeError, OSError):
pass
except Exception as e:
logger.error(f"Error scanning approvals: {e}")
@pending_approval_scanner.before_loop
async def before_scanner(self):
await self.wait_until_ready()
async def _send_approval_request(
self, channel: discord.TextChannel, request: ApprovalRequest
):
embed = discord.Embed(
title="⚠️ 승인 요청",
description=(
f"**명령어:**\n```\n{request.command[:1000]}\n```\n"
f"{request.description[:500]}"
),
color=discord.Color.orange(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"ID: {request.request_id}")
view = ApprovalView(self.bridge, request)
msg = await channel.send(embed=embed, view=view)
pending_file = self.bridge.pending_dir / f"{request.request_id}.json"
if pending_file.exists():
try:
data = json.loads(pending_file.read_text(encoding="utf-8-sig"))
data["discord_message_id"] = msg.id
pending_file.write_text(
json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
)
except (json.JSONDecodeError, OSError):
pass
logger.info(f"Sent approval request: {request.request_id[:12]}")
# ─── Discord → IDE Text Relay ─────────────────────────────────────
async def on_message(self, message: discord.Message):
if message.author == self.user:
return
# Determine project from channel
project = self.channel_to_project.get(message.channel.id)
if not project:
await self.process_commands(message)
return
text = message.content.strip()
# Special command: !stop — cancel AI work
if text == "!stop":
self.bridge.write_command(project, "!stop", project_name=project)
embed = discord.Embed(
title="⏹️ AI 작업 중지",
description=f"프로젝트: **{project}**\n중지 요청을 Extension에 전달했습니다.",
color=discord.Color.orange(),
)
await message.channel.send(embed=embed)
return
# Special command: !auto on/off
if text in ("!auto on", "!auto off"):
self.bridge.write_command(project, text, project_name=project)
enabled = text == "!auto on"
emoji = "🟢" if enabled else "🔴"
mode = "자동 승인" if enabled else "수동 승인"
embed = discord.Embed(
title=f"{emoji} {mode} 모드",
description=f"프로젝트: **{project}**\n"
f"`chat.tools.autoApprove = {enabled}`\n"
f"`chat.agent.autoApprove = {enabled}`",
color=discord.Color.green() if enabled else discord.Color.red(),
)
await message.channel.send(embed=embed)
return
# General text relay — routed by project
if text:
self.bridge.write_command(project, text, project_name=project)
await message.add_reaction("📨")
embed = discord.Embed(
description=f"📨 → **{project}** IDE에 전달됨\n`{text[:100]}`",
color=discord.Color.blurple(),
)
await message.channel.send(embed=embed, delete_after=10)
await self.process_commands(message)
# ─── Chat Snapshot Scanner ─────────────────────────────────────────
@tasks.loop(seconds=5)
async def chat_snapshot_scanner(self):
"""Scan bridge/chat_snapshots/ for AI response dumps."""
try:
snapshot_dir = self.bridge.bridge_dir / "chat_snapshots"
if not snapshot_dir.exists():
return
for f in snapshot_dir.glob("*.json"):
try:
data = json.loads(f.read_text(encoding="utf-8-sig"))
project = data.get("project_name", Config.PROJECT_NAME)
content = data.get("content", "")
if content:
channel = await self._get_channel(project)
if channel:
# Split long content
CHUNK = 4000
chunks = [content[i:i+CHUNK] for i in range(0, len(content), CHUNK)]
for i, chunk in enumerate(chunks):
title = "💬 AI 대화 내용" if i == 0 else f"💬 (계속 {i+1}/{len(chunks)})"
embed = discord.Embed(
title=title,
description=chunk,
color=discord.Color.purple(),
timestamp=datetime.now(timezone.utc),
)
try:
await channel.send(embed=embed)
except discord.NotFound:
# Channel was deleted — invalidate cache and retry once
logger.warning(f"Channel deleted for {project}, re-creating...")
self.project_channels.pop(project, None)
channel = await self._get_channel(project)
if channel:
await channel.send(embed=embed)
break
f.unlink() # Cleanup
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Bad chat snapshot {f.name}: {e}")
except Exception as e:
logger.error(f"Error scanning chat snapshots: {e}")
@chat_snapshot_scanner.before_loop
async def before_chat_scanner(self):
await self.wait_until_ready()