Files
gravity_control/bot.py

488 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Discord bot — relays Antigravity brain events to Discord channels.
Dynamic channel management:
- Creates `AG-{project_name}` channels only when file events arrive
- NO startup channel creation — only reconnects to existing Discord channels
- Archives channels after 10 minutes of inactivity
"""
import asyncio
import json
import logging
import re
import time
from datetime import datetime, timezone
from pathlib import Path
import discord
from discord.ext import commands, tasks
from config import Config
from parser import (
parse_task_progress,
md_to_discord_text,
format_task_embed_text,
)
from watcher import BrainEvent, EventType
from bridge import BridgeProtocol, ApprovalRequest, UserResponse
logger = logging.getLogger(__name__)
# ─── Discord UI Components ──────────────────────────────────────────
class ApprovalView(discord.ui.View):
"""Discord buttons for approving/rejecting Antigravity actions."""
def __init__(self, bridge: BridgeProtocol, request: ApprovalRequest):
super().__init__(timeout=300)
self.bridge = bridge
self.request = request
self.responded = False
@discord.ui.button(label="✅ 승인", style=discord.ButtonStyle.green)
async def approve(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=True,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.green()
embed.set_footer(text=f"✅ 승인됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
@discord.ui.button(label="❌ 거부", style=discord.ButtonStyle.red)
async def reject(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=False,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.red()
embed.set_footer(text=f"❌ 거부됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
async def on_timeout(self):
if not self.responded:
self.bridge.write_response(UserResponse(
request_id=self.request.request_id, approved=False,
))
# ─── Project Name Detection ─────────────────────────────────────────
def detect_project_name(conv_dir: Path) -> str:
"""Extract project name from conversation artifacts.
Returns: lowercase_with_underscores (e.g. 'gravity_control')
Uses FIRST successful extraction and caches it.
"""
short_id = conv_dir.name[:8]
def _sanitize(raw: str) -> str:
for suffix in ["Task Tracker", "— Task Tracker", "태스크", "구현 계획",
"Implementation Plan", "Walkthrough"]:
raw = raw.replace(suffix, "")
raw = raw.strip(" —-")
raw = re.sub(r'[^a-zA-Z0-9가-힣\s\-]', '', raw)
raw = re.sub(r'[\s\-]+', '_', raw).strip('_').lower()
return raw[:30] if raw else ""
for fname in ["task.md", "implementation_plan.md"]:
fpath = conv_dir / fname
if fpath.exists():
try:
first_lines = fpath.read_text(encoding="utf-8").splitlines()[:5]
for line in first_lines:
match = re.match(r'^#\s+(.+)', line)
if match:
name = _sanitize(match.group(1))
# Require at least 5 chars to avoid short generic names
if name and name != "task" and len(name) >= 5:
return name
except (OSError, UnicodeDecodeError):
pass
return short_id
# ─── Bot ─────────────────────────────────────────────────────────────
class GravityBot(commands.Bot):
"""Discord bot for Antigravity session monitoring."""
def __init__(self, event_queue: asyncio.Queue):
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
super().__init__(command_prefix="!", intents=intents)
self.event_queue = event_queue
self.session_channels: dict[str, discord.TextChannel] = {}
self.session_status_messages: dict[str, int] = {}
self.session_names: dict[str, str] = {}
self._channel_create_lock = asyncio.Lock() # SINGLE global lock
self._sent_approval_ids: set[str] = set() # Track sent approvals
self._ready_event = asyncio.Event() # Gate: wait until on_ready finishes
self.bridge = BridgeProtocol()
self.session_category: discord.CategoryChannel | None = None
self.guild: discord.Guild | None = None
async def setup_hook(self):
self.loop.create_task(self._process_events())
self.pending_approval_scanner.start()
logger.info("Bot setup complete, event processor started")
async def on_ready(self):
logger.info(f"Bot connected as {self.user} (ID: {self.user.id})")
self.guild = self.get_guild(Config.DISCORD_GUILD_ID)
if not self.guild:
logger.error(f"Guild {Config.DISCORD_GUILD_ID} not found!")
return
# Find or create category
category_name = "Antigravity Sessions"
self.session_category = discord.utils.get(
self.guild.categories, name=category_name
)
if not self.session_category:
try:
self.session_category = await self.guild.create_category(category_name)
logger.info(f"Created category: {category_name}")
except discord.errors.Forbidden:
logger.error("No permission to create category!")
return
# ONLY reconnect to existing Discord channels (NO new creation)
await self._reconnect_existing_channels()
# NOW allow event processing to begin
self._ready_event.set()
logger.info("Ready gate opened — event processing enabled")
async def _reconnect_existing_channels(self):
"""Scan existing Discord channels and map them — MERGE same-name channels."""
if not self.session_category:
return
# Group channels by normalized name
name_to_channel: dict[str, discord.TextChannel] = {}
duplicates: list[discord.TextChannel] = []
for ch in self.session_category.text_channels:
if ch.topic and "Antigravity Session:" in ch.topic:
if ch.name in name_to_channel:
# DUPLICATE — mark for cleanup
duplicates.append(ch)
else:
name_to_channel[ch.name] = ch
# Map the primary channel for each name
count = 0
for ch in name_to_channel.values():
conv_id = ch.topic.replace("Antigravity Session:", "").strip()
if conv_id:
self.session_channels[conv_id] = ch
await self._recover_task_message(ch, conv_id)
count += 1
# Delete duplicate channels
for ch in duplicates:
try:
await ch.delete(reason="Duplicate channel cleanup")
logger.info(f"Deleted duplicate channel: #{ch.name}")
except (discord.Forbidden, discord.HTTPException) as e:
logger.warning(f"Failed to delete duplicate #{ch.name}: {e}")
logger.info(f"Reconnected to {count} channels, cleaned {len(duplicates)} duplicates")
async def _recover_task_message(
self, channel: discord.TextChannel, conversation_id: str
):
if conversation_id in self.session_status_messages:
return
try:
async for msg in channel.history(limit=10):
if msg.author == self.user and msg.embeds:
embed = msg.embeds[0]
if embed.title and "Task" in embed.title:
self.session_status_messages[conversation_id] = msg.id
return
except (discord.Forbidden, discord.HTTPException):
pass
# ─── Channel Management ──────────────────────────────────────────
async def _ensure_channel(
self, conversation_id: str, project_name: str
) -> discord.TextChannel:
"""Get or create a channel. ONE channel per conv_id, guaranteed."""
# Fast path: this conv_id already has a channel — ALWAYS return it
# (even if project name changed; name changes are cosmetic, not worth a new channel)
if conversation_id in self.session_channels:
return self.session_channels[conversation_id]
async with self._channel_create_lock:
# Double-check after lock
if conversation_id in self.session_channels:
return self.session_channels[conversation_id]
channel_name = f"{Config.CHANNEL_PREFIX}-{project_name}"
target_name = channel_name.lower().replace(" ", "-")
# Check ALL mapped channels for same name
for cid, ch in self.session_channels.items():
if ch.name == target_name:
self.session_channels[conversation_id] = ch
self.session_names[conversation_id] = project_name
logger.info(f"Reusing mapped channel #{ch.name} for {conversation_id[:8]}")
return ch
# Check Discord API — maybe channel exists but isn't in our dict
if self.session_category:
for ch in self.session_category.text_channels:
if ch.name == target_name:
self.session_channels[conversation_id] = ch
self.session_names[conversation_id] = project_name
logger.info(f"Found existing Discord channel #{ch.name} for {conversation_id[:8]}")
return ch
# Create new channel (truly no match anywhere)
try:
channel = await self.guild.create_text_channel(
name=channel_name,
category=self.session_category,
topic=f"Antigravity Session: {conversation_id}",
)
self.session_channels[conversation_id] = channel
self.session_names[conversation_id] = project_name
logger.info(f"Created channel #{channel_name}")
embed = discord.Embed(
title=f"🚀 {project_name}",
description=f"Antigravity 세션 연결됨\nSession: `{conversation_id}`",
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
return channel
except discord.errors.Forbidden:
logger.error(f"No permission to create channel: {channel_name}")
return None
# ─── Event Processing (SINGLE ROUTE) ─────────────────────────────
async def _process_events(self):
"""Main event loop — ALL events go through here sequentially."""
await self.wait_until_ready()
await self._ready_event.wait() # Wait until on_ready + reconnect completes
logger.info("Event processor started (ready gate passed)")
while not self.is_closed():
try:
event = await asyncio.wait_for(
self.event_queue.get(), timeout=5.0
)
await self._handle_event(event)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error processing event: {e}", exc_info=True)
async def _handle_event(self, event: BrainEvent):
"""Route brain events to handlers — single entry point."""
conv_dir = Config.BRAIN_PATH / event.conversation_id
project_name = detect_project_name(conv_dir)
if event.event_type == EventType.SESSION_START:
await self._ensure_channel(event.conversation_id, project_name)
return
# FILE_CREATED or FILE_CHANGED
channel = await self._ensure_channel(event.conversation_id, project_name)
if not channel:
return
try:
if event.file_name == "task.md":
await self._send_task_update(channel, event)
else:
await self._send_artifact_update(channel, event)
except discord.NotFound:
# Channel was deleted while we held a reference
self.session_channels.pop(event.conversation_id, None)
self.session_status_messages.pop(event.conversation_id, None)
logger.warning(f"Channel deleted, cleared: {event.conversation_id[:8]}")
# ─── Message Senders ─────────────────────────────────────────────
async def _send_task_update(
self, channel: discord.TextChannel, event: BrainEvent
):
"""Send/edit task progress embed (SINGLE message, always edited)."""
progress = parse_task_progress(event.content)
embed = discord.Embed(
title="📋 Task 진행 현황",
description=format_task_embed_text(progress),
color=discord.Color.gold() if progress.in_progress > 0
else discord.Color.green() if progress.done == progress.total
else discord.Color.greyple(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"Session: {event.conversation_id[:8]}")
# Always try to edit existing message first
msg_id = self.session_status_messages.get(event.conversation_id)
if msg_id:
try:
msg = await channel.fetch_message(msg_id)
await msg.edit(embed=embed)
return
except (discord.NotFound, discord.HTTPException):
pass
msg = await channel.send(embed=embed)
self.session_status_messages[event.conversation_id] = msg.id
async def _send_artifact_update(
self, channel: discord.TextChannel, event: BrainEvent
):
"""Send artifact update as single compact embed (preview only)."""
labels = {
"implementation_plan.md": "📐 구현 계획",
"walkthrough.md": "📝 작업 결과 요약",
}
label = labels.get(event.file_name, f"📄 {event.file_name}")
event_label = "생성" if event.event_type == EventType.FILE_CREATED else "업데이트"
# Preview: first 6 non-empty lines only
lines = event.content.strip().splitlines()
preview = "\n".join(l for l in lines[:6] if l.strip())
if len(lines) > 6:
preview += f"\n... (+{len(lines) - 6} lines)"
embed = discord.Embed(
title=f"{label} ({event_label}됨)",
description=preview[:1000],
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
# ─── Approval Scanner ────────────────────────────────────────────
@tasks.loop(seconds=3)
async def pending_approval_scanner(self):
"""Scan bridge/pending/ for new approval requests."""
try:
requests = self.bridge.get_pending_requests()
for req in requests:
if req.request_id in self._sent_approval_ids:
continue # Already sent
if req.discord_message_id != 0:
continue
channel = self.session_channels.get(req.conversation_id)
if not channel:
conv_dir = Config.BRAIN_PATH / req.conversation_id
if conv_dir.exists():
project_name = detect_project_name(conv_dir)
channel = await self._ensure_channel(
req.conversation_id, project_name
)
if channel:
self._sent_approval_ids.add(req.request_id)
await self._send_approval_request(channel, req)
except Exception as e:
logger.error(f"Error scanning approvals: {e}")
@pending_approval_scanner.before_loop
async def before_scanner(self):
await self.wait_until_ready()
async def _send_approval_request(
self, channel: discord.TextChannel, request: ApprovalRequest
):
embed = discord.Embed(
title="⚠️ 승인 요청",
description=(
f"**명령어:**\n```\n{request.command[:1000]}\n```\n"
f"{request.description[:500]}"
),
color=discord.Color.orange(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"ID: {request.request_id}")
view = ApprovalView(self.bridge, request)
msg = await channel.send(embed=embed, view=view)
# Update pending file with discord message id
pending_file = self.bridge.pending_dir / f"{request.request_id}.json"
if pending_file.exists():
try:
data = json.loads(pending_file.read_text(encoding="utf-8"))
data["discord_message_id"] = msg.id
pending_file.write_text(
json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
)
except (json.JSONDecodeError, OSError):
pass
logger.info(f"Sent approval request: {request.request_id[:8]}")
# ─── Discord → Antigravity Text Relay ─────────────────────────────
async def on_message(self, message: discord.Message):
if message.author == self.user:
return
if not message.channel.name.startswith(Config.CHANNEL_PREFIX.lower() + "-"):
return
conv_id = None
for cid, ch in self.session_channels.items():
if ch.id == message.channel.id:
conv_id = cid
break
text = message.content.strip()
# Special command: !auto on/off
if text in ("!auto on", "!auto off"):
if conv_id:
self.bridge.write_command(conv_id, text)
enabled = text == "!auto on"
emoji = "🟢" if enabled else "🔴"
mode = "자동 승인" if enabled else "수동 승인"
embed = discord.Embed(
title=f"{emoji} {mode} 모드",
description=f"Antigravity IDE 설정이 변경됩니다.\n"
f"`chat.tools.autoApprove = {enabled}`\n"
f"`chat.agent.autoApprove = {enabled}`",
color=discord.Color.green() if enabled else discord.Color.red(),
)
await message.channel.send(embed=embed)
else:
await message.reply("⚠️ 채널에 연결된 세션이 없습니다.")
return
# General text relay
if conv_id and text:
self.bridge.write_command(conv_id, text)
await message.add_reaction("📨")
await self.process_commands(message)