Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,8 @@ trae_config.yaml
# VS Code settings
.vscode/
!.vscode/launch.template.json

# markdown files
*.md.example
Project_rules.md
user_rules.md
31 changes: 30 additions & 1 deletion trae_agent/agent/trae_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from trae_agent.utils.config import MCPServerConfig, TraeAgentConfig
from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse
from trae_agent.utils.mcp_client import MCPClient
from trae_agent.utils.project_rules import ProjectRulesLoader

TraeAgentToolNames = [
"str_replace_based_edit_tool",
Expand All @@ -43,6 +44,10 @@ def __init__(self, trae_agent_config: TraeAgentConfig):
self.base_commit: str | None = None
self.must_patch: str = "false"
self.patch_path: str | None = None
self.project_rules_enabled: bool = trae_agent_config.project_rules_enabled
self.project_rules_path: str = trae_agent_config.project_rules_path
self.user_rules_enabled: bool = trae_agent_config.user_rules_enabled
self.user_rules_path: str = trae_agent_config.user_rules_path
self.mcp_servers_config: dict[str, MCPServerConfig] | None = (
trae_agent_config.mcp_servers_config if trae_agent_config.mcp_servers_config else None
)
Expand Down Expand Up @@ -157,9 +162,33 @@ async def execute_task(self) -> AgentExecution:

return execution

def _get_project_rules(self) -> str:
"""Get the combined project rules and user rules content for TraeAgent.

Returns:
str: Formatted combined rules content, or empty string if disabled or loading failed
"""
if not self.project_path:
return ""

# Check if any rules are enabled
if not self.project_rules_enabled and not self.user_rules_enabled:
return ""

return ProjectRulesLoader.load_combined_rules(
project_path=self.project_path,
project_rules_path=self.project_rules_path,
user_rules_path=self.user_rules_path,
project_rules_enabled=self.project_rules_enabled,
user_rules_enabled=self.user_rules_enabled
)

def get_system_prompt(self) -> str:
"""Get the system prompt for TraeAgent."""
return TRAE_AGENT_SYSTEM_PROMPT
base_prompt = TRAE_AGENT_SYSTEM_PROMPT
project_rules = self._get_project_rules()

return base_prompt + project_rules

@override
def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None:
Expand Down
133 changes: 133 additions & 0 deletions trae_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from trae_agent.agent import Agent
from trae_agent.utils.cli import CLIConsole, ConsoleFactory, ConsoleMode, ConsoleType
from trae_agent.utils.config import Config, TraeAgentConfig
from trae_agent.utils.rules_manager import RulesManager

# Load environment variables
_ = load_dotenv()
Expand Down Expand Up @@ -517,6 +518,138 @@ def tools():
console.print(tools_table)


@cli.group()
def rules():
"""Manage project and user rules files."""
pass


@rules.command()
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def list(file_type: str, working_dir: str | None = None):
"""List all rules in the specified file.

FILE_TYPE: Type of rules file to list (project or user)
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.list_rules(file_type.lower())
except Exception as e:
console.print(f"[red]Error listing rules: {e}[/red]")
sys.exit(1)


@rules.command()
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.argument("section")
@click.argument("rule")
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def add(file_type: str, section: str, rule: str, working_dir: str | None = None):
"""Add a new rule to the specified section.

FILE_TYPE: Type of rules file (project or user)
SECTION: Section name to add the rule to
RULE: Rule text to add
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.add_rule(file_type.lower(), section, rule)
except Exception as e:
console.print(f"[red]Error adding rule: {e}[/red]")
sys.exit(1)


@rules.command()
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.argument("section")
@click.argument("rule_pattern")
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def remove(file_type: str, section: str, rule_pattern: str, working_dir: str | None = None):
"""Remove a rule from the specified section.

FILE_TYPE: Type of rules file (project or user)
SECTION: Section name to remove the rule from
RULE_PATTERN: Pattern to match the rule to remove
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.remove_rule(file_type.lower(), section, rule_pattern)
except Exception as e:
console.print(f"[red]Error removing rule: {e}[/red]")
sys.exit(1)


@rules.command()
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.argument("section")
@click.argument("old_pattern")
@click.argument("new_rule")
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def update(file_type: str, section: str, old_pattern: str, new_rule: str, working_dir: str | None = None):
"""Update an existing rule in the specified section.

FILE_TYPE: Type of rules file (project or user)
SECTION: Section name containing the rule
OLD_PATTERN: Pattern to match the rule to update
NEW_RULE: New rule text
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.update_rule(file_type.lower(), section, old_pattern, new_rule)
except Exception as e:
console.print(f"[red]Error updating rule: {e}[/red]")
sys.exit(1)


@rules.command(name="add-section")
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.argument("section")
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def add_section(file_type: str, section: str, working_dir: str | None = None):
"""Add a new section to the rules file.

FILE_TYPE: Type of rules file (project or user)
SECTION: Section name to add
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.add_section(file_type.lower(), section)
except Exception as e:
console.print(f"[red]Error adding section: {e}[/red]")
sys.exit(1)


@rules.command(name="remove-section")
@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False))
@click.argument("section")
@click.option("--working-dir", "-w", help="Working directory for the rules files")
def remove_section(file_type: str, section: str, working_dir: str | None = None):
"""Remove a section from the rules file.

FILE_TYPE: Type of rules file (project or user)
SECTION: Section name to remove
"""
try:
manager = RulesManager(working_dir)
if not manager.validate_permissions(file_type.lower()):
sys.exit(1)
manager.remove_section(file_type.lower(), section)
except Exception as e:
console.print(f"[red]Error removing section: {e}[/red]")
sys.exit(1)


def main():
"""Main entry point for the CLI."""
cli()
Expand Down
25 changes: 23 additions & 2 deletions trae_agent/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,33 @@ class ModelConfig:

model: str
model_provider: ModelProvider
max_tokens: int
temperature: float
top_p: float
top_k: int
parallel_tool_calls: bool
max_retries: int
max_tokens: int | None = None # Legacy max_tokens parameter, optional
supports_tool_calling: bool = True
candidate_count: int | None = None # Gemini specific field
stop_sequences: list[str] | None = None

max_completion_tokens: int | None = None # Azure OpenAI specific field

def get_max_tokens_param(self) -> int:
"""Get the maximum tokens parameter value.Prioritizes max_completion_tokens, falls back to max_tokens if not available. """
if self.max_completion_tokens is not None:
return self.max_completion_tokens
elif self.max_tokens is not None:
return self.max_tokens
else:
# Return default value if neither is set
return 4096

def should_use_max_completion_tokens(self) -> bool:
"""Determine whether to use the max_completion_tokens parameter.Primarily used for Azure OpenAI's newer models (e.g., gpt-5)."""
return (self.max_completion_tokens is not None and
self.model_provider.provider == "azure" and
("gpt-5" in self.model or "o3" in self.model or "o4-mini" in self.model))

def resolve_config_values(
self,
*,
Expand Down Expand Up @@ -143,6 +160,10 @@ class TraeAgentConfig(AgentConfig):
"""

enable_lakeview: bool = True
project_rules_enabled: bool = False
project_rules_path: str = "Project_rules.md"
user_rules_enabled: bool = False
user_rules_path: str = "user_rules.md"
tools: list[str] = field(
default_factory=lambda: [
"bash",
Expand Down
10 changes: 9 additions & 1 deletion trae_agent/utils/llm_clients/openai_compatible_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def _create_response(
extra_headers: dict[str, str] | None = None,
) -> ChatCompletion:
"""Create a response using the provider's API. This method will be decorated with retry logic."""
"""Select the correct token parameter based on model configuration.
If max_completion_tokens is set, use it. Otherwise, use max_tokens."""
token_params = {}
if model_config.should_use_max_completion_tokens():
token_params["max_completion_tokens"] = model_config.get_max_tokens_param()
else:
token_params["max_tokens"] = model_config.get_max_tokens_param()

return self.client.chat.completions.create(
model=model_config.model,
messages=self.message_history,
Expand All @@ -93,9 +101,9 @@ def _create_response(
and "gpt-5" not in model_config.model
else openai.NOT_GIVEN,
top_p=model_config.top_p,
max_tokens=model_config.max_tokens,
extra_headers=extra_headers if extra_headers else None,
n=1,
**token_params,
)

@override
Expand Down
Loading