wip: [01-stabilize] paused at task 1/1 - OCR Hallucination Immune logic via Semantic delta window and fret-isolation
114
.agent/vendor/mini-swe/sweagent/__init__.py
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from logging import WARNING, getLogger
|
||||
from pathlib import Path
|
||||
|
||||
import swerex.utils.log as log_swerex
|
||||
from git import Repo
|
||||
from packaging import version
|
||||
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
__version__ = "1.1.0"
|
||||
PYTHON_MINIMUM_VERSION = (3, 11)
|
||||
SWEREX_MINIMUM_VERSION = "1.2.0"
|
||||
SWEREX_RECOMMENDED_VERSION = "1.2.1"
|
||||
|
||||
# Monkey patch the logger to use our implementation
|
||||
log_swerex.get_logger = partial(get_logger, emoji="🦖")
|
||||
|
||||
# See https://github.com/SWE-agent/SWE-agent/issues/585
|
||||
getLogger("datasets").setLevel(WARNING)
|
||||
getLogger("numexpr.utils").setLevel(WARNING)
|
||||
getLogger("LiteLLM").setLevel(WARNING)
|
||||
|
||||
PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
if sys.version_info < PYTHON_MINIMUM_VERSION:
|
||||
msg = (
|
||||
f"Python {sys.version_info.major}.{sys.version_info.minor} is not supported. "
|
||||
"SWE-agent requires Python 3.11 or higher."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
assert PACKAGE_DIR.is_dir(), PACKAGE_DIR
|
||||
REPO_ROOT = PACKAGE_DIR.parent
|
||||
assert REPO_ROOT.is_dir(), REPO_ROOT
|
||||
CONFIG_DIR = Path(os.getenv("SWE_AGENT_CONFIG_DIR", PACKAGE_DIR.parent / "config"))
|
||||
assert CONFIG_DIR.is_dir(), CONFIG_DIR
|
||||
|
||||
TOOLS_DIR = Path(os.getenv("SWE_AGENT_TOOLS_DIR", PACKAGE_DIR.parent / "tools"))
|
||||
assert TOOLS_DIR.is_dir(), TOOLS_DIR
|
||||
|
||||
TRAJECTORY_DIR = Path(os.getenv("SWE_AGENT_TRAJECTORY_DIR", PACKAGE_DIR.parent / "trajectories"))
|
||||
assert TRAJECTORY_DIR.is_dir(), TRAJECTORY_DIR
|
||||
|
||||
|
||||
def get_agent_commit_hash() -> str:
|
||||
"""Get the commit hash of the current SWE-agent commit.
|
||||
|
||||
If we cannot get the hash, we return an empty string.
|
||||
"""
|
||||
try:
|
||||
repo = Repo(REPO_ROOT, search_parent_directories=False)
|
||||
except Exception:
|
||||
return "unavailable"
|
||||
return repo.head.object.hexsha
|
||||
|
||||
|
||||
def get_rex_commit_hash() -> str:
|
||||
import swerex
|
||||
|
||||
try:
|
||||
repo = Repo(Path(swerex.__file__).resolve().parent.parent.parent, search_parent_directories=False)
|
||||
except Exception:
|
||||
return "unavailable"
|
||||
return repo.head.object.hexsha
|
||||
|
||||
|
||||
def get_rex_version() -> str:
|
||||
from swerex import __version__ as rex_version
|
||||
|
||||
return rex_version
|
||||
|
||||
|
||||
def get_agent_version_info() -> str:
|
||||
hash = get_agent_commit_hash()
|
||||
rex_hash = get_rex_commit_hash()
|
||||
rex_version = get_rex_version()
|
||||
return f"This is SWE-agent version {__version__} ({hash=}) with SWE-ReX version {rex_version} ({rex_hash=})."
|
||||
|
||||
|
||||
def impose_rex_lower_bound() -> None:
|
||||
rex_version = get_rex_version()
|
||||
minimal_rex_version = "1.2.0"
|
||||
if version.parse(rex_version) < version.parse(minimal_rex_version):
|
||||
msg = (
|
||||
f"SWE-ReX version {rex_version} is too old. Please update to at least {minimal_rex_version} by "
|
||||
"running `pip install --upgrade swe-rex`."
|
||||
"You can also rerun `pip install -e .` in this repository to install the latest version."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
if version.parse(rex_version) < version.parse(SWEREX_RECOMMENDED_VERSION):
|
||||
msg = (
|
||||
f"SWE-ReX version {rex_version} is not recommended. Please update to at least {SWEREX_RECOMMENDED_VERSION} by "
|
||||
"running `pip install --upgrade swe-rex`."
|
||||
"You can also rerun `pip install -e .` in this repository to install the latest version."
|
||||
)
|
||||
get_logger("swe-agent", emoji="👋").warning(msg)
|
||||
|
||||
|
||||
impose_rex_lower_bound()
|
||||
get_logger("swe-agent", emoji="👋").info(get_agent_version_info())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PACKAGE_DIR",
|
||||
"CONFIG_DIR",
|
||||
"get_agent_commit_hash",
|
||||
"get_agent_version_info",
|
||||
"__version__",
|
||||
]
|
||||
4
.agent/vendor/mini-swe/sweagent/__main__.py
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
from sweagent.run.run import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
.agent/vendor/mini-swe/sweagent/agent/__init__.py
vendored
Normal file
317
.agent/vendor/mini-swe/sweagent/agent/action_sampler.py
vendored
Normal file
@@ -0,0 +1,317 @@
|
||||
from abc import abstractmethod
|
||||
from textwrap import dedent
|
||||
from typing import Any, Literal
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sweagent.agent.models import AbstractModel
|
||||
from sweagent.agent.problem_statement import ProblemStatement
|
||||
from sweagent.exceptions import FormatError
|
||||
from sweagent.tools.tools import ToolHandler
|
||||
from sweagent.types import Trajectory
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class ActionSamplerOutput(BaseModel):
|
||||
completion: dict[str, Any]
|
||||
messages: list[dict[str, Any]] = []
|
||||
trajectory_items: list[dict[str, Any]] = []
|
||||
extra_info: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AbstractActionSampler:
|
||||
def __init__(self, model: AbstractModel, tools: ToolHandler):
|
||||
self._model = model
|
||||
self._tools = tools
|
||||
self._logger = get_logger("action_sampler", emoji="👥")
|
||||
|
||||
@abstractmethod
|
||||
def get_action(
|
||||
self,
|
||||
problem_statement: ProblemStatement,
|
||||
trajectory: Trajectory,
|
||||
history: list[dict[str, Any]],
|
||||
) -> ActionSamplerOutput:
|
||||
"""Returns action with tool calls"""
|
||||
pass
|
||||
|
||||
|
||||
class AskColleaguesConfig(BaseModel):
|
||||
type: Literal["ask_colleagues"] = "ask_colleagues"
|
||||
|
||||
n_samples: int = 2
|
||||
|
||||
def get(self, model: AbstractModel, tools: ToolHandler) -> "AskColleagues":
|
||||
return AskColleagues(self, model, tools)
|
||||
|
||||
|
||||
class AskColleagues(AbstractActionSampler):
|
||||
def __init__(self, config: AskColleaguesConfig, model: AbstractModel, tools: ToolHandler):
|
||||
super().__init__(model, tools)
|
||||
self.config = config
|
||||
|
||||
def get_colleague_discussion(self, completions: list[dict[str, Any]]) -> str:
|
||||
"""Concat all completions into a single string"""
|
||||
out = "Your colleagues had the following ideas: \n\n"
|
||||
n_parsed_ok = 0
|
||||
for i, completion in enumerate(completions):
|
||||
try:
|
||||
thought, action = self._tools.parse_actions(completion)
|
||||
except FormatError:
|
||||
self._logger.warning("Could not parse completion %s, skipping.", completion)
|
||||
continue
|
||||
n_parsed_ok += 1
|
||||
out += f"Thought (colleague {i}): {thought}\nProposed Action (colleague {i}): {action}\n\n"
|
||||
if n_parsed_ok == 0:
|
||||
msg = "No completions could be parsed."
|
||||
raise FormatError(msg)
|
||||
out += (
|
||||
"Please summarize and compare the ideas and propose and action to take. "
|
||||
"Finally choose one action to perform and explain it in detail and include it as a tool call. "
|
||||
"<important>You must include a thought and action (as a tool/function call). Do not try to invoke commands with triple backticks, use function calls instead.</important>"
|
||||
)
|
||||
return out
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
problem_statement: ProblemStatement,
|
||||
trajectory: Trajectory,
|
||||
history: list[dict[str, Any]],
|
||||
) -> ActionSamplerOutput:
|
||||
"""Returns action with tool calls"""
|
||||
completions = self._model.query(history, n=self.config.n_samples) # type: ignore
|
||||
discussion = self.get_colleague_discussion(completions)
|
||||
self._logger.info(f"COLLEAGUE DISCUSSION:\n{discussion}")
|
||||
new_messages = [
|
||||
{"role": "user", "content": discussion},
|
||||
]
|
||||
final_completion = self._model.query(history + new_messages) # type: ignore
|
||||
return ActionSamplerOutput(
|
||||
completion=final_completion,
|
||||
extra_info={"colleagues": discussion},
|
||||
)
|
||||
|
||||
|
||||
class BinaryTrajectoryComparisonConfig(BaseModel):
|
||||
type: Literal["binary_trajectory_comparison"] = "binary_trajectory_comparison"
|
||||
|
||||
min_n_samples: int = 4
|
||||
max_n_samples: int = 10
|
||||
|
||||
comparison_temperature: float | None = None
|
||||
"""Override the model's temperature. If None, take the temperature configured for the model."""
|
||||
|
||||
system_template: str = """<setting>You are an expert software engineer overseeing junior developers. They suggest actions to take to solve a problem. You must choose the best action to take. </setting>"""
|
||||
instance_template: str = dedent("""
|
||||
We're solving the following problem
|
||||
|
||||
<problem_statement>
|
||||
{{problem_statement}}
|
||||
</problem_statement>
|
||||
|
||||
So far, we've performed the following actions:
|
||||
|
||||
<trajectory>
|
||||
{{traj}}
|
||||
</trajectory>
|
||||
""")
|
||||
|
||||
comparison_template: str = dedent("""
|
||||
Two junior developers suggested the following actions:
|
||||
|
||||
<thought1>
|
||||
{{thought1}}
|
||||
</thought1>
|
||||
|
||||
<action1>
|
||||
{{action1}}
|
||||
</action1>
|
||||
|
||||
<thought2>
|
||||
{{thought2}}
|
||||
</thought2>
|
||||
|
||||
<action2>
|
||||
{{action2}}
|
||||
</action2>
|
||||
|
||||
Please compare the two actions in detail.
|
||||
|
||||
Which action should we take?
|
||||
|
||||
If you think the first action is better, respond with "first".
|
||||
If you think the second action is better, respond with "second".
|
||||
|
||||
The last line of your response MUST be "first" or "second".
|
||||
""")
|
||||
|
||||
def get(self, model: AbstractModel, tools: ToolHandler) -> "BinaryTrajectoryComparison":
|
||||
return BinaryTrajectoryComparison(self, model, tools)
|
||||
|
||||
|
||||
class BinaryTrajectoryComparison(AbstractActionSampler):
|
||||
def __init__(self, config: BinaryTrajectoryComparisonConfig, model: AbstractModel, tools: ToolHandler):
|
||||
super().__init__(model, tools)
|
||||
self.config = config
|
||||
|
||||
def _format_trajectory(self, trajectory: Trajectory) -> str:
|
||||
steps = []
|
||||
for i, step in enumerate(trajectory):
|
||||
steps.append(f"Action {i}: {step['action']}\n Observation {i}: {step['observation']}")
|
||||
return "\n".join(steps)
|
||||
|
||||
def format_messages(
|
||||
self,
|
||||
*,
|
||||
problem_statement: ProblemStatement,
|
||||
trajectory: Trajectory,
|
||||
thought1: str,
|
||||
action1: str,
|
||||
thought2: str,
|
||||
action2: str,
|
||||
use_cache_control: bool = False,
|
||||
) -> list[dict]:
|
||||
system_message = self.config.system_template
|
||||
self._logger.debug(f"MODEL INPUT (system)\n{system_message}")
|
||||
ps_format_dict = {
|
||||
"problem_statement": problem_statement.get_problem_statement(),
|
||||
**problem_statement.get_extra_fields(),
|
||||
}
|
||||
user_message = Template(self.config.instance_template).render(
|
||||
**ps_format_dict,
|
||||
traj=self._format_trajectory(trajectory),
|
||||
)
|
||||
self._logger.debug(f"MODEL INPUT (instance)\n{user_message}")
|
||||
comparison_message = Template(self.config.comparison_template).render(
|
||||
thought1=thought1,
|
||||
action1=action1,
|
||||
thought2=thought2,
|
||||
action2=action2,
|
||||
)
|
||||
self._logger.debug(f"MODEL INPUT (comparison)\n{comparison_message}")
|
||||
cache_control_kwargs = {"cache_control": {"type": "ephemeral"}} if use_cache_control else {}
|
||||
return [
|
||||
{"role": "system", "content": system_message},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_message, **cache_control_kwargs}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comparison_message,
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def filter_duplicates(self, completions: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Filter out duplicate actions"""
|
||||
thoughts: list[str] = []
|
||||
actions: list[str] = []
|
||||
filtered_completions: list[dict[str, Any]] = []
|
||||
for pc in completions:
|
||||
thought, action = self._tools.parse_actions(pc)
|
||||
if action not in actions:
|
||||
thoughts.append(thought)
|
||||
actions.append(action)
|
||||
filtered_completions.append(pc)
|
||||
|
||||
if len(filtered_completions) < len(completions):
|
||||
self._logger.debug("Filtering duplicates: %d -> %d", len(completions), len(filtered_completions))
|
||||
|
||||
return filtered_completions
|
||||
|
||||
def filter_parseable_completions(self, completions: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
filtered_completions = []
|
||||
for completion in completions:
|
||||
try:
|
||||
self._tools.parse_actions(completion)
|
||||
except FormatError:
|
||||
self._logger.warning("Could not parse completion %s, skipping.", completion)
|
||||
continue
|
||||
filtered_completions.append(completion)
|
||||
if len(filtered_completions) == 0:
|
||||
msg = "No completions could be parsed."
|
||||
raise FormatError(msg)
|
||||
return filtered_completions
|
||||
|
||||
def contains_edits(self, completions: list[dict[str, Any]]) -> bool:
|
||||
keywords = ["edit", "str_replace_editor insert", "str_replace_editor str_replace"]
|
||||
for completion in completions:
|
||||
_, action = self._tools.parse_actions(completion)
|
||||
if any(action.startswith(keyword) for keyword in keywords):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_completions(self, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
completions = self._model.query(history, n=self.config.min_n_samples) # type: ignore
|
||||
completions = self.filter_parseable_completions(completions)
|
||||
completions = self.filter_duplicates(completions)
|
||||
if not completions:
|
||||
msg = "No completions could be parsed."
|
||||
raise FormatError(msg)
|
||||
if self.contains_edits(completions) and self.config.min_n_samples < self.config.max_n_samples:
|
||||
self._logger.debug("Edits were proposed, will sample more")
|
||||
new_completions = self._model.query(history, n=self.config.max_n_samples - self.config.min_n_samples) # type: ignore
|
||||
completions = self.filter_duplicates(self.filter_parseable_completions(completions + new_completions))
|
||||
if len(completions) == 1:
|
||||
_, action = self._tools.parse_actions(completions[0])
|
||||
self._logger.warning("Only identical actions were proposed (action=%s)", action)
|
||||
return completions
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
*,
|
||||
problem_statement: ProblemStatement,
|
||||
trajectory: Trajectory,
|
||||
history: list[dict[str, Any]],
|
||||
) -> ActionSamplerOutput:
|
||||
completions = self.get_completions(history)
|
||||
best_idx = 0
|
||||
comparison_log = []
|
||||
for i in range(1, len(completions)):
|
||||
thought1, action1 = self._tools.parse_actions(completions[best_idx])
|
||||
thought2, action2 = self._tools.parse_actions(completions[i])
|
||||
messages = self.format_messages(
|
||||
problem_statement=problem_statement,
|
||||
trajectory=trajectory,
|
||||
thought1=thought1,
|
||||
action1=action1,
|
||||
thought2=thought2,
|
||||
action2=action2,
|
||||
use_cache_control=len(completions) >= 3,
|
||||
)
|
||||
response = self._model.query(messages, temperature=self.config.comparison_temperature)["message"] # type: ignore
|
||||
self._logger.info(f"RESPONSE: {response}")
|
||||
idx = self.interpret(response)
|
||||
comparison_log.append(
|
||||
{
|
||||
"comparison_between": (best_idx, i),
|
||||
"messages": messages,
|
||||
"response": response,
|
||||
"idx": idx,
|
||||
}
|
||||
)
|
||||
best_idx = i if idx == 1 else best_idx
|
||||
|
||||
return ActionSamplerOutput(
|
||||
completion=completions[best_idx],
|
||||
extra_info={"comparison_log": comparison_log},
|
||||
)
|
||||
|
||||
def interpret(self, response: str) -> Literal[0, 1]:
|
||||
"""Interpret response from LM. Note: 1-based indexing"""
|
||||
last_line = response.strip().split("\n")[-1].strip()
|
||||
if "first" in last_line.lower():
|
||||
return 0
|
||||
elif "second" in last_line.lower():
|
||||
return 1
|
||||
self._logger.warning("Could not interpret response: %s, will choose first submission.", response)
|
||||
return 0
|
||||
|
||||
|
||||
ActionSamplerConfig = BinaryTrajectoryComparisonConfig | AskColleaguesConfig
|
||||
1294
.agent/vendor/mini-swe/sweagent/agent/agents.py
vendored
Normal file
106
.agent/vendor/mini-swe/sweagent/agent/extra/shell_agent.py
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
from sweagent.agent.agents import DefaultAgent, ShellAgentConfig
|
||||
from sweagent.agent.models import HumanModel, HumanModelConfig, get_model
|
||||
from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.tools.parsing import ActionOnlyParser
|
||||
from sweagent.tools.tools import ToolHandler
|
||||
from sweagent.types import AgentRunResult, StepOutput
|
||||
|
||||
|
||||
class ShellAgent(DefaultAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ShellAgentConfig) -> Self:
|
||||
# To ensure that all models stay completely independent, we deepcopy the
|
||||
# model config, because it lives on as a property in the model, tools, etc.
|
||||
config = config.model_copy(deep=True)
|
||||
model = get_model(config.model, config.tools)
|
||||
return cls(
|
||||
templates=config.templates,
|
||||
tools=ToolHandler(config.tools),
|
||||
history_processors=config.history_processors,
|
||||
model=model,
|
||||
max_requeries=config.max_requeries,
|
||||
)
|
||||
|
||||
def human_step_in(self) -> None:
|
||||
"""Replace the current model with a HumanModel instance.
|
||||
This allows for human intervention during agent execution.
|
||||
"""
|
||||
self._original_model = self.model
|
||||
self._original_parser = self.tools.config.parse_function
|
||||
|
||||
human_config = HumanModelConfig(name="human", catch_eof=False)
|
||||
self.model = get_model(human_config, self.tools.config)
|
||||
self.tools.config.parse_function = ActionOnlyParser()
|
||||
|
||||
self.logger.info("Switched to human mode. Agent will now accept human input. Press ^D to switch back.")
|
||||
|
||||
def human_step_out(self) -> None:
|
||||
"""Switch back to the original model from human mode.
|
||||
This is called when ^D is pressed in human mode.
|
||||
"""
|
||||
if not hasattr(self, "_original_model") or self._original_model is None:
|
||||
self.logger.info("No previous model to switch back to. Remaining in current mode.")
|
||||
return
|
||||
|
||||
self.model = self._original_model
|
||||
self.tools.config.parse_function = self._original_parser # type: ignore
|
||||
self._original_model = None
|
||||
self._original_parser = None
|
||||
|
||||
self.logger.info("Switched back to AI model mode.")
|
||||
|
||||
def run(
|
||||
self,
|
||||
env: SWEEnv,
|
||||
problem_statement: ProblemStatement | ProblemStatementConfig,
|
||||
*,
|
||||
output_dir: Path = Path("."),
|
||||
) -> AgentRunResult:
|
||||
"""Run the agent on a problem instance. This method contains the
|
||||
main loop that repeatedly calls `self._step` until the problem is solved.
|
||||
|
||||
Args:
|
||||
setup_args: Arguments to pass to the agent's setup method.
|
||||
env: The environment to run the agent on.
|
||||
traj_dir: Directory to save the trajectory to
|
||||
interruptible: Whether the human can jump in by pressing ^C
|
||||
"""
|
||||
self.setup(env=env, problem_statement=problem_statement, output_dir=output_dir)
|
||||
|
||||
# Run action/observation loop
|
||||
self._chook.on_run_start()
|
||||
step_output = StepOutput()
|
||||
while not step_output.done:
|
||||
try:
|
||||
step_output = self.step()
|
||||
self.save_trajectory()
|
||||
except KeyboardInterrupt:
|
||||
if not isinstance(self.model, HumanModel):
|
||||
self.human_step_in()
|
||||
continue
|
||||
raise
|
||||
except EOFError:
|
||||
# Can only happen if we have a human model, so switch back
|
||||
self.logger.info("Detected ^D - switching back to AI mode")
|
||||
self.human_step_out()
|
||||
continue
|
||||
if step_output.done and not isinstance(self.model, HumanModel):
|
||||
# Human has to submit the solution
|
||||
self.logger.info("Robot is done! Please submit the solution.")
|
||||
self.human_step_in()
|
||||
step_output.done = False
|
||||
self._chook.on_run_done(trajectory=self.trajectory, info=self.info)
|
||||
|
||||
self.logger.info("Trajectory saved to %s", self.traj_path)
|
||||
|
||||
# Here we want to return the "global" information (e.g., submission should
|
||||
# be the best submission instead of the last one, etc.), so we get it from the traj file
|
||||
data = self.get_trajectory_data()
|
||||
return AgentRunResult(info=data["info"], trajectory=data["trajectory"])
|
||||
399
.agent/vendor/mini-swe/sweagent/agent/history_processors.py
vendored
Normal file
@@ -0,0 +1,399 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Annotated, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from sweagent.types import History, HistoryItem
|
||||
|
||||
|
||||
class AbstractHistoryProcessor(Protocol):
|
||||
@abstractmethod
|
||||
def __call__(self, history: History) -> History:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Utility functions
|
||||
# -----------------
|
||||
|
||||
|
||||
def _get_content_stats(entry: HistoryItem) -> tuple[int, int]:
|
||||
if isinstance(entry["content"], str):
|
||||
return len(entry["content"].splitlines()), 0
|
||||
n_text_lines = sum(len(item["text"].splitlines()) for item in entry["content"] if item.get("type") == "text")
|
||||
n_images = sum(1 for item in entry["content"] if item.get("type") == "image_url")
|
||||
return n_text_lines, n_images
|
||||
|
||||
|
||||
def _get_content_text(entry: HistoryItem) -> str:
|
||||
if isinstance(entry["content"], str):
|
||||
return entry["content"]
|
||||
assert len(entry["content"]) == 1, "Expected single message in content"
|
||||
return entry["content"][0]["text"]
|
||||
|
||||
|
||||
def _set_content_text(entry: HistoryItem, text: str) -> None:
|
||||
if isinstance(entry["content"], str):
|
||||
entry["content"] = text
|
||||
else:
|
||||
assert len(entry["content"]) == 1, "Expected single message in content"
|
||||
entry["content"][0]["text"] = text
|
||||
|
||||
|
||||
def _clear_cache_control(entry: HistoryItem) -> None:
|
||||
if isinstance(entry["content"], list):
|
||||
for item in entry["content"]:
|
||||
item.pop("cache_control", None)
|
||||
entry.pop("cache_control", None)
|
||||
|
||||
|
||||
def _set_cache_control(entry: HistoryItem) -> None:
|
||||
if not isinstance(entry["content"], list):
|
||||
entry["content"] = [ # type: ignore
|
||||
{
|
||||
"type": "text",
|
||||
"text": _get_content_text(entry),
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
else:
|
||||
entry["content"][0]["cache_control"] = {"type": "ephemeral"}
|
||||
if entry["role"] == "tool":
|
||||
# Workaround for weird bug
|
||||
entry["content"][0].pop("cache_control", None)
|
||||
entry["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
|
||||
# History processors
|
||||
# ------------------
|
||||
|
||||
|
||||
class DefaultHistoryProcessor(BaseModel):
|
||||
type: Literal["default"] = "default"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
return history
|
||||
|
||||
|
||||
class LastNObservations(BaseModel):
|
||||
"""Elide all but the last n observations or remove tagged observations.
|
||||
|
||||
This is our most classic history processor, used in the original paper
|
||||
to elide but the last 5 observations.
|
||||
Elided observations are replaced by "Old environment output: (n lines omitted)".
|
||||
|
||||
Typical configuration:
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
history_processors:
|
||||
- type: last_n_observations
|
||||
n: 5
|
||||
```
|
||||
|
||||
as for example in use in the SWE-agent 0.7 config at
|
||||
https://github.com/SWE-agent/SWE-agent/blob/main/config/sweagent_0_7/07.yaml
|
||||
|
||||
For most use cases, you only need to set `n`.
|
||||
|
||||
Note that using this history processor will break prompt caching (as the
|
||||
history of every query will change every time due to the elided observations).
|
||||
There are some workarounds possible with the `polling` parameter.
|
||||
|
||||
However, most SotA models can now fit a lot of context, so generally this
|
||||
history processor is not always needed anymore.
|
||||
"""
|
||||
|
||||
n: int
|
||||
"""Number of observations to keep."""
|
||||
|
||||
polling: int = 1
|
||||
"""How many steps to keep between updating the number of observations to keep.
|
||||
This is useful for caching, as we want to remove more and more messages, but every
|
||||
time we change the history, we need to cache everything again.
|
||||
Effectively, we will now keep between `n` and `n+polling` observations.
|
||||
"""
|
||||
|
||||
always_remove_output_for_tags: set[str] = {"remove_output"}
|
||||
"""Any observation with a `tags` field containing one of these strings will be elided,
|
||||
even if it is one of the last n observations.
|
||||
"""
|
||||
|
||||
always_keep_output_for_tags: set[str] = {"keep_output"}
|
||||
"""Any observation with a `tags` field containing one of these strings will be kept,
|
||||
even if it is not one of the last n observations.
|
||||
"""
|
||||
|
||||
type: Literal["last_n_observations"] = "last_n_observations"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("n")
|
||||
def validate_n(cls, n: int) -> int:
|
||||
if n <= 0:
|
||||
msg = "n must be a positive integer"
|
||||
raise ValueError(msg)
|
||||
return n
|
||||
|
||||
def _get_omit_indices(self, history: History) -> list[int]:
|
||||
observation_indices = [
|
||||
idx
|
||||
for idx, entry in enumerate(history)
|
||||
if entry.get("message_type") == "observation" and not entry.get("is_demo", False)
|
||||
]
|
||||
last_removed_idx = max(0, (len(observation_indices) // self.polling) * self.polling - self.n)
|
||||
# Note: We never remove the first observation, as it is the instance template
|
||||
return observation_indices[1:last_removed_idx]
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
new_history = []
|
||||
omit_content_idxs = self._get_omit_indices(history)
|
||||
for idx, entry in enumerate(history):
|
||||
tags = set(entry.get("tags", []))
|
||||
if ((idx not in omit_content_idxs) or (tags & self.always_keep_output_for_tags)) and not (
|
||||
tags & self.always_remove_output_for_tags
|
||||
):
|
||||
new_history.append(entry)
|
||||
else:
|
||||
data = entry.copy()
|
||||
assert data.get("message_type") == "observation", (
|
||||
f"Expected observation for dropped entry, got: {data.get('message_type')}"
|
||||
)
|
||||
num_text_lines, num_images = _get_content_stats(data)
|
||||
data["content"] = f"Old environment output: ({num_text_lines} lines omitted)"
|
||||
if num_images > 0:
|
||||
data["content"] += f" ({num_images} images omitted)"
|
||||
new_history.append(data)
|
||||
return new_history
|
||||
|
||||
|
||||
class TagToolCallObservations(BaseModel):
|
||||
"""Adds tags to history items for specific tool calls."""
|
||||
|
||||
type: Literal["tag_tool_call_observations"] = "tag_tool_call_observations"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
tags: set[str] = {"keep_output"}
|
||||
"""Add the following tag to all observations matching the search criteria."""
|
||||
|
||||
function_names: set[str] = set()
|
||||
"""Only consider observations made by tools with these names."""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def _add_tags(self, entry: HistoryItem) -> None:
|
||||
tags = set(entry.get("tags", []))
|
||||
tags.update(self.tags)
|
||||
entry["tags"] = list(tags)
|
||||
|
||||
def _should_add_tags(self, entry: HistoryItem) -> bool:
|
||||
if entry.get("message_type") != "action":
|
||||
return False
|
||||
function_calls = entry.get("tool_calls", [])
|
||||
if not function_calls:
|
||||
return False
|
||||
function_names = {call["function"]["name"] for call in function_calls} # type: ignore
|
||||
return bool(self.function_names & function_names)
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
for entry in history:
|
||||
if self._should_add_tags(entry):
|
||||
self._add_tags(entry)
|
||||
return history
|
||||
|
||||
|
||||
class ClosedWindowHistoryProcessor(BaseModel):
|
||||
"""For each value in history, keep track of which windows have been shown.
|
||||
We want to mark windows that should stay open (they're the last window for a particular file)
|
||||
Then we'll replace all other windows with a simple summary of the window (i.e. number of lines)
|
||||
"""
|
||||
|
||||
type: Literal["closed_window"] = "closed_window"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
_pattern = re.compile(r"^(\d+)\:.*?(\n|$)", re.MULTILINE)
|
||||
_file_pattern = re.compile(r"\[File:\s+(.*)\s+\(\d+\s+lines\ total\)\]")
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __call__(self, history):
|
||||
new_history = list()
|
||||
windows = set()
|
||||
for entry in reversed(history):
|
||||
data = entry.copy()
|
||||
if data["role"] != "user":
|
||||
new_history.append(entry)
|
||||
continue
|
||||
if data.get("is_demo", False):
|
||||
new_history.append(entry)
|
||||
continue
|
||||
matches = list(self._pattern.finditer(entry["content"]))
|
||||
if len(matches) >= 1:
|
||||
file_match = self._file_pattern.search(entry["content"])
|
||||
if file_match:
|
||||
file = file_match.group(1)
|
||||
else:
|
||||
continue
|
||||
if file in windows:
|
||||
start = matches[0].start()
|
||||
end = matches[-1].end()
|
||||
data["content"] = (
|
||||
entry["content"][:start]
|
||||
+ f"Outdated window with {len(matches)} lines omitted...\n"
|
||||
+ entry["content"][end:]
|
||||
)
|
||||
windows.add(file)
|
||||
new_history.append(data)
|
||||
return list(reversed(new_history))
|
||||
|
||||
|
||||
class CacheControlHistoryProcessor(BaseModel):
|
||||
"""This history processor adds manual cache control marks to the history.
|
||||
Use this when running with anthropic claude.
|
||||
"""
|
||||
|
||||
type: Literal["cache_control"] = "cache_control"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
last_n_messages: int = 2
|
||||
"""Add cache control to the last n user messages (and clear it for anything else).
|
||||
In most cases this should be set to 2 (caching for multi-turn conversations).
|
||||
When resampling and running concurrent instances, you want to set it to 1.
|
||||
If set to <= 0, any set cache control will be removed from all messages.
|
||||
"""
|
||||
|
||||
last_n_messages_offset: int = 0
|
||||
"""E.g., set to 1 to start cache control after the second to last user message.
|
||||
This can be useful in rare cases, when you want to modify the last message after
|
||||
we've got the completion and you want to avoid cache mismatch.
|
||||
"""
|
||||
|
||||
tagged_roles: list[str] = ["user", "tool"]
|
||||
"""Only add cache control to messages with these roles."""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
new_history = []
|
||||
n_tagged = 0
|
||||
for i_entry, entry in enumerate(reversed(history)):
|
||||
# Clear cache control from previous messages
|
||||
_clear_cache_control(entry)
|
||||
if (
|
||||
n_tagged < self.last_n_messages
|
||||
and entry["role"] in self.tagged_roles
|
||||
and i_entry >= self.last_n_messages_offset
|
||||
):
|
||||
_set_cache_control(entry)
|
||||
n_tagged += 1
|
||||
new_history.append(entry)
|
||||
return list(reversed(new_history))
|
||||
|
||||
|
||||
class RemoveRegex(BaseModel):
|
||||
"""This history processor can remove arbitrary content from history items"""
|
||||
|
||||
remove: list[str] = ["<diff>.*</diff>"]
|
||||
"""Regex patterns to remove from history items"""
|
||||
|
||||
keep_last: int = 0
|
||||
"""Keep the last n history items unchanged"""
|
||||
|
||||
type: Literal["remove_regex"] = "remove_regex"
|
||||
"""Do not change. Used for (de)serialization."""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
new_history = []
|
||||
for i_entry, entry in enumerate(reversed(history)):
|
||||
entry = copy.deepcopy(entry)
|
||||
if i_entry < self.keep_last:
|
||||
new_history.append(entry)
|
||||
else:
|
||||
if isinstance(entry["content"], list):
|
||||
for item in entry["content"]:
|
||||
if item["type"] == "text":
|
||||
for pattern in self.remove:
|
||||
item["text"] = re.sub(pattern, "", item["text"], flags=re.DOTALL)
|
||||
else:
|
||||
assert isinstance(entry["content"], str), "Expected string content"
|
||||
for pattern in self.remove:
|
||||
entry["content"] = re.sub(pattern, "", entry["content"], flags=re.DOTALL)
|
||||
new_history.append(entry)
|
||||
return list(reversed(new_history))
|
||||
|
||||
|
||||
class ImageParsingHistoryProcessor(BaseModel):
|
||||
"""Parse embedded base64 images from markdown and convert to multi-modal format."""
|
||||
|
||||
type: Literal["image_parsing"] = "image_parsing"
|
||||
allowed_mime_types: set[str] = {"image/png", "image/jpeg", "image/webp"}
|
||||
|
||||
_pattern = re.compile(r"(!\[([^\]]*)\]\(data:)([^;]+);base64,([^)]+)(\))")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def __call__(self, history: History) -> History:
|
||||
return [self._process_entry(entry) for entry in history]
|
||||
|
||||
def _process_entry(self, entry: HistoryItem) -> HistoryItem:
|
||||
if entry.get("role") not in ["user", "tool"]:
|
||||
return entry
|
||||
entry = copy.deepcopy(entry)
|
||||
content = _get_content_text(entry)
|
||||
segments = self._parse_images(content)
|
||||
if any(seg["type"] == "image_url" for seg in segments):
|
||||
entry["content"] = segments
|
||||
return entry
|
||||
|
||||
def _parse_images(self, content: str) -> list[dict]:
|
||||
segments = []
|
||||
last_end = 0
|
||||
has_images = False
|
||||
|
||||
def add_text(text: str) -> None:
|
||||
"""Add text to the last segment if it's text, otherwise create new text segment."""
|
||||
if text and segments and segments[-1]["type"] == "text":
|
||||
segments[-1]["text"] += text
|
||||
elif text:
|
||||
segments.append({"type": "text", "text": text})
|
||||
|
||||
for match in self._pattern.finditer(content):
|
||||
markdown_prefix, alt_text, mime_type, base64_data, markdown_suffix = match.groups()
|
||||
add_text(content[last_end : match.start()])
|
||||
mime_type = "image/jpeg" if mime_type == "image/jpg" else mime_type
|
||||
if mime_type in self.allowed_mime_types:
|
||||
add_text(markdown_prefix)
|
||||
segments.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}})
|
||||
add_text(markdown_suffix)
|
||||
has_images = True
|
||||
else:
|
||||
add_text(match.group(0))
|
||||
last_end = match.end()
|
||||
add_text(content[last_end:])
|
||||
return segments if has_images else [{"type": "text", "text": content}]
|
||||
|
||||
|
||||
HistoryProcessor = Annotated[
|
||||
DefaultHistoryProcessor
|
||||
| LastNObservations
|
||||
| ClosedWindowHistoryProcessor
|
||||
| TagToolCallObservations
|
||||
| CacheControlHistoryProcessor
|
||||
| RemoveRegex
|
||||
| ImageParsingHistoryProcessor,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
0
.agent/vendor/mini-swe/sweagent/agent/hooks/__init__.py
vendored
Normal file
139
.agent/vendor/mini-swe/sweagent/agent/hooks/abstract.py
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sweagent.types import AgentInfo, StepOutput, Trajectory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circular import
|
||||
from sweagent.agent.agents import DefaultAgent
|
||||
|
||||
|
||||
class AbstractAgentHook:
|
||||
def on_init(self, *, agent: "DefaultAgent"):
|
||||
"""Note: Depending on the internals of `Agent` should be done with care,
|
||||
it's best to use this as little as possible.
|
||||
"""
|
||||
|
||||
def on_run_start(
|
||||
self,
|
||||
): ...
|
||||
|
||||
def on_step_start(self): ...
|
||||
|
||||
def on_actions_generated(self, *, step: StepOutput): ...
|
||||
|
||||
def on_action_started(self, *, step: StepOutput): ...
|
||||
|
||||
def on_action_executed(self, *, step: StepOutput): ...
|
||||
|
||||
def on_step_done(self, *, step: StepOutput, info: AgentInfo): ...
|
||||
|
||||
def on_run_done(self, *, trajectory: Trajectory, info: AgentInfo): ...
|
||||
|
||||
def on_setup_attempt(self): ...
|
||||
|
||||
def on_model_query(self, *, messages: list[dict[str, str]], agent: str):
|
||||
"""Actually query the model with the complete history."""
|
||||
|
||||
def on_query_message_added(
|
||||
self,
|
||||
*,
|
||||
agent: str,
|
||||
role: str,
|
||||
content: str,
|
||||
message_type: str,
|
||||
is_demo: bool = False,
|
||||
thought: str = "",
|
||||
action: str = "",
|
||||
tool_calls: list[dict[str, str]] | None = None,
|
||||
tool_call_ids: list[str] | None = None,
|
||||
): ...
|
||||
|
||||
def on_setup_done(self): ...
|
||||
|
||||
def on_tools_installation_started(self): ...
|
||||
|
||||
|
||||
class CombinedAgentHook(AbstractAgentHook):
|
||||
def __init__(self, hooks: list[AbstractAgentHook] | None = None):
|
||||
self._hooks = hooks or []
|
||||
|
||||
def add_hook(self, hook: AbstractAgentHook):
|
||||
self._hooks.append(hook)
|
||||
|
||||
@property
|
||||
def hooks(self) -> list[AbstractAgentHook]:
|
||||
return self._hooks
|
||||
|
||||
def on_init(self, *, agent: "DefaultAgent"):
|
||||
for hook in self.hooks:
|
||||
hook.on_init(agent=agent)
|
||||
|
||||
def on_run_start(self):
|
||||
for hook in self.hooks:
|
||||
hook.on_run_start()
|
||||
|
||||
def on_step_start(self):
|
||||
for hook in self.hooks:
|
||||
hook.on_step_start()
|
||||
|
||||
def on_actions_generated(self, *, step: StepOutput):
|
||||
for hook in self.hooks:
|
||||
hook.on_actions_generated(step=step)
|
||||
|
||||
def on_action_started(self, *, step: StepOutput):
|
||||
for hook in self.hooks:
|
||||
hook.on_action_started(step=step)
|
||||
|
||||
def on_action_executed(self, *, step: StepOutput):
|
||||
for hook in self.hooks:
|
||||
hook.on_action_executed(step=step)
|
||||
|
||||
def on_step_done(self, *, step: StepOutput, info: AgentInfo):
|
||||
for hook in self.hooks:
|
||||
hook.on_step_done(step=step, info=info)
|
||||
|
||||
def on_run_done(self, *, trajectory: Trajectory, info: AgentInfo):
|
||||
for hook in self.hooks:
|
||||
hook.on_run_done(trajectory=trajectory, info=info)
|
||||
|
||||
def on_setup_attempt(self):
|
||||
for hook in self.hooks:
|
||||
hook.on_setup_attempt()
|
||||
|
||||
def on_model_query(self, *, messages: list[dict[str, str]], agent: str):
|
||||
for hook in self.hooks:
|
||||
hook.on_model_query(messages=messages, agent=agent)
|
||||
|
||||
def on_query_message_added(
|
||||
self,
|
||||
*,
|
||||
agent: str,
|
||||
role: str,
|
||||
content: str,
|
||||
message_type: str,
|
||||
is_demo: bool = False,
|
||||
thought: str = "",
|
||||
action: str = "",
|
||||
tool_calls: list[dict[str, str]] | None = None,
|
||||
tool_call_ids: list[str] | None = None,
|
||||
thinking_blocks: list[dict[str, str]] | None = None,
|
||||
):
|
||||
for hook in self.hooks:
|
||||
hook.on_query_message_added(
|
||||
agent=agent,
|
||||
role=role,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
is_demo=is_demo,
|
||||
thought=thought,
|
||||
action=action,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_ids=tool_call_ids,
|
||||
)
|
||||
|
||||
def on_setup_done(self):
|
||||
return super().on_setup_done()
|
||||
|
||||
def on_tools_installation_started(self):
|
||||
for hook in self.hooks:
|
||||
hook.on_tools_installation_started()
|
||||
34
.agent/vendor/mini-swe/sweagent/agent/hooks/status.py
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sweagent.agent.hooks.abstract import AbstractAgentHook
|
||||
from sweagent.types import AgentInfo, StepOutput
|
||||
|
||||
|
||||
class SetStatusAgentHook(AbstractAgentHook):
|
||||
def __init__(self, id: str, callable: Callable[[str, str], None]):
|
||||
self._callable = callable
|
||||
self._id = id
|
||||
self._i_step = 0
|
||||
self._cost = 0.0
|
||||
self._i_attempt = 0
|
||||
self._previous_cost = 0.0
|
||||
|
||||
def on_setup_attempt(self):
|
||||
self._i_attempt += 1
|
||||
self._i_step = 0
|
||||
# Costs will be reset for the next attempt
|
||||
self._previous_cost += self._cost
|
||||
|
||||
def _update(self, message: str):
|
||||
self._callable(self._id, message)
|
||||
|
||||
def on_step_start(self):
|
||||
self._i_step += 1
|
||||
attempt_str = f"Attempt {self._i_attempt} " if self._i_attempt > 1 else ""
|
||||
self._update(f"{attempt_str}Step {self._i_step:>3} (${self._previous_cost + self._cost:.2f})")
|
||||
|
||||
def on_step_done(self, *, step: StepOutput, info: AgentInfo):
|
||||
self._cost = info["model_stats"]["instance_cost"] # type: ignore
|
||||
|
||||
def on_tools_installation_started(self):
|
||||
self._update("Installing tools")
|
||||
903
.agent/vendor/mini-swe/sweagent/agent/models.py
vendored
Normal file
@@ -0,0 +1,903 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import litellm
|
||||
import litellm.types.utils
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import ConfigDict, Field, SecretStr
|
||||
from swerex.exceptions import SwerexException
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
Retrying,
|
||||
retry_if_not_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from sweagent import REPO_ROOT, __version__
|
||||
from sweagent.exceptions import (
|
||||
ContentPolicyViolationError,
|
||||
ContextWindowExceededError,
|
||||
CostLimitExceededError,
|
||||
FunctionCallingFormatError,
|
||||
InstanceCallLimitExceededError,
|
||||
InstanceCostLimitExceededError,
|
||||
ModelConfigurationError,
|
||||
TotalCostLimitExceededError,
|
||||
)
|
||||
from sweagent.tools.tools import ToolConfig
|
||||
from sweagent.types import History, HistoryItem
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
try:
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
readline = None
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
_THREADS_THAT_USED_API_KEYS = []
|
||||
"""Keeps track of thread orders so that we can choose the same API key for the same thread."""
|
||||
|
||||
|
||||
class RetryConfig(PydanticBaseModel):
|
||||
"""This configuration object specifies how many times to retry a failed LM API call."""
|
||||
|
||||
retries: int = 20
|
||||
"""Number of retries"""
|
||||
min_wait: float = 10
|
||||
"""Minimum wait time between retries (random exponential wait)"""
|
||||
max_wait: float = 120
|
||||
"""Maximum wait time between retries (random exponential wait)"""
|
||||
|
||||
|
||||
class GenericAPIModelConfig(PydanticBaseModel):
|
||||
"""This configuration object specifies a LM like GPT4 or similar.
|
||||
The model will be served with the help of the `litellm` library.
|
||||
"""
|
||||
|
||||
name: str = Field(description="Name of the model.")
|
||||
|
||||
per_instance_cost_limit: float = Field(
|
||||
default=3.0,
|
||||
description="Cost limit for every instance (task).",
|
||||
)
|
||||
total_cost_limit: float = Field(default=0.0, description="Total cost limit.")
|
||||
per_instance_call_limit: int = Field(default=0, description="Per instance call limit.")
|
||||
temperature: float = 0.0
|
||||
"""Sampling temperature"""
|
||||
top_p: float | None = 1.0
|
||||
"""Sampling top-p"""
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
api_key: SecretStr | None = None
|
||||
"""API key to the model. We recommend using environment variables to set this instead
|
||||
or putting your environment variables in a `.env` file.
|
||||
You can concatenate more than one key by separating them with `:::`, e.g.,
|
||||
`key1:::key2`.
|
||||
If field starts with `$`, it will be interpreted as an environment variable.
|
||||
"""
|
||||
stop: list[str] = []
|
||||
"""Custom stop sequences"""
|
||||
|
||||
completion_kwargs: dict[str, Any] = {}
|
||||
"""Additional kwargs to pass to `litellm.completion`"""
|
||||
|
||||
convert_system_to_user: bool = False
|
||||
"""Whether to convert system messages to user messages. This is useful for
|
||||
models that do not support system messages like o1.
|
||||
"""
|
||||
|
||||
retry: RetryConfig = RetryConfig()
|
||||
"""Retry configuration: How often to retry after a failure (e.g., from a rate limit)
|
||||
etc.
|
||||
"""
|
||||
|
||||
delay: float = 0.0
|
||||
"""Minimum delay before querying (this can help to avoid overusing the API if sharing
|
||||
it with other people).
|
||||
"""
|
||||
|
||||
fallbacks: list[dict[str, Any]] = []
|
||||
"""List of fallbacks to try if the main model fails
|
||||
See https://docs.litellm.ai/docs/completion/reliable_completions#fallbacks-sdk
|
||||
for more information.
|
||||
"""
|
||||
|
||||
choose_api_key_by_thread: bool = True
|
||||
"""Whether to choose the API key based on the thread name (if multiple are configured).
|
||||
This ensures that with
|
||||
run-batch, we use the same API key within a single-thread so that prompt caching still works.
|
||||
"""
|
||||
|
||||
max_input_tokens: int | None = None
|
||||
"""If set, this will override the max input tokens for the model that we usually look
|
||||
up from `litellm.model_cost`.
|
||||
Use this for local models or if you want to set a custom max input token limit.
|
||||
If this value is exceeded, a `ContextWindowExceededError` will be raised.
|
||||
Set this to 0 to disable this check.
|
||||
"""
|
||||
|
||||
max_output_tokens: int | None = None
|
||||
"""If set, this will override the max output tokens for the model that we usually look
|
||||
up from `litellm.model_cost`.
|
||||
Use this for local models or if you want to set a custom max output token limit.
|
||||
If this value is exceeded, a `ContextWindowExceededError` will be raised.
|
||||
Set this to 0 to disable this check.
|
||||
"""
|
||||
|
||||
litellm_model_registry: str | None = None
|
||||
"""If set, this will override the default model registry for litellm.
|
||||
Use this for local models or models not (yet) in the default litellm model registry for tracking costs.
|
||||
"""
|
||||
|
||||
custom_tokenizer: dict[str, Any] | None = None
|
||||
"""Override the default tokenizer for the model.
|
||||
Use the arguments of `litellm.create_pretrained_tokenizer`.
|
||||
Basic example: `{"identifier": "hf-internal-testing/llama-tokenizer"}`
|
||||
"""
|
||||
|
||||
# pydantic
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def get_api_keys(self) -> list[str]:
|
||||
"""Returns a list of API keys that were explicitly set in this config.
|
||||
Does not return API keys that were set via environment variables/.env
|
||||
"""
|
||||
if self.api_key is None:
|
||||
return []
|
||||
api_key = self.api_key.get_secret_value()
|
||||
if not api_key:
|
||||
return []
|
||||
if api_key.startswith("$"):
|
||||
env_var_name = api_key[1:]
|
||||
api_key = os.getenv(env_var_name, "")
|
||||
if not api_key:
|
||||
get_logger("swea-config", emoji="🔧").warning(f"Environment variable {env_var_name} not set")
|
||||
return []
|
||||
return api_key.split(":::")
|
||||
|
||||
def choose_api_key(self) -> str | None:
|
||||
"""Chooses an API key based on the API keys explicitly set in this config.
|
||||
If no API keys are set, returns None (which means that the API key will be
|
||||
taken from the environment variables/.env file).
|
||||
"""
|
||||
api_keys = self.get_api_keys()
|
||||
if not api_keys:
|
||||
return None
|
||||
if not self.choose_api_key_by_thread:
|
||||
return random.choice(api_keys)
|
||||
thread_name = threading.current_thread().name
|
||||
if thread_name not in _THREADS_THAT_USED_API_KEYS:
|
||||
_THREADS_THAT_USED_API_KEYS.append(thread_name)
|
||||
thread_idx = _THREADS_THAT_USED_API_KEYS.index(thread_name)
|
||||
key_idx = thread_idx % len(api_keys)
|
||||
get_logger("config", emoji="🔧").debug(
|
||||
f"Choosing API key {key_idx} for thread {thread_name} (idx {thread_idx})"
|
||||
)
|
||||
return api_keys[key_idx]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
name = self.name.replace("/", "--")
|
||||
if self.top_p is not None:
|
||||
top_p = f"{self.top_p:.2f}"
|
||||
else:
|
||||
top_p = "None"
|
||||
temperature = f"{self.temperature:.2f}"
|
||||
per_instance_cost_limit = f"{self.per_instance_cost_limit:.2f}"
|
||||
return f"{name}__t-{temperature}__p-{top_p}__c-{per_instance_cost_limit}"
|
||||
|
||||
|
||||
class ReplayModelConfig(GenericAPIModelConfig):
|
||||
replay_path: Path = Field(description="Path to replay file when using the replay model.")
|
||||
|
||||
per_instance_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
|
||||
)
|
||||
total_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
|
||||
)
|
||||
|
||||
name: Literal["replay"] = Field(default="replay", description="Model name.")
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class InstantEmptySubmitModelConfig(GenericAPIModelConfig):
|
||||
"""Model that immediately submits an empty patch"""
|
||||
|
||||
name: Literal["instant_empty_submit"] = Field(default="instant_empty_submit", description="Model name.")
|
||||
|
||||
per_instance_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
|
||||
)
|
||||
total_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
|
||||
)
|
||||
delay: float = 0.0
|
||||
"""Delay before answering"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class HumanModelConfig(GenericAPIModelConfig):
|
||||
name: Literal["human"] = Field(default="human", description="Model name.")
|
||||
|
||||
per_instance_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
|
||||
)
|
||||
total_cost_limit: float = Field(default=0.0, description="Cost limit for all instances (tasks).")
|
||||
cost_per_call: float = 0.0
|
||||
catch_eof: bool = True
|
||||
"""Whether to catch EOF and return 'exit' when ^D is pressed. Set to False when used in human_step_in mode."""
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class HumanThoughtModelConfig(HumanModelConfig):
|
||||
name: Literal["human_thought"] = Field(default="human_thought", description="Model name.")
|
||||
|
||||
per_instance_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
|
||||
)
|
||||
total_cost_limit: float = Field(
|
||||
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
|
||||
)
|
||||
cost_per_call: float = 0.0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
ModelConfig = Annotated[
|
||||
GenericAPIModelConfig
|
||||
| ReplayModelConfig
|
||||
| InstantEmptySubmitModelConfig
|
||||
| HumanModelConfig
|
||||
| HumanThoughtModelConfig,
|
||||
Field(union_mode="left_to_right"),
|
||||
]
|
||||
|
||||
|
||||
class GlobalStats(PydanticBaseModel):
|
||||
"""This class tracks usage numbers (costs etc.) across all instances."""
|
||||
|
||||
total_cost: float = 0
|
||||
"""Cumulative cost for all instances so far"""
|
||||
|
||||
last_query_timestamp: float = 0
|
||||
"""Timestamp of the last query. Currently only used with API models."""
|
||||
|
||||
|
||||
GLOBAL_STATS = GlobalStats()
|
||||
"""This object tracks usage numbers (costs etc.) across all instances.
|
||||
Please use the `GLOBAL_STATS_LOCK` lock when accessing this object to avoid race conditions.
|
||||
"""
|
||||
|
||||
GLOBAL_STATS_LOCK = Lock()
|
||||
"""Lock for accessing `GLOBAL_STATS` without race conditions"""
|
||||
|
||||
|
||||
class InstanceStats(PydanticBaseModel):
|
||||
"""This object tracks usage numbers (costs etc.) for a single instance."""
|
||||
|
||||
instance_cost: float = 0
|
||||
tokens_sent: int = 0
|
||||
tokens_received: int = 0
|
||||
api_calls: int = 0
|
||||
|
||||
def __add__(self, other: InstanceStats) -> InstanceStats:
|
||||
return InstanceStats(
|
||||
**{field: getattr(self, field) + getattr(other, field) for field in self.model_fields.keys()},
|
||||
)
|
||||
|
||||
def __sub__(self, other: InstanceStats) -> InstanceStats:
|
||||
return InstanceStats(
|
||||
**{field: getattr(self, field) - getattr(other, field) for field in self.model_fields.keys()},
|
||||
)
|
||||
|
||||
|
||||
class AbstractModel(ABC):
|
||||
def __init__(self, config: ModelConfig, tools: ToolConfig):
|
||||
self.config: ModelConfig
|
||||
self.stats: InstanceStats
|
||||
|
||||
def reset_stats(self):
|
||||
self.stats = InstanceStats()
|
||||
|
||||
@abstractmethod
|
||||
def query(self, history: History, action_prompt: str = "> ") -> dict: ...
|
||||
|
||||
@property
|
||||
def instance_cost_limit(self) -> float:
|
||||
"""Cost limit for the model. Returns 0 if there is no limit."""
|
||||
return 0
|
||||
|
||||
|
||||
def _handle_raise_commands(action: str) -> None:
|
||||
if action == "raise_runtime":
|
||||
raise SwerexException()
|
||||
elif action == "raise_cost":
|
||||
raise CostLimitExceededError()
|
||||
elif action == "raise_context":
|
||||
raise ContextWindowExceededError()
|
||||
elif action.startswith("raise_function_calling"):
|
||||
parts = shlex.split(action)
|
||||
error_code = parts[1]
|
||||
if len(parts) == 3:
|
||||
error_message = parts[2]
|
||||
assert len(parts) < 4
|
||||
raise FunctionCallingFormatError(error_message, error_code) # type: ignore
|
||||
|
||||
|
||||
class HumanModel(AbstractModel):
|
||||
def __init__(self, config: HumanModelConfig, tools: ToolConfig):
|
||||
"""Model that allows for human-in-the-loop"""
|
||||
self.logger = get_logger("swea-lm", emoji="🤖")
|
||||
self.config: HumanModelConfig = config
|
||||
self.stats = InstanceStats()
|
||||
|
||||
# Determine which commands require multi-line input
|
||||
self.multi_line_command_endings = {
|
||||
command.name: command.end_name for command in tools.commands if command.end_name is not None
|
||||
}
|
||||
self._readline_histfile = REPO_ROOT / ".swe-agent-human-history"
|
||||
self._load_readline_history()
|
||||
|
||||
def _load_readline_history(self) -> None:
|
||||
"""Load autocomplete history from file"""
|
||||
if readline is None:
|
||||
return
|
||||
if self._readline_histfile.is_file():
|
||||
self.logger.debug(f"Loading readline history from {self._readline_histfile}")
|
||||
readline.read_history_file(self._readline_histfile)
|
||||
|
||||
def _save_readline_history(self) -> None:
|
||||
"""Save autocomplete history to file"""
|
||||
if readline is None:
|
||||
return
|
||||
readline.write_history_file(self._readline_histfile)
|
||||
|
||||
def _update_stats(
|
||||
self,
|
||||
) -> None:
|
||||
self.stats.instance_cost += self.config.cost_per_call
|
||||
self.stats.api_calls += 1
|
||||
if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
|
||||
msg = f"Instance cost limit exceeded: {self.stats.instance_cost} > {self.config.per_instance_cost_limit}"
|
||||
raise InstanceCostLimitExceededError(msg)
|
||||
if 0 < self.config.total_cost_limit < self.stats.instance_cost:
|
||||
msg = f"Total cost limit exceeded: {self.stats.instance_cost} > {self.config.total_cost_limit}"
|
||||
raise TotalCostLimitExceededError(msg)
|
||||
|
||||
def _query(
|
||||
self,
|
||||
history: History,
|
||||
action_prompt: str = "> ",
|
||||
) -> dict:
|
||||
"""Logic for handling user input to pass to SWEEnv"""
|
||||
action = input(action_prompt)
|
||||
self._save_readline_history()
|
||||
command_name = action.split()[0] if action.strip() else ""
|
||||
|
||||
# Special handling for multi-line input actions (i.e. edit)
|
||||
if command_name in self.multi_line_command_endings:
|
||||
buffer = [action]
|
||||
end_keyword = self.multi_line_command_endings[command_name]
|
||||
while True:
|
||||
action = input("... ")
|
||||
buffer.append(action)
|
||||
if action.rstrip() == end_keyword:
|
||||
# Continue reading input until terminating keyword inputted
|
||||
break
|
||||
action = "\n".join(buffer)
|
||||
elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
|
||||
buffer = []
|
||||
while True:
|
||||
action = input("... ")
|
||||
if action.rstrip() == "end_multiline_command":
|
||||
break
|
||||
buffer.append(action)
|
||||
action = "\n".join(buffer)
|
||||
else:
|
||||
# Input has escaped things like \n, so we need to unescape it
|
||||
action = action.encode("utf8").decode("unicode_escape")
|
||||
if action.strip() and action.strip().split()[0] == "spend_money":
|
||||
money = float(action.strip().split()[1])
|
||||
self.stats.instance_cost += money
|
||||
action = f"echo 'Spent {money} dollars'"
|
||||
_handle_raise_commands(action)
|
||||
self._update_stats()
|
||||
return {"message": action}
|
||||
|
||||
def query(self, history: History, action_prompt: str = "> ", n: int | None = None, **kwargs) -> dict | list[dict]:
|
||||
"""Wrapper to separate action prompt from formatting"""
|
||||
out = []
|
||||
n_samples = n or 1
|
||||
for _ in range(n_samples):
|
||||
try:
|
||||
out.append(self._query(history, action_prompt))
|
||||
except KeyboardInterrupt:
|
||||
print("^C (exit with ^D)")
|
||||
out.append(self.query(history, action_prompt))
|
||||
except EOFError:
|
||||
if self.config.catch_eof:
|
||||
print("\nGoodbye!")
|
||||
out.append({"message": "exit"})
|
||||
else:
|
||||
# Re-raise EOFError when catch_eof is disabled
|
||||
raise
|
||||
if n is None:
|
||||
return out[0]
|
||||
return out
|
||||
|
||||
|
||||
class HumanThoughtModel(HumanModel):
|
||||
def query(self, history: History, **kwargs) -> dict:
|
||||
"""Logic for handling user input (both thought + action) to pass to SWEEnv"""
|
||||
thought_all = ""
|
||||
thought = input("Thought (end w/ END_THOUGHT): ")
|
||||
while True:
|
||||
if "END_THOUGHT" in thought:
|
||||
thought = thought.split("END_THOUGHT")[0]
|
||||
thought_all += thought
|
||||
break
|
||||
thought_all += thought
|
||||
thought = input("... ")
|
||||
|
||||
action = super()._query(history, action_prompt="Action: ")["message"]
|
||||
|
||||
return {"message": f"{thought_all}\n```\n{action}\n```"}
|
||||
|
||||
|
||||
class ReplayModel(AbstractModel):
|
||||
def __init__(self, config: ReplayModelConfig, tools: ToolConfig):
|
||||
"""Model used for replaying a trajectory (i.e., taking all the actions for the `.traj` file
|
||||
and re-issuing them.
|
||||
"""
|
||||
self.config = config
|
||||
self.stats = InstanceStats()
|
||||
|
||||
if not self.config.replay_path.exists():
|
||||
msg = f"Replay file {self.config.replay_path} not found"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
self._replays = [
|
||||
list(json.loads(x).values())[0] for x in Path(self.config.replay_path).read_text().splitlines(keepends=True)
|
||||
]
|
||||
self._replay_idx = 0
|
||||
self._action_idx = 0
|
||||
self.use_function_calling = tools.use_function_calling
|
||||
self.submit_command = tools.submit_command
|
||||
self.logger = get_logger("swea-lm", emoji="🤖")
|
||||
|
||||
def _next_replay(self) -> None:
|
||||
"""Called after last action"""
|
||||
self._replay_idx += 1
|
||||
self._action_idx = 0
|
||||
|
||||
def query(self, history: History) -> dict:
|
||||
"""Logic for tracking which replay action to pass to SWEEnv"""
|
||||
self.stats.api_calls += 1
|
||||
actions = self._replays[self._replay_idx]
|
||||
try:
|
||||
action = actions[self._action_idx]
|
||||
except IndexError:
|
||||
# log error
|
||||
self.logger.error("Reached end of replay trajectory without submitting. Submitting now.")
|
||||
if self.use_function_calling:
|
||||
action = {
|
||||
"message": f"Calling `{self.submit_command}` to submit.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "call_submit",
|
||||
"function": {
|
||||
"name": self.submit_command,
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
action = f"```\n{self.submit_command}\n```"
|
||||
|
||||
self._action_idx += 1
|
||||
|
||||
# Assuming `submit` is always last action of replay trajectory
|
||||
if isinstance(action, str) and action == "submit":
|
||||
self._next_replay()
|
||||
return {"message": action}
|
||||
|
||||
# Handle both dict and string actions
|
||||
if isinstance(action, dict):
|
||||
return action
|
||||
return {"message": action}
|
||||
|
||||
|
||||
class PredeterminedTestModel(AbstractModel):
|
||||
def __init__(self, outputs: list[dict | str]):
|
||||
"""Model that outputs a predetermined sequence of messages. Useful for testing."""
|
||||
self._outputs = outputs
|
||||
self._idx = -1
|
||||
self.stats = InstanceStats()
|
||||
|
||||
def query(self, *args, **kwargs) -> dict:
|
||||
self._idx += 1
|
||||
output = self._outputs[self._idx]
|
||||
if isinstance(output, str):
|
||||
_handle_raise_commands(output)
|
||||
return {"message": output}
|
||||
if not isinstance(output, dict):
|
||||
msg = f"Output must be string or dict, got {type(output)}"
|
||||
raise ValueError(msg)
|
||||
result = {"message": output["message"]}
|
||||
if "tool_calls" in output:
|
||||
result["tool_calls"] = output["tool_calls"]
|
||||
return result
|
||||
|
||||
|
||||
class InstantEmptySubmitTestModel(AbstractModel):
|
||||
def __init__(self, args: InstantEmptySubmitModelConfig, tools: ToolConfig):
|
||||
"""This model immediately submits. Useful for testing purposes"""
|
||||
super().__init__(args, tools)
|
||||
self.config: InstantEmptySubmitModelConfig = args
|
||||
self.stats = InstanceStats()
|
||||
self._action_idx = 0
|
||||
|
||||
def query(self, history: list[dict[str, str]]) -> dict:
|
||||
time.sleep(random.uniform(0, self.config.delay))
|
||||
# Need to at least do _something_ to submit
|
||||
if self._action_idx == 0:
|
||||
self._action_idx = 1
|
||||
action = (
|
||||
"DISCUSSION\n"
|
||||
"Let's reproduce the bug by creating a `reproduce.py` file.\n\n"
|
||||
"```\n"
|
||||
"touch reproduce.py\n"
|
||||
"```\n"
|
||||
)
|
||||
elif self._action_idx == 1:
|
||||
self._action_idx = 0
|
||||
action = "DISCUSSION\nThe task should be resolved, so let's submit the patch.\n\n```\nsubmit\n```\n"
|
||||
self.stats.api_calls += 1
|
||||
return {"message": action}
|
||||
|
||||
|
||||
class LiteLLMModel(AbstractModel):
|
||||
def __init__(self, args: GenericAPIModelConfig, tools: ToolConfig):
|
||||
"""Model served by the `litellm` library."""
|
||||
# Always copy config to avoid shared state between different instances
|
||||
self.config: GenericAPIModelConfig = args.model_copy(deep=True)
|
||||
self.stats = InstanceStats()
|
||||
self.tools = tools
|
||||
self.logger = get_logger("swea-lm", emoji="🤖")
|
||||
|
||||
if tools.use_function_calling:
|
||||
if not litellm.utils.supports_function_calling(model=self.config.name):
|
||||
msg = (
|
||||
f"Model {self.config.name} does not support function calling. If your model"
|
||||
" does not support function calling, you can use `parse_function='thought_action'` instead. "
|
||||
"See https://swe-agent.com/latest/faq/ for more information."
|
||||
)
|
||||
self.logger.warning(msg)
|
||||
if self.config.litellm_model_registry is not None:
|
||||
with open(self.config.litellm_model_registry) as f:
|
||||
model_costs = json.load(f)
|
||||
litellm.register_model(model_costs)
|
||||
if self.config.max_input_tokens is not None:
|
||||
self.model_max_input_tokens = self.config.max_input_tokens
|
||||
else:
|
||||
self.model_max_input_tokens = litellm.model_cost.get(self.config.name, {}).get("max_input_tokens")
|
||||
|
||||
if self.config.max_output_tokens is not None:
|
||||
self.model_max_output_tokens = self.config.max_output_tokens
|
||||
else:
|
||||
self.model_max_output_tokens = litellm.model_cost.get(self.config.name, {}).get("max_output_tokens")
|
||||
# Special handling for Claude 3.7 models to set 64k context by default when beta header not present
|
||||
# See https://github.com/SWE-agent/SWE-agent/pull/1016
|
||||
is_claude_3_7 = "claude-3-7-sonnet" in self.config.name or "claude-sonnet-4" in self.config.name
|
||||
has_128k_beta_header = (
|
||||
self.config.completion_kwargs.get("extra_headers", {}).get("anthropic-beta") == "output-128k-2025-02-19"
|
||||
)
|
||||
if is_claude_3_7 and not has_128k_beta_header:
|
||||
self.model_max_output_tokens = 64000
|
||||
self.logger.warning(
|
||||
"Claude 3.7/4 models do not support 128k context by default. "
|
||||
"Setting max output tokens to 64k. To enable 128k context, please set the "
|
||||
"completion_kwargs to {'extra_headers': {'anthropic-beta': 'output-128k-2025-02-19'}}."
|
||||
)
|
||||
|
||||
self.lm_provider = litellm.model_cost.get(self.config.name, {}).get("litellm_provider", self.config.name)
|
||||
self.custom_tokenizer = None
|
||||
if self.config.custom_tokenizer is not None:
|
||||
self.custom_tokenizer = litellm.utils.create_pretrained_tokenizer(**self.config.custom_tokenizer)
|
||||
|
||||
@property
|
||||
def instance_cost_limit(self) -> float:
|
||||
"""Cost limit for the model. Returns 0 if there is no limit."""
|
||||
return self.config.per_instance_cost_limit
|
||||
|
||||
def _update_stats(self, *, input_tokens: int, output_tokens: int, cost: float) -> None:
|
||||
with GLOBAL_STATS_LOCK:
|
||||
GLOBAL_STATS.total_cost += cost
|
||||
self.stats.instance_cost += cost
|
||||
self.stats.tokens_sent += input_tokens
|
||||
self.stats.tokens_received += output_tokens
|
||||
self.stats.api_calls += 1
|
||||
|
||||
# Log updated cost values to std. err
|
||||
self.logger.debug(
|
||||
f"input_tokens={input_tokens:,}, "
|
||||
f"output_tokens={output_tokens:,}, "
|
||||
f"instance_cost={self.stats.instance_cost:.2f}, "
|
||||
f"cost={cost:.2f}",
|
||||
)
|
||||
self.logger.debug(
|
||||
f"total_tokens_sent={self.stats.tokens_sent:,}, "
|
||||
f"total_tokens_received={self.stats.tokens_received:,}, "
|
||||
f"total_cost={GLOBAL_STATS.total_cost:.2f}, "
|
||||
f"total_api_calls={self.stats.api_calls:,}",
|
||||
)
|
||||
|
||||
# Check whether total cost or instance cost limits have been exceeded
|
||||
if 0 < self.config.total_cost_limit < GLOBAL_STATS.total_cost:
|
||||
self.logger.warning(f"Cost {GLOBAL_STATS.total_cost:.2f} exceeds limit {self.config.total_cost_limit:.2f}")
|
||||
msg = "Total cost limit exceeded"
|
||||
raise TotalCostLimitExceededError(msg)
|
||||
|
||||
if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
|
||||
self.logger.warning(
|
||||
f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.config.per_instance_cost_limit:.2f}"
|
||||
)
|
||||
msg = "Instance cost limit exceeded"
|
||||
raise InstanceCostLimitExceededError(msg)
|
||||
|
||||
if 0 < self.config.per_instance_call_limit < self.stats.api_calls:
|
||||
self.logger.warning(f"API calls {self.stats.api_calls} exceeds limit {self.config.per_instance_call_limit}")
|
||||
msg = "Per instance call limit exceeded"
|
||||
raise InstanceCallLimitExceededError(msg)
|
||||
|
||||
def _sleep(self) -> None:
|
||||
elapsed_time = time.time() - GLOBAL_STATS.last_query_timestamp
|
||||
if elapsed_time < self.config.delay:
|
||||
time.sleep(self.config.delay - elapsed_time)
|
||||
with GLOBAL_STATS_LOCK:
|
||||
GLOBAL_STATS.last_query_timestamp = time.time()
|
||||
|
||||
def _single_query(
|
||||
self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
|
||||
) -> list[dict]:
|
||||
self._sleep()
|
||||
# Workaround for litellm bug https://github.com/SWE-agent/SWE-agent/issues/1109
|
||||
messages_no_cache_control = copy.deepcopy(messages)
|
||||
for message in messages_no_cache_control:
|
||||
if "cache_control" in message:
|
||||
del message["cache_control"]
|
||||
if "thinking_blocks" in message:
|
||||
del message["thinking_blocks"]
|
||||
input_tokens: int = litellm.utils.token_counter(
|
||||
messages=messages_no_cache_control,
|
||||
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
|
||||
custom_tokenizer=self.custom_tokenizer,
|
||||
)
|
||||
if self.model_max_input_tokens is None:
|
||||
msg = (
|
||||
f"No max input tokens found for model {self.config.name!r}. "
|
||||
"If you are using a local model, you can set `max_input_token` in the model config to override this."
|
||||
)
|
||||
self.logger.warning(msg)
|
||||
elif input_tokens > self.model_max_input_tokens > 0:
|
||||
msg = f"Input tokens {input_tokens} exceed max tokens {self.model_max_input_tokens}"
|
||||
raise ContextWindowExceededError(msg)
|
||||
extra_args = {}
|
||||
if self.config.api_base:
|
||||
# Not assigned a default value in litellm, so only pass this if it's set
|
||||
extra_args["api_base"] = self.config.api_base
|
||||
if self.tools.use_function_calling:
|
||||
extra_args["tools"] = self.tools.tools
|
||||
# We need to always set max_tokens for anthropic models
|
||||
completion_kwargs = copy.deepcopy(self.config.completion_kwargs)
|
||||
if self.lm_provider == "anthropic":
|
||||
completion_kwargs["max_tokens"] = self.model_max_output_tokens
|
||||
|
||||
# Add User-Agent header (don't override user-provided headers)
|
||||
if "extra_headers" not in completion_kwargs:
|
||||
completion_kwargs["extra_headers"] = {}
|
||||
if "User-Agent" not in completion_kwargs["extra_headers"]:
|
||||
completion_kwargs["extra_headers"]["User-Agent"] = f"swe-agent/{__version__}"
|
||||
|
||||
try:
|
||||
response: litellm.types.utils.ModelResponse = litellm.completion( # type: ignore
|
||||
model=self.config.name,
|
||||
messages=messages,
|
||||
temperature=self.config.temperature if temperature is None else temperature,
|
||||
top_p=self.config.top_p,
|
||||
api_version=self.config.api_version,
|
||||
api_key=self.config.choose_api_key(),
|
||||
fallbacks=self.config.fallbacks,
|
||||
**completion_kwargs,
|
||||
**extra_args,
|
||||
n=n,
|
||||
)
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
raise ContextWindowExceededError from e
|
||||
except litellm.exceptions.ContentPolicyViolationError as e:
|
||||
raise ContentPolicyViolationError from e
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
if "is longer than the model's context length" in str(e):
|
||||
raise ContextWindowExceededError from e
|
||||
raise
|
||||
self.logger.debug(f"Response: {response}")
|
||||
try:
|
||||
cost = litellm.cost_calculator.completion_cost(response, model=self.config.name)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error calculating cost: {e}, setting cost to 0.")
|
||||
if self.config.per_instance_cost_limit > 0 or self.config.total_cost_limit > 0:
|
||||
msg = (
|
||||
f"Error calculating cost: {e} for your model {self.config.name}. If this is ok "
|
||||
"(local models, etc.), please make sure you set `per_instance_cost_limit` and "
|
||||
"`total_cost_limit` to 0 to disable this safety check."
|
||||
)
|
||||
self.logger.error(msg)
|
||||
raise ModelConfigurationError(msg)
|
||||
cost = 0
|
||||
choices: litellm.types.utils.Choices = response.choices # type: ignore
|
||||
n_choices = n if n is not None else 1
|
||||
outputs = []
|
||||
output_tokens = 0
|
||||
for i in range(n_choices):
|
||||
output = choices[i].message.content or ""
|
||||
output_tokens += litellm.utils.token_counter(
|
||||
text=output,
|
||||
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
|
||||
custom_tokenizer=self.custom_tokenizer,
|
||||
)
|
||||
output_dict = {"message": output}
|
||||
if self.tools.use_function_calling:
|
||||
if response.choices[i].message.tool_calls: # type: ignore
|
||||
tool_calls = [call.to_dict() for call in response.choices[i].message.tool_calls] # type: ignore
|
||||
else:
|
||||
tool_calls = []
|
||||
output_dict["tool_calls"] = tool_calls
|
||||
if (
|
||||
hasattr(response.choices[i].message, "thinking_blocks") # type: ignore
|
||||
and response.choices[i].message.thinking_blocks # type: ignore
|
||||
):
|
||||
output_dict["thinking_blocks"] = response.choices[i].message.thinking_blocks # type: ignore
|
||||
outputs.append(output_dict)
|
||||
self._update_stats(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost)
|
||||
return outputs
|
||||
|
||||
def _query(
|
||||
self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
|
||||
) -> list[dict]:
|
||||
if n is None:
|
||||
return self._single_query(messages, temperature=temperature)
|
||||
outputs = []
|
||||
# not needed for openai, but oh well.
|
||||
for _ in range(n):
|
||||
outputs.extend(self._single_query(messages))
|
||||
return outputs
|
||||
|
||||
def query(self, history: History, n: int = 1, temperature: float | None = None) -> list[dict] | dict:
|
||||
messages = self._history_to_messages(history)
|
||||
|
||||
def retry_warning(retry_state: RetryCallState):
|
||||
exception_info = ""
|
||||
if attempt.retry_state.outcome is not None and attempt.retry_state.outcome.exception() is not None:
|
||||
exception = attempt.retry_state.outcome.exception()
|
||||
exception_info = f" due to {exception.__class__.__name__}: {str(exception)}"
|
||||
|
||||
self.logger.warning(
|
||||
f"Retrying LM query: attempt {attempt.retry_state.attempt_number} "
|
||||
f"(slept for {attempt.retry_state.idle_for:.2f}s)"
|
||||
f"{exception_info}"
|
||||
)
|
||||
|
||||
for attempt in Retrying(
|
||||
stop=stop_after_attempt(self.config.retry.retries),
|
||||
wait=wait_random_exponential(min=self.config.retry.min_wait, max=self.config.retry.max_wait),
|
||||
reraise=True,
|
||||
retry=retry_if_not_exception_type(
|
||||
(
|
||||
ContextWindowExceededError,
|
||||
CostLimitExceededError,
|
||||
RuntimeError,
|
||||
litellm.exceptions.UnsupportedParamsError,
|
||||
litellm.exceptions.NotFoundError,
|
||||
litellm.exceptions.PermissionDeniedError,
|
||||
litellm.exceptions.ContextWindowExceededError,
|
||||
litellm.exceptions.APIError,
|
||||
litellm.exceptions.ContentPolicyViolationError,
|
||||
TypeError,
|
||||
litellm.exceptions.AuthenticationError,
|
||||
ContentPolicyViolationError,
|
||||
ModelConfigurationError,
|
||||
KeyboardInterrupt,
|
||||
IndexError,
|
||||
)
|
||||
),
|
||||
before_sleep=retry_warning,
|
||||
):
|
||||
with attempt:
|
||||
result = self._query(messages, n=n, temperature=temperature)
|
||||
if n is None or n == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def _history_to_messages(
|
||||
self,
|
||||
history: History,
|
||||
) -> list[dict[str, str]]:
|
||||
history = copy.deepcopy(history)
|
||||
|
||||
def get_role(history_item: HistoryItem) -> str:
|
||||
if history_item["role"] == "system":
|
||||
return "user" if self.config.convert_system_to_user else "system"
|
||||
return history_item["role"]
|
||||
|
||||
messages = []
|
||||
for history_item in history:
|
||||
role = get_role(history_item)
|
||||
if role == "tool":
|
||||
message = {
|
||||
"role": role,
|
||||
"content": history_item["content"],
|
||||
# Only one tool call per observations
|
||||
"tool_call_id": history_item["tool_call_ids"][0], # type: ignore
|
||||
}
|
||||
elif (tool_calls := history_item.get("tool_calls")) is not None:
|
||||
message = {"role": role, "content": history_item["content"], "tool_calls": tool_calls}
|
||||
if thinking_blocks := history_item.get("thinking_blocks"):
|
||||
message["thinking_blocks"] = thinking_blocks
|
||||
else:
|
||||
message = {"role": role, "content": history_item["content"]}
|
||||
if "cache_control" in history_item:
|
||||
message["cache_control"] = history_item["cache_control"]
|
||||
messages.append(message)
|
||||
n_cache_control = str(messages).count("cache_control")
|
||||
self.logger.debug(f"n_cache_control: {n_cache_control}")
|
||||
return messages
|
||||
|
||||
|
||||
def get_model(args: ModelConfig, tools: ToolConfig) -> AbstractModel:
|
||||
"""Returns correct model object given arguments and commands"""
|
||||
# Convert GenericAPIModelConfig to specific model config if needed
|
||||
if isinstance(args, GenericAPIModelConfig) and not isinstance(
|
||||
args, HumanModelConfig | HumanThoughtModelConfig | ReplayModelConfig | InstantEmptySubmitModelConfig
|
||||
):
|
||||
if args.name == "human":
|
||||
args = HumanModelConfig(**args.model_dump())
|
||||
elif args.name == "human_thought":
|
||||
args = HumanThoughtModelConfig(**args.model_dump())
|
||||
elif args.name == "replay":
|
||||
args = ReplayModelConfig(**args.model_dump())
|
||||
elif args.name == "instant_empty_submit":
|
||||
args = InstantEmptySubmitModelConfig(**args.model_dump())
|
||||
|
||||
if args.name == "human":
|
||||
assert isinstance(args, HumanModelConfig), f"Expected {HumanModelConfig}, got {args}"
|
||||
return HumanModel(args, tools)
|
||||
if args.name == "human_thought":
|
||||
assert isinstance(args, HumanThoughtModelConfig), f"Expected {HumanThoughtModelConfig}, got {args}"
|
||||
return HumanThoughtModel(args, tools)
|
||||
if args.name == "replay":
|
||||
assert isinstance(args, ReplayModelConfig), f"Expected {ReplayModelConfig}, got {args}"
|
||||
return ReplayModel(args, tools)
|
||||
elif args.name == "instant_empty_submit":
|
||||
assert isinstance(args, InstantEmptySubmitModelConfig), f"Expected {InstantEmptySubmitModelConfig}, got {args}"
|
||||
return InstantEmptySubmitTestModel(args, tools)
|
||||
assert isinstance(args, GenericAPIModelConfig), f"Expected {GenericAPIModelConfig}, got {args}"
|
||||
return LiteLLMModel(args, tools)
|
||||
312
.agent/vendor/mini-swe/sweagent/agent/problem_statement.py
vendored
Normal file
@@ -0,0 +1,312 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from sweagent.utils.github import _get_problem_statement_from_github_issue, _parse_gh_issue_url
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
logger = get_logger("swea-config", emoji="🔧")
|
||||
|
||||
# Constants for image processing
|
||||
VALID_IMAGE_MIME_TYPES = {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg", # Some servers return jpg instead of jpeg
|
||||
"image/webp",
|
||||
}
|
||||
|
||||
|
||||
class ProblemStatement(Protocol):
|
||||
"""A problem statement for a task. Any class that implements this protocol
|
||||
can be used as a problem statement.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
def get_problem_statement(self) -> str: ...
|
||||
|
||||
def get_problem_statement_for_env(self) -> str:
|
||||
"""Used for setting environment variables in the container.
|
||||
|
||||
By default, this is the same as get_problem_statement().
|
||||
"""
|
||||
return self.get_problem_statement()
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class _BuiltinProblemStatementBase(BaseModel):
|
||||
"""A base class for the builtin problem statements to avoid typing much"""
|
||||
|
||||
def get_problem_statement(self) -> str: ...
|
||||
|
||||
def get_problem_statement_for_env(self) -> str:
|
||||
return self.get_problem_statement()
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmptyProblemStatement(_BuiltinProblemStatementBase):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
type: Literal["empty"] = "empty"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def get_problem_statement(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class TextProblemStatement(_BuiltinProblemStatementBase):
|
||||
text: str
|
||||
|
||||
extra_fields: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional data to be added to the instance.
|
||||
This data will be available when formatting prompt templates.
|
||||
"""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
id: str = None # type: ignore
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.id is None:
|
||||
logger.info("Setting problem statement id to hash of text")
|
||||
self.id = hashlib.sha256(self.text.encode()).hexdigest()[:6]
|
||||
|
||||
def get_problem_statement(self) -> str:
|
||||
return self.text
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]:
|
||||
return self.extra_fields
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TextProblemStatement(id={self.id}, text={self.text[:30]}...)"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"id={self.id}, text={self.text[:30]}..."
|
||||
|
||||
|
||||
class FileProblemStatement(_BuiltinProblemStatementBase):
|
||||
path: Path
|
||||
|
||||
extra_fields: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional data to be added to the instance.
|
||||
This data will be available when formatting prompt templates.
|
||||
"""
|
||||
|
||||
type: Literal["text_file"] = "text_file"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
id: str = None # type: ignore
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.id is None:
|
||||
logger.info("Setting problem statement id to hash of file contents (path: %s)", self.path)
|
||||
self.id = hashlib.sha256(self.get_problem_statement().encode()).hexdigest()[:6]
|
||||
|
||||
def get_problem_statement(self) -> str:
|
||||
return self.path.read_text()
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]:
|
||||
return self.extra_fields
|
||||
|
||||
|
||||
class GithubIssue(_BuiltinProblemStatementBase):
|
||||
github_url: str
|
||||
|
||||
extra_fields: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional data to be added to the instance.
|
||||
This data will be available when formatting prompt templates.
|
||||
"""
|
||||
|
||||
type: Literal["github"] = "github"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
id: str = None # type: ignore
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.id is None:
|
||||
logger.info("Setting problem statement based on github issue url")
|
||||
owner, repo, issue_number = _parse_gh_issue_url(self.github_url)
|
||||
self.id = f"{owner}__{repo}-i{issue_number}"
|
||||
|
||||
def get_problem_statement(self) -> str:
|
||||
owner, repo, issue_number = _parse_gh_issue_url(self.github_url)
|
||||
return _get_problem_statement_from_github_issue(owner, repo, issue_number, token=os.getenv("GITHUB_TOKEN"))
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]:
|
||||
return self.extra_fields
|
||||
|
||||
|
||||
class SWEBenchMultimodalProblemStatement(_BuiltinProblemStatementBase):
|
||||
text: str
|
||||
|
||||
issue_images: list[str] = Field(default_factory=list)
|
||||
"""List of image asset URLs.
|
||||
"""
|
||||
|
||||
disable_image_processing: bool = False
|
||||
"""If True, skip image downloading and processing, treating this as a text-only problem statement.
|
||||
"""
|
||||
|
||||
extra_fields: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional data to be added to the instance.
|
||||
This data will be available when formatting prompt templates.
|
||||
"""
|
||||
|
||||
type: Literal["swe_bench_multimodal"] = "swe_bench_multimodal"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
id: str = None # type: ignore
|
||||
|
||||
_cached_problem_statement: str | None = PrivateAttr(default=None)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.id is None:
|
||||
logger.info("Setting problem statement id to hash of text")
|
||||
self.id = hashlib.sha256(self.text.encode()).hexdigest()[:6]
|
||||
|
||||
def get_problem_statement_for_env(self) -> str:
|
||||
"""Return the problem statement without images.
|
||||
|
||||
Images are not supported in the environment.
|
||||
"""
|
||||
return self.text
|
||||
|
||||
def get_problem_statement(self) -> str:
|
||||
if self.disable_image_processing:
|
||||
logger.info("Image processing disabled, returning text-only problem statement")
|
||||
return self.text
|
||||
|
||||
if self._cached_problem_statement is not None:
|
||||
return self._cached_problem_statement
|
||||
|
||||
processed_text = self.text
|
||||
for link in self.issue_images:
|
||||
try:
|
||||
image_markdown = self._download_and_convert_image(link)
|
||||
if image_markdown:
|
||||
processed_text += f"\n\n{image_markdown}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process image from {link}: {e}")
|
||||
|
||||
# cache to avoid re-processing images
|
||||
self._cached_problem_statement = processed_text
|
||||
return processed_text
|
||||
|
||||
def get_extra_fields(self) -> dict[str, Any]:
|
||||
return self.extra_fields
|
||||
|
||||
def _download_and_convert_image(self, url: str) -> str | None:
|
||||
"""Download an image from URL and convert it to base64 markdown format.
|
||||
|
||||
Args:
|
||||
url: The URL of the image to download
|
||||
|
||||
Returns:
|
||||
Base64 markdown string if successful, None if failed
|
||||
|
||||
Raises:
|
||||
Various exceptions for network/processing errors
|
||||
"""
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
logger.warning(f"Invalid URL format: {url}")
|
||||
return None
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.5735.133 Safari/537.36"
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=30, stream=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
if content_type == "image/jpg":
|
||||
content_type = "image/jpeg"
|
||||
if content_type not in VALID_IMAGE_MIME_TYPES:
|
||||
logger.warning(f"Unsupported image MIME type '{content_type}' for URL: {url}. Not encoding image.")
|
||||
return None
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length and int(content_length) > max_size:
|
||||
logger.warning(f"Image too large ({content_length} bytes) for URL: {url}")
|
||||
return None
|
||||
image_data = b""
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
image_data += chunk
|
||||
if len(image_data) > max_size:
|
||||
logger.warning(f"Image too large (>{max_size} bytes) for URL: {url}")
|
||||
return None
|
||||
if not image_data:
|
||||
logger.warning(f"Empty image data for URL: {url}")
|
||||
return None
|
||||
b64_data = base64.b64encode(image_data).decode("ascii")
|
||||
markdown = f""
|
||||
logger.info(f"Successfully processed image from {url} ({len(image_data)} bytes, {content_type})")
|
||||
return markdown
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.warning(f"Timeout downloading image from {url}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Network error downloading image from {url}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error processing image from {url}: {e}")
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
n_images = len(self.issue_images)
|
||||
return f"SWEBenchMultimodalProblemStatement(id={self.id}, text={self.text[:30]}..., images={n_images})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
n_images = len(self.issue_images)
|
||||
return f"id={self.id}, text={self.text[:30]}..., images={n_images}"
|
||||
|
||||
|
||||
ProblemStatementConfig = (
|
||||
TextProblemStatement
|
||||
| SWEBenchMultimodalProblemStatement
|
||||
| GithubIssue
|
||||
| EmptyProblemStatement
|
||||
| FileProblemStatement
|
||||
)
|
||||
|
||||
|
||||
def problem_statement_from_simplified_input(
|
||||
*, input: str, type: Literal["text", "text_file", "github_issue", "swe_bench_multimodal"]
|
||||
) -> ProblemStatementConfig:
|
||||
"""Get a problem statement from an `input` string and a `type`.
|
||||
|
||||
Args:
|
||||
input: Url/path/text
|
||||
type: The type of problem statement
|
||||
"""
|
||||
if type == "text":
|
||||
return TextProblemStatement(text=input)
|
||||
elif type == "text_file":
|
||||
return FileProblemStatement(path=Path(input))
|
||||
elif type == "github_issue":
|
||||
return GithubIssue(github_url=input)
|
||||
elif type == "swe_bench_multimodal":
|
||||
return SWEBenchMultimodalProblemStatement(text=input)
|
||||
else:
|
||||
msg = f"Unknown problem statement type: {type}"
|
||||
raise ValueError(msg)
|
||||
664
.agent/vendor/mini-swe/sweagent/agent/reviewer.py
vendored
Normal file
@@ -0,0 +1,664 @@
|
||||
"""The reviewer implements a retry loop for the agent to retry
|
||||
solving the issue and to select the best solution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from sweagent.agent.history_processors import _set_cache_control
|
||||
from sweagent.agent.models import (
|
||||
AbstractModel,
|
||||
InstanceStats,
|
||||
ModelConfig,
|
||||
get_model,
|
||||
)
|
||||
from sweagent.agent.problem_statement import ProblemStatement
|
||||
from sweagent.tools.parsing import ActionParser
|
||||
from sweagent.tools.tools import ToolConfig
|
||||
from sweagent.types import AgentInfo, Trajectory, TrajectoryStep
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class ReviewSubmission(BaseModel):
|
||||
"""Information that's passed to the reviewer"""
|
||||
|
||||
#: Total trajectory (including several retries)
|
||||
trajectory: Trajectory
|
||||
#: Aggregate info dict (including several retries)
|
||||
info: AgentInfo
|
||||
#: Model stats for this attempt
|
||||
model_stats: InstanceStats
|
||||
|
||||
def to_format_dict(self, *, suffix="") -> dict[str, Any]:
|
||||
"""Return all the data that is used to format the
|
||||
messages. Trajectory is excluded because it needs special treatment.
|
||||
"""
|
||||
out = {}
|
||||
info = copy.deepcopy(self.info)
|
||||
if not info.get("submission"):
|
||||
# Observed that not all exit_cost lead to autosubmission
|
||||
# so sometimes this might be missing.
|
||||
info["submission"] = ""
|
||||
for k, v in info.items():
|
||||
if isinstance(v, str):
|
||||
out[f"{k}{suffix}"] = v
|
||||
elif isinstance(v, dict):
|
||||
for k2, v2 in v.items():
|
||||
out[f"{k}_{k2}{suffix}"] = v2
|
||||
return out
|
||||
|
||||
|
||||
class ReviewerResult(BaseModel):
|
||||
accept: bool | float
|
||||
outputs: list[str]
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class PreselectorOutput(BaseModel):
|
||||
chosen_idx: list[int]
|
||||
response: str
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class ChooserOutput(BaseModel):
|
||||
chosen_idx: int
|
||||
response: str
|
||||
preselector_output: PreselectorOutput | None = None
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
# --- INTERFACES ---
|
||||
|
||||
|
||||
class AbstractReviewer(ABC):
|
||||
"""The reviewer checks a single solution and tries to predict
|
||||
if it successfully solves the issue.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
|
||||
"""Returns True if the submission is believed to be correct"""
|
||||
|
||||
|
||||
class AbstractRetryLoop(ABC):
|
||||
"""The review loop controls how often the agent tries to solve
|
||||
the issue and how it selects the best solution.
|
||||
"""
|
||||
|
||||
def retry(self) -> bool:
|
||||
"""Returns True if the agent should retry solving the issue"""
|
||||
return False
|
||||
|
||||
def on_submit(self, submission: ReviewSubmission) -> None:
|
||||
"""Called when the agent submits a solution"""
|
||||
|
||||
def on_model_query(self, attempt_stats: InstanceStats):
|
||||
"""Called before the model is queried. Can be used to implement
|
||||
stop conditions based on attempt cost etc.
|
||||
"""
|
||||
|
||||
def on_attempt_started(self, i_attempt: int, agent):
|
||||
"""Called when a new attempt is started"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_best(self) -> int:
|
||||
"""Returns the best solution"""
|
||||
|
||||
def get_forwarded_vars(self) -> dict[str, Any]:
|
||||
"""Get the variables that should be forwarded to the next iteration.
|
||||
|
||||
Returns:
|
||||
A dictionary of variables that should be forwarded to the next iteration.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
# --- CONFIGS ---
|
||||
|
||||
|
||||
class PreselectorConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
system_template: str
|
||||
instance_template: str
|
||||
submission_template: str
|
||||
max_len_submission: int = 5000
|
||||
|
||||
|
||||
class ChooserConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
system_template: str
|
||||
instance_template: str
|
||||
submission_template: str
|
||||
max_len_submission: int = 5000
|
||||
preselector: PreselectorConfig | None = None
|
||||
|
||||
|
||||
class TrajFormatterConfig(BaseModel):
|
||||
#: Filter the following actions from the trajectory
|
||||
filter: list[str] = []
|
||||
#: Filter outputs from the following actions from the trajectory
|
||||
output_filter: list[str] = []
|
||||
#: Format of the trajectory item
|
||||
item_template: str = "Model: {{response}}\n\nObservation: {{observation}}"
|
||||
only_show_last_n_output: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ReviewerConfig(BaseModel):
|
||||
"""The configuration for the reviewer"""
|
||||
|
||||
system_template: str
|
||||
instance_template: str
|
||||
#: If a submission autosubmits because of total cost or a similar exit status,
|
||||
#: it will get this malus to its score
|
||||
failure_score_penalty: float = 0.0
|
||||
traj_formatter: TrajFormatterConfig
|
||||
n_sample: int = 5
|
||||
reduce_by_std: float = 0.0
|
||||
score_range: tuple[float | None, float | None] = (None, None)
|
||||
#: If set, we assume that the score is in the range [score_range[0], score_range[1]]
|
||||
#: Reviews that are outside this range will be ignored
|
||||
|
||||
type: Literal["reviewer"] = "reviewer"
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def get_reviewer(self, model: AbstractModel) -> AbstractReviewer:
|
||||
return Reviewer(self, model)
|
||||
|
||||
|
||||
class ChooserRetryLoopConfig(BaseModel):
|
||||
type: Literal["chooser"] = "chooser"
|
||||
chooser: ChooserConfig
|
||||
|
||||
max_attempts: int
|
||||
min_budget_for_new_attempt: float = 0.0
|
||||
"""Minimal $ that need to be left in order for us to start a new attempt.
|
||||
If set to 0: Always.
|
||||
"""
|
||||
|
||||
cost_limit: float
|
||||
"""The maximum cost to spend on all attempts. Does not include cost of choosing.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def get_retry_loop(self, problem_statement: ProblemStatement) -> ChooserRetryLoop:
|
||||
return ChooserRetryLoop(self, problem_statement)
|
||||
|
||||
|
||||
class ScoreRetryLoopConfig(BaseModel):
|
||||
"""The configuration for the review loop"""
|
||||
|
||||
type: Literal["score"] = "score"
|
||||
|
||||
reviewer_config: ReviewerConfig
|
||||
|
||||
accept_score: float
|
||||
max_accepts: int = 1
|
||||
max_attempts: int
|
||||
|
||||
min_budget_for_new_attempt: float = 0.0
|
||||
"""Minimal $ that need to be left in order for us to start a new attempt.
|
||||
If set to 0: Always.
|
||||
"""
|
||||
|
||||
cost_limit: float
|
||||
"""The maximum cost to spend on all attempts and reviews except the last review.
|
||||
The last review is not included in the cost limit, because we would waste the last
|
||||
attempt if we couldn't score it.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def validate(self):
|
||||
"""Checks config. Raises `ValueError` in case of misconfiguration"""
|
||||
...
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
|
||||
def get_retry_loop(self, problem_statement: ProblemStatement) -> ScoreRetryLoop:
|
||||
return ScoreRetryLoop(self, problem_statement)
|
||||
|
||||
|
||||
RetryLoopConfig = ScoreRetryLoopConfig | ChooserRetryLoopConfig
|
||||
|
||||
# --- IMPLEMENTATIONS ---
|
||||
|
||||
|
||||
class Preselector:
|
||||
def __init__(self, config: PreselectorConfig):
|
||||
self.config = config
|
||||
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
|
||||
self.logger = get_logger("chooser", emoji="🧠")
|
||||
|
||||
def interpret(self, response: str) -> list[int]:
|
||||
if not response:
|
||||
self.logger.warning("No response from preselector")
|
||||
return []
|
||||
# Use regex to extract the last number of the response
|
||||
last_line = response.splitlines()[-1]
|
||||
try:
|
||||
return [int(i) for i in re.findall(r"\d+", last_line)]
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error interpreting response: {e}")
|
||||
return []
|
||||
|
||||
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
|
||||
if (
|
||||
submission.info.get("submission") is None
|
||||
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
|
||||
):
|
||||
return "Solution invalid."
|
||||
return Template(self.config.submission_template).render(
|
||||
**submission.to_format_dict(),
|
||||
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
|
||||
)
|
||||
|
||||
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
|
||||
instance_message = Template(self.config.instance_template).render(
|
||||
problem_statement=problem_statement,
|
||||
submissions=[self.format_submission(problem_statement, s) for s in input],
|
||||
)
|
||||
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
|
||||
return [
|
||||
{"role": "system", "content": self.config.system_template},
|
||||
{"role": "user", "content": instance_message},
|
||||
]
|
||||
|
||||
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> PreselectorOutput:
|
||||
messages = self.build_messages(problem_statement, input)
|
||||
response = self.model.query(messages)["message"] # type: ignore
|
||||
indices = self.interpret(response)
|
||||
if not indices:
|
||||
self.logger.warning("No indices found in response, using all indices")
|
||||
indices = list(range(len(input)))
|
||||
return PreselectorOutput(chosen_idx=indices, response=response, messages=messages)
|
||||
|
||||
|
||||
class Chooser:
|
||||
def __init__(self, config: ChooserConfig):
|
||||
self.config = config
|
||||
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
|
||||
self.logger = get_logger("chooser", emoji="🧠")
|
||||
# self.summarizer = Summarizer(config.summarizer, self.model) if config.summarizer else None
|
||||
|
||||
def interpret(self, response: str) -> int:
|
||||
# Use regex to extract the last number of the response
|
||||
try:
|
||||
return int(re.findall(r"\d+", response)[-1])
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error interpreting response: {e}")
|
||||
return 0
|
||||
|
||||
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
|
||||
if (
|
||||
submission.info.get("submission") is None
|
||||
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
|
||||
):
|
||||
return "Solution invalid."
|
||||
return Template(self.config.submission_template).render(
|
||||
**submission.to_format_dict(),
|
||||
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
|
||||
)
|
||||
|
||||
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
|
||||
instance_message = Template(self.config.instance_template).render(
|
||||
problem_statement=problem_statement,
|
||||
submissions=[self.format_submission(problem_statement, s) for s in input],
|
||||
)
|
||||
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
|
||||
return [
|
||||
{"role": "system", "content": self.config.system_template},
|
||||
{"role": "user", "content": instance_message},
|
||||
]
|
||||
|
||||
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> ChooserOutput:
|
||||
preselector_output = None
|
||||
selected_indices = list(range(len(input)))
|
||||
n_submitted = sum(s.info.get("exit_status", "") == "submitted" for s in input)
|
||||
if n_submitted >= 2:
|
||||
self.logger.debug(f"Got {n_submitted} submitted submissions, only using them")
|
||||
selected_indices = [i for i, s in enumerate(input) if s.info.get("exit_status", "") == "submitted"]
|
||||
else:
|
||||
self.logger.debug(f"Got only {n_submitted} submitted submissions, disabling exit status filtering")
|
||||
if self.config.preselector and len(selected_indices) > 2:
|
||||
preselector = Preselector(self.config.preselector)
|
||||
try:
|
||||
preselector_output = preselector.choose(problem_statement, [input[i] for i in selected_indices])
|
||||
except Exception as e:
|
||||
self.logger.critical(f"Preselector failed: {e}", exc_info=True)
|
||||
preselector_output = None
|
||||
if preselector_output and preselector_output.chosen_idx:
|
||||
try:
|
||||
_preselected_indices = [selected_indices[i] for i in preselector_output.chosen_idx]
|
||||
except IndexError:
|
||||
_preselected_indices = []
|
||||
self.logger.error("Preselector gave invalid indices, ignoring it.")
|
||||
if not _preselected_indices:
|
||||
self.logger.error("Preselector gave no valid indices, ignoring it.")
|
||||
else:
|
||||
selected_indices = _preselected_indices
|
||||
else:
|
||||
self.logger.error("Preselector must have failed, ignoring it.")
|
||||
messages = self.build_messages(problem_statement, [input[i] for i in selected_indices])
|
||||
chosen_idx = None
|
||||
try:
|
||||
response = self.model.query(messages)["message"] # type: ignore
|
||||
chosen_idx = self.interpret(response)
|
||||
except Exception as e:
|
||||
self.logger.critical(f"Chooser failed: {e}", exc_info=True)
|
||||
chosen_idx = None
|
||||
if chosen_idx is None or not (0 <= chosen_idx < len(selected_indices)):
|
||||
self.logger.error(f"Invalid chosen index: {chosen_idx}, using first index")
|
||||
chosen_idx = selected_indices[0]
|
||||
else:
|
||||
chosen_idx = selected_indices[chosen_idx]
|
||||
return ChooserOutput(
|
||||
chosen_idx=chosen_idx, response=response, preselector_output=preselector_output, messages=messages
|
||||
)
|
||||
|
||||
|
||||
class Reviewer(AbstractReviewer):
|
||||
def __init__(self, config: ReviewerConfig, model):
|
||||
self._config = config
|
||||
self._model = model
|
||||
self._traj_formatter = TrajectoryFormatter(config=config.traj_formatter)
|
||||
self.logger = get_logger("reviewer", emoji="🧑⚖️")
|
||||
|
||||
def format_messages(self, instance: ProblemStatement, submission: ReviewSubmission):
|
||||
system_message = self._config.system_template
|
||||
self.logger.debug(f"MODEL INPUT (system)\n{system_message}")
|
||||
ps_format_dict = {
|
||||
"problem_statement": instance.get_problem_statement(),
|
||||
**instance.get_extra_fields(),
|
||||
}
|
||||
user_message = Template(self._config.instance_template).render(
|
||||
**ps_format_dict,
|
||||
**submission.to_format_dict(),
|
||||
traj=self._traj_formatter.format_trajectory(submission.trajectory),
|
||||
)
|
||||
self.logger.debug(f"MODEL INPUT (user)\n{user_message}")
|
||||
return [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
def interpret(self, response: str) -> bool | float:
|
||||
last_line = response.strip().split("\n")[-1].strip()
|
||||
# Find all numbers in the last line and take the last one
|
||||
numbers = re.findall(r"-?\d+\.?\d*", last_line)
|
||||
if not numbers:
|
||||
msg = f"Could not interpret response: {last_line!r}"
|
||||
raise ValueError(msg)
|
||||
number = float(numbers[-1])
|
||||
if self._config.score_range[0] is not None and number < self._config.score_range[0]:
|
||||
msg = f"Score {number} is below the minimum score {self._config.score_range[0]}"
|
||||
raise ValueError(msg)
|
||||
if self._config.score_range[1] is not None and number > self._config.score_range[1]:
|
||||
msg = f"Score {number} is above the maximum score {self._config.score_range[1]}"
|
||||
raise ValueError(msg)
|
||||
return number
|
||||
|
||||
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
|
||||
exit_status = submission.info.get("exit_status")
|
||||
messages = []
|
||||
penalty = 0.0
|
||||
if not exit_status or exit_status.strip() != "submitted":
|
||||
penalty = self._config.failure_score_penalty
|
||||
messages = self.format_messages(instance, submission)
|
||||
if self._config.n_sample > 1:
|
||||
_set_cache_control(messages[-1]) # type: ignore
|
||||
answers = []
|
||||
accepts = []
|
||||
for _ in range(self._config.n_sample):
|
||||
try:
|
||||
answer = self._model.query(messages)["message"]
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Query failed: {e}", exc_info=True)
|
||||
continue
|
||||
try:
|
||||
score = self.interpret(answer)
|
||||
except ValueError as e:
|
||||
self.logger.warning(f"Could not interpret response: {answer!r}, got {e}")
|
||||
continue
|
||||
answers.append(answer)
|
||||
accepts.append(score)
|
||||
if not accepts:
|
||||
answers = ["No valid scores found, failing submission"]
|
||||
accepts = [-100.0]
|
||||
accept = sum(accepts) / len(accepts) - penalty
|
||||
std = np.std(accepts).item()
|
||||
if self._config.reduce_by_std > 0:
|
||||
accept -= std * self._config.reduce_by_std
|
||||
self.logger.info(f"First answer: {answers[0]}")
|
||||
self.logger.info(f"Final score: {accept} (penalty: {penalty}, std: {std}), individual: {accepts}")
|
||||
return ReviewerResult(accept=accept, outputs=answers, messages=messages)
|
||||
|
||||
|
||||
# todo: Couldn't I just replace the whole thing with Jinja templates?
|
||||
|
||||
|
||||
class TrajectoryFormatter:
|
||||
def __init__(
|
||||
self,
|
||||
config: TrajFormatterConfig,
|
||||
):
|
||||
"""Formats trajectories for the use in prompts"""
|
||||
self._config = config
|
||||
|
||||
def _include_step(self, item: TrajectoryStep) -> bool:
|
||||
action = item["action"].strip()
|
||||
for f in self._config.filter:
|
||||
if action.startswith(f):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _include_step_output(self, item: TrajectoryStep, i_step: int, n_steps: int) -> bool:
|
||||
if self._config.only_show_last_n_output > 0 and i_step < n_steps - self._config.only_show_last_n_output:
|
||||
return False
|
||||
action = item["action"].strip()
|
||||
for f in self._config.output_filter:
|
||||
if action.startswith(f):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _format_trajectory_step(self, step: TrajectoryStep, i_step: int, *, n_steps: int, i_traj: int = 1) -> str:
|
||||
step = copy.deepcopy(step)
|
||||
if not self._include_step_output(step, i_step, n_steps=n_steps):
|
||||
step["observation"] = "[Output omitted]"
|
||||
return Template(self._config.item_template).render(
|
||||
**step,
|
||||
i_step=i_step,
|
||||
i_traj=i_traj,
|
||||
)
|
||||
|
||||
def format_trajectory(self, trajectory: Trajectory, i_traj: int = 1) -> str:
|
||||
traj_messages = [step for step in trajectory if self._include_step(step)]
|
||||
return "\n\n".join(
|
||||
[
|
||||
self._format_trajectory_step(step, i_step, i_traj=i_traj, n_steps=len(traj_messages))
|
||||
for i_step, step in enumerate(traj_messages)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ChooserRetryLoop(AbstractRetryLoop):
|
||||
def __init__(self, config: ChooserRetryLoopConfig, problem_statement: ProblemStatement):
|
||||
self._config = config
|
||||
self._problem_statement = problem_statement
|
||||
self._chooser = Chooser(config.chooser)
|
||||
self._submissions: list[ReviewSubmission] = []
|
||||
self._n_consec_exit_cost: int = 0
|
||||
self.logger = get_logger("chooser_loop", emoji="🔄")
|
||||
self._chooser_output: ChooserOutput | None = None
|
||||
|
||||
@property
|
||||
def _total_stats(self) -> InstanceStats:
|
||||
return sum((s.model_stats for s in self._submissions), start=InstanceStats())
|
||||
|
||||
@property
|
||||
def review_model_stats(self) -> InstanceStats:
|
||||
return InstanceStats()
|
||||
|
||||
@property
|
||||
def _n_attempts(self) -> int:
|
||||
return len(self._submissions)
|
||||
|
||||
def on_submit(self, submission: ReviewSubmission) -> None:
|
||||
self._submissions.append(submission)
|
||||
|
||||
def retry(self) -> bool:
|
||||
stat_str = f"n_samples={self._n_attempts}"
|
||||
if self._total_stats.instance_cost > self._config.cost_limit > 0:
|
||||
self.logger.info(
|
||||
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
|
||||
f"exceeds cost limit ({self._config.cost_limit})"
|
||||
)
|
||||
return False
|
||||
|
||||
if self._n_attempts >= self._config.max_attempts > 0:
|
||||
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
|
||||
return False
|
||||
|
||||
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
|
||||
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
|
||||
msg = (
|
||||
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
|
||||
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
|
||||
)
|
||||
self.logger.info(msg)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_best(self) -> int | None:
|
||||
"""Important note: This is cached. Only call this at the end."""
|
||||
if self._chooser_output is not None:
|
||||
return self._chooser_output.chosen_idx
|
||||
if len(self._submissions) == 0:
|
||||
return None
|
||||
self._chooser_output = self._chooser.choose(self._problem_statement.get_problem_statement(), self._submissions)
|
||||
return self._chooser_output.chosen_idx
|
||||
|
||||
|
||||
# todo: The model shouldn't be defined here, it should be defined as part of the scorer
|
||||
class ScoreRetryLoop(AbstractRetryLoop):
|
||||
def __init__(
|
||||
self,
|
||||
config: ScoreRetryLoopConfig,
|
||||
problem_statement: ProblemStatement,
|
||||
):
|
||||
# This model will not share instance cost with the parent agent
|
||||
self._model = get_model(config.model, tools=ToolConfig())
|
||||
self._problem_statement = problem_statement
|
||||
self._reviewer: AbstractReviewer = config.reviewer_config.get_reviewer(self._model)
|
||||
self._config = config
|
||||
# Note: These are "cumulative" submissions, i.e., they include all retries
|
||||
# up to that point.
|
||||
self._submissions: list[ReviewSubmission] = []
|
||||
self._reviews: list[ReviewerResult] = []
|
||||
#: Number of consecutive exit cost submissions
|
||||
self._n_consec_exit_cost: int = 0
|
||||
self.logger = get_logger("review_loop", emoji="🔄")
|
||||
|
||||
# Properties
|
||||
# ----------
|
||||
|
||||
@property
|
||||
def review_model_stats(self) -> InstanceStats:
|
||||
return self._model.stats
|
||||
|
||||
@property
|
||||
def reviews(self) -> list[ReviewerResult]:
|
||||
return self._reviews
|
||||
|
||||
@property
|
||||
def _n_attempts(self) -> int:
|
||||
return len(self._submissions)
|
||||
|
||||
@property
|
||||
def _n_accepted(self) -> int:
|
||||
return sum(r.accept >= self._config.accept_score for r in self._reviews)
|
||||
|
||||
@property
|
||||
def _total_stats(self) -> InstanceStats:
|
||||
return sum((s.model_stats for s in self._submissions), start=InstanceStats()) + self._model.stats
|
||||
|
||||
# -------
|
||||
|
||||
def on_submit(self, submission: ReviewSubmission) -> None:
|
||||
self._submissions.append(submission)
|
||||
self._review()
|
||||
|
||||
def _review(self) -> float:
|
||||
review = self._reviewer.review(self._problem_statement, self._submissions[-1])
|
||||
self._reviews.append(review)
|
||||
exit_status = self._submissions[-1].info.get("exit_status", "")
|
||||
if exit_status and "exit_cost" in exit_status.lower():
|
||||
self._n_consec_exit_cost += 1
|
||||
else:
|
||||
self._n_consec_exit_cost = 0
|
||||
return review.accept
|
||||
|
||||
def retry(self) -> bool:
|
||||
max_score = max([r.accept for r in self._reviews], default=-100.0)
|
||||
stat_str = f"n_samples={self._n_attempts}, max_score={max_score}, n_accepted={self._n_accepted}"
|
||||
|
||||
if self._total_stats.instance_cost > self._config.cost_limit > 0:
|
||||
self.logger.info(
|
||||
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
|
||||
f"exceeds cost limit ({self._config.cost_limit})"
|
||||
)
|
||||
return False
|
||||
|
||||
if self._n_attempts >= self._config.max_attempts > 0:
|
||||
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
|
||||
return False
|
||||
|
||||
if self._n_accepted >= self._config.max_accepts > 0:
|
||||
self.logger.info(f"Exiting retry loop ({stat_str}): max_accepts={self._config.max_accepts} reached")
|
||||
return False
|
||||
|
||||
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
|
||||
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
|
||||
msg = (
|
||||
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
|
||||
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
|
||||
)
|
||||
self.logger.info(msg)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_best(self) -> int | None:
|
||||
if len(self._reviews) == 0:
|
||||
return None
|
||||
scores = [r.accept for r in self._reviews]
|
||||
self.logger.debug(f"Scores: {scores}")
|
||||
max_score = np.max(scores)
|
||||
max_indices = [i for i, s in enumerate(scores) if np.isclose(s, max_score)]
|
||||
# If there are multiple submissions with the same score, choose the shortest one
|
||||
max_indices = sorted(max_indices, key=lambda i: self._submissions[i].model_stats.api_calls or float("inf"))
|
||||
chosen_idx = max_indices[0]
|
||||
self.logger.info(f"Best submission: {chosen_idx}")
|
||||
return chosen_idx
|
||||
|
||||
|
||||
def get_retry_loop_from_config(
|
||||
config: RetryLoopConfig, problem_statement: ProblemStatement
|
||||
) -> ScoreRetryLoop | ChooserRetryLoop:
|
||||
return config.get_retry_loop(problem_statement=problem_statement)
|
||||
0
.agent/vendor/mini-swe/sweagent/environment/__init__.py
vendored
Normal file
0
.agent/vendor/mini-swe/sweagent/environment/hooks/__init__.py
vendored
Normal file
60
.agent/vendor/mini-swe/sweagent/environment/hooks/abstract.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
from sweagent.environment.repo import Repo, RepoConfig
|
||||
|
||||
|
||||
class EnvHook:
|
||||
"""Hook to be used in `SWEEnv`.
|
||||
|
||||
Subclass this class, add functionality and add it with `SWEEEnv.add_hook(hook)`.
|
||||
This allows to inject custom functionality at different stages of the environment
|
||||
lifecycle, in particular to connect SWE-agent to a new interface (like a GUI).
|
||||
"""
|
||||
|
||||
def on_init(self, *, env) -> None:
|
||||
"""Gets called when the hook is added"""
|
||||
|
||||
def on_copy_repo_started(self, repo: RepoConfig | Repo) -> None:
|
||||
"""Gets called when the repository is being cloned to the container"""
|
||||
|
||||
def on_start_deployment(self) -> None:
|
||||
"""Gets called when the deployment is being started"""
|
||||
|
||||
def on_install_env_started(self) -> None:
|
||||
"""Called when we start installing the environment"""
|
||||
|
||||
def on_close(self):
|
||||
"""Called when the environment is closed"""
|
||||
|
||||
def on_environment_startup(self) -> None:
|
||||
"""Called when the environment is started"""
|
||||
|
||||
|
||||
class CombinedEnvHooks(EnvHook):
|
||||
def __init__(self):
|
||||
self._hooks = []
|
||||
|
||||
def add_hook(self, hook: EnvHook) -> None:
|
||||
self._hooks.append(hook)
|
||||
|
||||
def on_init(self, *, env) -> None:
|
||||
for hook in self._hooks:
|
||||
hook.on_init(env=env)
|
||||
|
||||
def on_copy_repo_started(self, repo: RepoConfig | Repo) -> None:
|
||||
for hook in self._hooks:
|
||||
hook.on_copy_repo_started(repo=repo)
|
||||
|
||||
def on_start_deployment(self) -> None:
|
||||
for hook in self._hooks:
|
||||
hook.on_start_deployment()
|
||||
|
||||
def on_install_env_started(self) -> None:
|
||||
for hook in self._hooks:
|
||||
hook.on_install_env_started()
|
||||
|
||||
def on_close(self):
|
||||
for hook in self._hooks:
|
||||
hook.on_close()
|
||||
|
||||
def on_environment_startup(self) -> None:
|
||||
for hook in self._hooks:
|
||||
hook.on_environment_startup()
|
||||
28
.agent/vendor/mini-swe/sweagent/environment/hooks/status.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sweagent.environment.hooks.abstract import EnvHook
|
||||
from sweagent.environment.repo import Repo, RepoConfig
|
||||
|
||||
|
||||
class SetStatusEnvironmentHook(EnvHook):
|
||||
def __init__(self, id: str, callable: Callable[[str, str], None]):
|
||||
self._callable = callable
|
||||
self._id = id
|
||||
|
||||
def _update(self, message: str):
|
||||
self._callable(self._id, message)
|
||||
|
||||
def on_copy_repo_started(self, repo: RepoConfig | Repo):
|
||||
self._update(f"Copying repo {repo.repo_name}")
|
||||
|
||||
def on_start_deployment(self):
|
||||
self._update("Starting deployment")
|
||||
|
||||
def on_install_env_started(self):
|
||||
self._update("Installing environment")
|
||||
|
||||
def on_environment_startup(self):
|
||||
self._update("Starting environment")
|
||||
|
||||
def on_close(self):
|
||||
self._update("Closing environment")
|
||||
258
.agent/vendor/mini-swe/sweagent/environment/repo.py
vendored
Normal file
@@ -0,0 +1,258 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from git import InvalidGitRepositoryError
|
||||
from git import Repo as GitRepo
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from swerex.deployment.abstract import AbstractDeployment
|
||||
from swerex.runtime.abstract import Command, UploadRequest
|
||||
from typing_extensions import Self
|
||||
|
||||
from sweagent.utils.github import _parse_gh_repo_url
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
logger = get_logger("swea-config", emoji="🔧")
|
||||
|
||||
|
||||
class Repo(Protocol):
|
||||
"""Protocol for repository configurations."""
|
||||
|
||||
base_commit: str
|
||||
repo_name: str
|
||||
|
||||
def copy(self, deployment: AbstractDeployment): ...
|
||||
|
||||
def get_reset_commands(self) -> list[str]: ...
|
||||
|
||||
|
||||
def _get_git_reset_commands(base_commit: str) -> list[str]:
|
||||
return [
|
||||
"git fetch",
|
||||
"git status",
|
||||
"git restore .",
|
||||
"git reset --hard",
|
||||
f"git checkout {shlex.quote(base_commit)}",
|
||||
"git clean -fdq",
|
||||
]
|
||||
|
||||
|
||||
class PreExistingRepoConfig(BaseModel):
|
||||
"""Use this to specify a repository that already exists on the deployment.
|
||||
This is important because we need to cd to the repo before running the agent.
|
||||
|
||||
Note: The repository must be at the root of the deployment.
|
||||
"""
|
||||
|
||||
repo_name: str
|
||||
"""The repo name (the repository must be located at the root of the deployment)."""
|
||||
base_commit: str = Field(default="HEAD")
|
||||
"""The commit to reset the repository to. The default is HEAD,
|
||||
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
|
||||
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
|
||||
SWE-agent will then start from this commit when trying to solve the problem.
|
||||
"""
|
||||
|
||||
type: Literal["preexisting"] = "preexisting"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
reset: bool = True
|
||||
"""If True, reset the repository to the base commit after the copy operation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def copy(self, deployment: AbstractDeployment):
|
||||
"""Does nothing."""
|
||||
pass
|
||||
|
||||
def get_reset_commands(self) -> list[str]:
|
||||
"""Issued after the copy operation or when the environment is reset."""
|
||||
if self.reset:
|
||||
return _get_git_reset_commands(self.base_commit)
|
||||
return []
|
||||
|
||||
|
||||
class LocalRepoConfig(BaseModel):
|
||||
path: Path
|
||||
base_commit: str = Field(default="HEAD")
|
||||
"""The commit to reset the repository to. The default is HEAD,
|
||||
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
|
||||
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
|
||||
SWE-agent will then start from this commit when trying to solve the problem.
|
||||
"""
|
||||
|
||||
type: Literal["local"] = "local"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
"""Set automatically based on the repository name. Cannot be set."""
|
||||
return Path(self.path).resolve().name.replace(" ", "-").replace("'", "")
|
||||
|
||||
# Let's not make this a model validator, because it leads to cryptic errors.
|
||||
# Let's just check during copy instead.
|
||||
def check_valid_repo(self) -> Self:
|
||||
try:
|
||||
repo = GitRepo(self.path, search_parent_directories=True)
|
||||
except InvalidGitRepositoryError as e:
|
||||
msg = f"Could not find git repository at {self.path=}."
|
||||
raise ValueError(msg) from e
|
||||
if repo.is_dirty() and "PYTEST_CURRENT_TEST" not in os.environ:
|
||||
msg = f"Local git repository {self.path} is dirty. Please commit or stash changes."
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
def copy(self, deployment: AbstractDeployment):
|
||||
self.check_valid_repo()
|
||||
asyncio.run(
|
||||
deployment.runtime.upload(UploadRequest(source_path=str(self.path), target_path=f"/{self.repo_name}"))
|
||||
)
|
||||
r = asyncio.run(
|
||||
deployment.runtime.execute(Command(command=f"chown -R root:root /{self.repo_name}", shell=True))
|
||||
)
|
||||
if r.exit_code != 0:
|
||||
msg = f"Failed to change permissions on copied repository (exit code: {r.exit_code}, stdout: {r.stdout}, stderr: {r.stderr})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def get_reset_commands(self) -> list[str]:
|
||||
"""Issued after the copy operation or when the environment is reset."""
|
||||
return _get_git_reset_commands(self.base_commit)
|
||||
|
||||
|
||||
class GithubRepoConfig(BaseModel):
|
||||
github_url: str
|
||||
|
||||
base_commit: str = Field(default="HEAD")
|
||||
"""The commit to reset the repository to. The default is HEAD,
|
||||
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
|
||||
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
|
||||
SWE-agent will then start from this commit when trying to solve the problem.
|
||||
"""
|
||||
|
||||
clone_timeout: float = 500
|
||||
"""Timeout for git clone operation."""
|
||||
|
||||
type: Literal["github"] = "github"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.github_url.count("/") == 1:
|
||||
self.github_url = f"https://github.com/{self.github_url}"
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
org, repo = _parse_gh_repo_url(self.github_url)
|
||||
return f"{org}__{repo}"
|
||||
|
||||
def _get_url_with_token(self, token: str) -> str:
|
||||
"""Prepend github token to URL"""
|
||||
if not token:
|
||||
return self.github_url
|
||||
if "@" in self.github_url:
|
||||
logger.warning("Cannot prepend token to URL. '@' found in URL")
|
||||
return self.github_url
|
||||
_, _, url_no_protocol = self.github_url.partition("://")
|
||||
return f"https://{token}@{url_no_protocol}"
|
||||
|
||||
def copy(self, deployment: AbstractDeployment):
|
||||
"""Clones the repository to the sandbox."""
|
||||
base_commit = self.base_commit
|
||||
github_token = os.getenv("GITHUB_TOKEN", "")
|
||||
url = self._get_url_with_token(github_token)
|
||||
asyncio.run(
|
||||
deployment.runtime.execute(
|
||||
Command(
|
||||
command=" && ".join(
|
||||
(
|
||||
f"mkdir /{self.repo_name}",
|
||||
f"cd /{self.repo_name}",
|
||||
"git init",
|
||||
f"git remote add origin {shlex.quote(url)}",
|
||||
f"git fetch --depth 1 origin {shlex.quote(base_commit)}",
|
||||
"git checkout FETCH_HEAD",
|
||||
"cd ..",
|
||||
)
|
||||
),
|
||||
timeout=self.clone_timeout,
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def get_reset_commands(self) -> list[str]:
|
||||
"""Issued after the copy operation or when the environment is reset."""
|
||||
return _get_git_reset_commands(self.base_commit)
|
||||
|
||||
|
||||
class SWESmithRepoConfig(BaseModel):
|
||||
"""Repository config for SWE-Smith instances that handles targeted fetch
|
||||
from a GitHub mirror, authenticating via GITHUB_TOKEN when needed.
|
||||
"""
|
||||
|
||||
repo_name: str
|
||||
base_commit: str = Field(default="HEAD")
|
||||
mirror_url: str = ""
|
||||
"""HTTPS URL of the GitHub mirror to fetch the bug branch from."""
|
||||
|
||||
type: Literal["swesmith_preexisting"] = "swesmith_preexisting"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def copy(self, deployment: AbstractDeployment):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_url_with_token(url: str, token: str) -> str:
|
||||
if not token or not url:
|
||||
return url
|
||||
_, _, url_no_protocol = url.partition("://")
|
||||
return f"https://{token}@{url_no_protocol}"
|
||||
|
||||
def get_reset_commands(self) -> list[str]:
|
||||
if self.mirror_url:
|
||||
github_token = os.getenv("GITHUB_TOKEN", "")
|
||||
url = self._get_url_with_token(self.mirror_url, github_token)
|
||||
return [
|
||||
"git restore .",
|
||||
"git reset --hard",
|
||||
f"git fetch {shlex.quote(url)} {shlex.quote(self.base_commit)}",
|
||||
"git checkout FETCH_HEAD",
|
||||
"git clean -fdq",
|
||||
]
|
||||
return _get_git_reset_commands(self.base_commit)
|
||||
|
||||
|
||||
RepoConfig = LocalRepoConfig | GithubRepoConfig | PreExistingRepoConfig | SWESmithRepoConfig
|
||||
|
||||
|
||||
def repo_from_simplified_input(
|
||||
*, input: str, base_commit: str = "HEAD", type: Literal["local", "github", "preexisting", "auto"] = "auto"
|
||||
) -> RepoConfig:
|
||||
"""Get repo config from a simplified input.
|
||||
|
||||
Args:
|
||||
input: Local path or GitHub URL
|
||||
type: The type of repo. Set to "auto" to automatically detect the type
|
||||
(does not work for preexisting repos).
|
||||
"""
|
||||
if type == "local":
|
||||
return LocalRepoConfig(path=Path(input), base_commit=base_commit)
|
||||
if type == "github":
|
||||
return GithubRepoConfig(github_url=input, base_commit=base_commit)
|
||||
if type == "preexisting":
|
||||
return PreExistingRepoConfig(repo_name=input, base_commit=base_commit)
|
||||
if type == "auto":
|
||||
if input.startswith("https://github.com/"):
|
||||
return GithubRepoConfig(github_url=input, base_commit=base_commit)
|
||||
else:
|
||||
return LocalRepoConfig(path=Path(input), base_commit=base_commit)
|
||||
msg = f"Unknown repo type: {type}"
|
||||
raise ValueError(msg)
|
||||
276
.agent/vendor/mini-swe/sweagent/environment/swe_env.py
vendored
Normal file
@@ -0,0 +1,276 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import shlex
|
||||
from pathlib import PurePath
|
||||
from typing import Literal, Self
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from swerex.deployment.abstract import AbstractDeployment
|
||||
from swerex.deployment.config import DeploymentConfig, DockerDeploymentConfig, get_deployment
|
||||
from swerex.runtime.abstract import (
|
||||
BashAction,
|
||||
BashInterruptAction,
|
||||
CreateBashSessionRequest,
|
||||
ReadFileRequest,
|
||||
WriteFileRequest,
|
||||
)
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
|
||||
from sweagent.environment.hooks.abstract import CombinedEnvHooks, EnvHook
|
||||
from sweagent.environment.repo import Repo, RepoConfig
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class EnvironmentConfig(BaseModel):
|
||||
"""Configure data sources and setup instructions for the environment in which we solve the tasks."""
|
||||
|
||||
deployment: DeploymentConfig = Field(
|
||||
default_factory=lambda: DockerDeploymentConfig(image="python:3.11", python_standalone_dir="/root"),
|
||||
description="Deployment options.",
|
||||
)
|
||||
repo: RepoConfig | None = Field(
|
||||
default=None,
|
||||
description="Repository options.",
|
||||
)
|
||||
post_startup_commands: list[str] = []
|
||||
"""Execute these commands before starting to run the agent but after all other setup steps.
|
||||
They will be executed in the same shell as the agent.
|
||||
Note: Every command is passed as a string, not a list of arguments.
|
||||
"""
|
||||
post_startup_command_timeout: int = 500
|
||||
"""Timeout for the post-startup commands.
|
||||
NOTE: The timeout applies to every command in `post_startup_commands` separately.
|
||||
"""
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = "main"
|
||||
|
||||
|
||||
class SWEEnv:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
deployment: AbstractDeployment,
|
||||
repo: Repo | RepoConfig | None,
|
||||
post_startup_commands: list[str],
|
||||
post_startup_command_timeout: int = 500,
|
||||
hooks: list[EnvHook] | None = None,
|
||||
name: str = "main",
|
||||
):
|
||||
"""This class represents the environment in which we solve the tasks.
|
||||
|
||||
Args:
|
||||
deployment: SWE-ReX deployment instance
|
||||
repo: Repository configuration object, or anything following the `Repo` protocol
|
||||
post_startup_commands: Commands to execute before starting the agent
|
||||
hooks: Environment hooks (used to inject custom functionality)
|
||||
Equivalent to calling `add_hook` for each hook after initialization.
|
||||
name: Name of the environment
|
||||
"""
|
||||
super().__init__()
|
||||
self.deployment = deployment
|
||||
self.repo = repo
|
||||
self._post_startup_commands = post_startup_commands
|
||||
self.post_startup_command_timeout = post_startup_command_timeout
|
||||
self.logger = get_logger("swea-env", emoji="🪴")
|
||||
self.name = name
|
||||
self.clean_multi_line_functions = lambda x: x
|
||||
self._chook = CombinedEnvHooks()
|
||||
for hook in hooks or []:
|
||||
self.add_hook(hook)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: EnvironmentConfig) -> Self:
|
||||
"""Create an environment instance from a configuration object.
|
||||
This is the recommended way to create an environment instance, unless you need
|
||||
more flexibility.
|
||||
"""
|
||||
# Always copy config to avoid shared state between different instances
|
||||
config = config.model_copy(deep=True)
|
||||
return cls(
|
||||
deployment=get_deployment(config.deployment),
|
||||
repo=config.repo,
|
||||
post_startup_commands=config.post_startup_commands,
|
||||
post_startup_command_timeout=config.post_startup_command_timeout,
|
||||
name=config.name,
|
||||
)
|
||||
|
||||
def add_hook(self, hook: EnvHook) -> None:
|
||||
"""Add `EnvHook` to the environment.
|
||||
|
||||
This allows to inject custom functionality at different stages of the environment
|
||||
lifecycle, in particular to connect SWE-agent to a new interface (like a GUI).
|
||||
"""
|
||||
hook.on_init(env=self)
|
||||
self._chook.add_hook(hook)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the environment and reset it to a clean state."""
|
||||
self._init_deployment()
|
||||
self.reset()
|
||||
for command in self._post_startup_commands:
|
||||
self.communicate(command, check="raise", timeout=self.post_startup_command_timeout)
|
||||
|
||||
def _copy_repo(self) -> None:
|
||||
"""Clone/copy repository/codebase in container"""
|
||||
if self.repo is None:
|
||||
return
|
||||
|
||||
folders = self.communicate(input="ls", check="raise").split("\n")
|
||||
if self.repo.repo_name in folders:
|
||||
return
|
||||
|
||||
self._chook.on_copy_repo_started(repo=self.repo)
|
||||
self.repo.copy(self.deployment)
|
||||
|
||||
def hard_reset(self):
|
||||
"""Resets the environment and deployment, i.e., completely restarts the
|
||||
deployment.
|
||||
"""
|
||||
self.close()
|
||||
self.start()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to a clean state.
|
||||
Gets called by `start`, but can also be called independently to reset the
|
||||
environment to a clean state before a new attempt.
|
||||
|
||||
Returns:
|
||||
observation: output from container
|
||||
info: additional information (e.g. debugging information)
|
||||
"""
|
||||
self.communicate(input="cd /", check="raise")
|
||||
self._copy_repo()
|
||||
self._reset_repository()
|
||||
self._chook.on_environment_startup()
|
||||
|
||||
def _reset_repository(self) -> None:
|
||||
"""Clean repository of any modifications + Checkout base commit"""
|
||||
if self.repo is not None:
|
||||
self.logger.debug("Resetting repository %s to commit %s", self.repo.repo_name, self.repo.base_commit)
|
||||
# todo: Currently has swe-ft specific change: The original repo.copy isn't called, because the repo is already
|
||||
# present. However, reset --hard <BRANCH> also doesn't work. So modified it here to do a checkout instead.
|
||||
startup_commands = [
|
||||
f"cd /{self.repo.repo_name}",
|
||||
"export ROOT=$(pwd -P)",
|
||||
*self.repo.get_reset_commands(),
|
||||
]
|
||||
self.communicate(
|
||||
input=" && ".join(startup_commands),
|
||||
check="raise",
|
||||
error_msg="Failed to clean repository",
|
||||
# Sometimes this is slow because it rebuilds some index
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Shutdown SWE-ReX deployment etc."""
|
||||
self.logger.info("Beginning environment shutdown...")
|
||||
asyncio.run(self.deployment.stop())
|
||||
self._chook.on_close()
|
||||
|
||||
# MARK: Helper functions #
|
||||
|
||||
def _init_deployment(
|
||||
self,
|
||||
) -> None:
|
||||
"""Handles container initialization. Defines container name and creates it.
|
||||
If cached_image is provided, it will use that image name instead of the default.
|
||||
"""
|
||||
self._chook.on_start_deployment()
|
||||
asyncio.run(self.deployment.start())
|
||||
asyncio.run(
|
||||
self.deployment.runtime.create_session(
|
||||
CreateBashSessionRequest(startup_source=["/root/.bashrc"], startup_timeout=10)
|
||||
)
|
||||
)
|
||||
self.set_env_variables({"LANG": "C.UTF-8", "LC_ALL": "C.UTF-8", "PIP_PROGRESS_BAR": "off", "PAGER": "cat"})
|
||||
self.logger.info("Environment Initialized")
|
||||
|
||||
def interrupt_session(self):
|
||||
self.logger.info("Interrupting session")
|
||||
asyncio.run(self.deployment.runtime.run_in_session(BashInterruptAction()))
|
||||
|
||||
# todo: return exit code?
|
||||
def communicate(
|
||||
self,
|
||||
input: str,
|
||||
timeout: int | float = 25,
|
||||
*,
|
||||
check: Literal["warn", "ignore", "raise"] = "ignore",
|
||||
error_msg: str = "Command failed",
|
||||
) -> str:
|
||||
"""Executes a command in the running shell. The details of this are handled by
|
||||
the SWE-ReX deployment/runtime.
|
||||
|
||||
Args:
|
||||
input: input to send to container
|
||||
timeout_duration: duration to wait for output
|
||||
check: `ignore`: do not extract exit code (more stable), `warn`: extract exit code and log error if
|
||||
exit code is non-zero, `raise`: raise error if exit code is non-zero
|
||||
error_msg: error message to raise if the command fails
|
||||
|
||||
Returns:
|
||||
output: output from container
|
||||
"""
|
||||
self.logger.log(logging.TRACE, "Input:\n%s", input) # type: ignore
|
||||
rex_check = "silent" if check else "ignore"
|
||||
r = asyncio.run(
|
||||
self.deployment.runtime.run_in_session(BashAction(command=input, timeout=timeout, check=rex_check))
|
||||
)
|
||||
output = r.output
|
||||
self.logger.log(logging.TRACE, "Output:\n%s", output) # type: ignore
|
||||
if check != "ignore" and r.exit_code != 0:
|
||||
self.logger.error(f"{error_msg}:\n{output}")
|
||||
msg = f"Command {input!r} failed ({r.exit_code=}): {error_msg}"
|
||||
self.logger.error(msg)
|
||||
if check == "raise":
|
||||
self.close()
|
||||
raise RuntimeError(msg)
|
||||
return output
|
||||
|
||||
def read_file(self, path: str | PurePath, encoding: str | None = None, errors: str | None = None) -> str:
|
||||
"""Read file contents from container
|
||||
|
||||
Args:
|
||||
path: Absolute path to file
|
||||
encoding: Encoding to use when reading the file. None means default encoding.
|
||||
This is the same as the `encoding` argument of `Path.read_text()`
|
||||
errors: Error handling to use when reading the file. None means default error handling.
|
||||
This is the same as the `errors` argument of `Path.read_text()`
|
||||
|
||||
Returns:
|
||||
file_contents: Contents of file as string
|
||||
"""
|
||||
r = asyncio.run(
|
||||
self.deployment.runtime.read_file(ReadFileRequest(path=str(path), encoding=encoding, errors=errors))
|
||||
)
|
||||
return r.content
|
||||
|
||||
def write_file(self, path: str | PurePath, content: str) -> None:
|
||||
"""Write content to file in container"""
|
||||
asyncio.run(self.deployment.runtime.write_file(WriteFileRequest(path=str(path), content=content)))
|
||||
|
||||
def set_env_variables(self, env_variables: dict[str, str]) -> None:
|
||||
"""Set environment variables in the environment."""
|
||||
if not env_variables:
|
||||
self.logger.debug("No environment variables to set")
|
||||
return
|
||||
_env_setters = [f"export {k}={shlex.quote(str(v))}" for k, v in env_variables.items()]
|
||||
command = " && ".join(_env_setters)
|
||||
self.communicate(command, check="raise")
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
shell: bool = True,
|
||||
check: bool = False,
|
||||
env: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Execute a command in the environment independent of the session (i.e., as a subprocess)"""
|
||||
asyncio.run(
|
||||
self.deployment.runtime.execute(RexCommand(command=command, shell=shell, check=check, env=env, cwd=cwd))
|
||||
)
|
||||
54
.agent/vendor/mini-swe/sweagent/exceptions.py
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
"""This module contains all custom exceptions used by the SWE-agent."""
|
||||
|
||||
|
||||
class FormatError(Exception):
|
||||
"""Raised when the model response cannot properly be parsed into thought and actions."""
|
||||
|
||||
|
||||
class FunctionCallingFormatError(FormatError):
|
||||
"""Format error exception used by the function
|
||||
calling parser."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_code: Literal[
|
||||
"missing", "multiple", "incorrect_args", "invalid_json", "invalid_command", "missing_arg", "unexpected_arg"
|
||||
],
|
||||
**extra_info: Any,
|
||||
):
|
||||
super().__init__(message + f" [error_code={error_code}]")
|
||||
self.message = message
|
||||
self.extra_info = {"error_code": error_code, **extra_info}
|
||||
|
||||
|
||||
class ContextWindowExceededError(Exception):
|
||||
"""Raised when the context window of a LM is exceeded"""
|
||||
|
||||
|
||||
class CostLimitExceededError(Exception):
|
||||
"""Raised when we exceed a cost limit"""
|
||||
|
||||
|
||||
class InstanceCostLimitExceededError(CostLimitExceededError):
|
||||
"""Raised when we exceed the cost limit set for one task instance"""
|
||||
|
||||
|
||||
class TotalCostLimitExceededError(CostLimitExceededError):
|
||||
"""Raised when we exceed the total cost limit"""
|
||||
|
||||
|
||||
class InstanceCallLimitExceededError(CostLimitExceededError):
|
||||
"""Raised when we exceed the per instance call limit"""
|
||||
|
||||
|
||||
class ContentPolicyViolationError(Exception):
|
||||
"""Raised when the model response violates a content policy"""
|
||||
|
||||
|
||||
class ModelConfigurationError(Exception):
|
||||
"""Raised when the model configuration is invalid/no further retries
|
||||
should be made.
|
||||
"""
|
||||
6
.agent/vendor/mini-swe/sweagent/inspector/README.md
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
🔗 For more information on the trajectory inspector, visit [our documentation website][docs].
|
||||
|
||||
You can also find the corresponding markdown files in the [`docs/` folder][source].
|
||||
|
||||
[docs]: https://swe-agent.com/latest/usage/inspector/
|
||||
[source]: https://github.com/SWE-agent/SWE-agent/tree/main/docs
|
||||
0
.agent/vendor/mini-swe/sweagent/inspector/__init__.py
vendored
Normal file
BIN
.agent/vendor/mini-swe/sweagent/inspector/favicon.ico
vendored
Normal file
|
After Width: | Height: | Size: 264 KiB |
354
.agent/vendor/mini-swe/sweagent/inspector/fileViewer.js
vendored
Normal file
@@ -0,0 +1,354 @@
|
||||
let currentFileName = null;
|
||||
let trajectoryDirectory = "";
|
||||
let timeoutIds = [];
|
||||
|
||||
function getBaseUrl() {
|
||||
const protocol = window.location.protocol;
|
||||
const host = window.location.hostname;
|
||||
const port = window.location.port;
|
||||
const defaultPort =
|
||||
protocol === "http:" && !port
|
||||
? "80"
|
||||
: protocol === "https:" && !port
|
||||
? "443"
|
||||
: port;
|
||||
return `${protocol}//${host}:${defaultPort}`;
|
||||
}
|
||||
|
||||
function fetchFiles() {
|
||||
const baseUrl = getBaseUrl();
|
||||
fetch(`${baseUrl}/files`)
|
||||
.then((response) => response.json())
|
||||
.then((files) => {
|
||||
const fileList = document.getElementById("fileList");
|
||||
fileList.innerHTML = "";
|
||||
files.forEach((file) => {
|
||||
const fileElement = document.createElement("li");
|
||||
fileElement.textContent = file;
|
||||
fileElement.onclick = () => viewFile(file.split(" ")[0]);
|
||||
fileList.appendChild(fileElement);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function createTrajectoryItem(item, index) {
|
||||
const elementId = `trajectoryItem${index}`;
|
||||
|
||||
// Check for old format and log a warning
|
||||
const isOldFormat = item.messages && !item.query;
|
||||
if (isOldFormat) {
|
||||
console.log(
|
||||
`Found old format using 'messages' instead of 'query' in item ${index}`,
|
||||
);
|
||||
// Migrate old format to new format
|
||||
item.query = item.messages;
|
||||
}
|
||||
|
||||
const hasMessages = item.query && item.query.length > 0;
|
||||
|
||||
const escapeHtml = (text) => {
|
||||
if (!text) {
|
||||
return "";
|
||||
}
|
||||
return text
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
};
|
||||
|
||||
const processImagesInObservation = (observation) => {
|
||||
if (!observation) {
|
||||
return { processedText: "", images: [] };
|
||||
}
|
||||
|
||||
// regex to match markdown-style base64 images: 
|
||||
const imageRegex = /!\[([^\]]*)\]\(data:image\/([^;]+);base64,([^)]+)\)/g;
|
||||
const images = [];
|
||||
let processedText = observation;
|
||||
let match;
|
||||
|
||||
while ((match = imageRegex.exec(observation)) !== null) {
|
||||
const [fullMatch, altText, format, base64Data] = match;
|
||||
|
||||
// create image object
|
||||
const imageObj = {
|
||||
altText: altText || "Image",
|
||||
format: format,
|
||||
dataUrl: `data:image/${format};base64,${base64Data}`,
|
||||
id: `img_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
};
|
||||
|
||||
images.push(imageObj);
|
||||
|
||||
// replace the full base64 string with a placeholder
|
||||
processedText = processedText.replace(
|
||||
fullMatch,
|
||||
`[IMAGE: ${imageObj.altText}]`,
|
||||
);
|
||||
}
|
||||
|
||||
return { processedText, images };
|
||||
};
|
||||
|
||||
const getMessageContent = (msg) => {
|
||||
if (!msg.content) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Handle content as a string
|
||||
if (typeof msg.content === "string") {
|
||||
return msg.content;
|
||||
}
|
||||
|
||||
// Handle content as an array with a dictionary containing 'text' key
|
||||
if (
|
||||
Array.isArray(msg.content) &&
|
||||
msg.content.length > 0 &&
|
||||
msg.content[0].text
|
||||
) {
|
||||
return msg.content[0].text;
|
||||
}
|
||||
|
||||
// Fallback to stringifying the content
|
||||
return JSON.stringify(msg.content);
|
||||
};
|
||||
|
||||
const messagesContent = hasMessages
|
||||
? item.query
|
||||
.map((msg, msgIndex) => {
|
||||
let content = `----Item ${msgIndex}-----\n`;
|
||||
content += `role: ${msg.role}\n`;
|
||||
content += `content: |\n${escapeHtml(getMessageContent(msg))}\n`;
|
||||
|
||||
if (msg.tool_calls && msg.tool_calls.length > 0) {
|
||||
msg.tool_calls.forEach((tool, idx) => {
|
||||
content += `- tool call ${idx + 1}:\n`;
|
||||
if (tool.function) {
|
||||
content += ` - name: ${tool.function.name}\n`;
|
||||
// Handle arguments based on type
|
||||
let args = tool.function.arguments;
|
||||
try {
|
||||
if (typeof args === "string") {
|
||||
args = JSON.parse(args);
|
||||
}
|
||||
content += ` - arguments: ${JSON.stringify(args, null, 2).replace(/\n/g, "\n ")}\n`;
|
||||
} catch (e) {
|
||||
content += ` - arguments: ${escapeHtml(String(args))}\n`;
|
||||
}
|
||||
}
|
||||
content += ` - id: ${tool.id}\n`;
|
||||
});
|
||||
}
|
||||
|
||||
if (msg.is_demo) {
|
||||
return `<span class="demo-message">${content}</span>`;
|
||||
}
|
||||
return content;
|
||||
})
|
||||
.join("\n")
|
||||
: "";
|
||||
|
||||
// Process images in observation
|
||||
const { processedText: processedObservation, images: observationImages } =
|
||||
processImagesInObservation(item.observation);
|
||||
|
||||
// Create separate image pane HTML if there are images
|
||||
const observationImagesPane =
|
||||
observationImages.length > 0
|
||||
? `<div class="observation-images-section" data-title="Observation Images">
|
||||
<div class="content-wrapper">
|
||||
<div class="observation-images">
|
||||
${observationImages
|
||||
.map(
|
||||
(img) =>
|
||||
`<div class="observation-image-container">
|
||||
<img src="${img.dataUrl}" alt="${escapeHtml(img.altText)}" class="observation-image" id="${img.id}">
|
||||
<div class="image-caption">${escapeHtml(img.altText)}</div>
|
||||
</div>`,
|
||||
)
|
||||
.join("")}
|
||||
</div>
|
||||
</div>
|
||||
</div>`
|
||||
: "";
|
||||
|
||||
return `
|
||||
<div class="trajectory-item fade-in" id="${elementId}">
|
||||
<div class="trajectory-main">
|
||||
<div class="response-section" data-title="Response">
|
||||
<div class="content-wrapper">
|
||||
<pre><code class="language-python">Response:
|
||||
${escapeHtml(item.response)}
|
||||
|
||||
Action:
|
||||
${escapeHtml(item.action)}</code></pre>
|
||||
</div>
|
||||
</div>
|
||||
<div class="observation-section" data-title="Environment Observation">
|
||||
<div class="content-wrapper">
|
||||
<pre><code class="language-python">${escapeHtml(processedObservation)}</code></pre>
|
||||
</div>
|
||||
</div>
|
||||
${observationImagesPane}
|
||||
${
|
||||
item.execution_time
|
||||
? `<div class="execution-time">Execution time: ${item.execution_time}s</div>`
|
||||
: ""
|
||||
}
|
||||
</div>
|
||||
${
|
||||
hasMessages
|
||||
? `
|
||||
<div class="messages-section" data-title="Messages">
|
||||
<div class="content-wrapper">
|
||||
<pre>${messagesContent}</pre>
|
||||
</div>
|
||||
</div>
|
||||
`
|
||||
: ""
|
||||
}
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
function viewFile(fileName) {
|
||||
currentFileName = fileName;
|
||||
timeoutIds.forEach((timeoutId) => clearTimeout(timeoutId));
|
||||
timeoutIds = [];
|
||||
|
||||
const baseUrl = getBaseUrl();
|
||||
const showDemos = document.getElementById("showDemos").checked;
|
||||
|
||||
fetch(`${baseUrl}/trajectory/${fileName}`)
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then((content) => {
|
||||
const container = document.getElementById("fileContent");
|
||||
container.innerHTML = "";
|
||||
|
||||
if (content.trajectory && Array.isArray(content.trajectory)) {
|
||||
content.trajectory.forEach((item, index) => {
|
||||
container.innerHTML += createTrajectoryItem(item, index);
|
||||
|
||||
// Highlight code blocks after adding them
|
||||
const newItem = document.getElementById(`trajectoryItem${index}`);
|
||||
newItem.querySelectorAll("pre code").forEach((block) => {
|
||||
hljs.highlightElement(block);
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize image click handlers after all items are added
|
||||
initializeImageHandlers();
|
||||
} else {
|
||||
container.textContent = "No trajectory content found.";
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Error fetching file:", error);
|
||||
document.getElementById("fileContent").textContent =
|
||||
"Error loading content. " + error;
|
||||
});
|
||||
|
||||
// Highlight selected file
|
||||
document.querySelectorAll("#fileList li").forEach((li) => {
|
||||
li.classList.remove("selected");
|
||||
if (li.textContent.split(" ")[0] === fileName) {
|
||||
li.classList.add("selected");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function initializeImageHandlers() {
|
||||
// Remove existing overlay if present
|
||||
const existingOverlay = document.querySelector(".image-overlay");
|
||||
if (existingOverlay) {
|
||||
existingOverlay.remove();
|
||||
}
|
||||
|
||||
// Create overlay element
|
||||
const overlay = document.createElement("div");
|
||||
overlay.className = "image-overlay";
|
||||
document.body.appendChild(overlay);
|
||||
|
||||
// Add click handlers to all observation images
|
||||
document.querySelectorAll(".observation-image").forEach((img) => {
|
||||
img.addEventListener("click", function (e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
|
||||
// Toggle expanded state
|
||||
if (this.classList.contains("expanded")) {
|
||||
this.classList.remove("expanded");
|
||||
overlay.classList.remove("active");
|
||||
} else {
|
||||
// Remove expanded class from all other images
|
||||
document
|
||||
.querySelectorAll(".observation-image.expanded")
|
||||
.forEach((otherImg) => {
|
||||
otherImg.classList.remove("expanded");
|
||||
});
|
||||
|
||||
// Add expanded class to clicked image
|
||||
this.classList.add("expanded");
|
||||
overlay.classList.add("active");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Close expanded image when clicking overlay
|
||||
overlay.addEventListener("click", function () {
|
||||
document.querySelectorAll(".observation-image.expanded").forEach((img) => {
|
||||
img.classList.remove("expanded");
|
||||
});
|
||||
overlay.classList.remove("active");
|
||||
});
|
||||
|
||||
// Close expanded image when pressing Escape key
|
||||
document.addEventListener("keydown", function (e) {
|
||||
if (e.key === "Escape") {
|
||||
document
|
||||
.querySelectorAll(".observation-image.expanded")
|
||||
.forEach((img) => {
|
||||
img.classList.remove("expanded");
|
||||
});
|
||||
overlay.classList.remove("active");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function refreshCurrentFile() {
|
||||
if (currentFileName) {
|
||||
const currentScrollPosition =
|
||||
document.documentElement.scrollTop || document.body.scrollTop;
|
||||
viewFile(currentFileName.split(" ")[0]);
|
||||
setTimeout(() => {
|
||||
window.scrollTo(0, currentScrollPosition);
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function fetchDirectoryInfo() {
|
||||
const baseUrl = getBaseUrl();
|
||||
fetch(`${baseUrl}/directory_info`)
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.directory) {
|
||||
trajectoryDirectory = data.directory;
|
||||
document.title = `Trajectory Viewer: ${data.directory}`;
|
||||
document.getElementById("directoryInfo").textContent =
|
||||
`Directory: ${data.directory}`;
|
||||
}
|
||||
})
|
||||
.catch((error) => console.error("Error fetching directory info:", error));
|
||||
}
|
||||
|
||||
window.onload = function () {
|
||||
fetchFiles();
|
||||
fetchDirectoryInfo();
|
||||
};
|
||||
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/computer.png
vendored
Normal file
|
After Width: | Height: | Size: 14 KiB |
11
.agent/vendor/mini-swe/sweagent/inspector/icons/edit_icon.svg
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
<svg width="21" height="20" viewBox="0 0 21 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.64645 11.7772L3.29289 12.1308L3.64645 12.4843L8.59619 17.4341L8.94975 17.7876L9.3033 17.4341L18.8512 7.88615C20.4133 6.32405 20.4133 3.79139 18.8512 2.22929C17.2891 0.667195 14.7565 0.667193 13.1944 2.22929L3.64645 11.7772Z" fill="black" stroke="white"/>
|
||||
<path d="M3.09429 14.2581C3.27091 13.4928 4.22041 13.2204 4.77579 13.7758L7.2242 16.2242C7.77958 16.7796 7.50727 17.7291 6.74194 17.9057L3.559 18.6402C2.83893 18.8064 2.19359 18.1611 2.35976 17.441L3.09429 14.2581Z" fill="black"/>
|
||||
<path d="M3.09429 14.258C3.27091 13.4927 4.22041 13.2204 4.77579 13.7758L7.2242 16.2242C7.77958 16.7796 7.50727 17.7291 6.74194 17.9057L3.559 18.6402C2.83893 18.8064 2.19359 18.161 2.35976 17.441L3.09429 14.258Z" fill="black"/>
|
||||
<mask id="mask0_8_32" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="2" y="12" width="7" height="7">
|
||||
<path d="M3.5 12.5L8.5 17.5L1.99998 19L3.5 12.5Z" fill="black"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_8_32)">
|
||||
<rect x="1.89587" y="14.1084" width="2.27835" height="7.05693" transform="rotate(-45 1.89587 14.1084)" fill="black"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/swe-agent-logo-50.png
vendored
Normal file
|
After Width: | Height: | Size: 2.4 KiB |
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/swellama_blue.png
vendored
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/swellama_brown.png
vendored
Normal file
|
After Width: | Height: | Size: 36 KiB |
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/swellama_grey.png
vendored
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
.agent/vendor/mini-swe/sweagent/inspector/icons/swellama_tan.png
vendored
Normal file
|
After Width: | Height: | Size: 33 KiB |
25
.agent/vendor/mini-swe/sweagent/inspector/index.html
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Trajectory Viewer</title>
|
||||
<link rel="stylesheet" type="text/css" href="style.css">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
|
||||
<script src="fileViewer.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Trajectory File Viewer</h1>
|
||||
<h3 id="directoryInfo"></h3>
|
||||
<ul id="fileList"></ul>
|
||||
<label style="margin-right: 20px;">
|
||||
<input type="checkbox" id="showDemos" onchange="refreshCurrentFile()"> Show demonstrations
|
||||
</label>
|
||||
<h2>Conversation History</h2>
|
||||
<div id="fileContent" class="trajectory-container">No file selected.</div>
|
||||
<div class="button-container">
|
||||
<button id="refreshButton" onclick="refreshCurrentFile()">Refresh Current File</button>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
354
.agent/vendor/mini-swe/sweagent/inspector/server.py
vendored
Normal file
@@ -0,0 +1,354 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import socketserver
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def add_problem_statement(content):
|
||||
"""The problem statement is the first 'user' message in the history.
|
||||
|
||||
We'll prepend the trajectory with the problem statement.
|
||||
"""
|
||||
problem_statement = ""
|
||||
for item in content["history"]:
|
||||
if item["role"] == "user":
|
||||
problem_statement = item["content"]
|
||||
break
|
||||
if problem_statement:
|
||||
content["trajectory"].insert(
|
||||
0,
|
||||
{
|
||||
"thought": "",
|
||||
"action": "",
|
||||
"response": "",
|
||||
"observation": problem_statement,
|
||||
"messages": [{"role": "system", "content": "Problem Statement placeholder"}],
|
||||
},
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def append_exit(content):
|
||||
exit_status = content.get("info", {}).get("exit_status", None)
|
||||
if exit_status is None:
|
||||
return content
|
||||
|
||||
if exit_status.startswith("submitted"):
|
||||
if "submission" in content["info"]:
|
||||
content["trajectory"].append(
|
||||
{
|
||||
"thought": "Submitting solution",
|
||||
"action": "Model Submission",
|
||||
"response": "Submitting solution",
|
||||
"observation": content["info"]["submission"],
|
||||
"messages": [{"role": "system", "content": f"Submission generated - {exit_status}"}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
msg = "No submission in history or info"
|
||||
raise ValueError(msg)
|
||||
return content
|
||||
|
||||
|
||||
def append_patch(instance_id, content, patches, patch_type):
|
||||
if content.get("info", {}).get("exit_status", None) is not None:
|
||||
if instance_id in patches:
|
||||
content["trajectory"].append(
|
||||
{
|
||||
"thought": f"Showing {patch_type} patch",
|
||||
"response": f"Showing {patch_type} patch",
|
||||
"action": f"{patch_type} Patch",
|
||||
"observation": patches[instance_id],
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def append_results(traj_path: Path, instance_id: str, content, results, results_file):
|
||||
stats: list[str] = []
|
||||
model_stats = {}
|
||||
if traj_path.exists():
|
||||
data = json.loads(traj_path.read_text())
|
||||
info = data.get("info", {})
|
||||
model_stats = info.get("model_stats", {})
|
||||
|
||||
# Build stats section
|
||||
exit_status = info.get("exit_status", "N/A")
|
||||
instance_cost = model_stats.get("instance_cost", None)
|
||||
instance_cost = f"{instance_cost:.2f}" if instance_cost is not None else "N/A"
|
||||
tokens_sent = model_stats.get("tokens_sent", None)
|
||||
tokens_sent = f"{tokens_sent:,}" if tokens_sent is not None else "N/A"
|
||||
tokens_received = model_stats.get("tokens_received", None)
|
||||
tokens_received = f"{tokens_received:,}" if tokens_received is not None else "N/A"
|
||||
api_calls = model_stats.get("api_calls", None)
|
||||
api_calls = f"{api_calls:,}" if api_calls is not None else "N/A"
|
||||
|
||||
stats.append("**** Run Stats ****")
|
||||
stats.append(f"Exit Status: {exit_status}")
|
||||
stats.append(f"Instance Cost: ${instance_cost}")
|
||||
stats.append(f"Tokens Sent: {tokens_sent}")
|
||||
stats.append(f"Tokens Received: {tokens_received}")
|
||||
stats.append(f"API Calls: {api_calls}\n")
|
||||
|
||||
# Build status section
|
||||
status = []
|
||||
if results is None:
|
||||
status.append("Evaluation results not found")
|
||||
elif "completed_ids" in results and "submitted_ids" in results and "resolved_ids" in results:
|
||||
is_completed = instance_id in results["completed_ids"]
|
||||
is_submitted = instance_id in results["submitted_ids"]
|
||||
is_resolved = instance_id in results["resolved_ids"]
|
||||
|
||||
status.append("**** Statuses ****")
|
||||
status.append(f" {'✅' if is_completed else '❌'} Completed (The agent successfully ran)")
|
||||
status.append(f" {'✅' if is_submitted else '❌'} Submitted (The agent successfully submitted a pull request)")
|
||||
status.append(
|
||||
f" {'✅' if is_resolved else '❌'} Resolved (The pull request {'' if is_resolved else 'has not '}"
|
||||
"successfully resolved the issue during eval)"
|
||||
)
|
||||
else:
|
||||
status.append("Results format not recognized")
|
||||
|
||||
if status == []:
|
||||
status.append("Instance not found in results")
|
||||
else:
|
||||
status.append("---------------------------")
|
||||
status.append(
|
||||
"Note that the evaluation results here may not be accurate or up to date, since they are computed separately from the agent run itself."
|
||||
)
|
||||
status.append(f"Check {results_file} for the most accurate evaluation results.")
|
||||
status.append("")
|
||||
status.append(f"Instance ID: {instance_id}")
|
||||
|
||||
# Add evaluation report as first and last items in trajectory
|
||||
eval_report = {
|
||||
"thought": "Evaluation Report",
|
||||
"action": "Showing evaluation results",
|
||||
"response": "Showing evaluation results",
|
||||
"observation": "\n".join([*stats, *status]),
|
||||
"messages": [{"role": "system", "content": "Showing evaluation results and statistics"}],
|
||||
}
|
||||
|
||||
if not content.get("trajectory"):
|
||||
content["trajectory"] = []
|
||||
content["trajectory"].insert(0, eval_report)
|
||||
content["trajectory"].append(eval_report)
|
||||
return content
|
||||
|
||||
|
||||
def get_action_summary(content):
|
||||
out = ""
|
||||
i = 0
|
||||
for item in content["history"]:
|
||||
if item["role"] != "assistant":
|
||||
continue
|
||||
if item.get("is_demo"):
|
||||
continue
|
||||
i += 1
|
||||
try:
|
||||
action = item["action"]
|
||||
except KeyError:
|
||||
print(f"No action for step {i}")
|
||||
print(item)
|
||||
raise
|
||||
if len(action) > 70:
|
||||
action = action[:67] + "..."
|
||||
out += f"Step {i}: {action}\n"
|
||||
return out
|
||||
|
||||
|
||||
def load_content(file_name, gold_patches, test_patches) -> dict[str, Any]:
|
||||
with open(file_name) as infile:
|
||||
content = json.load(infile)
|
||||
results_file = Path(file_name).parent / "results.json"
|
||||
results = load_results(results_file)
|
||||
|
||||
content = add_problem_statement(content)
|
||||
content = append_exit(content)
|
||||
content = append_patch(Path(file_name).stem, content, gold_patches, "Gold")
|
||||
content = append_patch(Path(file_name).stem, content, test_patches, "Test")
|
||||
content["history"].insert(0, {"role": "Action Summary", "content": get_action_summary(content)})
|
||||
return append_results(
|
||||
Path(file_name),
|
||||
Path(file_name).stem,
|
||||
content,
|
||||
results,
|
||||
results_file,
|
||||
)
|
||||
|
||||
|
||||
def load_results(results_path: Path) -> dict[str, Any] | None:
|
||||
"""Load results from results.json.
|
||||
|
||||
If file is not found, return None.
|
||||
"""
|
||||
if not results_path.exists():
|
||||
return None
|
||||
with open(results_path) as infile:
|
||||
results = json.load(infile)
|
||||
# Different versions of the code used "not_generated" or "no_generation".
|
||||
# Let's standardize this here
|
||||
if "no_generation" in results:
|
||||
results["not_generated"] = results["no_generation"]
|
||||
del results["no_generation"]
|
||||
return results
|
||||
|
||||
|
||||
def get_status(traj_path) -> str:
|
||||
"""Return results emoji for single trajectory"""
|
||||
results = load_results(Path(traj_path).parent / "results.json")
|
||||
info = json.loads(Path(traj_path).read_text()).get("info", {})
|
||||
n_steps = info.get("model_stats", {}).get("api_calls", "N/A")
|
||||
exit_status = info.get("exit_status", "N/A")
|
||||
exit_status_str = f" ({exit_status} after {n_steps} steps)"
|
||||
instance_id = Path(traj_path).stem
|
||||
if results is None:
|
||||
return f"❓ {exit_status_str}"
|
||||
elif instance_id in results["resolved_ids"]:
|
||||
return "✅"
|
||||
else:
|
||||
return f"❌ {exit_status_str}"
|
||||
|
||||
|
||||
class Handler(http.server.SimpleHTTPRequestHandler):
|
||||
file_mod_times = {} # Dictionary to keep track of file modification times
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.gold_patches = {}
|
||||
self.test_patches = {}
|
||||
if "gold_patches" in kwargs:
|
||||
self.gold_patches = kwargs.pop("gold_patches")
|
||||
if "test_patches" in kwargs:
|
||||
self.test_patches = kwargs.pop("test_patches")
|
||||
self.traj_dir = kwargs.pop("directory", ".") # Extract directory
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def serve_directory_info(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"directory": self.traj_dir}).encode())
|
||||
|
||||
def serve_file_content(self, file_path):
|
||||
try:
|
||||
content = load_content(
|
||||
Path(self.traj_dir) / file_path,
|
||||
self.gold_patches,
|
||||
self.test_patches,
|
||||
)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(content).encode())
|
||||
except FileNotFoundError:
|
||||
self.send_error(404, f"File {file_path} not found")
|
||||
|
||||
def do_GET(self):
|
||||
if self.path == "/directory_info":
|
||||
self.serve_directory_info()
|
||||
elif self.path.startswith("/files"):
|
||||
self.handle_files_request()
|
||||
elif self.path.startswith("/trajectory/"):
|
||||
file_path = self.path[len("/trajectory/") :]
|
||||
self.serve_file_content(file_path)
|
||||
elif self.path.startswith("/check_update"):
|
||||
self.check_for_updates()
|
||||
else:
|
||||
super().do_GET()
|
||||
|
||||
def handle_files_request(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
files = sorted(
|
||||
(
|
||||
str(file.relative_to(Path(self.traj_dir))) + " " * 4 + get_status(file)
|
||||
for file in Path(self.traj_dir).glob("**/*.traj")
|
||||
),
|
||||
key=lambda x: str(Path(self.traj_dir) / x),
|
||||
reverse=True,
|
||||
)
|
||||
self.wfile.write(json.dumps(files).encode())
|
||||
|
||||
def check_for_updates(self):
|
||||
current_mod_times = {str(file): file.stat().st_mtime for file in Path(self.traj_dir).glob("**/*.traj")}
|
||||
if current_mod_times != Handler.file_mod_times:
|
||||
Handler.file_mod_times = current_mod_times
|
||||
self.send_response(200) # Send response that there's an update
|
||||
else:
|
||||
self.send_response(204) # Send no content response if no update
|
||||
self.end_headers()
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
super().end_headers()
|
||||
|
||||
|
||||
def main(data_path, directory, port):
|
||||
data = []
|
||||
if data_path is not None:
|
||||
if data_path.endswith(".jsonl"):
|
||||
data = [json.loads(x) for x in Path(data_path).read_text().splitlines(keepends=True)]
|
||||
elif data_path.endswith(".json"):
|
||||
with open(data_path) as f:
|
||||
data = json.load(f)
|
||||
elif "args.yaml" in os.listdir(directory):
|
||||
with open(Path(directory) / "args.yaml") as file:
|
||||
args = yaml.safe_load(file)
|
||||
if "environment" in args and "data_path" in args["environment"]:
|
||||
data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
|
||||
if data_path.exists:
|
||||
with open(data_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
gold_patches = {d["instance_id"]: d["patch"] if "patch" in d else None for d in data}
|
||||
test_patches = {d["instance_id"]: d["test_patch"] if "test_patch" in d else None for d in data}
|
||||
|
||||
handler_with_directory = partial(
|
||||
Handler,
|
||||
directory=directory,
|
||||
gold_patches=gold_patches,
|
||||
test_patches=test_patches,
|
||||
)
|
||||
try:
|
||||
with socketserver.TCPServer(("", port), handler_with_directory) as httpd:
|
||||
print(f"Serving at http://localhost:{port}")
|
||||
httpd.serve_forever()
|
||||
except OSError as e:
|
||||
if e.errno == 48:
|
||||
print(f"ERROR: Port ({port}) is already in use. Try another port with the --port flag.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
help="Path to dataset that was used for the trajectories. Necessary to display gold patches.",
|
||||
)
|
||||
parser.add_argument("--directory", type=str, help="Directory to serve", default=os.getcwd(), nargs="?")
|
||||
parser.add_argument("--port", type=int, help="Port to serve", default=8000)
|
||||
return parser
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
# Hack to make sure all the templates and all are found
|
||||
parsed_args = get_parser().parse_args(args)
|
||||
# convert directory, relative to the absolute path
|
||||
parsed_args.directory = str(Path(parsed_args.directory).resolve().absolute())
|
||||
os.chdir(Path(__file__).parent)
|
||||
main(**vars(parsed_args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
169
.agent/vendor/mini-swe/sweagent/inspector/static.py
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
from .server import load_content
|
||||
except ImportError:
|
||||
from server import load_content
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("simple_parsing").setLevel(logging.INFO)
|
||||
|
||||
|
||||
TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Trajectory Viewer</title>
|
||||
<style>
|
||||
{style_sheet}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
{file_path_tree}
|
||||
<h2>Conversation History</h2>
|
||||
<pre id="fileContent">{file_content}</pre>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(Path(__file__).parent / "style.css") as infile:
|
||||
STYLE_SHEET = infile.read()
|
||||
except Exception as e:
|
||||
style_file = Path(__file__).parent / "style.css"
|
||||
logger.error(f"Failed to load style sheet from {style_file}: {traceback.format_exc()}")
|
||||
raise e
|
||||
|
||||
|
||||
def _load_file(file_name, gold_patches, test_patches):
|
||||
try:
|
||||
role_map = {
|
||||
"user": "Computer",
|
||||
"assistant": "SWE-Agent",
|
||||
"subroutine": "SWE-Agent subroutine",
|
||||
"default": "Default",
|
||||
"system": "System",
|
||||
"demo": "Demonstration",
|
||||
}
|
||||
content = load_content(file_name, gold_patches, test_patches)
|
||||
if "history" in content and isinstance(content["history"], list):
|
||||
history_content = ""
|
||||
for index, item in enumerate(content["history"]):
|
||||
item_content = item.get("content", "").replace("<", "<").replace(">", ">")
|
||||
if item.get("agent") and item["agent"] != "primary":
|
||||
role_class = "subroutine"
|
||||
else:
|
||||
role_class = item.get("role", "default").lower().replace(" ", "-")
|
||||
element_id = f"historyItem{index}"
|
||||
role_name = role_map.get(item.get("role", ""), item.get("role", ""))
|
||||
history_content += (
|
||||
f"""<div class="history-item {role_class}" id="{element_id}">"""
|
||||
f"""<div class="role-bar {role_class}"><strong><span>{role_name}</span></strong></div>"""
|
||||
f"""<div class="content-container">"""
|
||||
f"""<pre>{item_content}</pre>"""
|
||||
f"""</div>"""
|
||||
f"""<div class="shadow"></div>"""
|
||||
f"""</div>"""
|
||||
)
|
||||
return history_content
|
||||
else:
|
||||
return "No history content found."
|
||||
except Exception:
|
||||
return f"Error loading content. {traceback.format_exc()}"
|
||||
|
||||
|
||||
def _make_file_path_tree(file_path):
|
||||
path_parts = file_path.split("/")
|
||||
relevant_parts = path_parts[-3:]
|
||||
html_string = '<div class="filepath">\n'
|
||||
for part in relevant_parts:
|
||||
html_string += f'<div class="part">{part}</div>\n'
|
||||
html_string += "</div>"
|
||||
return html_string
|
||||
|
||||
|
||||
def save_static_viewer(file_path):
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
data = []
|
||||
if "args.yaml" in list(map(lambda x: x.name, file_path.parent.iterdir())):
|
||||
args = yaml.safe_load(Path(file_path.parent / "args.yaml").read_text())
|
||||
if "environment" in args and "data_path" in args["environment"]:
|
||||
data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
|
||||
if data_path.exists():
|
||||
with open(data_path) as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, list) or not data or "patch" not in data[0] or "test_patch" not in data[0]:
|
||||
data = []
|
||||
gold_patches = {x["instance_id"]: x["patch"] for x in data}
|
||||
test_patches = {x["instance_id"]: x["test_patch"] for x in data}
|
||||
content = _load_file(file_path, gold_patches, test_patches)
|
||||
file_path_tree = _make_file_path_tree(file_path.absolute().as_posix())
|
||||
icons_path = Path(__file__).parent / "icons"
|
||||
relative_icons_path = find_relative_path(file_path, icons_path)
|
||||
style_sheet = STYLE_SHEET.replace("url('icons/", f"url('{relative_icons_path.as_posix()}/").replace(
|
||||
'url("icons/',
|
||||
f'url("{relative_icons_path.as_posix()}/',
|
||||
)
|
||||
data = TEMPLATE.format(file_content=content, style_sheet=style_sheet, file_path_tree=file_path_tree)
|
||||
output_file = file_path.with_suffix(".html")
|
||||
with open(output_file, "w+") as outfile:
|
||||
print(data, file=outfile)
|
||||
logger.info(f"Saved static viewer to {output_file}")
|
||||
|
||||
|
||||
def find_relative_path(from_path, to_path):
|
||||
# Convert paths to absolute for uniformity
|
||||
from_path = from_path.resolve()
|
||||
to_path = to_path.resolve()
|
||||
if from_path.is_file():
|
||||
from_path = from_path.parent
|
||||
if to_path.is_file():
|
||||
to_path = to_path.parent
|
||||
if not from_path.is_dir() or not to_path.is_dir():
|
||||
msg = f"Both from_path and to_path must be directories, but got {from_path} and {to_path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Identify the common ancestor and the parts of each path beyond it
|
||||
common_parts = 0
|
||||
for from_part, to_part in zip(from_path.parts, to_path.parts):
|
||||
if from_part != to_part:
|
||||
break
|
||||
common_parts += 1
|
||||
|
||||
# Calculate the '../' needed to get back from from_path to the common ancestor
|
||||
back_to_ancestor = [".."] * (len(from_path.parts) - common_parts)
|
||||
|
||||
# Direct path from common ancestor to to_path
|
||||
to_target = to_path.parts[common_parts:]
|
||||
|
||||
# Combine to get the relative path
|
||||
return Path(*back_to_ancestor, *to_target)
|
||||
|
||||
|
||||
def save_all_trajectories(directory):
|
||||
if not isinstance(directory, Path):
|
||||
directory = Path(directory)
|
||||
all_files = list(directory.glob("**/*.traj"))
|
||||
logger.info(f"Found {len(all_files)} trajectory files in {directory}")
|
||||
for file_path in tqdm(all_files, desc="Saving static viewers"):
|
||||
save_static_viewer(file_path)
|
||||
logger.info(f"Saved static viewers for all trajectories in {args.directory}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("directory", type=str, help="Directory containing trajectory files")
|
||||
args = parser.parse_args()
|
||||
save_all_trajectories(args.directory)
|
||||
454
.agent/vendor/mini-swe/sweagent/inspector/style.css
vendored
Normal file
@@ -0,0 +1,454 @@
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #1e1e1e;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
h1,
|
||||
h2 {
|
||||
color: #d4d4d4;
|
||||
}
|
||||
#fileList {
|
||||
list-style-type: none;
|
||||
padding: 10px;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
margin: 0;
|
||||
background-color: #252526;
|
||||
border: 3px solid #1697e2;
|
||||
border-radius: 10px;
|
||||
}
|
||||
#fileList li {
|
||||
cursor: pointer;
|
||||
padding: 10px;
|
||||
background-color: #2d2d2d;
|
||||
margin-bottom: 5px;
|
||||
border: 2px solid #454545;
|
||||
border-radius: 5px;
|
||||
transition: background-color 0.3s;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
#fileList li:hover {
|
||||
background-color: #3e3e3e;
|
||||
}
|
||||
#fileList li.selected {
|
||||
border-color: #fabb00;
|
||||
}
|
||||
#fileContent {
|
||||
background-color: #1e1e1e;
|
||||
border: 1px solid #454545;
|
||||
padding: 10px;
|
||||
margin-top: 20px;
|
||||
white-space: pre-wrap;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
.button-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
#refreshButton {
|
||||
padding: 4px 10px;
|
||||
min-width: 80px;
|
||||
border: none;
|
||||
font: inherit;
|
||||
color: #373030;
|
||||
border-radius: 10px;
|
||||
outline: none;
|
||||
text-decoration: none;
|
||||
cursor: default;
|
||||
font-weight: 400;
|
||||
background: #fff;
|
||||
box-shadow:
|
||||
0px 0px 1px #0000004d,
|
||||
0px 1px 1px #00000066;
|
||||
}
|
||||
#refreshButton:hover {
|
||||
/* hover MUST come before active */
|
||||
background: linear-gradient(#0000004d, #00000066);
|
||||
color: #fff;
|
||||
position: relative;
|
||||
}
|
||||
#refreshButton:active {
|
||||
background: linear-gradient(#4faefc, #006bff);
|
||||
color: #fff;
|
||||
position: relative;
|
||||
}
|
||||
.history-item {
|
||||
border: 3px solid black;
|
||||
border-radius: 5px;
|
||||
padding: 0px;
|
||||
/* padding-bottom: 5%; */
|
||||
margin-bottom: 15px;
|
||||
overflow-x: hidden;
|
||||
white-space: normal;
|
||||
/* overflow-x: auto; Enables horizontal scrolling */
|
||||
/* white-space: nowrap; Keeps content in a single line */
|
||||
max-height: 450px; /* Adjust as needed for 25 lines */
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
.shadow {
|
||||
height: 30px; /* Height of the shadow */
|
||||
background: linear-gradient(to bottom, transparent, rgba(0, 0, 0, 0.4));
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
pointer-events: none; /* Ensures the shadow doesn't interfere with interaction */
|
||||
display: none; /* Initially hidden */
|
||||
}
|
||||
.has-shadow .shadow {
|
||||
display: block;
|
||||
}
|
||||
.content-container {
|
||||
max-height: 400px; /* Adjust as needed */
|
||||
overflow-y: auto;
|
||||
position: relative;
|
||||
padding: 10px;
|
||||
}
|
||||
.content-container pre {
|
||||
white-space: pre-wrap; /* Wrap lines and preserve whitespace */
|
||||
overflow-wrap: break-word; /* Handle long words */
|
||||
}
|
||||
.container {
|
||||
max-width: 1000px;
|
||||
margin: 0 auto; /* Centers the container */
|
||||
padding: 20px; /* Optional: for some inner spacing */
|
||||
}
|
||||
.history-item.user {
|
||||
border-color: #1697e2;
|
||||
}
|
||||
.history-item.tool {
|
||||
border-color: #1483c3;
|
||||
}
|
||||
.history-item.system {
|
||||
border-color: #004b80;
|
||||
}
|
||||
.history-item.subroutine {
|
||||
border-color: #006b00;
|
||||
}
|
||||
.history-item.gold-patch {
|
||||
border-color: #fabb00;
|
||||
}
|
||||
.history-item.assistant {
|
||||
border-color: rgb(0, 0, 0);
|
||||
}
|
||||
.history-item.test-patch {
|
||||
border-color: #7373d9;
|
||||
}
|
||||
.history-item.evaluation-report {
|
||||
border-color: #35614b;
|
||||
}
|
||||
|
||||
/* filepath-tree stuff */
|
||||
.filepath {
|
||||
display: flex;
|
||||
flex-direction: column; /* Changes layout to one part per line */
|
||||
align-items: flex-start; /* Aligns parts to the start of the container */
|
||||
font-size: 16px;
|
||||
gap: 10px;
|
||||
padding: 5px;
|
||||
background-color: #2d2d2d;
|
||||
}
|
||||
.part {
|
||||
border: 1px solid #454545;
|
||||
white-space: nowrap; /* Prevents wrapping within parts */
|
||||
padding: 5px;
|
||||
background-color: #3e3e3e;
|
||||
border-radius: 5px;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 1s ease-out;
|
||||
}
|
||||
|
||||
.trajectory-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.trajectory-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin-bottom: 0px;
|
||||
border: 3px solid #1697e2;
|
||||
border-radius: 5px;
|
||||
position: relative;
|
||||
background: #2d2d2d;
|
||||
min-height: fit-content; /* Added to ensure proper expansion */
|
||||
height: auto; /* Added to ensure proper expansion */
|
||||
overflow: visible; /* Changed from hidden/auto to visible */
|
||||
}
|
||||
|
||||
.trajectory-main {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0;
|
||||
}
|
||||
|
||||
.response-section,
|
||||
.observation-section,
|
||||
.messages-section {
|
||||
padding: 0 4px;
|
||||
position: relative;
|
||||
min-height: 30px;
|
||||
height: 400px;
|
||||
max-height: min-content;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
resize: vertical;
|
||||
overflow: hidden;
|
||||
width: calc(100% - 16px);
|
||||
}
|
||||
|
||||
.response-section {
|
||||
background-color: #252526;
|
||||
border-bottom: 1px solid #454545;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.observation-section {
|
||||
background-color: #2d2d2d;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.observation-images-section {
|
||||
padding: 4px;
|
||||
position: relative;
|
||||
min-height: 30px;
|
||||
height: 200px;
|
||||
max-height: 200px;
|
||||
background: #2d2d30;
|
||||
border-top: 1px solid #454545;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
resize: vertical;
|
||||
overflow: hidden;
|
||||
width: calc(100% - 16px);
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
/* Add section headers */
|
||||
.response-section::before,
|
||||
.observation-section::before,
|
||||
.observation-images-section::before,
|
||||
.messages-section::before {
|
||||
content: attr(data-title);
|
||||
font-weight: bold;
|
||||
padding: 8px 12px; /* Keep this comfortable padding */
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
margin: 0 0 8px 0; /* Add some bottom margin to separate from content */
|
||||
border-bottom: 1px solid #454545;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
/* Scrollable content containers */
|
||||
.content-wrapper {
|
||||
overflow-y: auto;
|
||||
flex: 1;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
display: flex; /* Add flex display */
|
||||
flex-direction: column; /* Stack children vertically */
|
||||
}
|
||||
|
||||
.content-wrapper pre {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
flex: 1; /* Allow pre to fill the space */
|
||||
display: flex; /* Make pre a flex container */
|
||||
flex-direction: column; /* Stack children vertically */
|
||||
}
|
||||
|
||||
.content-wrapper pre code {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
flex: 1; /* Allow code to fill the space */
|
||||
}
|
||||
|
||||
.messages-section {
|
||||
padding: 4px;
|
||||
position: relative;
|
||||
min-height: 30px;
|
||||
height: 30px;
|
||||
max-height: min-content; /* Only expand to fit content */
|
||||
background: #252526;
|
||||
border-top: 1px solid #454545;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
resize: vertical;
|
||||
overflow: hidden;
|
||||
width: calc(100% - 16px);
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.messages-section::before {
|
||||
content: attr(data-title);
|
||||
font-weight: bold;
|
||||
padding: 8px 12px; /* Increased padding around title text */
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
margin: -8px -8px 5px -8px; /* Reduced bottom margin */
|
||||
border-bottom: 1px solid #454545;
|
||||
position: sticky;
|
||||
top: -8px;
|
||||
width: calc(100% + 16px);
|
||||
box-sizing: border-box;
|
||||
color: #d4d4d4;
|
||||
}
|
||||
|
||||
.messages-toggle {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.execution-time {
|
||||
font-size: 0.8em;
|
||||
color: #666;
|
||||
text-align: right;
|
||||
padding: 5px;
|
||||
border-top: 1px solid #ddd;
|
||||
}
|
||||
|
||||
/* Add a visual indicator for resizable areas */
|
||||
.response-section:hover,
|
||||
.observation-section:hover,
|
||||
.observation-images-section:hover,
|
||||
.messages-section.expanded:hover {
|
||||
outline: 1px dashed #999;
|
||||
}
|
||||
|
||||
.demo-message {
|
||||
color: #2ecc71;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Syntax highlighting overrides for dark theme */
|
||||
.hljs {
|
||||
background: #1e1e1e !important;
|
||||
color: #d4d4d4 !important;
|
||||
}
|
||||
|
||||
.hljs-keyword,
|
||||
.hljs-selector-tag,
|
||||
.hljs-built_in,
|
||||
.hljs-name,
|
||||
.hljs-tag {
|
||||
color: #569cd6 !important; /* Blue */
|
||||
}
|
||||
|
||||
.hljs-string,
|
||||
.hljs-title,
|
||||
.hljs-section,
|
||||
.hljs-attribute,
|
||||
.hljs-literal,
|
||||
.hljs-template-tag,
|
||||
.hljs-template-variable,
|
||||
.hljs-type {
|
||||
color: #ce9178 !important; /* Orange/Pink */
|
||||
}
|
||||
|
||||
.hljs-comment,
|
||||
.hljs-quote {
|
||||
color: #6a9955 !important; /* Green */
|
||||
}
|
||||
|
||||
.hljs-number,
|
||||
.hljs-regexp,
|
||||
.hljs-symbol,
|
||||
.hljs-variable,
|
||||
.hljs-template-variable,
|
||||
.hljs-link,
|
||||
.hljs-selector-attr,
|
||||
.hljs-selector-pseudo {
|
||||
color: #b5cea8 !important; /* Light green */
|
||||
}
|
||||
|
||||
label {
|
||||
color: #d4d4d4; /* Matches the theme's text color */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
}
|
||||
|
||||
input[type="checkbox"] {
|
||||
margin: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
/* observation images styles */
|
||||
.observation-images {
|
||||
margin-top: 10px;
|
||||
padding: 10px 0;
|
||||
}
|
||||
|
||||
.observation-image-container {
|
||||
margin-bottom: 15px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.observation-image {
|
||||
max-width: 100%;
|
||||
max-height: 400px;
|
||||
border: 1px solid #454545;
|
||||
border-radius: 5px;
|
||||
background-color: #1e1e1e;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.observation-image:hover {
|
||||
transform: scale(1.02);
|
||||
border-color: #1697e2;
|
||||
}
|
||||
|
||||
.image-caption {
|
||||
margin-top: 5px;
|
||||
font-size: 0.9em;
|
||||
color: #b3b3b3;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.observation-image.expanded {
|
||||
position: fixed;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%) scale(1);
|
||||
max-width: 90vw;
|
||||
max-height: 90vh;
|
||||
z-index: 1000;
|
||||
border: 2px solid #1697e2;
|
||||
box-shadow: 0 0 20px rgba(0, 0, 0, 0.8);
|
||||
}
|
||||
|
||||
.image-overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0, 0, 0, 0.8);
|
||||
z-index: 999;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.image-overlay.active {
|
||||
display: block;
|
||||
}
|
||||
0
.agent/vendor/mini-swe/sweagent/run/__init__.py
vendored
Normal file
158
.agent/vendor/mini-swe/sweagent/run/_progress.py
vendored
Normal file
@@ -0,0 +1,158 @@
|
||||
"""This module contains an auxiliary class for rendering progress of the batch run."""
|
||||
|
||||
import collections
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import yaml
|
||||
from rich.console import Group
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskID,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
from rich.table import Table
|
||||
|
||||
from sweagent.agent.models import GLOBAL_STATS
|
||||
|
||||
|
||||
def _shorten_str(s: str, max_len: int, shorten_left=False) -> str:
|
||||
if not shorten_left:
|
||||
s = s[: max_len - 3] + "..." if len(s) > max_len else s
|
||||
else:
|
||||
s = "..." + s[-max_len + 3 :] if len(s) > max_len else s
|
||||
return f"{s:<{max_len}}"
|
||||
|
||||
|
||||
class RunBatchProgressManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_instances: int,
|
||||
yaml_report_path: Path | None = None,
|
||||
):
|
||||
"""This class manages a progress bar/UI for run-batch
|
||||
|
||||
Args:
|
||||
num_instances: Number of task instances
|
||||
yaml_report_path: Path to save a yaml report of the instances and their exit statuses
|
||||
"""
|
||||
|
||||
self._spinner_tasks: dict[str, TaskID] = {}
|
||||
"""We need to map instance ID to the task ID that is used by the rich progress bar."""
|
||||
|
||||
self._lock = Lock()
|
||||
|
||||
self._instances_by_exit_status = collections.defaultdict(list)
|
||||
self._main_progress_bar = Progress(
|
||||
SpinnerColumn(spinner_name="dots2"),
|
||||
TextColumn("[progress.description]{task.description} (${task.fields[total_cost]})"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("[cyan]eta:[/cyan]"),
|
||||
TimeRemainingColumn(),
|
||||
# Wait 5 min before estimating speed
|
||||
speed_estimate_period=60 * 5,
|
||||
)
|
||||
self._task_progress_bar = Progress(
|
||||
SpinnerColumn(spinner_name="dots2"),
|
||||
TextColumn("{task.fields[instance_id]}"),
|
||||
TextColumn("{task.fields[status]}"),
|
||||
TimeElapsedColumn(),
|
||||
)
|
||||
"""Task progress bar for individual instances. There's only one progress bar
|
||||
with one task for each instance.
|
||||
"""
|
||||
|
||||
self._main_task_id = self._main_progress_bar.add_task(
|
||||
"[cyan]Overall Progress", total=num_instances, total_cost=0
|
||||
)
|
||||
|
||||
self.render_group = Group(Table(), self._task_progress_bar, self._main_progress_bar)
|
||||
self._yaml_report_path = yaml_report_path
|
||||
|
||||
@property
|
||||
def n_completed(self) -> int:
|
||||
return sum(len(instances) for instances in self._instances_by_exit_status.values())
|
||||
|
||||
def update_exit_status_table(self):
|
||||
# We cannot update the existing table, so we need to create a new one and
|
||||
# assign it back to the render group.
|
||||
t = Table()
|
||||
t.add_column("Exit Status")
|
||||
t.add_column("Count", justify="right", style="bold cyan")
|
||||
t.add_column("Most recent instances")
|
||||
t.show_header = False
|
||||
with self._lock:
|
||||
t.show_header = True
|
||||
# Sort by number of instances in descending order
|
||||
sorted_items = sorted(self._instances_by_exit_status.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
for status, instances in sorted_items:
|
||||
instances_str = _shorten_str(", ".join(reversed(instances)), 55)
|
||||
t.add_row(status, str(len(instances)), instances_str)
|
||||
assert self.render_group is not None
|
||||
self.render_group.renderables[0] = t
|
||||
|
||||
def _update_total_costs(self) -> None:
|
||||
with self._lock:
|
||||
self._main_progress_bar.update(self._main_task_id, total_cost=f"{GLOBAL_STATS.total_cost:.2f}")
|
||||
|
||||
def update_instance_status(self, instance_id: str, message: str):
|
||||
assert self._task_progress_bar is not None
|
||||
assert self._main_progress_bar is not None
|
||||
with self._lock:
|
||||
self._task_progress_bar.update(
|
||||
self._spinner_tasks[instance_id],
|
||||
status=_shorten_str(message, 30),
|
||||
instance_id=_shorten_str(instance_id, 25, shorten_left=True),
|
||||
)
|
||||
self._update_total_costs()
|
||||
|
||||
def on_instance_start(self, instance_id: str):
|
||||
with self._lock:
|
||||
self._spinner_tasks[instance_id] = self._task_progress_bar.add_task(
|
||||
description=f"Task {instance_id}",
|
||||
status="Task initialized",
|
||||
total=None,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
def on_instance_end(self, instance_id: str, exit_status: str | None) -> None:
|
||||
self._instances_by_exit_status[exit_status].append(instance_id)
|
||||
with self._lock:
|
||||
self._task_progress_bar.remove_task(self._spinner_tasks[instance_id])
|
||||
self._main_progress_bar.update(TaskID(0), advance=1)
|
||||
self.update_exit_status_table()
|
||||
self._update_total_costs()
|
||||
if self._yaml_report_path is not None:
|
||||
self._save_overview_data_yaml(self._yaml_report_path)
|
||||
|
||||
def on_uncaught_exception(self, instance_id: str, exception: Exception) -> None:
|
||||
self.on_instance_end(instance_id, f"Uncaught {type(exception).__name__}")
|
||||
|
||||
def print_report(self) -> None:
|
||||
"""Print complete list of instances and their exit statuses."""
|
||||
for status, instances in self._instances_by_exit_status.items():
|
||||
print(f"{status}: {len(instances)}")
|
||||
for instance in instances:
|
||||
print(f" {instance}")
|
||||
|
||||
def _get_overview_data(self) -> dict:
|
||||
"""Get data like exit statuses, total costs, etc."""
|
||||
return {
|
||||
# convert defaultdict to dict because of serialization
|
||||
"instances_by_exit_status": dict(self._instances_by_exit_status),
|
||||
"total_cost": GLOBAL_STATS.total_cost,
|
||||
}
|
||||
|
||||
def _save_overview_data_yaml(self, path: Path) -> None:
|
||||
"""Save a yaml report of the instances and their exit statuses."""
|
||||
with self._lock:
|
||||
path.write_text(yaml.dump(self._get_overview_data(), indent=4))
|
||||
449
.agent/vendor/mini-swe/sweagent/run/batch_instances.py
vendored
Normal file
@@ -0,0 +1,449 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from swerex.deployment.config import (
|
||||
DeploymentConfig,
|
||||
DockerDeploymentConfig,
|
||||
DummyDeploymentConfig,
|
||||
LocalDeploymentConfig,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from sweagent.agent.problem_statement import (
|
||||
ProblemStatementConfig,
|
||||
SWEBenchMultimodalProblemStatement,
|
||||
TextProblemStatement,
|
||||
)
|
||||
from sweagent.environment.repo import GithubRepoConfig, LocalRepoConfig, PreExistingRepoConfig, SWESmithRepoConfig
|
||||
from sweagent.environment.swe_env import EnvironmentConfig
|
||||
from sweagent.utils.files import load_file
|
||||
from sweagent.utils.github import _is_repo_private
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
logger = get_logger("swea-config", emoji="🔧")
|
||||
|
||||
|
||||
class AbstractInstanceSource(ABC):
|
||||
"""Anything that adheres to this standard can be used to load instances."""
|
||||
|
||||
@abstractmethod
|
||||
def get_instance_configs(self) -> list[EnvironmentConfig]: ...
|
||||
|
||||
|
||||
class BatchInstance(BaseModel):
|
||||
"""A single instance in a batch of instances.
|
||||
This specifies both the environment configuration and the problem statement.
|
||||
"""
|
||||
|
||||
env: EnvironmentConfig
|
||||
problem_statement: ProblemStatementConfig
|
||||
|
||||
|
||||
def _slice_spec_to_slice(slice_spec: str) -> slice:
|
||||
if slice_spec == "":
|
||||
return slice(None)
|
||||
parts = slice_spec.split(":")
|
||||
values = [None if p == "" else int(p) for p in parts]
|
||||
if len(parts) == 1:
|
||||
return slice(values[0])
|
||||
if len(parts) == 2:
|
||||
return slice(values[0], values[1])
|
||||
if len(parts) == 3:
|
||||
return slice(values[0], values[1], values[2])
|
||||
msg = (
|
||||
f"Invalid slice specification: {slice_spec!r}. "
|
||||
"Here's the expected format: stop or start:stop or start:stop:step "
|
||||
"(i.e., it behaves exactly like python's list slicing `list[slice]`)."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _filter_batch_items(
|
||||
instances: list[BatchInstance], *, filter_: str, slice_: str = "", shuffle: bool = False
|
||||
) -> list[BatchInstance]:
|
||||
if shuffle:
|
||||
instances = sorted(instances.copy(), key=lambda x: x.problem_statement.id)
|
||||
random.seed(42)
|
||||
random.shuffle(instances)
|
||||
before_filter = len(instances)
|
||||
instances = [instance for instance in instances if re.match(filter_, instance.problem_statement.id)]
|
||||
after_filter = len(instances)
|
||||
if before_filter != after_filter:
|
||||
logger.info("Instance filter: %d -> %d instances", before_filter, after_filter)
|
||||
if slice_:
|
||||
instances = instances[_slice_spec_to_slice(slice_)]
|
||||
after_slice = len(instances)
|
||||
if before_filter != after_slice:
|
||||
logger.info("Instance slice: %d -> %d instances", before_filter, after_slice)
|
||||
return instances
|
||||
|
||||
|
||||
class SimpleBatchInstance(BaseModel):
|
||||
"""A simple way to configure a single instance in a batch of instances that all
|
||||
use similar deployment configurations.
|
||||
|
||||
Predominantly used for benchmarking purposes. Assumes that the repository is already
|
||||
present in the docker container.
|
||||
"""
|
||||
|
||||
image_name: str
|
||||
problem_statement: str
|
||||
instance_id: str
|
||||
repo_name: str = ""
|
||||
"""Specifies the repository to use. If empty, no repository is used.
|
||||
If the string does not contain a slash, it is interpreted as an already existing repository at the root
|
||||
of the docker container. If it contains the word "github", it is interpreted as a github repository.
|
||||
Else, it is interpreted as a local repository.
|
||||
"""
|
||||
base_commit: str = "HEAD"
|
||||
"""Used to reset repo."""
|
||||
extra_fields: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional data to be added to the instance.
|
||||
This data will be available when formatting prompt templates.
|
||||
"""
|
||||
|
||||
# Ignore instead of allow because they should be added as `extra_fields`
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
def to_full_batch_instance(self, deployment: DeploymentConfig) -> BatchInstance:
|
||||
"""Merge the deployment options into the `SimpleBatchInstance` object to get a full `BatchInstance`."""
|
||||
# Very important: Make a copy of the deployment config because it will be shared among instances!!!
|
||||
deployment = deployment.model_copy(deep=True)
|
||||
|
||||
if "issue_images" in self.extra_fields:
|
||||
problem_statement = SWEBenchMultimodalProblemStatement(
|
||||
text=self.problem_statement,
|
||||
issue_images=self.extra_fields.pop("issue_images"),
|
||||
id=self.instance_id,
|
||||
extra_fields=self.extra_fields,
|
||||
)
|
||||
else:
|
||||
problem_statement = TextProblemStatement(
|
||||
text=self.problem_statement, id=self.instance_id, extra_fields=self.extra_fields
|
||||
)
|
||||
|
||||
if not self.repo_name:
|
||||
repo = None
|
||||
elif "github" in self.repo_name:
|
||||
repo = GithubRepoConfig(github_url=self.repo_name, base_commit=self.base_commit)
|
||||
elif "/" not in self.repo_name:
|
||||
repo = PreExistingRepoConfig(repo_name=self.repo_name, base_commit=self.base_commit)
|
||||
else:
|
||||
repo = LocalRepoConfig(path=Path(self.repo_name), base_commit=self.base_commit)
|
||||
if isinstance(deployment, LocalDeploymentConfig):
|
||||
if self.image_name:
|
||||
msg = "Local deployment does not support image_name"
|
||||
raise ValueError(msg)
|
||||
return BatchInstance(
|
||||
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
|
||||
)
|
||||
if isinstance(deployment, DummyDeploymentConfig):
|
||||
return BatchInstance(
|
||||
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
|
||||
)
|
||||
|
||||
deployment.image = self.image_name # type: ignore
|
||||
|
||||
if isinstance(deployment, DockerDeploymentConfig) and deployment.python_standalone_dir is None:
|
||||
# Note: you can disable this by setting python_standalone_dir to ""
|
||||
deployment.python_standalone_dir = "/root" # type: ignore
|
||||
|
||||
return BatchInstance(
|
||||
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_legacy_id(cls, data):
|
||||
# Handling compatibility with swe-agent <= 1.0.1
|
||||
if isinstance(data, dict):
|
||||
if "id" in data and "instance_id" not in data:
|
||||
data["instance_id"] = data["id"]
|
||||
data.pop("id")
|
||||
return data
|
||||
|
||||
# todo: Maybe populate extra fields?
|
||||
@classmethod
|
||||
def from_swe_bench(cls, instance: dict[str, Any]) -> Self:
|
||||
"""Convert instances from the classical SWE-bench dataset to the `SimpleBatchInstance` format."""
|
||||
iid = instance["instance_id"]
|
||||
image_name = instance.get("image_name", None)
|
||||
if image_name is None:
|
||||
# Docker doesn't allow double underscore, so we replace them with a magic token
|
||||
id_docker_compatible = iid.replace("__", "_1776_")
|
||||
image_name = f"docker.io/swebench/sweb.eval.x86_64.{id_docker_compatible}:latest".lower()
|
||||
extra_fields = {}
|
||||
if "image_assets" in instance:
|
||||
issue_images = json.loads(instance["image_assets"])["problem_statement"]
|
||||
extra_fields["issue_images"] = issue_images
|
||||
return cls(
|
||||
image_name=image_name,
|
||||
problem_statement=instance["problem_statement"],
|
||||
instance_id=iid,
|
||||
repo_name="testbed",
|
||||
base_commit=instance["base_commit"],
|
||||
extra_fields=extra_fields,
|
||||
)
|
||||
|
||||
|
||||
class InstancesFromFile(BaseModel, AbstractInstanceSource):
|
||||
"""Load instances from a file."""
|
||||
|
||||
path: Path
|
||||
filter: str = ".*"
|
||||
"""Regular expression to filter the instances by instance id."""
|
||||
slice: str = ""
|
||||
"""Select only a slice of the instances (after filtering by `filter`).
|
||||
Possible values are stop or start:stop or start:stop:step
|
||||
(i.e., it behaves exactly like python's list slicing `list[slice]`).
|
||||
"""
|
||||
shuffle: bool = False
|
||||
"""Shuffle the instances (before filtering and slicing)."""
|
||||
|
||||
deployment: DeploymentConfig = Field(
|
||||
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
|
||||
description="Deployment options.",
|
||||
)
|
||||
"""Note that the image_name option is overwritten by the images specified in the task instances."""
|
||||
|
||||
simple: Literal[True] = True
|
||||
"""Convenience discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
type: Literal["file"] = "file"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
def get_instance_configs(self) -> list[BatchInstance]:
|
||||
instance_dicts = load_file(self.path)
|
||||
simple_instances = [SimpleBatchInstance.model_validate(instance_dict) for instance_dict in instance_dicts]
|
||||
instances = [instance.to_full_batch_instance(self.deployment) for instance in simple_instances]
|
||||
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.path.stem
|
||||
|
||||
|
||||
class InstancesFromHuggingFace(BaseModel, AbstractInstanceSource):
|
||||
"""Load instances from HuggingFace."""
|
||||
|
||||
dataset_name: str
|
||||
"""Name of the HuggingFace dataset. Same as when using `datasets.load_dataset`."""
|
||||
split: str = "dev"
|
||||
filter: str = ".*"
|
||||
"""Regular expression to filter the instances by instance id."""
|
||||
slice: str = ""
|
||||
"""Select only a slice of the instances (after filtering by `filter`).
|
||||
Possible values are stop or start:stop or start:stop:step.
|
||||
(i.e., it behaves exactly like python's list slicing `list[slice]`).
|
||||
"""
|
||||
shuffle: bool = False
|
||||
"""Shuffle the instances (before filtering and slicing)."""
|
||||
|
||||
deployment: DeploymentConfig = Field(
|
||||
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
|
||||
)
|
||||
"""Deployment configuration. Note that the `image_name` option is overwritten by the images specified in the task instances.
|
||||
"""
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
def get_instance_configs(self) -> list[BatchInstance]:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds: list[dict[str, Any]] = load_dataset(self.dataset_name, split=self.split) # type: ignore
|
||||
simple_instances: list[SimpleBatchInstance] = [SimpleBatchInstance.model_validate(instance) for instance in ds]
|
||||
instances = [instance.to_full_batch_instance(self.deployment) for instance in simple_instances]
|
||||
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
ds_name = "".join(l for l in self.dataset_name if l.isalnum() or l in ["-", "_"])
|
||||
return f"{ds_name}_{self.split}"
|
||||
|
||||
|
||||
class SWEBenchInstances(BaseModel, AbstractInstanceSource):
|
||||
"""Load instances from SWE-bench."""
|
||||
|
||||
subset: Literal["lite", "verified", "full", "multimodal", "multilingual"] = "lite"
|
||||
"""Subset of swe-bench to use"""
|
||||
|
||||
# IMPORTANT: Do not call this `path`, because then if people do not specify instance.type,
|
||||
# it might be resolved to ExpertInstancesFromFile or something like that.
|
||||
path_override: str | Path | None = None
|
||||
"""Allow to specify a different huggingface dataset name or path to a huggingface
|
||||
dataset. This will override the automatic path set by `subset`.
|
||||
"""
|
||||
|
||||
split: Literal["dev", "test"] = "dev"
|
||||
|
||||
deployment: DeploymentConfig = Field(
|
||||
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
|
||||
)
|
||||
"""Deployment configuration. Note that the image_name option is overwritten by the images specified in the task instances.
|
||||
"""
|
||||
|
||||
type: Literal["swe_bench"] = "swe_bench"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
filter: str = ".*"
|
||||
"""Regular expression to filter the instances by instance id."""
|
||||
slice: str = ""
|
||||
"""Select only a slice of the instances (after filtering by `filter`).
|
||||
Possible values are stop or start:stop or start:stop:step.
|
||||
(i.e., it behaves exactly like python's list slicing `list[slice]`).
|
||||
"""
|
||||
shuffle: bool = False
|
||||
"""Shuffle the instances (before filtering and slicing)."""
|
||||
|
||||
evaluate: bool = False
|
||||
"""Run sb-cli to evaluate"""
|
||||
|
||||
def _get_dataset_path(self) -> str:
|
||||
if self.path_override is not None:
|
||||
return str(self.path_override)
|
||||
dataset_mapping = {
|
||||
"full": "princeton-nlp/SWE-Bench",
|
||||
"verified": "princeton-nlp/SWE-Bench_Verified",
|
||||
"lite": "princeton-nlp/SWE-Bench_Lite",
|
||||
"multimodal": "princeton-nlp/SWE-Bench_Multimodal",
|
||||
"multilingual": "swe-bench/SWE-Bench_Multilingual",
|
||||
}
|
||||
|
||||
if self.subset not in dataset_mapping:
|
||||
msg = f"Unsupported subset: {self.subset}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return dataset_mapping[self.subset]
|
||||
|
||||
def get_instance_configs(self) -> list[BatchInstance]:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds: list[dict[str, Any]] = load_dataset(self._get_dataset_path(), split=self.split) # type: ignore
|
||||
|
||||
if isinstance(self.deployment, DockerDeploymentConfig):
|
||||
self.deployment.platform = "linux/amd64"
|
||||
|
||||
instances = [
|
||||
SimpleBatchInstance.from_swe_bench(instance).to_full_batch_instance(self.deployment) for instance in ds
|
||||
]
|
||||
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return f"swe_bench_{self.subset}_{self.split}"
|
||||
|
||||
|
||||
class ExpertInstancesFromFile(BaseModel, AbstractInstanceSource):
|
||||
"""Load instances from a file. The difference to `InstancesFromFile` is that the instances are configured as full
|
||||
`EnvironmentInstanceConfig` objects, i.e., we could specify separate deployment configurations etc.
|
||||
"""
|
||||
|
||||
path: Path
|
||||
filter: str = ".*"
|
||||
"""Regular expression to filter the instances by instance id."""
|
||||
slice: str = ""
|
||||
"""Select only a slice of the instances (after filtering by `filter`).
|
||||
Possible values are stop or start:stop or start:stop:step.
|
||||
(i.e., it behaves exactly like python's list slicing `list[slice]`).
|
||||
"""
|
||||
shuffle: bool = False
|
||||
"""Shuffle the instances (before filtering and slicing)."""
|
||||
|
||||
type: Literal["expert_file"] = "expert_file"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
def get_instance_configs(self) -> list[BatchInstance]:
|
||||
instance_dicts = load_file(self.path)
|
||||
instances = [BatchInstance.model_validate(instance_dict) for instance_dict in instance_dicts]
|
||||
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.path.stem
|
||||
|
||||
|
||||
class SWESmithInstances(BaseModel, AbstractInstanceSource):
|
||||
"""Load instances from SWE-smith."""
|
||||
|
||||
path: Path
|
||||
|
||||
deployment: DeploymentConfig = Field(
|
||||
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
|
||||
)
|
||||
"""Deployment configuration. Note that the image_name option is overwritten by the images specified in the task instances.
|
||||
"""
|
||||
|
||||
filter: str = ".*"
|
||||
"""Regular expression to filter the instances by instance id."""
|
||||
slice: str = ""
|
||||
"""Select only a slice of the instances (after filtering by `filter`).
|
||||
Possible values are stop or start:stop or start:stop:step.
|
||||
(i.e., it behaves exactly like python's list slicing `list[slice]`).
|
||||
"""
|
||||
shuffle: bool = False
|
||||
"""Shuffle the instances (before filtering and slicing)."""
|
||||
|
||||
type: Literal["swesmith"] = "swesmith"
|
||||
"""Discriminator for (de)serialization/CLI. Do not change."""
|
||||
|
||||
def get_instance_configs(self) -> list[BatchInstance]:
|
||||
github_token = os.getenv("GITHUB_TOKEN", "")
|
||||
|
||||
instance_dicts = load_file(self.path)
|
||||
instances = []
|
||||
|
||||
for instance_dict in instance_dicts:
|
||||
deployment = self.deployment.model_copy(deep=True)
|
||||
deployment.image = instance_dict["image_name"] # type: ignore
|
||||
|
||||
if isinstance(deployment, DockerDeploymentConfig) and deployment.python_standalone_dir is None:
|
||||
deployment.python_standalone_dir = "/root" # type: ignore
|
||||
|
||||
instance_id = instance_dict["instance_id"]
|
||||
repo_field = instance_dict.get("repo", "")
|
||||
|
||||
mirror_url = ""
|
||||
if repo_field and _is_repo_private(repo_field, github_token):
|
||||
if not github_token:
|
||||
msg = (
|
||||
f"Repo '{repo_field}' appears to be private but GITHUB_TOKEN is not set. "
|
||||
"Set GITHUB_TOKEN with 'repo' scope to access private repositories."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
mirror_url = f"https://github.com/{repo_field}.git"
|
||||
|
||||
repo = SWESmithRepoConfig(
|
||||
repo_name="testbed",
|
||||
base_commit=instance_id,
|
||||
mirror_url=mirror_url,
|
||||
)
|
||||
|
||||
problem_statement = TextProblemStatement(
|
||||
text=instance_dict.get("problem_statement", ""),
|
||||
id=instance_id,
|
||||
extra_fields={"fail_to_pass": instance_dict.get("FAIL_TO_PASS", [])},
|
||||
)
|
||||
|
||||
instances.append(
|
||||
BatchInstance(
|
||||
env=EnvironmentConfig(deployment=deployment, repo=repo),
|
||||
problem_statement=problem_statement,
|
||||
)
|
||||
)
|
||||
|
||||
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return f"swesmith_{self.path.stem}"
|
||||
|
||||
|
||||
BatchInstanceSourceConfig = (
|
||||
InstancesFromHuggingFace | InstancesFromFile | SWEBenchInstances | ExpertInstancesFromFile | SWESmithInstances
|
||||
)
|
||||
387
.agent/vendor/mini-swe/sweagent/run/common.py
vendored
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Common functionality for the run scripts."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from types import UnionType
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
from pydantic_settings import BaseSettings, CliApp, SettingsError
|
||||
from rich import print as rich_print
|
||||
from rich.panel import Panel
|
||||
|
||||
from sweagent import CONFIG_DIR
|
||||
from sweagent.types import AgentInfo, AgentRunResult
|
||||
from sweagent.utils.log import get_logger
|
||||
from sweagent.utils.serialization import merge_nested_dicts
|
||||
|
||||
|
||||
def _shorten_strings(data, *, max_length=30):
|
||||
"""
|
||||
Recursively shortens all strings in a nested data structure to a maximum length.
|
||||
|
||||
Args:
|
||||
data: The nested data structure (dicts, lists, and strings).
|
||||
max_length: The maximum length for strings.
|
||||
|
||||
Returns:
|
||||
The modified data structure with shortened strings.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
# Shorten the string if it exceeds the max length
|
||||
data = data.replace("\n", "\\n")
|
||||
return data[: max_length - 3] + "..."
|
||||
elif isinstance(data, list):
|
||||
# Recursively process each item in the list
|
||||
return [_shorten_strings(item, max_length=max_length) for item in data]
|
||||
elif isinstance(data, dict):
|
||||
# Recursively process each value in the dictionary
|
||||
return {key: _shorten_strings(value, max_length=max_length) for key, value in data.items()}
|
||||
else:
|
||||
# Return the data as is if it's neither a string, list, nor dict
|
||||
return data
|
||||
|
||||
|
||||
_VALIDATION_ERROR_HELP_TEXT = """
|
||||
The following errors are raised by Pydantic, trying to instantiate the configuration based on
|
||||
the merged configuration dictionary [bold](see above)[/bold].
|
||||
|
||||
Every new indented block corresponds to a different error from Pydantic.
|
||||
The first line of each block is the attribute that failed validation, the following lines are the error messages.
|
||||
|
||||
If you see many lines of errors, there are probably different ways to instantiate the same object (a union type).
|
||||
For example, there are different deployments with different options each. Pydantic is then trying
|
||||
one after the other and reporting the failures for each of them.
|
||||
More on union types: [link=https://swe-agent.com/latest/usage/cl_tutorial/#union-types]https://swe-agent.com/latest/usage/cl_tutorial/#union-types[/link]
|
||||
"""
|
||||
|
||||
_SETTING_ERROR_HINTS = """
|
||||
[red][bold]Hints:[/bold][/red]
|
||||
Run `sweagent <subcommand> --help` for usage examples.
|
||||
|
||||
[red][bold]Common mistakes:[/bold][/red]
|
||||
- You used dashes instead of underscores (wrong: `--num-workers`, correct: `--num_workers`).
|
||||
- You forgot about part of the hierarchy (wrong: `--model.name`, correct: `--agent.model.name`).
|
||||
"""
|
||||
|
||||
|
||||
class AutoCorrectSuggestion:
|
||||
def __init__(
|
||||
self, original: str, alternative: str = "", *, condition: Callable | None = None, help: str | None = None
|
||||
):
|
||||
self.original = original
|
||||
self.alternative = alternative
|
||||
self.condition = condition
|
||||
self.help = help
|
||||
if self.help and self.alternative:
|
||||
msg = "Cannot set both help and alternative"
|
||||
raise ValueError(msg)
|
||||
|
||||
def show(self, args: list[str]) -> bool:
|
||||
no_equal = []
|
||||
for arg in args:
|
||||
if "=" in arg:
|
||||
no_equal.extend(arg.split("="))
|
||||
else:
|
||||
no_equal.append(arg)
|
||||
if self.condition is not None:
|
||||
return self.condition(no_equal)
|
||||
return f"--{self.original}" in no_equal
|
||||
|
||||
def format(self) -> str:
|
||||
if self.help:
|
||||
return self.help
|
||||
return f"You wrote [red]--{self.original}[/red]. Did you mean [green]--{self.alternative}[/green]?"
|
||||
|
||||
|
||||
class ConfigHelper:
|
||||
"""Produce easy-to-read help text from pydantic setting objects."""
|
||||
|
||||
def _get_type_name(self, item: Any, full: bool = False):
|
||||
"""Given a config type, return a string that is either the full name or just the class name."""
|
||||
full_name = str(item).removeprefix("<class '").removesuffix("'>")
|
||||
if full:
|
||||
return full_name
|
||||
return full_name.split(".")[-1]
|
||||
|
||||
def _get_value_help_string(self, item: Any, description: str | None):
|
||||
"""Given an item, document it"""
|
||||
if hasattr(item, "model_fields"):
|
||||
# It's a pydantic config class
|
||||
full_name = self._get_type_name(item, full=True)
|
||||
name = self._get_type_name(item)
|
||||
out = f"[green]{name}[/green]\n"
|
||||
if description:
|
||||
out += f" {description}\n"
|
||||
out += f" Run [green]--help_option {full_name}[/green] for more info"
|
||||
return out
|
||||
if isinstance(item, UnionType):
|
||||
name = self._get_type_name(item)
|
||||
out = ""
|
||||
if description:
|
||||
out += f" {description}\n"
|
||||
out += " This config item can be one of the following things (run [green]--help_option <name>[/green] for more info):\n"
|
||||
things = str(item).split("|")
|
||||
for thing in things:
|
||||
out += f" [green]{thing.strip()}[/green]\n"
|
||||
return out.strip()
|
||||
return self._get_type_name(item)
|
||||
|
||||
def get_help(self, config_type: type[BaseSettings]) -> str:
|
||||
lines = []
|
||||
for name, field_info in config_type.model_fields.items():
|
||||
line = f"[green][bold]{name}[/bold][/green]: "
|
||||
line += self._get_value_help_string(field_info.annotation, field_info.description)
|
||||
lines.append(line)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def _nested_dict():
|
||||
"""Helper function to create nested dictionaries."""
|
||||
return defaultdict(_nested_dict)
|
||||
|
||||
|
||||
def _parse_args_to_nested_dict(args):
|
||||
"""Parse the command-line arguments into a nested dictionary."""
|
||||
result = _nested_dict()
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
if not arg.startswith("--"):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle --key=value format
|
||||
if "=" in arg:
|
||||
key, value = arg[2:].split("=", 1)
|
||||
# Handle --key value format
|
||||
else:
|
||||
key = arg[2:]
|
||||
i += 1
|
||||
if i >= len(args):
|
||||
break
|
||||
value = args[i]
|
||||
|
||||
# Convert value to int if possible
|
||||
value = int(value) if value.isdigit() else value
|
||||
|
||||
# Build nested dict structure
|
||||
keys = key.split(".")
|
||||
current = result
|
||||
for k in keys[:-1]:
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# todo: Parameterize type hints
|
||||
class BasicCLI:
|
||||
def __init__(
|
||||
self,
|
||||
config_type: type[BaseSettings],
|
||||
*,
|
||||
default_settings: bool = True,
|
||||
help_text: str | None = None,
|
||||
default_config_file: Path = CONFIG_DIR / "default.yaml",
|
||||
):
|
||||
"""This class implements a basic CLI for SWE-agent. It is based on pydantic-settings, i.e., takes
|
||||
a `BaseSettings` object. In principle you could just initialize these via `pydantic-settings`'s `CliApp.run`,
|
||||
however, we also want to add a `--config` option to load additional config files and some other things.
|
||||
We also try to improve a bit on the pydantic error messages in here.
|
||||
|
||||
Args:
|
||||
config_type: The type of the configuration object to instantiate.
|
||||
default_settings: Whether to load the default settings.
|
||||
help_text: If given, this will override the default help text that would usually be shown
|
||||
by argparse.
|
||||
"""
|
||||
self.arg_type = config_type
|
||||
self.default_settings = default_settings
|
||||
self.logger = get_logger("swea-cli", emoji="🔧")
|
||||
self.help_text = help_text
|
||||
self.default_config_file = default_config_file
|
||||
|
||||
def maybe_show_auto_correct(self, args: list[str]):
|
||||
auto_correct = []
|
||||
if hasattr(self.arg_type, "_get_auto_correct"):
|
||||
for ac in self.arg_type._get_auto_correct(): # type: ignore
|
||||
if ac.show(args):
|
||||
auto_correct.append(ac)
|
||||
if auto_correct:
|
||||
rich_print(
|
||||
Panel.fit(
|
||||
"[red][bold]Auto-correct suggestions[/bold][/red]\n\n"
|
||||
+ "\n".join(ac.format() for ac in auto_correct),
|
||||
)
|
||||
)
|
||||
|
||||
def get_config(self, args: list[str] | None = None) -> BaseSettings:
|
||||
"""Get the configuration object from defaults and command arguments."""
|
||||
|
||||
# >>> Step 1: Use argparse to add a --config option to load whole config files
|
||||
|
||||
# The defaults if no config file is provided
|
||||
# Otherwise, the configs from the respective classes will be used
|
||||
parser = ArgumentParser(description=__doc__, add_help=False)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=Path,
|
||||
action="append",
|
||||
default=[],
|
||||
help=(
|
||||
"Load additional config files. Use this option multiple times to load "
|
||||
"multiple files, e.g., --config config1.yaml --config config2.yaml"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-h",
|
||||
"--help",
|
||||
help="Show help text and exit",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--help_option",
|
||||
help="Show help text for a specific option",
|
||||
)
|
||||
if self.default_settings:
|
||||
parser.add_argument(
|
||||
"--no_config_file",
|
||||
action="store_true",
|
||||
help="Do not load default config file when no config file is provided",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_config",
|
||||
action="store_true",
|
||||
help="Print the final config and exit",
|
||||
)
|
||||
|
||||
# >>> Step 2: Parse argparse arguments but keep all the remaining arguments.
|
||||
# Explicitly handle --help and --print-options
|
||||
|
||||
cli_args, remaining_args = parser.parse_known_args(args)
|
||||
|
||||
if cli_args.help:
|
||||
if self.help_text:
|
||||
rich_print(self.help_text)
|
||||
else:
|
||||
parser.print_help()
|
||||
exit(0)
|
||||
if cli_args.help_option:
|
||||
module, _, name = cli_args.help_option.rpartition(".")
|
||||
if module not in sys.modules:
|
||||
__import__(module)
|
||||
type_ = getattr(sys.modules[module], name)
|
||||
rich_print(ConfigHelper().get_help(type_))
|
||||
exit(0)
|
||||
|
||||
# >>> Step 3: Load config files and merge them in a big nested data structure
|
||||
|
||||
config_merged = {}
|
||||
config_files = []
|
||||
if cli_args.config:
|
||||
config_files.extend(cli_args.config)
|
||||
for _f in cli_args.config:
|
||||
txt = Path(_f).read_text()
|
||||
if not txt.strip():
|
||||
self.logger.warning(f"Config file {_f} is empty")
|
||||
continue
|
||||
_loaded = yaml.safe_load(txt)
|
||||
merge_nested_dicts(config_merged, _loaded)
|
||||
elif self.default_settings and not cli_args.no_config_file:
|
||||
config_file = self.default_config_file
|
||||
config_files.append(config_file)
|
||||
msg = (
|
||||
f"Loading default config from {config_file}, because no other "
|
||||
"config file is specified. Specify --no_config_file to disable this."
|
||||
)
|
||||
self.logger.info(msg)
|
||||
txt = config_file.read_text()
|
||||
if not txt.strip():
|
||||
self.logger.warning(f"Default config file {config_file} is empty")
|
||||
config_merged = {}
|
||||
else:
|
||||
config_merged = yaml.safe_load(txt)
|
||||
else:
|
||||
config_merged = {}
|
||||
|
||||
# For informational purposes, we also merge in the command line options
|
||||
cl_options_dict = _parse_args_to_nested_dict(remaining_args)
|
||||
|
||||
# >>> Step 4: Bring together remaining arguments and the merged config to initialize the config object
|
||||
# This is done by CliApp.run from pydantic-settings
|
||||
|
||||
try:
|
||||
config: BaseSettings = CliApp.run(self.arg_type, remaining_args, **config_merged, cli_exit_on_error=False) # type: ignore
|
||||
except ValidationError as e:
|
||||
rich_print(
|
||||
Panel.fit(
|
||||
"[red][bold]Configuration from config files\n[/bold]"
|
||||
"This is all the configuration that was provided from defaults, --config, and CLI arguments[/red]\n\n"
|
||||
+ yaml.dump(_shorten_strings(config_merged))
|
||||
)
|
||||
)
|
||||
rich_print(
|
||||
Panel.fit(
|
||||
"[red][bold]Configuration from CLI arguments\n[/bold]"
|
||||
"This is all the configuration that was provided from the command line arguments[/red]\n\n"
|
||||
+ yaml.dump(_shorten_strings(cl_options_dict))
|
||||
)
|
||||
)
|
||||
rich_print(
|
||||
Panel.fit(
|
||||
"[red][bold]Merged configuration\n[/bold]"
|
||||
"This is the merged configuration that was used to instantiate the config object[/red]\n\n"
|
||||
+ yaml.dump(_shorten_strings(merge_nested_dicts(config_merged, cl_options_dict)))
|
||||
)
|
||||
)
|
||||
rich_print(
|
||||
Panel.fit(
|
||||
"[red][bold]Validation error[/bold]\n" + _VALIDATION_ERROR_HELP_TEXT + "[/red]\n" + str(e),
|
||||
)
|
||||
)
|
||||
self.maybe_show_auto_correct(remaining_args)
|
||||
msg = "Invalid configuration. Please check the above output."
|
||||
raise RuntimeError(msg) from None
|
||||
except SettingsError as e:
|
||||
rich_print(Panel.fit("[red][bold]SettingsError[/bold][/red]\n\n" + str(e) + "\n\n" + _SETTING_ERROR_HINTS))
|
||||
self.maybe_show_auto_correct(remaining_args)
|
||||
msg = "Invalid command line arguments. Please check the above output in the box."
|
||||
raise RuntimeError(msg) from None
|
||||
|
||||
if cli_args.print_config: # type: ignore
|
||||
print(yaml.dump(config.model_dump()))
|
||||
exit(0)
|
||||
|
||||
# Attach config files to the arg object, because we need them for file naming purposes
|
||||
# (the output traj directory is named after the last config file)
|
||||
config._config_files = config_files # type: ignore
|
||||
return config
|
||||
|
||||
|
||||
def save_predictions(traj_dir: Path, instance_id: str, result: AgentRunResult):
|
||||
"""Save predictions in a file readable by SWE-bench"""
|
||||
output_file = traj_dir / instance_id / (instance_id + ".pred")
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
datum = {
|
||||
"model_name_or_path": traj_dir.name,
|
||||
"instance_id": instance_id,
|
||||
"model_patch": result.info.get("submission"),
|
||||
}
|
||||
output_file.write_text(json.dumps(datum))
|
||||
|
||||
|
||||
def _is_promising_patch(info: AgentInfo) -> bool:
|
||||
"""Do we actually believe that the patch will solve the issue?
|
||||
Or are we just submitting the last patch we generated before hitting an error?
|
||||
"""
|
||||
# The exit status can also be `submitted (exit_cost)` etc.
|
||||
return info.get("exit_status") == "submitted" and info.get("submission") is not None
|
||||
123
.agent/vendor/mini-swe/sweagent/run/compare_runs.py
vendored
Normal file
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
def get_resolved(path: Path) -> set[str]:
|
||||
data = json.loads(path.read_text())
|
||||
if "resolved" in data:
|
||||
data["resolved_ids"] = data["resolved"]
|
||||
return set(data["resolved_ids"])
|
||||
|
||||
|
||||
def get_submitted(path: Path) -> set[str]:
|
||||
return set(json.loads(path.read_text())["submitted_ids"])
|
||||
|
||||
|
||||
def stats_single(path: Path) -> None:
|
||||
evaluated_ids = sorted(get_submitted(path))
|
||||
resolved_ids = sorted(get_resolved(path))
|
||||
print(f"Total evaluated: {len(evaluated_ids)}")
|
||||
print(f"Total resolved: {len(resolved_ids)}")
|
||||
|
||||
|
||||
def compare_many(paths: list[Path]) -> None:
|
||||
evaluated_ids = {}
|
||||
resolved_ids = {}
|
||||
for path in paths:
|
||||
evaluated_ids[path] = sorted(get_submitted(path))
|
||||
resolved_ids[path] = sorted(get_resolved(path))
|
||||
header: list[str] = ["ID"] + [str(i) for i in range(len(paths))] + ["Success rate"]
|
||||
table: list[list[str | float | int]] = []
|
||||
|
||||
def get_emoji(id: str, path: Path) -> str:
|
||||
if id not in evaluated_ids[path]:
|
||||
return "❓"
|
||||
if id in resolved_ids[path]:
|
||||
return "✅"
|
||||
return "❌"
|
||||
|
||||
ids_to_compare = set(evaluated_ids[paths[0]])
|
||||
for id in sorted(ids_to_compare):
|
||||
row = [id] + [get_emoji(id, path) for path in paths]
|
||||
n_success = sum(id in resolved_ids[path] for path in paths)
|
||||
n_evaluated = sum(id in evaluated_ids[path] for path in paths)
|
||||
row.append(f"{n_success / n_evaluated:.2f}")
|
||||
table.append(row)
|
||||
successes: list[str | float] = ["Successes"]
|
||||
success_rates: list[str | float] = ["Success rates"]
|
||||
for path in paths:
|
||||
n_success = sum(id in resolved_ids[path] for id in ids_to_compare)
|
||||
n_evaluated = sum(id in evaluated_ids[path] for id in ids_to_compare)
|
||||
successes.append(n_success)
|
||||
success_rates.append(f"{n_success / n_evaluated:.2f}")
|
||||
table.append(successes)
|
||||
table.append(success_rates)
|
||||
print(tabulate(table, headers=header))
|
||||
print()
|
||||
|
||||
header: list[str] = ["#", "ID", "Successes", "Success rate"]
|
||||
table: list[list[str | float | int]] = []
|
||||
for i, path in enumerate(paths):
|
||||
row = [i, path.parent.name, successes[i + 1], success_rates[i + 1]]
|
||||
table.append(row)
|
||||
print(tabulate(table, headers=header))
|
||||
|
||||
|
||||
def compare_pair(new_path: Path, old_path: Path, *, show_same=False) -> None:
|
||||
evaluated_ids = sorted(get_submitted(new_path))
|
||||
resolved_ids = sorted(get_resolved(new_path))
|
||||
old_evaluated_ids = sorted(get_submitted(old_path))
|
||||
old_resolved_ids = sorted(get_resolved(old_path))
|
||||
print(f"Total evaluated: new {len(evaluated_ids)}, old {len(old_evaluated_ids)}")
|
||||
print(f"Total resolved: new {len(resolved_ids)}, old {len(old_resolved_ids)}")
|
||||
print("-" * 80)
|
||||
print("Emoji legend:")
|
||||
print("❓: Not evaluated in old version, so guessing it's either 😀 or 👾")
|
||||
print("😀: Newly resolved in new version")
|
||||
print("✅: Resolved in both")
|
||||
print("❌: Resolved in old, not in new")
|
||||
print("👾: Unresolved in both")
|
||||
print("-" * 80)
|
||||
|
||||
for id in evaluated_ids:
|
||||
resolved_now = id in resolved_ids
|
||||
resolved_before = id in old_resolved_ids
|
||||
if id not in old_evaluated_ids and resolved_now:
|
||||
emoji = "😀❓"
|
||||
elif id not in old_evaluated_ids and not resolved_now:
|
||||
emoji = "👾❓"
|
||||
elif resolved_now and not resolved_before:
|
||||
emoji = "😀"
|
||||
elif resolved_now and resolved_before:
|
||||
emoji = "✅"
|
||||
if not show_same:
|
||||
continue
|
||||
elif not resolved_now and resolved_before:
|
||||
emoji = "❌"
|
||||
else:
|
||||
emoji = "👾"
|
||||
if not show_same:
|
||||
continue
|
||||
print(f"{emoji} {id}")
|
||||
|
||||
|
||||
def run_from_cli(_args: list[str] | None = None) -> None:
|
||||
def get_preds_path(path: Path) -> Path:
|
||||
if path.is_dir():
|
||||
return path / "results.json"
|
||||
return path
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("paths", type=Path, nargs="+")
|
||||
parser.add_argument("--show-same", action="store_true")
|
||||
args = parser.parse_args(_args)
|
||||
args.paths = [get_preds_path(path) for path in args.paths]
|
||||
if len(args.paths) == 1:
|
||||
stats_single(args.paths[0])
|
||||
elif len(args.paths) == 2:
|
||||
compare_pair(args.paths[0], args.paths[1], show_same=args.show_same)
|
||||
else:
|
||||
compare_many(args.paths)
|
||||
19
.agent/vendor/mini-swe/sweagent/run/extract_pred.py
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
"""If for some reason the .pred file isn't saved, we can extract it from the .traj file."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_from_cli(_args: list[str] | None = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("traj_path", type=Path)
|
||||
args = parser.parse_args(_args)
|
||||
data = json.loads(args.traj_path.read_text())
|
||||
pred_path = args.traj_path.with_suffix(".pred")
|
||||
pred_data = {
|
||||
"model_name_or_path": args.traj_path.resolve().parent.parent.name,
|
||||
"model_patch": data["info"]["submission"],
|
||||
"instance_id": args.traj_path.resolve().parent.name,
|
||||
}
|
||||
pred_path.write_text(json.dumps(pred_data))
|
||||
0
.agent/vendor/mini-swe/sweagent/run/hooks/__init__.py
vendored
Normal file
67
.agent/vendor/mini-swe/sweagent/run/hooks/abstract.py
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.types import AgentRunResult
|
||||
|
||||
|
||||
class RunHook:
|
||||
"""Hook structure for the web server or other addons to interface with"""
|
||||
|
||||
def on_init(self, *, run):
|
||||
"""Called when hook is initialized"""
|
||||
|
||||
def on_start(self):
|
||||
"""Called at the beginning of `Main.main`"""
|
||||
|
||||
def on_end(self):
|
||||
"""Called at the end of `Main.main`"""
|
||||
|
||||
def on_instance_start(
|
||||
self, *, index: int, env: SWEEnv, problem_statement: ProblemStatement | ProblemStatementConfig
|
||||
):
|
||||
"""Called at the beginning of each instance loop in `Main.run`"""
|
||||
|
||||
def on_instance_skipped(
|
||||
self,
|
||||
):
|
||||
"""Called when an instance is skipped in `Main.run`"""
|
||||
|
||||
def on_instance_completed(self, *, result: AgentRunResult):
|
||||
"""Called when an instance is completed in `Main.run`"""
|
||||
|
||||
|
||||
class CombinedRunHooks(RunHook):
|
||||
def __init__(self):
|
||||
self._hooks = []
|
||||
|
||||
def add_hook(self, hook: RunHook) -> None:
|
||||
self._hooks.append(hook)
|
||||
|
||||
@property
|
||||
def hooks(self) -> list[RunHook]:
|
||||
return self._hooks
|
||||
|
||||
def on_init(self, *, run):
|
||||
for hook in self._hooks:
|
||||
hook.on_init(run=run)
|
||||
|
||||
def on_start(self):
|
||||
for hook in self._hooks:
|
||||
hook.on_start()
|
||||
|
||||
def on_end(self):
|
||||
for hook in self._hooks:
|
||||
hook.on_end()
|
||||
|
||||
def on_instance_start(
|
||||
self, *, index: int, env: SWEEnv, problem_statement: ProblemStatement | ProblemStatementConfig
|
||||
):
|
||||
for hook in self._hooks:
|
||||
hook.on_instance_start(index=index, env=env, problem_statement=problem_statement)
|
||||
|
||||
def on_instance_skipped(self):
|
||||
for hook in self._hooks:
|
||||
hook.on_instance_skipped()
|
||||
|
||||
def on_instance_completed(self, *, result: AgentRunResult):
|
||||
for hook in self._hooks:
|
||||
hook.on_instance_completed(result=result)
|
||||
110
.agent/vendor/mini-swe/sweagent/run/hooks/apply_patch.py
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
import subprocess
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import rich
|
||||
import rich.markdown
|
||||
import rich.panel
|
||||
|
||||
from sweagent.agent.problem_statement import ProblemStatementConfig
|
||||
from sweagent.environment.repo import LocalRepoConfig
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.run.common import _is_promising_patch
|
||||
from sweagent.run.hooks.abstract import RunHook
|
||||
from sweagent.types import AgentRunResult
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class SaveApplyPatchHook(RunHook):
|
||||
"""This hook saves patches to a separate directory and optionally applies them to a local repository."""
|
||||
|
||||
def __init__(self, apply_patch_locally: bool = False, show_success_message: bool = True):
|
||||
self.logger = get_logger("swea-save_apply_patch", emoji="⚡️")
|
||||
self._apply_patch_locally = apply_patch_locally
|
||||
self._show_success_message = show_success_message
|
||||
# Thread-local storage so that concurrent workers in run-batch do not
|
||||
# overwrite each other's per-instance state (_env, _problem_statement).
|
||||
self._local = threading.local()
|
||||
|
||||
def on_init(self, *, run):
|
||||
self._output_dir = Path(run.output_dir)
|
||||
|
||||
def on_instance_start(self, *, index: int, env: SWEEnv, problem_statement: ProblemStatementConfig):
|
||||
self._local.env = env
|
||||
self._local.problem_statement = problem_statement
|
||||
|
||||
def on_instance_completed(self, *, result: AgentRunResult):
|
||||
instance_id = self._local.problem_statement.id
|
||||
patch_path = self._save_patch(instance_id, result.info)
|
||||
if patch_path:
|
||||
if not self._apply_patch_locally:
|
||||
return
|
||||
if not _is_promising_patch(result.info):
|
||||
return
|
||||
if self._local.env.repo is None:
|
||||
return
|
||||
if not isinstance(self._local.env.repo, LocalRepoConfig):
|
||||
return
|
||||
local_dir = Path(self._local.env.repo.path)
|
||||
self._apply_patch(patch_path, local_dir)
|
||||
|
||||
@staticmethod
|
||||
def _print_patch_message(patch_output_file: Path):
|
||||
console = rich.console.Console()
|
||||
msg = [
|
||||
"SWE-agent has produced a patch that it believes will solve the issue you submitted!",
|
||||
"Use the code snippet below to inspect or apply it!",
|
||||
]
|
||||
panel = rich.panel.Panel.fit(
|
||||
"\n".join(msg),
|
||||
title="🎉 Submission successful 🎉",
|
||||
)
|
||||
console.print(panel)
|
||||
content = [
|
||||
"```bash",
|
||||
"# The patch has been saved to your local filesystem at:",
|
||||
f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
|
||||
"# Inspect it:",
|
||||
'cat "${PATCH_FILE_PATH}"',
|
||||
"# Apply it to a local repository:",
|
||||
"cd <your local repo root>",
|
||||
'git apply "${PATCH_FILE_PATH}"',
|
||||
"```",
|
||||
]
|
||||
console.print(rich.markdown.Markdown("\n".join(content)))
|
||||
|
||||
def _save_patch(self, instance_id: str, info) -> Path | None:
|
||||
"""Create patch files that can be applied with `git am`.
|
||||
|
||||
Returns:
|
||||
The path to the patch file, if it was saved. Otherwise, returns None.
|
||||
"""
|
||||
patch_output_dir = self._output_dir / instance_id
|
||||
patch_output_dir.mkdir(exist_ok=True, parents=True)
|
||||
patch_output_file = patch_output_dir / f"{instance_id}.patch"
|
||||
if info.get("submission") is None:
|
||||
self.logger.info("No patch to save.")
|
||||
return None
|
||||
model_patch = info["submission"]
|
||||
patch_output_file.write_text(model_patch)
|
||||
if _is_promising_patch(info):
|
||||
# Only print big congratulations if we actually believe
|
||||
# the patch will solve the issue
|
||||
if self._show_success_message:
|
||||
self._print_patch_message(patch_output_file)
|
||||
return patch_output_file
|
||||
|
||||
def _apply_patch(self, patch_file: Path, local_dir: Path) -> None:
|
||||
"""Apply a patch to a local directory."""
|
||||
|
||||
assert local_dir.is_dir()
|
||||
assert patch_file.exists()
|
||||
# The resolve() is important, because we're gonna run the cmd
|
||||
# somewhere else
|
||||
cmd = ["git", "apply", str(patch_file.resolve())]
|
||||
try:
|
||||
subprocess.run(cmd, cwd=local_dir, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
|
||||
return
|
||||
self.logger.info(f"Applied patch {patch_file} to {local_dir}")
|
||||
244
.agent/vendor/mini-swe/sweagent/run/hooks/open_pr.py
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
|
||||
from ghapi.all import GhApi
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.run.hooks.abstract import RunHook
|
||||
from sweagent.types import AgentRunResult
|
||||
from sweagent.utils.github import (
|
||||
InvalidGithubURL,
|
||||
_get_associated_commit_urls,
|
||||
_get_gh_issue_data,
|
||||
_parse_gh_issue_url,
|
||||
)
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
# NOTE
|
||||
# THE IMPLEMENTATION DETAILS HERE WILL CHANGE SOON!
|
||||
|
||||
|
||||
# fixme: Bring back the ability to open the PR to a fork
|
||||
def open_pr(*, logger, token, env: SWEEnv, github_url, trajectory, _dry_run: bool = False) -> None:
|
||||
"""Create PR to repository
|
||||
|
||||
Args:
|
||||
trajectory: Trajectory of actions taken by the agent
|
||||
_dry_run: Whether to actually push anything or just simulate it
|
||||
"""
|
||||
|
||||
issue_url = github_url
|
||||
logger.info("Opening PR")
|
||||
try:
|
||||
issue = _get_gh_issue_data(issue_url, token=token)
|
||||
except InvalidGithubURL as e:
|
||||
msg = "Data path must be a github issue URL if open_pr is set to True."
|
||||
raise ValueError(msg) from e
|
||||
branch_name = f"swe-agent-fix-#{issue.number}-" + str(random.random())[2:10]
|
||||
env.communicate(
|
||||
input="git config user.email 'noemail@swe-agent.com' && git config user.name 'SWE-agent'",
|
||||
error_msg="Failed to set git user",
|
||||
timeout=10,
|
||||
check="raise",
|
||||
)
|
||||
env.communicate(input="rm -f model.patch", error_msg="Failed to remove model patch", timeout=10, check="raise")
|
||||
env.communicate(
|
||||
input=f"git checkout -b {branch_name}", error_msg="Failed to switch to new branch", timeout=10, check="raise"
|
||||
)
|
||||
env.communicate(input="git add .", error_msg="Failed to add commits", timeout=10, check="raise")
|
||||
dry_run_flag = "--allow-empty" if _dry_run else ""
|
||||
commit_msg = [
|
||||
shlex.quote(f"Fix: {issue.title}"),
|
||||
shlex.quote(f"Closes #{issue.number}"),
|
||||
]
|
||||
out = env.communicate(
|
||||
input=f"git commit -m {commit_msg[0]} -m {commit_msg[1]} {dry_run_flag}",
|
||||
error_msg="Failed to commit changes",
|
||||
timeout=10,
|
||||
check="raise",
|
||||
)
|
||||
logger.debug(f"Committed changes: {out}")
|
||||
|
||||
owner, repo, _ = _parse_gh_issue_url(issue_url)
|
||||
# fixme: bring this back
|
||||
# If `--repo_path` was specified with a different github URL, then the record will contain
|
||||
# the forking user
|
||||
forker = owner
|
||||
head = branch_name
|
||||
remote = "origin"
|
||||
if forker != owner:
|
||||
head = f"{forker}:{branch_name}"
|
||||
token_prefix = ""
|
||||
if token:
|
||||
token_prefix = f"{token}@"
|
||||
fork_url = f"https://{token_prefix}github.com/{forker}/{repo}.git"
|
||||
logger.debug(f"Using fork: {fork_url}")
|
||||
env.communicate(
|
||||
input=f"git remote add fork {fork_url}",
|
||||
error_msg="Failed to create new git remote",
|
||||
timeout=10,
|
||||
)
|
||||
remote = "fork"
|
||||
dry_run_prefix = "echo " if _dry_run else ""
|
||||
out = env.communicate(
|
||||
input=f"{dry_run_prefix} git push {remote} {branch_name}",
|
||||
error_msg=(
|
||||
"Failed to push branch to remote. Please check your token and permissions. "
|
||||
"You might want to push to a fork with the push_gh_repo_url option."
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
logger.debug(f"Pushed commit to {remote=} {branch_name=}: {out}")
|
||||
body = (
|
||||
f"This is a PR opened by AI tool [SWE Agent](https://github.com/SWE-agent/SWE-agent/) "
|
||||
f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}."
|
||||
)
|
||||
body += "\n\n" + format_trajectory_markdown(trajectory, char_limit=60_000)
|
||||
api = GhApi(token=token)
|
||||
default_branch = api.repos.get(owner, repo).default_branch
|
||||
if not _dry_run:
|
||||
args = dict(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
title=f"SWE-agent[bot] PR to fix: {issue.title}",
|
||||
head=head,
|
||||
base=default_branch,
|
||||
body=body,
|
||||
draft=True,
|
||||
)
|
||||
logger.debug(f"Creating PR with args: {args}")
|
||||
pr_info = api.pulls.create(**args) # type: ignore
|
||||
logger.info(
|
||||
f"🎉 PR created as a draft at {pr_info.html_url}. Please review it carefully, push "
|
||||
"any required changes onto the branch and then click "
|
||||
"'Ready for Review' to bring it to the attention of the maintainers.",
|
||||
)
|
||||
|
||||
|
||||
class OpenPRConfig(BaseModel):
|
||||
# Option to be used with open_pr: Skip action if there are already commits claiming
|
||||
# to fix the issue. Please only set this to False if you are sure the commits are
|
||||
# not fixes or if this is your own repository!
|
||||
skip_if_commits_reference_issue: bool = True
|
||||
|
||||
|
||||
class OpenPRHook(RunHook):
|
||||
"""This hook opens a PR if the issue is solved and the user has enabled the option."""
|
||||
|
||||
def __init__(self, config: OpenPRConfig):
|
||||
self.logger = get_logger("swea-open_pr", emoji="⚡️")
|
||||
self._config = config
|
||||
|
||||
def on_init(self, *, run):
|
||||
self._env = run.env
|
||||
self._token: str = os.getenv("GITHUB_TOKEN", "")
|
||||
self._problem_statement = run.problem_statement
|
||||
|
||||
def on_instance_completed(self, result: AgentRunResult):
|
||||
if self.should_open_pr(result):
|
||||
open_pr(
|
||||
logger=self.logger,
|
||||
token=self._token,
|
||||
env=self._env,
|
||||
github_url=self._problem_statement.github_url,
|
||||
trajectory=result.trajectory,
|
||||
)
|
||||
|
||||
def should_open_pr(self, result: AgentRunResult) -> bool:
|
||||
"""Does opening a PR make sense?"""
|
||||
if not result.info.get("submission"):
|
||||
self.logger.info("Not opening PR because no submission was made.")
|
||||
return False
|
||||
if result.info.get("exit_status") != "submitted":
|
||||
self.logger.info(
|
||||
"Not opening PR because exit status was %s and not submitted.", result.info.get("exit_status")
|
||||
)
|
||||
return False
|
||||
try:
|
||||
issue = _get_gh_issue_data(self._problem_statement.github_url, token=self._token)
|
||||
except InvalidGithubURL:
|
||||
self.logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
|
||||
return False
|
||||
if issue.state != "open":
|
||||
self.logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
|
||||
return False
|
||||
if issue.assignee:
|
||||
self.logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
|
||||
return False
|
||||
if issue.locked:
|
||||
self.logger.info("Issue is locked. Skipping PR creation.")
|
||||
return False
|
||||
org, repo, issue_number = _parse_gh_issue_url(self._problem_statement.github_url)
|
||||
associated_commits = _get_associated_commit_urls(org, repo, issue_number, token=self._token)
|
||||
if associated_commits:
|
||||
commit_url_strs = ", ".join(associated_commits)
|
||||
if self._config.skip_if_commits_reference_issue:
|
||||
self.logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
|
||||
return False
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Proceeding with PR creation even though there are already commits "
|
||||
f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
|
||||
"or after verifying that the existing commits do not fix the issue.",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _remove_triple_backticks(text: str) -> str:
|
||||
return "\n".join(line.removeprefix("```") for line in text.splitlines())
|
||||
|
||||
|
||||
def format_trajectory_markdown(trajectory: list[dict[str, str]], char_limit: int | None = None):
|
||||
"""Format a trajectory as a markdown string for use in gh PR description.
|
||||
|
||||
Args:
|
||||
char_limit: If not None, truncate the trajectory to this many characters.
|
||||
"""
|
||||
prefix = [
|
||||
"<details>",
|
||||
"<summary>Thought process ('trajectory') of SWE-agent (click to expand)</summary>",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
prefix_text = "\n".join(prefix)
|
||||
suffix = [
|
||||
"",
|
||||
"</details>",
|
||||
]
|
||||
suffix_text = "\n".join(suffix)
|
||||
|
||||
steps = []
|
||||
current_length = len(prefix_text) + len(suffix_text)
|
||||
|
||||
for i, step in enumerate(trajectory):
|
||||
step_strs = [
|
||||
f"**🧑🚒 Response ({i})**: ",
|
||||
f"{step['response'].strip()}",
|
||||
f"**👀 Observation ({i})**:",
|
||||
"```",
|
||||
f"{_remove_triple_backticks(step['observation']).strip()}",
|
||||
"```",
|
||||
]
|
||||
step_text = "\n".join(step_strs)
|
||||
|
||||
# Calculate separator length (only needed for steps after the first one)
|
||||
separator_length = 0
|
||||
if steps:
|
||||
separator_length = len("\n\n---\n\n")
|
||||
|
||||
# Check if adding this step would exceed the character limit
|
||||
if char_limit is not None and current_length + separator_length + len(step_text) > char_limit:
|
||||
if i > 0:
|
||||
steps.append("\n\n... (truncated due to length limit)")
|
||||
break
|
||||
|
||||
if steps:
|
||||
steps.append("\n\n---\n\n")
|
||||
current_length += separator_length
|
||||
|
||||
steps.append(step_text)
|
||||
current_length += len(step_text)
|
||||
|
||||
return prefix_text + "".join(steps) + suffix_text
|
||||
113
.agent/vendor/mini-swe/sweagent/run/hooks/swe_bench_evaluate.py
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
"""SweBench evaluation hook.
|
||||
|
||||
Will be automatically added to `run_batch` if `SWEBenchInstances.evaluate` is set to true
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from time import time
|
||||
|
||||
from sweagent.run.hooks.abstract import RunHook
|
||||
from sweagent.run.merge_predictions import merge_predictions
|
||||
from sweagent.types import AgentRunResult
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class SweBenchEvaluate(RunHook):
|
||||
_SUBSET_MAP = {"lite": "swe-bench_lite", "verified": "swe-bench_verified", "multimodal": "swe-bench_multimodal"}
|
||||
|
||||
def __init__(self, output_dir: Path, subset: str, split: str, continuous_submission_every: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.output_dir = output_dir
|
||||
self.subset = subset
|
||||
self.split = split
|
||||
self.continuous_submission_every = continuous_submission_every
|
||||
self.logger = get_logger("SB-evaluate", emoji="😬")
|
||||
self.merge_lock = Lock()
|
||||
self.last_evaluation_time = time()
|
||||
self.evaluation_interval = continuous_submission_every
|
||||
self._running_calls = []
|
||||
# We need to add a suffix to the run_id to avoid collisions when you reuse the name of your run
|
||||
self._time_suffix = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
||||
|
||||
@property
|
||||
def run_id(self) -> str:
|
||||
return f"{self.output_dir.name}_{self._time_suffix}"
|
||||
|
||||
def _get_sb_call(self, preds_path: Path, submit_only: bool = False) -> list[str]:
|
||||
args = [
|
||||
"sb-cli",
|
||||
"submit",
|
||||
self._SUBSET_MAP[self.subset],
|
||||
self.split,
|
||||
"--predictions_path",
|
||||
str(preds_path),
|
||||
"--run_id",
|
||||
self.run_id,
|
||||
"--output_dir",
|
||||
str(self.output_dir / "sb-cli-reports"),
|
||||
]
|
||||
if submit_only:
|
||||
args.extend(["--wait_for_evaluation", "0", "--gen_report", "0", "--verify_submission", "0"])
|
||||
return args
|
||||
|
||||
def check_running_calls(self) -> None:
|
||||
"""Warn if one of the running calls failed."""
|
||||
for call in self._running_calls:
|
||||
if call.poll() is not None:
|
||||
if call.returncode != 0:
|
||||
self.logger.error("Failed to submit results to SweBench eval: %s", call.stderr.read())
|
||||
self._running_calls.remove(call)
|
||||
|
||||
def on_instance_completed(self, *, result: AgentRunResult):
|
||||
if self.evaluation_interval == 0:
|
||||
return
|
||||
|
||||
current_time = time()
|
||||
if current_time - self.last_evaluation_time < self.evaluation_interval:
|
||||
return
|
||||
|
||||
with self.merge_lock:
|
||||
merge_predictions([self.output_dir], self.output_dir / "tmppreds.json")
|
||||
self.last_evaluation_time = current_time
|
||||
|
||||
self._running_calls.append(
|
||||
subprocess.Popen(
|
||||
self._get_sb_call(preds_path=self.output_dir / "tmppreds.json", submit_only=True),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
|
||||
def move_sb_cli_report(self) -> None:
|
||||
"""Move report from `sb-cli-reports` to `results.json`."""
|
||||
output_dir = self.output_dir / "sb-cli-reports"
|
||||
if not output_dir.exists():
|
||||
self.logger.warning("No SweBench report found at %s", output_dir)
|
||||
return
|
||||
(self.output_dir / "results.json").unlink(missing_ok=True)
|
||||
reports = list(output_dir.glob("*.json"))
|
||||
if len(reports) != 1:
|
||||
self.logger.warning("Expected 1 SweBench report at %s, found %d. Cannot rename.", output_dir, len(reports))
|
||||
return
|
||||
reports[0].rename(self.output_dir / "results.json")
|
||||
|
||||
def on_end(self) -> None:
|
||||
self.logger.info("Submitting results to SWE-Bench")
|
||||
try:
|
||||
subprocess.run(
|
||||
self._get_sb_call(preds_path=self.output_dir / "preds.json"),
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.logger.error("Failed to submit results to SweBench eval: %s", e)
|
||||
else:
|
||||
# remove temporary predictions if they exist
|
||||
if (self.output_dir / "tmppreds.json").exists():
|
||||
(self.output_dir / "tmppreds.json").unlink()
|
||||
self.move_sb_cli_report()
|
||||
493
.agent/vendor/mini-swe/sweagent/run/inspector_cli.py
vendored
Normal file
@@ -0,0 +1,493 @@
|
||||
"""This is a command line tool to inspect trajectory JSON files."""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
from rich.syntax import Syntax
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Vertical, VerticalScroll
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Footer, Header, Input, ListItem, ListView, Static
|
||||
|
||||
from sweagent.utils.files import load_file
|
||||
from sweagent.utils.serialization import _yaml_serialization_with_linebreaks
|
||||
|
||||
|
||||
def _move_items_top(d: dict, keys: list[str]) -> dict:
|
||||
"""Reorder items in a dictionary.
|
||||
|
||||
The first keys will be those specified in `keys`, the rest will
|
||||
be in the same order as in the original dictionary.
|
||||
"""
|
||||
new_d = {}
|
||||
for key in keys:
|
||||
if key in d:
|
||||
new_d[key] = d[key]
|
||||
for key in d.keys():
|
||||
if key not in keys:
|
||||
new_d[key] = d[key]
|
||||
return new_d
|
||||
|
||||
|
||||
class TrajectoryViewer(Static):
|
||||
BINDINGS = [
|
||||
Binding("right,l", "next_item", "Step++"),
|
||||
Binding("left,h", "previous_item", "Step--"),
|
||||
Binding("0", "first_item", "Step=0"),
|
||||
Binding("$", "last_item", "Step=-1"),
|
||||
Binding("v", "toggle_view", "Toggle view"),
|
||||
Binding("j,down", "scroll_down", "Scroll down"),
|
||||
Binding("k,up", "scroll_up", "Scroll up"),
|
||||
]
|
||||
|
||||
def __init__(self, path: Path, title: str, overview_stats: dict, *, gold_patch: str | None = None):
|
||||
"""View a single trajectory."""
|
||||
super().__init__()
|
||||
self.i_step = -1
|
||||
self.trajectory = json.loads(path.read_text())
|
||||
self.show_full = False
|
||||
self.title = title
|
||||
self.overview_stats = overview_stats
|
||||
self.gold_patch = gold_patch
|
||||
|
||||
def load_trajectory(self, path: Path, title: str, overview_stats: dict, *, gold_patch: str | None = None):
|
||||
"""Load a new trajectory and update the viewer."""
|
||||
print("Loading", path)
|
||||
self.trajectory = json.loads(path.read_text())
|
||||
self.title = title
|
||||
self.gold_patch = gold_patch
|
||||
self.overview_stats = overview_stats
|
||||
self.scroll_top()
|
||||
self.i_step = -1
|
||||
self.update_content()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
with VerticalScroll():
|
||||
yield Static(id="content", markup=False)
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.update_content()
|
||||
|
||||
@property
|
||||
def n_steps(self) -> int:
|
||||
return len(self.trajectory["trajectory"])
|
||||
|
||||
def _show_step_yaml(self, item: dict) -> None:
|
||||
"""Show full yaml of trajectory item"""
|
||||
content_str = _yaml_serialization_with_linebreaks(
|
||||
_move_items_top(item, ["thought", "action", "observation", "response", "execution_time"])
|
||||
)
|
||||
syntax = Syntax(content_str, "yaml", theme="monokai", word_wrap=True)
|
||||
content = self.query_one("#content")
|
||||
content.update(syntax) # type: ignore
|
||||
self.app.sub_title = f"{self.title} - Step {self.i_step + 1}/{self.n_steps} - Full View"
|
||||
|
||||
def _show_step_simple(self, item: dict) -> None:
|
||||
# Simplified view - show action and observation as plain text
|
||||
thought = item.get("thought", "")
|
||||
action = item.get("action", "")
|
||||
observation = item.get("observation", "")
|
||||
|
||||
content_str = f"THOUGHT:\n{thought}\n\nACTION:\n{action}\n\nOBSERVATION:\n{observation}"
|
||||
content = self.query_one("#content")
|
||||
content.update(content_str) # type: ignore
|
||||
|
||||
self.app.sub_title = f"{self.title} - Step {self.i_step + 1}/{self.n_steps} - Simple View"
|
||||
|
||||
def _show_info(self):
|
||||
info = copy.deepcopy(self.trajectory["info"])
|
||||
info["result"] = self.overview_stats["result"]
|
||||
info["gold_patch"] = self.gold_patch
|
||||
info = _move_items_top(info, ["result", "exit_status", "model_stats", "submission", "gold_patch"])
|
||||
syntax = Syntax(_yaml_serialization_with_linebreaks(info), "yaml", theme="monokai", word_wrap=True)
|
||||
content = self.query_one("#content")
|
||||
content.update(syntax) # type: ignore
|
||||
next_help = "Press l to see step 1" if self.i_step < 0 else f"Press h to see step {self.n_steps}"
|
||||
self.app.sub_title = f"{self.title} - Info ({next_help})"
|
||||
|
||||
def update_content(self) -> None:
|
||||
print(self.i_step)
|
||||
if self.i_step < 0 or self.i_step >= self.n_steps:
|
||||
return self._show_info()
|
||||
|
||||
item = self.trajectory["trajectory"][self.i_step]
|
||||
|
||||
if self.show_full:
|
||||
return self._show_step_yaml(item)
|
||||
|
||||
return self._show_step_simple(item)
|
||||
|
||||
def action_next_item(self) -> None:
|
||||
if self.i_step < self.n_steps:
|
||||
self.i_step += 1
|
||||
self.scroll_top()
|
||||
self.update_content()
|
||||
|
||||
def action_previous_item(self) -> None:
|
||||
if self.i_step > -1:
|
||||
self.i_step -= 1
|
||||
self.scroll_top()
|
||||
self.update_content()
|
||||
|
||||
def action_toggle_view(self) -> None:
|
||||
self.show_full = not self.show_full
|
||||
self.update_content()
|
||||
|
||||
def action_first_item(self) -> None:
|
||||
self.i_step = 0
|
||||
self.update_content()
|
||||
|
||||
def action_last_item(self) -> None:
|
||||
self.i_step = self.n_steps - 1
|
||||
self.update_content()
|
||||
|
||||
def scroll_top(self) -> None:
|
||||
"""Resets scrolling viewport"""
|
||||
vs = self.query_one(VerticalScroll)
|
||||
vs.scroll_home(animate=False)
|
||||
|
||||
def action_scroll_down(self) -> None:
|
||||
vs = self.query_one(VerticalScroll)
|
||||
vs.scroll_to(y=vs.scroll_target_y + 15)
|
||||
|
||||
def action_scroll_up(self) -> None:
|
||||
vs = self.query_one(VerticalScroll)
|
||||
vs.scroll_to(y=vs.scroll_target_y - 15)
|
||||
|
||||
|
||||
class TrajectorySelectorScreen(ModalScreen[int]):
|
||||
BINDINGS = [
|
||||
Binding("escape", "dismiss(None)", "Cancel"),
|
||||
]
|
||||
|
||||
def __init__(self, paths: list[Path], current_index: int, overview_stats: dict):
|
||||
super().__init__()
|
||||
self.paths = paths
|
||||
self.current_index = current_index
|
||||
self.overview_stats = overview_stats
|
||||
self.all_items = [] # Store all items for filtering
|
||||
self.filtered_indices = []
|
||||
|
||||
def _get_list_item_texts(self, paths: list[Path]) -> list[str]:
|
||||
"""Remove the common prefix from a list of paths."""
|
||||
prefix = os.path.commonpath([str(p) for p in paths])
|
||||
labels = []
|
||||
for p in paths:
|
||||
ostat = self.overview_stats[p.stem]
|
||||
ostat_str = f"{ostat['exit_status']} {ostat['result']} ${ostat['cost']:.2f} {ostat['api_calls']} calls"
|
||||
shortened_path = str(p)[len(prefix) :].lstrip("/\\")
|
||||
if Path(shortened_path).stem == Path(shortened_path).parent.name:
|
||||
# We have the instance ID twice (in the folder and the traj)
|
||||
shortened_path = Path(shortened_path).stem
|
||||
labels.append(f"{shortened_path} - {ostat_str}")
|
||||
|
||||
return labels
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
with Vertical(id="dialog"):
|
||||
yield Static(
|
||||
"Press <TAB> to switch between search and list. Use <ARROW KEY>/<ENTER> to select.",
|
||||
id="title",
|
||||
markup=False,
|
||||
)
|
||||
yield Input(placeholder="Type to filter (auto-select if only one item remains)...", id="filter-input")
|
||||
yield ListView(
|
||||
*[ListItem(Static(p, markup=False)) for p in self._get_list_item_texts(self.paths)],
|
||||
id="trajectory-list",
|
||||
initial_index=self.current_index,
|
||||
)
|
||||
# Store all items for later filtering
|
||||
self.all_items = self._get_list_item_texts(self.paths)
|
||||
self.filtered_indices = list(range(len(self.all_items)))
|
||||
|
||||
def on_input_changed(self, event: Input.Changed) -> None:
|
||||
"""Filter list items based on input"""
|
||||
filter_text = event.value.lower()
|
||||
list_view = self.query_one("#trajectory-list", ListView)
|
||||
|
||||
# Filter items and keep track of original indices
|
||||
self.filtered_indices = [i for i, item in enumerate(self.all_items) if filter_text in item.lower()]
|
||||
filtered_items = [self.all_items[i] for i in self.filtered_indices]
|
||||
|
||||
if len(filtered_items) == 1:
|
||||
# Find the index of the filtered item in the original list
|
||||
selected_index = self.all_items.index(filtered_items[0])
|
||||
self.dismiss(selected_index)
|
||||
return
|
||||
|
||||
# Update ListView with filtered items
|
||||
list_view.clear()
|
||||
for item in filtered_items:
|
||||
list_view.append(ListItem(Static(item, markup=False)))
|
||||
|
||||
def on_list_view_selected(self, event: ListView.Selected) -> None:
|
||||
# Map the filtered index back to the original index
|
||||
original_index = self.filtered_indices[event.list_view.index]
|
||||
print(f"Selected index: {original_index}")
|
||||
self.dismiss(original_index)
|
||||
|
||||
CSS = """
|
||||
#dialog {
|
||||
background: $surface;
|
||||
padding: 1;
|
||||
border: thick $primary;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#title {
|
||||
text-align: center;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#filter-input {
|
||||
dock: top;
|
||||
margin: 1 0;
|
||||
}
|
||||
|
||||
ListView {
|
||||
height: 100%;
|
||||
border: solid $primary;
|
||||
}
|
||||
|
||||
ListItem {
|
||||
padding: 0 1;
|
||||
}
|
||||
|
||||
ListItem:hover {
|
||||
background: $accent;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class FileViewerScreen(ModalScreen):
|
||||
BINDINGS = [
|
||||
Binding("q,escape", "dismiss", "Back"),
|
||||
Binding("j,down", "scroll_down", "Scroll down"),
|
||||
Binding("k,up", "scroll_up", "Scroll up"),
|
||||
Binding("e", "open_editor", "Open in $EDITOR"),
|
||||
]
|
||||
|
||||
def __init__(self, path: Path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
with VerticalScroll():
|
||||
text = self.path.read_text()
|
||||
truncated = False
|
||||
if len(text) > 10_000:
|
||||
# More than ~1000 lines
|
||||
self.app.notify(
|
||||
"File is too large to display. Showing first 10k chars. Use e to open in editor.",
|
||||
severity="warning",
|
||||
)
|
||||
text = text[:10_000]
|
||||
truncated = True
|
||||
if self.path.exists():
|
||||
if self.path.suffix == ".traj" and not truncated:
|
||||
# Syntax highlighting breaks if we truncate
|
||||
content_str = _yaml_serialization_with_linebreaks(json.loads(text))
|
||||
syntax = Syntax(content_str, "yaml", theme="monokai", word_wrap=True)
|
||||
yield Static(syntax, markup=False)
|
||||
else:
|
||||
yield Static(text, markup=False)
|
||||
else:
|
||||
yield Static(f"No file found at {self.path}", markup=False)
|
||||
|
||||
def action_scroll_down(self) -> None:
|
||||
vs = self.query_one(VerticalScroll)
|
||||
vs.scroll_to(y=vs.scroll_target_y + 15)
|
||||
|
||||
def action_scroll_up(self) -> None:
|
||||
vs = self.query_one(VerticalScroll)
|
||||
vs.scroll_to(y=vs.scroll_target_y - 15)
|
||||
|
||||
async def action_open_editor(self) -> None:
|
||||
editor = os.environ.get("EDITOR")
|
||||
if not editor:
|
||||
self.app.notify("No editor found in $EDITOR environment variable, cannot perform action", severity="error")
|
||||
return
|
||||
try:
|
||||
# Suspend the TUI app to restore terminal state before launching editor
|
||||
with self.app.suspend():
|
||||
subprocess.run([editor, str(self.path)], check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
CSS = """
|
||||
ScrollableContainer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: $surface;
|
||||
padding: 1;
|
||||
border: thick $primary;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TrajectoryInspectorApp(App):
|
||||
BINDINGS = [
|
||||
Binding("q", "quit", "Quit"),
|
||||
Binding("L", "next_traj", "Traj++"),
|
||||
Binding("H", "previous_traj", "Traj--"),
|
||||
Binding("t", "show_traj_selector", "Select Traj"),
|
||||
Binding("o", "show_log", "View Log"),
|
||||
Binding("r", "show_full", "Show full"),
|
||||
]
|
||||
|
||||
CSS = """
|
||||
Screen {
|
||||
layout: grid;
|
||||
grid-size: 1;
|
||||
}
|
||||
|
||||
#viewer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
ScrollView {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border: solid green;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, input_path: str | Path, data_path: Path | None = None):
|
||||
super().__init__()
|
||||
self.input_path = Path(input_path)
|
||||
if not self.input_path.exists():
|
||||
msg = f"{self.input_path} doesn't exist"
|
||||
raise FileNotFoundError(msg)
|
||||
self.available_traj_paths = self._get_available_trajs()
|
||||
if not self.available_traj_paths:
|
||||
msg = "No trajectory *.traj files available"
|
||||
raise ValueError(msg)
|
||||
self.trajectory_index = 0
|
||||
self.overview_stats = collections.defaultdict(dict)
|
||||
self._build_overview_stats()
|
||||
self._data = load_file(data_path)
|
||||
|
||||
def get_gold_patch(self, instance_id: str) -> str | None:
|
||||
if self._data is None:
|
||||
return None
|
||||
return self._data.get(instance_id, {}).get("patch", None)
|
||||
|
||||
def _build_overview_stats(self):
|
||||
results_path = self.input_path / "results.json"
|
||||
results = None
|
||||
if results_path.exists():
|
||||
results = json.loads(results_path.read_text())
|
||||
for traj in self.available_traj_paths:
|
||||
instance_id = traj.stem
|
||||
if results is None:
|
||||
result = "❓"
|
||||
elif instance_id in results["resolved_ids"]:
|
||||
result = "✅"
|
||||
else:
|
||||
result = "❌"
|
||||
self.overview_stats[instance_id]["result"] = result
|
||||
|
||||
def _get_info(traj: Path) -> tuple[str, dict]:
|
||||
traj_info = json.loads(traj.read_text()).get("info", {})
|
||||
return traj.stem, traj_info
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Map returns results in the same order as inputs
|
||||
all_infos = executor.map(_get_info, self.available_traj_paths)
|
||||
|
||||
for instance_id, info in all_infos:
|
||||
self.overview_stats[instance_id]["info"] = info
|
||||
self.overview_stats[instance_id]["exit_status"] = info.get("exit_status", "?")
|
||||
self.overview_stats[instance_id]["api_calls"] = info.get("model_stats", {}).get("api_calls", 0)
|
||||
self.overview_stats[instance_id]["cost"] = info.get("model_stats", {}).get("instance_cost", 0)
|
||||
|
||||
def _get_viewer_title(self, index: int) -> str:
|
||||
instance_id = self.available_traj_paths[index].stem
|
||||
if len(instance_id) > 20:
|
||||
instance_id = "..." + instance_id[-17:]
|
||||
return f"Traj {index + 1}/{len(self.available_traj_paths)} - {instance_id}"
|
||||
|
||||
def _load_traj(self):
|
||||
instance_id = self.available_traj_paths[self.trajectory_index].stem
|
||||
traj_viewer = self.query_one(TrajectoryViewer)
|
||||
traj_viewer.load_trajectory(
|
||||
self.available_traj_paths[self.trajectory_index],
|
||||
self._get_viewer_title(self.trajectory_index),
|
||||
self.overview_stats[instance_id],
|
||||
gold_patch=self.get_gold_patch(instance_id),
|
||||
)
|
||||
|
||||
def _get_available_trajs(self) -> list[Path]:
|
||||
if self.input_path.is_file():
|
||||
return [self.input_path]
|
||||
elif self.input_path.is_dir():
|
||||
return sorted(self.input_path.rglob("*.traj"))
|
||||
raise ValueError
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
with Container():
|
||||
yield TrajectoryViewer(
|
||||
self.available_traj_paths[self.trajectory_index],
|
||||
self._get_viewer_title(self.trajectory_index),
|
||||
self.overview_stats[self.available_traj_paths[self.trajectory_index].stem],
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def action_next_traj(self):
|
||||
self.trajectory_index = (self.trajectory_index + 1) % len(self.available_traj_paths)
|
||||
self._load_traj()
|
||||
|
||||
def action_previous_traj(self):
|
||||
self.trajectory_index = (self.trajectory_index - 1) % len(self.available_traj_paths)
|
||||
self._load_traj()
|
||||
|
||||
async def action_show_traj_selector(self) -> None:
|
||||
selector = TrajectorySelectorScreen(self.available_traj_paths, self.trajectory_index, self.overview_stats)
|
||||
|
||||
def handler(index: int | None):
|
||||
if index is not None:
|
||||
self.trajectory_index = index
|
||||
self._load_traj()
|
||||
|
||||
await self.push_screen(selector, handler) # This returns when the modal is dismissed
|
||||
|
||||
async def action_show_log(self) -> None:
|
||||
current_traj = self.available_traj_paths[self.trajectory_index]
|
||||
log_path = current_traj.with_suffix(".debug.log")
|
||||
log_viewer = FileViewerScreen(log_path)
|
||||
await self.push_screen(log_viewer)
|
||||
|
||||
async def action_show_full(self) -> None:
|
||||
"""Show full yaml of trajectory file"""
|
||||
current_traj = self.available_traj_paths[self.trajectory_index]
|
||||
viewer = FileViewerScreen(current_traj)
|
||||
await self.push_screen(viewer)
|
||||
|
||||
|
||||
def main(args: list[str] | None = None):
|
||||
parser = argparse.ArgumentParser(description="Inspect trajectory JSON files")
|
||||
parser.add_argument(
|
||||
"trajectory_path",
|
||||
help="Path to the trajectory JSON file or directory containing trajectories",
|
||||
default=os.getcwd(),
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument("-d", "--data_path", type=Path, help="Path to the data file to load gold patches from")
|
||||
parsed_args = parser.parse_args(args)
|
||||
|
||||
app = TrajectoryInspectorApp(parsed_args.trajectory_path)
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
64
.agent/vendor/mini-swe/sweagent/run/merge_predictions.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
"""Merge multiple predictions into a single file."""
|
||||
|
||||
|
||||
logger = get_logger("merge", emoji="➕")
|
||||
|
||||
|
||||
def merge_predictions(directories: list[Path], output: Path | None = None) -> None:
|
||||
"""Merge predictions found in `directories` into a single JSON file.
|
||||
|
||||
Args:
|
||||
directory: Directory containing predictions.
|
||||
output: Output file. If not provided, the merged predictions will be
|
||||
written to `directory/preds.json`.
|
||||
"""
|
||||
preds = []
|
||||
for directory in directories:
|
||||
new = list(directory.rglob("*.pred"))
|
||||
preds.extend(new)
|
||||
logger.debug("Found %d predictions in %s", len(new), directory)
|
||||
logger.info("Found %d predictions", len(preds))
|
||||
if not preds:
|
||||
logger.warning("No predictions found in %s", directory)
|
||||
return
|
||||
if output is None:
|
||||
output = directories[0] / "preds.json"
|
||||
data = {}
|
||||
for pred in preds:
|
||||
_data = json.loads(pred.read_text())
|
||||
instance_id = _data["instance_id"]
|
||||
if "model_patch" not in _data:
|
||||
logger.warning("Prediction %s does not contain a model patch. SKIPPING", pred)
|
||||
continue
|
||||
# Ensure model_patch is a string
|
||||
_data["model_patch"] = str(_data["model_patch"]) if _data["model_patch"] is not None else ""
|
||||
if instance_id in data:
|
||||
msg = f"Duplicate instance ID found: {instance_id}"
|
||||
raise ValueError(msg)
|
||||
data[instance_id] = _data
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
output.write_text(json.dumps(data, indent=4))
|
||||
logger.info("Wrote merged predictions to %s", output)
|
||||
|
||||
|
||||
def get_cli_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("directories", type=Path, help="Directory containing predictions", nargs="+")
|
||||
parser.add_argument("--output", type=Path, help="Output file")
|
||||
return parser
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None) -> None:
|
||||
cli_parser = get_cli_parser()
|
||||
cli_args = cli_parser.parse_args(args)
|
||||
merge_predictions(cli_args.directories, cli_args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
96
.agent/vendor/mini-swe/sweagent/run/quick_stats.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
"""Calculate statistics from .traj files."""
|
||||
|
||||
logger = get_logger("quick-stats", emoji="📊")
|
||||
|
||||
|
||||
def quick_stats(directory: Path | str = ".") -> str:
|
||||
"""Calculate statistics from .traj files.
|
||||
|
||||
Args:
|
||||
directory: Directory to search for .traj files (default: current directory)
|
||||
|
||||
Returns:
|
||||
str: Summary of statistics
|
||||
"""
|
||||
directory = Path(directory)
|
||||
# Find all .traj files
|
||||
traj_files = list(directory.glob("**/*.traj"))
|
||||
|
||||
if not traj_files:
|
||||
logger.warning("No .traj files found in %s", directory)
|
||||
return "No .traj files found."
|
||||
|
||||
# Extract api_calls from each file
|
||||
api_calls = []
|
||||
files_by_exit_status = collections.defaultdict(list)
|
||||
|
||||
for file_path in traj_files:
|
||||
try:
|
||||
data = json.loads(file_path.read_text())
|
||||
# Extract the api_calls value using dictionary path
|
||||
if "info" in data and "model_stats" in data["info"] and "api_calls" in data["info"]["model_stats"]:
|
||||
api_calls.append(data["info"]["model_stats"]["api_calls"])
|
||||
if "info" in data and "exit_status" in data["info"]:
|
||||
status = data["info"]["exit_status"]
|
||||
files_by_exit_status[status].append(file_path)
|
||||
except Exception as e:
|
||||
logger.error("Error processing %s: %s", file_path, e)
|
||||
|
||||
files_by_exit_status = dict(sorted(files_by_exit_status.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
if not api_calls:
|
||||
logger.warning("No valid api_calls data found in the .traj files")
|
||||
return "No valid api_calls data found in the .traj files."
|
||||
|
||||
# Calculate and return the average
|
||||
logger.info("Exit statuses:")
|
||||
# Sort exit statuses by count (highest to lowest)
|
||||
for status, files in files_by_exit_status.items():
|
||||
logger.info("%s: %d", status, len(files))
|
||||
|
||||
average_api_calls = np.mean(api_calls)
|
||||
logger.info("Avg api calls: %s", average_api_calls)
|
||||
|
||||
# Print exit statuses in the requested format
|
||||
result = []
|
||||
for status, files in files_by_exit_status.items():
|
||||
result.append(f"\n## `{status}`\n")
|
||||
# Extract unique subdirectories instead of full paths
|
||||
subdirs = {str(Path(file_path).parent) for file_path in files}
|
||||
result.append(" ".join(subdirs))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def get_cli_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"directory",
|
||||
type=Path,
|
||||
nargs="?",
|
||||
default=Path("."),
|
||||
help="Directory to search for .traj files (default: current directory)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None) -> None:
|
||||
cli_parser = get_cli_parser()
|
||||
cli_args = cli_parser.parse_args(args)
|
||||
|
||||
result = quick_stats(cli_args.directory)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
63
.agent/vendor/mini-swe/sweagent/run/remove_unfinished.py
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Remove unfinished trajectories."""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from sweagent.utils.files import load_file
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
logger = get_logger("remove_unfinished")
|
||||
|
||||
|
||||
def remove_unfinished(base_dir: Path, dry_run: bool = True) -> None:
|
||||
"""Remove unfinished trajectories."""
|
||||
to_remove = []
|
||||
for directory in base_dir.iterdir():
|
||||
if not directory.is_dir():
|
||||
continue
|
||||
if "__" not in directory.name:
|
||||
continue
|
||||
trajs = list(directory.glob("*.traj"))
|
||||
if not trajs:
|
||||
logger.info("No trajectories found in %s", directory)
|
||||
continue
|
||||
if len(trajs) > 1:
|
||||
logger.warning("Found multiple trajectories in %s. Skipping.", directory)
|
||||
continue
|
||||
try:
|
||||
traj = load_file(trajs[0])
|
||||
except Exception as e:
|
||||
logger.warning("Error loading trajectory %s: %s. Adding to remove list.", trajs[0], e)
|
||||
to_remove.append(directory)
|
||||
continue
|
||||
submission = traj.get("info", {}).get("submission", None)
|
||||
if submission is None:
|
||||
logger.warning("No submission found in %s. Adding to remove list.", directory)
|
||||
to_remove.append(directory)
|
||||
continue
|
||||
if dry_run:
|
||||
logger.info("Would remove %d unfinished trajectories.", len(to_remove))
|
||||
for directory in to_remove:
|
||||
logger.info(directory)
|
||||
else:
|
||||
for directory in to_remove:
|
||||
logger.info("Removing %s", directory)
|
||||
shutil.rmtree(directory)
|
||||
|
||||
|
||||
def get_cli_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--base_dir", type=Path, help="Base directory", default=Path("."))
|
||||
parser.add_argument("--remove", action="store_true", help="Remove unfinished trajectories")
|
||||
return parser
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None) -> None:
|
||||
cli_parser = get_cli_parser()
|
||||
cli_args = cli_parser.parse_args(args)
|
||||
remove_unfinished(cli_args.base_dir, dry_run=not cli_args.remove)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
91
.agent/vendor/mini-swe/sweagent/run/rich_test.py
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from random import random
|
||||
from threading import Lock
|
||||
|
||||
from rich.console import Group
|
||||
from rich.live import Live
|
||||
from rich.logging import RichHandler
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskID,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
logging.basicConfig(level="NOTSET", handlers=[RichHandler(level="NOTSET")])
|
||||
logger = logging.getLogger("rich")
|
||||
|
||||
# Lock for thread-safe progress updates
|
||||
progress_lock = Lock()
|
||||
|
||||
|
||||
class RunBatch:
|
||||
def __init__(self):
|
||||
self.tasks = list(range(10)) # Reduced to 10 tasks for example clarity
|
||||
self._main_progress_bar: Progress | None = None
|
||||
self._task_progress_bar: Progress | None = None
|
||||
self._spinner_tasks: dict[TaskID, TaskID] = {}
|
||||
|
||||
def do_task(self, task_id: TaskID):
|
||||
assert self._main_progress_bar is not None
|
||||
assert self._task_progress_bar is not None
|
||||
|
||||
# Create a spinner for this task
|
||||
with progress_lock:
|
||||
spinner_task_id = self._task_progress_bar.add_task(f"Task {task_id}", total=None)
|
||||
|
||||
logger.info("Starting task %d", task_id)
|
||||
# Startup
|
||||
time.sleep(random() * 4.5)
|
||||
# Work
|
||||
with progress_lock:
|
||||
self._task_progress_bar.update(spinner_task_id, description=f"Task {task_id} (working)")
|
||||
time.sleep(random() * 4.5 + 2)
|
||||
logger.info("Finished task %d", task_id)
|
||||
|
||||
# Remove spinner and update main progress
|
||||
with progress_lock:
|
||||
self._task_progress_bar.remove_task(spinner_task_id)
|
||||
self._main_progress_bar.update(TaskID(0), advance=1)
|
||||
|
||||
def main(self):
|
||||
# Custom progress columns
|
||||
self._main_progress_bar = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
self._task_progress_bar = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
TimeElapsedColumn(),
|
||||
)
|
||||
|
||||
group = Group(self._main_progress_bar, self._task_progress_bar)
|
||||
|
||||
with Live(group):
|
||||
# Add main progress bar
|
||||
self._main_task_id = self._main_progress_bar.add_task("[cyan]Overall Progress", total=len(self.tasks))
|
||||
|
||||
# Create thread pool and run tasks
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(self.do_task, task_id) for task_id in self.tasks]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_batch = RunBatch()
|
||||
run_batch.main()
|
||||
147
.agent/vendor/mini-swe/sweagent/run/run.py
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
"""[cyan][bold]Main command line interface for SWE-agent.[/bold][/cyan]
|
||||
|
||||
[cyan][bold]=== USAGE ===[/bold][/cyan]
|
||||
|
||||
[green]sweagent <command> [options][/green]
|
||||
|
||||
Display usage instructions for a specific command:
|
||||
|
||||
[green]sweagent <command> [bold]--help[/bold][/green]
|
||||
|
||||
[cyan][bold]=== SUBCOMMANDS TO RUN SWE-AGENT ===[/bold][/cyan]
|
||||
|
||||
[bold][green]run[/green][/bold] or [bold][green]r[/green][/bold]: Run swe-agent on a single problem statement, for example a github issue.
|
||||
[bold][green]run-batch[/green][/bold] or [bold][green]b[/green][/bold]: Run swe-agent on a batch of problem statements, e.g., on SWE-Bench.
|
||||
|
||||
[cyan][bold]=== MISC SUBCOMMANDS ===[/bold][/cyan]
|
||||
|
||||
[bold][green]merge-preds[/green][/bold]: Merge multiple prediction files into a single file. In most cases
|
||||
[green]run-batch[/green] will already do this, but you can use this to merge predictions
|
||||
from multiple directories.
|
||||
[bold][green]inspect[/green][/bold] or [bold][green]i[/green][/bold]: Open a single trajectory file in a terminal-based viewer.
|
||||
[bold][green]inspector[/green][/bold] or [bold][green]I[/green][/bold]: Open trajectories in a web-based viewer.
|
||||
[bold][green]run-replay[/green][/bold]: Replay a trajectory file or a demo file.
|
||||
This can be useful to fill in environment output when creating demonstrations.
|
||||
[bold][green]traj-to-demo[/green][/bold]: Convert a trajectory file to an easy to edit demo file.
|
||||
[bold][green]run-api[/green][/bold]: Run swe-agent as a backend for a GUI
|
||||
[bold][green]remove-unfinished[/green][/bold] or [bold][green]ru[/green][/bold]: Remove unfinished trajectories
|
||||
[bold][green]quick-stats[/green][/bold] or [bold][green]qs[/green][/bold]: Calculate quick stats from a directory of trajectories
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import rich
|
||||
|
||||
|
||||
def get_cli():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument(
|
||||
"command",
|
||||
choices=[
|
||||
"run",
|
||||
"run-batch",
|
||||
"run-replay",
|
||||
"traj-to-demo",
|
||||
"run-api",
|
||||
"merge-preds",
|
||||
"inspect",
|
||||
"inspector",
|
||||
"r",
|
||||
"b",
|
||||
"i",
|
||||
"I",
|
||||
"extract-pred",
|
||||
"compare-runs",
|
||||
"cr",
|
||||
"remove-unfinished",
|
||||
"ru",
|
||||
"quick-stats",
|
||||
"qs",
|
||||
"shell",
|
||||
"sh",
|
||||
],
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument("-h", "--help", action="store_true", help="Show this help message and exit")
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: list[str] | None = None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
cli = get_cli()
|
||||
parsed_args, remaining_args = cli.parse_known_args(args) # type: ignore
|
||||
command = parsed_args.command
|
||||
show_help = parsed_args.help
|
||||
if show_help:
|
||||
if not command:
|
||||
# Show main help
|
||||
rich.print(__doc__)
|
||||
sys.exit(0)
|
||||
else:
|
||||
# Add to remaining_args
|
||||
remaining_args.append("--help")
|
||||
elif not command:
|
||||
cli.print_help()
|
||||
sys.exit(2)
|
||||
# Defer imports to avoid unnecessary long loading times
|
||||
if command in ["run", "r"]:
|
||||
from sweagent.run.run_single import run_from_cli as run_single_main
|
||||
|
||||
run_single_main(remaining_args)
|
||||
elif command in ["run-batch", "b"]:
|
||||
from sweagent.run.run_batch import run_from_cli as run_batch_main
|
||||
|
||||
run_batch_main(remaining_args)
|
||||
elif command == "run-replay":
|
||||
from sweagent.run.run_replay import run_from_cli as run_replay_main
|
||||
|
||||
run_replay_main(remaining_args)
|
||||
elif command == "traj-to-demo":
|
||||
from sweagent.run.run_traj_to_demo import run_from_cli as convert_traj_to_demo_main
|
||||
|
||||
convert_traj_to_demo_main(remaining_args)
|
||||
elif command == "run-api":
|
||||
from sweagent.api.server import run_from_cli as run_api_main
|
||||
|
||||
run_api_main(remaining_args)
|
||||
elif command == "merge-preds":
|
||||
from sweagent.run.merge_predictions import run_from_cli as merge_predictions_main
|
||||
|
||||
merge_predictions_main(remaining_args)
|
||||
elif command in ["inspector", "I"]:
|
||||
from sweagent.inspector.server import run_from_cli as inspector_main
|
||||
|
||||
inspector_main(remaining_args)
|
||||
elif command in ["inspect", "i"]:
|
||||
from sweagent.run.inspector_cli import main as inspect_main
|
||||
|
||||
inspect_main(remaining_args)
|
||||
elif command == "extract-pred":
|
||||
from sweagent.run.extract_pred import run_from_cli as extract_pred_main
|
||||
|
||||
extract_pred_main(remaining_args)
|
||||
elif command in ["compare-runs", "cr"]:
|
||||
from sweagent.run.compare_runs import run_from_cli as compare_runs_main
|
||||
|
||||
compare_runs_main(remaining_args)
|
||||
elif command in ["remove-unfinished", "ru"]:
|
||||
from sweagent.run.remove_unfinished import run_from_cli as remove_unfinished_main
|
||||
|
||||
remove_unfinished_main(remaining_args)
|
||||
elif command in ["quick-stats", "qs"]:
|
||||
from sweagent.run.quick_stats import run_from_cli as quick_stats_main
|
||||
|
||||
quick_stats_main(remaining_args)
|
||||
elif command in ["shell", "sh"]:
|
||||
from sweagent.run.run_shell import run_from_cli as run_shell_main
|
||||
|
||||
run_shell_main(remaining_args)
|
||||
else:
|
||||
msg = f"Unknown command: {command}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
442
.agent/vendor/mini-swe/sweagent/run/run_batch.py
vendored
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Run on a batch of instances/issues, e.g., SWE-bench.
|
||||
|
||||
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
|
||||
|
||||
-h --help Show help text and exit
|
||||
--help_option Print specific help text and exit
|
||||
|
||||
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
|
||||
|
||||
Basic usage: Run over a [bold][cyan]SWE-bench lite[/bold][/cyan][green]:
|
||||
|
||||
sweagent run-batch \\
|
||||
--instances.type swe_bench \\ # configure instances
|
||||
--instances.subset lite \\
|
||||
--instances.split dev \\
|
||||
--instances.slice :50 \\ # first 50 instances
|
||||
--instances.shuffle=True \\ # shuffle instances (with fixed seed)
|
||||
--config config/default.yaml \\
|
||||
--agent.model.name gpt-4o # configure model
|
||||
[/green]
|
||||
|
||||
[cyan][bold]=== LOADING INSTANCES ===[/bold][/cyan]
|
||||
|
||||
[cyan][bold]From a file[/bold][/cyan] [green]--instances.type file --instances.path /path/to/file[/green].
|
||||
[cyan][bold]From huggingface[/bold][/cyan] [green]--instances.type huggingface --instances.dataset_name=SWE_Bench_lite --instances.split=dev[/green].
|
||||
|
||||
All instance specifications support the [green]filter[/green], [green]slice[/green], and [green]shuffle[/green] options.
|
||||
With [green]filter[/green], you can select specific instances, e.g., [green]--instances.filter='instance_id_1|instance_id_2'[/green].
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
import yaml
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from rich.live import Live
|
||||
from swerex.deployment.hooks.status import SetStatusDeploymentHook
|
||||
|
||||
from sweagent import TRAJECTORY_DIR
|
||||
from sweagent.agent.agents import AgentConfig, get_agent_from_config
|
||||
from sweagent.agent.hooks.status import SetStatusAgentHook
|
||||
from sweagent.environment.hooks.status import SetStatusEnvironmentHook
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.exceptions import ModelConfigurationError, TotalCostLimitExceededError
|
||||
from sweagent.run._progress import RunBatchProgressManager
|
||||
from sweagent.run.batch_instances import BatchInstance, BatchInstanceSourceConfig, SWEBenchInstances
|
||||
from sweagent.run.common import BasicCLI, ConfigHelper, save_predictions
|
||||
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
|
||||
from sweagent.run.hooks.apply_patch import SaveApplyPatchHook
|
||||
from sweagent.run.merge_predictions import merge_predictions
|
||||
from sweagent.run.run_single import RunSingleConfig
|
||||
from sweagent.types import AgentRunResult
|
||||
from sweagent.utils.config import load_environment_variables
|
||||
from sweagent.utils.log import (
|
||||
add_file_handler,
|
||||
add_logger_names_to_stream_handlers,
|
||||
get_logger,
|
||||
register_thread_name,
|
||||
remove_file_handler,
|
||||
set_stream_handler_levels,
|
||||
)
|
||||
|
||||
|
||||
class RunBatchConfig(BaseSettings, cli_implicit_flags=False):
|
||||
instances: BatchInstanceSourceConfig = Field(description="Instances to run.")
|
||||
agent: AgentConfig = Field(description="Agent options.")
|
||||
output_dir: Path = Field(default=Path("DEFAULT"), description="Output directory.")
|
||||
suffix: str = ""
|
||||
"""Suffix to add to the output directory. Only used if `output_dir` is `DEFAULT`."""
|
||||
raise_exceptions: bool = False
|
||||
"""Raise exceptions instead of skipping instances."""
|
||||
redo_existing: bool = False
|
||||
"""Do not skip instances that already have a trajectory."""
|
||||
env_var_path: Path | None = None
|
||||
"""Path to a .env file to load environment variables from."""
|
||||
num_workers: int = Field(default=1)
|
||||
"""Number of parallel workers to use."""
|
||||
random_delay_multiplier: float = 0.3
|
||||
"""We will wait for a random amount of time between 0 and `random_delay_multiplier`
|
||||
times the number of workers at the start of each instance. This is to avoid any
|
||||
potential race condition or issues with bottlenecks, e.g., when running on a platform
|
||||
with few CPUs that cannot handle the startup of all containers in time.
|
||||
"""
|
||||
progress_bar: bool = True
|
||||
"""Whether to show a progress bar. Progress bar is never shown for human models.
|
||||
Progress bar is always shown for multi-worker runs.
|
||||
"""
|
||||
|
||||
# pydantic config
|
||||
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
|
||||
|
||||
def set_default_output_dir(self) -> None:
|
||||
# Needs to be called explicitly, because self._config_files will be setup
|
||||
# post-init.
|
||||
if self.output_dir == Path("DEFAULT"):
|
||||
user_id = getpass.getuser()
|
||||
source_id = self.instances.id
|
||||
try:
|
||||
model_id = self.agent.model.id # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
model_id = "unknown"
|
||||
config_file = getattr(self, "_config_files", ["no_config"])[0]
|
||||
if config_file != "no_config":
|
||||
config_file = Path(config_file).stem
|
||||
suffix = f"__{self.suffix}" if self.suffix else ""
|
||||
self.output_dir = TRAJECTORY_DIR / user_id / f"{config_file}__{model_id}___{source_id}{suffix}"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def evaluate_and_redo_existing(self) -> Self:
|
||||
if not isinstance(self.instances, SWEBenchInstances):
|
||||
return self
|
||||
if self.instances.evaluate and self.redo_existing:
|
||||
msg = (
|
||||
"Cannot evaluate and redo existing at the same time. This would cause invalid results, because "
|
||||
"after the first merge_preds gives you a preds.json, this file would be submitted to SB-CLI, causing"
|
||||
"evaluation of old instances, which could then not be overwritten by the new ones."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
class _BreakLoop(Exception):
|
||||
"""Used for internal control flow"""
|
||||
|
||||
|
||||
class RunBatch:
|
||||
def __init__(
|
||||
self,
|
||||
instances: list[BatchInstance],
|
||||
agent_config: AgentConfig,
|
||||
*,
|
||||
output_dir: Path = Path("."),
|
||||
hooks: list[RunHook] | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
redo_existing: bool = False,
|
||||
num_workers: int = 1,
|
||||
progress_bar: bool = True,
|
||||
random_delay_multiplier: float = 0.3,
|
||||
):
|
||||
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
|
||||
See `from_config` for an example.
|
||||
|
||||
Args:
|
||||
hooks: If not specified, the default hooks will be used.
|
||||
num_workers: Number of parallel workers to use. Default is 1 (sequential execution).
|
||||
progress_bar: Whether to show a progress bar. Progress bar is never shown for human models.
|
||||
Progress bar is always shown for multi-worker runs.
|
||||
random_delay_multiplier: We will wait for a random amount of time between 0 and `random_delay_multiplier`
|
||||
times the number of workers at the start of each instance. This is to avoid any
|
||||
potential race conditions.
|
||||
"""
|
||||
if self._model_id in ["human", "human_thought"] and num_workers > 1:
|
||||
msg = "Cannot run with human model in parallel"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.logger = get_logger("swea-run", emoji="🏃")
|
||||
add_file_handler(
|
||||
output_dir / "run_batch.log",
|
||||
id_="progress",
|
||||
filter=lambda name: "swea-run" in name or "config" in name,
|
||||
)
|
||||
self.instances = instances
|
||||
self.agent_config = agent_config
|
||||
self.output_dir = output_dir
|
||||
self._raise_exceptions = raise_exceptions
|
||||
self._chooks = CombinedRunHooks()
|
||||
self._redo_existing = redo_existing
|
||||
self._num_workers = min(num_workers, len(instances))
|
||||
for hook in hooks or [SaveApplyPatchHook(show_success_message=False)]:
|
||||
self.add_hook(hook)
|
||||
self._progress_manager = RunBatchProgressManager(
|
||||
num_instances=len(instances), yaml_report_path=output_dir / "run_batch_exit_statuses.yaml"
|
||||
)
|
||||
self._show_progress_bar = progress_bar
|
||||
self._random_delay_multiplier = random_delay_multiplier
|
||||
|
||||
@property
|
||||
def _model_id(self) -> str:
|
||||
try:
|
||||
return self.agent_config.model.id # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
return "unknown"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RunBatchConfig) -> Self:
|
||||
load_environment_variables(config.env_var_path)
|
||||
config.set_default_output_dir()
|
||||
config.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
(config.output_dir / "run_batch.config.yaml").write_text(yaml.dump(config.model_dump_json(), indent=2))
|
||||
logger = get_logger("run", emoji="🏃")
|
||||
logger.debug("Loading instances from %s", f"{config.instances!r}")
|
||||
instances = config.instances.get_instance_configs()
|
||||
logger.info("Loaded %d instances", len(instances))
|
||||
if not instances:
|
||||
msg = (
|
||||
"No instances to run. Here are a few things to check:\n"
|
||||
"- With huggingface data: Check that you have the right split (test or dev)\n"
|
||||
"- Check your filter does not exclude all instances (check the info log messages)"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
logger.debug("The first instance is %s", f"{instances[0]!r}")
|
||||
rb = cls(
|
||||
instances=instances,
|
||||
agent_config=config.agent,
|
||||
output_dir=config.output_dir,
|
||||
raise_exceptions=config.raise_exceptions,
|
||||
redo_existing=config.redo_existing,
|
||||
num_workers=config.num_workers,
|
||||
progress_bar=config.progress_bar,
|
||||
random_delay_multiplier=config.random_delay_multiplier,
|
||||
)
|
||||
if isinstance(config.instances, SWEBenchInstances) and config.instances.evaluate:
|
||||
from sweagent.run.hooks.swe_bench_evaluate import SweBenchEvaluate
|
||||
|
||||
rb.add_hook(
|
||||
SweBenchEvaluate(
|
||||
output_dir=config.output_dir,
|
||||
subset=config.instances.subset,
|
||||
split=config.instances.split,
|
||||
continuous_submission_every=30,
|
||||
)
|
||||
)
|
||||
return rb
|
||||
|
||||
def add_hook(self, hook: RunHook) -> None:
|
||||
hook.on_init(run=self)
|
||||
self._chooks.add_hook(hook)
|
||||
|
||||
def main(self) -> None:
|
||||
self.logger.info("Starting run. Find output files at %s", self.output_dir)
|
||||
self._chooks.on_start()
|
||||
|
||||
if self._num_workers <= 1:
|
||||
self.main_single_worker()
|
||||
else:
|
||||
self.main_multi_worker()
|
||||
|
||||
output_dirs = []
|
||||
for instance in self.instances:
|
||||
output_dirs.append(self.output_dir / instance.problem_statement.id)
|
||||
merge_predictions(output_dirs, self.output_dir / "preds.json")
|
||||
|
||||
self._chooks.on_end()
|
||||
|
||||
def main_single_worker(self) -> None:
|
||||
with ExitStack() as stack:
|
||||
# Conditionally add progress bar
|
||||
if self._model_id not in ["human", "human_thought"] and self._show_progress_bar:
|
||||
stack.enter_context(Live(self._progress_manager.render_group))
|
||||
for instance in self.instances:
|
||||
try:
|
||||
self.run_instance(instance)
|
||||
except _BreakLoop:
|
||||
self.logger.info("Stopping loop over instances")
|
||||
break
|
||||
|
||||
def main_multi_worker(self) -> None:
|
||||
add_logger_names_to_stream_handlers()
|
||||
# Set all stream handlers to WARNING and set everything where we want to have
|
||||
# more verbosity explicitly
|
||||
set_stream_handler_levels(logging.WARNING)
|
||||
self.logger.setLevel(logging.TRACE) # type: ignore
|
||||
|
||||
with Live(self._progress_manager.render_group):
|
||||
with ThreadPoolExecutor(max_workers=self._num_workers) as executor:
|
||||
futures = [executor.submit(self.run_instance, instance) for instance in self.instances]
|
||||
try:
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
except (KeyboardInterrupt, _BreakLoop):
|
||||
msg = (
|
||||
"Received keyboard interrupt, waiting for running instances "
|
||||
"to finish, but cancelled everything else"
|
||||
)
|
||||
self.logger.info(msg)
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
finally:
|
||||
self._progress_manager.print_report()
|
||||
|
||||
def run_instance(self, instance: BatchInstance) -> None:
|
||||
self.logger.info("Running on instance %s", instance.problem_statement.id)
|
||||
register_thread_name(instance.problem_statement.id)
|
||||
self._add_instance_log_file_handlers(instance.problem_statement.id, multi_worker=self._num_workers > 1)
|
||||
# Let's add some randomness to avoid any potential race conditions or thundering herd
|
||||
if self._progress_manager.n_completed < self._num_workers:
|
||||
time.sleep(random.random() * self._random_delay_multiplier * (self._num_workers - 1))
|
||||
|
||||
self._progress_manager.on_instance_start(instance.problem_statement.id)
|
||||
|
||||
if previous_exit_status := self.should_skip(instance):
|
||||
self._progress_manager.on_instance_end(
|
||||
instance.problem_statement.id, exit_status=f"skipped ({previous_exit_status})"
|
||||
)
|
||||
self._remove_instance_log_file_handlers(instance.problem_statement.id)
|
||||
return
|
||||
|
||||
# Either catch and silence exception, or raise _BreakLoop to stop the loop
|
||||
# over the instances
|
||||
try:
|
||||
result = self._run_instance(instance)
|
||||
except KeyboardInterrupt:
|
||||
raise _BreakLoop
|
||||
except (SystemExit, ModelConfigurationError, TotalCostLimitExceededError) as e:
|
||||
if self._raise_exceptions:
|
||||
raise
|
||||
self.logger.critical(f"❌ Exiting because {e.__class__.__name__} was called")
|
||||
raise _BreakLoop
|
||||
except Exception as e:
|
||||
self.logger.error(traceback.format_exc())
|
||||
self.logger.error(f"❌ Failed on {instance.problem_statement.id}: {e}")
|
||||
self._progress_manager.on_uncaught_exception(instance.problem_statement.id, e)
|
||||
if self._raise_exceptions:
|
||||
raise
|
||||
else:
|
||||
self._progress_manager.on_instance_end(
|
||||
instance.problem_statement.id, exit_status=result.info.get("exit_status", "unknown_exit")
|
||||
)
|
||||
finally:
|
||||
self._progress_manager.update_exit_status_table()
|
||||
self._remove_instance_log_file_handlers(instance.problem_statement.id)
|
||||
|
||||
def _run_instance(self, instance: BatchInstance) -> AgentRunResult:
|
||||
output_dir = Path(self.output_dir) / instance.problem_statement.id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.agent_config.name = f"{instance.problem_statement.id}"
|
||||
agent = get_agent_from_config(self.agent_config)
|
||||
single_run_replay_config = RunSingleConfig(
|
||||
agent=self.agent_config,
|
||||
problem_statement=instance.problem_statement,
|
||||
env=instance.env,
|
||||
)
|
||||
(output_dir / f"{instance.problem_statement.id}.config.yaml").write_text(
|
||||
yaml.dump(single_run_replay_config.model_dump_json(), indent=2)
|
||||
)
|
||||
agent.replay_config = single_run_replay_config # type: ignore[attr-defined]
|
||||
agent.add_hook(SetStatusAgentHook(instance.problem_statement.id, self._progress_manager.update_instance_status))
|
||||
self._progress_manager.update_instance_status(instance.problem_statement.id, "Starting environment")
|
||||
instance.env.name = f"{instance.problem_statement.id}"
|
||||
env = SWEEnv.from_config(instance.env)
|
||||
env.add_hook(
|
||||
SetStatusEnvironmentHook(instance.problem_statement.id, self._progress_manager.update_instance_status)
|
||||
)
|
||||
env.deployment.add_hook(
|
||||
SetStatusDeploymentHook(instance.problem_statement.id, self._progress_manager.update_instance_status)
|
||||
)
|
||||
try:
|
||||
env.start()
|
||||
self._chooks.on_instance_start(index=0, env=env, problem_statement=instance.problem_statement)
|
||||
result = agent.run(
|
||||
problem_statement=instance.problem_statement,
|
||||
env=env,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
except Exception:
|
||||
# The actual handling is happening in `run_instance`, but we need to make sure that
|
||||
# we log it to the agent specific logger as well
|
||||
agent.logger.error(traceback.format_exc()) # type: ignore[attr-defined]
|
||||
raise
|
||||
finally:
|
||||
env.close()
|
||||
save_predictions(self.output_dir, instance.problem_statement.id, result)
|
||||
self._chooks.on_instance_completed(result=result)
|
||||
return result
|
||||
|
||||
def should_skip(self, instance: BatchInstance) -> bool | str:
|
||||
"""Check if we should skip this instance.
|
||||
Returns previous exit status if the instance should be skipped.
|
||||
"""
|
||||
if self._redo_existing:
|
||||
return False
|
||||
|
||||
# Check if there's an existing trajectory for this instance
|
||||
log_path = self.output_dir / instance.problem_statement.id / (instance.problem_statement.id + ".traj")
|
||||
if not log_path.exists():
|
||||
return False
|
||||
|
||||
content = log_path.read_text()
|
||||
if not content.strip():
|
||||
self.logger.warning("Found empty trajectory: %s. Removing.", log_path)
|
||||
log_path.unlink()
|
||||
return False
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
# If the trajectory has no exit status, it's incomplete and we will redo it
|
||||
exit_status = data["info"].get("exit_status", None)
|
||||
if exit_status == "early_exit" or exit_status is None:
|
||||
self.logger.warning(f"Found existing trajectory with no exit status: {log_path}. Removing.")
|
||||
log_path.unlink()
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to check existing trajectory: {log_path}: {e}. Removing.")
|
||||
# If we can't check the trajectory, we will redo it
|
||||
log_path.unlink()
|
||||
return False
|
||||
# otherwise, we will skip it
|
||||
self.logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
|
||||
return exit_status
|
||||
|
||||
def _add_instance_log_file_handlers(self, instance_id: str, multi_worker: bool = False) -> None:
|
||||
filename_template = f"{instance_id}.{{level}}.log"
|
||||
for level in ["trace", "debug", "info"]:
|
||||
filter = instance_id if multi_worker else ""
|
||||
add_file_handler(
|
||||
self.output_dir / instance_id / filename_template.format(level=level),
|
||||
filter=filter,
|
||||
level=level,
|
||||
id_=f"{instance_id}-{level}",
|
||||
)
|
||||
|
||||
def _remove_instance_log_file_handlers(self, instance_id: str) -> None:
|
||||
for level in ["trace", "debug", "info"]:
|
||||
remove_file_handler(f"{instance_id}-{level}")
|
||||
|
||||
|
||||
def run_from_config(config: RunBatchConfig):
|
||||
RunBatch.from_config(config).main()
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
assert __doc__ is not None
|
||||
help_text = ( # type: ignore
|
||||
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunBatchConfig)
|
||||
)
|
||||
run_from_config(BasicCLI(RunBatchConfig, help_text=help_text).get_config(args)) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
219
.agent/vendor/mini-swe/sweagent/run/run_replay.py
vendored
Normal file
@@ -0,0 +1,219 @@
|
||||
"""[cyan][bold]Replay a trajectory file.[/bold][/cyan]
|
||||
|
||||
[cyan][bold]=== DESCRIPTION ===[/bold][/cyan]
|
||||
|
||||
We will take all actions in the trajectory and execute them in an environment.
|
||||
|
||||
This has two main use cases:
|
||||
|
||||
1. Create a demo from a yaml file containing actions (can also be created from a trajectory file with [green]sweagent run traj-to-demo[/green]).
|
||||
[green]run-replay[/green] will execute the actions to get the environment output and produce a full trajectory to be used as a demo.
|
||||
2. Debugging and testing of tools and environment behavior.
|
||||
|
||||
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
|
||||
|
||||
Replay a trajectory file:
|
||||
|
||||
[green]sweagent run replay --traj_path mytraj.traj[/green]
|
||||
|
||||
Replay a demo file:
|
||||
|
||||
[green]sweagent run replay --traj_path mydemo.demo.yaml[/green]
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from getpass import getuser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from swerex.deployment.abstract import AbstractDeployment
|
||||
from swerex.deployment.config import DeploymentConfig, get_deployment
|
||||
from typing_extensions import Self
|
||||
|
||||
from sweagent.agent.agents import DefaultAgent
|
||||
from sweagent.agent.models import ReplayModelConfig
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.run.common import BasicCLI, ConfigHelper
|
||||
from sweagent.run.run_single import RunSingle, RunSingleConfig
|
||||
from sweagent.utils.config import load_environment_variables
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class RunReplayConfig(BaseSettings, cli_implicit_flags=False):
|
||||
traj_path: Path
|
||||
deployment: DeploymentConfig | None = None
|
||||
"""Override the deployment in the trajectory."""
|
||||
output_dir: Path = Path("DEFAULT")
|
||||
env_var_path: Path | None = None
|
||||
"""Path to a .env file to load environment variables from."""
|
||||
update_config: list[Path] = []
|
||||
"""Additional config files to merge with the replay config."""
|
||||
|
||||
# pydantic config
|
||||
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.output_dir == Path("DEFAULT"):
|
||||
user_id = getuser()
|
||||
self.output_dir = Path.cwd() / "trajectories" / user_id / f"replay___{self.traj_path.stem}"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class RunReplay:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
traj_path: Path,
|
||||
deployment: AbstractDeployment | None,
|
||||
output_dir: Path,
|
||||
update_config: list[Path] | None = None,
|
||||
_catch_errors: bool = False,
|
||||
_require_zero_exit_code: bool = False,
|
||||
):
|
||||
self.traj_path = traj_path
|
||||
self.output_dir = output_dir
|
||||
self._replay_action_trajs_path = Path(tempfile.NamedTemporaryFile(suffix=".json").name)
|
||||
self.logger = get_logger("swea-run", emoji="🏃")
|
||||
self._catch_errors = _catch_errors
|
||||
self._require_zero_exit_code = _require_zero_exit_code
|
||||
self._update_config = update_config if update_config is not None else []
|
||||
|
||||
if traj_path.suffix == ".yaml":
|
||||
self._traj_data = yaml.safe_load(traj_path.read_text())
|
||||
else:
|
||||
self._traj_data = json.loads(traj_path.read_text())
|
||||
self.config = self._get_config_from_agent(self._traj_data)
|
||||
|
||||
if deployment is None:
|
||||
self.deployment = get_deployment(self.config.env.deployment)
|
||||
else:
|
||||
self.deployment = deployment
|
||||
|
||||
def _get_config_from_agent(self, traj_data):
|
||||
try:
|
||||
if isinstance(traj_data["replay_config"], str):
|
||||
traj_data["replay_config"] = json.loads(traj_data["replay_config"])
|
||||
config = RunSingleConfig.model_validate(traj_data["replay_config"])
|
||||
except KeyError:
|
||||
msg = "Replay config not found in trajectory. Are you running on an old trajectory?"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Merge any additional config files
|
||||
for config_path in self._update_config:
|
||||
update_data = yaml.safe_load(config_path.read_text())
|
||||
# Store the current model config before merging
|
||||
current_model = config.agent.model
|
||||
# Convert the merged data back to a RunSingleConfig
|
||||
config_dict = config.model_dump(mode="json")
|
||||
merged_dict = config_dict | update_data
|
||||
|
||||
# Ensure agent.model is preserved if not explicitly updated
|
||||
if "agent" in merged_dict and "model" not in merged_dict["agent"]:
|
||||
merged_dict["agent"]["model"] = current_model.model_dump(mode="json")
|
||||
|
||||
config = RunSingleConfig.model_validate(merged_dict)
|
||||
|
||||
config.agent.model = ReplayModelConfig(replay_path=self._replay_action_trajs_path)
|
||||
return config
|
||||
|
||||
@property
|
||||
def instance_id(self) -> str:
|
||||
return Path(self.traj_path).stem
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RunReplayConfig, **kwargs) -> Self:
|
||||
load_environment_variables(config.env_var_path)
|
||||
return cls(
|
||||
traj_path=config.traj_path,
|
||||
deployment=get_deployment(config.deployment) if config.deployment else None,
|
||||
output_dir=config.output_dir,
|
||||
update_config=config.update_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _create_actions_file(self) -> None:
|
||||
# Verify config compatibility with tool calls
|
||||
has_tool_calls = any(
|
||||
"tool_calls" in item and item["tool_calls"] is not None
|
||||
for item in self._traj_data["history"]
|
||||
if item["role"] == "assistant"
|
||||
)
|
||||
|
||||
agent_config = self.config.agent
|
||||
parse_function = agent_config.tools.parse_function.type
|
||||
use_function_calling = parse_function == "function_calling"
|
||||
|
||||
if has_tool_calls and not use_function_calling:
|
||||
msg = (
|
||||
"Trajectory contains tool calls but config is not set up for function calling. "
|
||||
"Check that the config you want to use has agent.tools.parse_function.type set to 'function_calling'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
actions = []
|
||||
for ix, item in enumerate(self._traj_data["history"]):
|
||||
if item["role"] != "assistant":
|
||||
continue
|
||||
action = {"message": item["content"]}
|
||||
if use_function_calling:
|
||||
assert "tool_calls" in item and item["tool_calls"] is not None, (
|
||||
f"Config is set to use `function_calling` but trajectory item {ix} is missing a tool call "
|
||||
f"or has tool_calls set to None"
|
||||
)
|
||||
action["tool_calls"] = item["tool_calls"]
|
||||
actions.append(action)
|
||||
if len(actions) == 0:
|
||||
msg = "No actions found in trajectory"
|
||||
raise ValueError(msg)
|
||||
self._replay_action_trajs_path.write_text(json.dumps({self.instance_id: actions}))
|
||||
|
||||
def _get_env(self) -> SWEEnv:
|
||||
return SWEEnv(
|
||||
deployment=self.deployment,
|
||||
repo=self.config.env.repo,
|
||||
post_startup_commands=[],
|
||||
)
|
||||
|
||||
def _get_agent(self) -> DefaultAgent:
|
||||
agent = DefaultAgent.from_config(self.config.agent)
|
||||
agent._catch_errors = self._catch_errors
|
||||
agent._always_require_zero_exit_code = self._require_zero_exit_code
|
||||
return agent
|
||||
|
||||
def _get_run_single(self) -> RunSingle:
|
||||
return RunSingle(
|
||||
self._get_env(),
|
||||
self._get_agent(),
|
||||
problem_statement=self.config.problem_statement,
|
||||
output_dir=Path(self.output_dir),
|
||||
)
|
||||
|
||||
def main(self):
|
||||
self._create_actions_file()
|
||||
run_single = self._get_run_single()
|
||||
run_single.agent.replay_config = RunSingleConfig(
|
||||
agent=self.config.agent,
|
||||
problem_statement=run_single.problem_statement,
|
||||
env=self.config.env,
|
||||
)
|
||||
run_single.run()
|
||||
|
||||
|
||||
def run_from_config(config: RunReplayConfig):
|
||||
RunReplay.from_config(config).main()
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
help_text = ( # type: ignore
|
||||
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunReplayConfig)
|
||||
)
|
||||
run_from_config(BasicCLI(RunReplayConfig, help_text=help_text, default_settings=False).get_config(args)) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
155
.agent/vendor/mini-swe/sweagent/run/run_shell.py
vendored
Normal file
@@ -0,0 +1,155 @@
|
||||
"""[cyan][bold]Run SWE-agent in semi-interactive mode.[/bold][/cyan]
|
||||
|
||||
[cyan][bold]sweagen-sh is EXPERIMENTAL[/bold][/cyan]
|
||||
|
||||
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
|
||||
|
||||
-h --help Show help text and exit
|
||||
--help_option Print specific help text and exit
|
||||
--config CONFIG Load additional config files. Use this option multiple times to load
|
||||
multiple files, e.g., --config config1.yaml --config config2.yaml
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from rich.prompt import Prompt
|
||||
from swerex.deployment.config import DockerDeploymentConfig
|
||||
|
||||
from sweagent import CONFIG_DIR
|
||||
from sweagent.agent.agents import AbstractAgent, ShellAgentConfig
|
||||
from sweagent.agent.extra.shell_agent import ShellAgent
|
||||
from sweagent.agent.problem_statement import (
|
||||
GithubIssue,
|
||||
ProblemStatement,
|
||||
ProblemStatementConfig,
|
||||
TextProblemStatement,
|
||||
)
|
||||
from sweagent.environment.repo import PreExistingRepoConfig
|
||||
from sweagent.environment.swe_env import EnvironmentConfig, SWEEnv
|
||||
from sweagent.run.common import save_predictions
|
||||
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
|
||||
from sweagent.utils.config import load_environment_variables
|
||||
from sweagent.utils.github import _is_github_issue_url
|
||||
from sweagent.utils.log import add_file_handler, get_logger, set_stream_handler_levels
|
||||
|
||||
|
||||
class RunShell:
|
||||
def __init__(
|
||||
self,
|
||||
env: SWEEnv,
|
||||
agent: AbstractAgent,
|
||||
problem_statement: ProblemStatement | ProblemStatementConfig,
|
||||
*,
|
||||
output_dir: Path = Path("."),
|
||||
hooks: list[RunHook] | None = None,
|
||||
):
|
||||
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
|
||||
See `from_config` for an example.
|
||||
"""
|
||||
self.logger = get_logger("swea-run", emoji="🏃")
|
||||
instance_id = problem_statement.id
|
||||
_log_filename_template = f"{instance_id}.{{level}}.log"
|
||||
for level in ["trace", "debug", "info"]:
|
||||
add_file_handler(
|
||||
output_dir / instance_id / _log_filename_template.format(level=level),
|
||||
level=level,
|
||||
id_=f"{instance_id}-{level}",
|
||||
)
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.output_dir = output_dir
|
||||
self._hooks = []
|
||||
self._chooks = CombinedRunHooks()
|
||||
self.problem_statement = problem_statement
|
||||
for hook in hooks or []:
|
||||
self.add_hook(hook)
|
||||
|
||||
@property
|
||||
def hooks(self) -> list[RunHook]:
|
||||
return self._chooks.hooks
|
||||
|
||||
def add_hook(self, hook: RunHook) -> None:
|
||||
hook.on_init(run=self)
|
||||
self._chooks.add_hook(hook)
|
||||
|
||||
def run(self):
|
||||
self._chooks.on_start()
|
||||
self.logger.info("Starting environment")
|
||||
self.env.start()
|
||||
self.logger.info("Running agent")
|
||||
self._chooks.on_instance_start(index=0, env=self.env, problem_statement=self.problem_statement)
|
||||
output_dir = self.output_dir / self.problem_statement.id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
result = self.agent.run(
|
||||
problem_statement=self.problem_statement,
|
||||
env=self.env,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
self._chooks.on_instance_completed(result=result)
|
||||
self.logger.info("Done")
|
||||
self._chooks.on_end()
|
||||
save_predictions(self.output_dir, self.problem_statement.id, result)
|
||||
self.env.close()
|
||||
|
||||
|
||||
def get_cli():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("-r", "--repo", type=Path, help="Path to the repository.", default=None)
|
||||
# parser.add_argument(dest="--model", type=str, help="Model to use.", default="claude-sonnet-4-20250514")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=Path,
|
||||
help="Path to the agent config file.",
|
||||
default=CONFIG_DIR / "exotic" / "default_shell.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
type=str,
|
||||
help="Problem statement.",
|
||||
default="",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
set_stream_handler_levels(logging.INFO)
|
||||
cli_args = get_cli().parse_args(args)
|
||||
try:
|
||||
load_environment_variables(Path(".env"))
|
||||
except FileNotFoundError:
|
||||
print("Env file .env not found, please set API key as env variables.")
|
||||
env_config = EnvironmentConfig(
|
||||
repo=PreExistingRepoConfig(repo_name="repo", reset=False),
|
||||
deployment=DockerDeploymentConfig(
|
||||
image="python:3.11",
|
||||
docker_args=[
|
||||
"-v",
|
||||
f"{cli_args.repo}:/repo",
|
||||
],
|
||||
python_standalone_dir="/root",
|
||||
),
|
||||
)
|
||||
agent_config = ShellAgentConfig.model_validate(yaml.safe_load(cli_args.config.read_text())["agent"])
|
||||
agent = ShellAgent.from_config(agent_config)
|
||||
env = SWEEnv.from_config(env_config)
|
||||
if cli_args.repo is None:
|
||||
cli_args.repo = Path(Prompt.ask("[cyan]Repository path[/cyan]", default="", show_default=False))
|
||||
problem_input = cli_args.p
|
||||
if not problem_input:
|
||||
problem_input = Prompt.ask("[cyan]Problem statement or GitHub issue URL[/cyan]", default="", show_default=False)
|
||||
if _is_github_issue_url(problem_input):
|
||||
problem_statement = GithubIssue(github_url=problem_input)
|
||||
else:
|
||||
problem_statement = TextProblemStatement(
|
||||
text=problem_input,
|
||||
)
|
||||
run_shell = RunShell(env, agent, problem_statement=problem_statement, output_dir=Path.home() / "sweagent_shell")
|
||||
run_shell.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
225
.agent/vendor/mini-swe/sweagent/run/run_single.py
vendored
Normal file
@@ -0,0 +1,225 @@
|
||||
"""[cyan][bold]Run SWE-agent on a single instance taken from github or similar.[/bold][/cyan]
|
||||
|
||||
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
|
||||
|
||||
-h --help Show help text and exit
|
||||
--help_option Print specific help text and exit
|
||||
--config CONFIG Load additional config files. Use this option multiple times to load
|
||||
multiple files, e.g., --config config1.yaml --config config2.yaml
|
||||
|
||||
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
|
||||
|
||||
Basic usage: Run over a [bold][cyan]github issue[/bold][/cyan][green]:
|
||||
|
||||
sweagent run --config config/default.yaml --agent.model.name "gpt-4o" \\
|
||||
--env.repo.github_url=https://github.com/SWE-agent/test-repo/ \\
|
||||
--problem_statement.github_url=https://github.com/SWE-agent/test-repo/issues/1
|
||||
[/green]
|
||||
|
||||
By default this will start a docker container and run the agent in there.
|
||||
You can set the image with [green]--env.docker.image[/green].
|
||||
|
||||
Here's an example that uses [bold][cyan]modal[/bold][/cyan] instead of docker and also a [bold][cyan]local repository[/bold][/cyan]:
|
||||
|
||||
[green]sweagent run --config config/default.yaml --agent.model.name "gpt-4o" \\
|
||||
--env.deployment.type=modal --env.repo.path /path/to/repo \\
|
||||
--problem_statement.path=path/to/problem_statement.md
|
||||
[/green]
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from sweagent.agent.agents import AbstractAgent, AgentConfig, get_agent_from_config
|
||||
from sweagent.agent.problem_statement import (
|
||||
EmptyProblemStatement,
|
||||
ProblemStatement,
|
||||
ProblemStatementConfig,
|
||||
)
|
||||
from sweagent.environment.swe_env import EnvironmentConfig, SWEEnv
|
||||
from sweagent.run.common import AutoCorrectSuggestion as ACS
|
||||
from sweagent.run.common import BasicCLI, ConfigHelper, save_predictions
|
||||
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
|
||||
from sweagent.run.hooks.apply_patch import SaveApplyPatchHook
|
||||
from sweagent.run.hooks.open_pr import OpenPRConfig, OpenPRHook
|
||||
from sweagent.utils.config import load_environment_variables
|
||||
from sweagent.utils.log import add_file_handler, get_logger
|
||||
|
||||
|
||||
class RunSingleActionConfig(BaseModel):
|
||||
"""Run real-life actions (opening PRs, etc.) if we can solve the issue."""
|
||||
|
||||
# Open a PR with the patch if we can solve the issue
|
||||
open_pr: bool = False
|
||||
pr_config: OpenPRConfig = Field(default_factory=OpenPRConfig)
|
||||
# When working with local repository: Apply patch
|
||||
apply_patch_locally: bool = False
|
||||
|
||||
# pydantic config
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def _get_default_output_dir(output_dir: Path, problem_statement: ProblemStatement, agent: AgentConfig) -> Path:
|
||||
if output_dir == Path("DEFAULT"):
|
||||
user_id = getpass.getuser()
|
||||
problem_id = problem_statement.id
|
||||
try:
|
||||
model_id = agent.model.id # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
model_id = "unknown_model"
|
||||
config_file = getattr(agent, "_config_files", ["no_config"])[0]
|
||||
if isinstance(config_file, Path):
|
||||
config_file = config_file.stem
|
||||
return Path.cwd() / "trajectories" / user_id / f"{config_file}__{model_id}___{problem_id}"
|
||||
return output_dir
|
||||
|
||||
|
||||
class RunSingleConfig(BaseSettings, cli_implicit_flags=False):
|
||||
env: EnvironmentConfig = Field(default_factory=EnvironmentConfig, description="Environment options.")
|
||||
agent: AgentConfig = Field(description="Agent options.")
|
||||
problem_statement: ProblemStatementConfig = Field(
|
||||
default_factory=EmptyProblemStatement, description="Problem statement options."
|
||||
)
|
||||
output_dir: Path = Field(default=Path("DEFAULT"), description="Output directory.")
|
||||
|
||||
actions: RunSingleActionConfig = Field(default_factory=RunSingleActionConfig)
|
||||
|
||||
env_var_path: Path | None = None
|
||||
"""Path to a .env file to load environment variables from."""
|
||||
|
||||
# pydantic config
|
||||
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
|
||||
|
||||
def set_default_output_dir(self) -> None:
|
||||
# Needs to be called explicitly, because self._config_files will be setup
|
||||
# post-init.
|
||||
self.output_dir = _get_default_output_dir(self.output_dir, self.problem_statement, self.agent)
|
||||
|
||||
@classmethod
|
||||
def _get_auto_correct(cls) -> list[ACS]:
|
||||
return [
|
||||
ACS("model", "agent.model.name"),
|
||||
ACS("agent.model", "agent.model.name"),
|
||||
ACS("model.name", "agent.model.name"),
|
||||
ACS("per_instance_cost_limit", "agent.model.per_instance_cost_limit"),
|
||||
ACS("model.per_instance_cost_limit", "agent.model.per_instance_cost_limit"),
|
||||
ACS("config_file", "config"),
|
||||
ACS(
|
||||
"data_path",
|
||||
help="--data_path is no longer support for SWE-A 1.0. Please check the tutorial and use one of the --problem_statement options, e.g., --problem_statement.github_url or --problem_statement.path",
|
||||
),
|
||||
ACS(
|
||||
"repo_path",
|
||||
help="--repo_path is no longer support for SWE-A 1.0. Please check the tutorial and use one of the --env.repo options, e.g., --env.repo.github_url or --env.repo.path",
|
||||
),
|
||||
ACS("repo.path", "env.repo.path"),
|
||||
]
|
||||
|
||||
|
||||
class RunSingle:
|
||||
def __init__(
|
||||
self,
|
||||
env: SWEEnv,
|
||||
agent: AbstractAgent,
|
||||
problem_statement: ProblemStatement | ProblemStatementConfig,
|
||||
*,
|
||||
output_dir: Path = Path("."),
|
||||
hooks: list[RunHook] | None = None,
|
||||
actions: RunSingleActionConfig | None = None,
|
||||
):
|
||||
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
|
||||
See `from_config` for an example.
|
||||
"""
|
||||
self.logger = get_logger("swea-run", emoji="🏃")
|
||||
instance_id = problem_statement.id
|
||||
_log_filename_template = f"{instance_id}.{{level}}.log"
|
||||
for level in ["trace", "debug", "info"]:
|
||||
add_file_handler(
|
||||
output_dir / instance_id / _log_filename_template.format(level=level),
|
||||
level=level,
|
||||
id_=f"{instance_id}-{level}",
|
||||
)
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.output_dir = output_dir
|
||||
self._hooks = []
|
||||
if actions is not None:
|
||||
actions = RunSingleActionConfig()
|
||||
self.actions = actions
|
||||
self._chooks = CombinedRunHooks()
|
||||
self.problem_statement = problem_statement
|
||||
for hook in hooks or []:
|
||||
self.add_hook(hook)
|
||||
|
||||
@property
|
||||
def hooks(self) -> list[RunHook]:
|
||||
return self._chooks.hooks
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RunSingleConfig) -> Self:
|
||||
load_environment_variables(config.env_var_path)
|
||||
config.set_default_output_dir()
|
||||
config.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
agent = get_agent_from_config(config.agent)
|
||||
agent.replay_config = config # type: ignore[attr-defined]
|
||||
self = cls(
|
||||
env=SWEEnv.from_config(config.env),
|
||||
agent=agent,
|
||||
problem_statement=config.problem_statement,
|
||||
output_dir=config.output_dir,
|
||||
actions=config.actions,
|
||||
)
|
||||
self.add_hook(SaveApplyPatchHook(apply_patch_locally=config.actions.apply_patch_locally))
|
||||
if config.actions.open_pr:
|
||||
self.logger.debug("Adding OpenPRHook")
|
||||
self.add_hook(OpenPRHook(config.actions.pr_config))
|
||||
return self
|
||||
|
||||
def add_hook(self, hook: RunHook) -> None:
|
||||
hook.on_init(run=self)
|
||||
self._chooks.add_hook(hook)
|
||||
|
||||
def run(self):
|
||||
self._chooks.on_start()
|
||||
self.logger.info("Starting environment")
|
||||
self.env.start()
|
||||
self.logger.info("Running agent")
|
||||
self._chooks.on_instance_start(index=0, env=self.env, problem_statement=self.problem_statement)
|
||||
output_dir = self.output_dir / self.problem_statement.id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if self.agent.replay_config is not None: # type: ignore[attr-defined]
|
||||
(output_dir / "config.yaml").write_text(yaml.dump(self.agent.replay_config.model_dump_json(), indent=2)) # type: ignore[attr-defined]
|
||||
result = self.agent.run(
|
||||
problem_statement=self.problem_statement,
|
||||
env=self.env,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
self._chooks.on_instance_completed(result=result)
|
||||
self.logger.info("Done")
|
||||
self._chooks.on_end()
|
||||
save_predictions(self.output_dir, self.problem_statement.id, result)
|
||||
self.env.close()
|
||||
|
||||
|
||||
def run_from_config(config: RunSingleConfig):
|
||||
RunSingle.from_config(config).run()
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
assert __doc__ is not None
|
||||
help_text = ( # type: ignore
|
||||
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunSingleConfig)
|
||||
)
|
||||
run_from_config(BasicCLI(RunSingleConfig, help_text=help_text).get_config(args)) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
85
.agent/vendor/mini-swe/sweagent/run/run_traj_to_demo.py
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Convert a trajectory file to a yaml file for editing of demos.
|
||||
You can then load the yaml file with `run_replay.py` to replay the actions in an environment to get
|
||||
environment output.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from sweagent.utils.log import get_logger
|
||||
from sweagent.utils.serialization import _yaml_serialization_with_linebreaks
|
||||
|
||||
logger = get_logger("traj2demo")
|
||||
|
||||
DEMO_COMMENT = """# This is a demo file generated from trajectory file:
|
||||
# {traj_path}
|
||||
# You can use this demo file to replay the actions in the trajectory with run_replay.py.
|
||||
# You can edit the content of the actions in this file to modify the replay behavior.
|
||||
# NOTICE:
|
||||
# Only the actions of the assistant will be replayed.
|
||||
# You do not need to modify the observation's contents or any other fields.
|
||||
# You can add or remove actions to modify the replay behavior."""
|
||||
|
||||
|
||||
def save_demo(data: str | dict | list, file: Path, traj_path: Path) -> None:
|
||||
"""Save demo data as a yaml file. Takes care of multi-line strings and adds a header."""
|
||||
content = _yaml_serialization_with_linebreaks(data)
|
||||
header = DEMO_COMMENT.format(traj_path=str(traj_path))
|
||||
with open(file, "w") as f:
|
||||
f.write(f"{header}\n{content}")
|
||||
|
||||
|
||||
def convert_traj_to_action_demo(traj_path: Path, output_file: Path, include_user: bool = False) -> None:
|
||||
with open(traj_path) as file:
|
||||
traj = json.load(file)
|
||||
replay_config = traj["replay_config"]
|
||||
if isinstance(traj["replay_config"], str):
|
||||
replay_config = json.loads(traj["replay_config"])
|
||||
history = traj["history"]
|
||||
|
||||
copy_fields = {"content", "role", "tool_calls", "agent", "message_type", "tool_call_ids"}
|
||||
|
||||
admissible_roles = {"assistant", "user", "tool"} if include_user else {"assistant"}
|
||||
filtered_history = [
|
||||
{k: v for k, v in step.items() if k in copy_fields}
|
||||
for step in history
|
||||
if step["role"] in admissible_roles
|
||||
and step.get("agent", "main") in {"main", "primary"}
|
||||
and not step.get("is_demo")
|
||||
]
|
||||
|
||||
output_data = {"history": filtered_history, "replay_config": replay_config}
|
||||
save_demo(output_data, output_file, traj_path)
|
||||
logger.info(f"Saved demo to {output_file}")
|
||||
|
||||
|
||||
def main(traj_path: Path, output_dir: Path, suffix: str = "", overwrite: bool = False, include_user: bool = False):
|
||||
output_file = output_dir / (traj_path.parent.name + suffix) / (traj_path.stem.removesuffix(".traj") + ".demo.yaml")
|
||||
if output_file.exists() and not overwrite:
|
||||
msg = f"Output file already exists: {output_file}. Use --overwrite to overwrite."
|
||||
raise FileExistsError(msg)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
convert_traj_to_action_demo(traj_path, output_file, include_user)
|
||||
|
||||
|
||||
def run_from_cli(args: list[str] | None = None):
|
||||
"""Convert a trajectory file to a demo file."""
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument("traj_path", type=Path, help="Path to trajectory file")
|
||||
parser.add_argument("--output_dir", type=Path, help="Output directory for action demos", default=Path("./demos"))
|
||||
parser.add_argument("--suffix", type=str, help="Suffix for the output file", default="")
|
||||
parser.add_argument("--overwrite", help="Overwrite existing files", action="store_true")
|
||||
parser.add_argument(
|
||||
"--include_user",
|
||||
help="Include user responses (computer)",
|
||||
action="store_true",
|
||||
)
|
||||
parsed_args = parser.parse_args(args)
|
||||
main(**vars(parsed_args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_from_cli()
|
||||
0
.agent/vendor/mini-swe/sweagent/tools/__init__.py
vendored
Normal file
57
.agent/vendor/mini-swe/sweagent/tools/bundle.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
|
||||
from sweagent.tools.commands import Command
|
||||
from sweagent.utils.config import _convert_path_to_abspath
|
||||
|
||||
|
||||
class BundleConfig(BaseModel):
|
||||
tools: dict[str, dict]
|
||||
state_command: str | None = None
|
||||
|
||||
|
||||
class Bundle(BaseModel):
|
||||
path: Path
|
||||
hidden_tools: list[str] = Field(default_factory=list)
|
||||
_config: BundleConfig = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tools(self):
|
||||
self.path = _convert_path_to_abspath(self.path)
|
||||
if not self.path.exists():
|
||||
msg = f"Bundle path '{self.path}' does not exist."
|
||||
raise ValueError(msg)
|
||||
|
||||
config_path = self.path / "config.yaml"
|
||||
if not config_path.exists():
|
||||
msg = f"Bundle config file '{config_path}' does not exist."
|
||||
raise ValueError(msg)
|
||||
|
||||
config_data = yaml.safe_load(config_path.read_text())
|
||||
self._config = BundleConfig(**config_data)
|
||||
|
||||
invalid_hidden_tools = set(self.hidden_tools) - set(self._config.tools.keys())
|
||||
if invalid_hidden_tools:
|
||||
msg = f"Hidden tools {invalid_hidden_tools} do not exist in available tools"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
@property
|
||||
def state_command(self) -> str | None:
|
||||
return self.config.state_command
|
||||
|
||||
@property
|
||||
def config(self) -> BundleConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def commands(self) -> list[Command]:
|
||||
return [
|
||||
Command(name=tool, **tool_config.model_dump() if isinstance(tool_config, Command) else tool_config)
|
||||
for tool, tool_config in self.config.tools.items()
|
||||
if tool not in self.hidden_tools
|
||||
]
|
||||
220
.agent/vendor/mini-swe/sweagent/tools/commands.py
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Core module for defining and parsing commands in the SWE Agent system.
|
||||
|
||||
This module provides the foundational classes and utilities for defining commands that can be executed by the agent.
|
||||
It is used extensively by:
|
||||
|
||||
- tools.py: For command installation, execution and environment management
|
||||
- parsing.py: For parsing model outputs into executable commands
|
||||
- utils.py: For handling multi-line commands and argument quoting
|
||||
|
||||
Key Classes:
|
||||
- Command: Represents an executable command with arguments and documentation
|
||||
- Argument: Defines an argument that can be passed to a command
|
||||
|
||||
The module supports both simple bash commands and complex multi-line commands with typed arguments.
|
||||
Commands can be defined either in bash scripts with YAML docstrings or as bash functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
from functools import cached_property
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from sweagent.utils.jinja_warnings import _warn_probably_wrong_jinja_syntax
|
||||
|
||||
ARGUMENT_NAME_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_-]*"
|
||||
|
||||
|
||||
def _extract_keys(format_string: str) -> set[str]:
|
||||
"""Given a format string, returns a set of all the keys in the format string.
|
||||
|
||||
Used for validating that command signatures match their argument definitions.
|
||||
|
||||
Args:
|
||||
format_string: A Python format string containing named fields
|
||||
|
||||
Returns:
|
||||
Set of field names found in the format string
|
||||
"""
|
||||
formatter = string.Formatter()
|
||||
keys = set()
|
||||
for _, field_name, _, _ in formatter.parse(format_string):
|
||||
if field_name is not None:
|
||||
keys.add(field_name)
|
||||
return keys
|
||||
|
||||
|
||||
class Argument(BaseModel):
|
||||
f"""Defines an argument that can be passed to a command.
|
||||
|
||||
Attributes:
|
||||
name: The argument name, must match {ARGUMENT_NAME_PATTERN!r}
|
||||
type: The argument type (e.g. "string", "integer")
|
||||
description: Human readable description of the argument
|
||||
required: Whether this argument must be provided
|
||||
enum: Optional list of allowed values
|
||||
argument_format: Format string for how to render the argument value in the command
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
items: dict[str, str] | None = None
|
||||
description: str
|
||||
required: bool
|
||||
enum: list[str] | None = None
|
||||
argument_format: str = "{{value}}"
|
||||
"""How to invoke the argument in the command. Make sure to use jinja syntax ({{value}}) instead of {value})."""
|
||||
|
||||
@field_validator("argument_format")
|
||||
def validate_argument_format(cls, value: str) -> str:
|
||||
_warn_probably_wrong_jinja_syntax(value)
|
||||
return value
|
||||
|
||||
|
||||
class Command(BaseModel):
|
||||
"""Represents an executable command with arguments and documentation.
|
||||
|
||||
A command can be either a simple bash command or a multi-line command terminated by an end marker.
|
||||
|
||||
Attributes:
|
||||
name: The command name
|
||||
docstring: Human readable description of what the command does
|
||||
signature: Optional custom signature override
|
||||
end_name: For multi-line commands, the terminating marker
|
||||
arguments: List of arguments accepted by the command
|
||||
|
||||
Properties:
|
||||
invoke_format: Format string for constructing the full command invocation
|
||||
"""
|
||||
|
||||
name: str
|
||||
docstring: str | None
|
||||
signature: str | None = None
|
||||
# if there is an end_name, then it is a multi-line command
|
||||
end_name: str | None = None
|
||||
arguments: list[Argument] = []
|
||||
|
||||
@cached_property
|
||||
def invoke_format(self) -> str:
|
||||
"""Gets the format string for invoking this command with arguments.
|
||||
|
||||
Returns either the custom signature with argument placeholders replaced,
|
||||
or a default format of "command arg1 arg2 ...".
|
||||
"""
|
||||
if self.signature:
|
||||
# First validate that all arguments are present in the original signature
|
||||
for arg in self.arguments:
|
||||
if not (
|
||||
f"<{arg.name}>" in self.signature
|
||||
or f"[<{arg.name}>]" in self.signature
|
||||
or f"{{{arg.name}}}" in self.signature
|
||||
or f"--{arg.name}" in self.signature
|
||||
):
|
||||
msg = (
|
||||
f"Missing argument {arg.name} in signature: {self.signature}. Did you format the signature correctly? "
|
||||
f"You must include all argument names in the signature with <{arg.name}>, [<{arg.name}>], {{{arg.name}}}, or --{arg.name} notation."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Then do the replacement
|
||||
return re.sub(rf"\[?<({ARGUMENT_NAME_PATTERN})>\]?", r"{\1}", self.signature)
|
||||
else:
|
||||
# cmd arg_format_1 arg_format_2 ...
|
||||
_invoke_format = f"{self.name} "
|
||||
for arg in self.arguments:
|
||||
_invoke_format += f"{{{arg.name}}} "
|
||||
return _invoke_format
|
||||
|
||||
def get_function_calling_tool(self) -> dict:
|
||||
"""Converts this command into an OpenAI function calling tool definition.
|
||||
|
||||
Returns:
|
||||
Dict containing the OpenAI function schema for this command
|
||||
"""
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.docstring or "",
|
||||
},
|
||||
}
|
||||
properties = {}
|
||||
required = []
|
||||
if self.arguments:
|
||||
for arg in self.arguments:
|
||||
properties[arg.name] = {"type": arg.type, "description": arg.description}
|
||||
|
||||
if arg.items:
|
||||
properties[arg.name]["items"] = arg.items
|
||||
|
||||
if arg.required:
|
||||
required.append(arg.name)
|
||||
|
||||
# Handle enum if present
|
||||
if arg.enum:
|
||||
properties[arg.name]["enum"] = arg.enum
|
||||
tool["function"]["parameters"] = {"type": "object", "properties": properties, "required": required}
|
||||
return tool
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_arguments(self) -> Command:
|
||||
"""Validates command argument configuration.
|
||||
|
||||
Checks:
|
||||
- Required arguments come before optional ones
|
||||
- Argument names are unique
|
||||
- Argument names match the pattern
|
||||
- Arguments match the signature
|
||||
|
||||
Returns:
|
||||
The validated Command instance
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
if not self.arguments:
|
||||
return self
|
||||
found_optional = False
|
||||
for arg in self.arguments:
|
||||
if found_optional and arg.required:
|
||||
msg = f"Command '{self.name}': Required argument '{arg.name}' cannot come after optional arguments"
|
||||
raise ValueError(msg)
|
||||
if not arg.required:
|
||||
found_optional = True
|
||||
|
||||
name_counts = Counter(arg.name for arg in self.arguments)
|
||||
duplicates = {name for name, count in name_counts.items() if count > 1}
|
||||
if duplicates:
|
||||
msg = f"Command '{self.name}': Duplicate argument names: {duplicates}"
|
||||
raise ValueError(msg)
|
||||
for arg in self.arguments:
|
||||
if not re.match(ARGUMENT_NAME_PATTERN, arg.name):
|
||||
msg = f"Command '{self.name}': Invalid argument name: '{arg.name}'"
|
||||
raise ValueError(msg)
|
||||
if (invoke_keys := _extract_keys(self.invoke_format)) != {arg.name for arg in self.arguments}:
|
||||
msg = f"Command '{self.name}': Argument names ({invoke_keys}) in signature / invoke_format {self.invoke_format!r} do not match argument names"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
# Default Bash tool
|
||||
BASH_COMMAND = Command(
|
||||
name="bash",
|
||||
# name="execute_bash",
|
||||
signature="<command>",
|
||||
# signature="echo '<command>'\n<command>\necho \"root@workspace:${{PWD}} #\n[Command finished with exit code ${{?}}]\"",
|
||||
docstring="runs the given command directly in bash",
|
||||
arguments=[
|
||||
Argument(
|
||||
name="command",
|
||||
type="string",
|
||||
description="The bash command to execute.",
|
||||
required=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
619
.agent/vendor/mini-swe/sweagent/tools/parsing.py
vendored
Normal file
@@ -0,0 +1,619 @@
|
||||
"""Our parsers parse output from the LM into thoughts and actions.
|
||||
|
||||
For example, our most basic parser is the `ThoughtActionParser`.
|
||||
It expects the model response to be a discussion followed by a command wrapped in backticks like so:
|
||||
|
||||
```
|
||||
Let's look at the files in the current directory.
|
||||
|
||||
Action:
|
||||
```
|
||||
ls -l
|
||||
```
|
||||
```
|
||||
|
||||
For models that support function calling, we instead recommend using the `FunctionCallingParser`.
|
||||
|
||||
To use a specific parser, set the `parse_function` key in your tool config to the `type` field of the parser.
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
tools:
|
||||
...
|
||||
parse_function:
|
||||
type: "thought_action"
|
||||
```
|
||||
|
||||
Or from the command line: `--agent.tools.parse_function.type=thought_action`.
|
||||
|
||||
!!! note "Describing available tools"
|
||||
If you do not use the `FunctionCallingParser`, you need to include documentation about the available tools
|
||||
in your system prompt. You can use the `{{command_docs}}` variable to include the automatically generated
|
||||
documentation or explicitly describe the available tools.
|
||||
Also see [#1130](https://github.com/SWE-agent/SWE-agent/issues/1130).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
from abc import ABC, abstractmethod
|
||||
from shlex import quote
|
||||
from textwrap import dedent
|
||||
from typing import Any, Literal
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sweagent.exceptions import FormatError, FunctionCallingFormatError
|
||||
from sweagent.tools.commands import Command
|
||||
from sweagent.tools.utils import _should_quote
|
||||
|
||||
|
||||
class AbstractParseFunction(ABC):
|
||||
"""
|
||||
Abstract class for parsing functions.
|
||||
We use get to generate the right parser based on the name of the parser.
|
||||
"""
|
||||
|
||||
error_message: str
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, model_response, commands: list[Command], strict=False) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def format_error_template(self):
|
||||
return textwrap.dedent(self.error_message)
|
||||
|
||||
|
||||
# DEFINE NEW PARSING FUNCTIONS BELOW THIS LINE
|
||||
|
||||
|
||||
class ActionParser(AbstractParseFunction, BaseModel):
|
||||
"""
|
||||
Expects the model response to be a single command.
|
||||
Example: "ls -l"
|
||||
"""
|
||||
|
||||
error_message: str = """\
|
||||
The command you provided was not recognized. Please specify one of the commands (+ any necessary arguments) from the following list in your response. Do not include any other text.
|
||||
|
||||
COMMANDS:
|
||||
{command_docs}
|
||||
"""
|
||||
|
||||
type: Literal["action"] = "action"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
if model_response["message"].split():
|
||||
action = model_response["message"].strip().split()[0]
|
||||
if action in {command.name for command in commands}:
|
||||
return model_response["message"], model_response["message"]
|
||||
msg = "First word in model response is not a valid command."
|
||||
raise FormatError(msg)
|
||||
|
||||
|
||||
class ActionOnlyParser(AbstractParseFunction, BaseModel):
|
||||
"""Expects the model response to be a single command."""
|
||||
|
||||
error_message: str = "No message found in model response."
|
||||
|
||||
type: Literal["action_only"] = "action_only"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
return "", model_response["message"]
|
||||
|
||||
|
||||
class ThoughtActionParser(AbstractParseFunction, BaseModel):
|
||||
"""
|
||||
Expects the model response to be a discussion followed by a command wrapped in backticks.
|
||||
Example:
|
||||
Let's look at the files in the current directory.
|
||||
```
|
||||
ls -l
|
||||
```
|
||||
"""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
Your output was not formatted correctly. You must always include one discussion and one command as part of your response. Make sure you do not have multiple discussion/command tags.
|
||||
Please make sure your output precisely matches the following format:
|
||||
DISCUSSION
|
||||
Discuss here with yourself about what your planning and what you're going to do in this step.
|
||||
|
||||
```
|
||||
command(s) that you're going to run
|
||||
```
|
||||
""")
|
||||
|
||||
type: Literal["thought_action"] = "thought_action"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
"""
|
||||
Parses the action from the output of the API call.
|
||||
We assume that the action is the last code block in the model_response.
|
||||
We also assume that the action is not nested within another code block.
|
||||
This is problematic if the model_response includes many unnamed ``` blocks.
|
||||
For instance:
|
||||
```
|
||||
This is a code block.
|
||||
```
|
||||
```
|
||||
This is another code block.
|
||||
```
|
||||
|
||||
In this case, only the second code block will be parsed as the action.
|
||||
"""
|
||||
code_block_pat = re.compile(r"^```(\S*)\s*\n|^```\s*$", re.MULTILINE)
|
||||
stack = []
|
||||
last_valid_block = None
|
||||
for match in code_block_pat.finditer(model_response["message"]):
|
||||
if stack and not match.group(1): # Closing of a code block
|
||||
start = stack.pop()
|
||||
# Check if it's not nested within another block
|
||||
if not stack:
|
||||
last_valid_block = (start, match)
|
||||
elif match.group(1) is not None: # Opening of a code block
|
||||
stack.append(match)
|
||||
if last_valid_block:
|
||||
start, end = last_valid_block
|
||||
thought = model_response["message"][: start.start()] + model_response["message"][end.end() :]
|
||||
return thought, model_response["message"][start.end() : end.start()]
|
||||
msg = "No action found in model response."
|
||||
raise FormatError(msg)
|
||||
|
||||
|
||||
class XMLThoughtActionParser(AbstractParseFunction, BaseModel):
|
||||
"""
|
||||
Expects the model response to be a discussion followed by a command wrapped in XML tags.
|
||||
Example:
|
||||
Let's look at the files in the current directory.
|
||||
<command>
|
||||
ls -l
|
||||
</command>
|
||||
"""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
Your output was not formatted correctly. You must always include one discussion and one command as part of your response. Make sure you do not have multiple discussion/command tags.
|
||||
Please make sure your output precisely matches the following format:
|
||||
""")
|
||||
|
||||
type: Literal["xml_thought_action"] = "xml_thought_action"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
|
||||
"""
|
||||
Parses the action from the output of the API call.
|
||||
We assume that the action is the last code block in the model_response.
|
||||
We also assume that the action is not nested within another code block.
|
||||
This is problematic if the model_response includes many unnamed ``` blocks.
|
||||
For instance:
|
||||
<command>
|
||||
This is a code block.
|
||||
</command>
|
||||
<command>
|
||||
This is another code block.
|
||||
</command>
|
||||
|
||||
In this case, only the second code block will be parsed as the action.
|
||||
"""
|
||||
if "<command>" not in model_response["message"] or "</command>" not in model_response["message"]:
|
||||
msg = "No action found in model response."
|
||||
raise FormatError(msg)
|
||||
# `action` is everything between the last <command> and </command> tags
|
||||
start_action = model_response["message"].rfind("<command>") + len(
|
||||
"<command>"
|
||||
) # start after the last <command> tag
|
||||
end_thought = model_response["message"].rfind("<command>") # end before the last <command> tag
|
||||
end_action = model_response["message"].rfind("</command>") # end before the last </command> tag
|
||||
restart_thought = model_response["message"].rfind("</command>") + len(
|
||||
"</command>"
|
||||
) # start after the last </command> tag
|
||||
# `thought` is everything not in between <command> and </command> tags (includes after the last </command> tag)
|
||||
action = model_response["message"][start_action:end_action]
|
||||
thought = model_response["message"][:end_thought] + model_response["message"][restart_thought:]
|
||||
|
||||
return thought.strip(), action.strip()
|
||||
|
||||
|
||||
FN_REGEX_PATTERN = r"<function=([^>]+)>\n(.*?)</function>"
|
||||
FN_PARAM_REGEX_PATTERN = r"<parameter=([^>]+)>(.*?)</parameter>"
|
||||
|
||||
|
||||
class XMLFunctionCallingParser(AbstractParseFunction, BaseModel):
|
||||
"""
|
||||
Expects the model response to be a tool calling format, where the command and parameters are specified
|
||||
in XML tags.
|
||||
Example:
|
||||
Let's look at the files in the current directory.
|
||||
<function=bash>
|
||||
<parameter=command>find /testbed -type f -name "_discovery.py"</parameter>
|
||||
</function>
|
||||
"""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
{%- if error_code == "missing" -%}
|
||||
Your last output did not use any tool calls!
|
||||
Please make sure your output includes exactly _ONE_ function call!
|
||||
If you think you have already resolved the issue, please submit your changes by running the `submit` command.
|
||||
If you think you cannot solve the problem, please run `submit`.
|
||||
Else, please continue with a new tool call!
|
||||
{%- elif error_code == "multiple" -%}
|
||||
Your last output included multiple tool calls!
|
||||
Please make sure your output includes a thought and exactly _ONE_ function call.
|
||||
{%- elif error_code == "unexpected_arg" -%}
|
||||
Your action could not be parsed properly: {{exception_message}}.
|
||||
Make sure your function call doesn't include any extra arguments that are not in the allowed arguments, and only use the allowed commands.
|
||||
{%- else -%}
|
||||
Your action could not be parsed properly: {{exception_message}}.
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
type: Literal["xml_function_calling"] = "xml_function_calling"
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
|
||||
fn_match = re.search(FN_REGEX_PATTERN, model_response["message"], re.DOTALL)
|
||||
if not fn_match:
|
||||
msg = "No function found in model response."
|
||||
raise FormatError(msg)
|
||||
fn_name = fn_match.group(1).strip()
|
||||
|
||||
# Handle different names in SWE-agent vs. SWE-gym
|
||||
if fn_name == "execute_bash":
|
||||
fn_name = "bash"
|
||||
if fn_name == "finish":
|
||||
fn_name = "submit"
|
||||
|
||||
fn_body = fn_match.group(2)
|
||||
thought = model_response["message"][: fn_match.start()] + model_response["message"][fn_match.end() :]
|
||||
thought = thought.strip()
|
||||
|
||||
commands_dict = {c.name: c for c in commands}
|
||||
command = commands_dict.get(fn_name)
|
||||
if not command:
|
||||
msg = f"Command '{fn_name}' not found in list of available commands."
|
||||
raise FormatError(msg)
|
||||
|
||||
params_dict = {
|
||||
param[0]: re.sub(r"^\n|\n$", "", param[1])
|
||||
for param in re.findall(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
|
||||
}
|
||||
|
||||
if "view_range" in params_dict:
|
||||
# Check that value is format as [x, y]
|
||||
v = params_dict["view_range"]
|
||||
if isinstance(v, str):
|
||||
if not re.match(r"\[\d+,\s*\d+\]", v):
|
||||
msg = f"view_range must be in the format [<start>, <end>], got {v}."
|
||||
raise FormatError(msg)
|
||||
params_dict["view_range"] = json.loads(v)
|
||||
|
||||
# Check if all required arguments are there
|
||||
required_args = {arg.name for arg in command.arguments if arg.required}
|
||||
missing_args = required_args - params_dict.keys()
|
||||
if missing_args:
|
||||
msg = f"Required argument(s) missing: {', '.join(missing_args)}"
|
||||
raise FormatError(msg)
|
||||
|
||||
# Check if all arguments are valid
|
||||
valid_args = {arg.name for arg in command.arguments}
|
||||
extra_args = set(params_dict.keys()) - valid_args
|
||||
if command.end_name:
|
||||
# sometimes the model will include the end_name in the arguments - just ignore it
|
||||
extra_args.discard(command.end_name)
|
||||
if extra_args:
|
||||
msg = f"Unexpected argument(s): {', '.join(extra_args)}"
|
||||
raise FormatError(msg)
|
||||
|
||||
# Format arguments using their individual argument_format
|
||||
formatted_args = {
|
||||
arg.name: Template(arg.argument_format).render(
|
||||
value=quote(params_dict[arg.name])
|
||||
if _should_quote(params_dict[arg.name], command)
|
||||
else params_dict[arg.name]
|
||||
)
|
||||
if arg.name in params_dict
|
||||
else ""
|
||||
for arg in command.arguments
|
||||
}
|
||||
return thought, command.invoke_format.format(**formatted_args).strip()
|
||||
|
||||
|
||||
class EditFormat(ThoughtActionParser, BaseModel):
|
||||
"""
|
||||
Expects the model response to be a discussion followed by a command wrapped in backticks.
|
||||
Example:
|
||||
We'll replace the contents of the current window with the following:
|
||||
```
|
||||
import os
|
||||
os.listdir()
|
||||
```
|
||||
"""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
Your output was not formatted correctly. You must wrap the replacement text in backticks (```).
|
||||
Please make sure your output precisely matches the following format:
|
||||
COMMENTS
|
||||
You can write comments here about what you're going to do if you want.
|
||||
|
||||
```
|
||||
New window contents.
|
||||
Make sure you copy the entire contents of the window here, with the required indentation.
|
||||
Make the changes to the window above directly in this window.
|
||||
Remember that all of the window's contents will be replaced with the contents of this window.
|
||||
Don't include line numbers in your response.
|
||||
```
|
||||
""")
|
||||
|
||||
type: Literal["edit_format"] = "edit_format"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
|
||||
class Identity(AbstractParseFunction, BaseModel):
|
||||
"""This parser does not do any parsing. It just returns the model response as both the thought and action."""
|
||||
|
||||
error_message: str = """\
|
||||
It seems like something went wrong with your output. Please try again.
|
||||
"""
|
||||
|
||||
type: Literal["identity"] = "identity"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
|
||||
"""
|
||||
This doesn't do any parsing. It just returns the model response as the thought and action.
|
||||
"""
|
||||
return model_response["message"], model_response["message"]
|
||||
|
||||
|
||||
class FunctionCallingParser(AbstractParseFunction, BaseModel):
|
||||
"""Expects the model response to be a LiteLLM tool call."""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
{%- if error_code == "missing" -%}
|
||||
Your last output did not use any tool calls!
|
||||
Please make sure your output includes exactly _ONE_ function call!
|
||||
You must invoke the function directly using the function call format.
|
||||
You cannot invoke commands with ```, you have to use the function call format.
|
||||
If you think you have already resolved the issue, please submit your changes by running the `submit` command.
|
||||
If you think you cannot solve the problem, please run `exit_forfeit` (if available) or `submit`.
|
||||
Else, please continue with a new tool call!
|
||||
{%- elif error_code == "multiple" -%}
|
||||
Your last output included multiple tool calls!
|
||||
Please make sure your output includes a thought and exactly _ONE_ function call.
|
||||
{%- elif error_code == "unexpected_arg" -%}
|
||||
Your action could not be parsed properly: {{exception_message}}.
|
||||
Make sure your function call doesn't include any extra arguments that are not in the allowed arguments, and only use the allowed commands.
|
||||
{%- else -%}
|
||||
Your action could not be parsed properly: {{exception_message}}.
|
||||
{% endif %}
|
||||
""")
|
||||
|
||||
type: Literal["function_calling"] = "function_calling"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def _parse_tool_call(self, tool_call: dict, commands: list[Command]):
|
||||
name = tool_call["function"]["name"]
|
||||
command = {c.name: c for c in commands}.get(name)
|
||||
if not command:
|
||||
msg = f"Command '{name}' not found in list of available commands."
|
||||
raise FunctionCallingFormatError(msg, "invalid_command")
|
||||
if not isinstance(tool_call["function"]["arguments"], dict):
|
||||
try:
|
||||
values = json.loads(tool_call["function"]["arguments"])
|
||||
except json.JSONDecodeError:
|
||||
msg = "Tool call arguments are not valid JSON."
|
||||
raise FunctionCallingFormatError(msg, "invalid_json")
|
||||
required_args = {arg.name for arg in command.arguments if arg.required}
|
||||
missing_args = required_args - values.keys()
|
||||
if missing_args:
|
||||
msg = f"Required argument(s) missing: {', '.join(missing_args)}"
|
||||
raise FunctionCallingFormatError(msg, "missing_arg")
|
||||
valid_args = {arg.name for arg in command.arguments}
|
||||
extra_args = set(values.keys()) - valid_args
|
||||
if command.end_name:
|
||||
# sometimes the model will include the end_name in the arguments - just ignore it
|
||||
extra_args.discard(command.end_name)
|
||||
if extra_args:
|
||||
msg = f"Unexpected argument(s): {', '.join(extra_args)}"
|
||||
raise FunctionCallingFormatError(msg, "unexpected_arg")
|
||||
|
||||
def get_quoted_arg(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return quote(value) if _should_quote(value, command) else value
|
||||
# See https://github.com/SWE-agent/SWE-agent/issues/1159
|
||||
if value is None:
|
||||
return ""
|
||||
return value
|
||||
|
||||
formatted_args = {
|
||||
arg.name: Template(arg.argument_format).render(value=get_quoted_arg(values[arg.name]))
|
||||
if arg.name in values
|
||||
else ""
|
||||
for arg in command.arguments
|
||||
}
|
||||
return command.invoke_format.format(**formatted_args).strip()
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
message = model_response["message"]
|
||||
tool_calls = model_response.get("tool_calls", None)
|
||||
if tool_calls is None or len(tool_calls) != 1:
|
||||
num_tools = len(tool_calls) if tool_calls else 0
|
||||
msg = (
|
||||
f"Expected exactly one tool call in model response - received {num_tools} "
|
||||
f"tool calls with message: {message}"
|
||||
)
|
||||
error_code = "missing" if num_tools == 0 else "multiple"
|
||||
raise FunctionCallingFormatError(msg, error_code, num_tools=num_tools)
|
||||
tool_call = tool_calls[0]
|
||||
action = self._parse_tool_call(tool_call, commands)
|
||||
return message, action
|
||||
|
||||
|
||||
class JsonParser(AbstractParseFunction, BaseModel):
|
||||
"""Expects the model response to be a JSON object."""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
Your output could not be parsed as JSON. Please make sure your output 1) is valid JSON and
|
||||
2) Includes the "thought" and "command" fields.
|
||||
|
||||
""")
|
||||
|
||||
type: Literal["json"] = "json"
|
||||
"""Type for (de)serialization. Do not change."""
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
"""Parses the action from the output of the API call.
|
||||
We assume that model output is a JSON object with the following fields:
|
||||
{
|
||||
"thought": "discussion text here.",
|
||||
"command": {
|
||||
"arguments": {
|
||||
"arg1": "value1",
|
||||
"arg2": "value2",
|
||||
...
|
||||
},
|
||||
"name": "command_name"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = json.loads(model_response["message"])
|
||||
if not isinstance(data, dict):
|
||||
msg = "Model output is not a JSON object."
|
||||
raise FormatError(msg)
|
||||
|
||||
# Check if required keys are present
|
||||
required_keys = ["thought", "command"]
|
||||
for key in required_keys:
|
||||
if key not in data:
|
||||
msg = f"Key '{key}' is missing from model output."
|
||||
raise FormatError(msg)
|
||||
|
||||
# Check structure of 'command' key
|
||||
data_command = data["command"]
|
||||
if not isinstance(data_command, dict):
|
||||
msg = "Value of 'command' key is not a JSON object."
|
||||
raise FormatError(msg)
|
||||
|
||||
# Check if required keys are present in 'command' object
|
||||
command_keys = ["name"]
|
||||
for key in command_keys:
|
||||
if key not in data_command:
|
||||
msg = f"Key '{key}' is missing from 'command' object."
|
||||
raise FormatError(msg)
|
||||
|
||||
thought = data["thought"]
|
||||
commands_dict = {c.name: c for c in commands}
|
||||
command = commands_dict.get(data_command["name"])
|
||||
|
||||
# Handle command parsing based on strict mode
|
||||
if command is None:
|
||||
if strict:
|
||||
msg = f"Command '{data_command['name']}' not found in list of available commands."
|
||||
raise FormatError(msg)
|
||||
# In non-strict mode, just join command name with argument values
|
||||
return thought, " ".join([data_command["name"], *data_command.get("arguments", {}).values()])
|
||||
|
||||
# Format arguments using their individual argument_format
|
||||
formatted_args = {}
|
||||
if command.arguments:
|
||||
for arg in command.arguments:
|
||||
if arg.name in data_command.get("arguments", {}):
|
||||
value = data_command["arguments"][arg.name]
|
||||
if _should_quote(value, command):
|
||||
value = quote(value)
|
||||
formatted_args[arg.name] = Template(arg.argument_format).render(value=value)
|
||||
elif strict and arg.required:
|
||||
msg = f"Required argument '{arg.name}' missing for command '{command.name}'"
|
||||
raise FormatError(msg)
|
||||
|
||||
# Use the formatted arguments with invoke_format
|
||||
action = command.invoke_format.format(**formatted_args).strip()
|
||||
return thought, action
|
||||
except json.JSONDecodeError:
|
||||
msg = "Model output is not valid JSON."
|
||||
raise FormatError(msg)
|
||||
|
||||
|
||||
class BashCodeBlockParser(AbstractParseFunction, BaseModel):
|
||||
"""Executes all commands in ```bash code blocks."""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
No bash code blocks were detected in your output.
|
||||
You need to include at least one bash code block in your output.
|
||||
|
||||
It must follow this format exactly to be valid:
|
||||
```bash
|
||||
cmd arg1 arg2 ...
|
||||
...
|
||||
|
||||
Other types of code blocks (e.g. python, rust, none, etc.) won't be executed. Only bash.
|
||||
""")
|
||||
|
||||
type: Literal["all_bash_code_blocks"] = "all_bash_code_blocks"
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
"""Parses the action from the output of the API call.
|
||||
We assume that model output is a JSON object with the following fields:
|
||||
"""
|
||||
pattern = re.compile(r"```bash\n(.*?)\n```", re.DOTALL)
|
||||
matches = pattern.findall(model_response["message"])
|
||||
if not matches:
|
||||
msg = "No bash code blocks were detected in your output."
|
||||
raise FormatError(msg)
|
||||
thought = pattern.sub("<extracted_code_block>", model_response["message"])
|
||||
action = "\n".join(matches)
|
||||
return thought, action
|
||||
|
||||
|
||||
class SingleBashCodeBlockParser(AbstractParseFunction, BaseModel):
|
||||
"""Executes all commands in ```bash code blocks."""
|
||||
|
||||
error_message: str = dedent("""\
|
||||
We did not detect the right number of bash code blocks in your output.
|
||||
You need to include EXACTLY ONE bash code block in your output.
|
||||
|
||||
It must follow this format exactly to be valid:
|
||||
```bash
|
||||
cmd arg1 arg2 ...
|
||||
```
|
||||
""")
|
||||
|
||||
type: Literal["single_bash_code_block"] = "single_bash_code_block"
|
||||
|
||||
def __call__(self, model_response: dict, commands: list[Command], strict=False):
|
||||
"""Parses the action from the output of the API call.
|
||||
We assume that model output is a JSON object with the following fields:
|
||||
"""
|
||||
pattern = re.compile(r"```bash\n(.*?)\n```", re.DOTALL)
|
||||
matches = pattern.findall(model_response["message"])
|
||||
if not matches:
|
||||
msg = "No bash code blocks were detected in your output."
|
||||
raise FormatError(msg)
|
||||
if len(matches) > 1:
|
||||
msg = (
|
||||
"We detected multiple bash code blocks in your output. "
|
||||
"You need to include EXACTLY ONE bash code block in your output."
|
||||
)
|
||||
raise FormatError(msg)
|
||||
thought = pattern.sub("<extracted_code_block>", model_response["message"])
|
||||
action = "\n".join(matches)
|
||||
return thought, action
|
||||
|
||||
|
||||
ParseFunction = (
|
||||
ActionParser
|
||||
| ThoughtActionParser
|
||||
| ActionOnlyParser
|
||||
| XMLThoughtActionParser
|
||||
| XMLFunctionCallingParser
|
||||
| FunctionCallingParser
|
||||
| EditFormat
|
||||
| Identity
|
||||
| JsonParser
|
||||
| BashCodeBlockParser
|
||||
| SingleBashCodeBlockParser
|
||||
)
|
||||
430
.agent/vendor/mini-swe/sweagent/tools/tools.py
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
This module contains the configuration for the tools that are made available to the agent.
|
||||
|
||||
The `ToolConfig` class is used to configure the tools that are available to the agent.
|
||||
The `ToolHandler` class is used to handle the tools that are available to the agent.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
from swerex.runtime.abstract import UploadRequest
|
||||
from typing_extensions import Self
|
||||
|
||||
from sweagent.environment.swe_env import SWEEnv
|
||||
from sweagent.tools.bundle import Bundle
|
||||
from sweagent.tools.commands import BASH_COMMAND, Command
|
||||
from sweagent.tools.parsing import FunctionCallingParser, JsonParser, ParseFunction
|
||||
from sweagent.tools.utils import _guard_multiline_input, generate_command_docs
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
class ToolFilterConfig(BaseModel):
|
||||
"""Filter out commands that are blocked by the environment
|
||||
(for example interactive commands like `vim`).
|
||||
"""
|
||||
|
||||
blocklist_error_template: str = "Operation '{{action}}' is not supported by this environment."
|
||||
"""The error template to use when a command is blocked."""
|
||||
|
||||
blocklist: list[str] = [
|
||||
"vim",
|
||||
"vi",
|
||||
"emacs",
|
||||
"nano",
|
||||
"nohup",
|
||||
"gdb",
|
||||
"less",
|
||||
"tail -f",
|
||||
"python -m venv",
|
||||
"make",
|
||||
]
|
||||
"""Block any command that starts with one of these"""
|
||||
|
||||
blocklist_standalone: list[str] = [
|
||||
"python",
|
||||
"python3",
|
||||
"ipython",
|
||||
"bash",
|
||||
"sh",
|
||||
"/bin/bash",
|
||||
"/bin/sh",
|
||||
"nohup",
|
||||
"vi",
|
||||
"vim",
|
||||
"emacs",
|
||||
"nano",
|
||||
"su",
|
||||
]
|
||||
"""Block any command that matches one of these exactly"""
|
||||
|
||||
block_unless_regex: dict[str, str] = {
|
||||
"radare2": r"\b(?:radare2)\b.*\s+-c\s+.*",
|
||||
"r2": r"\b(?:radare2)\b.*\s+-c\s+.*",
|
||||
}
|
||||
"""Block any command that matches one of these names unless it also matches the regex"""
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Configuration for the tools that are made available to the agent."""
|
||||
|
||||
filter: ToolFilterConfig = ToolFilterConfig()
|
||||
"""Filter out commands that are blocked by the environment
|
||||
(for example interactive commands like `vim`).
|
||||
"""
|
||||
|
||||
bundles: list[Bundle] = Field(default_factory=list)
|
||||
"""The tool bundles to load."""
|
||||
|
||||
propagate_env_variables: list[str] = []
|
||||
"""Environment variables to propagate to the environment.
|
||||
This is useful if you want to propagate API keys or similar from your own environment to the
|
||||
environment in which the tools run.
|
||||
IMPORTANT NOTE: The value of the environment variables can be read in debug log files,
|
||||
so be careful with your API keys!
|
||||
"""
|
||||
|
||||
env_variables: dict[str, Any] = {
|
||||
"PAGER": "cat",
|
||||
"MANPAGER": "cat",
|
||||
"LESS": "-R",
|
||||
"PIP_PROGRESS_BAR": "off",
|
||||
"TQDM_DISABLE": "1",
|
||||
"GIT_PAGER": "cat",
|
||||
}
|
||||
"""Shorthand to set environment variables for the tools, effectively
|
||||
equivalent to adding `export VARNAME=value` to the `reset_commands`.
|
||||
"""
|
||||
|
||||
registry_variables: dict[str, Any] = {}
|
||||
"""Populate the registry with these variables. Will be written out as json in the registry file."""
|
||||
|
||||
submit_command: str = "submit"
|
||||
"""The command/tool to use to submit the solution."""
|
||||
|
||||
parse_function: ParseFunction = Field(default_factory=FunctionCallingParser)
|
||||
"""The action parser that is responsible for parsing the model output into a thought and action.
|
||||
"""
|
||||
|
||||
enable_bash_tool: bool = True
|
||||
"""Whether to enable the bash tool in addition to the other tools specified in bundles."""
|
||||
|
||||
format_error_template: str = None # type: ignore
|
||||
"""Defaults to format_error_template in ParseFunction"""
|
||||
|
||||
command_docs: str = None # type: ignore
|
||||
"""Automatically generated documentation generated based on
|
||||
the loaded tool bundles.
|
||||
"""
|
||||
|
||||
multi_line_command_endings: dict[str, str] = {}
|
||||
submit_command_end_name: str | None = None
|
||||
|
||||
"""Commands to install dependencies and tools.
|
||||
These commands are executed in a subprocess and are not part of the environment state.
|
||||
"""
|
||||
|
||||
reset_commands: list[str | list[str]] = []
|
||||
"""Commands to reset the environment. They will also be called when we start the environment.
|
||||
Unlike `install_commands`, these commands are part of the environment state.
|
||||
"""
|
||||
|
||||
execution_timeout: int = 30
|
||||
"""Timeout for executing commands in the environment"""
|
||||
|
||||
install_timeout: int = 300
|
||||
"""Timeout used for each of the installation commands"""
|
||||
|
||||
total_execution_timeout: int = 1800
|
||||
"""Timeout for executing all commands in the environment.
|
||||
Note: Does not interrupt running commands, but will stop the agent for the next step.
|
||||
"""
|
||||
|
||||
max_consecutive_execution_timeouts: int = 3
|
||||
"""Maximum number of consecutive execution timeouts before the agent exits.
|
||||
"""
|
||||
|
||||
@cached_property
|
||||
def use_function_calling(self) -> bool:
|
||||
return isinstance(self.parse_function, FunctionCallingParser)
|
||||
|
||||
@cached_property
|
||||
def state_commands(self) -> list[str]:
|
||||
"""This property returns the state commands from all bundles.
|
||||
State commands are commands that are used to get the state of the environment
|
||||
(e.g., the current working directory).
|
||||
"""
|
||||
return [bundle.state_command for bundle in self.bundles if bundle.state_command]
|
||||
|
||||
# todo: move to ToolHandler?
|
||||
@cached_property
|
||||
def commands(self) -> list[Command]:
|
||||
"""Read command files and return parsed command objects"""
|
||||
commands = []
|
||||
tool_sources: dict[str, Path] = {} # Track which file each tool comes from
|
||||
# Add bash command if enabled
|
||||
if self.enable_bash_tool:
|
||||
commands.append(BASH_COMMAND)
|
||||
tool_sources[BASH_COMMAND.name] = Path("<builtin>")
|
||||
|
||||
# Collect commands from all bundles
|
||||
for bundle in self.bundles:
|
||||
for command in bundle.commands:
|
||||
if command.name in tool_sources:
|
||||
existing_source = tool_sources[command.name]
|
||||
msg = (
|
||||
f"Tool '{command.name}' is defined multiple times:\n"
|
||||
f" - First definition in: {existing_source}\n"
|
||||
f" - Duplicate definition in: {bundle.path}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
commands.append(command)
|
||||
tool_sources[command.name] = bundle.path
|
||||
|
||||
return commands
|
||||
|
||||
@cached_property
|
||||
def tools(self) -> list[dict]:
|
||||
return [command.get_function_calling_tool() for command in self.commands]
|
||||
|
||||
# todo: can some of these be moved to ToolHandler?
|
||||
def model_post_init(self, __context):
|
||||
# for caching:
|
||||
commands = self.commands
|
||||
multi_line_command_endings = {
|
||||
command.name: command.end_name for command in commands if command.end_name is not None
|
||||
}
|
||||
self.tools
|
||||
|
||||
# assert not self.enable_bash_tool and parse_function is FunctionCallingParser or JsonParser
|
||||
if not self.enable_bash_tool and not (
|
||||
isinstance(self.parse_function, FunctionCallingParser) or isinstance(self.parse_function, JsonParser)
|
||||
):
|
||||
msg = f"Bash tool can only be disabled if {FunctionCallingParser.type} parser or {JsonParser.type} parser is used."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.multi_line_command_endings = multi_line_command_endings
|
||||
self.command_docs = generate_command_docs(
|
||||
self.commands,
|
||||
[],
|
||||
**self.env_variables,
|
||||
)
|
||||
if self.format_error_template is None:
|
||||
self.format_error_template = self.parse_function.format_error_template
|
||||
for command in commands:
|
||||
if command.name == self.submit_command:
|
||||
self.submit_command_end_name = command.end_name
|
||||
break
|
||||
|
||||
|
||||
class ToolHandler:
|
||||
def __init__(self, tools: ToolConfig):
|
||||
"""This class handles most of the tool usage. It has the following responsibilities:
|
||||
|
||||
- Install the tools
|
||||
- Parse commands and handle multiline commands
|
||||
- Decide if an action should be blocked
|
||||
- Get the current state of the environment
|
||||
"""
|
||||
# Always copy config to avoid shared state between different instances across threads
|
||||
self.config = tools.model_copy(deep=True)
|
||||
# partially initialized in `install_commands`.
|
||||
self._reset_commands = []
|
||||
self._command_patterns = self._get_command_patterns()
|
||||
self.logger = get_logger("swea-tools", emoji="🧰")
|
||||
# For testing: Return this state instead of querying the environment
|
||||
self.mock_state: dict[str, str] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ToolConfig) -> Self:
|
||||
return cls(config)
|
||||
|
||||
# Installation & Reset
|
||||
# --------------------
|
||||
|
||||
def install(self, env: SWEEnv) -> None:
|
||||
self._install_commands(env)
|
||||
self.reset(env)
|
||||
|
||||
def reset(self, env: SWEEnv) -> None:
|
||||
self.logger.info("Resetting tools")
|
||||
env_variables = self.config.env_variables.copy() | {
|
||||
var: os.getenv(var) for var in self.config.propagate_env_variables
|
||||
}
|
||||
env.set_env_variables(env_variables)
|
||||
env.write_file("/root/.swe-agent-env", json.dumps(self.config.registry_variables))
|
||||
env.write_file("/root/state.json", "{}")
|
||||
env.communicate(" && ".join(self._reset_commands), check="raise", timeout=self.config.install_timeout)
|
||||
|
||||
async def _upload_bundles(self, env: SWEEnv) -> None:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
env.deployment.runtime.upload(
|
||||
UploadRequest(source_path=bundle.path.as_posix(), target_path=f"/root/tools/{bundle.path.name}")
|
||||
)
|
||||
for bundle in self.config.bundles
|
||||
)
|
||||
)
|
||||
|
||||
async def _is_command_available(self, env, command: str, env_vars: dict[str, str]) -> None:
|
||||
if command == "bash":
|
||||
return
|
||||
try:
|
||||
await env.deployment.runtime.execute(
|
||||
RexCommand(command=f"which {command}", shell=True, check=True, env=env_vars)
|
||||
)
|
||||
except Exception:
|
||||
msg = f"Tool {command} is not available in the container."
|
||||
raise RuntimeError(msg) from None
|
||||
|
||||
async def _check_available_commands(self, env: SWEEnv, env_vars: dict[str, str]) -> None:
|
||||
await asyncio.gather(
|
||||
*(self._is_command_available(env, command.name, env_vars) for command in self.config.commands)
|
||||
)
|
||||
|
||||
def _install_commands(self, env: SWEEnv) -> None:
|
||||
"""Make sure all commands are available in the container"""
|
||||
env.set_env_variables(self.config.env_variables)
|
||||
cwd = env.communicate("pwd", check="raise").strip()
|
||||
asyncio.run(self._upload_bundles(env))
|
||||
for bundle in self.config.bundles:
|
||||
cmds = [
|
||||
f"export PATH=/root/tools/{bundle.path.name}/bin:$PATH",
|
||||
f"chmod +x /root/tools/{bundle.path.name}/bin/*",
|
||||
]
|
||||
if (bundle.path / "install.sh").exists():
|
||||
cmds.append(f"cd /root/tools/{bundle.path.name} && source install.sh")
|
||||
cmds.append(f"chmod +x /root/tools/{bundle.path.name}/bin/*")
|
||||
env.communicate(
|
||||
" && ".join(cmds),
|
||||
check="raise",
|
||||
timeout=self.config.install_timeout,
|
||||
)
|
||||
env.communicate(f"cd {cwd}", check="raise")
|
||||
path = env.communicate("echo $PATH", check="raise").strip()
|
||||
asyncio.run(self._check_available_commands(env, {"PATH": path}))
|
||||
|
||||
# Getting state
|
||||
# -------------
|
||||
|
||||
def _get_state(self, env: SWEEnv) -> dict[str, str]:
|
||||
"""Retrieve the state from the environment"""
|
||||
try:
|
||||
state_str = env.read_file("/root/state.json")
|
||||
except FileNotFoundError:
|
||||
self.logger.warning("State file not found, returning empty state")
|
||||
return {}
|
||||
if not state_str.strip():
|
||||
self.logger.warning("State file is empty, returning empty state")
|
||||
return {}
|
||||
try:
|
||||
state = json.loads(state_str)
|
||||
except json.JSONDecodeError as e:
|
||||
msg = f"State {state_str!r} is not valid json. This is an internal error, please report it."
|
||||
raise ValueError(msg) from e
|
||||
if not isinstance(state, dict):
|
||||
msg = f"State commands must return a dictionary. Got {state!r} instead."
|
||||
raise ValueError(msg)
|
||||
return state
|
||||
|
||||
def get_state(self, env: SWEEnv) -> dict[str, str]:
|
||||
"""Execute state commands from all bundles and combine their results.
|
||||
This can be used to extract environment variables etc. from the environment.
|
||||
"""
|
||||
if self.mock_state is not None:
|
||||
return self.mock_state
|
||||
|
||||
for state_command in self.config.state_commands:
|
||||
env.communicate(state_command, check="warn")
|
||||
combined_state = self._get_state(env)
|
||||
self.logger.debug(f"Retrieved state from environment: {combined_state}")
|
||||
return combined_state
|
||||
|
||||
# Blocking
|
||||
# --------
|
||||
|
||||
def should_block_action(self, action: str) -> bool:
|
||||
"""Check if the command should be blocked."""
|
||||
action = action.strip()
|
||||
if not action:
|
||||
return False
|
||||
if any(action.startswith(f) for f in self.config.filter.blocklist):
|
||||
return True
|
||||
if action in self.config.filter.blocklist_standalone:
|
||||
return True
|
||||
name = action.split()[0]
|
||||
if name in self.config.filter.block_unless_regex and not re.search(
|
||||
self.config.filter.block_unless_regex[name], action
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Parsing & multiline commands
|
||||
# -----------------------------
|
||||
|
||||
def check_for_submission_cmd(self, output: str) -> bool:
|
||||
"""Function for checking submission request."""
|
||||
if r"<<SWE_AGENT_SUBMISSION>>" in output:
|
||||
return True
|
||||
return False
|
||||
|
||||
def parse_actions(self, output: dict) -> tuple[str, str]:
|
||||
"""Parse the model output into a thought and action."""
|
||||
return self.config.parse_function(output, self.config.commands)
|
||||
|
||||
def guard_multiline_input(self, action: str) -> str:
|
||||
"""Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
|
||||
Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
|
||||
|
||||
Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
|
||||
"""
|
||||
return _guard_multiline_input(action, self._get_first_multiline_cmd)
|
||||
|
||||
def _get_first_multiline_cmd(self, action: str) -> re.Match | None:
|
||||
"""Return the first match of a command pattern in the action string.
|
||||
Where first match is defined by the start of the match.
|
||||
|
||||
The match object has three groups: (1) command name, (2) command arguments, (3) end name
|
||||
"""
|
||||
patterns = {
|
||||
k: v
|
||||
for k, v in self._command_patterns.items()
|
||||
if k in self.config.multi_line_command_endings or k == self.config.submit_command
|
||||
}
|
||||
matches = list()
|
||||
for _, pat in patterns.items():
|
||||
match = pat.search(action)
|
||||
if match:
|
||||
matches.append(match)
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
matches = sorted(matches, key=lambda x: x.start())
|
||||
return matches[0]
|
||||
|
||||
def _get_command_patterns(self) -> dict[str, re.Pattern]:
|
||||
"""Creates regular expressions for the commands"""
|
||||
|
||||
_command_patterns = {}
|
||||
for command in self.config.commands:
|
||||
if command.end_name is not None:
|
||||
pat = re.compile(
|
||||
rf"^\s*({command.name})\s*(.*?)^({command.end_name})\s*$",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
_command_patterns[command.name] = pat
|
||||
else:
|
||||
pat = re.compile(rf"^\s*({command.name})\s*(.*?)$", re.MULTILINE)
|
||||
_command_patterns[command.name] = pat
|
||||
submit_pat = re.compile(
|
||||
rf"^\s*({self.config.submit_command})\s*(.*?)^({self.config.submit_command_end_name})\s*$",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
_command_patterns[self.config.submit_command] = submit_pat
|
||||
return _command_patterns
|
||||
108
.agent/vendor/mini-swe/sweagent/tools/utils.py
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sweagent.tools.commands import Command
|
||||
|
||||
|
||||
def _guard_multiline_input(action: str, match_fct: Callable[[str], re.Match | None]) -> str:
|
||||
"""Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
|
||||
Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
|
||||
|
||||
Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
|
||||
"""
|
||||
parsed_action = []
|
||||
rem_action = action
|
||||
while rem_action.strip():
|
||||
first_match = match_fct(rem_action)
|
||||
if first_match:
|
||||
pre_action = rem_action[: first_match.start()]
|
||||
match_action = rem_action[first_match.start() : first_match.end()]
|
||||
rem_action = rem_action[first_match.end() :]
|
||||
if pre_action.strip():
|
||||
parsed_action.append(pre_action)
|
||||
if match_action.strip():
|
||||
eof = first_match.group(3).strip()
|
||||
if not match_action.split("\n")[0].strip().endswith(f"<< '{eof}'"):
|
||||
guarded_command = match_action[first_match.start() :]
|
||||
first_line = guarded_command.split("\n")[0]
|
||||
guarded_command = guarded_command.replace(first_line, first_line + f" << '{eof}'", 1)
|
||||
parsed_action.append(guarded_command)
|
||||
else:
|
||||
parsed_action.append(match_action)
|
||||
else:
|
||||
parsed_action.append(rem_action)
|
||||
rem_action = ""
|
||||
return "\n".join(parsed_action)
|
||||
|
||||
|
||||
def _should_quote(value: Any, command: Command) -> bool:
|
||||
"""Returns True if the value should be quoted, False otherwise."""
|
||||
if command.name == "bash":
|
||||
return False
|
||||
return isinstance(value, str) and command.end_name is None
|
||||
|
||||
|
||||
def get_signature(cmd):
|
||||
"""Generate a command signature from its arguments.
|
||||
|
||||
Args:
|
||||
cmd: Command object to generate signature for
|
||||
|
||||
Returns:
|
||||
Formatted signature string
|
||||
"""
|
||||
signature = cmd.name
|
||||
if "arguments" in cmd.__dict__ and cmd.arguments is not None:
|
||||
if cmd.end_name is None:
|
||||
for argument in cmd.arguments:
|
||||
param = argument.name
|
||||
if argument.required:
|
||||
signature += f" <{param}>"
|
||||
else:
|
||||
signature += f" [<{param}>]"
|
||||
else:
|
||||
for argument in cmd.arguments[:-1]:
|
||||
param = argument.name
|
||||
if argument.required:
|
||||
signature += f" <{param}>"
|
||||
else:
|
||||
signature += f" [<{param}>]"
|
||||
signature += f"\n{list(cmd.arguments[-1].keys())[0]}\n{cmd.end_name}"
|
||||
return signature
|
||||
|
||||
|
||||
def generate_command_docs(
|
||||
commands: list[Command],
|
||||
subroutine_types,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Generate detailed command documentation.
|
||||
|
||||
Format includes docstring, signature and argument details.
|
||||
|
||||
Args:
|
||||
commands: List of commands to document
|
||||
subroutine_types: List of subroutines to document
|
||||
**kwargs: Additional format variables for docstrings
|
||||
|
||||
Returns:
|
||||
Formatted documentation string
|
||||
"""
|
||||
docs = ""
|
||||
for cmd in commands + subroutine_types:
|
||||
docs += f"{cmd.name}:\n"
|
||||
if cmd.docstring is not None:
|
||||
docs += f" docstring: {cmd.docstring.format(**kwargs)}\n"
|
||||
if cmd.signature is not None:
|
||||
docs += f" signature: {cmd.signature}\n"
|
||||
else:
|
||||
docs += f" signature: {get_signature(cmd)}\n"
|
||||
if cmd.arguments:
|
||||
docs += " arguments:\n"
|
||||
for argument in cmd.arguments:
|
||||
param = argument.name
|
||||
req_string = "required" if argument.required else "optional"
|
||||
docs += f" - {param} ({argument.type}) [{req_string}]: {argument.description}\n"
|
||||
docs += "\n"
|
||||
return docs
|
||||
102
.agent/vendor/mini-swe/sweagent/types.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
"""This file has types/dataclass definitions that are used in the SWE agent
|
||||
for exchanging data between different modules/functions/classes.
|
||||
They oftentimes cannot be defined in the same file where they are used
|
||||
because of circular dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class StepOutput(BaseModel):
|
||||
query: list[dict] = [{}]
|
||||
thought: str = ""
|
||||
action: str = ""
|
||||
output: str = ""
|
||||
observation: str = ""
|
||||
execution_time: float = 0.0
|
||||
done: bool = False
|
||||
exit_status: int | str | None = None
|
||||
submission: str | None = None
|
||||
state: dict[str, str] = {}
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_ids: list[str] | None = None
|
||||
thinking_blocks: list[dict[str, Any]] | None = None
|
||||
|
||||
"""State of the environment at the end of the step"""
|
||||
extra_info: dict[str, Any] = {}
|
||||
|
||||
def to_template_format_dict(self) -> dict[str, str | int | float | bool | None]:
|
||||
"""Used for formatting (error) prompt templates"""
|
||||
out = {}
|
||||
for k, v in self.model_dump().items():
|
||||
if k in ("tool_calls", "tool_call_ids", "state"):
|
||||
continue
|
||||
out[k] = v
|
||||
out |= self.state
|
||||
return out
|
||||
|
||||
|
||||
class TrajectoryStep(TypedDict):
|
||||
action: str
|
||||
observation: str
|
||||
response: str
|
||||
state: dict[str, str]
|
||||
thought: str
|
||||
execution_time: float
|
||||
query: list[dict[str, Any]]
|
||||
extra_info: dict[str, Any]
|
||||
|
||||
|
||||
# required fields go here
|
||||
class _HistoryItem(TypedDict):
|
||||
role: str
|
||||
content: str | list[dict[str, Any]]
|
||||
message_type: Literal["thought", "action", "observation"]
|
||||
|
||||
|
||||
# see _HistoryItem for required fields
|
||||
class HistoryItem(_HistoryItem, total=False):
|
||||
agent: str
|
||||
is_demo: bool
|
||||
thought: str
|
||||
action: str | None
|
||||
tool_calls: list[dict[str, str]] | None
|
||||
tool_call_ids: list[str] | None
|
||||
tags: list[str]
|
||||
cache_control: dict[str, Any] | None
|
||||
thinking_blocks: list[dict[str, Any]] | None
|
||||
|
||||
"""HistoryProcessors can add these tags to enable special processing"""
|
||||
|
||||
|
||||
History = list[HistoryItem]
|
||||
Trajectory = list[TrajectoryStep]
|
||||
|
||||
|
||||
# todo: Make this actually have the dataclasses instead of dict versions
|
||||
class AgentInfo(TypedDict, total=False):
|
||||
# same as `APIStats` from models.py
|
||||
model_stats: dict[str, float]
|
||||
exit_status: str | None
|
||||
submission: str | None
|
||||
# same as `ReviewerResult`
|
||||
review: dict[str, Any]
|
||||
edited_files30: str
|
||||
edited_files50: str
|
||||
edited_files70: str
|
||||
# only if summarizer is used
|
||||
summarizer: dict
|
||||
swe_agent_hash: str
|
||||
swe_agent_version: str
|
||||
swe_rex_version: str
|
||||
swe_rex_hash: str
|
||||
|
||||
|
||||
class AgentRunResult(BaseModel):
|
||||
info: AgentInfo
|
||||
trajectory: Trajectory
|
||||
0
.agent/vendor/mini-swe/sweagent/utils/__init__.py
vendored
Normal file
80
.agent/vendor/mini-swe/sweagent/utils/config.py
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from sweagent import REPO_ROOT
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
logger = get_logger("swea-config", emoji="🔧")
|
||||
|
||||
|
||||
def _convert_path_relative_to_repo_root(path: Path | str, root: Path | None = None) -> Path | str:
|
||||
original_type = type(path)
|
||||
path = Path(path).resolve()
|
||||
root = Path(root or os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
|
||||
relative_path = path.relative_to(root) if root in path.parents else path
|
||||
return relative_path if original_type is Path else str(relative_path)
|
||||
|
||||
|
||||
def _could_be_a_path(v: Any) -> bool:
|
||||
try:
|
||||
return Path(v).exists()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _strip_abspath_from_dict(value: dict | list | str, root: Path | None = None) -> dict | list | str:
|
||||
root = Path(root or os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
|
||||
if isinstance(value, dict):
|
||||
return {k: _strip_abspath_from_dict(v, root) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_strip_abspath_from_dict(v, root) for v in value]
|
||||
elif isinstance(value, str) and _could_be_a_path(value):
|
||||
return _convert_path_relative_to_repo_root(value, root)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_to_abspath(path: Path | str) -> Path:
|
||||
"""If path is not absolute, convert it to an absolute path
|
||||
using the SWE_AGENT_CONFIG_ROOT environment variable (if set) or
|
||||
REPO_ROOT as base.
|
||||
"""
|
||||
path = Path(path)
|
||||
root = Path(os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
|
||||
assert root.is_dir()
|
||||
if not path.is_absolute():
|
||||
path = root / path
|
||||
assert path.is_absolute()
|
||||
return path.resolve()
|
||||
|
||||
|
||||
def _convert_paths_to_abspath(paths: list[Path] | list[str]) -> list[Path]:
|
||||
return [_convert_path_to_abspath(p) for p in paths]
|
||||
|
||||
|
||||
def load_environment_variables(path: Path | None = None):
|
||||
"""Load environment variables from a .env file.
|
||||
If path is not provided, we first look for a .env file in the current working
|
||||
directory and then in the repository root.
|
||||
"""
|
||||
if path is None:
|
||||
cwd_path = Path.cwd() / ".env"
|
||||
repo_path = REPO_ROOT / ".env"
|
||||
if cwd_path.exists():
|
||||
path = cwd_path
|
||||
elif repo_path.exists():
|
||||
path = REPO_ROOT / ".env"
|
||||
else:
|
||||
logger.debug("No .env file found")
|
||||
return
|
||||
if not path.is_file():
|
||||
msg = f"No .env file found at {path}"
|
||||
raise FileNotFoundError(msg)
|
||||
anything_loaded = load_dotenv(dotenv_path=path)
|
||||
if anything_loaded:
|
||||
logger.info(f"Loaded environment variables from {path}")
|
||||
27
.agent/vendor/mini-swe/sweagent/utils/files.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_file(path: Path | str | None) -> Any:
|
||||
"""Load files based on their extension."""
|
||||
if path is None:
|
||||
return None
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(path)
|
||||
if path.is_dir():
|
||||
from datasets import load_from_disk
|
||||
|
||||
return load_from_disk(path)
|
||||
if path.suffix in [".json", ".traj"]:
|
||||
return json.loads(path.read_text())
|
||||
if path.suffix == ".jsonl":
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
if path.suffix == ".yaml":
|
||||
return yaml.safe_load(path.read_text())
|
||||
msg = f"Unsupported file extension: {path.suffix}"
|
||||
raise NotImplementedError(msg)
|
||||
155
.agent/vendor/mini-swe/sweagent/utils/github.py
vendored
Normal file
@@ -0,0 +1,155 @@
|
||||
import json
|
||||
import re
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
from ghapi.all import GhApi
|
||||
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
_logger = get_logger("swea-github", emoji="🔧")
|
||||
|
||||
_repo_privacy_cache: dict[str, bool] = {}
|
||||
|
||||
GITHUB_ISSUE_URL_PATTERN = re.compile(r"github\.com\/(.*?)\/(.*?)\/issues\/(\d+)")
|
||||
|
||||
|
||||
class InvalidGithubURL(Exception):
|
||||
"""Raised when a github URL is invalid"""
|
||||
|
||||
|
||||
GITHUB_REPO_URL_PATTERN = re.compile(r".*[/@]?github\.com\/([^/]+)\/([^/]+)")
|
||||
|
||||
|
||||
def _is_github_repo_url(data_path: str) -> bool:
|
||||
"""Check if data_path is an URL pointing to a github repository.
|
||||
Paths to issues or PRs will also match this pattern.
|
||||
"""
|
||||
return GITHUB_REPO_URL_PATTERN.search(data_path) is not None
|
||||
|
||||
|
||||
def _is_github_issue_url(data_path: str) -> bool:
|
||||
"""Check if data_path is an URL pointing to a github issue"""
|
||||
return GITHUB_ISSUE_URL_PATTERN.search(data_path) is not None
|
||||
|
||||
|
||||
def _get_commit(api: GhApi, owner: str, repo: str, ref: str | None = None):
|
||||
"""Get commit object from github api
|
||||
|
||||
Args:
|
||||
api (GhApi):
|
||||
owner (str): Repo owner, e.g., "SWE-agent"
|
||||
repo (str): Repo, e.g., "SWE-agent"
|
||||
ref (str, optional): Branch, tag or commit hash
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if ref:
|
||||
return api.repos.get_commit(owner, repo, ref) # type: ignore
|
||||
return api.repos.list_commits(owner, repo)[0] # type: ignore
|
||||
|
||||
|
||||
def _parse_gh_issue_url(issue_url: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns:
|
||||
owner: Repo owner
|
||||
repo: Repo name
|
||||
issue number: Issue number as str
|
||||
|
||||
Raises:
|
||||
InvalidGithubURL: If the URL is not a valid github issue URL
|
||||
"""
|
||||
match = GITHUB_ISSUE_URL_PATTERN.search(issue_url)
|
||||
if not match:
|
||||
msg = f"Invalid GitHub issue URL: {issue_url}"
|
||||
raise InvalidGithubURL(msg)
|
||||
res = match.groups()
|
||||
assert len(res) == 3
|
||||
return tuple(res) # type: ignore
|
||||
|
||||
|
||||
def _parse_gh_repo_url(repo_url: str) -> tuple[str, str]:
|
||||
"""
|
||||
Returns:
|
||||
owner: Repo owner/org
|
||||
repo: Repo name
|
||||
|
||||
Raises:
|
||||
InvalidGithubURL: If the URL is not a valid github repo URL
|
||||
"""
|
||||
match = GITHUB_REPO_URL_PATTERN.search(repo_url)
|
||||
if not match:
|
||||
msg = f"Invalid GitHub issue URL: {repo_url}"
|
||||
raise InvalidGithubURL(msg)
|
||||
res = match.groups()
|
||||
assert len(res) == 2
|
||||
return tuple(res) # type: ignore
|
||||
|
||||
|
||||
def _get_gh_issue_data(issue_url: str, *, token: str = ""):
|
||||
"""Returns github issue data in the form of a dictionary.
|
||||
See https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#get-an-issue
|
||||
for return format
|
||||
"""
|
||||
owner, repo, issue_number = _parse_gh_issue_url(issue_url)
|
||||
api = GhApi(token=token)
|
||||
return api.issues.get(owner, repo, issue_number) # type: ignore
|
||||
|
||||
|
||||
def _get_problem_statement_from_github_issue(
|
||||
owner: str, repo: str, issue_number: str, *, token: str | None = ""
|
||||
) -> str:
|
||||
"""Return problem statement from github issue"""
|
||||
api = GhApi(token=token)
|
||||
issue = api.issues.get(owner, repo, issue_number) # type: ignore
|
||||
title = issue.title if issue.title else ""
|
||||
body = issue.body if issue.body else ""
|
||||
return f"{title}\n{body}\n"
|
||||
|
||||
|
||||
def _get_associated_commit_urls(org: str, repo: str, issue_number: str, *, token: str = "") -> list[str]:
|
||||
"""Return the URLs of commits that would close an issue."""
|
||||
api = GhApi(token=token)
|
||||
# Strangely the "pull_request" field of api.issues.get is often not set
|
||||
# so we have to go through the events to check if there's a commit
|
||||
events = api.issues.list_events(org, repo, issue_number) # type: ignore
|
||||
commit_urls = []
|
||||
for event in events:
|
||||
if event.event != "referenced":
|
||||
continue
|
||||
if not event.commit_id:
|
||||
continue
|
||||
commit = api.repos.get_commit(org, repo, event.commit_id) # type: ignore
|
||||
message = commit.commit.message
|
||||
if f"fixes #{issue_number}" in message.lower() or f"closes #{issue_number}" in message.lower():
|
||||
commit_urls.append(commit.html_url)
|
||||
return commit_urls
|
||||
|
||||
|
||||
def _is_repo_private(owner_repo: str, token: str) -> bool:
|
||||
"""Check if a GitHub repository is private via the GitHub API.
|
||||
|
||||
Returns True if the repo is private or if a 404 is returned (GitHub returns
|
||||
404 for private repos when the token lacks access). Any other HTTP or
|
||||
network error is raised so callers can handle it explicitly.
|
||||
"""
|
||||
if owner_repo in _repo_privacy_cache:
|
||||
return _repo_privacy_cache[owner_repo]
|
||||
url = f"https://api.github.com/repos/{owner_repo}"
|
||||
headers = {"User-Agent": "sweagent"}
|
||||
if token:
|
||||
headers["Authorization"] = f"token {token}"
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
data = json.loads(resp.read())
|
||||
private = data.get("private", False)
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
_logger.warning("Repo '%s' returned 404 — assuming private", owner_repo)
|
||||
private = True
|
||||
else:
|
||||
raise
|
||||
_repo_privacy_cache[owner_repo] = private
|
||||
return private
|
||||
14
.agent/vendor/mini-swe/sweagent/utils/jinja_warnings.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
from sweagent.utils.log import get_logger
|
||||
|
||||
|
||||
def _warn_probably_wrong_jinja_syntax(template: str | None) -> None:
|
||||
"""Warn if the template uses {var} instead of {{var}}."""
|
||||
if template is None:
|
||||
return
|
||||
if "{" not in template:
|
||||
return
|
||||
for s in ["{%", "{ %", "{{"]:
|
||||
if s in template:
|
||||
return
|
||||
logger = get_logger("swea-config", emoji="🔧")
|
||||
logger.warning("Probably wrong Jinja syntax in template: %s. Make sure to use {{var}} instead of {var}.", template)
|
||||
175
.agent/vendor/mini-swe/sweagent/utils/log.py
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
from rich.logging import RichHandler
|
||||
from rich.text import Text
|
||||
|
||||
_SET_UP_LOGGERS: set[str] = set()
|
||||
_ADDITIONAL_HANDLERS: dict[str, logging.Handler] = {}
|
||||
_LOG_LOCK = threading.Lock()
|
||||
|
||||
logging.TRACE = 5 # type: ignore
|
||||
logging.addLevelName(logging.TRACE, "TRACE") # type: ignore
|
||||
|
||||
|
||||
def _interpret_level(level: int | str | None, *, default=logging.DEBUG) -> int:
|
||||
if not level:
|
||||
return default
|
||||
if isinstance(level, int):
|
||||
return level
|
||||
if level.isnumeric():
|
||||
return int(level)
|
||||
return getattr(logging, level.upper())
|
||||
|
||||
|
||||
_STREAM_LEVEL = _interpret_level(os.environ.get("SWE_AGENT_LOG_STREAM_LEVEL"))
|
||||
_INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER = False
|
||||
|
||||
_THREAD_NAME_TO_LOG_SUFFIX: dict[str, str] = {}
|
||||
"""Mapping from thread name to suffix to add to the logger name."""
|
||||
|
||||
|
||||
def register_thread_name(name: str) -> None:
|
||||
"""Register a suffix to add to the logger name for the current thread."""
|
||||
thread_name = threading.current_thread().name
|
||||
_THREAD_NAME_TO_LOG_SUFFIX[thread_name] = name
|
||||
|
||||
|
||||
class _RichHandlerWithEmoji(RichHandler):
|
||||
def __init__(self, emoji: str, *args, **kwargs):
|
||||
"""Subclass of RichHandler that adds an emoji to the log message."""
|
||||
super().__init__(*args, **kwargs)
|
||||
if not emoji.endswith(" "):
|
||||
emoji += " "
|
||||
self.emoji = emoji
|
||||
|
||||
def get_level_text(self, record: logging.LogRecord) -> Text:
|
||||
level_name = record.levelname.replace("WARNING", "WARN")
|
||||
return Text.styled((self.emoji + level_name).ljust(10), f"logging.level.{level_name.lower()}")
|
||||
|
||||
|
||||
def get_logger(name: str, *, emoji: str = "") -> logging.Logger:
|
||||
"""Get logger. Use this instead of `logging.getLogger` to ensure
|
||||
that the logger is set up with the correct handlers.
|
||||
"""
|
||||
thread_name = threading.current_thread().name
|
||||
if thread_name != "MainThread":
|
||||
name = name + "-" + _THREAD_NAME_TO_LOG_SUFFIX.get(thread_name, thread_name)
|
||||
logger = logging.getLogger(name)
|
||||
if logger.hasHandlers():
|
||||
# Already set up
|
||||
return logger
|
||||
handler = _RichHandlerWithEmoji(
|
||||
emoji=emoji,
|
||||
show_time=bool(os.environ.get("SWE_AGENT_LOG_TIME", False)),
|
||||
show_path=False,
|
||||
)
|
||||
handler.setLevel(_STREAM_LEVEL)
|
||||
# Set to lowest level and only use stream handlers to adjust levels
|
||||
logger.setLevel(logging.TRACE) # type: ignore
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
_SET_UP_LOGGERS.add(name)
|
||||
with _LOG_LOCK:
|
||||
for handler in _ADDITIONAL_HANDLERS.values():
|
||||
my_filter = getattr(handler, "my_filter", None)
|
||||
if my_filter is None:
|
||||
logger.addHandler(handler)
|
||||
elif isinstance(my_filter, str) and my_filter in name:
|
||||
logger.addHandler(handler)
|
||||
elif callable(my_filter) and my_filter(name):
|
||||
logger.addHandler(handler)
|
||||
if _INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER:
|
||||
_add_logger_name_to_stream_handler(logger)
|
||||
return logger
|
||||
|
||||
|
||||
def add_file_handler(
|
||||
path: PurePath | str,
|
||||
*,
|
||||
filter: str | Callable[[str], bool] | None = None,
|
||||
level: int | str = logging.TRACE, # type: ignore[attr-defined]
|
||||
id_: str = "",
|
||||
) -> str:
|
||||
"""Adds a file handler to all loggers that we have set up
|
||||
and all future loggers that will be set up with `get_logger`.
|
||||
|
||||
Args:
|
||||
filter: If str: Check that the logger name contains the filter string.
|
||||
If callable: Check that the logger name satisfies the condition returned by the callable.
|
||||
level: The level of the handler.
|
||||
id_: The id of the handler. If not provided, a random id will be generated.
|
||||
|
||||
Returns:
|
||||
The id of the handler. This can be used to remove the handler later.
|
||||
"""
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
handler = logging.FileHandler(path, encoding="utf-8")
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(_interpret_level(level))
|
||||
with _LOG_LOCK:
|
||||
# Lock because other thread might be modifying the _SET_UP_LOGGERS set
|
||||
for name in _SET_UP_LOGGERS:
|
||||
if filter is not None:
|
||||
if isinstance(filter, str) and filter not in name:
|
||||
continue
|
||||
if callable(filter) and not filter(name):
|
||||
continue
|
||||
logger = logging.getLogger(name)
|
||||
logger.addHandler(handler)
|
||||
handler.my_filter = filter # type: ignore
|
||||
if not id_:
|
||||
id_ = str(uuid.uuid4())
|
||||
_ADDITIONAL_HANDLERS[id_] = handler
|
||||
return id_
|
||||
|
||||
|
||||
def remove_file_handler(id_: str) -> None:
|
||||
"""Remove a file handler by its id."""
|
||||
handler = _ADDITIONAL_HANDLERS.pop(id_)
|
||||
with _LOG_LOCK:
|
||||
# Lock because other thread might be modifying the _SET_UP_LOGGERS set
|
||||
for log_name in _SET_UP_LOGGERS:
|
||||
logger = logging.getLogger(log_name)
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
def _add_logger_name_to_stream_handler(logger: logging.Logger) -> None:
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, _RichHandlerWithEmoji):
|
||||
formatter = logging.Formatter("[%(name)s] %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
def add_logger_names_to_stream_handlers() -> None:
|
||||
"""Add the logger name to the stream handler for all loggers that we have set up."""
|
||||
global _INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER
|
||||
_INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER = True
|
||||
with _LOG_LOCK:
|
||||
for logger in _SET_UP_LOGGERS:
|
||||
_add_logger_name_to_stream_handler(logging.getLogger(logger))
|
||||
|
||||
|
||||
def set_stream_handler_levels(level: int) -> None:
|
||||
"""Set the default stream level and adjust the levels of all stream handlers
|
||||
to be at most the given level.
|
||||
|
||||
Note: Can only be used to lower the level, not raise it.
|
||||
"""
|
||||
global _STREAM_LEVEL
|
||||
_STREAM_LEVEL = level
|
||||
with _LOG_LOCK:
|
||||
for name in _SET_UP_LOGGERS:
|
||||
logger = logging.getLogger(name)
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, _RichHandlerWithEmoji):
|
||||
current_level = handler.level
|
||||
if current_level < level:
|
||||
handler.setLevel(level)
|
||||
152
.agent/vendor/mini-swe/sweagent/utils/patch_formatter.py
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from unidiff import PatchSet
|
||||
|
||||
|
||||
class PatchFormatter:
|
||||
def __init__(
|
||||
self,
|
||||
patch: str,
|
||||
read_method: Callable[[str], str],
|
||||
):
|
||||
"""Given the final patch and access to the container that contains the repository,
|
||||
extract relevant lines from the modified file.
|
||||
|
||||
Args:
|
||||
patch: The patch as a string.
|
||||
read_method: Callable with path to file (relative to repository root) as argument
|
||||
that returns the file content as a string.
|
||||
"""
|
||||
self._patch = PatchSet(patch)
|
||||
self._patched_files: dict[str, str] = {}
|
||||
self._original_files: dict[str, str] = {}
|
||||
self._patch_applied = True
|
||||
self._read_file = read_method
|
||||
self._read_files(original=False)
|
||||
|
||||
@staticmethod
|
||||
def _merge_intervals(starts: list[int], stops: list[int]) -> tuple[list[int], list[int]]:
|
||||
"""Given two lists of integers, starts and stops, merges all overlapping intervals.
|
||||
|
||||
For example `starts=[1, 5, 18]`, `stops=[10, 13, 20]`
|
||||
should return `starts=[1, 18]`, `stops=[13, 20]`
|
||||
"""
|
||||
if not starts:
|
||||
assert not stops
|
||||
return [], []
|
||||
|
||||
intervals = sorted(zip(starts, stops))
|
||||
merged = []
|
||||
for start, stop in intervals:
|
||||
if not merged or merged[-1][1] < start:
|
||||
# No overlap
|
||||
merged.append([start, stop])
|
||||
else:
|
||||
# Overlap
|
||||
merged[-1][1] = max(merged[-1][1], stop)
|
||||
# Unzip again
|
||||
merged_starts, merged_stops = zip(*merged)
|
||||
return list(merged_starts), list(merged_stops)
|
||||
|
||||
def format_file(self, text: str, starts: list[int], stops: list[int], *, linenos: bool = True) -> str:
|
||||
"""Reads file and returns string representation of the relevant lines.
|
||||
|
||||
Args:
|
||||
path: The path to the file within the repo location
|
||||
starts: The starting line numbers of the relevant lines. The first line is line 1.
|
||||
stops: The stopping line numbers of the relevant lines. The stop is not inclusive.
|
||||
The first line is line 1.
|
||||
linenos: Whether to include line numbers
|
||||
"""
|
||||
if not starts:
|
||||
assert not stops
|
||||
return ""
|
||||
|
||||
assert len(starts) == len(stops)
|
||||
assert all(start >= 1 for start in starts)
|
||||
assert all(start < stop for start, stop in zip(starts, stops))
|
||||
starts, stops = self._merge_intervals(starts, stops)
|
||||
assert all(hunk1_start < hunk2_start for hunk1_start, hunk2_start in zip(starts, starts[1:]))
|
||||
out: list[str] = []
|
||||
if starts[0] > 1:
|
||||
# Count from 1
|
||||
out.append(f"[{starts[0] - 1} lines above omitted]")
|
||||
last_stop: int | None = None
|
||||
lines = text.splitlines()
|
||||
for start, stop in zip(starts, stops):
|
||||
assert start >= 1
|
||||
if last_stop is not None:
|
||||
n_omitted = start - last_stop
|
||||
# Check that we have non-overlapping hunks
|
||||
assert n_omitted >= 0
|
||||
if n_omitted:
|
||||
out.append(f"\n[{n_omitted} lines omitted]\n")
|
||||
# Count from 1
|
||||
these_lines = lines[start - 1 : stop - 1]
|
||||
if linenos:
|
||||
out.append("\n".join([f"{i:6d}: {l}" for i, l in enumerate(these_lines, start=start)]))
|
||||
else:
|
||||
out.append("\n".join(these_lines))
|
||||
last_stop = stop
|
||||
if last_stop < len(lines):
|
||||
# Stop is not inclusive
|
||||
omitted = len(lines) - last_stop
|
||||
assert omitted > 0
|
||||
out.append(f"[{omitted} lines below omitted]")
|
||||
return "\n".join(out)
|
||||
|
||||
def _get_hunk_lines(self, original: bool, *, context_length: int) -> dict[str, tuple[list[int], list[int]]]:
|
||||
"""Get the starts and stops for all files in the patch.
|
||||
|
||||
Args:
|
||||
original: Whether to read the original file or the patched file
|
||||
context_length: The number of lines to include above and below the hunk
|
||||
|
||||
Returns:
|
||||
A dictionary with the file path as key and a tuple of lists of starts and stops as value.
|
||||
"""
|
||||
out: dict[str, tuple[list[int], list[int]]] = {}
|
||||
for patch in self._patch:
|
||||
if not patch.is_modified_file:
|
||||
continue
|
||||
starts: list[int] = []
|
||||
stops: list[int] = []
|
||||
for hunk in patch:
|
||||
if original:
|
||||
# 1 is the lowest line number
|
||||
start = max(1, hunk.source_start - context_length)
|
||||
stop = hunk.source_start + hunk.source_length + context_length
|
||||
else:
|
||||
start = max(1, hunk.target_start - context_length)
|
||||
stop = hunk.target_start + hunk.target_length + context_length
|
||||
starts.append(start)
|
||||
stops.append(stop)
|
||||
out[patch.path] = (starts, stops)
|
||||
return out
|
||||
|
||||
def _read_files(self, original: bool) -> None:
|
||||
for patch in self._patch:
|
||||
path = patch.path
|
||||
if not patch.is_modified_file:
|
||||
continue
|
||||
if original:
|
||||
msg = "Original file reading not implemented"
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
assert self._patch_applied
|
||||
self._patched_files[path] = self._read_file(path)
|
||||
|
||||
@staticmethod
|
||||
def concat_files_strings(files: dict[str, str]) -> str:
|
||||
"""Concatenate multiple `read_files` outputs into a single string."""
|
||||
out = []
|
||||
for path, content in files.items():
|
||||
out.append(f"[File: {path}]\n{content}")
|
||||
return "\n\n".join(out)
|
||||
|
||||
def get_files_str(self, *, original: bool, context_length: int | None = 50, linenos: bool = True) -> str:
|
||||
hunk_lines = self._get_hunk_lines(original=original, context_length=context_length)
|
||||
sources = self._original_files if original else self._patched_files
|
||||
return self.concat_files_strings(
|
||||
{path: self.format_file(text, *hunk_lines[path], linenos=linenos) for path, text in sources.items()}
|
||||
)
|
||||
45
.agent/vendor/mini-swe/sweagent/utils/serialization.py
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
import io
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.scalarstring import LiteralScalarString as LSS
|
||||
|
||||
|
||||
def _convert_to_yaml_literal_string(d: Any) -> Any:
|
||||
"""Convert any multi-line strings in nested data object to LiteralScalarString.
|
||||
This will then use the `|-` syntax of yaml.
|
||||
"""
|
||||
d = deepcopy(d)
|
||||
if isinstance(d, dict):
|
||||
for key, value in d.items():
|
||||
d[key] = _convert_to_yaml_literal_string(value)
|
||||
elif isinstance(d, list):
|
||||
for i, item in enumerate(d):
|
||||
d[i] = _convert_to_yaml_literal_string(item)
|
||||
elif isinstance(d, str) and "\n" in d:
|
||||
d = LSS(d.replace("\r\n", "\n").replace("\r", "\n"))
|
||||
return d
|
||||
|
||||
|
||||
def _yaml_serialization_with_linebreaks(data: Any) -> str:
|
||||
data = _convert_to_yaml_literal_string(data)
|
||||
yaml = YAML()
|
||||
yaml.indent(mapping=2, sequence=4, offset=2)
|
||||
yaml.width = float("inf")
|
||||
yaml.default_flow_style = False
|
||||
buffer = io.StringIO()
|
||||
yaml.dump(data, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def merge_nested_dicts(d1: dict, d2: dict) -> dict:
|
||||
"""Merge two nested dictionaries, updating d1 in place.
|
||||
If a key exists in both dictionaries, the value from d2 will be used.
|
||||
"""
|
||||
for key, value in d2.items():
|
||||
if isinstance(value, dict):
|
||||
d1[key] = merge_nested_dicts(d1.get(key, {}), value)
|
||||
else:
|
||||
d1[key] = value
|
||||
return d1
|
||||