wip: [01-stabilize] paused at task 1/1 - OCR Hallucination Immune logic via Semantic delta window and fret-isolation

This commit is contained in:
2026-03-29 22:08:40 +09:00
parent aca7bf592a
commit 2507de45d3
4289 changed files with 732689 additions and 28672 deletions

View File

@@ -0,0 +1,114 @@
from __future__ import annotations
import os
import sys
from functools import partial
from logging import WARNING, getLogger
from pathlib import Path
import swerex.utils.log as log_swerex
from git import Repo
from packaging import version
from sweagent.utils.log import get_logger
__version__ = "1.1.0"
PYTHON_MINIMUM_VERSION = (3, 11)
SWEREX_MINIMUM_VERSION = "1.2.0"
SWEREX_RECOMMENDED_VERSION = "1.2.1"
# Monkey patch the logger to use our implementation
log_swerex.get_logger = partial(get_logger, emoji="🦖")
# See https://github.com/SWE-agent/SWE-agent/issues/585
getLogger("datasets").setLevel(WARNING)
getLogger("numexpr.utils").setLevel(WARNING)
getLogger("LiteLLM").setLevel(WARNING)
PACKAGE_DIR = Path(__file__).resolve().parent
if sys.version_info < PYTHON_MINIMUM_VERSION:
msg = (
f"Python {sys.version_info.major}.{sys.version_info.minor} is not supported. "
"SWE-agent requires Python 3.11 or higher."
)
raise RuntimeError(msg)
assert PACKAGE_DIR.is_dir(), PACKAGE_DIR
REPO_ROOT = PACKAGE_DIR.parent
assert REPO_ROOT.is_dir(), REPO_ROOT
CONFIG_DIR = Path(os.getenv("SWE_AGENT_CONFIG_DIR", PACKAGE_DIR.parent / "config"))
assert CONFIG_DIR.is_dir(), CONFIG_DIR
TOOLS_DIR = Path(os.getenv("SWE_AGENT_TOOLS_DIR", PACKAGE_DIR.parent / "tools"))
assert TOOLS_DIR.is_dir(), TOOLS_DIR
TRAJECTORY_DIR = Path(os.getenv("SWE_AGENT_TRAJECTORY_DIR", PACKAGE_DIR.parent / "trajectories"))
assert TRAJECTORY_DIR.is_dir(), TRAJECTORY_DIR
def get_agent_commit_hash() -> str:
"""Get the commit hash of the current SWE-agent commit.
If we cannot get the hash, we return an empty string.
"""
try:
repo = Repo(REPO_ROOT, search_parent_directories=False)
except Exception:
return "unavailable"
return repo.head.object.hexsha
def get_rex_commit_hash() -> str:
import swerex
try:
repo = Repo(Path(swerex.__file__).resolve().parent.parent.parent, search_parent_directories=False)
except Exception:
return "unavailable"
return repo.head.object.hexsha
def get_rex_version() -> str:
from swerex import __version__ as rex_version
return rex_version
def get_agent_version_info() -> str:
hash = get_agent_commit_hash()
rex_hash = get_rex_commit_hash()
rex_version = get_rex_version()
return f"This is SWE-agent version {__version__} ({hash=}) with SWE-ReX version {rex_version} ({rex_hash=})."
def impose_rex_lower_bound() -> None:
rex_version = get_rex_version()
minimal_rex_version = "1.2.0"
if version.parse(rex_version) < version.parse(minimal_rex_version):
msg = (
f"SWE-ReX version {rex_version} is too old. Please update to at least {minimal_rex_version} by "
"running `pip install --upgrade swe-rex`."
"You can also rerun `pip install -e .` in this repository to install the latest version."
)
raise RuntimeError(msg)
if version.parse(rex_version) < version.parse(SWEREX_RECOMMENDED_VERSION):
msg = (
f"SWE-ReX version {rex_version} is not recommended. Please update to at least {SWEREX_RECOMMENDED_VERSION} by "
"running `pip install --upgrade swe-rex`."
"You can also rerun `pip install -e .` in this repository to install the latest version."
)
get_logger("swe-agent", emoji="👋").warning(msg)
impose_rex_lower_bound()
get_logger("swe-agent", emoji="👋").info(get_agent_version_info())
__all__ = [
"PACKAGE_DIR",
"CONFIG_DIR",
"get_agent_commit_hash",
"get_agent_version_info",
"__version__",
]

View File

@@ -0,0 +1,4 @@
from sweagent.run.run import main
if __name__ == "__main__":
main()

View File

View File

@@ -0,0 +1,317 @@
from abc import abstractmethod
from textwrap import dedent
from typing import Any, Literal
from jinja2 import Template
from pydantic import BaseModel
from sweagent.agent.models import AbstractModel
from sweagent.agent.problem_statement import ProblemStatement
from sweagent.exceptions import FormatError
from sweagent.tools.tools import ToolHandler
from sweagent.types import Trajectory
from sweagent.utils.log import get_logger
class ActionSamplerOutput(BaseModel):
completion: dict[str, Any]
messages: list[dict[str, Any]] = []
trajectory_items: list[dict[str, Any]] = []
extra_info: dict[str, Any] = {}
class AbstractActionSampler:
def __init__(self, model: AbstractModel, tools: ToolHandler):
self._model = model
self._tools = tools
self._logger = get_logger("action_sampler", emoji="👥")
@abstractmethod
def get_action(
self,
problem_statement: ProblemStatement,
trajectory: Trajectory,
history: list[dict[str, Any]],
) -> ActionSamplerOutput:
"""Returns action with tool calls"""
pass
class AskColleaguesConfig(BaseModel):
type: Literal["ask_colleagues"] = "ask_colleagues"
n_samples: int = 2
def get(self, model: AbstractModel, tools: ToolHandler) -> "AskColleagues":
return AskColleagues(self, model, tools)
class AskColleagues(AbstractActionSampler):
def __init__(self, config: AskColleaguesConfig, model: AbstractModel, tools: ToolHandler):
super().__init__(model, tools)
self.config = config
def get_colleague_discussion(self, completions: list[dict[str, Any]]) -> str:
"""Concat all completions into a single string"""
out = "Your colleagues had the following ideas: \n\n"
n_parsed_ok = 0
for i, completion in enumerate(completions):
try:
thought, action = self._tools.parse_actions(completion)
except FormatError:
self._logger.warning("Could not parse completion %s, skipping.", completion)
continue
n_parsed_ok += 1
out += f"Thought (colleague {i}): {thought}\nProposed Action (colleague {i}): {action}\n\n"
if n_parsed_ok == 0:
msg = "No completions could be parsed."
raise FormatError(msg)
out += (
"Please summarize and compare the ideas and propose and action to take. "
"Finally choose one action to perform and explain it in detail and include it as a tool call. "
"<important>You must include a thought and action (as a tool/function call). Do not try to invoke commands with triple backticks, use function calls instead.</important>"
)
return out
def get_action(
self,
problem_statement: ProblemStatement,
trajectory: Trajectory,
history: list[dict[str, Any]],
) -> ActionSamplerOutput:
"""Returns action with tool calls"""
completions = self._model.query(history, n=self.config.n_samples) # type: ignore
discussion = self.get_colleague_discussion(completions)
self._logger.info(f"COLLEAGUE DISCUSSION:\n{discussion}")
new_messages = [
{"role": "user", "content": discussion},
]
final_completion = self._model.query(history + new_messages) # type: ignore
return ActionSamplerOutput(
completion=final_completion,
extra_info={"colleagues": discussion},
)
class BinaryTrajectoryComparisonConfig(BaseModel):
type: Literal["binary_trajectory_comparison"] = "binary_trajectory_comparison"
min_n_samples: int = 4
max_n_samples: int = 10
comparison_temperature: float | None = None
"""Override the model's temperature. If None, take the temperature configured for the model."""
system_template: str = """<setting>You are an expert software engineer overseeing junior developers. They suggest actions to take to solve a problem. You must choose the best action to take. </setting>"""
instance_template: str = dedent("""
We're solving the following problem
<problem_statement>
{{problem_statement}}
</problem_statement>
So far, we've performed the following actions:
<trajectory>
{{traj}}
</trajectory>
""")
comparison_template: str = dedent("""
Two junior developers suggested the following actions:
<thought1>
{{thought1}}
</thought1>
<action1>
{{action1}}
</action1>
<thought2>
{{thought2}}
</thought2>
<action2>
{{action2}}
</action2>
Please compare the two actions in detail.
Which action should we take?
If you think the first action is better, respond with "first".
If you think the second action is better, respond with "second".
The last line of your response MUST be "first" or "second".
""")
def get(self, model: AbstractModel, tools: ToolHandler) -> "BinaryTrajectoryComparison":
return BinaryTrajectoryComparison(self, model, tools)
class BinaryTrajectoryComparison(AbstractActionSampler):
def __init__(self, config: BinaryTrajectoryComparisonConfig, model: AbstractModel, tools: ToolHandler):
super().__init__(model, tools)
self.config = config
def _format_trajectory(self, trajectory: Trajectory) -> str:
steps = []
for i, step in enumerate(trajectory):
steps.append(f"Action {i}: {step['action']}\n Observation {i}: {step['observation']}")
return "\n".join(steps)
def format_messages(
self,
*,
problem_statement: ProblemStatement,
trajectory: Trajectory,
thought1: str,
action1: str,
thought2: str,
action2: str,
use_cache_control: bool = False,
) -> list[dict]:
system_message = self.config.system_template
self._logger.debug(f"MODEL INPUT (system)\n{system_message}")
ps_format_dict = {
"problem_statement": problem_statement.get_problem_statement(),
**problem_statement.get_extra_fields(),
}
user_message = Template(self.config.instance_template).render(
**ps_format_dict,
traj=self._format_trajectory(trajectory),
)
self._logger.debug(f"MODEL INPUT (instance)\n{user_message}")
comparison_message = Template(self.config.comparison_template).render(
thought1=thought1,
action1=action1,
thought2=thought2,
action2=action2,
)
self._logger.debug(f"MODEL INPUT (comparison)\n{comparison_message}")
cache_control_kwargs = {"cache_control": {"type": "ephemeral"}} if use_cache_control else {}
return [
{"role": "system", "content": system_message},
{
"role": "user",
"content": [{"type": "text", "text": user_message, **cache_control_kwargs}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": comparison_message,
}
],
},
]
def filter_duplicates(self, completions: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Filter out duplicate actions"""
thoughts: list[str] = []
actions: list[str] = []
filtered_completions: list[dict[str, Any]] = []
for pc in completions:
thought, action = self._tools.parse_actions(pc)
if action not in actions:
thoughts.append(thought)
actions.append(action)
filtered_completions.append(pc)
if len(filtered_completions) < len(completions):
self._logger.debug("Filtering duplicates: %d -> %d", len(completions), len(filtered_completions))
return filtered_completions
def filter_parseable_completions(self, completions: list[dict[str, Any]]) -> list[dict[str, Any]]:
filtered_completions = []
for completion in completions:
try:
self._tools.parse_actions(completion)
except FormatError:
self._logger.warning("Could not parse completion %s, skipping.", completion)
continue
filtered_completions.append(completion)
if len(filtered_completions) == 0:
msg = "No completions could be parsed."
raise FormatError(msg)
return filtered_completions
def contains_edits(self, completions: list[dict[str, Any]]) -> bool:
keywords = ["edit", "str_replace_editor insert", "str_replace_editor str_replace"]
for completion in completions:
_, action = self._tools.parse_actions(completion)
if any(action.startswith(keyword) for keyword in keywords):
return True
return False
def get_completions(self, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
completions = self._model.query(history, n=self.config.min_n_samples) # type: ignore
completions = self.filter_parseable_completions(completions)
completions = self.filter_duplicates(completions)
if not completions:
msg = "No completions could be parsed."
raise FormatError(msg)
if self.contains_edits(completions) and self.config.min_n_samples < self.config.max_n_samples:
self._logger.debug("Edits were proposed, will sample more")
new_completions = self._model.query(history, n=self.config.max_n_samples - self.config.min_n_samples) # type: ignore
completions = self.filter_duplicates(self.filter_parseable_completions(completions + new_completions))
if len(completions) == 1:
_, action = self._tools.parse_actions(completions[0])
self._logger.warning("Only identical actions were proposed (action=%s)", action)
return completions
def get_action(
self,
*,
problem_statement: ProblemStatement,
trajectory: Trajectory,
history: list[dict[str, Any]],
) -> ActionSamplerOutput:
completions = self.get_completions(history)
best_idx = 0
comparison_log = []
for i in range(1, len(completions)):
thought1, action1 = self._tools.parse_actions(completions[best_idx])
thought2, action2 = self._tools.parse_actions(completions[i])
messages = self.format_messages(
problem_statement=problem_statement,
trajectory=trajectory,
thought1=thought1,
action1=action1,
thought2=thought2,
action2=action2,
use_cache_control=len(completions) >= 3,
)
response = self._model.query(messages, temperature=self.config.comparison_temperature)["message"] # type: ignore
self._logger.info(f"RESPONSE: {response}")
idx = self.interpret(response)
comparison_log.append(
{
"comparison_between": (best_idx, i),
"messages": messages,
"response": response,
"idx": idx,
}
)
best_idx = i if idx == 1 else best_idx
return ActionSamplerOutput(
completion=completions[best_idx],
extra_info={"comparison_log": comparison_log},
)
def interpret(self, response: str) -> Literal[0, 1]:
"""Interpret response from LM. Note: 1-based indexing"""
last_line = response.strip().split("\n")[-1].strip()
if "first" in last_line.lower():
return 0
elif "second" in last_line.lower():
return 1
self._logger.warning("Could not interpret response: %s, will choose first submission.", response)
return 0
ActionSamplerConfig = BinaryTrajectoryComparisonConfig | AskColleaguesConfig

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
from pathlib import Path
from typing import Self
from sweagent.agent.agents import DefaultAgent, ShellAgentConfig
from sweagent.agent.models import HumanModel, HumanModelConfig, get_model
from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig
from sweagent.environment.swe_env import SWEEnv
from sweagent.tools.parsing import ActionOnlyParser
from sweagent.tools.tools import ToolHandler
from sweagent.types import AgentRunResult, StepOutput
class ShellAgent(DefaultAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_config(cls, config: ShellAgentConfig) -> Self:
# To ensure that all models stay completely independent, we deepcopy the
# model config, because it lives on as a property in the model, tools, etc.
config = config.model_copy(deep=True)
model = get_model(config.model, config.tools)
return cls(
templates=config.templates,
tools=ToolHandler(config.tools),
history_processors=config.history_processors,
model=model,
max_requeries=config.max_requeries,
)
def human_step_in(self) -> None:
"""Replace the current model with a HumanModel instance.
This allows for human intervention during agent execution.
"""
self._original_model = self.model
self._original_parser = self.tools.config.parse_function
human_config = HumanModelConfig(name="human", catch_eof=False)
self.model = get_model(human_config, self.tools.config)
self.tools.config.parse_function = ActionOnlyParser()
self.logger.info("Switched to human mode. Agent will now accept human input. Press ^D to switch back.")
def human_step_out(self) -> None:
"""Switch back to the original model from human mode.
This is called when ^D is pressed in human mode.
"""
if not hasattr(self, "_original_model") or self._original_model is None:
self.logger.info("No previous model to switch back to. Remaining in current mode.")
return
self.model = self._original_model
self.tools.config.parse_function = self._original_parser # type: ignore
self._original_model = None
self._original_parser = None
self.logger.info("Switched back to AI model mode.")
def run(
self,
env: SWEEnv,
problem_statement: ProblemStatement | ProblemStatementConfig,
*,
output_dir: Path = Path("."),
) -> AgentRunResult:
"""Run the agent on a problem instance. This method contains the
main loop that repeatedly calls `self._step` until the problem is solved.
Args:
setup_args: Arguments to pass to the agent's setup method.
env: The environment to run the agent on.
traj_dir: Directory to save the trajectory to
interruptible: Whether the human can jump in by pressing ^C
"""
self.setup(env=env, problem_statement=problem_statement, output_dir=output_dir)
# Run action/observation loop
self._chook.on_run_start()
step_output = StepOutput()
while not step_output.done:
try:
step_output = self.step()
self.save_trajectory()
except KeyboardInterrupt:
if not isinstance(self.model, HumanModel):
self.human_step_in()
continue
raise
except EOFError:
# Can only happen if we have a human model, so switch back
self.logger.info("Detected ^D - switching back to AI mode")
self.human_step_out()
continue
if step_output.done and not isinstance(self.model, HumanModel):
# Human has to submit the solution
self.logger.info("Robot is done! Please submit the solution.")
self.human_step_in()
step_output.done = False
self._chook.on_run_done(trajectory=self.trajectory, info=self.info)
self.logger.info("Trajectory saved to %s", self.traj_path)
# Here we want to return the "global" information (e.g., submission should
# be the best submission instead of the last one, etc.), so we get it from the traj file
data = self.get_trajectory_data()
return AgentRunResult(info=data["info"], trajectory=data["trajectory"])

View File

@@ -0,0 +1,399 @@
from __future__ import annotations
import copy
import re
from abc import abstractmethod
from typing import Annotated, Literal, Protocol
from pydantic import BaseModel, ConfigDict, Field, field_validator
from sweagent.types import History, HistoryItem
class AbstractHistoryProcessor(Protocol):
@abstractmethod
def __call__(self, history: History) -> History:
raise NotImplementedError
# Utility functions
# -----------------
def _get_content_stats(entry: HistoryItem) -> tuple[int, int]:
if isinstance(entry["content"], str):
return len(entry["content"].splitlines()), 0
n_text_lines = sum(len(item["text"].splitlines()) for item in entry["content"] if item.get("type") == "text")
n_images = sum(1 for item in entry["content"] if item.get("type") == "image_url")
return n_text_lines, n_images
def _get_content_text(entry: HistoryItem) -> str:
if isinstance(entry["content"], str):
return entry["content"]
assert len(entry["content"]) == 1, "Expected single message in content"
return entry["content"][0]["text"]
def _set_content_text(entry: HistoryItem, text: str) -> None:
if isinstance(entry["content"], str):
entry["content"] = text
else:
assert len(entry["content"]) == 1, "Expected single message in content"
entry["content"][0]["text"] = text
def _clear_cache_control(entry: HistoryItem) -> None:
if isinstance(entry["content"], list):
for item in entry["content"]:
item.pop("cache_control", None)
entry.pop("cache_control", None)
def _set_cache_control(entry: HistoryItem) -> None:
if not isinstance(entry["content"], list):
entry["content"] = [ # type: ignore
{
"type": "text",
"text": _get_content_text(entry),
"cache_control": {"type": "ephemeral"},
}
]
else:
entry["content"][0]["cache_control"] = {"type": "ephemeral"}
if entry["role"] == "tool":
# Workaround for weird bug
entry["content"][0].pop("cache_control", None)
entry["cache_control"] = {"type": "ephemeral"}
# History processors
# ------------------
class DefaultHistoryProcessor(BaseModel):
type: Literal["default"] = "default"
"""Do not change. Used for (de)serialization."""
# pydantic config
model_config = ConfigDict(extra="forbid")
def __call__(self, history: History) -> History:
return history
class LastNObservations(BaseModel):
"""Elide all but the last n observations or remove tagged observations.
This is our most classic history processor, used in the original paper
to elide but the last 5 observations.
Elided observations are replaced by "Old environment output: (n lines omitted)".
Typical configuration:
```yaml
agent:
history_processors:
- type: last_n_observations
n: 5
```
as for example in use in the SWE-agent 0.7 config at
https://github.com/SWE-agent/SWE-agent/blob/main/config/sweagent_0_7/07.yaml
For most use cases, you only need to set `n`.
Note that using this history processor will break prompt caching (as the
history of every query will change every time due to the elided observations).
There are some workarounds possible with the `polling` parameter.
However, most SotA models can now fit a lot of context, so generally this
history processor is not always needed anymore.
"""
n: int
"""Number of observations to keep."""
polling: int = 1
"""How many steps to keep between updating the number of observations to keep.
This is useful for caching, as we want to remove more and more messages, but every
time we change the history, we need to cache everything again.
Effectively, we will now keep between `n` and `n+polling` observations.
"""
always_remove_output_for_tags: set[str] = {"remove_output"}
"""Any observation with a `tags` field containing one of these strings will be elided,
even if it is one of the last n observations.
"""
always_keep_output_for_tags: set[str] = {"keep_output"}
"""Any observation with a `tags` field containing one of these strings will be kept,
even if it is not one of the last n observations.
"""
type: Literal["last_n_observations"] = "last_n_observations"
"""Do not change. Used for (de)serialization."""
# pydantic config
model_config = ConfigDict(extra="forbid")
@field_validator("n")
def validate_n(cls, n: int) -> int:
if n <= 0:
msg = "n must be a positive integer"
raise ValueError(msg)
return n
def _get_omit_indices(self, history: History) -> list[int]:
observation_indices = [
idx
for idx, entry in enumerate(history)
if entry.get("message_type") == "observation" and not entry.get("is_demo", False)
]
last_removed_idx = max(0, (len(observation_indices) // self.polling) * self.polling - self.n)
# Note: We never remove the first observation, as it is the instance template
return observation_indices[1:last_removed_idx]
def __call__(self, history: History) -> History:
new_history = []
omit_content_idxs = self._get_omit_indices(history)
for idx, entry in enumerate(history):
tags = set(entry.get("tags", []))
if ((idx not in omit_content_idxs) or (tags & self.always_keep_output_for_tags)) and not (
tags & self.always_remove_output_for_tags
):
new_history.append(entry)
else:
data = entry.copy()
assert data.get("message_type") == "observation", (
f"Expected observation for dropped entry, got: {data.get('message_type')}"
)
num_text_lines, num_images = _get_content_stats(data)
data["content"] = f"Old environment output: ({num_text_lines} lines omitted)"
if num_images > 0:
data["content"] += f" ({num_images} images omitted)"
new_history.append(data)
return new_history
class TagToolCallObservations(BaseModel):
"""Adds tags to history items for specific tool calls."""
type: Literal["tag_tool_call_observations"] = "tag_tool_call_observations"
"""Do not change. Used for (de)serialization."""
tags: set[str] = {"keep_output"}
"""Add the following tag to all observations matching the search criteria."""
function_names: set[str] = set()
"""Only consider observations made by tools with these names."""
# pydantic config
model_config = ConfigDict(extra="forbid")
def _add_tags(self, entry: HistoryItem) -> None:
tags = set(entry.get("tags", []))
tags.update(self.tags)
entry["tags"] = list(tags)
def _should_add_tags(self, entry: HistoryItem) -> bool:
if entry.get("message_type") != "action":
return False
function_calls = entry.get("tool_calls", [])
if not function_calls:
return False
function_names = {call["function"]["name"] for call in function_calls} # type: ignore
return bool(self.function_names & function_names)
def __call__(self, history: History) -> History:
for entry in history:
if self._should_add_tags(entry):
self._add_tags(entry)
return history
class ClosedWindowHistoryProcessor(BaseModel):
"""For each value in history, keep track of which windows have been shown.
We want to mark windows that should stay open (they're the last window for a particular file)
Then we'll replace all other windows with a simple summary of the window (i.e. number of lines)
"""
type: Literal["closed_window"] = "closed_window"
"""Do not change. Used for (de)serialization."""
_pattern = re.compile(r"^(\d+)\:.*?(\n|$)", re.MULTILINE)
_file_pattern = re.compile(r"\[File:\s+(.*)\s+\(\d+\s+lines\ total\)\]")
# pydantic config
model_config = ConfigDict(extra="forbid")
def __call__(self, history):
new_history = list()
windows = set()
for entry in reversed(history):
data = entry.copy()
if data["role"] != "user":
new_history.append(entry)
continue
if data.get("is_demo", False):
new_history.append(entry)
continue
matches = list(self._pattern.finditer(entry["content"]))
if len(matches) >= 1:
file_match = self._file_pattern.search(entry["content"])
if file_match:
file = file_match.group(1)
else:
continue
if file in windows:
start = matches[0].start()
end = matches[-1].end()
data["content"] = (
entry["content"][:start]
+ f"Outdated window with {len(matches)} lines omitted...\n"
+ entry["content"][end:]
)
windows.add(file)
new_history.append(data)
return list(reversed(new_history))
class CacheControlHistoryProcessor(BaseModel):
"""This history processor adds manual cache control marks to the history.
Use this when running with anthropic claude.
"""
type: Literal["cache_control"] = "cache_control"
"""Do not change. Used for (de)serialization."""
last_n_messages: int = 2
"""Add cache control to the last n user messages (and clear it for anything else).
In most cases this should be set to 2 (caching for multi-turn conversations).
When resampling and running concurrent instances, you want to set it to 1.
If set to <= 0, any set cache control will be removed from all messages.
"""
last_n_messages_offset: int = 0
"""E.g., set to 1 to start cache control after the second to last user message.
This can be useful in rare cases, when you want to modify the last message after
we've got the completion and you want to avoid cache mismatch.
"""
tagged_roles: list[str] = ["user", "tool"]
"""Only add cache control to messages with these roles."""
# pydantic config
model_config = ConfigDict(extra="forbid")
def __call__(self, history: History) -> History:
new_history = []
n_tagged = 0
for i_entry, entry in enumerate(reversed(history)):
# Clear cache control from previous messages
_clear_cache_control(entry)
if (
n_tagged < self.last_n_messages
and entry["role"] in self.tagged_roles
and i_entry >= self.last_n_messages_offset
):
_set_cache_control(entry)
n_tagged += 1
new_history.append(entry)
return list(reversed(new_history))
class RemoveRegex(BaseModel):
"""This history processor can remove arbitrary content from history items"""
remove: list[str] = ["<diff>.*</diff>"]
"""Regex patterns to remove from history items"""
keep_last: int = 0
"""Keep the last n history items unchanged"""
type: Literal["remove_regex"] = "remove_regex"
"""Do not change. Used for (de)serialization."""
# pydantic config
model_config = ConfigDict(extra="forbid")
def __call__(self, history: History) -> History:
new_history = []
for i_entry, entry in enumerate(reversed(history)):
entry = copy.deepcopy(entry)
if i_entry < self.keep_last:
new_history.append(entry)
else:
if isinstance(entry["content"], list):
for item in entry["content"]:
if item["type"] == "text":
for pattern in self.remove:
item["text"] = re.sub(pattern, "", item["text"], flags=re.DOTALL)
else:
assert isinstance(entry["content"], str), "Expected string content"
for pattern in self.remove:
entry["content"] = re.sub(pattern, "", entry["content"], flags=re.DOTALL)
new_history.append(entry)
return list(reversed(new_history))
class ImageParsingHistoryProcessor(BaseModel):
"""Parse embedded base64 images from markdown and convert to multi-modal format."""
type: Literal["image_parsing"] = "image_parsing"
allowed_mime_types: set[str] = {"image/png", "image/jpeg", "image/webp"}
_pattern = re.compile(r"(!\[([^\]]*)\]\(data:)([^;]+);base64,([^)]+)(\))")
model_config = ConfigDict(extra="forbid")
def __call__(self, history: History) -> History:
return [self._process_entry(entry) for entry in history]
def _process_entry(self, entry: HistoryItem) -> HistoryItem:
if entry.get("role") not in ["user", "tool"]:
return entry
entry = copy.deepcopy(entry)
content = _get_content_text(entry)
segments = self._parse_images(content)
if any(seg["type"] == "image_url" for seg in segments):
entry["content"] = segments
return entry
def _parse_images(self, content: str) -> list[dict]:
segments = []
last_end = 0
has_images = False
def add_text(text: str) -> None:
"""Add text to the last segment if it's text, otherwise create new text segment."""
if text and segments and segments[-1]["type"] == "text":
segments[-1]["text"] += text
elif text:
segments.append({"type": "text", "text": text})
for match in self._pattern.finditer(content):
markdown_prefix, alt_text, mime_type, base64_data, markdown_suffix = match.groups()
add_text(content[last_end : match.start()])
mime_type = "image/jpeg" if mime_type == "image/jpg" else mime_type
if mime_type in self.allowed_mime_types:
add_text(markdown_prefix)
segments.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}})
add_text(markdown_suffix)
has_images = True
else:
add_text(match.group(0))
last_end = match.end()
add_text(content[last_end:])
return segments if has_images else [{"type": "text", "text": content}]
HistoryProcessor = Annotated[
DefaultHistoryProcessor
| LastNObservations
| ClosedWindowHistoryProcessor
| TagToolCallObservations
| CacheControlHistoryProcessor
| RemoveRegex
| ImageParsingHistoryProcessor,
Field(discriminator="type"),
]

View File

@@ -0,0 +1,139 @@
from typing import TYPE_CHECKING
from sweagent.types import AgentInfo, StepOutput, Trajectory
if TYPE_CHECKING:
# avoid circular import
from sweagent.agent.agents import DefaultAgent
class AbstractAgentHook:
def on_init(self, *, agent: "DefaultAgent"):
"""Note: Depending on the internals of `Agent` should be done with care,
it's best to use this as little as possible.
"""
def on_run_start(
self,
): ...
def on_step_start(self): ...
def on_actions_generated(self, *, step: StepOutput): ...
def on_action_started(self, *, step: StepOutput): ...
def on_action_executed(self, *, step: StepOutput): ...
def on_step_done(self, *, step: StepOutput, info: AgentInfo): ...
def on_run_done(self, *, trajectory: Trajectory, info: AgentInfo): ...
def on_setup_attempt(self): ...
def on_model_query(self, *, messages: list[dict[str, str]], agent: str):
"""Actually query the model with the complete history."""
def on_query_message_added(
self,
*,
agent: str,
role: str,
content: str,
message_type: str,
is_demo: bool = False,
thought: str = "",
action: str = "",
tool_calls: list[dict[str, str]] | None = None,
tool_call_ids: list[str] | None = None,
): ...
def on_setup_done(self): ...
def on_tools_installation_started(self): ...
class CombinedAgentHook(AbstractAgentHook):
def __init__(self, hooks: list[AbstractAgentHook] | None = None):
self._hooks = hooks or []
def add_hook(self, hook: AbstractAgentHook):
self._hooks.append(hook)
@property
def hooks(self) -> list[AbstractAgentHook]:
return self._hooks
def on_init(self, *, agent: "DefaultAgent"):
for hook in self.hooks:
hook.on_init(agent=agent)
def on_run_start(self):
for hook in self.hooks:
hook.on_run_start()
def on_step_start(self):
for hook in self.hooks:
hook.on_step_start()
def on_actions_generated(self, *, step: StepOutput):
for hook in self.hooks:
hook.on_actions_generated(step=step)
def on_action_started(self, *, step: StepOutput):
for hook in self.hooks:
hook.on_action_started(step=step)
def on_action_executed(self, *, step: StepOutput):
for hook in self.hooks:
hook.on_action_executed(step=step)
def on_step_done(self, *, step: StepOutput, info: AgentInfo):
for hook in self.hooks:
hook.on_step_done(step=step, info=info)
def on_run_done(self, *, trajectory: Trajectory, info: AgentInfo):
for hook in self.hooks:
hook.on_run_done(trajectory=trajectory, info=info)
def on_setup_attempt(self):
for hook in self.hooks:
hook.on_setup_attempt()
def on_model_query(self, *, messages: list[dict[str, str]], agent: str):
for hook in self.hooks:
hook.on_model_query(messages=messages, agent=agent)
def on_query_message_added(
self,
*,
agent: str,
role: str,
content: str,
message_type: str,
is_demo: bool = False,
thought: str = "",
action: str = "",
tool_calls: list[dict[str, str]] | None = None,
tool_call_ids: list[str] | None = None,
thinking_blocks: list[dict[str, str]] | None = None,
):
for hook in self.hooks:
hook.on_query_message_added(
agent=agent,
role=role,
content=content,
message_type=message_type,
is_demo=is_demo,
thought=thought,
action=action,
tool_calls=tool_calls,
tool_call_ids=tool_call_ids,
)
def on_setup_done(self):
return super().on_setup_done()
def on_tools_installation_started(self):
for hook in self.hooks:
hook.on_tools_installation_started()

View File

@@ -0,0 +1,34 @@
from collections.abc import Callable
from sweagent.agent.hooks.abstract import AbstractAgentHook
from sweagent.types import AgentInfo, StepOutput
class SetStatusAgentHook(AbstractAgentHook):
def __init__(self, id: str, callable: Callable[[str, str], None]):
self._callable = callable
self._id = id
self._i_step = 0
self._cost = 0.0
self._i_attempt = 0
self._previous_cost = 0.0
def on_setup_attempt(self):
self._i_attempt += 1
self._i_step = 0
# Costs will be reset for the next attempt
self._previous_cost += self._cost
def _update(self, message: str):
self._callable(self._id, message)
def on_step_start(self):
self._i_step += 1
attempt_str = f"Attempt {self._i_attempt} " if self._i_attempt > 1 else ""
self._update(f"{attempt_str}Step {self._i_step:>3} (${self._previous_cost + self._cost:.2f})")
def on_step_done(self, *, step: StepOutput, info: AgentInfo):
self._cost = info["model_stats"]["instance_cost"] # type: ignore
def on_tools_installation_started(self):
self._update("Installing tools")

View File

@@ -0,0 +1,903 @@
from __future__ import annotations
import copy
import json
import os
import random
import shlex
import threading
import time
from abc import ABC, abstractmethod
from pathlib import Path
from threading import Lock
from typing import Annotated, Any, Literal
import litellm
import litellm.types.utils
from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, Field, SecretStr
from swerex.exceptions import SwerexException
from tenacity import (
RetryCallState,
Retrying,
retry_if_not_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from sweagent import REPO_ROOT, __version__
from sweagent.exceptions import (
ContentPolicyViolationError,
ContextWindowExceededError,
CostLimitExceededError,
FunctionCallingFormatError,
InstanceCallLimitExceededError,
InstanceCostLimitExceededError,
ModelConfigurationError,
TotalCostLimitExceededError,
)
from sweagent.tools.tools import ToolConfig
from sweagent.types import History, HistoryItem
from sweagent.utils.log import get_logger
try:
import readline # noqa: F401
except ImportError:
readline = None
litellm.suppress_debug_info = True
_THREADS_THAT_USED_API_KEYS = []
"""Keeps track of thread orders so that we can choose the same API key for the same thread."""
class RetryConfig(PydanticBaseModel):
"""This configuration object specifies how many times to retry a failed LM API call."""
retries: int = 20
"""Number of retries"""
min_wait: float = 10
"""Minimum wait time between retries (random exponential wait)"""
max_wait: float = 120
"""Maximum wait time between retries (random exponential wait)"""
class GenericAPIModelConfig(PydanticBaseModel):
"""This configuration object specifies a LM like GPT4 or similar.
The model will be served with the help of the `litellm` library.
"""
name: str = Field(description="Name of the model.")
per_instance_cost_limit: float = Field(
default=3.0,
description="Cost limit for every instance (task).",
)
total_cost_limit: float = Field(default=0.0, description="Total cost limit.")
per_instance_call_limit: int = Field(default=0, description="Per instance call limit.")
temperature: float = 0.0
"""Sampling temperature"""
top_p: float | None = 1.0
"""Sampling top-p"""
api_base: str | None = None
api_version: str | None = None
api_key: SecretStr | None = None
"""API key to the model. We recommend using environment variables to set this instead
or putting your environment variables in a `.env` file.
You can concatenate more than one key by separating them with `:::`, e.g.,
`key1:::key2`.
If field starts with `$`, it will be interpreted as an environment variable.
"""
stop: list[str] = []
"""Custom stop sequences"""
completion_kwargs: dict[str, Any] = {}
"""Additional kwargs to pass to `litellm.completion`"""
convert_system_to_user: bool = False
"""Whether to convert system messages to user messages. This is useful for
models that do not support system messages like o1.
"""
retry: RetryConfig = RetryConfig()
"""Retry configuration: How often to retry after a failure (e.g., from a rate limit)
etc.
"""
delay: float = 0.0
"""Minimum delay before querying (this can help to avoid overusing the API if sharing
it with other people).
"""
fallbacks: list[dict[str, Any]] = []
"""List of fallbacks to try if the main model fails
See https://docs.litellm.ai/docs/completion/reliable_completions#fallbacks-sdk
for more information.
"""
choose_api_key_by_thread: bool = True
"""Whether to choose the API key based on the thread name (if multiple are configured).
This ensures that with
run-batch, we use the same API key within a single-thread so that prompt caching still works.
"""
max_input_tokens: int | None = None
"""If set, this will override the max input tokens for the model that we usually look
up from `litellm.model_cost`.
Use this for local models or if you want to set a custom max input token limit.
If this value is exceeded, a `ContextWindowExceededError` will be raised.
Set this to 0 to disable this check.
"""
max_output_tokens: int | None = None
"""If set, this will override the max output tokens for the model that we usually look
up from `litellm.model_cost`.
Use this for local models or if you want to set a custom max output token limit.
If this value is exceeded, a `ContextWindowExceededError` will be raised.
Set this to 0 to disable this check.
"""
litellm_model_registry: str | None = None
"""If set, this will override the default model registry for litellm.
Use this for local models or models not (yet) in the default litellm model registry for tracking costs.
"""
custom_tokenizer: dict[str, Any] | None = None
"""Override the default tokenizer for the model.
Use the arguments of `litellm.create_pretrained_tokenizer`.
Basic example: `{"identifier": "hf-internal-testing/llama-tokenizer"}`
"""
# pydantic
model_config = ConfigDict(extra="forbid")
def get_api_keys(self) -> list[str]:
"""Returns a list of API keys that were explicitly set in this config.
Does not return API keys that were set via environment variables/.env
"""
if self.api_key is None:
return []
api_key = self.api_key.get_secret_value()
if not api_key:
return []
if api_key.startswith("$"):
env_var_name = api_key[1:]
api_key = os.getenv(env_var_name, "")
if not api_key:
get_logger("swea-config", emoji="🔧").warning(f"Environment variable {env_var_name} not set")
return []
return api_key.split(":::")
def choose_api_key(self) -> str | None:
"""Chooses an API key based on the API keys explicitly set in this config.
If no API keys are set, returns None (which means that the API key will be
taken from the environment variables/.env file).
"""
api_keys = self.get_api_keys()
if not api_keys:
return None
if not self.choose_api_key_by_thread:
return random.choice(api_keys)
thread_name = threading.current_thread().name
if thread_name not in _THREADS_THAT_USED_API_KEYS:
_THREADS_THAT_USED_API_KEYS.append(thread_name)
thread_idx = _THREADS_THAT_USED_API_KEYS.index(thread_name)
key_idx = thread_idx % len(api_keys)
get_logger("config", emoji="🔧").debug(
f"Choosing API key {key_idx} for thread {thread_name} (idx {thread_idx})"
)
return api_keys[key_idx]
@property
def id(self) -> str:
name = self.name.replace("/", "--")
if self.top_p is not None:
top_p = f"{self.top_p:.2f}"
else:
top_p = "None"
temperature = f"{self.temperature:.2f}"
per_instance_cost_limit = f"{self.per_instance_cost_limit:.2f}"
return f"{name}__t-{temperature}__p-{top_p}__c-{per_instance_cost_limit}"
class ReplayModelConfig(GenericAPIModelConfig):
replay_path: Path = Field(description="Path to replay file when using the replay model.")
per_instance_cost_limit: float = Field(
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
)
total_cost_limit: float = Field(
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
)
name: Literal["replay"] = Field(default="replay", description="Model name.")
model_config = ConfigDict(extra="forbid")
class InstantEmptySubmitModelConfig(GenericAPIModelConfig):
"""Model that immediately submits an empty patch"""
name: Literal["instant_empty_submit"] = Field(default="instant_empty_submit", description="Model name.")
per_instance_cost_limit: float = Field(
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
)
total_cost_limit: float = Field(
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
)
delay: float = 0.0
"""Delay before answering"""
model_config = ConfigDict(extra="forbid")
class HumanModelConfig(GenericAPIModelConfig):
name: Literal["human"] = Field(default="human", description="Model name.")
per_instance_cost_limit: float = Field(
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
)
total_cost_limit: float = Field(default=0.0, description="Cost limit for all instances (tasks).")
cost_per_call: float = 0.0
catch_eof: bool = True
"""Whether to catch EOF and return 'exit' when ^D is pressed. Set to False when used in human_step_in mode."""
model_config = ConfigDict(extra="forbid")
class HumanThoughtModelConfig(HumanModelConfig):
name: Literal["human_thought"] = Field(default="human_thought", description="Model name.")
per_instance_cost_limit: float = Field(
default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
)
total_cost_limit: float = Field(
default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
)
cost_per_call: float = 0.0
model_config = ConfigDict(extra="forbid")
ModelConfig = Annotated[
GenericAPIModelConfig
| ReplayModelConfig
| InstantEmptySubmitModelConfig
| HumanModelConfig
| HumanThoughtModelConfig,
Field(union_mode="left_to_right"),
]
class GlobalStats(PydanticBaseModel):
"""This class tracks usage numbers (costs etc.) across all instances."""
total_cost: float = 0
"""Cumulative cost for all instances so far"""
last_query_timestamp: float = 0
"""Timestamp of the last query. Currently only used with API models."""
GLOBAL_STATS = GlobalStats()
"""This object tracks usage numbers (costs etc.) across all instances.
Please use the `GLOBAL_STATS_LOCK` lock when accessing this object to avoid race conditions.
"""
GLOBAL_STATS_LOCK = Lock()
"""Lock for accessing `GLOBAL_STATS` without race conditions"""
class InstanceStats(PydanticBaseModel):
"""This object tracks usage numbers (costs etc.) for a single instance."""
instance_cost: float = 0
tokens_sent: int = 0
tokens_received: int = 0
api_calls: int = 0
def __add__(self, other: InstanceStats) -> InstanceStats:
return InstanceStats(
**{field: getattr(self, field) + getattr(other, field) for field in self.model_fields.keys()},
)
def __sub__(self, other: InstanceStats) -> InstanceStats:
return InstanceStats(
**{field: getattr(self, field) - getattr(other, field) for field in self.model_fields.keys()},
)
class AbstractModel(ABC):
def __init__(self, config: ModelConfig, tools: ToolConfig):
self.config: ModelConfig
self.stats: InstanceStats
def reset_stats(self):
self.stats = InstanceStats()
@abstractmethod
def query(self, history: History, action_prompt: str = "> ") -> dict: ...
@property
def instance_cost_limit(self) -> float:
"""Cost limit for the model. Returns 0 if there is no limit."""
return 0
def _handle_raise_commands(action: str) -> None:
if action == "raise_runtime":
raise SwerexException()
elif action == "raise_cost":
raise CostLimitExceededError()
elif action == "raise_context":
raise ContextWindowExceededError()
elif action.startswith("raise_function_calling"):
parts = shlex.split(action)
error_code = parts[1]
if len(parts) == 3:
error_message = parts[2]
assert len(parts) < 4
raise FunctionCallingFormatError(error_message, error_code) # type: ignore
class HumanModel(AbstractModel):
def __init__(self, config: HumanModelConfig, tools: ToolConfig):
"""Model that allows for human-in-the-loop"""
self.logger = get_logger("swea-lm", emoji="🤖")
self.config: HumanModelConfig = config
self.stats = InstanceStats()
# Determine which commands require multi-line input
self.multi_line_command_endings = {
command.name: command.end_name for command in tools.commands if command.end_name is not None
}
self._readline_histfile = REPO_ROOT / ".swe-agent-human-history"
self._load_readline_history()
def _load_readline_history(self) -> None:
"""Load autocomplete history from file"""
if readline is None:
return
if self._readline_histfile.is_file():
self.logger.debug(f"Loading readline history from {self._readline_histfile}")
readline.read_history_file(self._readline_histfile)
def _save_readline_history(self) -> None:
"""Save autocomplete history to file"""
if readline is None:
return
readline.write_history_file(self._readline_histfile)
def _update_stats(
self,
) -> None:
self.stats.instance_cost += self.config.cost_per_call
self.stats.api_calls += 1
if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
msg = f"Instance cost limit exceeded: {self.stats.instance_cost} > {self.config.per_instance_cost_limit}"
raise InstanceCostLimitExceededError(msg)
if 0 < self.config.total_cost_limit < self.stats.instance_cost:
msg = f"Total cost limit exceeded: {self.stats.instance_cost} > {self.config.total_cost_limit}"
raise TotalCostLimitExceededError(msg)
def _query(
self,
history: History,
action_prompt: str = "> ",
) -> dict:
"""Logic for handling user input to pass to SWEEnv"""
action = input(action_prompt)
self._save_readline_history()
command_name = action.split()[0] if action.strip() else ""
# Special handling for multi-line input actions (i.e. edit)
if command_name in self.multi_line_command_endings:
buffer = [action]
end_keyword = self.multi_line_command_endings[command_name]
while True:
action = input("... ")
buffer.append(action)
if action.rstrip() == end_keyword:
# Continue reading input until terminating keyword inputted
break
action = "\n".join(buffer)
elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
buffer = []
while True:
action = input("... ")
if action.rstrip() == "end_multiline_command":
break
buffer.append(action)
action = "\n".join(buffer)
else:
# Input has escaped things like \n, so we need to unescape it
action = action.encode("utf8").decode("unicode_escape")
if action.strip() and action.strip().split()[0] == "spend_money":
money = float(action.strip().split()[1])
self.stats.instance_cost += money
action = f"echo 'Spent {money} dollars'"
_handle_raise_commands(action)
self._update_stats()
return {"message": action}
def query(self, history: History, action_prompt: str = "> ", n: int | None = None, **kwargs) -> dict | list[dict]:
"""Wrapper to separate action prompt from formatting"""
out = []
n_samples = n or 1
for _ in range(n_samples):
try:
out.append(self._query(history, action_prompt))
except KeyboardInterrupt:
print("^C (exit with ^D)")
out.append(self.query(history, action_prompt))
except EOFError:
if self.config.catch_eof:
print("\nGoodbye!")
out.append({"message": "exit"})
else:
# Re-raise EOFError when catch_eof is disabled
raise
if n is None:
return out[0]
return out
class HumanThoughtModel(HumanModel):
def query(self, history: History, **kwargs) -> dict:
"""Logic for handling user input (both thought + action) to pass to SWEEnv"""
thought_all = ""
thought = input("Thought (end w/ END_THOUGHT): ")
while True:
if "END_THOUGHT" in thought:
thought = thought.split("END_THOUGHT")[0]
thought_all += thought
break
thought_all += thought
thought = input("... ")
action = super()._query(history, action_prompt="Action: ")["message"]
return {"message": f"{thought_all}\n```\n{action}\n```"}
class ReplayModel(AbstractModel):
def __init__(self, config: ReplayModelConfig, tools: ToolConfig):
"""Model used for replaying a trajectory (i.e., taking all the actions for the `.traj` file
and re-issuing them.
"""
self.config = config
self.stats = InstanceStats()
if not self.config.replay_path.exists():
msg = f"Replay file {self.config.replay_path} not found"
raise FileNotFoundError(msg)
self._replays = [
list(json.loads(x).values())[0] for x in Path(self.config.replay_path).read_text().splitlines(keepends=True)
]
self._replay_idx = 0
self._action_idx = 0
self.use_function_calling = tools.use_function_calling
self.submit_command = tools.submit_command
self.logger = get_logger("swea-lm", emoji="🤖")
def _next_replay(self) -> None:
"""Called after last action"""
self._replay_idx += 1
self._action_idx = 0
def query(self, history: History) -> dict:
"""Logic for tracking which replay action to pass to SWEEnv"""
self.stats.api_calls += 1
actions = self._replays[self._replay_idx]
try:
action = actions[self._action_idx]
except IndexError:
# log error
self.logger.error("Reached end of replay trajectory without submitting. Submitting now.")
if self.use_function_calling:
action = {
"message": f"Calling `{self.submit_command}` to submit.",
"tool_calls": [
{
"type": "function",
"id": "call_submit",
"function": {
"name": self.submit_command,
"arguments": "{}",
},
}
],
}
else:
action = f"```\n{self.submit_command}\n```"
self._action_idx += 1
# Assuming `submit` is always last action of replay trajectory
if isinstance(action, str) and action == "submit":
self._next_replay()
return {"message": action}
# Handle both dict and string actions
if isinstance(action, dict):
return action
return {"message": action}
class PredeterminedTestModel(AbstractModel):
def __init__(self, outputs: list[dict | str]):
"""Model that outputs a predetermined sequence of messages. Useful for testing."""
self._outputs = outputs
self._idx = -1
self.stats = InstanceStats()
def query(self, *args, **kwargs) -> dict:
self._idx += 1
output = self._outputs[self._idx]
if isinstance(output, str):
_handle_raise_commands(output)
return {"message": output}
if not isinstance(output, dict):
msg = f"Output must be string or dict, got {type(output)}"
raise ValueError(msg)
result = {"message": output["message"]}
if "tool_calls" in output:
result["tool_calls"] = output["tool_calls"]
return result
class InstantEmptySubmitTestModel(AbstractModel):
def __init__(self, args: InstantEmptySubmitModelConfig, tools: ToolConfig):
"""This model immediately submits. Useful for testing purposes"""
super().__init__(args, tools)
self.config: InstantEmptySubmitModelConfig = args
self.stats = InstanceStats()
self._action_idx = 0
def query(self, history: list[dict[str, str]]) -> dict:
time.sleep(random.uniform(0, self.config.delay))
# Need to at least do _something_ to submit
if self._action_idx == 0:
self._action_idx = 1
action = (
"DISCUSSION\n"
"Let's reproduce the bug by creating a `reproduce.py` file.\n\n"
"```\n"
"touch reproduce.py\n"
"```\n"
)
elif self._action_idx == 1:
self._action_idx = 0
action = "DISCUSSION\nThe task should be resolved, so let's submit the patch.\n\n```\nsubmit\n```\n"
self.stats.api_calls += 1
return {"message": action}
class LiteLLMModel(AbstractModel):
def __init__(self, args: GenericAPIModelConfig, tools: ToolConfig):
"""Model served by the `litellm` library."""
# Always copy config to avoid shared state between different instances
self.config: GenericAPIModelConfig = args.model_copy(deep=True)
self.stats = InstanceStats()
self.tools = tools
self.logger = get_logger("swea-lm", emoji="🤖")
if tools.use_function_calling:
if not litellm.utils.supports_function_calling(model=self.config.name):
msg = (
f"Model {self.config.name} does not support function calling. If your model"
" does not support function calling, you can use `parse_function='thought_action'` instead. "
"See https://swe-agent.com/latest/faq/ for more information."
)
self.logger.warning(msg)
if self.config.litellm_model_registry is not None:
with open(self.config.litellm_model_registry) as f:
model_costs = json.load(f)
litellm.register_model(model_costs)
if self.config.max_input_tokens is not None:
self.model_max_input_tokens = self.config.max_input_tokens
else:
self.model_max_input_tokens = litellm.model_cost.get(self.config.name, {}).get("max_input_tokens")
if self.config.max_output_tokens is not None:
self.model_max_output_tokens = self.config.max_output_tokens
else:
self.model_max_output_tokens = litellm.model_cost.get(self.config.name, {}).get("max_output_tokens")
# Special handling for Claude 3.7 models to set 64k context by default when beta header not present
# See https://github.com/SWE-agent/SWE-agent/pull/1016
is_claude_3_7 = "claude-3-7-sonnet" in self.config.name or "claude-sonnet-4" in self.config.name
has_128k_beta_header = (
self.config.completion_kwargs.get("extra_headers", {}).get("anthropic-beta") == "output-128k-2025-02-19"
)
if is_claude_3_7 and not has_128k_beta_header:
self.model_max_output_tokens = 64000
self.logger.warning(
"Claude 3.7/4 models do not support 128k context by default. "
"Setting max output tokens to 64k. To enable 128k context, please set the "
"completion_kwargs to {'extra_headers': {'anthropic-beta': 'output-128k-2025-02-19'}}."
)
self.lm_provider = litellm.model_cost.get(self.config.name, {}).get("litellm_provider", self.config.name)
self.custom_tokenizer = None
if self.config.custom_tokenizer is not None:
self.custom_tokenizer = litellm.utils.create_pretrained_tokenizer(**self.config.custom_tokenizer)
@property
def instance_cost_limit(self) -> float:
"""Cost limit for the model. Returns 0 if there is no limit."""
return self.config.per_instance_cost_limit
def _update_stats(self, *, input_tokens: int, output_tokens: int, cost: float) -> None:
with GLOBAL_STATS_LOCK:
GLOBAL_STATS.total_cost += cost
self.stats.instance_cost += cost
self.stats.tokens_sent += input_tokens
self.stats.tokens_received += output_tokens
self.stats.api_calls += 1
# Log updated cost values to std. err
self.logger.debug(
f"input_tokens={input_tokens:,}, "
f"output_tokens={output_tokens:,}, "
f"instance_cost={self.stats.instance_cost:.2f}, "
f"cost={cost:.2f}",
)
self.logger.debug(
f"total_tokens_sent={self.stats.tokens_sent:,}, "
f"total_tokens_received={self.stats.tokens_received:,}, "
f"total_cost={GLOBAL_STATS.total_cost:.2f}, "
f"total_api_calls={self.stats.api_calls:,}",
)
# Check whether total cost or instance cost limits have been exceeded
if 0 < self.config.total_cost_limit < GLOBAL_STATS.total_cost:
self.logger.warning(f"Cost {GLOBAL_STATS.total_cost:.2f} exceeds limit {self.config.total_cost_limit:.2f}")
msg = "Total cost limit exceeded"
raise TotalCostLimitExceededError(msg)
if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
self.logger.warning(
f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.config.per_instance_cost_limit:.2f}"
)
msg = "Instance cost limit exceeded"
raise InstanceCostLimitExceededError(msg)
if 0 < self.config.per_instance_call_limit < self.stats.api_calls:
self.logger.warning(f"API calls {self.stats.api_calls} exceeds limit {self.config.per_instance_call_limit}")
msg = "Per instance call limit exceeded"
raise InstanceCallLimitExceededError(msg)
def _sleep(self) -> None:
elapsed_time = time.time() - GLOBAL_STATS.last_query_timestamp
if elapsed_time < self.config.delay:
time.sleep(self.config.delay - elapsed_time)
with GLOBAL_STATS_LOCK:
GLOBAL_STATS.last_query_timestamp = time.time()
def _single_query(
self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
) -> list[dict]:
self._sleep()
# Workaround for litellm bug https://github.com/SWE-agent/SWE-agent/issues/1109
messages_no_cache_control = copy.deepcopy(messages)
for message in messages_no_cache_control:
if "cache_control" in message:
del message["cache_control"]
if "thinking_blocks" in message:
del message["thinking_blocks"]
input_tokens: int = litellm.utils.token_counter(
messages=messages_no_cache_control,
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
custom_tokenizer=self.custom_tokenizer,
)
if self.model_max_input_tokens is None:
msg = (
f"No max input tokens found for model {self.config.name!r}. "
"If you are using a local model, you can set `max_input_token` in the model config to override this."
)
self.logger.warning(msg)
elif input_tokens > self.model_max_input_tokens > 0:
msg = f"Input tokens {input_tokens} exceed max tokens {self.model_max_input_tokens}"
raise ContextWindowExceededError(msg)
extra_args = {}
if self.config.api_base:
# Not assigned a default value in litellm, so only pass this if it's set
extra_args["api_base"] = self.config.api_base
if self.tools.use_function_calling:
extra_args["tools"] = self.tools.tools
# We need to always set max_tokens for anthropic models
completion_kwargs = copy.deepcopy(self.config.completion_kwargs)
if self.lm_provider == "anthropic":
completion_kwargs["max_tokens"] = self.model_max_output_tokens
# Add User-Agent header (don't override user-provided headers)
if "extra_headers" not in completion_kwargs:
completion_kwargs["extra_headers"] = {}
if "User-Agent" not in completion_kwargs["extra_headers"]:
completion_kwargs["extra_headers"]["User-Agent"] = f"swe-agent/{__version__}"
try:
response: litellm.types.utils.ModelResponse = litellm.completion( # type: ignore
model=self.config.name,
messages=messages,
temperature=self.config.temperature if temperature is None else temperature,
top_p=self.config.top_p,
api_version=self.config.api_version,
api_key=self.config.choose_api_key(),
fallbacks=self.config.fallbacks,
**completion_kwargs,
**extra_args,
n=n,
)
except litellm.exceptions.ContextWindowExceededError as e:
raise ContextWindowExceededError from e
except litellm.exceptions.ContentPolicyViolationError as e:
raise ContentPolicyViolationError from e
except litellm.exceptions.BadRequestError as e:
if "is longer than the model's context length" in str(e):
raise ContextWindowExceededError from e
raise
self.logger.debug(f"Response: {response}")
try:
cost = litellm.cost_calculator.completion_cost(response, model=self.config.name)
except Exception as e:
self.logger.debug(f"Error calculating cost: {e}, setting cost to 0.")
if self.config.per_instance_cost_limit > 0 or self.config.total_cost_limit > 0:
msg = (
f"Error calculating cost: {e} for your model {self.config.name}. If this is ok "
"(local models, etc.), please make sure you set `per_instance_cost_limit` and "
"`total_cost_limit` to 0 to disable this safety check."
)
self.logger.error(msg)
raise ModelConfigurationError(msg)
cost = 0
choices: litellm.types.utils.Choices = response.choices # type: ignore
n_choices = n if n is not None else 1
outputs = []
output_tokens = 0
for i in range(n_choices):
output = choices[i].message.content or ""
output_tokens += litellm.utils.token_counter(
text=output,
model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
custom_tokenizer=self.custom_tokenizer,
)
output_dict = {"message": output}
if self.tools.use_function_calling:
if response.choices[i].message.tool_calls: # type: ignore
tool_calls = [call.to_dict() for call in response.choices[i].message.tool_calls] # type: ignore
else:
tool_calls = []
output_dict["tool_calls"] = tool_calls
if (
hasattr(response.choices[i].message, "thinking_blocks") # type: ignore
and response.choices[i].message.thinking_blocks # type: ignore
):
output_dict["thinking_blocks"] = response.choices[i].message.thinking_blocks # type: ignore
outputs.append(output_dict)
self._update_stats(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost)
return outputs
def _query(
self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
) -> list[dict]:
if n is None:
return self._single_query(messages, temperature=temperature)
outputs = []
# not needed for openai, but oh well.
for _ in range(n):
outputs.extend(self._single_query(messages))
return outputs
def query(self, history: History, n: int = 1, temperature: float | None = None) -> list[dict] | dict:
messages = self._history_to_messages(history)
def retry_warning(retry_state: RetryCallState):
exception_info = ""
if attempt.retry_state.outcome is not None and attempt.retry_state.outcome.exception() is not None:
exception = attempt.retry_state.outcome.exception()
exception_info = f" due to {exception.__class__.__name__}: {str(exception)}"
self.logger.warning(
f"Retrying LM query: attempt {attempt.retry_state.attempt_number} "
f"(slept for {attempt.retry_state.idle_for:.2f}s)"
f"{exception_info}"
)
for attempt in Retrying(
stop=stop_after_attempt(self.config.retry.retries),
wait=wait_random_exponential(min=self.config.retry.min_wait, max=self.config.retry.max_wait),
reraise=True,
retry=retry_if_not_exception_type(
(
ContextWindowExceededError,
CostLimitExceededError,
RuntimeError,
litellm.exceptions.UnsupportedParamsError,
litellm.exceptions.NotFoundError,
litellm.exceptions.PermissionDeniedError,
litellm.exceptions.ContextWindowExceededError,
litellm.exceptions.APIError,
litellm.exceptions.ContentPolicyViolationError,
TypeError,
litellm.exceptions.AuthenticationError,
ContentPolicyViolationError,
ModelConfigurationError,
KeyboardInterrupt,
IndexError,
)
),
before_sleep=retry_warning,
):
with attempt:
result = self._query(messages, n=n, temperature=temperature)
if n is None or n == 1:
return result[0]
return result
def _history_to_messages(
self,
history: History,
) -> list[dict[str, str]]:
history = copy.deepcopy(history)
def get_role(history_item: HistoryItem) -> str:
if history_item["role"] == "system":
return "user" if self.config.convert_system_to_user else "system"
return history_item["role"]
messages = []
for history_item in history:
role = get_role(history_item)
if role == "tool":
message = {
"role": role,
"content": history_item["content"],
# Only one tool call per observations
"tool_call_id": history_item["tool_call_ids"][0], # type: ignore
}
elif (tool_calls := history_item.get("tool_calls")) is not None:
message = {"role": role, "content": history_item["content"], "tool_calls": tool_calls}
if thinking_blocks := history_item.get("thinking_blocks"):
message["thinking_blocks"] = thinking_blocks
else:
message = {"role": role, "content": history_item["content"]}
if "cache_control" in history_item:
message["cache_control"] = history_item["cache_control"]
messages.append(message)
n_cache_control = str(messages).count("cache_control")
self.logger.debug(f"n_cache_control: {n_cache_control}")
return messages
def get_model(args: ModelConfig, tools: ToolConfig) -> AbstractModel:
"""Returns correct model object given arguments and commands"""
# Convert GenericAPIModelConfig to specific model config if needed
if isinstance(args, GenericAPIModelConfig) and not isinstance(
args, HumanModelConfig | HumanThoughtModelConfig | ReplayModelConfig | InstantEmptySubmitModelConfig
):
if args.name == "human":
args = HumanModelConfig(**args.model_dump())
elif args.name == "human_thought":
args = HumanThoughtModelConfig(**args.model_dump())
elif args.name == "replay":
args = ReplayModelConfig(**args.model_dump())
elif args.name == "instant_empty_submit":
args = InstantEmptySubmitModelConfig(**args.model_dump())
if args.name == "human":
assert isinstance(args, HumanModelConfig), f"Expected {HumanModelConfig}, got {args}"
return HumanModel(args, tools)
if args.name == "human_thought":
assert isinstance(args, HumanThoughtModelConfig), f"Expected {HumanThoughtModelConfig}, got {args}"
return HumanThoughtModel(args, tools)
if args.name == "replay":
assert isinstance(args, ReplayModelConfig), f"Expected {ReplayModelConfig}, got {args}"
return ReplayModel(args, tools)
elif args.name == "instant_empty_submit":
assert isinstance(args, InstantEmptySubmitModelConfig), f"Expected {InstantEmptySubmitModelConfig}, got {args}"
return InstantEmptySubmitTestModel(args, tools)
assert isinstance(args, GenericAPIModelConfig), f"Expected {GenericAPIModelConfig}, got {args}"
return LiteLLMModel(args, tools)

View File

@@ -0,0 +1,312 @@
import base64
import hashlib
import os
import uuid
from pathlib import Path
from typing import Any, Literal, Protocol
from urllib.parse import urlparse
import requests
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from sweagent.utils.github import _get_problem_statement_from_github_issue, _parse_gh_issue_url
from sweagent.utils.log import get_logger
logger = get_logger("swea-config", emoji="🔧")
# Constants for image processing
VALID_IMAGE_MIME_TYPES = {
"image/png",
"image/jpeg",
"image/jpg", # Some servers return jpg instead of jpeg
"image/webp",
}
class ProblemStatement(Protocol):
"""A problem statement for a task. Any class that implements this protocol
can be used as a problem statement.
"""
id: str
def get_problem_statement(self) -> str: ...
def get_problem_statement_for_env(self) -> str:
"""Used for setting environment variables in the container.
By default, this is the same as get_problem_statement().
"""
return self.get_problem_statement()
def get_extra_fields(self) -> dict[str, Any]: ...
class _BuiltinProblemStatementBase(BaseModel):
"""A base class for the builtin problem statements to avoid typing much"""
def get_problem_statement(self) -> str: ...
def get_problem_statement_for_env(self) -> str:
return self.get_problem_statement()
def get_extra_fields(self) -> dict[str, Any]:
return {}
class EmptyProblemStatement(_BuiltinProblemStatementBase):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
type: Literal["empty"] = "empty"
"""Discriminator for (de)serialization/CLI. Do not change."""
model_config = ConfigDict(extra="forbid")
def get_problem_statement(self) -> str:
return ""
class TextProblemStatement(_BuiltinProblemStatementBase):
text: str
extra_fields: dict[str, Any] = Field(default_factory=dict)
"""Any additional data to be added to the instance.
This data will be available when formatting prompt templates.
"""
type: Literal["text"] = "text"
"""Discriminator for (de)serialization/CLI. Do not change."""
id: str = None # type: ignore
model_config = ConfigDict(extra="forbid")
def model_post_init(self, __context: Any) -> None:
if self.id is None:
logger.info("Setting problem statement id to hash of text")
self.id = hashlib.sha256(self.text.encode()).hexdigest()[:6]
def get_problem_statement(self) -> str:
return self.text
def get_extra_fields(self) -> dict[str, Any]:
return self.extra_fields
def __repr__(self) -> str:
return f"TextProblemStatement(id={self.id}, text={self.text[:30]}...)"
def __str__(self) -> str:
return f"id={self.id}, text={self.text[:30]}..."
class FileProblemStatement(_BuiltinProblemStatementBase):
path: Path
extra_fields: dict[str, Any] = Field(default_factory=dict)
"""Any additional data to be added to the instance.
This data will be available when formatting prompt templates.
"""
type: Literal["text_file"] = "text_file"
"""Discriminator for (de)serialization/CLI. Do not change."""
id: str = None # type: ignore
model_config = ConfigDict(extra="forbid")
def model_post_init(self, __context: Any) -> None:
if self.id is None:
logger.info("Setting problem statement id to hash of file contents (path: %s)", self.path)
self.id = hashlib.sha256(self.get_problem_statement().encode()).hexdigest()[:6]
def get_problem_statement(self) -> str:
return self.path.read_text()
def get_extra_fields(self) -> dict[str, Any]:
return self.extra_fields
class GithubIssue(_BuiltinProblemStatementBase):
github_url: str
extra_fields: dict[str, Any] = Field(default_factory=dict)
"""Any additional data to be added to the instance.
This data will be available when formatting prompt templates.
"""
type: Literal["github"] = "github"
"""Discriminator for (de)serialization/CLI. Do not change."""
id: str = None # type: ignore
model_config = ConfigDict(extra="forbid")
def model_post_init(self, __context: Any) -> None:
if self.id is None:
logger.info("Setting problem statement based on github issue url")
owner, repo, issue_number = _parse_gh_issue_url(self.github_url)
self.id = f"{owner}__{repo}-i{issue_number}"
def get_problem_statement(self) -> str:
owner, repo, issue_number = _parse_gh_issue_url(self.github_url)
return _get_problem_statement_from_github_issue(owner, repo, issue_number, token=os.getenv("GITHUB_TOKEN"))
def get_extra_fields(self) -> dict[str, Any]:
return self.extra_fields
class SWEBenchMultimodalProblemStatement(_BuiltinProblemStatementBase):
text: str
issue_images: list[str] = Field(default_factory=list)
"""List of image asset URLs.
"""
disable_image_processing: bool = False
"""If True, skip image downloading and processing, treating this as a text-only problem statement.
"""
extra_fields: dict[str, Any] = Field(default_factory=dict)
"""Any additional data to be added to the instance.
This data will be available when formatting prompt templates.
"""
type: Literal["swe_bench_multimodal"] = "swe_bench_multimodal"
"""Discriminator for (de)serialization/CLI. Do not change."""
id: str = None # type: ignore
_cached_problem_statement: str | None = PrivateAttr(default=None)
model_config = ConfigDict(extra="forbid")
def model_post_init(self, __context: Any) -> None:
if self.id is None:
logger.info("Setting problem statement id to hash of text")
self.id = hashlib.sha256(self.text.encode()).hexdigest()[:6]
def get_problem_statement_for_env(self) -> str:
"""Return the problem statement without images.
Images are not supported in the environment.
"""
return self.text
def get_problem_statement(self) -> str:
if self.disable_image_processing:
logger.info("Image processing disabled, returning text-only problem statement")
return self.text
if self._cached_problem_statement is not None:
return self._cached_problem_statement
processed_text = self.text
for link in self.issue_images:
try:
image_markdown = self._download_and_convert_image(link)
if image_markdown:
processed_text += f"\n\n{image_markdown}"
except Exception as e:
logger.warning(f"Failed to process image from {link}: {e}")
# cache to avoid re-processing images
self._cached_problem_statement = processed_text
return processed_text
def get_extra_fields(self) -> dict[str, Any]:
return self.extra_fields
def _download_and_convert_image(self, url: str) -> str | None:
"""Download an image from URL and convert it to base64 markdown format.
Args:
url: The URL of the image to download
Returns:
Base64 markdown string if successful, None if failed
Raises:
Various exceptions for network/processing errors
"""
try:
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
logger.warning(f"Invalid URL format: {url}")
return None
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.5735.133 Safari/537.36"
}
response = requests.get(url, headers=headers, timeout=30, stream=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "").lower()
if content_type == "image/jpg":
content_type = "image/jpeg"
if content_type not in VALID_IMAGE_MIME_TYPES:
logger.warning(f"Unsupported image MIME type '{content_type}' for URL: {url}. Not encoding image.")
return None
max_size = 10 * 1024 * 1024 # 10MB
content_length = response.headers.get("content-length")
if content_length and int(content_length) > max_size:
logger.warning(f"Image too large ({content_length} bytes) for URL: {url}")
return None
image_data = b""
for chunk in response.iter_content(chunk_size=8192):
image_data += chunk
if len(image_data) > max_size:
logger.warning(f"Image too large (>{max_size} bytes) for URL: {url}")
return None
if not image_data:
logger.warning(f"Empty image data for URL: {url}")
return None
b64_data = base64.b64encode(image_data).decode("ascii")
markdown = f"![{url}](data:{content_type};base64,{b64_data})"
logger.info(f"Successfully processed image from {url} ({len(image_data)} bytes, {content_type})")
return markdown
except requests.exceptions.Timeout:
logger.warning(f"Timeout downloading image from {url}")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Network error downloading image from {url}: {e}")
return None
except Exception as e:
logger.warning(f"Unexpected error processing image from {url}: {e}")
return None
def __repr__(self) -> str:
n_images = len(self.issue_images)
return f"SWEBenchMultimodalProblemStatement(id={self.id}, text={self.text[:30]}..., images={n_images})"
def __str__(self) -> str:
n_images = len(self.issue_images)
return f"id={self.id}, text={self.text[:30]}..., images={n_images}"
ProblemStatementConfig = (
TextProblemStatement
| SWEBenchMultimodalProblemStatement
| GithubIssue
| EmptyProblemStatement
| FileProblemStatement
)
def problem_statement_from_simplified_input(
*, input: str, type: Literal["text", "text_file", "github_issue", "swe_bench_multimodal"]
) -> ProblemStatementConfig:
"""Get a problem statement from an `input` string and a `type`.
Args:
input: Url/path/text
type: The type of problem statement
"""
if type == "text":
return TextProblemStatement(text=input)
elif type == "text_file":
return FileProblemStatement(path=Path(input))
elif type == "github_issue":
return GithubIssue(github_url=input)
elif type == "swe_bench_multimodal":
return SWEBenchMultimodalProblemStatement(text=input)
else:
msg = f"Unknown problem statement type: {type}"
raise ValueError(msg)

View File

@@ -0,0 +1,664 @@
"""The reviewer implements a retry loop for the agent to retry
solving the issue and to select the best solution.
"""
from __future__ import annotations
import copy
import re
from abc import ABC, abstractmethod
from typing import Any, Literal
import numpy as np
from jinja2 import Template
from pydantic import BaseModel, ConfigDict
from sweagent.agent.history_processors import _set_cache_control
from sweagent.agent.models import (
AbstractModel,
InstanceStats,
ModelConfig,
get_model,
)
from sweagent.agent.problem_statement import ProblemStatement
from sweagent.tools.parsing import ActionParser
from sweagent.tools.tools import ToolConfig
from sweagent.types import AgentInfo, Trajectory, TrajectoryStep
from sweagent.utils.log import get_logger
class ReviewSubmission(BaseModel):
"""Information that's passed to the reviewer"""
#: Total trajectory (including several retries)
trajectory: Trajectory
#: Aggregate info dict (including several retries)
info: AgentInfo
#: Model stats for this attempt
model_stats: InstanceStats
def to_format_dict(self, *, suffix="") -> dict[str, Any]:
"""Return all the data that is used to format the
messages. Trajectory is excluded because it needs special treatment.
"""
out = {}
info = copy.deepcopy(self.info)
if not info.get("submission"):
# Observed that not all exit_cost lead to autosubmission
# so sometimes this might be missing.
info["submission"] = ""
for k, v in info.items():
if isinstance(v, str):
out[f"{k}{suffix}"] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
out[f"{k}_{k2}{suffix}"] = v2
return out
class ReviewerResult(BaseModel):
accept: bool | float
outputs: list[str]
messages: list[dict[str, Any]]
class PreselectorOutput(BaseModel):
chosen_idx: list[int]
response: str
messages: list[dict[str, Any]]
class ChooserOutput(BaseModel):
chosen_idx: int
response: str
preselector_output: PreselectorOutput | None = None
messages: list[dict[str, Any]]
# --- INTERFACES ---
class AbstractReviewer(ABC):
"""The reviewer checks a single solution and tries to predict
if it successfully solves the issue.
"""
@abstractmethod
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
"""Returns True if the submission is believed to be correct"""
class AbstractRetryLoop(ABC):
"""The review loop controls how often the agent tries to solve
the issue and how it selects the best solution.
"""
def retry(self) -> bool:
"""Returns True if the agent should retry solving the issue"""
return False
def on_submit(self, submission: ReviewSubmission) -> None:
"""Called when the agent submits a solution"""
def on_model_query(self, attempt_stats: InstanceStats):
"""Called before the model is queried. Can be used to implement
stop conditions based on attempt cost etc.
"""
def on_attempt_started(self, i_attempt: int, agent):
"""Called when a new attempt is started"""
pass
@abstractmethod
def get_best(self) -> int:
"""Returns the best solution"""
def get_forwarded_vars(self) -> dict[str, Any]:
"""Get the variables that should be forwarded to the next iteration.
Returns:
A dictionary of variables that should be forwarded to the next iteration.
"""
return {}
# --- CONFIGS ---
class PreselectorConfig(BaseModel):
model: ModelConfig
system_template: str
instance_template: str
submission_template: str
max_len_submission: int = 5000
class ChooserConfig(BaseModel):
model: ModelConfig
system_template: str
instance_template: str
submission_template: str
max_len_submission: int = 5000
preselector: PreselectorConfig | None = None
class TrajFormatterConfig(BaseModel):
#: Filter the following actions from the trajectory
filter: list[str] = []
#: Filter outputs from the following actions from the trajectory
output_filter: list[str] = []
#: Format of the trajectory item
item_template: str = "Model: {{response}}\n\nObservation: {{observation}}"
only_show_last_n_output: int = 0
model_config = ConfigDict(extra="forbid")
class ReviewerConfig(BaseModel):
"""The configuration for the reviewer"""
system_template: str
instance_template: str
#: If a submission autosubmits because of total cost or a similar exit status,
#: it will get this malus to its score
failure_score_penalty: float = 0.0
traj_formatter: TrajFormatterConfig
n_sample: int = 5
reduce_by_std: float = 0.0
score_range: tuple[float | None, float | None] = (None, None)
#: If set, we assume that the score is in the range [score_range[0], score_range[1]]
#: Reviews that are outside this range will be ignored
type: Literal["reviewer"] = "reviewer"
model_config = ConfigDict(extra="forbid")
def get_reviewer(self, model: AbstractModel) -> AbstractReviewer:
return Reviewer(self, model)
class ChooserRetryLoopConfig(BaseModel):
type: Literal["chooser"] = "chooser"
chooser: ChooserConfig
max_attempts: int
min_budget_for_new_attempt: float = 0.0
"""Minimal $ that need to be left in order for us to start a new attempt.
If set to 0: Always.
"""
cost_limit: float
"""The maximum cost to spend on all attempts. Does not include cost of choosing.
"""
model_config = ConfigDict(extra="forbid")
def get_retry_loop(self, problem_statement: ProblemStatement) -> ChooserRetryLoop:
return ChooserRetryLoop(self, problem_statement)
class ScoreRetryLoopConfig(BaseModel):
"""The configuration for the review loop"""
type: Literal["score"] = "score"
reviewer_config: ReviewerConfig
accept_score: float
max_accepts: int = 1
max_attempts: int
min_budget_for_new_attempt: float = 0.0
"""Minimal $ that need to be left in order for us to start a new attempt.
If set to 0: Always.
"""
cost_limit: float
"""The maximum cost to spend on all attempts and reviews except the last review.
The last review is not included in the cost limit, because we would waste the last
attempt if we couldn't score it.
"""
model: ModelConfig
model_config = ConfigDict(extra="forbid")
def validate(self):
"""Checks config. Raises `ValueError` in case of misconfiguration"""
...
def __post_init__(self):
self.validate()
def get_retry_loop(self, problem_statement: ProblemStatement) -> ScoreRetryLoop:
return ScoreRetryLoop(self, problem_statement)
RetryLoopConfig = ScoreRetryLoopConfig | ChooserRetryLoopConfig
# --- IMPLEMENTATIONS ---
class Preselector:
def __init__(self, config: PreselectorConfig):
self.config = config
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
self.logger = get_logger("chooser", emoji="🧠")
def interpret(self, response: str) -> list[int]:
if not response:
self.logger.warning("No response from preselector")
return []
# Use regex to extract the last number of the response
last_line = response.splitlines()[-1]
try:
return [int(i) for i in re.findall(r"\d+", last_line)]
except Exception as e:
self.logger.error(f"Error interpreting response: {e}")
return []
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
if (
submission.info.get("submission") is None
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
):
return "Solution invalid."
return Template(self.config.submission_template).render(
**submission.to_format_dict(),
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
)
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
instance_message = Template(self.config.instance_template).render(
problem_statement=problem_statement,
submissions=[self.format_submission(problem_statement, s) for s in input],
)
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
return [
{"role": "system", "content": self.config.system_template},
{"role": "user", "content": instance_message},
]
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> PreselectorOutput:
messages = self.build_messages(problem_statement, input)
response = self.model.query(messages)["message"] # type: ignore
indices = self.interpret(response)
if not indices:
self.logger.warning("No indices found in response, using all indices")
indices = list(range(len(input)))
return PreselectorOutput(chosen_idx=indices, response=response, messages=messages)
class Chooser:
def __init__(self, config: ChooserConfig):
self.config = config
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
self.logger = get_logger("chooser", emoji="🧠")
# self.summarizer = Summarizer(config.summarizer, self.model) if config.summarizer else None
def interpret(self, response: str) -> int:
# Use regex to extract the last number of the response
try:
return int(re.findall(r"\d+", response)[-1])
except Exception as e:
self.logger.error(f"Error interpreting response: {e}")
return 0
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
if (
submission.info.get("submission") is None
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
):
return "Solution invalid."
return Template(self.config.submission_template).render(
**submission.to_format_dict(),
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
)
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
instance_message = Template(self.config.instance_template).render(
problem_statement=problem_statement,
submissions=[self.format_submission(problem_statement, s) for s in input],
)
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
return [
{"role": "system", "content": self.config.system_template},
{"role": "user", "content": instance_message},
]
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> ChooserOutput:
preselector_output = None
selected_indices = list(range(len(input)))
n_submitted = sum(s.info.get("exit_status", "") == "submitted" for s in input)
if n_submitted >= 2:
self.logger.debug(f"Got {n_submitted} submitted submissions, only using them")
selected_indices = [i for i, s in enumerate(input) if s.info.get("exit_status", "") == "submitted"]
else:
self.logger.debug(f"Got only {n_submitted} submitted submissions, disabling exit status filtering")
if self.config.preselector and len(selected_indices) > 2:
preselector = Preselector(self.config.preselector)
try:
preselector_output = preselector.choose(problem_statement, [input[i] for i in selected_indices])
except Exception as e:
self.logger.critical(f"Preselector failed: {e}", exc_info=True)
preselector_output = None
if preselector_output and preselector_output.chosen_idx:
try:
_preselected_indices = [selected_indices[i] for i in preselector_output.chosen_idx]
except IndexError:
_preselected_indices = []
self.logger.error("Preselector gave invalid indices, ignoring it.")
if not _preselected_indices:
self.logger.error("Preselector gave no valid indices, ignoring it.")
else:
selected_indices = _preselected_indices
else:
self.logger.error("Preselector must have failed, ignoring it.")
messages = self.build_messages(problem_statement, [input[i] for i in selected_indices])
chosen_idx = None
try:
response = self.model.query(messages)["message"] # type: ignore
chosen_idx = self.interpret(response)
except Exception as e:
self.logger.critical(f"Chooser failed: {e}", exc_info=True)
chosen_idx = None
if chosen_idx is None or not (0 <= chosen_idx < len(selected_indices)):
self.logger.error(f"Invalid chosen index: {chosen_idx}, using first index")
chosen_idx = selected_indices[0]
else:
chosen_idx = selected_indices[chosen_idx]
return ChooserOutput(
chosen_idx=chosen_idx, response=response, preselector_output=preselector_output, messages=messages
)
class Reviewer(AbstractReviewer):
def __init__(self, config: ReviewerConfig, model):
self._config = config
self._model = model
self._traj_formatter = TrajectoryFormatter(config=config.traj_formatter)
self.logger = get_logger("reviewer", emoji="🧑‍⚖️")
def format_messages(self, instance: ProblemStatement, submission: ReviewSubmission):
system_message = self._config.system_template
self.logger.debug(f"MODEL INPUT (system)\n{system_message}")
ps_format_dict = {
"problem_statement": instance.get_problem_statement(),
**instance.get_extra_fields(),
}
user_message = Template(self._config.instance_template).render(
**ps_format_dict,
**submission.to_format_dict(),
traj=self._traj_formatter.format_trajectory(submission.trajectory),
)
self.logger.debug(f"MODEL INPUT (user)\n{user_message}")
return [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
def interpret(self, response: str) -> bool | float:
last_line = response.strip().split("\n")[-1].strip()
# Find all numbers in the last line and take the last one
numbers = re.findall(r"-?\d+\.?\d*", last_line)
if not numbers:
msg = f"Could not interpret response: {last_line!r}"
raise ValueError(msg)
number = float(numbers[-1])
if self._config.score_range[0] is not None and number < self._config.score_range[0]:
msg = f"Score {number} is below the minimum score {self._config.score_range[0]}"
raise ValueError(msg)
if self._config.score_range[1] is not None and number > self._config.score_range[1]:
msg = f"Score {number} is above the maximum score {self._config.score_range[1]}"
raise ValueError(msg)
return number
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
exit_status = submission.info.get("exit_status")
messages = []
penalty = 0.0
if not exit_status or exit_status.strip() != "submitted":
penalty = self._config.failure_score_penalty
messages = self.format_messages(instance, submission)
if self._config.n_sample > 1:
_set_cache_control(messages[-1]) # type: ignore
answers = []
accepts = []
for _ in range(self._config.n_sample):
try:
answer = self._model.query(messages)["message"]
except Exception as e:
self.logger.warning(f"Query failed: {e}", exc_info=True)
continue
try:
score = self.interpret(answer)
except ValueError as e:
self.logger.warning(f"Could not interpret response: {answer!r}, got {e}")
continue
answers.append(answer)
accepts.append(score)
if not accepts:
answers = ["No valid scores found, failing submission"]
accepts = [-100.0]
accept = sum(accepts) / len(accepts) - penalty
std = np.std(accepts).item()
if self._config.reduce_by_std > 0:
accept -= std * self._config.reduce_by_std
self.logger.info(f"First answer: {answers[0]}")
self.logger.info(f"Final score: {accept} (penalty: {penalty}, std: {std}), individual: {accepts}")
return ReviewerResult(accept=accept, outputs=answers, messages=messages)
# todo: Couldn't I just replace the whole thing with Jinja templates?
class TrajectoryFormatter:
def __init__(
self,
config: TrajFormatterConfig,
):
"""Formats trajectories for the use in prompts"""
self._config = config
def _include_step(self, item: TrajectoryStep) -> bool:
action = item["action"].strip()
for f in self._config.filter:
if action.startswith(f):
return False
return True
def _include_step_output(self, item: TrajectoryStep, i_step: int, n_steps: int) -> bool:
if self._config.only_show_last_n_output > 0 and i_step < n_steps - self._config.only_show_last_n_output:
return False
action = item["action"].strip()
for f in self._config.output_filter:
if action.startswith(f):
return False
return True
def _format_trajectory_step(self, step: TrajectoryStep, i_step: int, *, n_steps: int, i_traj: int = 1) -> str:
step = copy.deepcopy(step)
if not self._include_step_output(step, i_step, n_steps=n_steps):
step["observation"] = "[Output omitted]"
return Template(self._config.item_template).render(
**step,
i_step=i_step,
i_traj=i_traj,
)
def format_trajectory(self, trajectory: Trajectory, i_traj: int = 1) -> str:
traj_messages = [step for step in trajectory if self._include_step(step)]
return "\n\n".join(
[
self._format_trajectory_step(step, i_step, i_traj=i_traj, n_steps=len(traj_messages))
for i_step, step in enumerate(traj_messages)
]
)
class ChooserRetryLoop(AbstractRetryLoop):
def __init__(self, config: ChooserRetryLoopConfig, problem_statement: ProblemStatement):
self._config = config
self._problem_statement = problem_statement
self._chooser = Chooser(config.chooser)
self._submissions: list[ReviewSubmission] = []
self._n_consec_exit_cost: int = 0
self.logger = get_logger("chooser_loop", emoji="🔄")
self._chooser_output: ChooserOutput | None = None
@property
def _total_stats(self) -> InstanceStats:
return sum((s.model_stats for s in self._submissions), start=InstanceStats())
@property
def review_model_stats(self) -> InstanceStats:
return InstanceStats()
@property
def _n_attempts(self) -> int:
return len(self._submissions)
def on_submit(self, submission: ReviewSubmission) -> None:
self._submissions.append(submission)
def retry(self) -> bool:
stat_str = f"n_samples={self._n_attempts}"
if self._total_stats.instance_cost > self._config.cost_limit > 0:
self.logger.info(
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
f"exceeds cost limit ({self._config.cost_limit})"
)
return False
if self._n_attempts >= self._config.max_attempts > 0:
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
return False
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
msg = (
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
)
self.logger.info(msg)
return False
return True
def get_best(self) -> int | None:
"""Important note: This is cached. Only call this at the end."""
if self._chooser_output is not None:
return self._chooser_output.chosen_idx
if len(self._submissions) == 0:
return None
self._chooser_output = self._chooser.choose(self._problem_statement.get_problem_statement(), self._submissions)
return self._chooser_output.chosen_idx
# todo: The model shouldn't be defined here, it should be defined as part of the scorer
class ScoreRetryLoop(AbstractRetryLoop):
def __init__(
self,
config: ScoreRetryLoopConfig,
problem_statement: ProblemStatement,
):
# This model will not share instance cost with the parent agent
self._model = get_model(config.model, tools=ToolConfig())
self._problem_statement = problem_statement
self._reviewer: AbstractReviewer = config.reviewer_config.get_reviewer(self._model)
self._config = config
# Note: These are "cumulative" submissions, i.e., they include all retries
# up to that point.
self._submissions: list[ReviewSubmission] = []
self._reviews: list[ReviewerResult] = []
#: Number of consecutive exit cost submissions
self._n_consec_exit_cost: int = 0
self.logger = get_logger("review_loop", emoji="🔄")
# Properties
# ----------
@property
def review_model_stats(self) -> InstanceStats:
return self._model.stats
@property
def reviews(self) -> list[ReviewerResult]:
return self._reviews
@property
def _n_attempts(self) -> int:
return len(self._submissions)
@property
def _n_accepted(self) -> int:
return sum(r.accept >= self._config.accept_score for r in self._reviews)
@property
def _total_stats(self) -> InstanceStats:
return sum((s.model_stats for s in self._submissions), start=InstanceStats()) + self._model.stats
# -------
def on_submit(self, submission: ReviewSubmission) -> None:
self._submissions.append(submission)
self._review()
def _review(self) -> float:
review = self._reviewer.review(self._problem_statement, self._submissions[-1])
self._reviews.append(review)
exit_status = self._submissions[-1].info.get("exit_status", "")
if exit_status and "exit_cost" in exit_status.lower():
self._n_consec_exit_cost += 1
else:
self._n_consec_exit_cost = 0
return review.accept
def retry(self) -> bool:
max_score = max([r.accept for r in self._reviews], default=-100.0)
stat_str = f"n_samples={self._n_attempts}, max_score={max_score}, n_accepted={self._n_accepted}"
if self._total_stats.instance_cost > self._config.cost_limit > 0:
self.logger.info(
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
f"exceeds cost limit ({self._config.cost_limit})"
)
return False
if self._n_attempts >= self._config.max_attempts > 0:
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
return False
if self._n_accepted >= self._config.max_accepts > 0:
self.logger.info(f"Exiting retry loop ({stat_str}): max_accepts={self._config.max_accepts} reached")
return False
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
msg = (
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
)
self.logger.info(msg)
return False
return True
def get_best(self) -> int | None:
if len(self._reviews) == 0:
return None
scores = [r.accept for r in self._reviews]
self.logger.debug(f"Scores: {scores}")
max_score = np.max(scores)
max_indices = [i for i, s in enumerate(scores) if np.isclose(s, max_score)]
# If there are multiple submissions with the same score, choose the shortest one
max_indices = sorted(max_indices, key=lambda i: self._submissions[i].model_stats.api_calls or float("inf"))
chosen_idx = max_indices[0]
self.logger.info(f"Best submission: {chosen_idx}")
return chosen_idx
def get_retry_loop_from_config(
config: RetryLoopConfig, problem_statement: ProblemStatement
) -> ScoreRetryLoop | ChooserRetryLoop:
return config.get_retry_loop(problem_statement=problem_statement)

View File

@@ -0,0 +1,60 @@
from sweagent.environment.repo import Repo, RepoConfig
class EnvHook:
"""Hook to be used in `SWEEnv`.
Subclass this class, add functionality and add it with `SWEEEnv.add_hook(hook)`.
This allows to inject custom functionality at different stages of the environment
lifecycle, in particular to connect SWE-agent to a new interface (like a GUI).
"""
def on_init(self, *, env) -> None:
"""Gets called when the hook is added"""
def on_copy_repo_started(self, repo: RepoConfig | Repo) -> None:
"""Gets called when the repository is being cloned to the container"""
def on_start_deployment(self) -> None:
"""Gets called when the deployment is being started"""
def on_install_env_started(self) -> None:
"""Called when we start installing the environment"""
def on_close(self):
"""Called when the environment is closed"""
def on_environment_startup(self) -> None:
"""Called when the environment is started"""
class CombinedEnvHooks(EnvHook):
def __init__(self):
self._hooks = []
def add_hook(self, hook: EnvHook) -> None:
self._hooks.append(hook)
def on_init(self, *, env) -> None:
for hook in self._hooks:
hook.on_init(env=env)
def on_copy_repo_started(self, repo: RepoConfig | Repo) -> None:
for hook in self._hooks:
hook.on_copy_repo_started(repo=repo)
def on_start_deployment(self) -> None:
for hook in self._hooks:
hook.on_start_deployment()
def on_install_env_started(self) -> None:
for hook in self._hooks:
hook.on_install_env_started()
def on_close(self):
for hook in self._hooks:
hook.on_close()
def on_environment_startup(self) -> None:
for hook in self._hooks:
hook.on_environment_startup()

View File

@@ -0,0 +1,28 @@
from collections.abc import Callable
from sweagent.environment.hooks.abstract import EnvHook
from sweagent.environment.repo import Repo, RepoConfig
class SetStatusEnvironmentHook(EnvHook):
def __init__(self, id: str, callable: Callable[[str, str], None]):
self._callable = callable
self._id = id
def _update(self, message: str):
self._callable(self._id, message)
def on_copy_repo_started(self, repo: RepoConfig | Repo):
self._update(f"Copying repo {repo.repo_name}")
def on_start_deployment(self):
self._update("Starting deployment")
def on_install_env_started(self):
self._update("Installing environment")
def on_environment_startup(self):
self._update("Starting environment")
def on_close(self):
self._update("Closing environment")

View File

@@ -0,0 +1,258 @@
import asyncio
import os
import shlex
from pathlib import Path
from typing import Any, Literal, Protocol
from git import InvalidGitRepositoryError
from git import Repo as GitRepo
from pydantic import BaseModel, ConfigDict, Field
from swerex.deployment.abstract import AbstractDeployment
from swerex.runtime.abstract import Command, UploadRequest
from typing_extensions import Self
from sweagent.utils.github import _parse_gh_repo_url
from sweagent.utils.log import get_logger
logger = get_logger("swea-config", emoji="🔧")
class Repo(Protocol):
"""Protocol for repository configurations."""
base_commit: str
repo_name: str
def copy(self, deployment: AbstractDeployment): ...
def get_reset_commands(self) -> list[str]: ...
def _get_git_reset_commands(base_commit: str) -> list[str]:
return [
"git fetch",
"git status",
"git restore .",
"git reset --hard",
f"git checkout {shlex.quote(base_commit)}",
"git clean -fdq",
]
class PreExistingRepoConfig(BaseModel):
"""Use this to specify a repository that already exists on the deployment.
This is important because we need to cd to the repo before running the agent.
Note: The repository must be at the root of the deployment.
"""
repo_name: str
"""The repo name (the repository must be located at the root of the deployment)."""
base_commit: str = Field(default="HEAD")
"""The commit to reset the repository to. The default is HEAD,
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
SWE-agent will then start from this commit when trying to solve the problem.
"""
type: Literal["preexisting"] = "preexisting"
"""Discriminator for (de)serialization/CLI. Do not change."""
reset: bool = True
"""If True, reset the repository to the base commit after the copy operation."""
model_config = ConfigDict(extra="forbid")
def copy(self, deployment: AbstractDeployment):
"""Does nothing."""
pass
def get_reset_commands(self) -> list[str]:
"""Issued after the copy operation or when the environment is reset."""
if self.reset:
return _get_git_reset_commands(self.base_commit)
return []
class LocalRepoConfig(BaseModel):
path: Path
base_commit: str = Field(default="HEAD")
"""The commit to reset the repository to. The default is HEAD,
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
SWE-agent will then start from this commit when trying to solve the problem.
"""
type: Literal["local"] = "local"
"""Discriminator for (de)serialization/CLI. Do not change."""
model_config = ConfigDict(extra="forbid")
@property
def repo_name(self) -> str:
"""Set automatically based on the repository name. Cannot be set."""
return Path(self.path).resolve().name.replace(" ", "-").replace("'", "")
# Let's not make this a model validator, because it leads to cryptic errors.
# Let's just check during copy instead.
def check_valid_repo(self) -> Self:
try:
repo = GitRepo(self.path, search_parent_directories=True)
except InvalidGitRepositoryError as e:
msg = f"Could not find git repository at {self.path=}."
raise ValueError(msg) from e
if repo.is_dirty() and "PYTEST_CURRENT_TEST" not in os.environ:
msg = f"Local git repository {self.path} is dirty. Please commit or stash changes."
raise ValueError(msg)
return self
def copy(self, deployment: AbstractDeployment):
self.check_valid_repo()
asyncio.run(
deployment.runtime.upload(UploadRequest(source_path=str(self.path), target_path=f"/{self.repo_name}"))
)
r = asyncio.run(
deployment.runtime.execute(Command(command=f"chown -R root:root /{self.repo_name}", shell=True))
)
if r.exit_code != 0:
msg = f"Failed to change permissions on copied repository (exit code: {r.exit_code}, stdout: {r.stdout}, stderr: {r.stderr})"
raise RuntimeError(msg)
def get_reset_commands(self) -> list[str]:
"""Issued after the copy operation or when the environment is reset."""
return _get_git_reset_commands(self.base_commit)
class GithubRepoConfig(BaseModel):
github_url: str
base_commit: str = Field(default="HEAD")
"""The commit to reset the repository to. The default is HEAD,
i.e., the latest commit. You can also set this to a branch name (e.g., `dev`),
a tag (e.g., `v0.1.0`), or a commit hash (e.g., `a4464baca1f`).
SWE-agent will then start from this commit when trying to solve the problem.
"""
clone_timeout: float = 500
"""Timeout for git clone operation."""
type: Literal["github"] = "github"
"""Discriminator for (de)serialization/CLI. Do not change."""
model_config = ConfigDict(extra="forbid")
def model_post_init(self, __context: Any) -> None:
if self.github_url.count("/") == 1:
self.github_url = f"https://github.com/{self.github_url}"
@property
def repo_name(self) -> str:
org, repo = _parse_gh_repo_url(self.github_url)
return f"{org}__{repo}"
def _get_url_with_token(self, token: str) -> str:
"""Prepend github token to URL"""
if not token:
return self.github_url
if "@" in self.github_url:
logger.warning("Cannot prepend token to URL. '@' found in URL")
return self.github_url
_, _, url_no_protocol = self.github_url.partition("://")
return f"https://{token}@{url_no_protocol}"
def copy(self, deployment: AbstractDeployment):
"""Clones the repository to the sandbox."""
base_commit = self.base_commit
github_token = os.getenv("GITHUB_TOKEN", "")
url = self._get_url_with_token(github_token)
asyncio.run(
deployment.runtime.execute(
Command(
command=" && ".join(
(
f"mkdir /{self.repo_name}",
f"cd /{self.repo_name}",
"git init",
f"git remote add origin {shlex.quote(url)}",
f"git fetch --depth 1 origin {shlex.quote(base_commit)}",
"git checkout FETCH_HEAD",
"cd ..",
)
),
timeout=self.clone_timeout,
shell=True,
check=True,
)
),
)
def get_reset_commands(self) -> list[str]:
"""Issued after the copy operation or when the environment is reset."""
return _get_git_reset_commands(self.base_commit)
class SWESmithRepoConfig(BaseModel):
"""Repository config for SWE-Smith instances that handles targeted fetch
from a GitHub mirror, authenticating via GITHUB_TOKEN when needed.
"""
repo_name: str
base_commit: str = Field(default="HEAD")
mirror_url: str = ""
"""HTTPS URL of the GitHub mirror to fetch the bug branch from."""
type: Literal["swesmith_preexisting"] = "swesmith_preexisting"
"""Discriminator for (de)serialization/CLI. Do not change."""
model_config = ConfigDict(extra="forbid")
def copy(self, deployment: AbstractDeployment):
pass
@staticmethod
def _get_url_with_token(url: str, token: str) -> str:
if not token or not url:
return url
_, _, url_no_protocol = url.partition("://")
return f"https://{token}@{url_no_protocol}"
def get_reset_commands(self) -> list[str]:
if self.mirror_url:
github_token = os.getenv("GITHUB_TOKEN", "")
url = self._get_url_with_token(self.mirror_url, github_token)
return [
"git restore .",
"git reset --hard",
f"git fetch {shlex.quote(url)} {shlex.quote(self.base_commit)}",
"git checkout FETCH_HEAD",
"git clean -fdq",
]
return _get_git_reset_commands(self.base_commit)
RepoConfig = LocalRepoConfig | GithubRepoConfig | PreExistingRepoConfig | SWESmithRepoConfig
def repo_from_simplified_input(
*, input: str, base_commit: str = "HEAD", type: Literal["local", "github", "preexisting", "auto"] = "auto"
) -> RepoConfig:
"""Get repo config from a simplified input.
Args:
input: Local path or GitHub URL
type: The type of repo. Set to "auto" to automatically detect the type
(does not work for preexisting repos).
"""
if type == "local":
return LocalRepoConfig(path=Path(input), base_commit=base_commit)
if type == "github":
return GithubRepoConfig(github_url=input, base_commit=base_commit)
if type == "preexisting":
return PreExistingRepoConfig(repo_name=input, base_commit=base_commit)
if type == "auto":
if input.startswith("https://github.com/"):
return GithubRepoConfig(github_url=input, base_commit=base_commit)
else:
return LocalRepoConfig(path=Path(input), base_commit=base_commit)
msg = f"Unknown repo type: {type}"
raise ValueError(msg)

View File

@@ -0,0 +1,276 @@
import asyncio
import logging
import shlex
from pathlib import PurePath
from typing import Literal, Self
from pydantic import BaseModel, ConfigDict, Field
from swerex.deployment.abstract import AbstractDeployment
from swerex.deployment.config import DeploymentConfig, DockerDeploymentConfig, get_deployment
from swerex.runtime.abstract import (
BashAction,
BashInterruptAction,
CreateBashSessionRequest,
ReadFileRequest,
WriteFileRequest,
)
from swerex.runtime.abstract import Command as RexCommand
from sweagent.environment.hooks.abstract import CombinedEnvHooks, EnvHook
from sweagent.environment.repo import Repo, RepoConfig
from sweagent.utils.log import get_logger
class EnvironmentConfig(BaseModel):
"""Configure data sources and setup instructions for the environment in which we solve the tasks."""
deployment: DeploymentConfig = Field(
default_factory=lambda: DockerDeploymentConfig(image="python:3.11", python_standalone_dir="/root"),
description="Deployment options.",
)
repo: RepoConfig | None = Field(
default=None,
description="Repository options.",
)
post_startup_commands: list[str] = []
"""Execute these commands before starting to run the agent but after all other setup steps.
They will be executed in the same shell as the agent.
Note: Every command is passed as a string, not a list of arguments.
"""
post_startup_command_timeout: int = 500
"""Timeout for the post-startup commands.
NOTE: The timeout applies to every command in `post_startup_commands` separately.
"""
# pydantic config
model_config = ConfigDict(extra="forbid")
name: str = "main"
class SWEEnv:
def __init__(
self,
*,
deployment: AbstractDeployment,
repo: Repo | RepoConfig | None,
post_startup_commands: list[str],
post_startup_command_timeout: int = 500,
hooks: list[EnvHook] | None = None,
name: str = "main",
):
"""This class represents the environment in which we solve the tasks.
Args:
deployment: SWE-ReX deployment instance
repo: Repository configuration object, or anything following the `Repo` protocol
post_startup_commands: Commands to execute before starting the agent
hooks: Environment hooks (used to inject custom functionality)
Equivalent to calling `add_hook` for each hook after initialization.
name: Name of the environment
"""
super().__init__()
self.deployment = deployment
self.repo = repo
self._post_startup_commands = post_startup_commands
self.post_startup_command_timeout = post_startup_command_timeout
self.logger = get_logger("swea-env", emoji="🪴")
self.name = name
self.clean_multi_line_functions = lambda x: x
self._chook = CombinedEnvHooks()
for hook in hooks or []:
self.add_hook(hook)
@classmethod
def from_config(cls, config: EnvironmentConfig) -> Self:
"""Create an environment instance from a configuration object.
This is the recommended way to create an environment instance, unless you need
more flexibility.
"""
# Always copy config to avoid shared state between different instances
config = config.model_copy(deep=True)
return cls(
deployment=get_deployment(config.deployment),
repo=config.repo,
post_startup_commands=config.post_startup_commands,
post_startup_command_timeout=config.post_startup_command_timeout,
name=config.name,
)
def add_hook(self, hook: EnvHook) -> None:
"""Add `EnvHook` to the environment.
This allows to inject custom functionality at different stages of the environment
lifecycle, in particular to connect SWE-agent to a new interface (like a GUI).
"""
hook.on_init(env=self)
self._chook.add_hook(hook)
def start(self) -> None:
"""Start the environment and reset it to a clean state."""
self._init_deployment()
self.reset()
for command in self._post_startup_commands:
self.communicate(command, check="raise", timeout=self.post_startup_command_timeout)
def _copy_repo(self) -> None:
"""Clone/copy repository/codebase in container"""
if self.repo is None:
return
folders = self.communicate(input="ls", check="raise").split("\n")
if self.repo.repo_name in folders:
return
self._chook.on_copy_repo_started(repo=self.repo)
self.repo.copy(self.deployment)
def hard_reset(self):
"""Resets the environment and deployment, i.e., completely restarts the
deployment.
"""
self.close()
self.start()
def reset(self):
"""Reset the environment to a clean state.
Gets called by `start`, but can also be called independently to reset the
environment to a clean state before a new attempt.
Returns:
observation: output from container
info: additional information (e.g. debugging information)
"""
self.communicate(input="cd /", check="raise")
self._copy_repo()
self._reset_repository()
self._chook.on_environment_startup()
def _reset_repository(self) -> None:
"""Clean repository of any modifications + Checkout base commit"""
if self.repo is not None:
self.logger.debug("Resetting repository %s to commit %s", self.repo.repo_name, self.repo.base_commit)
# todo: Currently has swe-ft specific change: The original repo.copy isn't called, because the repo is already
# present. However, reset --hard <BRANCH> also doesn't work. So modified it here to do a checkout instead.
startup_commands = [
f"cd /{self.repo.repo_name}",
"export ROOT=$(pwd -P)",
*self.repo.get_reset_commands(),
]
self.communicate(
input=" && ".join(startup_commands),
check="raise",
error_msg="Failed to clean repository",
# Sometimes this is slow because it rebuilds some index
timeout=120,
)
def close(self) -> None:
"""Shutdown SWE-ReX deployment etc."""
self.logger.info("Beginning environment shutdown...")
asyncio.run(self.deployment.stop())
self._chook.on_close()
# MARK: Helper functions #
def _init_deployment(
self,
) -> None:
"""Handles container initialization. Defines container name and creates it.
If cached_image is provided, it will use that image name instead of the default.
"""
self._chook.on_start_deployment()
asyncio.run(self.deployment.start())
asyncio.run(
self.deployment.runtime.create_session(
CreateBashSessionRequest(startup_source=["/root/.bashrc"], startup_timeout=10)
)
)
self.set_env_variables({"LANG": "C.UTF-8", "LC_ALL": "C.UTF-8", "PIP_PROGRESS_BAR": "off", "PAGER": "cat"})
self.logger.info("Environment Initialized")
def interrupt_session(self):
self.logger.info("Interrupting session")
asyncio.run(self.deployment.runtime.run_in_session(BashInterruptAction()))
# todo: return exit code?
def communicate(
self,
input: str,
timeout: int | float = 25,
*,
check: Literal["warn", "ignore", "raise"] = "ignore",
error_msg: str = "Command failed",
) -> str:
"""Executes a command in the running shell. The details of this are handled by
the SWE-ReX deployment/runtime.
Args:
input: input to send to container
timeout_duration: duration to wait for output
check: `ignore`: do not extract exit code (more stable), `warn`: extract exit code and log error if
exit code is non-zero, `raise`: raise error if exit code is non-zero
error_msg: error message to raise if the command fails
Returns:
output: output from container
"""
self.logger.log(logging.TRACE, "Input:\n%s", input) # type: ignore
rex_check = "silent" if check else "ignore"
r = asyncio.run(
self.deployment.runtime.run_in_session(BashAction(command=input, timeout=timeout, check=rex_check))
)
output = r.output
self.logger.log(logging.TRACE, "Output:\n%s", output) # type: ignore
if check != "ignore" and r.exit_code != 0:
self.logger.error(f"{error_msg}:\n{output}")
msg = f"Command {input!r} failed ({r.exit_code=}): {error_msg}"
self.logger.error(msg)
if check == "raise":
self.close()
raise RuntimeError(msg)
return output
def read_file(self, path: str | PurePath, encoding: str | None = None, errors: str | None = None) -> str:
"""Read file contents from container
Args:
path: Absolute path to file
encoding: Encoding to use when reading the file. None means default encoding.
This is the same as the `encoding` argument of `Path.read_text()`
errors: Error handling to use when reading the file. None means default error handling.
This is the same as the `errors` argument of `Path.read_text()`
Returns:
file_contents: Contents of file as string
"""
r = asyncio.run(
self.deployment.runtime.read_file(ReadFileRequest(path=str(path), encoding=encoding, errors=errors))
)
return r.content
def write_file(self, path: str | PurePath, content: str) -> None:
"""Write content to file in container"""
asyncio.run(self.deployment.runtime.write_file(WriteFileRequest(path=str(path), content=content)))
def set_env_variables(self, env_variables: dict[str, str]) -> None:
"""Set environment variables in the environment."""
if not env_variables:
self.logger.debug("No environment variables to set")
return
_env_setters = [f"export {k}={shlex.quote(str(v))}" for k, v in env_variables.items()]
command = " && ".join(_env_setters)
self.communicate(command, check="raise")
def execute_command(
self,
command: str,
shell: bool = True,
check: bool = False,
env: dict[str, str] | None = None,
cwd: str | None = None,
) -> None:
"""Execute a command in the environment independent of the session (i.e., as a subprocess)"""
asyncio.run(
self.deployment.runtime.execute(RexCommand(command=command, shell=shell, check=check, env=env, cwd=cwd))
)

View File

@@ -0,0 +1,54 @@
from typing import Any, Literal
"""This module contains all custom exceptions used by the SWE-agent."""
class FormatError(Exception):
"""Raised when the model response cannot properly be parsed into thought and actions."""
class FunctionCallingFormatError(FormatError):
"""Format error exception used by the function
calling parser."""
def __init__(
self,
message: str,
error_code: Literal[
"missing", "multiple", "incorrect_args", "invalid_json", "invalid_command", "missing_arg", "unexpected_arg"
],
**extra_info: Any,
):
super().__init__(message + f" [error_code={error_code}]")
self.message = message
self.extra_info = {"error_code": error_code, **extra_info}
class ContextWindowExceededError(Exception):
"""Raised when the context window of a LM is exceeded"""
class CostLimitExceededError(Exception):
"""Raised when we exceed a cost limit"""
class InstanceCostLimitExceededError(CostLimitExceededError):
"""Raised when we exceed the cost limit set for one task instance"""
class TotalCostLimitExceededError(CostLimitExceededError):
"""Raised when we exceed the total cost limit"""
class InstanceCallLimitExceededError(CostLimitExceededError):
"""Raised when we exceed the per instance call limit"""
class ContentPolicyViolationError(Exception):
"""Raised when the model response violates a content policy"""
class ModelConfigurationError(Exception):
"""Raised when the model configuration is invalid/no further retries
should be made.
"""

View File

@@ -0,0 +1,6 @@
🔗 For more information on the trajectory inspector, visit [our documentation website][docs].
You can also find the corresponding markdown files in the [`docs/` folder][source].
[docs]: https://swe-agent.com/latest/usage/inspector/
[source]: https://github.com/SWE-agent/SWE-agent/tree/main/docs

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

View File

@@ -0,0 +1,354 @@
let currentFileName = null;
let trajectoryDirectory = "";
let timeoutIds = [];
function getBaseUrl() {
const protocol = window.location.protocol;
const host = window.location.hostname;
const port = window.location.port;
const defaultPort =
protocol === "http:" && !port
? "80"
: protocol === "https:" && !port
? "443"
: port;
return `${protocol}//${host}:${defaultPort}`;
}
function fetchFiles() {
const baseUrl = getBaseUrl();
fetch(`${baseUrl}/files`)
.then((response) => response.json())
.then((files) => {
const fileList = document.getElementById("fileList");
fileList.innerHTML = "";
files.forEach((file) => {
const fileElement = document.createElement("li");
fileElement.textContent = file;
fileElement.onclick = () => viewFile(file.split(" ")[0]);
fileList.appendChild(fileElement);
});
});
}
function createTrajectoryItem(item, index) {
const elementId = `trajectoryItem${index}`;
// Check for old format and log a warning
const isOldFormat = item.messages && !item.query;
if (isOldFormat) {
console.log(
`Found old format using 'messages' instead of 'query' in item ${index}`,
);
// Migrate old format to new format
item.query = item.messages;
}
const hasMessages = item.query && item.query.length > 0;
const escapeHtml = (text) => {
if (!text) {
return "";
}
return text
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
};
const processImagesInObservation = (observation) => {
if (!observation) {
return { processedText: "", images: [] };
}
// regex to match markdown-style base64 images: ![alt text](data:image/<format>;base64,<base64-data>)
const imageRegex = /!\[([^\]]*)\]\(data:image\/([^;]+);base64,([^)]+)\)/g;
const images = [];
let processedText = observation;
let match;
while ((match = imageRegex.exec(observation)) !== null) {
const [fullMatch, altText, format, base64Data] = match;
// create image object
const imageObj = {
altText: altText || "Image",
format: format,
dataUrl: `data:image/${format};base64,${base64Data}`,
id: `img_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
};
images.push(imageObj);
// replace the full base64 string with a placeholder
processedText = processedText.replace(
fullMatch,
`[IMAGE: ${imageObj.altText}]`,
);
}
return { processedText, images };
};
const getMessageContent = (msg) => {
if (!msg.content) {
return "";
}
// Handle content as a string
if (typeof msg.content === "string") {
return msg.content;
}
// Handle content as an array with a dictionary containing 'text' key
if (
Array.isArray(msg.content) &&
msg.content.length > 0 &&
msg.content[0].text
) {
return msg.content[0].text;
}
// Fallback to stringifying the content
return JSON.stringify(msg.content);
};
const messagesContent = hasMessages
? item.query
.map((msg, msgIndex) => {
let content = `----Item ${msgIndex}-----\n`;
content += `role: ${msg.role}\n`;
content += `content: |\n${escapeHtml(getMessageContent(msg))}\n`;
if (msg.tool_calls && msg.tool_calls.length > 0) {
msg.tool_calls.forEach((tool, idx) => {
content += `- tool call ${idx + 1}:\n`;
if (tool.function) {
content += ` - name: ${tool.function.name}\n`;
// Handle arguments based on type
let args = tool.function.arguments;
try {
if (typeof args === "string") {
args = JSON.parse(args);
}
content += ` - arguments: ${JSON.stringify(args, null, 2).replace(/\n/g, "\n ")}\n`;
} catch (e) {
content += ` - arguments: ${escapeHtml(String(args))}\n`;
}
}
content += ` - id: ${tool.id}\n`;
});
}
if (msg.is_demo) {
return `<span class="demo-message">${content}</span>`;
}
return content;
})
.join("\n")
: "";
// Process images in observation
const { processedText: processedObservation, images: observationImages } =
processImagesInObservation(item.observation);
// Create separate image pane HTML if there are images
const observationImagesPane =
observationImages.length > 0
? `<div class="observation-images-section" data-title="Observation Images">
<div class="content-wrapper">
<div class="observation-images">
${observationImages
.map(
(img) =>
`<div class="observation-image-container">
<img src="${img.dataUrl}" alt="${escapeHtml(img.altText)}" class="observation-image" id="${img.id}">
<div class="image-caption">${escapeHtml(img.altText)}</div>
</div>`,
)
.join("")}
</div>
</div>
</div>`
: "";
return `
<div class="trajectory-item fade-in" id="${elementId}">
<div class="trajectory-main">
<div class="response-section" data-title="Response">
<div class="content-wrapper">
<pre><code class="language-python">Response:
${escapeHtml(item.response)}
Action:
${escapeHtml(item.action)}</code></pre>
</div>
</div>
<div class="observation-section" data-title="Environment Observation">
<div class="content-wrapper">
<pre><code class="language-python">${escapeHtml(processedObservation)}</code></pre>
</div>
</div>
${observationImagesPane}
${
item.execution_time
? `<div class="execution-time">Execution time: ${item.execution_time}s</div>`
: ""
}
</div>
${
hasMessages
? `
<div class="messages-section" data-title="Messages">
<div class="content-wrapper">
<pre>${messagesContent}</pre>
</div>
</div>
`
: ""
}
</div>
`;
}
function viewFile(fileName) {
currentFileName = fileName;
timeoutIds.forEach((timeoutId) => clearTimeout(timeoutId));
timeoutIds = [];
const baseUrl = getBaseUrl();
const showDemos = document.getElementById("showDemos").checked;
fetch(`${baseUrl}/trajectory/${fileName}`)
.then((response) => {
if (!response.ok) {
throw new Error("Network response was not ok");
}
return response.json();
})
.then((content) => {
const container = document.getElementById("fileContent");
container.innerHTML = "";
if (content.trajectory && Array.isArray(content.trajectory)) {
content.trajectory.forEach((item, index) => {
container.innerHTML += createTrajectoryItem(item, index);
// Highlight code blocks after adding them
const newItem = document.getElementById(`trajectoryItem${index}`);
newItem.querySelectorAll("pre code").forEach((block) => {
hljs.highlightElement(block);
});
});
// Initialize image click handlers after all items are added
initializeImageHandlers();
} else {
container.textContent = "No trajectory content found.";
}
})
.catch((error) => {
console.error("Error fetching file:", error);
document.getElementById("fileContent").textContent =
"Error loading content. " + error;
});
// Highlight selected file
document.querySelectorAll("#fileList li").forEach((li) => {
li.classList.remove("selected");
if (li.textContent.split(" ")[0] === fileName) {
li.classList.add("selected");
}
});
}
function initializeImageHandlers() {
// Remove existing overlay if present
const existingOverlay = document.querySelector(".image-overlay");
if (existingOverlay) {
existingOverlay.remove();
}
// Create overlay element
const overlay = document.createElement("div");
overlay.className = "image-overlay";
document.body.appendChild(overlay);
// Add click handlers to all observation images
document.querySelectorAll(".observation-image").forEach((img) => {
img.addEventListener("click", function (e) {
e.preventDefault();
e.stopPropagation();
// Toggle expanded state
if (this.classList.contains("expanded")) {
this.classList.remove("expanded");
overlay.classList.remove("active");
} else {
// Remove expanded class from all other images
document
.querySelectorAll(".observation-image.expanded")
.forEach((otherImg) => {
otherImg.classList.remove("expanded");
});
// Add expanded class to clicked image
this.classList.add("expanded");
overlay.classList.add("active");
}
});
});
// Close expanded image when clicking overlay
overlay.addEventListener("click", function () {
document.querySelectorAll(".observation-image.expanded").forEach((img) => {
img.classList.remove("expanded");
});
overlay.classList.remove("active");
});
// Close expanded image when pressing Escape key
document.addEventListener("keydown", function (e) {
if (e.key === "Escape") {
document
.querySelectorAll(".observation-image.expanded")
.forEach((img) => {
img.classList.remove("expanded");
});
overlay.classList.remove("active");
}
});
}
function refreshCurrentFile() {
if (currentFileName) {
const currentScrollPosition =
document.documentElement.scrollTop || document.body.scrollTop;
viewFile(currentFileName.split(" ")[0]);
setTimeout(() => {
window.scrollTo(0, currentScrollPosition);
}, 100);
}
}
function fetchDirectoryInfo() {
const baseUrl = getBaseUrl();
fetch(`${baseUrl}/directory_info`)
.then((response) => response.json())
.then((data) => {
if (data.directory) {
trajectoryDirectory = data.directory;
document.title = `Trajectory Viewer: ${data.directory}`;
document.getElementById("directoryInfo").textContent =
`Directory: ${data.directory}`;
}
})
.catch((error) => console.error("Error fetching directory info:", error));
}
window.onload = function () {
fetchFiles();
fetchDirectoryInfo();
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@@ -0,0 +1,11 @@
<svg width="21" height="20" viewBox="0 0 21 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.64645 11.7772L3.29289 12.1308L3.64645 12.4843L8.59619 17.4341L8.94975 17.7876L9.3033 17.4341L18.8512 7.88615C20.4133 6.32405 20.4133 3.79139 18.8512 2.22929C17.2891 0.667195 14.7565 0.667193 13.1944 2.22929L3.64645 11.7772Z" fill="black" stroke="white"/>
<path d="M3.09429 14.2581C3.27091 13.4928 4.22041 13.2204 4.77579 13.7758L7.2242 16.2242C7.77958 16.7796 7.50727 17.7291 6.74194 17.9057L3.559 18.6402C2.83893 18.8064 2.19359 18.1611 2.35976 17.441L3.09429 14.2581Z" fill="black"/>
<path d="M3.09429 14.258C3.27091 13.4927 4.22041 13.2204 4.77579 13.7758L7.2242 16.2242C7.77958 16.7796 7.50727 17.7291 6.74194 17.9057L3.559 18.6402C2.83893 18.8064 2.19359 18.161 2.35976 17.441L3.09429 14.258Z" fill="black"/>
<mask id="mask0_8_32" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="2" y="12" width="7" height="7">
<path d="M3.5 12.5L8.5 17.5L1.99998 19L3.5 12.5Z" fill="black"/>
</mask>
<g mask="url(#mask0_8_32)">
<rect x="1.89587" y="14.1084" width="2.27835" height="7.05693" transform="rotate(-45 1.89587 14.1084)" fill="black"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@@ -0,0 +1,25 @@
<html>
<head>
<title>Trajectory Viewer</title>
<link rel="stylesheet" type="text/css" href="style.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
<script src="fileViewer.js"></script>
</head>
<body>
<div class="container">
<h1>Trajectory File Viewer</h1>
<h3 id="directoryInfo"></h3>
<ul id="fileList"></ul>
<label style="margin-right: 20px;">
<input type="checkbox" id="showDemos" onchange="refreshCurrentFile()"> Show demonstrations
</label>
<h2>Conversation History</h2>
<div id="fileContent" class="trajectory-container">No file selected.</div>
<div class="button-container">
<button id="refreshButton" onclick="refreshCurrentFile()">Refresh Current File</button>
</div>
</div>
</body>
</html>

View File

@@ -0,0 +1,354 @@
from __future__ import annotations
import http.server
import json
import os
import socketserver
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
from typing import Any
import yaml
def add_problem_statement(content):
"""The problem statement is the first 'user' message in the history.
We'll prepend the trajectory with the problem statement.
"""
problem_statement = ""
for item in content["history"]:
if item["role"] == "user":
problem_statement = item["content"]
break
if problem_statement:
content["trajectory"].insert(
0,
{
"thought": "",
"action": "",
"response": "",
"observation": problem_statement,
"messages": [{"role": "system", "content": "Problem Statement placeholder"}],
},
)
return content
def append_exit(content):
exit_status = content.get("info", {}).get("exit_status", None)
if exit_status is None:
return content
if exit_status.startswith("submitted"):
if "submission" in content["info"]:
content["trajectory"].append(
{
"thought": "Submitting solution",
"action": "Model Submission",
"response": "Submitting solution",
"observation": content["info"]["submission"],
"messages": [{"role": "system", "content": f"Submission generated - {exit_status}"}],
}
)
else:
msg = "No submission in history or info"
raise ValueError(msg)
return content
def append_patch(instance_id, content, patches, patch_type):
if content.get("info", {}).get("exit_status", None) is not None:
if instance_id in patches:
content["trajectory"].append(
{
"thought": f"Showing {patch_type} patch",
"response": f"Showing {patch_type} patch",
"action": f"{patch_type} Patch",
"observation": patches[instance_id],
}
)
return content
def append_results(traj_path: Path, instance_id: str, content, results, results_file):
stats: list[str] = []
model_stats = {}
if traj_path.exists():
data = json.loads(traj_path.read_text())
info = data.get("info", {})
model_stats = info.get("model_stats", {})
# Build stats section
exit_status = info.get("exit_status", "N/A")
instance_cost = model_stats.get("instance_cost", None)
instance_cost = f"{instance_cost:.2f}" if instance_cost is not None else "N/A"
tokens_sent = model_stats.get("tokens_sent", None)
tokens_sent = f"{tokens_sent:,}" if tokens_sent is not None else "N/A"
tokens_received = model_stats.get("tokens_received", None)
tokens_received = f"{tokens_received:,}" if tokens_received is not None else "N/A"
api_calls = model_stats.get("api_calls", None)
api_calls = f"{api_calls:,}" if api_calls is not None else "N/A"
stats.append("**** Run Stats ****")
stats.append(f"Exit Status: {exit_status}")
stats.append(f"Instance Cost: ${instance_cost}")
stats.append(f"Tokens Sent: {tokens_sent}")
stats.append(f"Tokens Received: {tokens_received}")
stats.append(f"API Calls: {api_calls}\n")
# Build status section
status = []
if results is None:
status.append("Evaluation results not found")
elif "completed_ids" in results and "submitted_ids" in results and "resolved_ids" in results:
is_completed = instance_id in results["completed_ids"]
is_submitted = instance_id in results["submitted_ids"]
is_resolved = instance_id in results["resolved_ids"]
status.append("**** Statuses ****")
status.append(f" {'' if is_completed else ''} Completed (The agent successfully ran)")
status.append(f" {'' if is_submitted else ''} Submitted (The agent successfully submitted a pull request)")
status.append(
f" {'' if is_resolved else ''} Resolved (The pull request {'' if is_resolved else 'has not '}"
"successfully resolved the issue during eval)"
)
else:
status.append("Results format not recognized")
if status == []:
status.append("Instance not found in results")
else:
status.append("---------------------------")
status.append(
"Note that the evaluation results here may not be accurate or up to date, since they are computed separately from the agent run itself."
)
status.append(f"Check {results_file} for the most accurate evaluation results.")
status.append("")
status.append(f"Instance ID: {instance_id}")
# Add evaluation report as first and last items in trajectory
eval_report = {
"thought": "Evaluation Report",
"action": "Showing evaluation results",
"response": "Showing evaluation results",
"observation": "\n".join([*stats, *status]),
"messages": [{"role": "system", "content": "Showing evaluation results and statistics"}],
}
if not content.get("trajectory"):
content["trajectory"] = []
content["trajectory"].insert(0, eval_report)
content["trajectory"].append(eval_report)
return content
def get_action_summary(content):
out = ""
i = 0
for item in content["history"]:
if item["role"] != "assistant":
continue
if item.get("is_demo"):
continue
i += 1
try:
action = item["action"]
except KeyError:
print(f"No action for step {i}")
print(item)
raise
if len(action) > 70:
action = action[:67] + "..."
out += f"Step {i}: {action}\n"
return out
def load_content(file_name, gold_patches, test_patches) -> dict[str, Any]:
with open(file_name) as infile:
content = json.load(infile)
results_file = Path(file_name).parent / "results.json"
results = load_results(results_file)
content = add_problem_statement(content)
content = append_exit(content)
content = append_patch(Path(file_name).stem, content, gold_patches, "Gold")
content = append_patch(Path(file_name).stem, content, test_patches, "Test")
content["history"].insert(0, {"role": "Action Summary", "content": get_action_summary(content)})
return append_results(
Path(file_name),
Path(file_name).stem,
content,
results,
results_file,
)
def load_results(results_path: Path) -> dict[str, Any] | None:
"""Load results from results.json.
If file is not found, return None.
"""
if not results_path.exists():
return None
with open(results_path) as infile:
results = json.load(infile)
# Different versions of the code used "not_generated" or "no_generation".
# Let's standardize this here
if "no_generation" in results:
results["not_generated"] = results["no_generation"]
del results["no_generation"]
return results
def get_status(traj_path) -> str:
"""Return results emoji for single trajectory"""
results = load_results(Path(traj_path).parent / "results.json")
info = json.loads(Path(traj_path).read_text()).get("info", {})
n_steps = info.get("model_stats", {}).get("api_calls", "N/A")
exit_status = info.get("exit_status", "N/A")
exit_status_str = f" ({exit_status} after {n_steps} steps)"
instance_id = Path(traj_path).stem
if results is None:
return f"{exit_status_str}"
elif instance_id in results["resolved_ids"]:
return ""
else:
return f"{exit_status_str}"
class Handler(http.server.SimpleHTTPRequestHandler):
file_mod_times = {} # Dictionary to keep track of file modification times
def __init__(self, *args, **kwargs):
self.gold_patches = {}
self.test_patches = {}
if "gold_patches" in kwargs:
self.gold_patches = kwargs.pop("gold_patches")
if "test_patches" in kwargs:
self.test_patches = kwargs.pop("test_patches")
self.traj_dir = kwargs.pop("directory", ".") # Extract directory
super().__init__(*args, **kwargs)
def serve_directory_info(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps({"directory": self.traj_dir}).encode())
def serve_file_content(self, file_path):
try:
content = load_content(
Path(self.traj_dir) / file_path,
self.gold_patches,
self.test_patches,
)
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(json.dumps(content).encode())
except FileNotFoundError:
self.send_error(404, f"File {file_path} not found")
def do_GET(self):
if self.path == "/directory_info":
self.serve_directory_info()
elif self.path.startswith("/files"):
self.handle_files_request()
elif self.path.startswith("/trajectory/"):
file_path = self.path[len("/trajectory/") :]
self.serve_file_content(file_path)
elif self.path.startswith("/check_update"):
self.check_for_updates()
else:
super().do_GET()
def handle_files_request(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
files = sorted(
(
str(file.relative_to(Path(self.traj_dir))) + " " * 4 + get_status(file)
for file in Path(self.traj_dir).glob("**/*.traj")
),
key=lambda x: str(Path(self.traj_dir) / x),
reverse=True,
)
self.wfile.write(json.dumps(files).encode())
def check_for_updates(self):
current_mod_times = {str(file): file.stat().st_mtime for file in Path(self.traj_dir).glob("**/*.traj")}
if current_mod_times != Handler.file_mod_times:
Handler.file_mod_times = current_mod_times
self.send_response(200) # Send response that there's an update
else:
self.send_response(204) # Send no content response if no update
self.end_headers()
def end_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
super().end_headers()
def main(data_path, directory, port):
data = []
if data_path is not None:
if data_path.endswith(".jsonl"):
data = [json.loads(x) for x in Path(data_path).read_text().splitlines(keepends=True)]
elif data_path.endswith(".json"):
with open(data_path) as f:
data = json.load(f)
elif "args.yaml" in os.listdir(directory):
with open(Path(directory) / "args.yaml") as file:
args = yaml.safe_load(file)
if "environment" in args and "data_path" in args["environment"]:
data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
if data_path.exists:
with open(data_path) as f:
data = json.load(f)
gold_patches = {d["instance_id"]: d["patch"] if "patch" in d else None for d in data}
test_patches = {d["instance_id"]: d["test_patch"] if "test_patch" in d else None for d in data}
handler_with_directory = partial(
Handler,
directory=directory,
gold_patches=gold_patches,
test_patches=test_patches,
)
try:
with socketserver.TCPServer(("", port), handler_with_directory) as httpd:
print(f"Serving at http://localhost:{port}")
httpd.serve_forever()
except OSError as e:
if e.errno == 48:
print(f"ERROR: Port ({port}) is already in use. Try another port with the --port flag.")
else:
raise e
def get_parser():
parser = ArgumentParser()
parser.add_argument(
"--data_path",
type=str,
help="Path to dataset that was used for the trajectories. Necessary to display gold patches.",
)
parser.add_argument("--directory", type=str, help="Directory to serve", default=os.getcwd(), nargs="?")
parser.add_argument("--port", type=int, help="Port to serve", default=8000)
return parser
def run_from_cli(args: list[str] | None = None):
# Hack to make sure all the templates and all are found
parsed_args = get_parser().parse_args(args)
# convert directory, relative to the absolute path
parsed_args.directory = str(Path(parsed_args.directory).resolve().absolute())
os.chdir(Path(__file__).parent)
main(**vars(parsed_args))
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,169 @@
from __future__ import annotations
import json
import logging
import traceback
from argparse import ArgumentParser
from pathlib import Path
import yaml
from tqdm.auto import tqdm
try:
from .server import load_content
except ImportError:
from server import load_content
logger = logging.getLogger(__name__)
logging.getLogger("simple_parsing").setLevel(logging.INFO)
TEMPLATE = """
<html>
<head>
<title>Trajectory Viewer</title>
<style>
{style_sheet}
</style>
</head>
<body>
<div class="container">
{file_path_tree}
<h2>Conversation History</h2>
<pre id="fileContent">{file_content}</pre>
</div>
</body>
</html>
"""
try:
with open(Path(__file__).parent / "style.css") as infile:
STYLE_SHEET = infile.read()
except Exception as e:
style_file = Path(__file__).parent / "style.css"
logger.error(f"Failed to load style sheet from {style_file}: {traceback.format_exc()}")
raise e
def _load_file(file_name, gold_patches, test_patches):
try:
role_map = {
"user": "Computer",
"assistant": "SWE-Agent",
"subroutine": "SWE-Agent subroutine",
"default": "Default",
"system": "System",
"demo": "Demonstration",
}
content = load_content(file_name, gold_patches, test_patches)
if "history" in content and isinstance(content["history"], list):
history_content = ""
for index, item in enumerate(content["history"]):
item_content = item.get("content", "").replace("<", "&lt;").replace(">", "&gt;")
if item.get("agent") and item["agent"] != "primary":
role_class = "subroutine"
else:
role_class = item.get("role", "default").lower().replace(" ", "-")
element_id = f"historyItem{index}"
role_name = role_map.get(item.get("role", ""), item.get("role", ""))
history_content += (
f"""<div class="history-item {role_class}" id="{element_id}">"""
f"""<div class="role-bar {role_class}"><strong><span>{role_name}</span></strong></div>"""
f"""<div class="content-container">"""
f"""<pre>{item_content}</pre>"""
f"""</div>"""
f"""<div class="shadow"></div>"""
f"""</div>"""
)
return history_content
else:
return "No history content found."
except Exception:
return f"Error loading content. {traceback.format_exc()}"
def _make_file_path_tree(file_path):
path_parts = file_path.split("/")
relevant_parts = path_parts[-3:]
html_string = '<div class="filepath">\n'
for part in relevant_parts:
html_string += f'<div class="part">{part}</div>\n'
html_string += "</div>"
return html_string
def save_static_viewer(file_path):
if not isinstance(file_path, Path):
file_path = Path(file_path)
data = []
if "args.yaml" in list(map(lambda x: x.name, file_path.parent.iterdir())):
args = yaml.safe_load(Path(file_path.parent / "args.yaml").read_text())
if "environment" in args and "data_path" in args["environment"]:
data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
if data_path.exists():
with open(data_path) as f:
data = json.load(f)
if not isinstance(data, list) or not data or "patch" not in data[0] or "test_patch" not in data[0]:
data = []
gold_patches = {x["instance_id"]: x["patch"] for x in data}
test_patches = {x["instance_id"]: x["test_patch"] for x in data}
content = _load_file(file_path, gold_patches, test_patches)
file_path_tree = _make_file_path_tree(file_path.absolute().as_posix())
icons_path = Path(__file__).parent / "icons"
relative_icons_path = find_relative_path(file_path, icons_path)
style_sheet = STYLE_SHEET.replace("url('icons/", f"url('{relative_icons_path.as_posix()}/").replace(
'url("icons/',
f'url("{relative_icons_path.as_posix()}/',
)
data = TEMPLATE.format(file_content=content, style_sheet=style_sheet, file_path_tree=file_path_tree)
output_file = file_path.with_suffix(".html")
with open(output_file, "w+") as outfile:
print(data, file=outfile)
logger.info(f"Saved static viewer to {output_file}")
def find_relative_path(from_path, to_path):
# Convert paths to absolute for uniformity
from_path = from_path.resolve()
to_path = to_path.resolve()
if from_path.is_file():
from_path = from_path.parent
if to_path.is_file():
to_path = to_path.parent
if not from_path.is_dir() or not to_path.is_dir():
msg = f"Both from_path and to_path must be directories, but got {from_path} and {to_path}"
raise ValueError(msg)
# Identify the common ancestor and the parts of each path beyond it
common_parts = 0
for from_part, to_part in zip(from_path.parts, to_path.parts):
if from_part != to_part:
break
common_parts += 1
# Calculate the '../' needed to get back from from_path to the common ancestor
back_to_ancestor = [".."] * (len(from_path.parts) - common_parts)
# Direct path from common ancestor to to_path
to_target = to_path.parts[common_parts:]
# Combine to get the relative path
return Path(*back_to_ancestor, *to_target)
def save_all_trajectories(directory):
if not isinstance(directory, Path):
directory = Path(directory)
all_files = list(directory.glob("**/*.traj"))
logger.info(f"Found {len(all_files)} trajectory files in {directory}")
for file_path in tqdm(all_files, desc="Saving static viewers"):
save_static_viewer(file_path)
logger.info(f"Saved static viewers for all trajectories in {args.directory}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("directory", type=str, help="Directory containing trajectory files")
args = parser.parse_args()
save_all_trajectories(args.directory)

View File

@@ -0,0 +1,454 @@
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #1e1e1e;
color: #d4d4d4;
}
h1,
h2 {
color: #d4d4d4;
}
#fileList {
list-style-type: none;
padding: 10px;
max-height: 400px;
overflow-y: auto;
margin: 0;
background-color: #252526;
border: 3px solid #1697e2;
border-radius: 10px;
}
#fileList li {
cursor: pointer;
padding: 10px;
background-color: #2d2d2d;
margin-bottom: 5px;
border: 2px solid #454545;
border-radius: 5px;
transition: background-color 0.3s;
color: #d4d4d4;
}
#fileList li:hover {
background-color: #3e3e3e;
}
#fileList li.selected {
border-color: #fabb00;
}
#fileContent {
background-color: #1e1e1e;
border: 1px solid #454545;
padding: 10px;
margin-top: 20px;
white-space: pre-wrap;
color: #d4d4d4;
}
.button-container {
display: flex;
justify-content: center;
align-items: center;
}
#refreshButton {
padding: 4px 10px;
min-width: 80px;
border: none;
font: inherit;
color: #373030;
border-radius: 10px;
outline: none;
text-decoration: none;
cursor: default;
font-weight: 400;
background: #fff;
box-shadow:
0px 0px 1px #0000004d,
0px 1px 1px #00000066;
}
#refreshButton:hover {
/* hover MUST come before active */
background: linear-gradient(#0000004d, #00000066);
color: #fff;
position: relative;
}
#refreshButton:active {
background: linear-gradient(#4faefc, #006bff);
color: #fff;
position: relative;
}
.history-item {
border: 3px solid black;
border-radius: 5px;
padding: 0px;
/* padding-bottom: 5%; */
margin-bottom: 15px;
overflow-x: hidden;
white-space: normal;
/* overflow-x: auto; Enables horizontal scrolling */
/* white-space: nowrap; Keeps content in a single line */
max-height: 450px; /* Adjust as needed for 25 lines */
overflow: hidden;
position: relative;
}
.shadow {
height: 30px; /* Height of the shadow */
background: linear-gradient(to bottom, transparent, rgba(0, 0, 0, 0.4));
position: absolute;
bottom: 0;
left: 0;
right: 0;
pointer-events: none; /* Ensures the shadow doesn't interfere with interaction */
display: none; /* Initially hidden */
}
.has-shadow .shadow {
display: block;
}
.content-container {
max-height: 400px; /* Adjust as needed */
overflow-y: auto;
position: relative;
padding: 10px;
}
.content-container pre {
white-space: pre-wrap; /* Wrap lines and preserve whitespace */
overflow-wrap: break-word; /* Handle long words */
}
.container {
max-width: 1000px;
margin: 0 auto; /* Centers the container */
padding: 20px; /* Optional: for some inner spacing */
}
.history-item.user {
border-color: #1697e2;
}
.history-item.tool {
border-color: #1483c3;
}
.history-item.system {
border-color: #004b80;
}
.history-item.subroutine {
border-color: #006b00;
}
.history-item.gold-patch {
border-color: #fabb00;
}
.history-item.assistant {
border-color: rgb(0, 0, 0);
}
.history-item.test-patch {
border-color: #7373d9;
}
.history-item.evaluation-report {
border-color: #35614b;
}
/* filepath-tree stuff */
.filepath {
display: flex;
flex-direction: column; /* Changes layout to one part per line */
align-items: flex-start; /* Aligns parts to the start of the container */
font-size: 16px;
gap: 10px;
padding: 5px;
background-color: #2d2d2d;
}
.part {
border: 1px solid #454545;
white-space: nowrap; /* Prevents wrapping within parts */
padding: 5px;
background-color: #3e3e3e;
border-radius: 5px;
color: #d4d4d4;
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 1s ease-out;
}
.trajectory-container {
position: relative;
width: 100%;
}
.trajectory-item {
display: flex;
flex-direction: column;
margin-bottom: 0px;
border: 3px solid #1697e2;
border-radius: 5px;
position: relative;
background: #2d2d2d;
min-height: fit-content; /* Added to ensure proper expansion */
height: auto; /* Added to ensure proper expansion */
overflow: visible; /* Changed from hidden/auto to visible */
}
.trajectory-main {
flex: 1;
min-width: 0;
display: flex;
flex-direction: column;
gap: 0;
}
.response-section,
.observation-section,
.messages-section {
padding: 0 4px;
position: relative;
min-height: 30px;
height: 400px;
max-height: min-content;
display: flex;
flex-direction: column;
resize: vertical;
overflow: hidden;
width: calc(100% - 16px);
}
.response-section {
background-color: #252526;
border-bottom: 1px solid #454545;
color: #d4d4d4;
}
.observation-section {
background-color: #2d2d2d;
color: #d4d4d4;
}
.observation-images-section {
padding: 4px;
position: relative;
min-height: 30px;
height: 200px;
max-height: 200px;
background: #2d2d30;
border-top: 1px solid #454545;
display: flex;
flex-direction: column;
resize: vertical;
overflow: hidden;
width: calc(100% - 16px);
color: #d4d4d4;
}
/* Add section headers */
.response-section::before,
.observation-section::before,
.observation-images-section::before,
.messages-section::before {
content: attr(data-title);
font-weight: bold;
padding: 8px 12px; /* Keep this comfortable padding */
background: rgba(255, 255, 255, 0.05);
margin: 0 0 8px 0; /* Add some bottom margin to separate from content */
border-bottom: 1px solid #454545;
position: sticky;
top: 0;
width: 100%;
box-sizing: border-box;
color: #d4d4d4;
}
/* Scrollable content containers */
.content-wrapper {
overflow-y: auto;
flex: 1;
padding: 0;
margin: 0;
display: flex; /* Add flex display */
flex-direction: column; /* Stack children vertically */
}
.content-wrapper pre {
margin: 0;
padding: 0;
flex: 1; /* Allow pre to fill the space */
display: flex; /* Make pre a flex container */
flex-direction: column; /* Stack children vertically */
}
.content-wrapper pre code {
margin: 0;
padding: 0;
flex: 1; /* Allow code to fill the space */
}
.messages-section {
padding: 4px;
position: relative;
min-height: 30px;
height: 30px;
max-height: min-content; /* Only expand to fit content */
background: #252526;
border-top: 1px solid #454545;
display: flex;
flex-direction: column;
resize: vertical;
overflow: hidden;
width: calc(100% - 16px);
color: #d4d4d4;
}
.messages-section::before {
content: attr(data-title);
font-weight: bold;
padding: 8px 12px; /* Increased padding around title text */
background: rgba(255, 255, 255, 0.05);
margin: -8px -8px 5px -8px; /* Reduced bottom margin */
border-bottom: 1px solid #454545;
position: sticky;
top: -8px;
width: calc(100% + 16px);
box-sizing: border-box;
color: #d4d4d4;
}
.messages-toggle {
display: none;
}
.execution-time {
font-size: 0.8em;
color: #666;
text-align: right;
padding: 5px;
border-top: 1px solid #ddd;
}
/* Add a visual indicator for resizable areas */
.response-section:hover,
.observation-section:hover,
.observation-images-section:hover,
.messages-section.expanded:hover {
outline: 1px dashed #999;
}
.demo-message {
color: #2ecc71;
font-weight: bold;
}
/* Syntax highlighting overrides for dark theme */
.hljs {
background: #1e1e1e !important;
color: #d4d4d4 !important;
}
.hljs-keyword,
.hljs-selector-tag,
.hljs-built_in,
.hljs-name,
.hljs-tag {
color: #569cd6 !important; /* Blue */
}
.hljs-string,
.hljs-title,
.hljs-section,
.hljs-attribute,
.hljs-literal,
.hljs-template-tag,
.hljs-template-variable,
.hljs-type {
color: #ce9178 !important; /* Orange/Pink */
}
.hljs-comment,
.hljs-quote {
color: #6a9955 !important; /* Green */
}
.hljs-number,
.hljs-regexp,
.hljs-symbol,
.hljs-variable,
.hljs-template-variable,
.hljs-link,
.hljs-selector-attr,
.hljs-selector-pseudo {
color: #b5cea8 !important; /* Light green */
}
label {
color: #d4d4d4; /* Matches the theme's text color */
display: flex;
align-items: center;
gap: 5px;
}
input[type="checkbox"] {
margin: 0;
cursor: pointer;
}
/* observation images styles */
.observation-images {
margin-top: 10px;
padding: 10px 0;
}
.observation-image-container {
margin-bottom: 15px;
text-align: center;
}
.observation-image {
max-width: 100%;
max-height: 400px;
border: 1px solid #454545;
border-radius: 5px;
background-color: #1e1e1e;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
cursor: pointer;
transition: transform 0.2s ease;
}
.observation-image:hover {
transform: scale(1.02);
border-color: #1697e2;
}
.image-caption {
margin-top: 5px;
font-size: 0.9em;
color: #b3b3b3;
font-style: italic;
}
.observation-image.expanded {
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%) scale(1);
max-width: 90vw;
max-height: 90vh;
z-index: 1000;
border: 2px solid #1697e2;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.8);
}
.image-overlay {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.8);
z-index: 999;
display: none;
}
.image-overlay.active {
display: block;
}

View File

View File

@@ -0,0 +1,158 @@
"""This module contains an auxiliary class for rendering progress of the batch run."""
import collections
from pathlib import Path
from threading import Lock
import yaml
from rich.console import Group
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskID,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from sweagent.agent.models import GLOBAL_STATS
def _shorten_str(s: str, max_len: int, shorten_left=False) -> str:
if not shorten_left:
s = s[: max_len - 3] + "..." if len(s) > max_len else s
else:
s = "..." + s[-max_len + 3 :] if len(s) > max_len else s
return f"{s:<{max_len}}"
class RunBatchProgressManager:
def __init__(
self,
num_instances: int,
yaml_report_path: Path | None = None,
):
"""This class manages a progress bar/UI for run-batch
Args:
num_instances: Number of task instances
yaml_report_path: Path to save a yaml report of the instances and their exit statuses
"""
self._spinner_tasks: dict[str, TaskID] = {}
"""We need to map instance ID to the task ID that is used by the rich progress bar."""
self._lock = Lock()
self._instances_by_exit_status = collections.defaultdict(list)
self._main_progress_bar = Progress(
SpinnerColumn(spinner_name="dots2"),
TextColumn("[progress.description]{task.description} (${task.fields[total_cost]})"),
BarColumn(),
MofNCompleteColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TextColumn("[cyan]eta:[/cyan]"),
TimeRemainingColumn(),
# Wait 5 min before estimating speed
speed_estimate_period=60 * 5,
)
self._task_progress_bar = Progress(
SpinnerColumn(spinner_name="dots2"),
TextColumn("{task.fields[instance_id]}"),
TextColumn("{task.fields[status]}"),
TimeElapsedColumn(),
)
"""Task progress bar for individual instances. There's only one progress bar
with one task for each instance.
"""
self._main_task_id = self._main_progress_bar.add_task(
"[cyan]Overall Progress", total=num_instances, total_cost=0
)
self.render_group = Group(Table(), self._task_progress_bar, self._main_progress_bar)
self._yaml_report_path = yaml_report_path
@property
def n_completed(self) -> int:
return sum(len(instances) for instances in self._instances_by_exit_status.values())
def update_exit_status_table(self):
# We cannot update the existing table, so we need to create a new one and
# assign it back to the render group.
t = Table()
t.add_column("Exit Status")
t.add_column("Count", justify="right", style="bold cyan")
t.add_column("Most recent instances")
t.show_header = False
with self._lock:
t.show_header = True
# Sort by number of instances in descending order
sorted_items = sorted(self._instances_by_exit_status.items(), key=lambda x: len(x[1]), reverse=True)
for status, instances in sorted_items:
instances_str = _shorten_str(", ".join(reversed(instances)), 55)
t.add_row(status, str(len(instances)), instances_str)
assert self.render_group is not None
self.render_group.renderables[0] = t
def _update_total_costs(self) -> None:
with self._lock:
self._main_progress_bar.update(self._main_task_id, total_cost=f"{GLOBAL_STATS.total_cost:.2f}")
def update_instance_status(self, instance_id: str, message: str):
assert self._task_progress_bar is not None
assert self._main_progress_bar is not None
with self._lock:
self._task_progress_bar.update(
self._spinner_tasks[instance_id],
status=_shorten_str(message, 30),
instance_id=_shorten_str(instance_id, 25, shorten_left=True),
)
self._update_total_costs()
def on_instance_start(self, instance_id: str):
with self._lock:
self._spinner_tasks[instance_id] = self._task_progress_bar.add_task(
description=f"Task {instance_id}",
status="Task initialized",
total=None,
instance_id=instance_id,
)
def on_instance_end(self, instance_id: str, exit_status: str | None) -> None:
self._instances_by_exit_status[exit_status].append(instance_id)
with self._lock:
self._task_progress_bar.remove_task(self._spinner_tasks[instance_id])
self._main_progress_bar.update(TaskID(0), advance=1)
self.update_exit_status_table()
self._update_total_costs()
if self._yaml_report_path is not None:
self._save_overview_data_yaml(self._yaml_report_path)
def on_uncaught_exception(self, instance_id: str, exception: Exception) -> None:
self.on_instance_end(instance_id, f"Uncaught {type(exception).__name__}")
def print_report(self) -> None:
"""Print complete list of instances and their exit statuses."""
for status, instances in self._instances_by_exit_status.items():
print(f"{status}: {len(instances)}")
for instance in instances:
print(f" {instance}")
def _get_overview_data(self) -> dict:
"""Get data like exit statuses, total costs, etc."""
return {
# convert defaultdict to dict because of serialization
"instances_by_exit_status": dict(self._instances_by_exit_status),
"total_cost": GLOBAL_STATS.total_cost,
}
def _save_overview_data_yaml(self, path: Path) -> None:
"""Save a yaml report of the instances and their exit statuses."""
with self._lock:
path.write_text(yaml.dump(self._get_overview_data(), indent=4))

View File

@@ -0,0 +1,449 @@
import json
import os
import random
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
from swerex.deployment.config import (
DeploymentConfig,
DockerDeploymentConfig,
DummyDeploymentConfig,
LocalDeploymentConfig,
)
from typing_extensions import Self
from sweagent.agent.problem_statement import (
ProblemStatementConfig,
SWEBenchMultimodalProblemStatement,
TextProblemStatement,
)
from sweagent.environment.repo import GithubRepoConfig, LocalRepoConfig, PreExistingRepoConfig, SWESmithRepoConfig
from sweagent.environment.swe_env import EnvironmentConfig
from sweagent.utils.files import load_file
from sweagent.utils.github import _is_repo_private
from sweagent.utils.log import get_logger
logger = get_logger("swea-config", emoji="🔧")
class AbstractInstanceSource(ABC):
"""Anything that adheres to this standard can be used to load instances."""
@abstractmethod
def get_instance_configs(self) -> list[EnvironmentConfig]: ...
class BatchInstance(BaseModel):
"""A single instance in a batch of instances.
This specifies both the environment configuration and the problem statement.
"""
env: EnvironmentConfig
problem_statement: ProblemStatementConfig
def _slice_spec_to_slice(slice_spec: str) -> slice:
if slice_spec == "":
return slice(None)
parts = slice_spec.split(":")
values = [None if p == "" else int(p) for p in parts]
if len(parts) == 1:
return slice(values[0])
if len(parts) == 2:
return slice(values[0], values[1])
if len(parts) == 3:
return slice(values[0], values[1], values[2])
msg = (
f"Invalid slice specification: {slice_spec!r}. "
"Here's the expected format: stop or start:stop or start:stop:step "
"(i.e., it behaves exactly like python's list slicing `list[slice]`)."
)
raise ValueError(msg)
def _filter_batch_items(
instances: list[BatchInstance], *, filter_: str, slice_: str = "", shuffle: bool = False
) -> list[BatchInstance]:
if shuffle:
instances = sorted(instances.copy(), key=lambda x: x.problem_statement.id)
random.seed(42)
random.shuffle(instances)
before_filter = len(instances)
instances = [instance for instance in instances if re.match(filter_, instance.problem_statement.id)]
after_filter = len(instances)
if before_filter != after_filter:
logger.info("Instance filter: %d -> %d instances", before_filter, after_filter)
if slice_:
instances = instances[_slice_spec_to_slice(slice_)]
after_slice = len(instances)
if before_filter != after_slice:
logger.info("Instance slice: %d -> %d instances", before_filter, after_slice)
return instances
class SimpleBatchInstance(BaseModel):
"""A simple way to configure a single instance in a batch of instances that all
use similar deployment configurations.
Predominantly used for benchmarking purposes. Assumes that the repository is already
present in the docker container.
"""
image_name: str
problem_statement: str
instance_id: str
repo_name: str = ""
"""Specifies the repository to use. If empty, no repository is used.
If the string does not contain a slash, it is interpreted as an already existing repository at the root
of the docker container. If it contains the word "github", it is interpreted as a github repository.
Else, it is interpreted as a local repository.
"""
base_commit: str = "HEAD"
"""Used to reset repo."""
extra_fields: dict[str, Any] = Field(default_factory=dict)
"""Any additional data to be added to the instance.
This data will be available when formatting prompt templates.
"""
# Ignore instead of allow because they should be added as `extra_fields`
model_config = ConfigDict(extra="ignore")
def to_full_batch_instance(self, deployment: DeploymentConfig) -> BatchInstance:
"""Merge the deployment options into the `SimpleBatchInstance` object to get a full `BatchInstance`."""
# Very important: Make a copy of the deployment config because it will be shared among instances!!!
deployment = deployment.model_copy(deep=True)
if "issue_images" in self.extra_fields:
problem_statement = SWEBenchMultimodalProblemStatement(
text=self.problem_statement,
issue_images=self.extra_fields.pop("issue_images"),
id=self.instance_id,
extra_fields=self.extra_fields,
)
else:
problem_statement = TextProblemStatement(
text=self.problem_statement, id=self.instance_id, extra_fields=self.extra_fields
)
if not self.repo_name:
repo = None
elif "github" in self.repo_name:
repo = GithubRepoConfig(github_url=self.repo_name, base_commit=self.base_commit)
elif "/" not in self.repo_name:
repo = PreExistingRepoConfig(repo_name=self.repo_name, base_commit=self.base_commit)
else:
repo = LocalRepoConfig(path=Path(self.repo_name), base_commit=self.base_commit)
if isinstance(deployment, LocalDeploymentConfig):
if self.image_name:
msg = "Local deployment does not support image_name"
raise ValueError(msg)
return BatchInstance(
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
)
if isinstance(deployment, DummyDeploymentConfig):
return BatchInstance(
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
)
deployment.image = self.image_name # type: ignore
if isinstance(deployment, DockerDeploymentConfig) and deployment.python_standalone_dir is None:
# Note: you can disable this by setting python_standalone_dir to ""
deployment.python_standalone_dir = "/root" # type: ignore
return BatchInstance(
env=EnvironmentConfig(deployment=deployment, repo=repo), problem_statement=problem_statement
)
@model_validator(mode="before")
@classmethod
def handle_legacy_id(cls, data):
# Handling compatibility with swe-agent <= 1.0.1
if isinstance(data, dict):
if "id" in data and "instance_id" not in data:
data["instance_id"] = data["id"]
data.pop("id")
return data
# todo: Maybe populate extra fields?
@classmethod
def from_swe_bench(cls, instance: dict[str, Any]) -> Self:
"""Convert instances from the classical SWE-bench dataset to the `SimpleBatchInstance` format."""
iid = instance["instance_id"]
image_name = instance.get("image_name", None)
if image_name is None:
# Docker doesn't allow double underscore, so we replace them with a magic token
id_docker_compatible = iid.replace("__", "_1776_")
image_name = f"docker.io/swebench/sweb.eval.x86_64.{id_docker_compatible}:latest".lower()
extra_fields = {}
if "image_assets" in instance:
issue_images = json.loads(instance["image_assets"])["problem_statement"]
extra_fields["issue_images"] = issue_images
return cls(
image_name=image_name,
problem_statement=instance["problem_statement"],
instance_id=iid,
repo_name="testbed",
base_commit=instance["base_commit"],
extra_fields=extra_fields,
)
class InstancesFromFile(BaseModel, AbstractInstanceSource):
"""Load instances from a file."""
path: Path
filter: str = ".*"
"""Regular expression to filter the instances by instance id."""
slice: str = ""
"""Select only a slice of the instances (after filtering by `filter`).
Possible values are stop or start:stop or start:stop:step
(i.e., it behaves exactly like python's list slicing `list[slice]`).
"""
shuffle: bool = False
"""Shuffle the instances (before filtering and slicing)."""
deployment: DeploymentConfig = Field(
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
description="Deployment options.",
)
"""Note that the image_name option is overwritten by the images specified in the task instances."""
simple: Literal[True] = True
"""Convenience discriminator for (de)serialization/CLI. Do not change."""
type: Literal["file"] = "file"
"""Discriminator for (de)serialization/CLI. Do not change."""
def get_instance_configs(self) -> list[BatchInstance]:
instance_dicts = load_file(self.path)
simple_instances = [SimpleBatchInstance.model_validate(instance_dict) for instance_dict in instance_dicts]
instances = [instance.to_full_batch_instance(self.deployment) for instance in simple_instances]
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
@property
def id(self) -> str:
return self.path.stem
class InstancesFromHuggingFace(BaseModel, AbstractInstanceSource):
"""Load instances from HuggingFace."""
dataset_name: str
"""Name of the HuggingFace dataset. Same as when using `datasets.load_dataset`."""
split: str = "dev"
filter: str = ".*"
"""Regular expression to filter the instances by instance id."""
slice: str = ""
"""Select only a slice of the instances (after filtering by `filter`).
Possible values are stop or start:stop or start:stop:step.
(i.e., it behaves exactly like python's list slicing `list[slice]`).
"""
shuffle: bool = False
"""Shuffle the instances (before filtering and slicing)."""
deployment: DeploymentConfig = Field(
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
)
"""Deployment configuration. Note that the `image_name` option is overwritten by the images specified in the task instances.
"""
type: Literal["huggingface"] = "huggingface"
"""Discriminator for (de)serialization/CLI. Do not change."""
def get_instance_configs(self) -> list[BatchInstance]:
from datasets import load_dataset
ds: list[dict[str, Any]] = load_dataset(self.dataset_name, split=self.split) # type: ignore
simple_instances: list[SimpleBatchInstance] = [SimpleBatchInstance.model_validate(instance) for instance in ds]
instances = [instance.to_full_batch_instance(self.deployment) for instance in simple_instances]
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
@property
def id(self) -> str:
ds_name = "".join(l for l in self.dataset_name if l.isalnum() or l in ["-", "_"])
return f"{ds_name}_{self.split}"
class SWEBenchInstances(BaseModel, AbstractInstanceSource):
"""Load instances from SWE-bench."""
subset: Literal["lite", "verified", "full", "multimodal", "multilingual"] = "lite"
"""Subset of swe-bench to use"""
# IMPORTANT: Do not call this `path`, because then if people do not specify instance.type,
# it might be resolved to ExpertInstancesFromFile or something like that.
path_override: str | Path | None = None
"""Allow to specify a different huggingface dataset name or path to a huggingface
dataset. This will override the automatic path set by `subset`.
"""
split: Literal["dev", "test"] = "dev"
deployment: DeploymentConfig = Field(
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
)
"""Deployment configuration. Note that the image_name option is overwritten by the images specified in the task instances.
"""
type: Literal["swe_bench"] = "swe_bench"
"""Discriminator for (de)serialization/CLI. Do not change."""
filter: str = ".*"
"""Regular expression to filter the instances by instance id."""
slice: str = ""
"""Select only a slice of the instances (after filtering by `filter`).
Possible values are stop or start:stop or start:stop:step.
(i.e., it behaves exactly like python's list slicing `list[slice]`).
"""
shuffle: bool = False
"""Shuffle the instances (before filtering and slicing)."""
evaluate: bool = False
"""Run sb-cli to evaluate"""
def _get_dataset_path(self) -> str:
if self.path_override is not None:
return str(self.path_override)
dataset_mapping = {
"full": "princeton-nlp/SWE-Bench",
"verified": "princeton-nlp/SWE-Bench_Verified",
"lite": "princeton-nlp/SWE-Bench_Lite",
"multimodal": "princeton-nlp/SWE-Bench_Multimodal",
"multilingual": "swe-bench/SWE-Bench_Multilingual",
}
if self.subset not in dataset_mapping:
msg = f"Unsupported subset: {self.subset}"
raise ValueError(msg)
return dataset_mapping[self.subset]
def get_instance_configs(self) -> list[BatchInstance]:
from datasets import load_dataset
ds: list[dict[str, Any]] = load_dataset(self._get_dataset_path(), split=self.split) # type: ignore
if isinstance(self.deployment, DockerDeploymentConfig):
self.deployment.platform = "linux/amd64"
instances = [
SimpleBatchInstance.from_swe_bench(instance).to_full_batch_instance(self.deployment) for instance in ds
]
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
@property
def id(self) -> str:
return f"swe_bench_{self.subset}_{self.split}"
class ExpertInstancesFromFile(BaseModel, AbstractInstanceSource):
"""Load instances from a file. The difference to `InstancesFromFile` is that the instances are configured as full
`EnvironmentInstanceConfig` objects, i.e., we could specify separate deployment configurations etc.
"""
path: Path
filter: str = ".*"
"""Regular expression to filter the instances by instance id."""
slice: str = ""
"""Select only a slice of the instances (after filtering by `filter`).
Possible values are stop or start:stop or start:stop:step.
(i.e., it behaves exactly like python's list slicing `list[slice]`).
"""
shuffle: bool = False
"""Shuffle the instances (before filtering and slicing)."""
type: Literal["expert_file"] = "expert_file"
"""Discriminator for (de)serialization/CLI. Do not change."""
def get_instance_configs(self) -> list[BatchInstance]:
instance_dicts = load_file(self.path)
instances = [BatchInstance.model_validate(instance_dict) for instance_dict in instance_dicts]
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
@property
def id(self) -> str:
return self.path.stem
class SWESmithInstances(BaseModel, AbstractInstanceSource):
"""Load instances from SWE-smith."""
path: Path
deployment: DeploymentConfig = Field(
default_factory=lambda: DockerDeploymentConfig(image="python:3.11"),
)
"""Deployment configuration. Note that the image_name option is overwritten by the images specified in the task instances.
"""
filter: str = ".*"
"""Regular expression to filter the instances by instance id."""
slice: str = ""
"""Select only a slice of the instances (after filtering by `filter`).
Possible values are stop or start:stop or start:stop:step.
(i.e., it behaves exactly like python's list slicing `list[slice]`).
"""
shuffle: bool = False
"""Shuffle the instances (before filtering and slicing)."""
type: Literal["swesmith"] = "swesmith"
"""Discriminator for (de)serialization/CLI. Do not change."""
def get_instance_configs(self) -> list[BatchInstance]:
github_token = os.getenv("GITHUB_TOKEN", "")
instance_dicts = load_file(self.path)
instances = []
for instance_dict in instance_dicts:
deployment = self.deployment.model_copy(deep=True)
deployment.image = instance_dict["image_name"] # type: ignore
if isinstance(deployment, DockerDeploymentConfig) and deployment.python_standalone_dir is None:
deployment.python_standalone_dir = "/root" # type: ignore
instance_id = instance_dict["instance_id"]
repo_field = instance_dict.get("repo", "")
mirror_url = ""
if repo_field and _is_repo_private(repo_field, github_token):
if not github_token:
msg = (
f"Repo '{repo_field}' appears to be private but GITHUB_TOKEN is not set. "
"Set GITHUB_TOKEN with 'repo' scope to access private repositories."
)
raise ValueError(msg)
mirror_url = f"https://github.com/{repo_field}.git"
repo = SWESmithRepoConfig(
repo_name="testbed",
base_commit=instance_id,
mirror_url=mirror_url,
)
problem_statement = TextProblemStatement(
text=instance_dict.get("problem_statement", ""),
id=instance_id,
extra_fields={"fail_to_pass": instance_dict.get("FAIL_TO_PASS", [])},
)
instances.append(
BatchInstance(
env=EnvironmentConfig(deployment=deployment, repo=repo),
problem_statement=problem_statement,
)
)
return _filter_batch_items(instances, filter_=self.filter, slice_=self.slice, shuffle=self.shuffle)
@property
def id(self) -> str:
return f"swesmith_{self.path.stem}"
BatchInstanceSourceConfig = (
InstancesFromHuggingFace | InstancesFromFile | SWEBenchInstances | ExpertInstancesFromFile | SWESmithInstances
)

View File

@@ -0,0 +1,387 @@
"""Common functionality for the run scripts."""
import json
import sys
from argparse import ArgumentParser
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from types import UnionType
from typing import Any
import yaml
from pydantic import ValidationError
from pydantic_settings import BaseSettings, CliApp, SettingsError
from rich import print as rich_print
from rich.panel import Panel
from sweagent import CONFIG_DIR
from sweagent.types import AgentInfo, AgentRunResult
from sweagent.utils.log import get_logger
from sweagent.utils.serialization import merge_nested_dicts
def _shorten_strings(data, *, max_length=30):
"""
Recursively shortens all strings in a nested data structure to a maximum length.
Args:
data: The nested data structure (dicts, lists, and strings).
max_length: The maximum length for strings.
Returns:
The modified data structure with shortened strings.
"""
if isinstance(data, str):
# Shorten the string if it exceeds the max length
data = data.replace("\n", "\\n")
return data[: max_length - 3] + "..."
elif isinstance(data, list):
# Recursively process each item in the list
return [_shorten_strings(item, max_length=max_length) for item in data]
elif isinstance(data, dict):
# Recursively process each value in the dictionary
return {key: _shorten_strings(value, max_length=max_length) for key, value in data.items()}
else:
# Return the data as is if it's neither a string, list, nor dict
return data
_VALIDATION_ERROR_HELP_TEXT = """
The following errors are raised by Pydantic, trying to instantiate the configuration based on
the merged configuration dictionary [bold](see above)[/bold].
Every new indented block corresponds to a different error from Pydantic.
The first line of each block is the attribute that failed validation, the following lines are the error messages.
If you see many lines of errors, there are probably different ways to instantiate the same object (a union type).
For example, there are different deployments with different options each. Pydantic is then trying
one after the other and reporting the failures for each of them.
More on union types: [link=https://swe-agent.com/latest/usage/cl_tutorial/#union-types]https://swe-agent.com/latest/usage/cl_tutorial/#union-types[/link]
"""
_SETTING_ERROR_HINTS = """
[red][bold]Hints:[/bold][/red]
Run `sweagent <subcommand> --help` for usage examples.
[red][bold]Common mistakes:[/bold][/red]
- You used dashes instead of underscores (wrong: `--num-workers`, correct: `--num_workers`).
- You forgot about part of the hierarchy (wrong: `--model.name`, correct: `--agent.model.name`).
"""
class AutoCorrectSuggestion:
def __init__(
self, original: str, alternative: str = "", *, condition: Callable | None = None, help: str | None = None
):
self.original = original
self.alternative = alternative
self.condition = condition
self.help = help
if self.help and self.alternative:
msg = "Cannot set both help and alternative"
raise ValueError(msg)
def show(self, args: list[str]) -> bool:
no_equal = []
for arg in args:
if "=" in arg:
no_equal.extend(arg.split("="))
else:
no_equal.append(arg)
if self.condition is not None:
return self.condition(no_equal)
return f"--{self.original}" in no_equal
def format(self) -> str:
if self.help:
return self.help
return f"You wrote [red]--{self.original}[/red]. Did you mean [green]--{self.alternative}[/green]?"
class ConfigHelper:
"""Produce easy-to-read help text from pydantic setting objects."""
def _get_type_name(self, item: Any, full: bool = False):
"""Given a config type, return a string that is either the full name or just the class name."""
full_name = str(item).removeprefix("<class '").removesuffix("'>")
if full:
return full_name
return full_name.split(".")[-1]
def _get_value_help_string(self, item: Any, description: str | None):
"""Given an item, document it"""
if hasattr(item, "model_fields"):
# It's a pydantic config class
full_name = self._get_type_name(item, full=True)
name = self._get_type_name(item)
out = f"[green]{name}[/green]\n"
if description:
out += f" {description}\n"
out += f" Run [green]--help_option {full_name}[/green] for more info"
return out
if isinstance(item, UnionType):
name = self._get_type_name(item)
out = ""
if description:
out += f" {description}\n"
out += " This config item can be one of the following things (run [green]--help_option <name>[/green] for more info):\n"
things = str(item).split("|")
for thing in things:
out += f" [green]{thing.strip()}[/green]\n"
return out.strip()
return self._get_type_name(item)
def get_help(self, config_type: type[BaseSettings]) -> str:
lines = []
for name, field_info in config_type.model_fields.items():
line = f"[green][bold]{name}[/bold][/green]: "
line += self._get_value_help_string(field_info.annotation, field_info.description)
lines.append(line)
return "\n\n".join(lines)
def _nested_dict():
"""Helper function to create nested dictionaries."""
return defaultdict(_nested_dict)
def _parse_args_to_nested_dict(args):
"""Parse the command-line arguments into a nested dictionary."""
result = _nested_dict()
i = 0
while i < len(args):
arg = args[i]
if not arg.startswith("--"):
i += 1
continue
# Handle --key=value format
if "=" in arg:
key, value = arg[2:].split("=", 1)
# Handle --key value format
else:
key = arg[2:]
i += 1
if i >= len(args):
break
value = args[i]
# Convert value to int if possible
value = int(value) if value.isdigit() else value
# Build nested dict structure
keys = key.split(".")
current = result
for k in keys[:-1]:
current = current[k]
current[keys[-1]] = value
i += 1
return result
# todo: Parameterize type hints
class BasicCLI:
def __init__(
self,
config_type: type[BaseSettings],
*,
default_settings: bool = True,
help_text: str | None = None,
default_config_file: Path = CONFIG_DIR / "default.yaml",
):
"""This class implements a basic CLI for SWE-agent. It is based on pydantic-settings, i.e., takes
a `BaseSettings` object. In principle you could just initialize these via `pydantic-settings`'s `CliApp.run`,
however, we also want to add a `--config` option to load additional config files and some other things.
We also try to improve a bit on the pydantic error messages in here.
Args:
config_type: The type of the configuration object to instantiate.
default_settings: Whether to load the default settings.
help_text: If given, this will override the default help text that would usually be shown
by argparse.
"""
self.arg_type = config_type
self.default_settings = default_settings
self.logger = get_logger("swea-cli", emoji="🔧")
self.help_text = help_text
self.default_config_file = default_config_file
def maybe_show_auto_correct(self, args: list[str]):
auto_correct = []
if hasattr(self.arg_type, "_get_auto_correct"):
for ac in self.arg_type._get_auto_correct(): # type: ignore
if ac.show(args):
auto_correct.append(ac)
if auto_correct:
rich_print(
Panel.fit(
"[red][bold]Auto-correct suggestions[/bold][/red]\n\n"
+ "\n".join(ac.format() for ac in auto_correct),
)
)
def get_config(self, args: list[str] | None = None) -> BaseSettings:
"""Get the configuration object from defaults and command arguments."""
# >>> Step 1: Use argparse to add a --config option to load whole config files
# The defaults if no config file is provided
# Otherwise, the configs from the respective classes will be used
parser = ArgumentParser(description=__doc__, add_help=False)
parser.add_argument(
"--config",
type=Path,
action="append",
default=[],
help=(
"Load additional config files. Use this option multiple times to load "
"multiple files, e.g., --config config1.yaml --config config2.yaml"
),
)
parser.add_argument(
"-h",
"--help",
help="Show help text and exit",
action="store_true",
)
parser.add_argument(
"--help_option",
help="Show help text for a specific option",
)
if self.default_settings:
parser.add_argument(
"--no_config_file",
action="store_true",
help="Do not load default config file when no config file is provided",
)
parser.add_argument(
"--print_config",
action="store_true",
help="Print the final config and exit",
)
# >>> Step 2: Parse argparse arguments but keep all the remaining arguments.
# Explicitly handle --help and --print-options
cli_args, remaining_args = parser.parse_known_args(args)
if cli_args.help:
if self.help_text:
rich_print(self.help_text)
else:
parser.print_help()
exit(0)
if cli_args.help_option:
module, _, name = cli_args.help_option.rpartition(".")
if module not in sys.modules:
__import__(module)
type_ = getattr(sys.modules[module], name)
rich_print(ConfigHelper().get_help(type_))
exit(0)
# >>> Step 3: Load config files and merge them in a big nested data structure
config_merged = {}
config_files = []
if cli_args.config:
config_files.extend(cli_args.config)
for _f in cli_args.config:
txt = Path(_f).read_text()
if not txt.strip():
self.logger.warning(f"Config file {_f} is empty")
continue
_loaded = yaml.safe_load(txt)
merge_nested_dicts(config_merged, _loaded)
elif self.default_settings and not cli_args.no_config_file:
config_file = self.default_config_file
config_files.append(config_file)
msg = (
f"Loading default config from {config_file}, because no other "
"config file is specified. Specify --no_config_file to disable this."
)
self.logger.info(msg)
txt = config_file.read_text()
if not txt.strip():
self.logger.warning(f"Default config file {config_file} is empty")
config_merged = {}
else:
config_merged = yaml.safe_load(txt)
else:
config_merged = {}
# For informational purposes, we also merge in the command line options
cl_options_dict = _parse_args_to_nested_dict(remaining_args)
# >>> Step 4: Bring together remaining arguments and the merged config to initialize the config object
# This is done by CliApp.run from pydantic-settings
try:
config: BaseSettings = CliApp.run(self.arg_type, remaining_args, **config_merged, cli_exit_on_error=False) # type: ignore
except ValidationError as e:
rich_print(
Panel.fit(
"[red][bold]Configuration from config files\n[/bold]"
"This is all the configuration that was provided from defaults, --config, and CLI arguments[/red]\n\n"
+ yaml.dump(_shorten_strings(config_merged))
)
)
rich_print(
Panel.fit(
"[red][bold]Configuration from CLI arguments\n[/bold]"
"This is all the configuration that was provided from the command line arguments[/red]\n\n"
+ yaml.dump(_shorten_strings(cl_options_dict))
)
)
rich_print(
Panel.fit(
"[red][bold]Merged configuration\n[/bold]"
"This is the merged configuration that was used to instantiate the config object[/red]\n\n"
+ yaml.dump(_shorten_strings(merge_nested_dicts(config_merged, cl_options_dict)))
)
)
rich_print(
Panel.fit(
"[red][bold]Validation error[/bold]\n" + _VALIDATION_ERROR_HELP_TEXT + "[/red]\n" + str(e),
)
)
self.maybe_show_auto_correct(remaining_args)
msg = "Invalid configuration. Please check the above output."
raise RuntimeError(msg) from None
except SettingsError as e:
rich_print(Panel.fit("[red][bold]SettingsError[/bold][/red]\n\n" + str(e) + "\n\n" + _SETTING_ERROR_HINTS))
self.maybe_show_auto_correct(remaining_args)
msg = "Invalid command line arguments. Please check the above output in the box."
raise RuntimeError(msg) from None
if cli_args.print_config: # type: ignore
print(yaml.dump(config.model_dump()))
exit(0)
# Attach config files to the arg object, because we need them for file naming purposes
# (the output traj directory is named after the last config file)
config._config_files = config_files # type: ignore
return config
def save_predictions(traj_dir: Path, instance_id: str, result: AgentRunResult):
"""Save predictions in a file readable by SWE-bench"""
output_file = traj_dir / instance_id / (instance_id + ".pred")
output_file.parent.mkdir(parents=True, exist_ok=True)
datum = {
"model_name_or_path": traj_dir.name,
"instance_id": instance_id,
"model_patch": result.info.get("submission"),
}
output_file.write_text(json.dumps(datum))
def _is_promising_patch(info: AgentInfo) -> bool:
"""Do we actually believe that the patch will solve the issue?
Or are we just submitting the last patch we generated before hitting an error?
"""
# The exit status can also be `submitted (exit_cost)` etc.
return info.get("exit_status") == "submitted" and info.get("submission") is not None

View File

@@ -0,0 +1,123 @@
import argparse
import json
from pathlib import Path
from tabulate import tabulate
def get_resolved(path: Path) -> set[str]:
data = json.loads(path.read_text())
if "resolved" in data:
data["resolved_ids"] = data["resolved"]
return set(data["resolved_ids"])
def get_submitted(path: Path) -> set[str]:
return set(json.loads(path.read_text())["submitted_ids"])
def stats_single(path: Path) -> None:
evaluated_ids = sorted(get_submitted(path))
resolved_ids = sorted(get_resolved(path))
print(f"Total evaluated: {len(evaluated_ids)}")
print(f"Total resolved: {len(resolved_ids)}")
def compare_many(paths: list[Path]) -> None:
evaluated_ids = {}
resolved_ids = {}
for path in paths:
evaluated_ids[path] = sorted(get_submitted(path))
resolved_ids[path] = sorted(get_resolved(path))
header: list[str] = ["ID"] + [str(i) for i in range(len(paths))] + ["Success rate"]
table: list[list[str | float | int]] = []
def get_emoji(id: str, path: Path) -> str:
if id not in evaluated_ids[path]:
return ""
if id in resolved_ids[path]:
return ""
return ""
ids_to_compare = set(evaluated_ids[paths[0]])
for id in sorted(ids_to_compare):
row = [id] + [get_emoji(id, path) for path in paths]
n_success = sum(id in resolved_ids[path] for path in paths)
n_evaluated = sum(id in evaluated_ids[path] for path in paths)
row.append(f"{n_success / n_evaluated:.2f}")
table.append(row)
successes: list[str | float] = ["Successes"]
success_rates: list[str | float] = ["Success rates"]
for path in paths:
n_success = sum(id in resolved_ids[path] for id in ids_to_compare)
n_evaluated = sum(id in evaluated_ids[path] for id in ids_to_compare)
successes.append(n_success)
success_rates.append(f"{n_success / n_evaluated:.2f}")
table.append(successes)
table.append(success_rates)
print(tabulate(table, headers=header))
print()
header: list[str] = ["#", "ID", "Successes", "Success rate"]
table: list[list[str | float | int]] = []
for i, path in enumerate(paths):
row = [i, path.parent.name, successes[i + 1], success_rates[i + 1]]
table.append(row)
print(tabulate(table, headers=header))
def compare_pair(new_path: Path, old_path: Path, *, show_same=False) -> None:
evaluated_ids = sorted(get_submitted(new_path))
resolved_ids = sorted(get_resolved(new_path))
old_evaluated_ids = sorted(get_submitted(old_path))
old_resolved_ids = sorted(get_resolved(old_path))
print(f"Total evaluated: new {len(evaluated_ids)}, old {len(old_evaluated_ids)}")
print(f"Total resolved: new {len(resolved_ids)}, old {len(old_resolved_ids)}")
print("-" * 80)
print("Emoji legend:")
print("❓: Not evaluated in old version, so guessing it's either 😀 or 👾")
print("😀: Newly resolved in new version")
print("✅: Resolved in both")
print("❌: Resolved in old, not in new")
print("👾: Unresolved in both")
print("-" * 80)
for id in evaluated_ids:
resolved_now = id in resolved_ids
resolved_before = id in old_resolved_ids
if id not in old_evaluated_ids and resolved_now:
emoji = "😀❓"
elif id not in old_evaluated_ids and not resolved_now:
emoji = "👾❓"
elif resolved_now and not resolved_before:
emoji = "😀"
elif resolved_now and resolved_before:
emoji = ""
if not show_same:
continue
elif not resolved_now and resolved_before:
emoji = ""
else:
emoji = "👾"
if not show_same:
continue
print(f"{emoji} {id}")
def run_from_cli(_args: list[str] | None = None) -> None:
def get_preds_path(path: Path) -> Path:
if path.is_dir():
return path / "results.json"
return path
parser = argparse.ArgumentParser()
parser.add_argument("paths", type=Path, nargs="+")
parser.add_argument("--show-same", action="store_true")
args = parser.parse_args(_args)
args.paths = [get_preds_path(path) for path in args.paths]
if len(args.paths) == 1:
stats_single(args.paths[0])
elif len(args.paths) == 2:
compare_pair(args.paths[0], args.paths[1], show_same=args.show_same)
else:
compare_many(args.paths)

View File

@@ -0,0 +1,19 @@
"""If for some reason the .pred file isn't saved, we can extract it from the .traj file."""
import argparse
import json
from pathlib import Path
def run_from_cli(_args: list[str] | None = None):
parser = argparse.ArgumentParser()
parser.add_argument("traj_path", type=Path)
args = parser.parse_args(_args)
data = json.loads(args.traj_path.read_text())
pred_path = args.traj_path.with_suffix(".pred")
pred_data = {
"model_name_or_path": args.traj_path.resolve().parent.parent.name,
"model_patch": data["info"]["submission"],
"instance_id": args.traj_path.resolve().parent.name,
}
pred_path.write_text(json.dumps(pred_data))

View File

View File

@@ -0,0 +1,67 @@
from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig
from sweagent.environment.swe_env import SWEEnv
from sweagent.types import AgentRunResult
class RunHook:
"""Hook structure for the web server or other addons to interface with"""
def on_init(self, *, run):
"""Called when hook is initialized"""
def on_start(self):
"""Called at the beginning of `Main.main`"""
def on_end(self):
"""Called at the end of `Main.main`"""
def on_instance_start(
self, *, index: int, env: SWEEnv, problem_statement: ProblemStatement | ProblemStatementConfig
):
"""Called at the beginning of each instance loop in `Main.run`"""
def on_instance_skipped(
self,
):
"""Called when an instance is skipped in `Main.run`"""
def on_instance_completed(self, *, result: AgentRunResult):
"""Called when an instance is completed in `Main.run`"""
class CombinedRunHooks(RunHook):
def __init__(self):
self._hooks = []
def add_hook(self, hook: RunHook) -> None:
self._hooks.append(hook)
@property
def hooks(self) -> list[RunHook]:
return self._hooks
def on_init(self, *, run):
for hook in self._hooks:
hook.on_init(run=run)
def on_start(self):
for hook in self._hooks:
hook.on_start()
def on_end(self):
for hook in self._hooks:
hook.on_end()
def on_instance_start(
self, *, index: int, env: SWEEnv, problem_statement: ProblemStatement | ProblemStatementConfig
):
for hook in self._hooks:
hook.on_instance_start(index=index, env=env, problem_statement=problem_statement)
def on_instance_skipped(self):
for hook in self._hooks:
hook.on_instance_skipped()
def on_instance_completed(self, *, result: AgentRunResult):
for hook in self._hooks:
hook.on_instance_completed(result=result)

View File

@@ -0,0 +1,110 @@
import subprocess
import threading
from pathlib import Path
import rich
import rich.markdown
import rich.panel
from sweagent.agent.problem_statement import ProblemStatementConfig
from sweagent.environment.repo import LocalRepoConfig
from sweagent.environment.swe_env import SWEEnv
from sweagent.run.common import _is_promising_patch
from sweagent.run.hooks.abstract import RunHook
from sweagent.types import AgentRunResult
from sweagent.utils.log import get_logger
class SaveApplyPatchHook(RunHook):
"""This hook saves patches to a separate directory and optionally applies them to a local repository."""
def __init__(self, apply_patch_locally: bool = False, show_success_message: bool = True):
self.logger = get_logger("swea-save_apply_patch", emoji="⚡️")
self._apply_patch_locally = apply_patch_locally
self._show_success_message = show_success_message
# Thread-local storage so that concurrent workers in run-batch do not
# overwrite each other's per-instance state (_env, _problem_statement).
self._local = threading.local()
def on_init(self, *, run):
self._output_dir = Path(run.output_dir)
def on_instance_start(self, *, index: int, env: SWEEnv, problem_statement: ProblemStatementConfig):
self._local.env = env
self._local.problem_statement = problem_statement
def on_instance_completed(self, *, result: AgentRunResult):
instance_id = self._local.problem_statement.id
patch_path = self._save_patch(instance_id, result.info)
if patch_path:
if not self._apply_patch_locally:
return
if not _is_promising_patch(result.info):
return
if self._local.env.repo is None:
return
if not isinstance(self._local.env.repo, LocalRepoConfig):
return
local_dir = Path(self._local.env.repo.path)
self._apply_patch(patch_path, local_dir)
@staticmethod
def _print_patch_message(patch_output_file: Path):
console = rich.console.Console()
msg = [
"SWE-agent has produced a patch that it believes will solve the issue you submitted!",
"Use the code snippet below to inspect or apply it!",
]
panel = rich.panel.Panel.fit(
"\n".join(msg),
title="🎉 Submission successful 🎉",
)
console.print(panel)
content = [
"```bash",
"# The patch has been saved to your local filesystem at:",
f"PATCH_FILE_PATH='{patch_output_file.resolve()}'",
"# Inspect it:",
'cat "${PATCH_FILE_PATH}"',
"# Apply it to a local repository:",
"cd <your local repo root>",
'git apply "${PATCH_FILE_PATH}"',
"```",
]
console.print(rich.markdown.Markdown("\n".join(content)))
def _save_patch(self, instance_id: str, info) -> Path | None:
"""Create patch files that can be applied with `git am`.
Returns:
The path to the patch file, if it was saved. Otherwise, returns None.
"""
patch_output_dir = self._output_dir / instance_id
patch_output_dir.mkdir(exist_ok=True, parents=True)
patch_output_file = patch_output_dir / f"{instance_id}.patch"
if info.get("submission") is None:
self.logger.info("No patch to save.")
return None
model_patch = info["submission"]
patch_output_file.write_text(model_patch)
if _is_promising_patch(info):
# Only print big congratulations if we actually believe
# the patch will solve the issue
if self._show_success_message:
self._print_patch_message(patch_output_file)
return patch_output_file
def _apply_patch(self, patch_file: Path, local_dir: Path) -> None:
"""Apply a patch to a local directory."""
assert local_dir.is_dir()
assert patch_file.exists()
# The resolve() is important, because we're gonna run the cmd
# somewhere else
cmd = ["git", "apply", str(patch_file.resolve())]
try:
subprocess.run(cmd, cwd=local_dir, check=True)
except subprocess.CalledProcessError as e:
self.logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
return
self.logger.info(f"Applied patch {patch_file} to {local_dir}")

View File

@@ -0,0 +1,244 @@
import os
import random
import shlex
from ghapi.all import GhApi
from pydantic import BaseModel
from sweagent.environment.swe_env import SWEEnv
from sweagent.run.hooks.abstract import RunHook
from sweagent.types import AgentRunResult
from sweagent.utils.github import (
InvalidGithubURL,
_get_associated_commit_urls,
_get_gh_issue_data,
_parse_gh_issue_url,
)
from sweagent.utils.log import get_logger
# NOTE
# THE IMPLEMENTATION DETAILS HERE WILL CHANGE SOON!
# fixme: Bring back the ability to open the PR to a fork
def open_pr(*, logger, token, env: SWEEnv, github_url, trajectory, _dry_run: bool = False) -> None:
"""Create PR to repository
Args:
trajectory: Trajectory of actions taken by the agent
_dry_run: Whether to actually push anything or just simulate it
"""
issue_url = github_url
logger.info("Opening PR")
try:
issue = _get_gh_issue_data(issue_url, token=token)
except InvalidGithubURL as e:
msg = "Data path must be a github issue URL if open_pr is set to True."
raise ValueError(msg) from e
branch_name = f"swe-agent-fix-#{issue.number}-" + str(random.random())[2:10]
env.communicate(
input="git config user.email 'noemail@swe-agent.com' && git config user.name 'SWE-agent'",
error_msg="Failed to set git user",
timeout=10,
check="raise",
)
env.communicate(input="rm -f model.patch", error_msg="Failed to remove model patch", timeout=10, check="raise")
env.communicate(
input=f"git checkout -b {branch_name}", error_msg="Failed to switch to new branch", timeout=10, check="raise"
)
env.communicate(input="git add .", error_msg="Failed to add commits", timeout=10, check="raise")
dry_run_flag = "--allow-empty" if _dry_run else ""
commit_msg = [
shlex.quote(f"Fix: {issue.title}"),
shlex.quote(f"Closes #{issue.number}"),
]
out = env.communicate(
input=f"git commit -m {commit_msg[0]} -m {commit_msg[1]} {dry_run_flag}",
error_msg="Failed to commit changes",
timeout=10,
check="raise",
)
logger.debug(f"Committed changes: {out}")
owner, repo, _ = _parse_gh_issue_url(issue_url)
# fixme: bring this back
# If `--repo_path` was specified with a different github URL, then the record will contain
# the forking user
forker = owner
head = branch_name
remote = "origin"
if forker != owner:
head = f"{forker}:{branch_name}"
token_prefix = ""
if token:
token_prefix = f"{token}@"
fork_url = f"https://{token_prefix}github.com/{forker}/{repo}.git"
logger.debug(f"Using fork: {fork_url}")
env.communicate(
input=f"git remote add fork {fork_url}",
error_msg="Failed to create new git remote",
timeout=10,
)
remote = "fork"
dry_run_prefix = "echo " if _dry_run else ""
out = env.communicate(
input=f"{dry_run_prefix} git push {remote} {branch_name}",
error_msg=(
"Failed to push branch to remote. Please check your token and permissions. "
"You might want to push to a fork with the push_gh_repo_url option."
),
timeout=10,
)
logger.debug(f"Pushed commit to {remote=} {branch_name=}: {out}")
body = (
f"This is a PR opened by AI tool [SWE Agent](https://github.com/SWE-agent/SWE-agent/) "
f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}."
)
body += "\n\n" + format_trajectory_markdown(trajectory, char_limit=60_000)
api = GhApi(token=token)
default_branch = api.repos.get(owner, repo).default_branch
if not _dry_run:
args = dict(
owner=owner,
repo=repo,
title=f"SWE-agent[bot] PR to fix: {issue.title}",
head=head,
base=default_branch,
body=body,
draft=True,
)
logger.debug(f"Creating PR with args: {args}")
pr_info = api.pulls.create(**args) # type: ignore
logger.info(
f"🎉 PR created as a draft at {pr_info.html_url}. Please review it carefully, push "
"any required changes onto the branch and then click "
"'Ready for Review' to bring it to the attention of the maintainers.",
)
class OpenPRConfig(BaseModel):
# Option to be used with open_pr: Skip action if there are already commits claiming
# to fix the issue. Please only set this to False if you are sure the commits are
# not fixes or if this is your own repository!
skip_if_commits_reference_issue: bool = True
class OpenPRHook(RunHook):
"""This hook opens a PR if the issue is solved and the user has enabled the option."""
def __init__(self, config: OpenPRConfig):
self.logger = get_logger("swea-open_pr", emoji="⚡️")
self._config = config
def on_init(self, *, run):
self._env = run.env
self._token: str = os.getenv("GITHUB_TOKEN", "")
self._problem_statement = run.problem_statement
def on_instance_completed(self, result: AgentRunResult):
if self.should_open_pr(result):
open_pr(
logger=self.logger,
token=self._token,
env=self._env,
github_url=self._problem_statement.github_url,
trajectory=result.trajectory,
)
def should_open_pr(self, result: AgentRunResult) -> bool:
"""Does opening a PR make sense?"""
if not result.info.get("submission"):
self.logger.info("Not opening PR because no submission was made.")
return False
if result.info.get("exit_status") != "submitted":
self.logger.info(
"Not opening PR because exit status was %s and not submitted.", result.info.get("exit_status")
)
return False
try:
issue = _get_gh_issue_data(self._problem_statement.github_url, token=self._token)
except InvalidGithubURL:
self.logger.info("Currently only GitHub is supported to open PRs to. Skipping PR creation.")
return False
if issue.state != "open":
self.logger.info(f"Issue is not open (state={issue.state}. Skipping PR creation.")
return False
if issue.assignee:
self.logger.info("Issue is already assigned. Skipping PR creation. Be nice :)")
return False
if issue.locked:
self.logger.info("Issue is locked. Skipping PR creation.")
return False
org, repo, issue_number = _parse_gh_issue_url(self._problem_statement.github_url)
associated_commits = _get_associated_commit_urls(org, repo, issue_number, token=self._token)
if associated_commits:
commit_url_strs = ", ".join(associated_commits)
if self._config.skip_if_commits_reference_issue:
self.logger.info(f"Issue already has associated commits (see {commit_url_strs}). Skipping PR creation.")
return False
else:
self.logger.warning(
"Proceeding with PR creation even though there are already commits "
f"({commit_url_strs}) associated with the issue. Please only do this for your own repositories "
"or after verifying that the existing commits do not fix the issue.",
)
return True
def _remove_triple_backticks(text: str) -> str:
return "\n".join(line.removeprefix("```") for line in text.splitlines())
def format_trajectory_markdown(trajectory: list[dict[str, str]], char_limit: int | None = None):
"""Format a trajectory as a markdown string for use in gh PR description.
Args:
char_limit: If not None, truncate the trajectory to this many characters.
"""
prefix = [
"<details>",
"<summary>Thought process ('trajectory') of SWE-agent (click to expand)</summary>",
"",
"",
]
prefix_text = "\n".join(prefix)
suffix = [
"",
"</details>",
]
suffix_text = "\n".join(suffix)
steps = []
current_length = len(prefix_text) + len(suffix_text)
for i, step in enumerate(trajectory):
step_strs = [
f"**🧑‍🚒 Response ({i})**: ",
f"{step['response'].strip()}",
f"**👀‍ Observation ({i})**:",
"```",
f"{_remove_triple_backticks(step['observation']).strip()}",
"```",
]
step_text = "\n".join(step_strs)
# Calculate separator length (only needed for steps after the first one)
separator_length = 0
if steps:
separator_length = len("\n\n---\n\n")
# Check if adding this step would exceed the character limit
if char_limit is not None and current_length + separator_length + len(step_text) > char_limit:
if i > 0:
steps.append("\n\n... (truncated due to length limit)")
break
if steps:
steps.append("\n\n---\n\n")
current_length += separator_length
steps.append(step_text)
current_length += len(step_text)
return prefix_text + "".join(steps) + suffix_text

View File

@@ -0,0 +1,113 @@
"""SweBench evaluation hook.
Will be automatically added to `run_batch` if `SWEBenchInstances.evaluate` is set to true
"""
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from threading import Lock
from time import time
from sweagent.run.hooks.abstract import RunHook
from sweagent.run.merge_predictions import merge_predictions
from sweagent.types import AgentRunResult
from sweagent.utils.log import get_logger
class SweBenchEvaluate(RunHook):
_SUBSET_MAP = {"lite": "swe-bench_lite", "verified": "swe-bench_verified", "multimodal": "swe-bench_multimodal"}
def __init__(self, output_dir: Path, subset: str, split: str, continuous_submission_every: int = 0) -> None:
super().__init__()
self.output_dir = output_dir
self.subset = subset
self.split = split
self.continuous_submission_every = continuous_submission_every
self.logger = get_logger("SB-evaluate", emoji="😬")
self.merge_lock = Lock()
self.last_evaluation_time = time()
self.evaluation_interval = continuous_submission_every
self._running_calls = []
# We need to add a suffix to the run_id to avoid collisions when you reuse the name of your run
self._time_suffix = datetime.now().strftime("%Y%m%d%H%M%S%f")
@property
def run_id(self) -> str:
return f"{self.output_dir.name}_{self._time_suffix}"
def _get_sb_call(self, preds_path: Path, submit_only: bool = False) -> list[str]:
args = [
"sb-cli",
"submit",
self._SUBSET_MAP[self.subset],
self.split,
"--predictions_path",
str(preds_path),
"--run_id",
self.run_id,
"--output_dir",
str(self.output_dir / "sb-cli-reports"),
]
if submit_only:
args.extend(["--wait_for_evaluation", "0", "--gen_report", "0", "--verify_submission", "0"])
return args
def check_running_calls(self) -> None:
"""Warn if one of the running calls failed."""
for call in self._running_calls:
if call.poll() is not None:
if call.returncode != 0:
self.logger.error("Failed to submit results to SweBench eval: %s", call.stderr.read())
self._running_calls.remove(call)
def on_instance_completed(self, *, result: AgentRunResult):
if self.evaluation_interval == 0:
return
current_time = time()
if current_time - self.last_evaluation_time < self.evaluation_interval:
return
with self.merge_lock:
merge_predictions([self.output_dir], self.output_dir / "tmppreds.json")
self.last_evaluation_time = current_time
self._running_calls.append(
subprocess.Popen(
self._get_sb_call(preds_path=self.output_dir / "tmppreds.json", submit_only=True),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
)
def move_sb_cli_report(self) -> None:
"""Move report from `sb-cli-reports` to `results.json`."""
output_dir = self.output_dir / "sb-cli-reports"
if not output_dir.exists():
self.logger.warning("No SweBench report found at %s", output_dir)
return
(self.output_dir / "results.json").unlink(missing_ok=True)
reports = list(output_dir.glob("*.json"))
if len(reports) != 1:
self.logger.warning("Expected 1 SweBench report at %s, found %d. Cannot rename.", output_dir, len(reports))
return
reports[0].rename(self.output_dir / "results.json")
def on_end(self) -> None:
self.logger.info("Submitting results to SWE-Bench")
try:
subprocess.run(
self._get_sb_call(preds_path=self.output_dir / "preds.json"),
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
except subprocess.CalledProcessError as e:
self.logger.error("Failed to submit results to SweBench eval: %s", e)
else:
# remove temporary predictions if they exist
if (self.output_dir / "tmppreds.json").exists():
(self.output_dir / "tmppreds.json").unlink()
self.move_sb_cli_report()

View File

@@ -0,0 +1,493 @@
"""This is a command line tool to inspect trajectory JSON files."""
import argparse
import collections
import copy
import json
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from rich.syntax import Syntax
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Vertical, VerticalScroll
from textual.screen import ModalScreen
from textual.widgets import Footer, Header, Input, ListItem, ListView, Static
from sweagent.utils.files import load_file
from sweagent.utils.serialization import _yaml_serialization_with_linebreaks
def _move_items_top(d: dict, keys: list[str]) -> dict:
"""Reorder items in a dictionary.
The first keys will be those specified in `keys`, the rest will
be in the same order as in the original dictionary.
"""
new_d = {}
for key in keys:
if key in d:
new_d[key] = d[key]
for key in d.keys():
if key not in keys:
new_d[key] = d[key]
return new_d
class TrajectoryViewer(Static):
BINDINGS = [
Binding("right,l", "next_item", "Step++"),
Binding("left,h", "previous_item", "Step--"),
Binding("0", "first_item", "Step=0"),
Binding("$", "last_item", "Step=-1"),
Binding("v", "toggle_view", "Toggle view"),
Binding("j,down", "scroll_down", "Scroll down"),
Binding("k,up", "scroll_up", "Scroll up"),
]
def __init__(self, path: Path, title: str, overview_stats: dict, *, gold_patch: str | None = None):
"""View a single trajectory."""
super().__init__()
self.i_step = -1
self.trajectory = json.loads(path.read_text())
self.show_full = False
self.title = title
self.overview_stats = overview_stats
self.gold_patch = gold_patch
def load_trajectory(self, path: Path, title: str, overview_stats: dict, *, gold_patch: str | None = None):
"""Load a new trajectory and update the viewer."""
print("Loading", path)
self.trajectory = json.loads(path.read_text())
self.title = title
self.gold_patch = gold_patch
self.overview_stats = overview_stats
self.scroll_top()
self.i_step = -1
self.update_content()
def compose(self) -> ComposeResult:
with VerticalScroll():
yield Static(id="content", markup=False)
def on_mount(self) -> None:
self.update_content()
@property
def n_steps(self) -> int:
return len(self.trajectory["trajectory"])
def _show_step_yaml(self, item: dict) -> None:
"""Show full yaml of trajectory item"""
content_str = _yaml_serialization_with_linebreaks(
_move_items_top(item, ["thought", "action", "observation", "response", "execution_time"])
)
syntax = Syntax(content_str, "yaml", theme="monokai", word_wrap=True)
content = self.query_one("#content")
content.update(syntax) # type: ignore
self.app.sub_title = f"{self.title} - Step {self.i_step + 1}/{self.n_steps} - Full View"
def _show_step_simple(self, item: dict) -> None:
# Simplified view - show action and observation as plain text
thought = item.get("thought", "")
action = item.get("action", "")
observation = item.get("observation", "")
content_str = f"THOUGHT:\n{thought}\n\nACTION:\n{action}\n\nOBSERVATION:\n{observation}"
content = self.query_one("#content")
content.update(content_str) # type: ignore
self.app.sub_title = f"{self.title} - Step {self.i_step + 1}/{self.n_steps} - Simple View"
def _show_info(self):
info = copy.deepcopy(self.trajectory["info"])
info["result"] = self.overview_stats["result"]
info["gold_patch"] = self.gold_patch
info = _move_items_top(info, ["result", "exit_status", "model_stats", "submission", "gold_patch"])
syntax = Syntax(_yaml_serialization_with_linebreaks(info), "yaml", theme="monokai", word_wrap=True)
content = self.query_one("#content")
content.update(syntax) # type: ignore
next_help = "Press l to see step 1" if self.i_step < 0 else f"Press h to see step {self.n_steps}"
self.app.sub_title = f"{self.title} - Info ({next_help})"
def update_content(self) -> None:
print(self.i_step)
if self.i_step < 0 or self.i_step >= self.n_steps:
return self._show_info()
item = self.trajectory["trajectory"][self.i_step]
if self.show_full:
return self._show_step_yaml(item)
return self._show_step_simple(item)
def action_next_item(self) -> None:
if self.i_step < self.n_steps:
self.i_step += 1
self.scroll_top()
self.update_content()
def action_previous_item(self) -> None:
if self.i_step > -1:
self.i_step -= 1
self.scroll_top()
self.update_content()
def action_toggle_view(self) -> None:
self.show_full = not self.show_full
self.update_content()
def action_first_item(self) -> None:
self.i_step = 0
self.update_content()
def action_last_item(self) -> None:
self.i_step = self.n_steps - 1
self.update_content()
def scroll_top(self) -> None:
"""Resets scrolling viewport"""
vs = self.query_one(VerticalScroll)
vs.scroll_home(animate=False)
def action_scroll_down(self) -> None:
vs = self.query_one(VerticalScroll)
vs.scroll_to(y=vs.scroll_target_y + 15)
def action_scroll_up(self) -> None:
vs = self.query_one(VerticalScroll)
vs.scroll_to(y=vs.scroll_target_y - 15)
class TrajectorySelectorScreen(ModalScreen[int]):
BINDINGS = [
Binding("escape", "dismiss(None)", "Cancel"),
]
def __init__(self, paths: list[Path], current_index: int, overview_stats: dict):
super().__init__()
self.paths = paths
self.current_index = current_index
self.overview_stats = overview_stats
self.all_items = [] # Store all items for filtering
self.filtered_indices = []
def _get_list_item_texts(self, paths: list[Path]) -> list[str]:
"""Remove the common prefix from a list of paths."""
prefix = os.path.commonpath([str(p) for p in paths])
labels = []
for p in paths:
ostat = self.overview_stats[p.stem]
ostat_str = f"{ostat['exit_status']} {ostat['result']} ${ostat['cost']:.2f} {ostat['api_calls']} calls"
shortened_path = str(p)[len(prefix) :].lstrip("/\\")
if Path(shortened_path).stem == Path(shortened_path).parent.name:
# We have the instance ID twice (in the folder and the traj)
shortened_path = Path(shortened_path).stem
labels.append(f"{shortened_path} - {ostat_str}")
return labels
def compose(self) -> ComposeResult:
with Vertical(id="dialog"):
yield Static(
"Press <TAB> to switch between search and list. Use <ARROW KEY>/<ENTER> to select.",
id="title",
markup=False,
)
yield Input(placeholder="Type to filter (auto-select if only one item remains)...", id="filter-input")
yield ListView(
*[ListItem(Static(p, markup=False)) for p in self._get_list_item_texts(self.paths)],
id="trajectory-list",
initial_index=self.current_index,
)
# Store all items for later filtering
self.all_items = self._get_list_item_texts(self.paths)
self.filtered_indices = list(range(len(self.all_items)))
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter list items based on input"""
filter_text = event.value.lower()
list_view = self.query_one("#trajectory-list", ListView)
# Filter items and keep track of original indices
self.filtered_indices = [i for i, item in enumerate(self.all_items) if filter_text in item.lower()]
filtered_items = [self.all_items[i] for i in self.filtered_indices]
if len(filtered_items) == 1:
# Find the index of the filtered item in the original list
selected_index = self.all_items.index(filtered_items[0])
self.dismiss(selected_index)
return
# Update ListView with filtered items
list_view.clear()
for item in filtered_items:
list_view.append(ListItem(Static(item, markup=False)))
def on_list_view_selected(self, event: ListView.Selected) -> None:
# Map the filtered index back to the original index
original_index = self.filtered_indices[event.list_view.index]
print(f"Selected index: {original_index}")
self.dismiss(original_index)
CSS = """
#dialog {
background: $surface;
padding: 1;
border: thick $primary;
width: 100%;
height: 100%;
}
#title {
text-align: center;
padding: 1;
}
#filter-input {
dock: top;
margin: 1 0;
}
ListView {
height: 100%;
border: solid $primary;
}
ListItem {
padding: 0 1;
}
ListItem:hover {
background: $accent;
}
"""
class FileViewerScreen(ModalScreen):
BINDINGS = [
Binding("q,escape", "dismiss", "Back"),
Binding("j,down", "scroll_down", "Scroll down"),
Binding("k,up", "scroll_up", "Scroll up"),
Binding("e", "open_editor", "Open in $EDITOR"),
]
def __init__(self, path: Path):
super().__init__()
self.path = path
def compose(self) -> ComposeResult:
with VerticalScroll():
text = self.path.read_text()
truncated = False
if len(text) > 10_000:
# More than ~1000 lines
self.app.notify(
"File is too large to display. Showing first 10k chars. Use e to open in editor.",
severity="warning",
)
text = text[:10_000]
truncated = True
if self.path.exists():
if self.path.suffix == ".traj" and not truncated:
# Syntax highlighting breaks if we truncate
content_str = _yaml_serialization_with_linebreaks(json.loads(text))
syntax = Syntax(content_str, "yaml", theme="monokai", word_wrap=True)
yield Static(syntax, markup=False)
else:
yield Static(text, markup=False)
else:
yield Static(f"No file found at {self.path}", markup=False)
def action_scroll_down(self) -> None:
vs = self.query_one(VerticalScroll)
vs.scroll_to(y=vs.scroll_target_y + 15)
def action_scroll_up(self) -> None:
vs = self.query_one(VerticalScroll)
vs.scroll_to(y=vs.scroll_target_y - 15)
async def action_open_editor(self) -> None:
editor = os.environ.get("EDITOR")
if not editor:
self.app.notify("No editor found in $EDITOR environment variable, cannot perform action", severity="error")
return
try:
# Suspend the TUI app to restore terminal state before launching editor
with self.app.suspend():
subprocess.run([editor, str(self.path)], check=True)
except subprocess.CalledProcessError:
pass
CSS = """
ScrollableContainer {
width: 100%;
height: 100%;
background: $surface;
padding: 1;
border: thick $primary;
}
"""
class TrajectoryInspectorApp(App):
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("L", "next_traj", "Traj++"),
Binding("H", "previous_traj", "Traj--"),
Binding("t", "show_traj_selector", "Select Traj"),
Binding("o", "show_log", "View Log"),
Binding("r", "show_full", "Show full"),
]
CSS = """
Screen {
layout: grid;
grid-size: 1;
}
#viewer {
width: 100%;
height: 100%;
}
ScrollView {
width: 100%;
height: 100%;
border: solid green;
}
"""
def __init__(self, input_path: str | Path, data_path: Path | None = None):
super().__init__()
self.input_path = Path(input_path)
if not self.input_path.exists():
msg = f"{self.input_path} doesn't exist"
raise FileNotFoundError(msg)
self.available_traj_paths = self._get_available_trajs()
if not self.available_traj_paths:
msg = "No trajectory *.traj files available"
raise ValueError(msg)
self.trajectory_index = 0
self.overview_stats = collections.defaultdict(dict)
self._build_overview_stats()
self._data = load_file(data_path)
def get_gold_patch(self, instance_id: str) -> str | None:
if self._data is None:
return None
return self._data.get(instance_id, {}).get("patch", None)
def _build_overview_stats(self):
results_path = self.input_path / "results.json"
results = None
if results_path.exists():
results = json.loads(results_path.read_text())
for traj in self.available_traj_paths:
instance_id = traj.stem
if results is None:
result = ""
elif instance_id in results["resolved_ids"]:
result = ""
else:
result = ""
self.overview_stats[instance_id]["result"] = result
def _get_info(traj: Path) -> tuple[str, dict]:
traj_info = json.loads(traj.read_text()).get("info", {})
return traj.stem, traj_info
with ThreadPoolExecutor() as executor:
# Map returns results in the same order as inputs
all_infos = executor.map(_get_info, self.available_traj_paths)
for instance_id, info in all_infos:
self.overview_stats[instance_id]["info"] = info
self.overview_stats[instance_id]["exit_status"] = info.get("exit_status", "?")
self.overview_stats[instance_id]["api_calls"] = info.get("model_stats", {}).get("api_calls", 0)
self.overview_stats[instance_id]["cost"] = info.get("model_stats", {}).get("instance_cost", 0)
def _get_viewer_title(self, index: int) -> str:
instance_id = self.available_traj_paths[index].stem
if len(instance_id) > 20:
instance_id = "..." + instance_id[-17:]
return f"Traj {index + 1}/{len(self.available_traj_paths)} - {instance_id}"
def _load_traj(self):
instance_id = self.available_traj_paths[self.trajectory_index].stem
traj_viewer = self.query_one(TrajectoryViewer)
traj_viewer.load_trajectory(
self.available_traj_paths[self.trajectory_index],
self._get_viewer_title(self.trajectory_index),
self.overview_stats[instance_id],
gold_patch=self.get_gold_patch(instance_id),
)
def _get_available_trajs(self) -> list[Path]:
if self.input_path.is_file():
return [self.input_path]
elif self.input_path.is_dir():
return sorted(self.input_path.rglob("*.traj"))
raise ValueError
def compose(self) -> ComposeResult:
yield Header()
with Container():
yield TrajectoryViewer(
self.available_traj_paths[self.trajectory_index],
self._get_viewer_title(self.trajectory_index),
self.overview_stats[self.available_traj_paths[self.trajectory_index].stem],
)
yield Footer()
def action_next_traj(self):
self.trajectory_index = (self.trajectory_index + 1) % len(self.available_traj_paths)
self._load_traj()
def action_previous_traj(self):
self.trajectory_index = (self.trajectory_index - 1) % len(self.available_traj_paths)
self._load_traj()
async def action_show_traj_selector(self) -> None:
selector = TrajectorySelectorScreen(self.available_traj_paths, self.trajectory_index, self.overview_stats)
def handler(index: int | None):
if index is not None:
self.trajectory_index = index
self._load_traj()
await self.push_screen(selector, handler) # This returns when the modal is dismissed
async def action_show_log(self) -> None:
current_traj = self.available_traj_paths[self.trajectory_index]
log_path = current_traj.with_suffix(".debug.log")
log_viewer = FileViewerScreen(log_path)
await self.push_screen(log_viewer)
async def action_show_full(self) -> None:
"""Show full yaml of trajectory file"""
current_traj = self.available_traj_paths[self.trajectory_index]
viewer = FileViewerScreen(current_traj)
await self.push_screen(viewer)
def main(args: list[str] | None = None):
parser = argparse.ArgumentParser(description="Inspect trajectory JSON files")
parser.add_argument(
"trajectory_path",
help="Path to the trajectory JSON file or directory containing trajectories",
default=os.getcwd(),
nargs="?",
)
parser.add_argument("-d", "--data_path", type=Path, help="Path to the data file to load gold patches from")
parsed_args = parser.parse_args(args)
app = TrajectoryInspectorApp(parsed_args.trajectory_path)
app.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,64 @@
import argparse
import json
from pathlib import Path
from sweagent.utils.log import get_logger
"""Merge multiple predictions into a single file."""
logger = get_logger("merge", emoji="")
def merge_predictions(directories: list[Path], output: Path | None = None) -> None:
"""Merge predictions found in `directories` into a single JSON file.
Args:
directory: Directory containing predictions.
output: Output file. If not provided, the merged predictions will be
written to `directory/preds.json`.
"""
preds = []
for directory in directories:
new = list(directory.rglob("*.pred"))
preds.extend(new)
logger.debug("Found %d predictions in %s", len(new), directory)
logger.info("Found %d predictions", len(preds))
if not preds:
logger.warning("No predictions found in %s", directory)
return
if output is None:
output = directories[0] / "preds.json"
data = {}
for pred in preds:
_data = json.loads(pred.read_text())
instance_id = _data["instance_id"]
if "model_patch" not in _data:
logger.warning("Prediction %s does not contain a model patch. SKIPPING", pred)
continue
# Ensure model_patch is a string
_data["model_patch"] = str(_data["model_patch"]) if _data["model_patch"] is not None else ""
if instance_id in data:
msg = f"Duplicate instance ID found: {instance_id}"
raise ValueError(msg)
data[instance_id] = _data
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(data, indent=4))
logger.info("Wrote merged predictions to %s", output)
def get_cli_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("directories", type=Path, help="Directory containing predictions", nargs="+")
parser.add_argument("--output", type=Path, help="Output file")
return parser
def run_from_cli(args: list[str] | None = None) -> None:
cli_parser = get_cli_parser()
cli_args = cli_parser.parse_args(args)
merge_predictions(cli_args.directories, cli_args.output)
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,96 @@
#!/usr/bin/env python3
import argparse
import collections
import json
from pathlib import Path
import numpy as np
from sweagent.utils.log import get_logger
"""Calculate statistics from .traj files."""
logger = get_logger("quick-stats", emoji="📊")
def quick_stats(directory: Path | str = ".") -> str:
"""Calculate statistics from .traj files.
Args:
directory: Directory to search for .traj files (default: current directory)
Returns:
str: Summary of statistics
"""
directory = Path(directory)
# Find all .traj files
traj_files = list(directory.glob("**/*.traj"))
if not traj_files:
logger.warning("No .traj files found in %s", directory)
return "No .traj files found."
# Extract api_calls from each file
api_calls = []
files_by_exit_status = collections.defaultdict(list)
for file_path in traj_files:
try:
data = json.loads(file_path.read_text())
# Extract the api_calls value using dictionary path
if "info" in data and "model_stats" in data["info"] and "api_calls" in data["info"]["model_stats"]:
api_calls.append(data["info"]["model_stats"]["api_calls"])
if "info" in data and "exit_status" in data["info"]:
status = data["info"]["exit_status"]
files_by_exit_status[status].append(file_path)
except Exception as e:
logger.error("Error processing %s: %s", file_path, e)
files_by_exit_status = dict(sorted(files_by_exit_status.items(), key=lambda x: len(x[1]), reverse=True))
if not api_calls:
logger.warning("No valid api_calls data found in the .traj files")
return "No valid api_calls data found in the .traj files."
# Calculate and return the average
logger.info("Exit statuses:")
# Sort exit statuses by count (highest to lowest)
for status, files in files_by_exit_status.items():
logger.info("%s: %d", status, len(files))
average_api_calls = np.mean(api_calls)
logger.info("Avg api calls: %s", average_api_calls)
# Print exit statuses in the requested format
result = []
for status, files in files_by_exit_status.items():
result.append(f"\n## `{status}`\n")
# Extract unique subdirectories instead of full paths
subdirs = {str(Path(file_path).parent) for file_path in files}
result.append(" ".join(subdirs))
return "\n".join(result)
def get_cli_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"directory",
type=Path,
nargs="?",
default=Path("."),
help="Directory to search for .traj files (default: current directory)",
)
return parser
def run_from_cli(args: list[str] | None = None) -> None:
cli_parser = get_cli_parser()
cli_args = cli_parser.parse_args(args)
result = quick_stats(cli_args.directory)
print(result)
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,63 @@
"""Remove unfinished trajectories."""
import argparse
import shutil
from pathlib import Path
from sweagent.utils.files import load_file
from sweagent.utils.log import get_logger
logger = get_logger("remove_unfinished")
def remove_unfinished(base_dir: Path, dry_run: bool = True) -> None:
"""Remove unfinished trajectories."""
to_remove = []
for directory in base_dir.iterdir():
if not directory.is_dir():
continue
if "__" not in directory.name:
continue
trajs = list(directory.glob("*.traj"))
if not trajs:
logger.info("No trajectories found in %s", directory)
continue
if len(trajs) > 1:
logger.warning("Found multiple trajectories in %s. Skipping.", directory)
continue
try:
traj = load_file(trajs[0])
except Exception as e:
logger.warning("Error loading trajectory %s: %s. Adding to remove list.", trajs[0], e)
to_remove.append(directory)
continue
submission = traj.get("info", {}).get("submission", None)
if submission is None:
logger.warning("No submission found in %s. Adding to remove list.", directory)
to_remove.append(directory)
continue
if dry_run:
logger.info("Would remove %d unfinished trajectories.", len(to_remove))
for directory in to_remove:
logger.info(directory)
else:
for directory in to_remove:
logger.info("Removing %s", directory)
shutil.rmtree(directory)
def get_cli_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--base_dir", type=Path, help="Base directory", default=Path("."))
parser.add_argument("--remove", action="store_true", help="Remove unfinished trajectories")
return parser
def run_from_cli(args: list[str] | None = None) -> None:
cli_parser = get_cli_parser()
cli_args = cli_parser.parse_args(args)
remove_unfinished(cli_args.base_dir, dry_run=not cli_args.remove)
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,91 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from random import random
from threading import Lock
from rich.console import Group
from rich.live import Live
from rich.logging import RichHandler
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskID,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
logging.basicConfig(level="NOTSET", handlers=[RichHandler(level="NOTSET")])
logger = logging.getLogger("rich")
# Lock for thread-safe progress updates
progress_lock = Lock()
class RunBatch:
def __init__(self):
self.tasks = list(range(10)) # Reduced to 10 tasks for example clarity
self._main_progress_bar: Progress | None = None
self._task_progress_bar: Progress | None = None
self._spinner_tasks: dict[TaskID, TaskID] = {}
def do_task(self, task_id: TaskID):
assert self._main_progress_bar is not None
assert self._task_progress_bar is not None
# Create a spinner for this task
with progress_lock:
spinner_task_id = self._task_progress_bar.add_task(f"Task {task_id}", total=None)
logger.info("Starting task %d", task_id)
# Startup
time.sleep(random() * 4.5)
# Work
with progress_lock:
self._task_progress_bar.update(spinner_task_id, description=f"Task {task_id} (working)")
time.sleep(random() * 4.5 + 2)
logger.info("Finished task %d", task_id)
# Remove spinner and update main progress
with progress_lock:
self._task_progress_bar.remove_task(spinner_task_id)
self._main_progress_bar.update(TaskID(0), advance=1)
def main(self):
# Custom progress columns
self._main_progress_bar = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
)
self._task_progress_bar = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(),
)
group = Group(self._main_progress_bar, self._task_progress_bar)
with Live(group):
# Add main progress bar
self._main_task_id = self._main_progress_bar.add_task("[cyan]Overall Progress", total=len(self.tasks))
# Create thread pool and run tasks
with ThreadPoolExecutor(max_workers=5) as executor:
# Submit all tasks
futures = [executor.submit(self.do_task, task_id) for task_id in self.tasks]
# Wait for all tasks to complete
for future in futures:
future.result()
if __name__ == "__main__":
run_batch = RunBatch()
run_batch.main()

View File

@@ -0,0 +1,147 @@
"""[cyan][bold]Main command line interface for SWE-agent.[/bold][/cyan]
[cyan][bold]=== USAGE ===[/bold][/cyan]
[green]sweagent <command> [options][/green]
Display usage instructions for a specific command:
[green]sweagent <command> [bold]--help[/bold][/green]
[cyan][bold]=== SUBCOMMANDS TO RUN SWE-AGENT ===[/bold][/cyan]
[bold][green]run[/green][/bold] or [bold][green]r[/green][/bold]: Run swe-agent on a single problem statement, for example a github issue.
[bold][green]run-batch[/green][/bold] or [bold][green]b[/green][/bold]: Run swe-agent on a batch of problem statements, e.g., on SWE-Bench.
[cyan][bold]=== MISC SUBCOMMANDS ===[/bold][/cyan]
[bold][green]merge-preds[/green][/bold]: Merge multiple prediction files into a single file. In most cases
[green]run-batch[/green] will already do this, but you can use this to merge predictions
from multiple directories.
[bold][green]inspect[/green][/bold] or [bold][green]i[/green][/bold]: Open a single trajectory file in a terminal-based viewer.
[bold][green]inspector[/green][/bold] or [bold][green]I[/green][/bold]: Open trajectories in a web-based viewer.
[bold][green]run-replay[/green][/bold]: Replay a trajectory file or a demo file.
This can be useful to fill in environment output when creating demonstrations.
[bold][green]traj-to-demo[/green][/bold]: Convert a trajectory file to an easy to edit demo file.
[bold][green]run-api[/green][/bold]: Run swe-agent as a backend for a GUI
[bold][green]remove-unfinished[/green][/bold] or [bold][green]ru[/green][/bold]: Remove unfinished trajectories
[bold][green]quick-stats[/green][/bold] or [bold][green]qs[/green][/bold]: Calculate quick stats from a directory of trajectories
"""
import argparse
import sys
import rich
def get_cli():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"command",
choices=[
"run",
"run-batch",
"run-replay",
"traj-to-demo",
"run-api",
"merge-preds",
"inspect",
"inspector",
"r",
"b",
"i",
"I",
"extract-pred",
"compare-runs",
"cr",
"remove-unfinished",
"ru",
"quick-stats",
"qs",
"shell",
"sh",
],
nargs="?",
)
parser.add_argument("-h", "--help", action="store_true", help="Show this help message and exit")
return parser
def main(args: list[str] | None = None):
if args is None:
args = sys.argv[1:]
cli = get_cli()
parsed_args, remaining_args = cli.parse_known_args(args) # type: ignore
command = parsed_args.command
show_help = parsed_args.help
if show_help:
if not command:
# Show main help
rich.print(__doc__)
sys.exit(0)
else:
# Add to remaining_args
remaining_args.append("--help")
elif not command:
cli.print_help()
sys.exit(2)
# Defer imports to avoid unnecessary long loading times
if command in ["run", "r"]:
from sweagent.run.run_single import run_from_cli as run_single_main
run_single_main(remaining_args)
elif command in ["run-batch", "b"]:
from sweagent.run.run_batch import run_from_cli as run_batch_main
run_batch_main(remaining_args)
elif command == "run-replay":
from sweagent.run.run_replay import run_from_cli as run_replay_main
run_replay_main(remaining_args)
elif command == "traj-to-demo":
from sweagent.run.run_traj_to_demo import run_from_cli as convert_traj_to_demo_main
convert_traj_to_demo_main(remaining_args)
elif command == "run-api":
from sweagent.api.server import run_from_cli as run_api_main
run_api_main(remaining_args)
elif command == "merge-preds":
from sweagent.run.merge_predictions import run_from_cli as merge_predictions_main
merge_predictions_main(remaining_args)
elif command in ["inspector", "I"]:
from sweagent.inspector.server import run_from_cli as inspector_main
inspector_main(remaining_args)
elif command in ["inspect", "i"]:
from sweagent.run.inspector_cli import main as inspect_main
inspect_main(remaining_args)
elif command == "extract-pred":
from sweagent.run.extract_pred import run_from_cli as extract_pred_main
extract_pred_main(remaining_args)
elif command in ["compare-runs", "cr"]:
from sweagent.run.compare_runs import run_from_cli as compare_runs_main
compare_runs_main(remaining_args)
elif command in ["remove-unfinished", "ru"]:
from sweagent.run.remove_unfinished import run_from_cli as remove_unfinished_main
remove_unfinished_main(remaining_args)
elif command in ["quick-stats", "qs"]:
from sweagent.run.quick_stats import run_from_cli as quick_stats_main
quick_stats_main(remaining_args)
elif command in ["shell", "sh"]:
from sweagent.run.run_shell import run_from_cli as run_shell_main
run_shell_main(remaining_args)
else:
msg = f"Unknown command: {command}"
raise ValueError(msg)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,442 @@
"""
Run on a batch of instances/issues, e.g., SWE-bench.
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
-h --help Show help text and exit
--help_option Print specific help text and exit
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
Basic usage: Run over a [bold][cyan]SWE-bench lite[/bold][/cyan][green]:
sweagent run-batch \\
--instances.type swe_bench \\ # configure instances
--instances.subset lite \\
--instances.split dev \\
--instances.slice :50 \\ # first 50 instances
--instances.shuffle=True \\ # shuffle instances (with fixed seed)
--config config/default.yaml \\
--agent.model.name gpt-4o # configure model
[/green]
[cyan][bold]=== LOADING INSTANCES ===[/bold][/cyan]
[cyan][bold]From a file[/bold][/cyan] [green]--instances.type file --instances.path /path/to/file[/green].
[cyan][bold]From huggingface[/bold][/cyan] [green]--instances.type huggingface --instances.dataset_name=SWE_Bench_lite --instances.split=dev[/green].
All instance specifications support the [green]filter[/green], [green]slice[/green], and [green]shuffle[/green] options.
With [green]filter[/green], you can select specific instances, e.g., [green]--instances.filter='instance_id_1|instance_id_2'[/green].
"""
import getpass
import json
import logging
import random
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from pathlib import Path
from typing import Self
import yaml
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.live import Live
from swerex.deployment.hooks.status import SetStatusDeploymentHook
from sweagent import TRAJECTORY_DIR
from sweagent.agent.agents import AgentConfig, get_agent_from_config
from sweagent.agent.hooks.status import SetStatusAgentHook
from sweagent.environment.hooks.status import SetStatusEnvironmentHook
from sweagent.environment.swe_env import SWEEnv
from sweagent.exceptions import ModelConfigurationError, TotalCostLimitExceededError
from sweagent.run._progress import RunBatchProgressManager
from sweagent.run.batch_instances import BatchInstance, BatchInstanceSourceConfig, SWEBenchInstances
from sweagent.run.common import BasicCLI, ConfigHelper, save_predictions
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
from sweagent.run.hooks.apply_patch import SaveApplyPatchHook
from sweagent.run.merge_predictions import merge_predictions
from sweagent.run.run_single import RunSingleConfig
from sweagent.types import AgentRunResult
from sweagent.utils.config import load_environment_variables
from sweagent.utils.log import (
add_file_handler,
add_logger_names_to_stream_handlers,
get_logger,
register_thread_name,
remove_file_handler,
set_stream_handler_levels,
)
class RunBatchConfig(BaseSettings, cli_implicit_flags=False):
instances: BatchInstanceSourceConfig = Field(description="Instances to run.")
agent: AgentConfig = Field(description="Agent options.")
output_dir: Path = Field(default=Path("DEFAULT"), description="Output directory.")
suffix: str = ""
"""Suffix to add to the output directory. Only used if `output_dir` is `DEFAULT`."""
raise_exceptions: bool = False
"""Raise exceptions instead of skipping instances."""
redo_existing: bool = False
"""Do not skip instances that already have a trajectory."""
env_var_path: Path | None = None
"""Path to a .env file to load environment variables from."""
num_workers: int = Field(default=1)
"""Number of parallel workers to use."""
random_delay_multiplier: float = 0.3
"""We will wait for a random amount of time between 0 and `random_delay_multiplier`
times the number of workers at the start of each instance. This is to avoid any
potential race condition or issues with bottlenecks, e.g., when running on a platform
with few CPUs that cannot handle the startup of all containers in time.
"""
progress_bar: bool = True
"""Whether to show a progress bar. Progress bar is never shown for human models.
Progress bar is always shown for multi-worker runs.
"""
# pydantic config
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
def set_default_output_dir(self) -> None:
# Needs to be called explicitly, because self._config_files will be setup
# post-init.
if self.output_dir == Path("DEFAULT"):
user_id = getpass.getuser()
source_id = self.instances.id
try:
model_id = self.agent.model.id # type: ignore[attr-defined]
except AttributeError:
model_id = "unknown"
config_file = getattr(self, "_config_files", ["no_config"])[0]
if config_file != "no_config":
config_file = Path(config_file).stem
suffix = f"__{self.suffix}" if self.suffix else ""
self.output_dir = TRAJECTORY_DIR / user_id / f"{config_file}__{model_id}___{source_id}{suffix}"
@model_validator(mode="after")
def evaluate_and_redo_existing(self) -> Self:
if not isinstance(self.instances, SWEBenchInstances):
return self
if self.instances.evaluate and self.redo_existing:
msg = (
"Cannot evaluate and redo existing at the same time. This would cause invalid results, because "
"after the first merge_preds gives you a preds.json, this file would be submitted to SB-CLI, causing"
"evaluation of old instances, which could then not be overwritten by the new ones."
)
raise ValueError(msg)
return self
class _BreakLoop(Exception):
"""Used for internal control flow"""
class RunBatch:
def __init__(
self,
instances: list[BatchInstance],
agent_config: AgentConfig,
*,
output_dir: Path = Path("."),
hooks: list[RunHook] | None = None,
raise_exceptions: bool = False,
redo_existing: bool = False,
num_workers: int = 1,
progress_bar: bool = True,
random_delay_multiplier: float = 0.3,
):
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
See `from_config` for an example.
Args:
hooks: If not specified, the default hooks will be used.
num_workers: Number of parallel workers to use. Default is 1 (sequential execution).
progress_bar: Whether to show a progress bar. Progress bar is never shown for human models.
Progress bar is always shown for multi-worker runs.
random_delay_multiplier: We will wait for a random amount of time between 0 and `random_delay_multiplier`
times the number of workers at the start of each instance. This is to avoid any
potential race conditions.
"""
if self._model_id in ["human", "human_thought"] and num_workers > 1:
msg = "Cannot run with human model in parallel"
raise ValueError(msg)
self.logger = get_logger("swea-run", emoji="🏃")
add_file_handler(
output_dir / "run_batch.log",
id_="progress",
filter=lambda name: "swea-run" in name or "config" in name,
)
self.instances = instances
self.agent_config = agent_config
self.output_dir = output_dir
self._raise_exceptions = raise_exceptions
self._chooks = CombinedRunHooks()
self._redo_existing = redo_existing
self._num_workers = min(num_workers, len(instances))
for hook in hooks or [SaveApplyPatchHook(show_success_message=False)]:
self.add_hook(hook)
self._progress_manager = RunBatchProgressManager(
num_instances=len(instances), yaml_report_path=output_dir / "run_batch_exit_statuses.yaml"
)
self._show_progress_bar = progress_bar
self._random_delay_multiplier = random_delay_multiplier
@property
def _model_id(self) -> str:
try:
return self.agent_config.model.id # type: ignore[attr-defined]
except AttributeError:
return "unknown"
@classmethod
def from_config(cls, config: RunBatchConfig) -> Self:
load_environment_variables(config.env_var_path)
config.set_default_output_dir()
config.output_dir.mkdir(parents=True, exist_ok=True)
(config.output_dir / "run_batch.config.yaml").write_text(yaml.dump(config.model_dump_json(), indent=2))
logger = get_logger("run", emoji="🏃")
logger.debug("Loading instances from %s", f"{config.instances!r}")
instances = config.instances.get_instance_configs()
logger.info("Loaded %d instances", len(instances))
if not instances:
msg = (
"No instances to run. Here are a few things to check:\n"
"- With huggingface data: Check that you have the right split (test or dev)\n"
"- Check your filter does not exclude all instances (check the info log messages)"
)
raise ValueError(msg)
logger.debug("The first instance is %s", f"{instances[0]!r}")
rb = cls(
instances=instances,
agent_config=config.agent,
output_dir=config.output_dir,
raise_exceptions=config.raise_exceptions,
redo_existing=config.redo_existing,
num_workers=config.num_workers,
progress_bar=config.progress_bar,
random_delay_multiplier=config.random_delay_multiplier,
)
if isinstance(config.instances, SWEBenchInstances) and config.instances.evaluate:
from sweagent.run.hooks.swe_bench_evaluate import SweBenchEvaluate
rb.add_hook(
SweBenchEvaluate(
output_dir=config.output_dir,
subset=config.instances.subset,
split=config.instances.split,
continuous_submission_every=30,
)
)
return rb
def add_hook(self, hook: RunHook) -> None:
hook.on_init(run=self)
self._chooks.add_hook(hook)
def main(self) -> None:
self.logger.info("Starting run. Find output files at %s", self.output_dir)
self._chooks.on_start()
if self._num_workers <= 1:
self.main_single_worker()
else:
self.main_multi_worker()
output_dirs = []
for instance in self.instances:
output_dirs.append(self.output_dir / instance.problem_statement.id)
merge_predictions(output_dirs, self.output_dir / "preds.json")
self._chooks.on_end()
def main_single_worker(self) -> None:
with ExitStack() as stack:
# Conditionally add progress bar
if self._model_id not in ["human", "human_thought"] and self._show_progress_bar:
stack.enter_context(Live(self._progress_manager.render_group))
for instance in self.instances:
try:
self.run_instance(instance)
except _BreakLoop:
self.logger.info("Stopping loop over instances")
break
def main_multi_worker(self) -> None:
add_logger_names_to_stream_handlers()
# Set all stream handlers to WARNING and set everything where we want to have
# more verbosity explicitly
set_stream_handler_levels(logging.WARNING)
self.logger.setLevel(logging.TRACE) # type: ignore
with Live(self._progress_manager.render_group):
with ThreadPoolExecutor(max_workers=self._num_workers) as executor:
futures = [executor.submit(self.run_instance, instance) for instance in self.instances]
try:
for future in as_completed(futures):
future.result()
except (KeyboardInterrupt, _BreakLoop):
msg = (
"Received keyboard interrupt, waiting for running instances "
"to finish, but cancelled everything else"
)
self.logger.info(msg)
executor.shutdown(wait=False, cancel_futures=True)
finally:
self._progress_manager.print_report()
def run_instance(self, instance: BatchInstance) -> None:
self.logger.info("Running on instance %s", instance.problem_statement.id)
register_thread_name(instance.problem_statement.id)
self._add_instance_log_file_handlers(instance.problem_statement.id, multi_worker=self._num_workers > 1)
# Let's add some randomness to avoid any potential race conditions or thundering herd
if self._progress_manager.n_completed < self._num_workers:
time.sleep(random.random() * self._random_delay_multiplier * (self._num_workers - 1))
self._progress_manager.on_instance_start(instance.problem_statement.id)
if previous_exit_status := self.should_skip(instance):
self._progress_manager.on_instance_end(
instance.problem_statement.id, exit_status=f"skipped ({previous_exit_status})"
)
self._remove_instance_log_file_handlers(instance.problem_statement.id)
return
# Either catch and silence exception, or raise _BreakLoop to stop the loop
# over the instances
try:
result = self._run_instance(instance)
except KeyboardInterrupt:
raise _BreakLoop
except (SystemExit, ModelConfigurationError, TotalCostLimitExceededError) as e:
if self._raise_exceptions:
raise
self.logger.critical(f"❌ Exiting because {e.__class__.__name__} was called")
raise _BreakLoop
except Exception as e:
self.logger.error(traceback.format_exc())
self.logger.error(f"❌ Failed on {instance.problem_statement.id}: {e}")
self._progress_manager.on_uncaught_exception(instance.problem_statement.id, e)
if self._raise_exceptions:
raise
else:
self._progress_manager.on_instance_end(
instance.problem_statement.id, exit_status=result.info.get("exit_status", "unknown_exit")
)
finally:
self._progress_manager.update_exit_status_table()
self._remove_instance_log_file_handlers(instance.problem_statement.id)
def _run_instance(self, instance: BatchInstance) -> AgentRunResult:
output_dir = Path(self.output_dir) / instance.problem_statement.id
output_dir.mkdir(parents=True, exist_ok=True)
self.agent_config.name = f"{instance.problem_statement.id}"
agent = get_agent_from_config(self.agent_config)
single_run_replay_config = RunSingleConfig(
agent=self.agent_config,
problem_statement=instance.problem_statement,
env=instance.env,
)
(output_dir / f"{instance.problem_statement.id}.config.yaml").write_text(
yaml.dump(single_run_replay_config.model_dump_json(), indent=2)
)
agent.replay_config = single_run_replay_config # type: ignore[attr-defined]
agent.add_hook(SetStatusAgentHook(instance.problem_statement.id, self._progress_manager.update_instance_status))
self._progress_manager.update_instance_status(instance.problem_statement.id, "Starting environment")
instance.env.name = f"{instance.problem_statement.id}"
env = SWEEnv.from_config(instance.env)
env.add_hook(
SetStatusEnvironmentHook(instance.problem_statement.id, self._progress_manager.update_instance_status)
)
env.deployment.add_hook(
SetStatusDeploymentHook(instance.problem_statement.id, self._progress_manager.update_instance_status)
)
try:
env.start()
self._chooks.on_instance_start(index=0, env=env, problem_statement=instance.problem_statement)
result = agent.run(
problem_statement=instance.problem_statement,
env=env,
output_dir=output_dir,
)
except Exception:
# The actual handling is happening in `run_instance`, but we need to make sure that
# we log it to the agent specific logger as well
agent.logger.error(traceback.format_exc()) # type: ignore[attr-defined]
raise
finally:
env.close()
save_predictions(self.output_dir, instance.problem_statement.id, result)
self._chooks.on_instance_completed(result=result)
return result
def should_skip(self, instance: BatchInstance) -> bool | str:
"""Check if we should skip this instance.
Returns previous exit status if the instance should be skipped.
"""
if self._redo_existing:
return False
# Check if there's an existing trajectory for this instance
log_path = self.output_dir / instance.problem_statement.id / (instance.problem_statement.id + ".traj")
if not log_path.exists():
return False
content = log_path.read_text()
if not content.strip():
self.logger.warning("Found empty trajectory: %s. Removing.", log_path)
log_path.unlink()
return False
try:
data = json.loads(content)
# If the trajectory has no exit status, it's incomplete and we will redo it
exit_status = data["info"].get("exit_status", None)
if exit_status == "early_exit" or exit_status is None:
self.logger.warning(f"Found existing trajectory with no exit status: {log_path}. Removing.")
log_path.unlink()
return False
except Exception as e:
self.logger.error(f"Failed to check existing trajectory: {log_path}: {e}. Removing.")
# If we can't check the trajectory, we will redo it
log_path.unlink()
return False
# otherwise, we will skip it
self.logger.info(f"⏭️ Skipping existing trajectory: {log_path}")
return exit_status
def _add_instance_log_file_handlers(self, instance_id: str, multi_worker: bool = False) -> None:
filename_template = f"{instance_id}.{{level}}.log"
for level in ["trace", "debug", "info"]:
filter = instance_id if multi_worker else ""
add_file_handler(
self.output_dir / instance_id / filename_template.format(level=level),
filter=filter,
level=level,
id_=f"{instance_id}-{level}",
)
def _remove_instance_log_file_handlers(self, instance_id: str) -> None:
for level in ["trace", "debug", "info"]:
remove_file_handler(f"{instance_id}-{level}")
def run_from_config(config: RunBatchConfig):
RunBatch.from_config(config).main()
def run_from_cli(args: list[str] | None = None):
if args is None:
args = sys.argv[1:]
assert __doc__ is not None
help_text = ( # type: ignore
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunBatchConfig)
)
run_from_config(BasicCLI(RunBatchConfig, help_text=help_text).get_config(args)) # type: ignore
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,219 @@
"""[cyan][bold]Replay a trajectory file.[/bold][/cyan]
[cyan][bold]=== DESCRIPTION ===[/bold][/cyan]
We will take all actions in the trajectory and execute them in an environment.
This has two main use cases:
1. Create a demo from a yaml file containing actions (can also be created from a trajectory file with [green]sweagent run traj-to-demo[/green]).
[green]run-replay[/green] will execute the actions to get the environment output and produce a full trajectory to be used as a demo.
2. Debugging and testing of tools and environment behavior.
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
Replay a trajectory file:
[green]sweagent run replay --traj_path mytraj.traj[/green]
Replay a demo file:
[green]sweagent run replay --traj_path mydemo.demo.yaml[/green]
"""
import json
import sys
import tempfile
from getpass import getuser
from pathlib import Path
from typing import Any
import yaml
from pydantic_settings import BaseSettings, SettingsConfigDict
from swerex.deployment.abstract import AbstractDeployment
from swerex.deployment.config import DeploymentConfig, get_deployment
from typing_extensions import Self
from sweagent.agent.agents import DefaultAgent
from sweagent.agent.models import ReplayModelConfig
from sweagent.environment.swe_env import SWEEnv
from sweagent.run.common import BasicCLI, ConfigHelper
from sweagent.run.run_single import RunSingle, RunSingleConfig
from sweagent.utils.config import load_environment_variables
from sweagent.utils.log import get_logger
class RunReplayConfig(BaseSettings, cli_implicit_flags=False):
traj_path: Path
deployment: DeploymentConfig | None = None
"""Override the deployment in the trajectory."""
output_dir: Path = Path("DEFAULT")
env_var_path: Path | None = None
"""Path to a .env file to load environment variables from."""
update_config: list[Path] = []
"""Additional config files to merge with the replay config."""
# pydantic config
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
def model_post_init(self, __context: Any) -> None:
if self.output_dir == Path("DEFAULT"):
user_id = getuser()
self.output_dir = Path.cwd() / "trajectories" / user_id / f"replay___{self.traj_path.stem}"
self.output_dir.mkdir(parents=True, exist_ok=True)
class RunReplay:
def __init__(
self,
*,
traj_path: Path,
deployment: AbstractDeployment | None,
output_dir: Path,
update_config: list[Path] | None = None,
_catch_errors: bool = False,
_require_zero_exit_code: bool = False,
):
self.traj_path = traj_path
self.output_dir = output_dir
self._replay_action_trajs_path = Path(tempfile.NamedTemporaryFile(suffix=".json").name)
self.logger = get_logger("swea-run", emoji="🏃")
self._catch_errors = _catch_errors
self._require_zero_exit_code = _require_zero_exit_code
self._update_config = update_config if update_config is not None else []
if traj_path.suffix == ".yaml":
self._traj_data = yaml.safe_load(traj_path.read_text())
else:
self._traj_data = json.loads(traj_path.read_text())
self.config = self._get_config_from_agent(self._traj_data)
if deployment is None:
self.deployment = get_deployment(self.config.env.deployment)
else:
self.deployment = deployment
def _get_config_from_agent(self, traj_data):
try:
if isinstance(traj_data["replay_config"], str):
traj_data["replay_config"] = json.loads(traj_data["replay_config"])
config = RunSingleConfig.model_validate(traj_data["replay_config"])
except KeyError:
msg = "Replay config not found in trajectory. Are you running on an old trajectory?"
raise ValueError(msg)
# Merge any additional config files
for config_path in self._update_config:
update_data = yaml.safe_load(config_path.read_text())
# Store the current model config before merging
current_model = config.agent.model
# Convert the merged data back to a RunSingleConfig
config_dict = config.model_dump(mode="json")
merged_dict = config_dict | update_data
# Ensure agent.model is preserved if not explicitly updated
if "agent" in merged_dict and "model" not in merged_dict["agent"]:
merged_dict["agent"]["model"] = current_model.model_dump(mode="json")
config = RunSingleConfig.model_validate(merged_dict)
config.agent.model = ReplayModelConfig(replay_path=self._replay_action_trajs_path)
return config
@property
def instance_id(self) -> str:
return Path(self.traj_path).stem
@classmethod
def from_config(cls, config: RunReplayConfig, **kwargs) -> Self:
load_environment_variables(config.env_var_path)
return cls(
traj_path=config.traj_path,
deployment=get_deployment(config.deployment) if config.deployment else None,
output_dir=config.output_dir,
update_config=config.update_config,
**kwargs,
)
def _create_actions_file(self) -> None:
# Verify config compatibility with tool calls
has_tool_calls = any(
"tool_calls" in item and item["tool_calls"] is not None
for item in self._traj_data["history"]
if item["role"] == "assistant"
)
agent_config = self.config.agent
parse_function = agent_config.tools.parse_function.type
use_function_calling = parse_function == "function_calling"
if has_tool_calls and not use_function_calling:
msg = (
"Trajectory contains tool calls but config is not set up for function calling. "
"Check that the config you want to use has agent.tools.parse_function.type set to 'function_calling'."
)
raise ValueError(msg)
actions = []
for ix, item in enumerate(self._traj_data["history"]):
if item["role"] != "assistant":
continue
action = {"message": item["content"]}
if use_function_calling:
assert "tool_calls" in item and item["tool_calls"] is not None, (
f"Config is set to use `function_calling` but trajectory item {ix} is missing a tool call "
f"or has tool_calls set to None"
)
action["tool_calls"] = item["tool_calls"]
actions.append(action)
if len(actions) == 0:
msg = "No actions found in trajectory"
raise ValueError(msg)
self._replay_action_trajs_path.write_text(json.dumps({self.instance_id: actions}))
def _get_env(self) -> SWEEnv:
return SWEEnv(
deployment=self.deployment,
repo=self.config.env.repo,
post_startup_commands=[],
)
def _get_agent(self) -> DefaultAgent:
agent = DefaultAgent.from_config(self.config.agent)
agent._catch_errors = self._catch_errors
agent._always_require_zero_exit_code = self._require_zero_exit_code
return agent
def _get_run_single(self) -> RunSingle:
return RunSingle(
self._get_env(),
self._get_agent(),
problem_statement=self.config.problem_statement,
output_dir=Path(self.output_dir),
)
def main(self):
self._create_actions_file()
run_single = self._get_run_single()
run_single.agent.replay_config = RunSingleConfig(
agent=self.config.agent,
problem_statement=run_single.problem_statement,
env=self.config.env,
)
run_single.run()
def run_from_config(config: RunReplayConfig):
RunReplay.from_config(config).main()
def run_from_cli(args: list[str] | None = None):
if args is None:
args = sys.argv[1:]
help_text = ( # type: ignore
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunReplayConfig)
)
run_from_config(BasicCLI(RunReplayConfig, help_text=help_text, default_settings=False).get_config(args)) # type: ignore
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,155 @@
"""[cyan][bold]Run SWE-agent in semi-interactive mode.[/bold][/cyan]
[cyan][bold]sweagen-sh is EXPERIMENTAL[/bold][/cyan]
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
-h --help Show help text and exit
--help_option Print specific help text and exit
--config CONFIG Load additional config files. Use this option multiple times to load
multiple files, e.g., --config config1.yaml --config config2.yaml
"""
import argparse
import logging
from pathlib import Path
import yaml
from rich.prompt import Prompt
from swerex.deployment.config import DockerDeploymentConfig
from sweagent import CONFIG_DIR
from sweagent.agent.agents import AbstractAgent, ShellAgentConfig
from sweagent.agent.extra.shell_agent import ShellAgent
from sweagent.agent.problem_statement import (
GithubIssue,
ProblemStatement,
ProblemStatementConfig,
TextProblemStatement,
)
from sweagent.environment.repo import PreExistingRepoConfig
from sweagent.environment.swe_env import EnvironmentConfig, SWEEnv
from sweagent.run.common import save_predictions
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
from sweagent.utils.config import load_environment_variables
from sweagent.utils.github import _is_github_issue_url
from sweagent.utils.log import add_file_handler, get_logger, set_stream_handler_levels
class RunShell:
def __init__(
self,
env: SWEEnv,
agent: AbstractAgent,
problem_statement: ProblemStatement | ProblemStatementConfig,
*,
output_dir: Path = Path("."),
hooks: list[RunHook] | None = None,
):
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
See `from_config` for an example.
"""
self.logger = get_logger("swea-run", emoji="🏃")
instance_id = problem_statement.id
_log_filename_template = f"{instance_id}.{{level}}.log"
for level in ["trace", "debug", "info"]:
add_file_handler(
output_dir / instance_id / _log_filename_template.format(level=level),
level=level,
id_=f"{instance_id}-{level}",
)
self.env = env
self.agent = agent
self.output_dir = output_dir
self._hooks = []
self._chooks = CombinedRunHooks()
self.problem_statement = problem_statement
for hook in hooks or []:
self.add_hook(hook)
@property
def hooks(self) -> list[RunHook]:
return self._chooks.hooks
def add_hook(self, hook: RunHook) -> None:
hook.on_init(run=self)
self._chooks.add_hook(hook)
def run(self):
self._chooks.on_start()
self.logger.info("Starting environment")
self.env.start()
self.logger.info("Running agent")
self._chooks.on_instance_start(index=0, env=self.env, problem_statement=self.problem_statement)
output_dir = self.output_dir / self.problem_statement.id
output_dir.mkdir(parents=True, exist_ok=True)
result = self.agent.run(
problem_statement=self.problem_statement,
env=self.env,
output_dir=output_dir,
)
self._chooks.on_instance_completed(result=result)
self.logger.info("Done")
self._chooks.on_end()
save_predictions(self.output_dir, self.problem_statement.id, result)
self.env.close()
def get_cli():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-r", "--repo", type=Path, help="Path to the repository.", default=None)
# parser.add_argument(dest="--model", type=str, help="Model to use.", default="claude-sonnet-4-20250514")
parser.add_argument(
"--config",
type=Path,
help="Path to the agent config file.",
default=CONFIG_DIR / "exotic" / "default_shell.yaml",
)
parser.add_argument(
"-p",
type=str,
help="Problem statement.",
default="",
)
return parser
def run_from_cli(args: list[str] | None = None):
set_stream_handler_levels(logging.INFO)
cli_args = get_cli().parse_args(args)
try:
load_environment_variables(Path(".env"))
except FileNotFoundError:
print("Env file .env not found, please set API key as env variables.")
env_config = EnvironmentConfig(
repo=PreExistingRepoConfig(repo_name="repo", reset=False),
deployment=DockerDeploymentConfig(
image="python:3.11",
docker_args=[
"-v",
f"{cli_args.repo}:/repo",
],
python_standalone_dir="/root",
),
)
agent_config = ShellAgentConfig.model_validate(yaml.safe_load(cli_args.config.read_text())["agent"])
agent = ShellAgent.from_config(agent_config)
env = SWEEnv.from_config(env_config)
if cli_args.repo is None:
cli_args.repo = Path(Prompt.ask("[cyan]Repository path[/cyan]", default="", show_default=False))
problem_input = cli_args.p
if not problem_input:
problem_input = Prompt.ask("[cyan]Problem statement or GitHub issue URL[/cyan]", default="", show_default=False)
if _is_github_issue_url(problem_input):
problem_statement = GithubIssue(github_url=problem_input)
else:
problem_statement = TextProblemStatement(
text=problem_input,
)
run_shell = RunShell(env, agent, problem_statement=problem_statement, output_dir=Path.home() / "sweagent_shell")
run_shell.run()
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,225 @@
"""[cyan][bold]Run SWE-agent on a single instance taken from github or similar.[/bold][/cyan]
[cyan][bold]=== BASIC OPTIONS ===[/bold][/cyan]
-h --help Show help text and exit
--help_option Print specific help text and exit
--config CONFIG Load additional config files. Use this option multiple times to load
multiple files, e.g., --config config1.yaml --config config2.yaml
[cyan][bold]=== EXAMPLES ===[/bold][/cyan]
Basic usage: Run over a [bold][cyan]github issue[/bold][/cyan][green]:
sweagent run --config config/default.yaml --agent.model.name "gpt-4o" \\
--env.repo.github_url=https://github.com/SWE-agent/test-repo/ \\
--problem_statement.github_url=https://github.com/SWE-agent/test-repo/issues/1
[/green]
By default this will start a docker container and run the agent in there.
You can set the image with [green]--env.docker.image[/green].
Here's an example that uses [bold][cyan]modal[/bold][/cyan] instead of docker and also a [bold][cyan]local repository[/bold][/cyan]:
[green]sweagent run --config config/default.yaml --agent.model.name "gpt-4o" \\
--env.deployment.type=modal --env.repo.path /path/to/repo \\
--problem_statement.path=path/to/problem_statement.md
[/green]
"""
import getpass
import sys
from pathlib import Path
from typing import Self
import yaml
from pydantic import BaseModel, ConfigDict, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from sweagent.agent.agents import AbstractAgent, AgentConfig, get_agent_from_config
from sweagent.agent.problem_statement import (
EmptyProblemStatement,
ProblemStatement,
ProblemStatementConfig,
)
from sweagent.environment.swe_env import EnvironmentConfig, SWEEnv
from sweagent.run.common import AutoCorrectSuggestion as ACS
from sweagent.run.common import BasicCLI, ConfigHelper, save_predictions
from sweagent.run.hooks.abstract import CombinedRunHooks, RunHook
from sweagent.run.hooks.apply_patch import SaveApplyPatchHook
from sweagent.run.hooks.open_pr import OpenPRConfig, OpenPRHook
from sweagent.utils.config import load_environment_variables
from sweagent.utils.log import add_file_handler, get_logger
class RunSingleActionConfig(BaseModel):
"""Run real-life actions (opening PRs, etc.) if we can solve the issue."""
# Open a PR with the patch if we can solve the issue
open_pr: bool = False
pr_config: OpenPRConfig = Field(default_factory=OpenPRConfig)
# When working with local repository: Apply patch
apply_patch_locally: bool = False
# pydantic config
model_config = ConfigDict(extra="forbid")
def _get_default_output_dir(output_dir: Path, problem_statement: ProblemStatement, agent: AgentConfig) -> Path:
if output_dir == Path("DEFAULT"):
user_id = getpass.getuser()
problem_id = problem_statement.id
try:
model_id = agent.model.id # type: ignore[attr-defined]
except AttributeError:
model_id = "unknown_model"
config_file = getattr(agent, "_config_files", ["no_config"])[0]
if isinstance(config_file, Path):
config_file = config_file.stem
return Path.cwd() / "trajectories" / user_id / f"{config_file}__{model_id}___{problem_id}"
return output_dir
class RunSingleConfig(BaseSettings, cli_implicit_flags=False):
env: EnvironmentConfig = Field(default_factory=EnvironmentConfig, description="Environment options.")
agent: AgentConfig = Field(description="Agent options.")
problem_statement: ProblemStatementConfig = Field(
default_factory=EmptyProblemStatement, description="Problem statement options."
)
output_dir: Path = Field(default=Path("DEFAULT"), description="Output directory.")
actions: RunSingleActionConfig = Field(default_factory=RunSingleActionConfig)
env_var_path: Path | None = None
"""Path to a .env file to load environment variables from."""
# pydantic config
model_config = SettingsConfigDict(extra="forbid", env_prefix="SWE_AGENT_")
def set_default_output_dir(self) -> None:
# Needs to be called explicitly, because self._config_files will be setup
# post-init.
self.output_dir = _get_default_output_dir(self.output_dir, self.problem_statement, self.agent)
@classmethod
def _get_auto_correct(cls) -> list[ACS]:
return [
ACS("model", "agent.model.name"),
ACS("agent.model", "agent.model.name"),
ACS("model.name", "agent.model.name"),
ACS("per_instance_cost_limit", "agent.model.per_instance_cost_limit"),
ACS("model.per_instance_cost_limit", "agent.model.per_instance_cost_limit"),
ACS("config_file", "config"),
ACS(
"data_path",
help="--data_path is no longer support for SWE-A 1.0. Please check the tutorial and use one of the --problem_statement options, e.g., --problem_statement.github_url or --problem_statement.path",
),
ACS(
"repo_path",
help="--repo_path is no longer support for SWE-A 1.0. Please check the tutorial and use one of the --env.repo options, e.g., --env.repo.github_url or --env.repo.path",
),
ACS("repo.path", "env.repo.path"),
]
class RunSingle:
def __init__(
self,
env: SWEEnv,
agent: AbstractAgent,
problem_statement: ProblemStatement | ProblemStatementConfig,
*,
output_dir: Path = Path("."),
hooks: list[RunHook] | None = None,
actions: RunSingleActionConfig | None = None,
):
"""Note: When initializing this class, make sure to add the hooks that are required by your actions.
See `from_config` for an example.
"""
self.logger = get_logger("swea-run", emoji="🏃")
instance_id = problem_statement.id
_log_filename_template = f"{instance_id}.{{level}}.log"
for level in ["trace", "debug", "info"]:
add_file_handler(
output_dir / instance_id / _log_filename_template.format(level=level),
level=level,
id_=f"{instance_id}-{level}",
)
self.env = env
self.agent = agent
self.output_dir = output_dir
self._hooks = []
if actions is not None:
actions = RunSingleActionConfig()
self.actions = actions
self._chooks = CombinedRunHooks()
self.problem_statement = problem_statement
for hook in hooks or []:
self.add_hook(hook)
@property
def hooks(self) -> list[RunHook]:
return self._chooks.hooks
@classmethod
def from_config(cls, config: RunSingleConfig) -> Self:
load_environment_variables(config.env_var_path)
config.set_default_output_dir()
config.output_dir.mkdir(parents=True, exist_ok=True)
agent = get_agent_from_config(config.agent)
agent.replay_config = config # type: ignore[attr-defined]
self = cls(
env=SWEEnv.from_config(config.env),
agent=agent,
problem_statement=config.problem_statement,
output_dir=config.output_dir,
actions=config.actions,
)
self.add_hook(SaveApplyPatchHook(apply_patch_locally=config.actions.apply_patch_locally))
if config.actions.open_pr:
self.logger.debug("Adding OpenPRHook")
self.add_hook(OpenPRHook(config.actions.pr_config))
return self
def add_hook(self, hook: RunHook) -> None:
hook.on_init(run=self)
self._chooks.add_hook(hook)
def run(self):
self._chooks.on_start()
self.logger.info("Starting environment")
self.env.start()
self.logger.info("Running agent")
self._chooks.on_instance_start(index=0, env=self.env, problem_statement=self.problem_statement)
output_dir = self.output_dir / self.problem_statement.id
output_dir.mkdir(parents=True, exist_ok=True)
if self.agent.replay_config is not None: # type: ignore[attr-defined]
(output_dir / "config.yaml").write_text(yaml.dump(self.agent.replay_config.model_dump_json(), indent=2)) # type: ignore[attr-defined]
result = self.agent.run(
problem_statement=self.problem_statement,
env=self.env,
output_dir=output_dir,
)
self._chooks.on_instance_completed(result=result)
self.logger.info("Done")
self._chooks.on_end()
save_predictions(self.output_dir, self.problem_statement.id, result)
self.env.close()
def run_from_config(config: RunSingleConfig):
RunSingle.from_config(config).run()
def run_from_cli(args: list[str] | None = None):
if args is None:
args = sys.argv[1:]
assert __doc__ is not None
help_text = ( # type: ignore
__doc__ + "\n[cyan][bold]=== ALL THE OPTIONS ===[/bold][/cyan]\n\n" + ConfigHelper().get_help(RunSingleConfig)
)
run_from_config(BasicCLI(RunSingleConfig, help_text=help_text).get_config(args)) # type: ignore
if __name__ == "__main__":
run_from_cli()

View File

@@ -0,0 +1,85 @@
"""Convert a trajectory file to a yaml file for editing of demos.
You can then load the yaml file with `run_replay.py` to replay the actions in an environment to get
environment output.
"""
from __future__ import annotations
import json
from argparse import ArgumentParser
from pathlib import Path
from sweagent.utils.log import get_logger
from sweagent.utils.serialization import _yaml_serialization_with_linebreaks
logger = get_logger("traj2demo")
DEMO_COMMENT = """# This is a demo file generated from trajectory file:
# {traj_path}
# You can use this demo file to replay the actions in the trajectory with run_replay.py.
# You can edit the content of the actions in this file to modify the replay behavior.
# NOTICE:
# Only the actions of the assistant will be replayed.
# You do not need to modify the observation's contents or any other fields.
# You can add or remove actions to modify the replay behavior."""
def save_demo(data: str | dict | list, file: Path, traj_path: Path) -> None:
"""Save demo data as a yaml file. Takes care of multi-line strings and adds a header."""
content = _yaml_serialization_with_linebreaks(data)
header = DEMO_COMMENT.format(traj_path=str(traj_path))
with open(file, "w") as f:
f.write(f"{header}\n{content}")
def convert_traj_to_action_demo(traj_path: Path, output_file: Path, include_user: bool = False) -> None:
with open(traj_path) as file:
traj = json.load(file)
replay_config = traj["replay_config"]
if isinstance(traj["replay_config"], str):
replay_config = json.loads(traj["replay_config"])
history = traj["history"]
copy_fields = {"content", "role", "tool_calls", "agent", "message_type", "tool_call_ids"}
admissible_roles = {"assistant", "user", "tool"} if include_user else {"assistant"}
filtered_history = [
{k: v for k, v in step.items() if k in copy_fields}
for step in history
if step["role"] in admissible_roles
and step.get("agent", "main") in {"main", "primary"}
and not step.get("is_demo")
]
output_data = {"history": filtered_history, "replay_config": replay_config}
save_demo(output_data, output_file, traj_path)
logger.info(f"Saved demo to {output_file}")
def main(traj_path: Path, output_dir: Path, suffix: str = "", overwrite: bool = False, include_user: bool = False):
output_file = output_dir / (traj_path.parent.name + suffix) / (traj_path.stem.removesuffix(".traj") + ".demo.yaml")
if output_file.exists() and not overwrite:
msg = f"Output file already exists: {output_file}. Use --overwrite to overwrite."
raise FileExistsError(msg)
output_file.parent.mkdir(parents=True, exist_ok=True)
convert_traj_to_action_demo(traj_path, output_file, include_user)
def run_from_cli(args: list[str] | None = None):
"""Convert a trajectory file to a demo file."""
parser = ArgumentParser(description=__doc__)
parser.add_argument("traj_path", type=Path, help="Path to trajectory file")
parser.add_argument("--output_dir", type=Path, help="Output directory for action demos", default=Path("./demos"))
parser.add_argument("--suffix", type=str, help="Suffix for the output file", default="")
parser.add_argument("--overwrite", help="Overwrite existing files", action="store_true")
parser.add_argument(
"--include_user",
help="Include user responses (computer)",
action="store_true",
)
parsed_args = parser.parse_args(args)
main(**vars(parsed_args))
if __name__ == "__main__":
run_from_cli()

View File

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from pathlib import Path
import yaml
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from sweagent.tools.commands import Command
from sweagent.utils.config import _convert_path_to_abspath
class BundleConfig(BaseModel):
tools: dict[str, dict]
state_command: str | None = None
class Bundle(BaseModel):
path: Path
hidden_tools: list[str] = Field(default_factory=list)
_config: BundleConfig = PrivateAttr(default=None)
@model_validator(mode="after")
def validate_tools(self):
self.path = _convert_path_to_abspath(self.path)
if not self.path.exists():
msg = f"Bundle path '{self.path}' does not exist."
raise ValueError(msg)
config_path = self.path / "config.yaml"
if not config_path.exists():
msg = f"Bundle config file '{config_path}' does not exist."
raise ValueError(msg)
config_data = yaml.safe_load(config_path.read_text())
self._config = BundleConfig(**config_data)
invalid_hidden_tools = set(self.hidden_tools) - set(self._config.tools.keys())
if invalid_hidden_tools:
msg = f"Hidden tools {invalid_hidden_tools} do not exist in available tools"
raise ValueError(msg)
return self
@property
def state_command(self) -> str | None:
return self.config.state_command
@property
def config(self) -> BundleConfig:
return self._config
@property
def commands(self) -> list[Command]:
return [
Command(name=tool, **tool_config.model_dump() if isinstance(tool_config, Command) else tool_config)
for tool, tool_config in self.config.tools.items()
if tool not in self.hidden_tools
]

View File

@@ -0,0 +1,220 @@
"""
Core module for defining and parsing commands in the SWE Agent system.
This module provides the foundational classes and utilities for defining commands that can be executed by the agent.
It is used extensively by:
- tools.py: For command installation, execution and environment management
- parsing.py: For parsing model outputs into executable commands
- utils.py: For handling multi-line commands and argument quoting
Key Classes:
- Command: Represents an executable command with arguments and documentation
- Argument: Defines an argument that can be passed to a command
The module supports both simple bash commands and complex multi-line commands with typed arguments.
Commands can be defined either in bash scripts with YAML docstrings or as bash functions.
"""
from __future__ import annotations
import re
import string
from collections import Counter
from functools import cached_property
from pydantic import BaseModel, field_validator, model_validator
from sweagent.utils.jinja_warnings import _warn_probably_wrong_jinja_syntax
ARGUMENT_NAME_PATTERN = r"[a-zA-Z_][a-zA-Z0-9_-]*"
def _extract_keys(format_string: str) -> set[str]:
"""Given a format string, returns a set of all the keys in the format string.
Used for validating that command signatures match their argument definitions.
Args:
format_string: A Python format string containing named fields
Returns:
Set of field names found in the format string
"""
formatter = string.Formatter()
keys = set()
for _, field_name, _, _ in formatter.parse(format_string):
if field_name is not None:
keys.add(field_name)
return keys
class Argument(BaseModel):
f"""Defines an argument that can be passed to a command.
Attributes:
name: The argument name, must match {ARGUMENT_NAME_PATTERN!r}
type: The argument type (e.g. "string", "integer")
description: Human readable description of the argument
required: Whether this argument must be provided
enum: Optional list of allowed values
argument_format: Format string for how to render the argument value in the command
"""
name: str
type: str
items: dict[str, str] | None = None
description: str
required: bool
enum: list[str] | None = None
argument_format: str = "{{value}}"
"""How to invoke the argument in the command. Make sure to use jinja syntax ({{value}}) instead of {value})."""
@field_validator("argument_format")
def validate_argument_format(cls, value: str) -> str:
_warn_probably_wrong_jinja_syntax(value)
return value
class Command(BaseModel):
"""Represents an executable command with arguments and documentation.
A command can be either a simple bash command or a multi-line command terminated by an end marker.
Attributes:
name: The command name
docstring: Human readable description of what the command does
signature: Optional custom signature override
end_name: For multi-line commands, the terminating marker
arguments: List of arguments accepted by the command
Properties:
invoke_format: Format string for constructing the full command invocation
"""
name: str
docstring: str | None
signature: str | None = None
# if there is an end_name, then it is a multi-line command
end_name: str | None = None
arguments: list[Argument] = []
@cached_property
def invoke_format(self) -> str:
"""Gets the format string for invoking this command with arguments.
Returns either the custom signature with argument placeholders replaced,
or a default format of "command arg1 arg2 ...".
"""
if self.signature:
# First validate that all arguments are present in the original signature
for arg in self.arguments:
if not (
f"<{arg.name}>" in self.signature
or f"[<{arg.name}>]" in self.signature
or f"{{{arg.name}}}" in self.signature
or f"--{arg.name}" in self.signature
):
msg = (
f"Missing argument {arg.name} in signature: {self.signature}. Did you format the signature correctly? "
f"You must include all argument names in the signature with <{arg.name}>, [<{arg.name}>], {{{arg.name}}}, or --{arg.name} notation."
)
raise ValueError(msg)
# Then do the replacement
return re.sub(rf"\[?<({ARGUMENT_NAME_PATTERN})>\]?", r"{\1}", self.signature)
else:
# cmd arg_format_1 arg_format_2 ...
_invoke_format = f"{self.name} "
for arg in self.arguments:
_invoke_format += f"{{{arg.name}}} "
return _invoke_format
def get_function_calling_tool(self) -> dict:
"""Converts this command into an OpenAI function calling tool definition.
Returns:
Dict containing the OpenAI function schema for this command
"""
tool = {
"type": "function",
"function": {
"name": self.name,
"description": self.docstring or "",
},
}
properties = {}
required = []
if self.arguments:
for arg in self.arguments:
properties[arg.name] = {"type": arg.type, "description": arg.description}
if arg.items:
properties[arg.name]["items"] = arg.items
if arg.required:
required.append(arg.name)
# Handle enum if present
if arg.enum:
properties[arg.name]["enum"] = arg.enum
tool["function"]["parameters"] = {"type": "object", "properties": properties, "required": required}
return tool
@model_validator(mode="after")
def validate_arguments(self) -> Command:
"""Validates command argument configuration.
Checks:
- Required arguments come before optional ones
- Argument names are unique
- Argument names match the pattern
- Arguments match the signature
Returns:
The validated Command instance
Raises:
ValueError: If validation fails
"""
if not self.arguments:
return self
found_optional = False
for arg in self.arguments:
if found_optional and arg.required:
msg = f"Command '{self.name}': Required argument '{arg.name}' cannot come after optional arguments"
raise ValueError(msg)
if not arg.required:
found_optional = True
name_counts = Counter(arg.name for arg in self.arguments)
duplicates = {name for name, count in name_counts.items() if count > 1}
if duplicates:
msg = f"Command '{self.name}': Duplicate argument names: {duplicates}"
raise ValueError(msg)
for arg in self.arguments:
if not re.match(ARGUMENT_NAME_PATTERN, arg.name):
msg = f"Command '{self.name}': Invalid argument name: '{arg.name}'"
raise ValueError(msg)
if (invoke_keys := _extract_keys(self.invoke_format)) != {arg.name for arg in self.arguments}:
msg = f"Command '{self.name}': Argument names ({invoke_keys}) in signature / invoke_format {self.invoke_format!r} do not match argument names"
raise ValueError(msg)
return self
# Default Bash tool
BASH_COMMAND = Command(
name="bash",
# name="execute_bash",
signature="<command>",
# signature="echo '<command>'\n<command>\necho \"root@workspace:${{PWD}} #\n[Command finished with exit code ${{?}}]\"",
docstring="runs the given command directly in bash",
arguments=[
Argument(
name="command",
type="string",
description="The bash command to execute.",
required=True,
)
],
)

View File

@@ -0,0 +1,619 @@
"""Our parsers parse output from the LM into thoughts and actions.
For example, our most basic parser is the `ThoughtActionParser`.
It expects the model response to be a discussion followed by a command wrapped in backticks like so:
```
Let's look at the files in the current directory.
Action:
```
ls -l
```
```
For models that support function calling, we instead recommend using the `FunctionCallingParser`.
To use a specific parser, set the `parse_function` key in your tool config to the `type` field of the parser.
```yaml
agent:
tools:
...
parse_function:
type: "thought_action"
```
Or from the command line: `--agent.tools.parse_function.type=thought_action`.
!!! note "Describing available tools"
If you do not use the `FunctionCallingParser`, you need to include documentation about the available tools
in your system prompt. You can use the `{{command_docs}}` variable to include the automatically generated
documentation or explicitly describe the available tools.
Also see [#1130](https://github.com/SWE-agent/SWE-agent/issues/1130).
"""
import json
import re
import textwrap
from abc import ABC, abstractmethod
from shlex import quote
from textwrap import dedent
from typing import Any, Literal
from jinja2 import Template
from pydantic import BaseModel
from sweagent.exceptions import FormatError, FunctionCallingFormatError
from sweagent.tools.commands import Command
from sweagent.tools.utils import _should_quote
class AbstractParseFunction(ABC):
"""
Abstract class for parsing functions.
We use get to generate the right parser based on the name of the parser.
"""
error_message: str
@abstractmethod
def __call__(self, model_response, commands: list[Command], strict=False) -> tuple[str, str]:
raise NotImplementedError
@property
def format_error_template(self):
return textwrap.dedent(self.error_message)
# DEFINE NEW PARSING FUNCTIONS BELOW THIS LINE
class ActionParser(AbstractParseFunction, BaseModel):
"""
Expects the model response to be a single command.
Example: "ls -l"
"""
error_message: str = """\
The command you provided was not recognized. Please specify one of the commands (+ any necessary arguments) from the following list in your response. Do not include any other text.
COMMANDS:
{command_docs}
"""
type: Literal["action"] = "action"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False):
if model_response["message"].split():
action = model_response["message"].strip().split()[0]
if action in {command.name for command in commands}:
return model_response["message"], model_response["message"]
msg = "First word in model response is not a valid command."
raise FormatError(msg)
class ActionOnlyParser(AbstractParseFunction, BaseModel):
"""Expects the model response to be a single command."""
error_message: str = "No message found in model response."
type: Literal["action_only"] = "action_only"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False):
return "", model_response["message"]
class ThoughtActionParser(AbstractParseFunction, BaseModel):
"""
Expects the model response to be a discussion followed by a command wrapped in backticks.
Example:
Let's look at the files in the current directory.
```
ls -l
```
"""
error_message: str = dedent("""\
Your output was not formatted correctly. You must always include one discussion and one command as part of your response. Make sure you do not have multiple discussion/command tags.
Please make sure your output precisely matches the following format:
DISCUSSION
Discuss here with yourself about what your planning and what you're going to do in this step.
```
command(s) that you're going to run
```
""")
type: Literal["thought_action"] = "thought_action"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False):
"""
Parses the action from the output of the API call.
We assume that the action is the last code block in the model_response.
We also assume that the action is not nested within another code block.
This is problematic if the model_response includes many unnamed ``` blocks.
For instance:
```
This is a code block.
```
```
This is another code block.
```
In this case, only the second code block will be parsed as the action.
"""
code_block_pat = re.compile(r"^```(\S*)\s*\n|^```\s*$", re.MULTILINE)
stack = []
last_valid_block = None
for match in code_block_pat.finditer(model_response["message"]):
if stack and not match.group(1): # Closing of a code block
start = stack.pop()
# Check if it's not nested within another block
if not stack:
last_valid_block = (start, match)
elif match.group(1) is not None: # Opening of a code block
stack.append(match)
if last_valid_block:
start, end = last_valid_block
thought = model_response["message"][: start.start()] + model_response["message"][end.end() :]
return thought, model_response["message"][start.end() : end.start()]
msg = "No action found in model response."
raise FormatError(msg)
class XMLThoughtActionParser(AbstractParseFunction, BaseModel):
"""
Expects the model response to be a discussion followed by a command wrapped in XML tags.
Example:
Let's look at the files in the current directory.
<command>
ls -l
</command>
"""
error_message: str = dedent("""\
Your output was not formatted correctly. You must always include one discussion and one command as part of your response. Make sure you do not have multiple discussion/command tags.
Please make sure your output precisely matches the following format:
""")
type: Literal["xml_thought_action"] = "xml_thought_action"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
"""
Parses the action from the output of the API call.
We assume that the action is the last code block in the model_response.
We also assume that the action is not nested within another code block.
This is problematic if the model_response includes many unnamed ``` blocks.
For instance:
<command>
This is a code block.
</command>
<command>
This is another code block.
</command>
In this case, only the second code block will be parsed as the action.
"""
if "<command>" not in model_response["message"] or "</command>" not in model_response["message"]:
msg = "No action found in model response."
raise FormatError(msg)
# `action` is everything between the last <command> and </command> tags
start_action = model_response["message"].rfind("<command>") + len(
"<command>"
) # start after the last <command> tag
end_thought = model_response["message"].rfind("<command>") # end before the last <command> tag
end_action = model_response["message"].rfind("</command>") # end before the last </command> tag
restart_thought = model_response["message"].rfind("</command>") + len(
"</command>"
) # start after the last </command> tag
# `thought` is everything not in between <command> and </command> tags (includes after the last </command> tag)
action = model_response["message"][start_action:end_action]
thought = model_response["message"][:end_thought] + model_response["message"][restart_thought:]
return thought.strip(), action.strip()
FN_REGEX_PATTERN = r"<function=([^>]+)>\n(.*?)</function>"
FN_PARAM_REGEX_PATTERN = r"<parameter=([^>]+)>(.*?)</parameter>"
class XMLFunctionCallingParser(AbstractParseFunction, BaseModel):
"""
Expects the model response to be a tool calling format, where the command and parameters are specified
in XML tags.
Example:
Let's look at the files in the current directory.
<function=bash>
<parameter=command>find /testbed -type f -name "_discovery.py"</parameter>
</function>
"""
error_message: str = dedent("""\
{%- if error_code == "missing" -%}
Your last output did not use any tool calls!
Please make sure your output includes exactly _ONE_ function call!
If you think you have already resolved the issue, please submit your changes by running the `submit` command.
If you think you cannot solve the problem, please run `submit`.
Else, please continue with a new tool call!
{%- elif error_code == "multiple" -%}
Your last output included multiple tool calls!
Please make sure your output includes a thought and exactly _ONE_ function call.
{%- elif error_code == "unexpected_arg" -%}
Your action could not be parsed properly: {{exception_message}}.
Make sure your function call doesn't include any extra arguments that are not in the allowed arguments, and only use the allowed commands.
{%- else -%}
Your action could not be parsed properly: {{exception_message}}.
{% endif %}
""")
type: Literal["xml_function_calling"] = "xml_function_calling"
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
fn_match = re.search(FN_REGEX_PATTERN, model_response["message"], re.DOTALL)
if not fn_match:
msg = "No function found in model response."
raise FormatError(msg)
fn_name = fn_match.group(1).strip()
# Handle different names in SWE-agent vs. SWE-gym
if fn_name == "execute_bash":
fn_name = "bash"
if fn_name == "finish":
fn_name = "submit"
fn_body = fn_match.group(2)
thought = model_response["message"][: fn_match.start()] + model_response["message"][fn_match.end() :]
thought = thought.strip()
commands_dict = {c.name: c for c in commands}
command = commands_dict.get(fn_name)
if not command:
msg = f"Command '{fn_name}' not found in list of available commands."
raise FormatError(msg)
params_dict = {
param[0]: re.sub(r"^\n|\n$", "", param[1])
for param in re.findall(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
}
if "view_range" in params_dict:
# Check that value is format as [x, y]
v = params_dict["view_range"]
if isinstance(v, str):
if not re.match(r"\[\d+,\s*\d+\]", v):
msg = f"view_range must be in the format [<start>, <end>], got {v}."
raise FormatError(msg)
params_dict["view_range"] = json.loads(v)
# Check if all required arguments are there
required_args = {arg.name for arg in command.arguments if arg.required}
missing_args = required_args - params_dict.keys()
if missing_args:
msg = f"Required argument(s) missing: {', '.join(missing_args)}"
raise FormatError(msg)
# Check if all arguments are valid
valid_args = {arg.name for arg in command.arguments}
extra_args = set(params_dict.keys()) - valid_args
if command.end_name:
# sometimes the model will include the end_name in the arguments - just ignore it
extra_args.discard(command.end_name)
if extra_args:
msg = f"Unexpected argument(s): {', '.join(extra_args)}"
raise FormatError(msg)
# Format arguments using their individual argument_format
formatted_args = {
arg.name: Template(arg.argument_format).render(
value=quote(params_dict[arg.name])
if _should_quote(params_dict[arg.name], command)
else params_dict[arg.name]
)
if arg.name in params_dict
else ""
for arg in command.arguments
}
return thought, command.invoke_format.format(**formatted_args).strip()
class EditFormat(ThoughtActionParser, BaseModel):
"""
Expects the model response to be a discussion followed by a command wrapped in backticks.
Example:
We'll replace the contents of the current window with the following:
```
import os
os.listdir()
```
"""
error_message: str = dedent("""\
Your output was not formatted correctly. You must wrap the replacement text in backticks (```).
Please make sure your output precisely matches the following format:
COMMENTS
You can write comments here about what you're going to do if you want.
```
New window contents.
Make sure you copy the entire contents of the window here, with the required indentation.
Make the changes to the window above directly in this window.
Remember that all of the window's contents will be replaced with the contents of this window.
Don't include line numbers in your response.
```
""")
type: Literal["edit_format"] = "edit_format"
"""Type for (de)serialization. Do not change."""
class Identity(AbstractParseFunction, BaseModel):
"""This parser does not do any parsing. It just returns the model response as both the thought and action."""
error_message: str = """\
It seems like something went wrong with your output. Please try again.
"""
type: Literal["identity"] = "identity"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False) -> tuple[str, str]:
"""
This doesn't do any parsing. It just returns the model response as the thought and action.
"""
return model_response["message"], model_response["message"]
class FunctionCallingParser(AbstractParseFunction, BaseModel):
"""Expects the model response to be a LiteLLM tool call."""
error_message: str = dedent("""\
{%- if error_code == "missing" -%}
Your last output did not use any tool calls!
Please make sure your output includes exactly _ONE_ function call!
You must invoke the function directly using the function call format.
You cannot invoke commands with ```, you have to use the function call format.
If you think you have already resolved the issue, please submit your changes by running the `submit` command.
If you think you cannot solve the problem, please run `exit_forfeit` (if available) or `submit`.
Else, please continue with a new tool call!
{%- elif error_code == "multiple" -%}
Your last output included multiple tool calls!
Please make sure your output includes a thought and exactly _ONE_ function call.
{%- elif error_code == "unexpected_arg" -%}
Your action could not be parsed properly: {{exception_message}}.
Make sure your function call doesn't include any extra arguments that are not in the allowed arguments, and only use the allowed commands.
{%- else -%}
Your action could not be parsed properly: {{exception_message}}.
{% endif %}
""")
type: Literal["function_calling"] = "function_calling"
"""Type for (de)serialization. Do not change."""
def _parse_tool_call(self, tool_call: dict, commands: list[Command]):
name = tool_call["function"]["name"]
command = {c.name: c for c in commands}.get(name)
if not command:
msg = f"Command '{name}' not found in list of available commands."
raise FunctionCallingFormatError(msg, "invalid_command")
if not isinstance(tool_call["function"]["arguments"], dict):
try:
values = json.loads(tool_call["function"]["arguments"])
except json.JSONDecodeError:
msg = "Tool call arguments are not valid JSON."
raise FunctionCallingFormatError(msg, "invalid_json")
required_args = {arg.name for arg in command.arguments if arg.required}
missing_args = required_args - values.keys()
if missing_args:
msg = f"Required argument(s) missing: {', '.join(missing_args)}"
raise FunctionCallingFormatError(msg, "missing_arg")
valid_args = {arg.name for arg in command.arguments}
extra_args = set(values.keys()) - valid_args
if command.end_name:
# sometimes the model will include the end_name in the arguments - just ignore it
extra_args.discard(command.end_name)
if extra_args:
msg = f"Unexpected argument(s): {', '.join(extra_args)}"
raise FunctionCallingFormatError(msg, "unexpected_arg")
def get_quoted_arg(value: Any) -> str:
if isinstance(value, str):
return quote(value) if _should_quote(value, command) else value
# See https://github.com/SWE-agent/SWE-agent/issues/1159
if value is None:
return ""
return value
formatted_args = {
arg.name: Template(arg.argument_format).render(value=get_quoted_arg(values[arg.name]))
if arg.name in values
else ""
for arg in command.arguments
}
return command.invoke_format.format(**formatted_args).strip()
def __call__(self, model_response: dict, commands: list[Command], strict=False):
message = model_response["message"]
tool_calls = model_response.get("tool_calls", None)
if tool_calls is None or len(tool_calls) != 1:
num_tools = len(tool_calls) if tool_calls else 0
msg = (
f"Expected exactly one tool call in model response - received {num_tools} "
f"tool calls with message: {message}"
)
error_code = "missing" if num_tools == 0 else "multiple"
raise FunctionCallingFormatError(msg, error_code, num_tools=num_tools)
tool_call = tool_calls[0]
action = self._parse_tool_call(tool_call, commands)
return message, action
class JsonParser(AbstractParseFunction, BaseModel):
"""Expects the model response to be a JSON object."""
error_message: str = dedent("""\
Your output could not be parsed as JSON. Please make sure your output 1) is valid JSON and
2) Includes the "thought" and "command" fields.
""")
type: Literal["json"] = "json"
"""Type for (de)serialization. Do not change."""
def __call__(self, model_response: dict, commands: list[Command], strict=False):
"""Parses the action from the output of the API call.
We assume that model output is a JSON object with the following fields:
{
"thought": "discussion text here.",
"command": {
"arguments": {
"arg1": "value1",
"arg2": "value2",
...
},
"name": "command_name"
}
}
"""
try:
data = json.loads(model_response["message"])
if not isinstance(data, dict):
msg = "Model output is not a JSON object."
raise FormatError(msg)
# Check if required keys are present
required_keys = ["thought", "command"]
for key in required_keys:
if key not in data:
msg = f"Key '{key}' is missing from model output."
raise FormatError(msg)
# Check structure of 'command' key
data_command = data["command"]
if not isinstance(data_command, dict):
msg = "Value of 'command' key is not a JSON object."
raise FormatError(msg)
# Check if required keys are present in 'command' object
command_keys = ["name"]
for key in command_keys:
if key not in data_command:
msg = f"Key '{key}' is missing from 'command' object."
raise FormatError(msg)
thought = data["thought"]
commands_dict = {c.name: c for c in commands}
command = commands_dict.get(data_command["name"])
# Handle command parsing based on strict mode
if command is None:
if strict:
msg = f"Command '{data_command['name']}' not found in list of available commands."
raise FormatError(msg)
# In non-strict mode, just join command name with argument values
return thought, " ".join([data_command["name"], *data_command.get("arguments", {}).values()])
# Format arguments using their individual argument_format
formatted_args = {}
if command.arguments:
for arg in command.arguments:
if arg.name in data_command.get("arguments", {}):
value = data_command["arguments"][arg.name]
if _should_quote(value, command):
value = quote(value)
formatted_args[arg.name] = Template(arg.argument_format).render(value=value)
elif strict and arg.required:
msg = f"Required argument '{arg.name}' missing for command '{command.name}'"
raise FormatError(msg)
# Use the formatted arguments with invoke_format
action = command.invoke_format.format(**formatted_args).strip()
return thought, action
except json.JSONDecodeError:
msg = "Model output is not valid JSON."
raise FormatError(msg)
class BashCodeBlockParser(AbstractParseFunction, BaseModel):
"""Executes all commands in ```bash code blocks."""
error_message: str = dedent("""\
No bash code blocks were detected in your output.
You need to include at least one bash code block in your output.
It must follow this format exactly to be valid:
```bash
cmd arg1 arg2 ...
...
Other types of code blocks (e.g. python, rust, none, etc.) won't be executed. Only bash.
""")
type: Literal["all_bash_code_blocks"] = "all_bash_code_blocks"
def __call__(self, model_response: dict, commands: list[Command], strict=False):
"""Parses the action from the output of the API call.
We assume that model output is a JSON object with the following fields:
"""
pattern = re.compile(r"```bash\n(.*?)\n```", re.DOTALL)
matches = pattern.findall(model_response["message"])
if not matches:
msg = "No bash code blocks were detected in your output."
raise FormatError(msg)
thought = pattern.sub("<extracted_code_block>", model_response["message"])
action = "\n".join(matches)
return thought, action
class SingleBashCodeBlockParser(AbstractParseFunction, BaseModel):
"""Executes all commands in ```bash code blocks."""
error_message: str = dedent("""\
We did not detect the right number of bash code blocks in your output.
You need to include EXACTLY ONE bash code block in your output.
It must follow this format exactly to be valid:
```bash
cmd arg1 arg2 ...
```
""")
type: Literal["single_bash_code_block"] = "single_bash_code_block"
def __call__(self, model_response: dict, commands: list[Command], strict=False):
"""Parses the action from the output of the API call.
We assume that model output is a JSON object with the following fields:
"""
pattern = re.compile(r"```bash\n(.*?)\n```", re.DOTALL)
matches = pattern.findall(model_response["message"])
if not matches:
msg = "No bash code blocks were detected in your output."
raise FormatError(msg)
if len(matches) > 1:
msg = (
"We detected multiple bash code blocks in your output. "
"You need to include EXACTLY ONE bash code block in your output."
)
raise FormatError(msg)
thought = pattern.sub("<extracted_code_block>", model_response["message"])
action = "\n".join(matches)
return thought, action
ParseFunction = (
ActionParser
| ThoughtActionParser
| ActionOnlyParser
| XMLThoughtActionParser
| XMLFunctionCallingParser
| FunctionCallingParser
| EditFormat
| Identity
| JsonParser
| BashCodeBlockParser
| SingleBashCodeBlockParser
)

View File

@@ -0,0 +1,430 @@
"""
This module contains the configuration for the tools that are made available to the agent.
The `ToolConfig` class is used to configure the tools that are available to the agent.
The `ToolHandler` class is used to handle the tools that are available to the agent.
"""
import asyncio
import json
import os
import re
from functools import cached_property
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
from swerex.runtime.abstract import Command as RexCommand
from swerex.runtime.abstract import UploadRequest
from typing_extensions import Self
from sweagent.environment.swe_env import SWEEnv
from sweagent.tools.bundle import Bundle
from sweagent.tools.commands import BASH_COMMAND, Command
from sweagent.tools.parsing import FunctionCallingParser, JsonParser, ParseFunction
from sweagent.tools.utils import _guard_multiline_input, generate_command_docs
from sweagent.utils.log import get_logger
class ToolFilterConfig(BaseModel):
"""Filter out commands that are blocked by the environment
(for example interactive commands like `vim`).
"""
blocklist_error_template: str = "Operation '{{action}}' is not supported by this environment."
"""The error template to use when a command is blocked."""
blocklist: list[str] = [
"vim",
"vi",
"emacs",
"nano",
"nohup",
"gdb",
"less",
"tail -f",
"python -m venv",
"make",
]
"""Block any command that starts with one of these"""
blocklist_standalone: list[str] = [
"python",
"python3",
"ipython",
"bash",
"sh",
"/bin/bash",
"/bin/sh",
"nohup",
"vi",
"vim",
"emacs",
"nano",
"su",
]
"""Block any command that matches one of these exactly"""
block_unless_regex: dict[str, str] = {
"radare2": r"\b(?:radare2)\b.*\s+-c\s+.*",
"r2": r"\b(?:radare2)\b.*\s+-c\s+.*",
}
"""Block any command that matches one of these names unless it also matches the regex"""
class ToolConfig(BaseModel):
"""Configuration for the tools that are made available to the agent."""
filter: ToolFilterConfig = ToolFilterConfig()
"""Filter out commands that are blocked by the environment
(for example interactive commands like `vim`).
"""
bundles: list[Bundle] = Field(default_factory=list)
"""The tool bundles to load."""
propagate_env_variables: list[str] = []
"""Environment variables to propagate to the environment.
This is useful if you want to propagate API keys or similar from your own environment to the
environment in which the tools run.
IMPORTANT NOTE: The value of the environment variables can be read in debug log files,
so be careful with your API keys!
"""
env_variables: dict[str, Any] = {
"PAGER": "cat",
"MANPAGER": "cat",
"LESS": "-R",
"PIP_PROGRESS_BAR": "off",
"TQDM_DISABLE": "1",
"GIT_PAGER": "cat",
}
"""Shorthand to set environment variables for the tools, effectively
equivalent to adding `export VARNAME=value` to the `reset_commands`.
"""
registry_variables: dict[str, Any] = {}
"""Populate the registry with these variables. Will be written out as json in the registry file."""
submit_command: str = "submit"
"""The command/tool to use to submit the solution."""
parse_function: ParseFunction = Field(default_factory=FunctionCallingParser)
"""The action parser that is responsible for parsing the model output into a thought and action.
"""
enable_bash_tool: bool = True
"""Whether to enable the bash tool in addition to the other tools specified in bundles."""
format_error_template: str = None # type: ignore
"""Defaults to format_error_template in ParseFunction"""
command_docs: str = None # type: ignore
"""Automatically generated documentation generated based on
the loaded tool bundles.
"""
multi_line_command_endings: dict[str, str] = {}
submit_command_end_name: str | None = None
"""Commands to install dependencies and tools.
These commands are executed in a subprocess and are not part of the environment state.
"""
reset_commands: list[str | list[str]] = []
"""Commands to reset the environment. They will also be called when we start the environment.
Unlike `install_commands`, these commands are part of the environment state.
"""
execution_timeout: int = 30
"""Timeout for executing commands in the environment"""
install_timeout: int = 300
"""Timeout used for each of the installation commands"""
total_execution_timeout: int = 1800
"""Timeout for executing all commands in the environment.
Note: Does not interrupt running commands, but will stop the agent for the next step.
"""
max_consecutive_execution_timeouts: int = 3
"""Maximum number of consecutive execution timeouts before the agent exits.
"""
@cached_property
def use_function_calling(self) -> bool:
return isinstance(self.parse_function, FunctionCallingParser)
@cached_property
def state_commands(self) -> list[str]:
"""This property returns the state commands from all bundles.
State commands are commands that are used to get the state of the environment
(e.g., the current working directory).
"""
return [bundle.state_command for bundle in self.bundles if bundle.state_command]
# todo: move to ToolHandler?
@cached_property
def commands(self) -> list[Command]:
"""Read command files and return parsed command objects"""
commands = []
tool_sources: dict[str, Path] = {} # Track which file each tool comes from
# Add bash command if enabled
if self.enable_bash_tool:
commands.append(BASH_COMMAND)
tool_sources[BASH_COMMAND.name] = Path("<builtin>")
# Collect commands from all bundles
for bundle in self.bundles:
for command in bundle.commands:
if command.name in tool_sources:
existing_source = tool_sources[command.name]
msg = (
f"Tool '{command.name}' is defined multiple times:\n"
f" - First definition in: {existing_source}\n"
f" - Duplicate definition in: {bundle.path}"
)
raise ValueError(msg)
commands.append(command)
tool_sources[command.name] = bundle.path
return commands
@cached_property
def tools(self) -> list[dict]:
return [command.get_function_calling_tool() for command in self.commands]
# todo: can some of these be moved to ToolHandler?
def model_post_init(self, __context):
# for caching:
commands = self.commands
multi_line_command_endings = {
command.name: command.end_name for command in commands if command.end_name is not None
}
self.tools
# assert not self.enable_bash_tool and parse_function is FunctionCallingParser or JsonParser
if not self.enable_bash_tool and not (
isinstance(self.parse_function, FunctionCallingParser) or isinstance(self.parse_function, JsonParser)
):
msg = f"Bash tool can only be disabled if {FunctionCallingParser.type} parser or {JsonParser.type} parser is used."
raise ValueError(msg)
self.multi_line_command_endings = multi_line_command_endings
self.command_docs = generate_command_docs(
self.commands,
[],
**self.env_variables,
)
if self.format_error_template is None:
self.format_error_template = self.parse_function.format_error_template
for command in commands:
if command.name == self.submit_command:
self.submit_command_end_name = command.end_name
break
class ToolHandler:
def __init__(self, tools: ToolConfig):
"""This class handles most of the tool usage. It has the following responsibilities:
- Install the tools
- Parse commands and handle multiline commands
- Decide if an action should be blocked
- Get the current state of the environment
"""
# Always copy config to avoid shared state between different instances across threads
self.config = tools.model_copy(deep=True)
# partially initialized in `install_commands`.
self._reset_commands = []
self._command_patterns = self._get_command_patterns()
self.logger = get_logger("swea-tools", emoji="🧰")
# For testing: Return this state instead of querying the environment
self.mock_state: dict[str, str] | None = None
@classmethod
def from_config(cls, config: ToolConfig) -> Self:
return cls(config)
# Installation & Reset
# --------------------
def install(self, env: SWEEnv) -> None:
self._install_commands(env)
self.reset(env)
def reset(self, env: SWEEnv) -> None:
self.logger.info("Resetting tools")
env_variables = self.config.env_variables.copy() | {
var: os.getenv(var) for var in self.config.propagate_env_variables
}
env.set_env_variables(env_variables)
env.write_file("/root/.swe-agent-env", json.dumps(self.config.registry_variables))
env.write_file("/root/state.json", "{}")
env.communicate(" && ".join(self._reset_commands), check="raise", timeout=self.config.install_timeout)
async def _upload_bundles(self, env: SWEEnv) -> None:
await asyncio.gather(
*(
env.deployment.runtime.upload(
UploadRequest(source_path=bundle.path.as_posix(), target_path=f"/root/tools/{bundle.path.name}")
)
for bundle in self.config.bundles
)
)
async def _is_command_available(self, env, command: str, env_vars: dict[str, str]) -> None:
if command == "bash":
return
try:
await env.deployment.runtime.execute(
RexCommand(command=f"which {command}", shell=True, check=True, env=env_vars)
)
except Exception:
msg = f"Tool {command} is not available in the container."
raise RuntimeError(msg) from None
async def _check_available_commands(self, env: SWEEnv, env_vars: dict[str, str]) -> None:
await asyncio.gather(
*(self._is_command_available(env, command.name, env_vars) for command in self.config.commands)
)
def _install_commands(self, env: SWEEnv) -> None:
"""Make sure all commands are available in the container"""
env.set_env_variables(self.config.env_variables)
cwd = env.communicate("pwd", check="raise").strip()
asyncio.run(self._upload_bundles(env))
for bundle in self.config.bundles:
cmds = [
f"export PATH=/root/tools/{bundle.path.name}/bin:$PATH",
f"chmod +x /root/tools/{bundle.path.name}/bin/*",
]
if (bundle.path / "install.sh").exists():
cmds.append(f"cd /root/tools/{bundle.path.name} && source install.sh")
cmds.append(f"chmod +x /root/tools/{bundle.path.name}/bin/*")
env.communicate(
" && ".join(cmds),
check="raise",
timeout=self.config.install_timeout,
)
env.communicate(f"cd {cwd}", check="raise")
path = env.communicate("echo $PATH", check="raise").strip()
asyncio.run(self._check_available_commands(env, {"PATH": path}))
# Getting state
# -------------
def _get_state(self, env: SWEEnv) -> dict[str, str]:
"""Retrieve the state from the environment"""
try:
state_str = env.read_file("/root/state.json")
except FileNotFoundError:
self.logger.warning("State file not found, returning empty state")
return {}
if not state_str.strip():
self.logger.warning("State file is empty, returning empty state")
return {}
try:
state = json.loads(state_str)
except json.JSONDecodeError as e:
msg = f"State {state_str!r} is not valid json. This is an internal error, please report it."
raise ValueError(msg) from e
if not isinstance(state, dict):
msg = f"State commands must return a dictionary. Got {state!r} instead."
raise ValueError(msg)
return state
def get_state(self, env: SWEEnv) -> dict[str, str]:
"""Execute state commands from all bundles and combine their results.
This can be used to extract environment variables etc. from the environment.
"""
if self.mock_state is not None:
return self.mock_state
for state_command in self.config.state_commands:
env.communicate(state_command, check="warn")
combined_state = self._get_state(env)
self.logger.debug(f"Retrieved state from environment: {combined_state}")
return combined_state
# Blocking
# --------
def should_block_action(self, action: str) -> bool:
"""Check if the command should be blocked."""
action = action.strip()
if not action:
return False
if any(action.startswith(f) for f in self.config.filter.blocklist):
return True
if action in self.config.filter.blocklist_standalone:
return True
name = action.split()[0]
if name in self.config.filter.block_unless_regex and not re.search(
self.config.filter.block_unless_regex[name], action
):
return True
return False
# Parsing & multiline commands
# -----------------------------
def check_for_submission_cmd(self, output: str) -> bool:
"""Function for checking submission request."""
if r"<<SWE_AGENT_SUBMISSION>>" in output:
return True
return False
def parse_actions(self, output: dict) -> tuple[str, str]:
"""Parse the model output into a thought and action."""
return self.config.parse_function(output, self.config.commands)
def guard_multiline_input(self, action: str) -> str:
"""Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
"""
return _guard_multiline_input(action, self._get_first_multiline_cmd)
def _get_first_multiline_cmd(self, action: str) -> re.Match | None:
"""Return the first match of a command pattern in the action string.
Where first match is defined by the start of the match.
The match object has three groups: (1) command name, (2) command arguments, (3) end name
"""
patterns = {
k: v
for k, v in self._command_patterns.items()
if k in self.config.multi_line_command_endings or k == self.config.submit_command
}
matches = list()
for _, pat in patterns.items():
match = pat.search(action)
if match:
matches.append(match)
if len(matches) == 0:
return None
matches = sorted(matches, key=lambda x: x.start())
return matches[0]
def _get_command_patterns(self) -> dict[str, re.Pattern]:
"""Creates regular expressions for the commands"""
_command_patterns = {}
for command in self.config.commands:
if command.end_name is not None:
pat = re.compile(
rf"^\s*({command.name})\s*(.*?)^({command.end_name})\s*$",
re.DOTALL | re.MULTILINE,
)
_command_patterns[command.name] = pat
else:
pat = re.compile(rf"^\s*({command.name})\s*(.*?)$", re.MULTILINE)
_command_patterns[command.name] = pat
submit_pat = re.compile(
rf"^\s*({self.config.submit_command})\s*(.*?)^({self.config.submit_command_end_name})\s*$",
re.DOTALL | re.MULTILINE,
)
_command_patterns[self.config.submit_command] = submit_pat
return _command_patterns

View File

@@ -0,0 +1,108 @@
import re
from collections.abc import Callable
from typing import Any
from sweagent.tools.commands import Command
def _guard_multiline_input(action: str, match_fct: Callable[[str], re.Match | None]) -> str:
"""Split action by multiline commands, then append the first line in each multiline command with "<< '{end_name}'".
Multiline commands (which are specified by an end_name) are commands that span multiple lines and are terminated by a specific end_name.
Their multi-line argument is sent using a heredoc, which is a way to send a multi-line string to a command in bash.
"""
parsed_action = []
rem_action = action
while rem_action.strip():
first_match = match_fct(rem_action)
if first_match:
pre_action = rem_action[: first_match.start()]
match_action = rem_action[first_match.start() : first_match.end()]
rem_action = rem_action[first_match.end() :]
if pre_action.strip():
parsed_action.append(pre_action)
if match_action.strip():
eof = first_match.group(3).strip()
if not match_action.split("\n")[0].strip().endswith(f"<< '{eof}'"):
guarded_command = match_action[first_match.start() :]
first_line = guarded_command.split("\n")[0]
guarded_command = guarded_command.replace(first_line, first_line + f" << '{eof}'", 1)
parsed_action.append(guarded_command)
else:
parsed_action.append(match_action)
else:
parsed_action.append(rem_action)
rem_action = ""
return "\n".join(parsed_action)
def _should_quote(value: Any, command: Command) -> bool:
"""Returns True if the value should be quoted, False otherwise."""
if command.name == "bash":
return False
return isinstance(value, str) and command.end_name is None
def get_signature(cmd):
"""Generate a command signature from its arguments.
Args:
cmd: Command object to generate signature for
Returns:
Formatted signature string
"""
signature = cmd.name
if "arguments" in cmd.__dict__ and cmd.arguments is not None:
if cmd.end_name is None:
for argument in cmd.arguments:
param = argument.name
if argument.required:
signature += f" <{param}>"
else:
signature += f" [<{param}>]"
else:
for argument in cmd.arguments[:-1]:
param = argument.name
if argument.required:
signature += f" <{param}>"
else:
signature += f" [<{param}>]"
signature += f"\n{list(cmd.arguments[-1].keys())[0]}\n{cmd.end_name}"
return signature
def generate_command_docs(
commands: list[Command],
subroutine_types,
**kwargs,
) -> str:
"""Generate detailed command documentation.
Format includes docstring, signature and argument details.
Args:
commands: List of commands to document
subroutine_types: List of subroutines to document
**kwargs: Additional format variables for docstrings
Returns:
Formatted documentation string
"""
docs = ""
for cmd in commands + subroutine_types:
docs += f"{cmd.name}:\n"
if cmd.docstring is not None:
docs += f" docstring: {cmd.docstring.format(**kwargs)}\n"
if cmd.signature is not None:
docs += f" signature: {cmd.signature}\n"
else:
docs += f" signature: {get_signature(cmd)}\n"
if cmd.arguments:
docs += " arguments:\n"
for argument in cmd.arguments:
param = argument.name
req_string = "required" if argument.required else "optional"
docs += f" - {param} ({argument.type}) [{req_string}]: {argument.description}\n"
docs += "\n"
return docs

102
.agent/vendor/mini-swe/sweagent/types.py vendored Normal file
View File

@@ -0,0 +1,102 @@
"""This file has types/dataclass definitions that are used in the SWE agent
for exchanging data between different modules/functions/classes.
They oftentimes cannot be defined in the same file where they are used
because of circular dependencies.
"""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel
from typing_extensions import TypedDict
class StepOutput(BaseModel):
query: list[dict] = [{}]
thought: str = ""
action: str = ""
output: str = ""
observation: str = ""
execution_time: float = 0.0
done: bool = False
exit_status: int | str | None = None
submission: str | None = None
state: dict[str, str] = {}
tool_calls: list[dict[str, Any]] | None = None
tool_call_ids: list[str] | None = None
thinking_blocks: list[dict[str, Any]] | None = None
"""State of the environment at the end of the step"""
extra_info: dict[str, Any] = {}
def to_template_format_dict(self) -> dict[str, str | int | float | bool | None]:
"""Used for formatting (error) prompt templates"""
out = {}
for k, v in self.model_dump().items():
if k in ("tool_calls", "tool_call_ids", "state"):
continue
out[k] = v
out |= self.state
return out
class TrajectoryStep(TypedDict):
action: str
observation: str
response: str
state: dict[str, str]
thought: str
execution_time: float
query: list[dict[str, Any]]
extra_info: dict[str, Any]
# required fields go here
class _HistoryItem(TypedDict):
role: str
content: str | list[dict[str, Any]]
message_type: Literal["thought", "action", "observation"]
# see _HistoryItem for required fields
class HistoryItem(_HistoryItem, total=False):
agent: str
is_demo: bool
thought: str
action: str | None
tool_calls: list[dict[str, str]] | None
tool_call_ids: list[str] | None
tags: list[str]
cache_control: dict[str, Any] | None
thinking_blocks: list[dict[str, Any]] | None
"""HistoryProcessors can add these tags to enable special processing"""
History = list[HistoryItem]
Trajectory = list[TrajectoryStep]
# todo: Make this actually have the dataclasses instead of dict versions
class AgentInfo(TypedDict, total=False):
# same as `APIStats` from models.py
model_stats: dict[str, float]
exit_status: str | None
submission: str | None
# same as `ReviewerResult`
review: dict[str, Any]
edited_files30: str
edited_files50: str
edited_files70: str
# only if summarizer is used
summarizer: dict
swe_agent_hash: str
swe_agent_version: str
swe_rex_version: str
swe_rex_hash: str
class AgentRunResult(BaseModel):
info: AgentInfo
trajectory: Trajectory

View File

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from sweagent import REPO_ROOT
from sweagent.utils.log import get_logger
logger = get_logger("swea-config", emoji="🔧")
def _convert_path_relative_to_repo_root(path: Path | str, root: Path | None = None) -> Path | str:
original_type = type(path)
path = Path(path).resolve()
root = Path(root or os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
relative_path = path.relative_to(root) if root in path.parents else path
return relative_path if original_type is Path else str(relative_path)
def _could_be_a_path(v: Any) -> bool:
try:
return Path(v).exists()
except Exception:
return False
def _strip_abspath_from_dict(value: dict | list | str, root: Path | None = None) -> dict | list | str:
root = Path(root or os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
if isinstance(value, dict):
return {k: _strip_abspath_from_dict(v, root) for k, v in value.items()}
elif isinstance(value, list):
return [_strip_abspath_from_dict(v, root) for v in value]
elif isinstance(value, str) and _could_be_a_path(value):
return _convert_path_relative_to_repo_root(value, root)
else:
return value
def _convert_path_to_abspath(path: Path | str) -> Path:
"""If path is not absolute, convert it to an absolute path
using the SWE_AGENT_CONFIG_ROOT environment variable (if set) or
REPO_ROOT as base.
"""
path = Path(path)
root = Path(os.getenv("SWE_AGENT_CONFIG_ROOT", REPO_ROOT))
assert root.is_dir()
if not path.is_absolute():
path = root / path
assert path.is_absolute()
return path.resolve()
def _convert_paths_to_abspath(paths: list[Path] | list[str]) -> list[Path]:
return [_convert_path_to_abspath(p) for p in paths]
def load_environment_variables(path: Path | None = None):
"""Load environment variables from a .env file.
If path is not provided, we first look for a .env file in the current working
directory and then in the repository root.
"""
if path is None:
cwd_path = Path.cwd() / ".env"
repo_path = REPO_ROOT / ".env"
if cwd_path.exists():
path = cwd_path
elif repo_path.exists():
path = REPO_ROOT / ".env"
else:
logger.debug("No .env file found")
return
if not path.is_file():
msg = f"No .env file found at {path}"
raise FileNotFoundError(msg)
anything_loaded = load_dotenv(dotenv_path=path)
if anything_loaded:
logger.info(f"Loaded environment variables from {path}")

View File

@@ -0,0 +1,27 @@
import json
from pathlib import Path
from typing import Any
import yaml
def load_file(path: Path | str | None) -> Any:
"""Load files based on their extension."""
if path is None:
return None
if isinstance(path, str):
path = Path(path)
if not path.exists():
raise FileNotFoundError(path)
if path.is_dir():
from datasets import load_from_disk
return load_from_disk(path)
if path.suffix in [".json", ".traj"]:
return json.loads(path.read_text())
if path.suffix == ".jsonl":
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
if path.suffix == ".yaml":
return yaml.safe_load(path.read_text())
msg = f"Unsupported file extension: {path.suffix}"
raise NotImplementedError(msg)

View File

@@ -0,0 +1,155 @@
import json
import re
import urllib.error
import urllib.request
from ghapi.all import GhApi
from sweagent.utils.log import get_logger
_logger = get_logger("swea-github", emoji="🔧")
_repo_privacy_cache: dict[str, bool] = {}
GITHUB_ISSUE_URL_PATTERN = re.compile(r"github\.com\/(.*?)\/(.*?)\/issues\/(\d+)")
class InvalidGithubURL(Exception):
"""Raised when a github URL is invalid"""
GITHUB_REPO_URL_PATTERN = re.compile(r".*[/@]?github\.com\/([^/]+)\/([^/]+)")
def _is_github_repo_url(data_path: str) -> bool:
"""Check if data_path is an URL pointing to a github repository.
Paths to issues or PRs will also match this pattern.
"""
return GITHUB_REPO_URL_PATTERN.search(data_path) is not None
def _is_github_issue_url(data_path: str) -> bool:
"""Check if data_path is an URL pointing to a github issue"""
return GITHUB_ISSUE_URL_PATTERN.search(data_path) is not None
def _get_commit(api: GhApi, owner: str, repo: str, ref: str | None = None):
"""Get commit object from github api
Args:
api (GhApi):
owner (str): Repo owner, e.g., "SWE-agent"
repo (str): Repo, e.g., "SWE-agent"
ref (str, optional): Branch, tag or commit hash
Returns:
_type_: _description_
"""
if ref:
return api.repos.get_commit(owner, repo, ref) # type: ignore
return api.repos.list_commits(owner, repo)[0] # type: ignore
def _parse_gh_issue_url(issue_url: str) -> tuple[str, str, str]:
"""
Returns:
owner: Repo owner
repo: Repo name
issue number: Issue number as str
Raises:
InvalidGithubURL: If the URL is not a valid github issue URL
"""
match = GITHUB_ISSUE_URL_PATTERN.search(issue_url)
if not match:
msg = f"Invalid GitHub issue URL: {issue_url}"
raise InvalidGithubURL(msg)
res = match.groups()
assert len(res) == 3
return tuple(res) # type: ignore
def _parse_gh_repo_url(repo_url: str) -> tuple[str, str]:
"""
Returns:
owner: Repo owner/org
repo: Repo name
Raises:
InvalidGithubURL: If the URL is not a valid github repo URL
"""
match = GITHUB_REPO_URL_PATTERN.search(repo_url)
if not match:
msg = f"Invalid GitHub issue URL: {repo_url}"
raise InvalidGithubURL(msg)
res = match.groups()
assert len(res) == 2
return tuple(res) # type: ignore
def _get_gh_issue_data(issue_url: str, *, token: str = ""):
"""Returns github issue data in the form of a dictionary.
See https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#get-an-issue
for return format
"""
owner, repo, issue_number = _parse_gh_issue_url(issue_url)
api = GhApi(token=token)
return api.issues.get(owner, repo, issue_number) # type: ignore
def _get_problem_statement_from_github_issue(
owner: str, repo: str, issue_number: str, *, token: str | None = ""
) -> str:
"""Return problem statement from github issue"""
api = GhApi(token=token)
issue = api.issues.get(owner, repo, issue_number) # type: ignore
title = issue.title if issue.title else ""
body = issue.body if issue.body else ""
return f"{title}\n{body}\n"
def _get_associated_commit_urls(org: str, repo: str, issue_number: str, *, token: str = "") -> list[str]:
"""Return the URLs of commits that would close an issue."""
api = GhApi(token=token)
# Strangely the "pull_request" field of api.issues.get is often not set
# so we have to go through the events to check if there's a commit
events = api.issues.list_events(org, repo, issue_number) # type: ignore
commit_urls = []
for event in events:
if event.event != "referenced":
continue
if not event.commit_id:
continue
commit = api.repos.get_commit(org, repo, event.commit_id) # type: ignore
message = commit.commit.message
if f"fixes #{issue_number}" in message.lower() or f"closes #{issue_number}" in message.lower():
commit_urls.append(commit.html_url)
return commit_urls
def _is_repo_private(owner_repo: str, token: str) -> bool:
"""Check if a GitHub repository is private via the GitHub API.
Returns True if the repo is private or if a 404 is returned (GitHub returns
404 for private repos when the token lacks access). Any other HTTP or
network error is raised so callers can handle it explicitly.
"""
if owner_repo in _repo_privacy_cache:
return _repo_privacy_cache[owner_repo]
url = f"https://api.github.com/repos/{owner_repo}"
headers = {"User-Agent": "sweagent"}
if token:
headers["Authorization"] = f"token {token}"
req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req) as resp:
data = json.loads(resp.read())
private = data.get("private", False)
except urllib.error.HTTPError as e:
if e.code == 404:
_logger.warning("Repo '%s' returned 404 — assuming private", owner_repo)
private = True
else:
raise
_repo_privacy_cache[owner_repo] = private
return private

View File

@@ -0,0 +1,14 @@
from sweagent.utils.log import get_logger
def _warn_probably_wrong_jinja_syntax(template: str | None) -> None:
"""Warn if the template uses {var} instead of {{var}}."""
if template is None:
return
if "{" not in template:
return
for s in ["{%", "{ %", "{{"]:
if s in template:
return
logger = get_logger("swea-config", emoji="🔧")
logger.warning("Probably wrong Jinja syntax in template: %s. Make sure to use {{var}} instead of {var}.", template)

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import logging
import os
import threading
import uuid
from collections.abc import Callable
from pathlib import Path, PurePath
from rich.logging import RichHandler
from rich.text import Text
_SET_UP_LOGGERS: set[str] = set()
_ADDITIONAL_HANDLERS: dict[str, logging.Handler] = {}
_LOG_LOCK = threading.Lock()
logging.TRACE = 5 # type: ignore
logging.addLevelName(logging.TRACE, "TRACE") # type: ignore
def _interpret_level(level: int | str | None, *, default=logging.DEBUG) -> int:
if not level:
return default
if isinstance(level, int):
return level
if level.isnumeric():
return int(level)
return getattr(logging, level.upper())
_STREAM_LEVEL = _interpret_level(os.environ.get("SWE_AGENT_LOG_STREAM_LEVEL"))
_INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER = False
_THREAD_NAME_TO_LOG_SUFFIX: dict[str, str] = {}
"""Mapping from thread name to suffix to add to the logger name."""
def register_thread_name(name: str) -> None:
"""Register a suffix to add to the logger name for the current thread."""
thread_name = threading.current_thread().name
_THREAD_NAME_TO_LOG_SUFFIX[thread_name] = name
class _RichHandlerWithEmoji(RichHandler):
def __init__(self, emoji: str, *args, **kwargs):
"""Subclass of RichHandler that adds an emoji to the log message."""
super().__init__(*args, **kwargs)
if not emoji.endswith(" "):
emoji += " "
self.emoji = emoji
def get_level_text(self, record: logging.LogRecord) -> Text:
level_name = record.levelname.replace("WARNING", "WARN")
return Text.styled((self.emoji + level_name).ljust(10), f"logging.level.{level_name.lower()}")
def get_logger(name: str, *, emoji: str = "") -> logging.Logger:
"""Get logger. Use this instead of `logging.getLogger` to ensure
that the logger is set up with the correct handlers.
"""
thread_name = threading.current_thread().name
if thread_name != "MainThread":
name = name + "-" + _THREAD_NAME_TO_LOG_SUFFIX.get(thread_name, thread_name)
logger = logging.getLogger(name)
if logger.hasHandlers():
# Already set up
return logger
handler = _RichHandlerWithEmoji(
emoji=emoji,
show_time=bool(os.environ.get("SWE_AGENT_LOG_TIME", False)),
show_path=False,
)
handler.setLevel(_STREAM_LEVEL)
# Set to lowest level and only use stream handlers to adjust levels
logger.setLevel(logging.TRACE) # type: ignore
logger.addHandler(handler)
logger.propagate = False
_SET_UP_LOGGERS.add(name)
with _LOG_LOCK:
for handler in _ADDITIONAL_HANDLERS.values():
my_filter = getattr(handler, "my_filter", None)
if my_filter is None:
logger.addHandler(handler)
elif isinstance(my_filter, str) and my_filter in name:
logger.addHandler(handler)
elif callable(my_filter) and my_filter(name):
logger.addHandler(handler)
if _INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER:
_add_logger_name_to_stream_handler(logger)
return logger
def add_file_handler(
path: PurePath | str,
*,
filter: str | Callable[[str], bool] | None = None,
level: int | str = logging.TRACE, # type: ignore[attr-defined]
id_: str = "",
) -> str:
"""Adds a file handler to all loggers that we have set up
and all future loggers that will be set up with `get_logger`.
Args:
filter: If str: Check that the logger name contains the filter string.
If callable: Check that the logger name satisfies the condition returned by the callable.
level: The level of the handler.
id_: The id of the handler. If not provided, a random id will be generated.
Returns:
The id of the handler. This can be used to remove the handler later.
"""
Path(path).parent.mkdir(parents=True, exist_ok=True)
handler = logging.FileHandler(path, encoding="utf-8")
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
handler.setFormatter(formatter)
handler.setLevel(_interpret_level(level))
with _LOG_LOCK:
# Lock because other thread might be modifying the _SET_UP_LOGGERS set
for name in _SET_UP_LOGGERS:
if filter is not None:
if isinstance(filter, str) and filter not in name:
continue
if callable(filter) and not filter(name):
continue
logger = logging.getLogger(name)
logger.addHandler(handler)
handler.my_filter = filter # type: ignore
if not id_:
id_ = str(uuid.uuid4())
_ADDITIONAL_HANDLERS[id_] = handler
return id_
def remove_file_handler(id_: str) -> None:
"""Remove a file handler by its id."""
handler = _ADDITIONAL_HANDLERS.pop(id_)
with _LOG_LOCK:
# Lock because other thread might be modifying the _SET_UP_LOGGERS set
for log_name in _SET_UP_LOGGERS:
logger = logging.getLogger(log_name)
logger.removeHandler(handler)
def _add_logger_name_to_stream_handler(logger: logging.Logger) -> None:
for handler in logger.handlers:
if isinstance(handler, _RichHandlerWithEmoji):
formatter = logging.Formatter("[%(name)s] %(message)s")
handler.setFormatter(formatter)
def add_logger_names_to_stream_handlers() -> None:
"""Add the logger name to the stream handler for all loggers that we have set up."""
global _INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER
_INCLUDE_LOGGER_NAME_IN_STREAM_HANDLER = True
with _LOG_LOCK:
for logger in _SET_UP_LOGGERS:
_add_logger_name_to_stream_handler(logging.getLogger(logger))
def set_stream_handler_levels(level: int) -> None:
"""Set the default stream level and adjust the levels of all stream handlers
to be at most the given level.
Note: Can only be used to lower the level, not raise it.
"""
global _STREAM_LEVEL
_STREAM_LEVEL = level
with _LOG_LOCK:
for name in _SET_UP_LOGGERS:
logger = logging.getLogger(name)
for handler in logger.handlers:
if isinstance(handler, _RichHandlerWithEmoji):
current_level = handler.level
if current_level < level:
handler.setLevel(level)

View File

@@ -0,0 +1,152 @@
from collections.abc import Callable
from unidiff import PatchSet
class PatchFormatter:
def __init__(
self,
patch: str,
read_method: Callable[[str], str],
):
"""Given the final patch and access to the container that contains the repository,
extract relevant lines from the modified file.
Args:
patch: The patch as a string.
read_method: Callable with path to file (relative to repository root) as argument
that returns the file content as a string.
"""
self._patch = PatchSet(patch)
self._patched_files: dict[str, str] = {}
self._original_files: dict[str, str] = {}
self._patch_applied = True
self._read_file = read_method
self._read_files(original=False)
@staticmethod
def _merge_intervals(starts: list[int], stops: list[int]) -> tuple[list[int], list[int]]:
"""Given two lists of integers, starts and stops, merges all overlapping intervals.
For example `starts=[1, 5, 18]`, `stops=[10, 13, 20]`
should return `starts=[1, 18]`, `stops=[13, 20]`
"""
if not starts:
assert not stops
return [], []
intervals = sorted(zip(starts, stops))
merged = []
for start, stop in intervals:
if not merged or merged[-1][1] < start:
# No overlap
merged.append([start, stop])
else:
# Overlap
merged[-1][1] = max(merged[-1][1], stop)
# Unzip again
merged_starts, merged_stops = zip(*merged)
return list(merged_starts), list(merged_stops)
def format_file(self, text: str, starts: list[int], stops: list[int], *, linenos: bool = True) -> str:
"""Reads file and returns string representation of the relevant lines.
Args:
path: The path to the file within the repo location
starts: The starting line numbers of the relevant lines. The first line is line 1.
stops: The stopping line numbers of the relevant lines. The stop is not inclusive.
The first line is line 1.
linenos: Whether to include line numbers
"""
if not starts:
assert not stops
return ""
assert len(starts) == len(stops)
assert all(start >= 1 for start in starts)
assert all(start < stop for start, stop in zip(starts, stops))
starts, stops = self._merge_intervals(starts, stops)
assert all(hunk1_start < hunk2_start for hunk1_start, hunk2_start in zip(starts, starts[1:]))
out: list[str] = []
if starts[0] > 1:
# Count from 1
out.append(f"[{starts[0] - 1} lines above omitted]")
last_stop: int | None = None
lines = text.splitlines()
for start, stop in zip(starts, stops):
assert start >= 1
if last_stop is not None:
n_omitted = start - last_stop
# Check that we have non-overlapping hunks
assert n_omitted >= 0
if n_omitted:
out.append(f"\n[{n_omitted} lines omitted]\n")
# Count from 1
these_lines = lines[start - 1 : stop - 1]
if linenos:
out.append("\n".join([f"{i:6d}: {l}" for i, l in enumerate(these_lines, start=start)]))
else:
out.append("\n".join(these_lines))
last_stop = stop
if last_stop < len(lines):
# Stop is not inclusive
omitted = len(lines) - last_stop
assert omitted > 0
out.append(f"[{omitted} lines below omitted]")
return "\n".join(out)
def _get_hunk_lines(self, original: bool, *, context_length: int) -> dict[str, tuple[list[int], list[int]]]:
"""Get the starts and stops for all files in the patch.
Args:
original: Whether to read the original file or the patched file
context_length: The number of lines to include above and below the hunk
Returns:
A dictionary with the file path as key and a tuple of lists of starts and stops as value.
"""
out: dict[str, tuple[list[int], list[int]]] = {}
for patch in self._patch:
if not patch.is_modified_file:
continue
starts: list[int] = []
stops: list[int] = []
for hunk in patch:
if original:
# 1 is the lowest line number
start = max(1, hunk.source_start - context_length)
stop = hunk.source_start + hunk.source_length + context_length
else:
start = max(1, hunk.target_start - context_length)
stop = hunk.target_start + hunk.target_length + context_length
starts.append(start)
stops.append(stop)
out[patch.path] = (starts, stops)
return out
def _read_files(self, original: bool) -> None:
for patch in self._patch:
path = patch.path
if not patch.is_modified_file:
continue
if original:
msg = "Original file reading not implemented"
raise NotImplementedError(msg)
else:
assert self._patch_applied
self._patched_files[path] = self._read_file(path)
@staticmethod
def concat_files_strings(files: dict[str, str]) -> str:
"""Concatenate multiple `read_files` outputs into a single string."""
out = []
for path, content in files.items():
out.append(f"[File: {path}]\n{content}")
return "\n\n".join(out)
def get_files_str(self, *, original: bool, context_length: int | None = 50, linenos: bool = True) -> str:
hunk_lines = self._get_hunk_lines(original=original, context_length=context_length)
sources = self._original_files if original else self._patched_files
return self.concat_files_strings(
{path: self.format_file(text, *hunk_lines[path], linenos=linenos) for path, text in sources.items()}
)

View File

@@ -0,0 +1,45 @@
import io
from copy import deepcopy
from typing import Any
from ruamel.yaml import YAML
from ruamel.yaml.scalarstring import LiteralScalarString as LSS
def _convert_to_yaml_literal_string(d: Any) -> Any:
"""Convert any multi-line strings in nested data object to LiteralScalarString.
This will then use the `|-` syntax of yaml.
"""
d = deepcopy(d)
if isinstance(d, dict):
for key, value in d.items():
d[key] = _convert_to_yaml_literal_string(value)
elif isinstance(d, list):
for i, item in enumerate(d):
d[i] = _convert_to_yaml_literal_string(item)
elif isinstance(d, str) and "\n" in d:
d = LSS(d.replace("\r\n", "\n").replace("\r", "\n"))
return d
def _yaml_serialization_with_linebreaks(data: Any) -> str:
data = _convert_to_yaml_literal_string(data)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.width = float("inf")
yaml.default_flow_style = False
buffer = io.StringIO()
yaml.dump(data, buffer)
return buffer.getvalue()
def merge_nested_dicts(d1: dict, d2: dict) -> dict:
"""Merge two nested dictionaries, updating d1 in place.
If a key exists in both dictionaries, the value from d2 will be used.
"""
for key, value in d2.items():
if isinstance(value, dict):
d1[key] = merge_nested_dicts(d1.get(key, {}), value)
else:
d1[key] = value
return d1