453 lines
18 KiB
Python
453 lines
18 KiB
Python
"""Discord bot — relays Antigravity brain events to Discord channels.
|
||
|
||
Dynamic channel management:
|
||
- Creates `AG-{project_name}` channels only when file events arrive
|
||
- NO startup channel creation — only reconnects to existing Discord channels
|
||
- Archives channels after 10 minutes of inactivity
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
import discord
|
||
from discord.ext import commands, tasks
|
||
|
||
from config import Config
|
||
from parser import (
|
||
parse_task_progress,
|
||
md_to_discord_text,
|
||
format_task_embed_text,
|
||
)
|
||
from watcher import BrainEvent, EventType
|
||
from bridge import BridgeProtocol, ApprovalRequest, UserResponse
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ─── Discord UI Components ──────────────────────────────────────────
|
||
|
||
class ApprovalView(discord.ui.View):
|
||
"""Discord buttons for approving/rejecting Antigravity actions."""
|
||
|
||
def __init__(self, bridge: BridgeProtocol, request: ApprovalRequest):
|
||
super().__init__(timeout=300)
|
||
self.bridge = bridge
|
||
self.request = request
|
||
self.responded = False
|
||
|
||
@discord.ui.button(label="✅ 승인", style=discord.ButtonStyle.green)
|
||
async def approve(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||
if self.responded:
|
||
await interaction.response.send_message("이미 응답됨", ephemeral=True)
|
||
return
|
||
self.responded = True
|
||
self.bridge.write_response(UserResponse(
|
||
request_id=self.request.request_id, approved=True,
|
||
))
|
||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.green()
|
||
embed.set_footer(text=f"✅ 승인됨 by {interaction.user.display_name}")
|
||
await interaction.response.edit_message(embed=embed, view=None)
|
||
|
||
@discord.ui.button(label="❌ 거부", style=discord.ButtonStyle.red)
|
||
async def reject(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||
if self.responded:
|
||
await interaction.response.send_message("이미 응답됨", ephemeral=True)
|
||
return
|
||
self.responded = True
|
||
self.bridge.write_response(UserResponse(
|
||
request_id=self.request.request_id, approved=False,
|
||
))
|
||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.red()
|
||
embed.set_footer(text=f"❌ 거부됨 by {interaction.user.display_name}")
|
||
await interaction.response.edit_message(embed=embed, view=None)
|
||
|
||
async def on_timeout(self):
|
||
if not self.responded:
|
||
self.bridge.write_response(UserResponse(
|
||
request_id=self.request.request_id, approved=False,
|
||
))
|
||
|
||
|
||
# ─── Project Name Detection ─────────────────────────────────────────
|
||
|
||
def detect_project_name(conv_dir: Path) -> str:
|
||
"""Extract project name from conversation artifacts.
|
||
Returns: lowercase_with_underscores (e.g. 'gravity_control')
|
||
Uses FIRST successful extraction and caches it.
|
||
"""
|
||
short_id = conv_dir.name[:8]
|
||
|
||
def _sanitize(raw: str) -> str:
|
||
for suffix in ["Task Tracker", "— Task Tracker", "태스크", "구현 계획",
|
||
"Implementation Plan", "Walkthrough"]:
|
||
raw = raw.replace(suffix, "")
|
||
raw = raw.strip(" —-–")
|
||
raw = re.sub(r'[^a-zA-Z0-9가-힣\s\-]', '', raw)
|
||
raw = re.sub(r'[\s\-]+', '_', raw).strip('_').lower()
|
||
return raw[:30] if raw else ""
|
||
|
||
for fname in ["task.md", "implementation_plan.md"]:
|
||
fpath = conv_dir / fname
|
||
if fpath.exists():
|
||
try:
|
||
first_lines = fpath.read_text(encoding="utf-8").splitlines()[:5]
|
||
for line in first_lines:
|
||
match = re.match(r'^#\s+(.+)', line)
|
||
if match:
|
||
name = _sanitize(match.group(1))
|
||
# Require at least 5 chars to avoid short generic names
|
||
if name and name != "task" and len(name) >= 5:
|
||
return name
|
||
except (OSError, UnicodeDecodeError):
|
||
pass
|
||
|
||
return short_id
|
||
|
||
|
||
# ─── Bot ─────────────────────────────────────────────────────────────
|
||
|
||
class GravityBot(commands.Bot):
|
||
"""Discord bot for Antigravity session monitoring."""
|
||
|
||
def __init__(self, event_queue: asyncio.Queue):
|
||
intents = discord.Intents.default()
|
||
intents.message_content = True
|
||
intents.guilds = True
|
||
super().__init__(command_prefix="!", intents=intents)
|
||
|
||
self.event_queue = event_queue
|
||
self.session_channels: dict[str, discord.TextChannel] = {}
|
||
self.session_status_messages: dict[str, int] = {}
|
||
self.session_names: dict[str, str] = {}
|
||
self._channel_create_lock = asyncio.Lock() # SINGLE global lock
|
||
self._sent_approval_ids: set[str] = set() # Track sent approvals
|
||
self.bridge = BridgeProtocol()
|
||
self.session_category: discord.CategoryChannel | None = None
|
||
self.guild: discord.Guild | None = None
|
||
|
||
async def setup_hook(self):
|
||
self.loop.create_task(self._process_events())
|
||
self.pending_approval_scanner.start()
|
||
logger.info("Bot setup complete, event processor started")
|
||
|
||
async def on_ready(self):
|
||
logger.info(f"Bot connected as {self.user} (ID: {self.user.id})")
|
||
|
||
self.guild = self.get_guild(Config.DISCORD_GUILD_ID)
|
||
if not self.guild:
|
||
logger.error(f"Guild {Config.DISCORD_GUILD_ID} not found!")
|
||
return
|
||
|
||
# Find or create category
|
||
category_name = "Antigravity Sessions"
|
||
self.session_category = discord.utils.get(
|
||
self.guild.categories, name=category_name
|
||
)
|
||
if not self.session_category:
|
||
try:
|
||
self.session_category = await self.guild.create_category(category_name)
|
||
logger.info(f"Created category: {category_name}")
|
||
except discord.errors.Forbidden:
|
||
logger.error("No permission to create category!")
|
||
return
|
||
|
||
# ONLY reconnect to existing Discord channels (NO new creation)
|
||
await self._reconnect_existing_channels()
|
||
|
||
async def _reconnect_existing_channels(self):
|
||
"""Scan existing Discord channels and map them — MERGE same-name channels."""
|
||
if not self.session_category:
|
||
return
|
||
|
||
# Group channels by normalized name
|
||
name_to_channel: dict[str, discord.TextChannel] = {}
|
||
duplicates: list[discord.TextChannel] = []
|
||
|
||
for ch in self.session_category.text_channels:
|
||
if ch.topic and "Antigravity Session:" in ch.topic:
|
||
if ch.name in name_to_channel:
|
||
# DUPLICATE — mark for cleanup
|
||
duplicates.append(ch)
|
||
else:
|
||
name_to_channel[ch.name] = ch
|
||
|
||
# Map the primary channel for each name
|
||
count = 0
|
||
for ch in name_to_channel.values():
|
||
conv_id = ch.topic.replace("Antigravity Session:", "").strip()
|
||
if conv_id:
|
||
self.session_channels[conv_id] = ch
|
||
await self._recover_task_message(ch, conv_id)
|
||
count += 1
|
||
|
||
# Delete duplicate channels
|
||
for ch in duplicates:
|
||
try:
|
||
await ch.delete(reason="Duplicate channel cleanup")
|
||
logger.info(f"Deleted duplicate channel: #{ch.name}")
|
||
except (discord.Forbidden, discord.HTTPException) as e:
|
||
logger.warning(f"Failed to delete duplicate #{ch.name}: {e}")
|
||
|
||
logger.info(f"Reconnected to {count} channels, cleaned {len(duplicates)} duplicates")
|
||
|
||
async def _recover_task_message(
|
||
self, channel: discord.TextChannel, conversation_id: str
|
||
):
|
||
if conversation_id in self.session_status_messages:
|
||
return
|
||
try:
|
||
async for msg in channel.history(limit=10):
|
||
if msg.author == self.user and msg.embeds:
|
||
embed = msg.embeds[0]
|
||
if embed.title and "Task" in embed.title:
|
||
self.session_status_messages[conversation_id] = msg.id
|
||
return
|
||
except (discord.Forbidden, discord.HTTPException):
|
||
pass
|
||
|
||
# ─── Channel Management ──────────────────────────────────────────
|
||
|
||
async def _ensure_channel(
|
||
self, conversation_id: str, project_name: str
|
||
) -> discord.TextChannel:
|
||
"""Get or create a channel. SINGLE channel per project name, guaranteed."""
|
||
channel_name = f"{Config.CHANNEL_PREFIX}-{project_name}"
|
||
target_name = channel_name.lower().replace(" ", "-")
|
||
|
||
# Fast path: this conv_id already mapped
|
||
if conversation_id in self.session_channels:
|
||
ch = self.session_channels[conversation_id]
|
||
# Verify the channel name matches (project name might have changed)
|
||
if ch.name == target_name:
|
||
return ch
|
||
|
||
async with self._channel_create_lock:
|
||
# Double-check after lock
|
||
if conversation_id in self.session_channels:
|
||
ch = self.session_channels[conversation_id]
|
||
if ch.name == target_name:
|
||
return ch
|
||
|
||
# Check ALL mapped channels for same name
|
||
for cid, ch in self.session_channels.items():
|
||
if ch.name == target_name:
|
||
self.session_channels[conversation_id] = ch
|
||
self.session_names[conversation_id] = project_name
|
||
logger.info(f"Reusing channel #{ch.name} for {conversation_id[:8]}")
|
||
return ch
|
||
|
||
# Create new channel (truly no match anywhere)
|
||
try:
|
||
channel = await self.guild.create_text_channel(
|
||
name=channel_name,
|
||
category=self.session_category,
|
||
topic=f"Antigravity Session: {conversation_id}",
|
||
)
|
||
self.session_channels[conversation_id] = channel
|
||
self.session_names[conversation_id] = project_name
|
||
logger.info(f"Created channel #{channel_name}")
|
||
|
||
embed = discord.Embed(
|
||
title=f"🚀 {project_name}",
|
||
description=f"Antigravity 세션 연결됨\nSession: `{conversation_id}`",
|
||
color=discord.Color.blue(),
|
||
timestamp=datetime.now(timezone.utc),
|
||
)
|
||
await channel.send(embed=embed)
|
||
return channel
|
||
|
||
except discord.errors.Forbidden:
|
||
logger.error(f"No permission to create channel: {channel_name}")
|
||
return None
|
||
|
||
# ─── Event Processing (SINGLE ROUTE) ─────────────────────────────
|
||
|
||
async def _process_events(self):
|
||
"""Main event loop — ALL events go through here sequentially."""
|
||
await self.wait_until_ready()
|
||
|
||
while not self.is_closed():
|
||
try:
|
||
event = await asyncio.wait_for(
|
||
self.event_queue.get(), timeout=5.0
|
||
)
|
||
await self._handle_event(event)
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"Error processing event: {e}", exc_info=True)
|
||
|
||
async def _handle_event(self, event: BrainEvent):
|
||
"""Route brain events to handlers — single entry point."""
|
||
conv_dir = Config.BRAIN_PATH / event.conversation_id
|
||
project_name = detect_project_name(conv_dir)
|
||
|
||
if event.event_type == EventType.SESSION_START:
|
||
await self._ensure_channel(event.conversation_id, project_name)
|
||
return
|
||
|
||
# FILE_CREATED or FILE_CHANGED
|
||
channel = await self._ensure_channel(event.conversation_id, project_name)
|
||
if not channel:
|
||
return
|
||
|
||
try:
|
||
if event.file_name == "task.md":
|
||
await self._send_task_update(channel, event)
|
||
else:
|
||
await self._send_artifact_update(channel, event)
|
||
except discord.NotFound:
|
||
# Channel was deleted while we held a reference
|
||
self.session_channels.pop(event.conversation_id, None)
|
||
self.session_status_messages.pop(event.conversation_id, None)
|
||
logger.warning(f"Channel deleted, cleared: {event.conversation_id[:8]}")
|
||
|
||
# ─── Message Senders ─────────────────────────────────────────────
|
||
|
||
async def _send_task_update(
|
||
self, channel: discord.TextChannel, event: BrainEvent
|
||
):
|
||
"""Send/edit task progress embed (SINGLE message, always edited)."""
|
||
progress = parse_task_progress(event.content)
|
||
|
||
embed = discord.Embed(
|
||
title="📋 Task 진행 현황",
|
||
description=format_task_embed_text(progress),
|
||
color=discord.Color.gold() if progress.in_progress > 0
|
||
else discord.Color.green() if progress.done == progress.total
|
||
else discord.Color.greyple(),
|
||
timestamp=datetime.now(timezone.utc),
|
||
)
|
||
embed.set_footer(text=f"Session: {event.conversation_id[:8]}")
|
||
|
||
# Always try to edit existing message first
|
||
msg_id = self.session_status_messages.get(event.conversation_id)
|
||
if msg_id:
|
||
try:
|
||
msg = await channel.fetch_message(msg_id)
|
||
await msg.edit(embed=embed)
|
||
return
|
||
except (discord.NotFound, discord.HTTPException):
|
||
pass
|
||
|
||
msg = await channel.send(embed=embed)
|
||
self.session_status_messages[event.conversation_id] = msg.id
|
||
|
||
async def _send_artifact_update(
|
||
self, channel: discord.TextChannel, event: BrainEvent
|
||
):
|
||
"""Send artifact update as single compact embed (preview only)."""
|
||
labels = {
|
||
"implementation_plan.md": "📐 구현 계획",
|
||
"walkthrough.md": "📝 작업 결과 요약",
|
||
}
|
||
label = labels.get(event.file_name, f"📄 {event.file_name}")
|
||
event_label = "생성" if event.event_type == EventType.FILE_CREATED else "업데이트"
|
||
|
||
# Preview: first 6 non-empty lines only
|
||
lines = event.content.strip().splitlines()
|
||
preview = "\n".join(l for l in lines[:6] if l.strip())
|
||
if len(lines) > 6:
|
||
preview += f"\n... (+{len(lines) - 6} lines)"
|
||
|
||
embed = discord.Embed(
|
||
title=f"{label} ({event_label}됨)",
|
||
description=preview[:1000],
|
||
color=discord.Color.blue(),
|
||
timestamp=datetime.now(timezone.utc),
|
||
)
|
||
await channel.send(embed=embed)
|
||
|
||
# ─── Approval Scanner ────────────────────────────────────────────
|
||
|
||
@tasks.loop(seconds=3)
|
||
async def pending_approval_scanner(self):
|
||
"""Scan bridge/pending/ for new approval requests."""
|
||
try:
|
||
requests = self.bridge.get_pending_requests()
|
||
for req in requests:
|
||
if req.request_id in self._sent_approval_ids:
|
||
continue # Already sent
|
||
if req.discord_message_id != 0:
|
||
continue
|
||
|
||
channel = self.session_channels.get(req.conversation_id)
|
||
if not channel:
|
||
conv_dir = Config.BRAIN_PATH / req.conversation_id
|
||
if conv_dir.exists():
|
||
project_name = detect_project_name(conv_dir)
|
||
channel = await self._ensure_channel(
|
||
req.conversation_id, project_name
|
||
)
|
||
|
||
if channel:
|
||
self._sent_approval_ids.add(req.request_id)
|
||
await self._send_approval_request(channel, req)
|
||
except Exception as e:
|
||
logger.error(f"Error scanning approvals: {e}")
|
||
|
||
@pending_approval_scanner.before_loop
|
||
async def before_scanner(self):
|
||
await self.wait_until_ready()
|
||
|
||
async def _send_approval_request(
|
||
self, channel: discord.TextChannel, request: ApprovalRequest
|
||
):
|
||
embed = discord.Embed(
|
||
title="⚠️ 승인 요청",
|
||
description=(
|
||
f"**명령어:**\n```\n{request.command[:1000]}\n```\n"
|
||
f"{request.description[:500]}"
|
||
),
|
||
color=discord.Color.orange(),
|
||
timestamp=datetime.now(timezone.utc),
|
||
)
|
||
embed.set_footer(text=f"ID: {request.request_id}")
|
||
|
||
view = ApprovalView(self.bridge, request)
|
||
msg = await channel.send(embed=embed, view=view)
|
||
|
||
# Update pending file with discord message id
|
||
pending_file = self.bridge.pending_dir / f"{request.request_id}.json"
|
||
if pending_file.exists():
|
||
try:
|
||
data = json.loads(pending_file.read_text(encoding="utf-8"))
|
||
data["discord_message_id"] = msg.id
|
||
pending_file.write_text(
|
||
json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
|
||
)
|
||
except (json.JSONDecodeError, OSError):
|
||
pass
|
||
|
||
logger.info(f"Sent approval request: {request.request_id[:8]}")
|
||
|
||
# ─── Discord → Antigravity Text Relay ─────────────────────────────
|
||
|
||
async def on_message(self, message: discord.Message):
|
||
if message.author == self.user:
|
||
return
|
||
|
||
if not message.channel.name.startswith(Config.CHANNEL_PREFIX.lower() + "-"):
|
||
return
|
||
|
||
conv_id = None
|
||
for cid, ch in self.session_channels.items():
|
||
if ch.id == message.channel.id:
|
||
conv_id = cid
|
||
break
|
||
|
||
if conv_id and message.content.strip():
|
||
self.bridge.write_command(conv_id, message.content.strip())
|
||
await message.add_reaction("📨")
|
||
|
||
await self.process_commands(message)
|