Files
gravity_control/bot.py

595 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Discord bot — relays Antigravity brain events to Discord channels.
Dynamic channel management:
- Scans brain/ for active sessions on startup
- Creates `gravity-{project_name}` channels per active session
- Archives channels when sessions become inactive
- Project name extracted from artifact content or short conversation ID
"""
import asyncio
import json
import logging
import re
import time
from datetime import datetime, timezone
from pathlib import Path
import discord
from discord.ext import commands, tasks
from config import Config
from parser import (
parse_task_progress,
md_to_discord_text,
format_task_embed_text,
)
from watcher import BrainEvent, EventType
from bridge import BridgeProtocol, ApprovalRequest, UserResponse
logger = logging.getLogger(__name__)
class ApprovalView(discord.ui.View):
"""Discord buttons for approving/rejecting Antigravity actions."""
def __init__(self, bridge: BridgeProtocol, request: ApprovalRequest):
super().__init__(timeout=300) # 5 min timeout
self.bridge = bridge
self.request = request
self.responded = False
@discord.ui.button(label="✅ 승인", style=discord.ButtonStyle.green)
async def approve(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id,
approved=True,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.green()
embed.set_footer(text=f"✅ 승인됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
@discord.ui.button(label="❌ 거부", style=discord.ButtonStyle.red)
async def reject(self, interaction: discord.Interaction, button: discord.ui.Button):
if self.responded:
await interaction.response.send_message("이미 응답됨", ephemeral=True)
return
self.responded = True
self.bridge.write_response(UserResponse(
request_id=self.request.request_id,
approved=False,
))
embed = interaction.message.embeds[0] if interaction.message.embeds else None
if embed:
embed.color = discord.Color.red()
embed.set_footer(text=f"❌ 거부됨 by {interaction.user.display_name}")
await interaction.response.edit_message(embed=embed, view=None)
async def on_timeout(self):
"""Auto-timeout after 5 minutes."""
if not self.responded:
self.bridge.write_response(UserResponse(
request_id=self.request.request_id,
approved=False,
))
def detect_project_name(conv_dir: Path) -> str:
"""Extract a human-readable project name from conversation artifacts.
Strategy:
1. Check task.md first heading for a title
2. Check implementation_plan.md first heading
3. Check any .metadata.json for summary keywords
4. Fallback to short conversation ID
Output format: lowercase_with_underscores (e.g. 'gravity_control')
"""
short_id = conv_dir.name[:8]
def _sanitize(raw: str) -> str:
"""Convert raw title to Discord-friendly channel name part."""
# Remove common suffixes
for suffix in ["Task Tracker", "— Task Tracker", "태스크", "구현 계획"]:
raw = raw.replace(suffix, "")
raw = raw.strip(" —-")
# Keep only alphanumeric, Korean, spaces, hyphens
raw = re.sub(r'[^a-zA-Z0-9가-힣\s\-]', '', raw)
# Convert to underscore format
raw = re.sub(r'[\s\-]+', '_', raw).strip('_').lower()
return raw[:30] if raw else ""
# Try task.md title
task_file = conv_dir / "task.md"
if task_file.exists():
try:
first_lines = task_file.read_text(encoding="utf-8").splitlines()[:5]
for line in first_lines:
match = re.match(r'^#\s+(.+)', line)
if match:
name = _sanitize(match.group(1))
if name and name != "task":
return name
except (OSError, UnicodeDecodeError):
pass
# Try implementation_plan.md title
plan_file = conv_dir / "implementation_plan.md"
if plan_file.exists():
try:
first_lines = plan_file.read_text(encoding="utf-8").splitlines()[:5]
for line in first_lines:
match = re.match(r'^#\s+(.+)', line)
if match:
name = _sanitize(match.group(1))
if name:
return name
except (OSError, UnicodeDecodeError):
pass
# Try metadata summary (first 2 words)
for meta_file in conv_dir.glob("*.metadata.json"):
try:
meta = json.loads(meta_file.read_text(encoding="utf-8"))
summary = meta.get("summary", "")
words = summary.split()[:2]
if words:
name = "_".join(words).lower()
name = re.sub(r'[^a-z0-9가-힣_]', '', name)[:30]
if name:
return name
except (OSError, json.JSONDecodeError):
pass
return short_id
def is_session_active(conv_dir: Path) -> bool:
"""Check if a session is active based on file modification time."""
now = time.time()
threshold = Config.ACTIVE_TIMEOUT_SECONDS
for f in conv_dir.iterdir():
if f.is_file():
try:
mtime = f.stat().st_mtime
if now - mtime < threshold:
return True
except OSError:
continue
return False
class GravityBot(commands.Bot):
"""Discord bot for Antigravity session monitoring."""
def __init__(self, event_queue: asyncio.Queue):
intents = discord.Intents.default()
intents.message_content = True
intents.guilds = True
super().__init__(command_prefix="!", intents=intents)
self.event_queue = event_queue
# conversation_id -> channel object
self.session_channels: dict[str, discord.TextChannel] = {}
# conversation_id -> status message id (to edit in-place)
self.session_status_messages: dict[str, int] = {}
# conversation_id -> project name
self.session_names: dict[str, str] = {}
# Locks to prevent duplicate channel creation
self._channel_locks: dict[str, asyncio.Lock] = {}
# Bridge protocol for bidirectional communication
self.bridge = BridgeProtocol()
# Cache: conversation_id -> last metadata summary (skip unchanged)
self._last_meta_summary: dict[str, str] = {}
# Category for session channels
self.session_category: discord.CategoryChannel | None = None
# Guild reference
self.guild: discord.Guild | None = None
async def setup_hook(self):
"""Called after login, before processing events."""
self.loop.create_task(self._process_events())
self.session_cleanup_loop.start()
self.pending_approval_scanner.start()
logger.info("Bot setup complete, event processor + approval scanner started")
async def on_ready(self):
"""Called when bot is ready."""
logger.info(f"Bot connected as {self.user} (ID: {self.user.id})")
# Get guild
self.guild = self.get_guild(Config.DISCORD_GUILD_ID)
if not self.guild:
logger.error(f"Guild {Config.DISCORD_GUILD_ID} not found!")
return
# Find or create session category
category_name = "Antigravity Sessions"
self.session_category = discord.utils.get(
self.guild.categories, name=category_name
)
if not self.session_category:
try:
self.session_category = await self.guild.create_category(category_name)
logger.info(f"Created category: {category_name}")
except discord.errors.Forbidden:
logger.error("No permission to create category!")
return
# Sync existing active sessions
await self._sync_active_sessions()
async def _sync_active_sessions(self):
"""Scan brain/ and create channels for currently active sessions."""
brain_path = Config.BRAIN_PATH
if not brain_path.exists():
return
active_count = 0
for entry in brain_path.iterdir():
if entry.is_dir() and self._is_conversation_id(entry.name):
if is_session_active(entry):
project_name = detect_project_name(entry)
await self._ensure_channel(entry.name, project_name)
active_count += 1
logger.info(f"Synced {active_count} active sessions on startup")
def _is_conversation_id(self, name: str) -> bool:
"""Check if directory name looks like a UUID."""
parts = name.split("-")
return len(parts) == 5 and all(len(p) >= 4 for p in parts)
async def _ensure_channel(
self, conversation_id: str, project_name: str
) -> discord.TextChannel:
"""Get or create a Discord channel for a session (thread-safe)."""
# Fast path: already mapped
if conversation_id in self.session_channels:
return self.session_channels[conversation_id]
# Get or create a lock for this conversation
if conversation_id not in self._channel_locks:
self._channel_locks[conversation_id] = asyncio.Lock()
async with self._channel_locks[conversation_id]:
# Double-check after acquiring lock
if conversation_id in self.session_channels:
return self.session_channels[conversation_id]
channel_name = f"{Config.CHANNEL_PREFIX}-{project_name}"
# Check if channel already exists in category (from previous run)
if self.session_category:
for ch in self.session_category.text_channels:
if ch.topic and conversation_id in ch.topic:
self.session_channels[conversation_id] = ch
self.session_names[conversation_id] = project_name
# Rename back from closed- if needed
expected_name = f"{Config.CHANNEL_PREFIX}-{project_name}"
if ch.name != expected_name.lower():
try:
await ch.edit(name=expected_name)
logger.info(f"Renamed channel {ch.name} -> {expected_name}")
except discord.errors.Forbidden:
pass
# Recover last task embed message ID
await self._recover_task_message(ch, conversation_id)
logger.info(f"Reconnected to existing channel #{ch.name}")
return ch
# Create new channel
try:
channel = await self.guild.create_text_channel(
name=channel_name,
category=self.session_category,
topic=f"Antigravity Session: {conversation_id}",
)
self.session_channels[conversation_id] = channel
self.session_names[conversation_id] = project_name
logger.info(f"Created channel #{channel_name}")
# Welcome embed
embed = discord.Embed(
title=f"🚀 {project_name}",
description=(
f"Antigravity 세션 연결됨\n"
f"Session: `{conversation_id}`"
),
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
return channel
except discord.errors.Forbidden:
logger.error(f"No permission to create channel: {channel_name}")
return None
async def _recover_task_message(
self, channel: discord.TextChannel, conversation_id: str
):
"""Find the last task embed in channel history to reuse for editing."""
if conversation_id in self.session_status_messages:
return # Already have it
try:
async for msg in channel.history(limit=20):
if msg.author == self.user and msg.embeds:
embed = msg.embeds[0]
if embed.title and "Task 진행 현황" in embed.title:
self.session_status_messages[conversation_id] = msg.id
logger.info(f"Recovered task embed msg {msg.id} for {conversation_id[:8]}")
return
except (discord.Forbidden, discord.HTTPException):
pass
async def _process_events(self):
"""Main event processing loop."""
await self.wait_until_ready()
while not self.is_closed():
try:
event = await asyncio.wait_for(
self.event_queue.get(), timeout=5.0
)
await self._handle_event(event)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error processing event: {e}", exc_info=True)
async def _handle_event(self, event: BrainEvent):
"""Route brain events to handlers."""
if event.event_type == EventType.SESSION_START:
conv_dir = Config.BRAIN_PATH / event.conversation_id
project_name = detect_project_name(conv_dir)
await self._ensure_channel(event.conversation_id, project_name)
elif event.event_type in (EventType.FILE_CREATED, EventType.FILE_CHANGED):
# Ensure channel exists
conv_dir = Config.BRAIN_PATH / event.conversation_id
project_name = detect_project_name(conv_dir)
channel = await self._ensure_channel(event.conversation_id, project_name)
if channel:
if event.file_name == "task.md":
await self._send_task_update(channel, event)
elif event.file_name.endswith(".metadata.json"):
await self._send_metadata_update(channel, event)
else:
await self._send_artifact_content(channel, event)
async def _send_task_update(
self, channel: discord.TextChannel, event: BrainEvent
):
"""Send task progress update as an embed."""
progress = parse_task_progress(event.content)
embed = discord.Embed(
title="📋 Task 진행 현황",
description=format_task_embed_text(progress),
color=discord.Color.gold() if progress.in_progress > 0
else discord.Color.green() if progress.done == progress.total
else discord.Color.greyple(),
timestamp=datetime.now(timezone.utc),
)
short_id = event.conversation_id[:8]
embed.set_footer(text=f"Session: {short_id}")
# Edit existing or send new
msg_id = self.session_status_messages.get(event.conversation_id)
if msg_id:
try:
msg = await channel.fetch_message(msg_id)
await msg.edit(embed=embed)
return
except (discord.NotFound, discord.HTTPException):
pass
msg = await channel.send(embed=embed)
self.session_status_messages[event.conversation_id] = msg.id
async def _send_artifact_content(
self, channel: discord.TextChannel, event: BrainEvent
):
"""Send artifact change notification (compact — metadata handler sends summary)."""
labels = {
"implementation_plan.md": "📐 구현 계획",
"walkthrough.md": "📝 작업 결과 요약",
}
label = labels.get(event.file_name, f"📄 {event.file_name}")
event_label = "생성" if event.event_type == EventType.FILE_CREATED else "업데이트"
# Only send first few lines as preview instead of full content
lines = event.content.strip().splitlines()
preview_lines = []
for line in lines[:8]:
if line.strip():
preview_lines.append(line)
preview = "\n".join(preview_lines)
if len(lines) > 8:
preview += f"\n... (+{len(lines) - 8} lines)"
embed = discord.Embed(
title=f"{label} ({event_label}됨)",
description=preview[:1000],
color=discord.Color.blue(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
async def _send_metadata_update(
self, channel: discord.TextChannel, event: BrainEvent
):
"""Send artifact metadata summary changes as compact embed."""
try:
meta = json.loads(event.content)
except json.JSONDecodeError:
return
summary = meta.get("summary", "")
artifact_type = meta.get("artifactType", "")
if not summary:
return
# Dedup: skip if summary unchanged since last notification
cache_key = f"{event.conversation_id}:{event.file_name}"
if self._last_meta_summary.get(cache_key) == summary:
return
self._last_meta_summary[cache_key] = summary
# Map artifact types to emoji
type_emoji = {
"ARTIFACT_TYPE_TASK": "📋",
"ARTIFACT_TYPE_IMPLEMENTATION_PLAN": "📐",
"ARTIFACT_TYPE_WALKTHROUGH": "📝",
"ARTIFACT_TYPE_OTHER": "📄",
}
emoji = type_emoji.get(artifact_type, "📄")
# Artifact name from filename (strip .metadata.json)
artifact_name = event.file_name.replace(".metadata.json", "")
embed = discord.Embed(
description=f"{emoji} **{artifact_name}** 업데이트\n\n{summary[:500]}",
color=discord.Color.dark_teal(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
@tasks.loop(minutes=5)
async def session_cleanup_loop(self):
"""Periodically check for inactive sessions and archive their channels."""
if not self.guild:
return
to_remove = []
for conv_id, channel in self.session_channels.items():
conv_dir = Config.BRAIN_PATH / conv_id
if not conv_dir.exists() or not is_session_active(conv_dir):
to_remove.append(conv_id)
for conv_id in to_remove:
channel = self.session_channels.pop(conv_id, None)
self.session_status_messages.pop(conv_id, None)
name = self.session_names.pop(conv_id, conv_id[:8])
if channel:
try:
embed = discord.Embed(
title="🔴 세션 비활성",
description=f"`{name}` 세션이 비활성 상태입니다.",
color=discord.Color.red(),
timestamp=datetime.now(timezone.utc),
)
await channel.send(embed=embed)
await channel.edit(name=f"closed-{name[:20]}")
logger.info(f"Archived channel for {conv_id[:8]}")
except discord.errors.Forbidden:
logger.warning(f"No permission to archive {conv_id[:8]}")
except Exception as e:
logger.warning(f"Failed to archive {conv_id[:8]}: {e}")
@session_cleanup_loop.before_loop
async def before_cleanup(self):
await self.wait_until_ready()
@tasks.loop(seconds=3)
async def pending_approval_scanner(self):
"""Scan bridge/pending/ for new approval requests and send Discord buttons."""
try:
requests = self.bridge.get_pending_requests()
for req in requests:
# Find the right channel
if req.conversation_id in self.session_channels:
channel = self.session_channels[req.conversation_id]
elif req.discord_message_id == 0:
# Try to find channel by conversation_id
conv_dir = Config.BRAIN_PATH / req.conversation_id
if conv_dir.exists():
project_name = detect_project_name(conv_dir)
channel = await self._ensure_channel(req.conversation_id, project_name)
else:
continue
else:
continue
if channel and req.discord_message_id == 0:
await self._send_approval_request(channel, req)
except Exception as e:
logger.error(f"Error scanning pending approvals: {e}")
@pending_approval_scanner.before_loop
async def before_scanner(self):
await self.wait_until_ready()
async def _send_approval_request(
self, channel: discord.TextChannel, request: ApprovalRequest
):
"""Send an approval request with buttons to Discord."""
embed = discord.Embed(
title="⚠️ 승인 요청",
description=(
f"**명령어:**\n```\n{request.command[:1000]}\n```\n"
f"{request.description[:500]}"
),
color=discord.Color.orange(),
timestamp=datetime.now(timezone.utc),
)
embed.set_footer(text=f"ID: {request.request_id}")
view = ApprovalView(self.bridge, request)
msg = await channel.send(embed=embed, view=view)
# Update pending file with discord message id
pending_file = self.bridge.pending_dir / f"{request.request_id}.json"
if pending_file.exists():
try:
data = json.loads(pending_file.read_text(encoding="utf-8"))
data["discord_message_id"] = msg.id
pending_file.write_text(
json.dumps(data, ensure_ascii=False, indent=2),
encoding="utf-8"
)
except (json.JSONDecodeError, OSError):
pass
logger.info(f"Sent approval request: {request.request_id[:8]}")
async def on_message(self, message: discord.Message):
"""Handle user messages in AG channels → relay as text input."""
# Ignore bot's own messages
if message.author == self.user:
return
# Check if message is in an AG session channel
if not message.channel.name.startswith(Config.CHANNEL_PREFIX + "-"):
return
# Find conversation_id for this channel
conv_id = None
for cid, ch in self.session_channels.items():
if ch.id == message.channel.id:
conv_id = cid
break
if conv_id and message.content.strip():
# Write user input to bridge commands
cmd_id = self.bridge.write_command(conv_id, message.content.strip())
await message.add_reaction("📨")
logger.info(f"User input relayed: {message.content[:50]}... → {conv_id[:8]}")
# Process commands (e.g., !approve, !reject)
await self.process_commands(message)