diff --git a/.gitignore b/.gitignore index 9ced3f51..4e7341d8 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,8 @@ trae_config.yaml # VS Code settings .vscode/ !.vscode/launch.template.json + +# markdown files +*.md.example +Project_rules.md +user_rules.md diff --git a/trae_agent/agent/trae_agent.py b/trae_agent/agent/trae_agent.py index 1ac311e8..29fda49a 100644 --- a/trae_agent/agent/trae_agent.py +++ b/trae_agent/agent/trae_agent.py @@ -17,6 +17,7 @@ from trae_agent.utils.config import MCPServerConfig, TraeAgentConfig from trae_agent.utils.llm_clients.llm_basics import LLMMessage, LLMResponse from trae_agent.utils.mcp_client import MCPClient +from trae_agent.utils.project_rules import ProjectRulesLoader TraeAgentToolNames = [ "str_replace_based_edit_tool", @@ -43,6 +44,10 @@ def __init__(self, trae_agent_config: TraeAgentConfig): self.base_commit: str | None = None self.must_patch: str = "false" self.patch_path: str | None = None + self.project_rules_enabled: bool = trae_agent_config.project_rules_enabled + self.project_rules_path: str = trae_agent_config.project_rules_path + self.user_rules_enabled: bool = trae_agent_config.user_rules_enabled + self.user_rules_path: str = trae_agent_config.user_rules_path self.mcp_servers_config: dict[str, MCPServerConfig] | None = ( trae_agent_config.mcp_servers_config if trae_agent_config.mcp_servers_config else None ) @@ -157,9 +162,33 @@ async def execute_task(self) -> AgentExecution: return execution + def _get_project_rules(self) -> str: + """Get the combined project rules and user rules content for TraeAgent. + + Returns: + str: Formatted combined rules content, or empty string if disabled or loading failed + """ + if not self.project_path: + return "" + + # Check if any rules are enabled + if not self.project_rules_enabled and not self.user_rules_enabled: + return "" + + return ProjectRulesLoader.load_combined_rules( + project_path=self.project_path, + project_rules_path=self.project_rules_path, + user_rules_path=self.user_rules_path, + project_rules_enabled=self.project_rules_enabled, + user_rules_enabled=self.user_rules_enabled + ) + def get_system_prompt(self) -> str: """Get the system prompt for TraeAgent.""" - return TRAE_AGENT_SYSTEM_PROMPT + base_prompt = TRAE_AGENT_SYSTEM_PROMPT + project_rules = self._get_project_rules() + + return base_prompt + project_rules @override def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None: diff --git a/trae_agent/cli.py b/trae_agent/cli.py index 165cb7b1..f0df2a26 100644 --- a/trae_agent/cli.py +++ b/trae_agent/cli.py @@ -18,6 +18,7 @@ from trae_agent.agent import Agent from trae_agent.utils.cli import CLIConsole, ConsoleFactory, ConsoleMode, ConsoleType from trae_agent.utils.config import Config, TraeAgentConfig +from trae_agent.utils.rules_manager import RulesManager # Load environment variables _ = load_dotenv() @@ -517,6 +518,138 @@ def tools(): console.print(tools_table) +@cli.group() +def rules(): + """Manage project and user rules files.""" + pass + + +@rules.command() +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def list(file_type: str, working_dir: str | None = None): + """List all rules in the specified file. + + FILE_TYPE: Type of rules file to list (project or user) + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.list_rules(file_type.lower()) + except Exception as e: + console.print(f"[red]Error listing rules: {e}[/red]") + sys.exit(1) + + +@rules.command() +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.argument("section") +@click.argument("rule") +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def add(file_type: str, section: str, rule: str, working_dir: str | None = None): + """Add a new rule to the specified section. + + FILE_TYPE: Type of rules file (project or user) + SECTION: Section name to add the rule to + RULE: Rule text to add + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.add_rule(file_type.lower(), section, rule) + except Exception as e: + console.print(f"[red]Error adding rule: {e}[/red]") + sys.exit(1) + + +@rules.command() +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.argument("section") +@click.argument("rule_pattern") +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def remove(file_type: str, section: str, rule_pattern: str, working_dir: str | None = None): + """Remove a rule from the specified section. + + FILE_TYPE: Type of rules file (project or user) + SECTION: Section name to remove the rule from + RULE_PATTERN: Pattern to match the rule to remove + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.remove_rule(file_type.lower(), section, rule_pattern) + except Exception as e: + console.print(f"[red]Error removing rule: {e}[/red]") + sys.exit(1) + + +@rules.command() +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.argument("section") +@click.argument("old_pattern") +@click.argument("new_rule") +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def update(file_type: str, section: str, old_pattern: str, new_rule: str, working_dir: str | None = None): + """Update an existing rule in the specified section. + + FILE_TYPE: Type of rules file (project or user) + SECTION: Section name containing the rule + OLD_PATTERN: Pattern to match the rule to update + NEW_RULE: New rule text + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.update_rule(file_type.lower(), section, old_pattern, new_rule) + except Exception as e: + console.print(f"[red]Error updating rule: {e}[/red]") + sys.exit(1) + + +@rules.command(name="add-section") +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.argument("section") +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def add_section(file_type: str, section: str, working_dir: str | None = None): + """Add a new section to the rules file. + + FILE_TYPE: Type of rules file (project or user) + SECTION: Section name to add + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.add_section(file_type.lower(), section) + except Exception as e: + console.print(f"[red]Error adding section: {e}[/red]") + sys.exit(1) + + +@rules.command(name="remove-section") +@click.argument("file_type", type=click.Choice(["project", "user"], case_sensitive=False)) +@click.argument("section") +@click.option("--working-dir", "-w", help="Working directory for the rules files") +def remove_section(file_type: str, section: str, working_dir: str | None = None): + """Remove a section from the rules file. + + FILE_TYPE: Type of rules file (project or user) + SECTION: Section name to remove + """ + try: + manager = RulesManager(working_dir) + if not manager.validate_permissions(file_type.lower()): + sys.exit(1) + manager.remove_section(file_type.lower(), section) + except Exception as e: + console.print(f"[red]Error removing section: {e}[/red]") + sys.exit(1) + + def main(): """Main entry point for the CLI.""" cli() diff --git a/trae_agent/utils/config.py b/trae_agent/utils/config.py index 95e8164b..2ba3de45 100644 --- a/trae_agent/utils/config.py +++ b/trae_agent/utils/config.py @@ -34,16 +34,33 @@ class ModelConfig: model: str model_provider: ModelProvider - max_tokens: int temperature: float top_p: float top_k: int parallel_tool_calls: bool max_retries: int + max_tokens: int | None = None # Legacy max_tokens parameter, optional supports_tool_calling: bool = True candidate_count: int | None = None # Gemini specific field stop_sequences: list[str] | None = None - + max_completion_tokens: int | None = None # Azure OpenAI specific field + + def get_max_tokens_param(self) -> int: + """Get the maximum tokens parameter value.Prioritizes max_completion_tokens, falls back to max_tokens if not available. """ + if self.max_completion_tokens is not None: + return self.max_completion_tokens + elif self.max_tokens is not None: + return self.max_tokens + else: + # Return default value if neither is set + return 4096 + + def should_use_max_completion_tokens(self) -> bool: + """Determine whether to use the max_completion_tokens parameter.Primarily used for Azure OpenAI's newer models (e.g., gpt-5).""" + return (self.max_completion_tokens is not None and + self.model_provider.provider == "azure" and + ("gpt-5" in self.model or "o3" in self.model or "o4-mini" in self.model)) + def resolve_config_values( self, *, @@ -143,6 +160,10 @@ class TraeAgentConfig(AgentConfig): """ enable_lakeview: bool = True + project_rules_enabled: bool = False + project_rules_path: str = "Project_rules.md" + user_rules_enabled: bool = False + user_rules_path: str = "user_rules.md" tools: list[str] = field( default_factory=lambda: [ "bash", diff --git a/trae_agent/utils/llm_clients/openai_compatible_base.py b/trae_agent/utils/llm_clients/openai_compatible_base.py index 20dabf58..a324415c 100644 --- a/trae_agent/utils/llm_clients/openai_compatible_base.py +++ b/trae_agent/utils/llm_clients/openai_compatible_base.py @@ -83,6 +83,14 @@ def _create_response( extra_headers: dict[str, str] | None = None, ) -> ChatCompletion: """Create a response using the provider's API. This method will be decorated with retry logic.""" + """Select the correct token parameter based on model configuration. + If max_completion_tokens is set, use it. Otherwise, use max_tokens.""" + token_params = {} + if model_config.should_use_max_completion_tokens(): + token_params["max_completion_tokens"] = model_config.get_max_tokens_param() + else: + token_params["max_tokens"] = model_config.get_max_tokens_param() + return self.client.chat.completions.create( model=model_config.model, messages=self.message_history, @@ -93,9 +101,9 @@ def _create_response( and "gpt-5" not in model_config.model else openai.NOT_GIVEN, top_p=model_config.top_p, - max_tokens=model_config.max_tokens, extra_headers=extra_headers if extra_headers else None, n=1, + **token_params, ) @override diff --git a/trae_agent/utils/project_rules.py b/trae_agent/utils/project_rules.py new file mode 100644 index 00000000..463ba707 --- /dev/null +++ b/trae_agent/utils/project_rules.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +"""Project rules utilities for loading and parsing Project_rules.md files.""" + +import os +from pathlib import Path +from typing import Optional + + +class ProjectRulesLoader: + """Utility class for loading and parsing project rules and user rules files.""" + + @staticmethod + def load_project_rules(project_path: str, rules_file_path: str = "Project_rules.md") -> Optional[str]: + """Load Project_rules.md file content from project directory. + Args: + project_path: Project root directory path + rules_file_path: Relative path to rules file, defaults to "Project_rules.md" + Returns: + Rules file content string, or None if file doesn't exist or read fails + """ + try: + if os.path.isabs(rules_file_path): + full_path = rules_file_path + else: + full_path = os.path.join(project_path, rules_file_path) + + if not os.path.exists(full_path): + return None + + project_path_resolved = os.path.abspath(project_path) + full_path_resolved = os.path.abspath(full_path) + if not full_path_resolved.startswith(project_path_resolved): + return None + + with open(full_path, 'r', encoding='utf-8') as f: + content = f.read().strip() + + return content if content else None + + except (OSError, IOError, UnicodeDecodeError): + return None + + @staticmethod + def load_user_rules(project_path: str, rules_file_path: str = "user_rules.md") -> Optional[str]: + """Load user_rules.md file content from project directory. + Args: + project_path: Project root directory path + rules_file_path: Relative path to user rules file, defaults to "user_rules.md" + Returns: + User rules file content string, or None if file doesn't exist or read fails + """ + return ProjectRulesLoader.load_project_rules(project_path, rules_file_path) + + @staticmethod + def load_combined_rules(project_path: str, project_rules_path: str = "Project_rules.md", + user_rules_path: str = "user_rules.md", + project_rules_enabled: bool = True, + user_rules_enabled: bool = True) -> str: + """Load and combine project rules and user rules. + Args: + project_path: Project root directory path + project_rules_path: Relative path to project rules file + user_rules_path: Relative path to user rules file + project_rules_enabled: Whether to load project rules + user_rules_enabled: Whether to load user rules + Returns: + Combined formatted rules content + """ + combined_content = "" + + # Load project rules + if project_rules_enabled: + project_content = ProjectRulesLoader.load_project_rules(project_path, project_rules_path) + if project_content: + combined_content += ProjectRulesLoader.format_rules_for_prompt( + project_content, "PROJECT-SPECIFIC RULES" + ) + + # Load user rules + if user_rules_enabled: + user_content = ProjectRulesLoader.load_user_rules(project_path, user_rules_path) + if user_content: + combined_content += ProjectRulesLoader.format_rules_for_prompt( + user_content, "USER-SPECIFIC RULES" + ) + + return combined_content + + @staticmethod + def format_rules_for_prompt(rules_content: str, section_title: str = "PROJECT-SPECIFIC RULES") -> str: + """Format rules content for adding to system prompt. + Args: + rules_content: Original rules file content + section_title: Title for the rules section + Returns: + Formatted rules content + """ + if not rules_content: + return "" + + formatted_rules = f""" +# {section_title} +The following are {section_title.lower()} and guidelines that you MUST follow: +{rules_content} +# END OF {section_title} +""" + return formatted_rules + + @staticmethod + def validate_rules_file(file_path: str) -> bool: + """Validate if rules file is valid. + Args: + file_path: Rules file path + Returns: + Whether the file is valid + """ + try: + if not os.path.exists(file_path): + return False + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + if len(content) > 10000: + return False + + try: + content.encode('utf-8') + for char in content: + if ord(char) < 32 and char not in '\n\r\t': + return False + + except UnicodeEncodeError: + return False + return True + except (OSError, IOError, UnicodeDecodeError): + return False \ No newline at end of file diff --git a/trae_agent/utils/rules_manager.py b/trae_agent/utils/rules_manager.py new file mode 100644 index 00000000..e2e4d063 --- /dev/null +++ b/trae_agent/utils/rules_manager.py @@ -0,0 +1,343 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +"""Rules Manager for managing project_rules.md and user_rules.md files.""" + +import os +import re +from pathlib import Path +from typing import List, Optional, Dict, Any + +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +console = Console() + + +class RulesManager: + """Manager for project and user rules files.""" + + def __init__(self, working_dir: Optional[str] = None): + """Initialize the rules manager. + + Args: + working_dir: Working directory path. If None, uses current directory. + """ + self.working_dir = Path(working_dir) if working_dir else Path.cwd() + self.project_rules_file = self.working_dir / "project_rules.md" + self.user_rules_file = self.working_dir / "user_rules.md" + + def _ensure_file_exists(self, file_path: Path, file_type: str) -> None: + """Ensure the rules file exists, create if not. + + Args: + file_path: Path to the rules file + file_type: Type of rules file (project or user) + """ + if not file_path.exists(): + template_content = self._get_template_content(file_type) + file_path.write_text(template_content, encoding='utf-8') + console.print(f"[green]Created {file_type} rules file: {file_path}[/green]") + + def _get_template_content(self, file_type: str) -> str: + """Get template content for rules file. + + Args: + file_type: Type of rules file (project or user) + + Returns: + Template content string + """ + if file_type == "project": + return """# Project Rules + +## Code Style +- Follow PEP 8 for Python code +- Use meaningful variable and function names +- Add docstrings for all public functions and classes + +## Architecture +- Follow the existing project structure +- Use dependency injection where appropriate +- Implement proper error handling + +## Testing +- Write unit tests for new functionality +- Ensure all tests pass before committing +- Maintain test coverage above 80% + +## Documentation +- Update README.md when adding new features +- Document API changes in appropriate files +- Use clear and concise comments +""" + else: # user rules + return """# User Rules + +## Personal Preferences +- Prefer explicit over implicit code +- Use type hints for better code clarity +- Favor composition over inheritance + +## Workflow +- Create feature branches for new work +- Use descriptive commit messages +- Review code before merging + +## Tools and Libraries +- Use pytest for testing +- Use black for code formatting +- Use mypy for type checking +""" + + def _parse_rules(self, content: str) -> Dict[str, List[str]]: + """Parse rules content into sections. + + Args: + content: Raw markdown content + + Returns: + Dictionary mapping section names to lists of rules + """ + sections = {} + current_section = None + current_rules = [] + + for line in content.split('\n'): + line = line.strip() + if line.startswith('## '): + if current_section: + sections[current_section] = current_rules + current_section = line[3:].strip() + current_rules = [] + elif line.startswith('- '): + current_rules.append(line[2:].strip()) + + if current_section: + sections[current_section] = current_rules + + return sections + + def _format_rules(self, sections: Dict[str, List[str]], title: str) -> str: + """Format rules sections back to markdown. + + Args: + sections: Dictionary mapping section names to lists of rules + title: Title for the rules file + + Returns: + Formatted markdown content + """ + content = f"# {title}\n\n" + for section, rules in sections.items(): + content += f"## {section}\n" + for rule in rules: + content += f"- {rule}\n" + content += "\n" + return content.rstrip() + "\n" + + def list_rules(self, file_type: str) -> None: + """List all rules in the specified file. + + Args: + file_type: Type of rules file ('project' or 'user') + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + + if not file_path.exists(): + console.print(f"[yellow]{file_type.title()} rules file does not exist: {file_path}[/yellow]") + return + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if not sections: + console.print(f"[yellow]No rules found in {file_type} rules file[/yellow]") + return + + table = Table(title=f"{file_type.title()} Rules") + table.add_column("Section", style="cyan") + table.add_column("Rules", style="green") + + for section, rules in sections.items(): + rules_text = "\n".join([f"• {rule}" for rule in rules]) + table.add_row(section, rules_text) + + console.print(table) + + def add_rule(self, file_type: str, section: str, rule: str) -> None: + """Add a new rule to the specified section. + + Args: + file_type: Type of rules file ('project' or 'user') + section: Section name to add the rule to + rule: Rule text to add + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + title = "Project Rules" if file_type == "project" else "User Rules" + + self._ensure_file_exists(file_path, file_type) + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if section not in sections: + sections[section] = [] + + if rule not in sections[section]: + sections[section].append(rule) + new_content = self._format_rules(sections, title) + file_path.write_text(new_content, encoding='utf-8') + console.print(f"[green]Added rule to {section} in {file_type} rules[/green]") + else: + console.print(f"[yellow]Rule already exists in {section}[/yellow]") + + def remove_rule(self, file_type: str, section: str, rule_pattern: str) -> None: + """Remove a rule from the specified section. + + Args: + file_type: Type of rules file ('project' or 'user') + section: Section name to remove the rule from + rule_pattern: Pattern to match the rule to remove + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + title = "Project Rules" if file_type == "project" else "User Rules" + + if not file_path.exists(): + console.print(f"[red]{file_type.title()} rules file does not exist[/red]") + return + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if section not in sections: + console.print(f"[red]Section '{section}' not found in {file_type} rules[/red]") + return + + original_count = len(sections[section]) + sections[section] = [rule for rule in sections[section] + if not re.search(rule_pattern, rule, re.IGNORECASE)] + + removed_count = original_count - len(sections[section]) + + if removed_count > 0: + new_content = self._format_rules(sections, title) + file_path.write_text(new_content, encoding='utf-8') + console.print(f"[green]Removed {removed_count} rule(s) from {section} in {file_type} rules[/green]") + else: + console.print(f"[yellow]No rules matching '{rule_pattern}' found in {section}[/yellow]") + + def update_rule(self, file_type: str, section: str, old_pattern: str, new_rule: str) -> None: + """Update an existing rule in the specified section. + + Args: + file_type: Type of rules file ('project' or 'user') + section: Section name containing the rule + old_pattern: Pattern to match the rule to update + new_rule: New rule text + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + title = "Project Rules" if file_type == "project" else "User Rules" + + if not file_path.exists(): + console.print(f"[red]{file_type.title()} rules file does not exist[/red]") + return + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if section not in sections: + console.print(f"[red]Section '{section}' not found in {file_type} rules[/red]") + return + + updated = False + for i, rule in enumerate(sections[section]): + if re.search(old_pattern, rule, re.IGNORECASE): + sections[section][i] = new_rule + updated = True + break + + if updated: + new_content = self._format_rules(sections, title) + file_path.write_text(new_content, encoding='utf-8') + console.print(f"[green]Updated rule in {section} in {file_type} rules[/green]") + else: + console.print(f"[yellow]No rule matching '{old_pattern}' found in {section}[/yellow]") + + def add_section(self, file_type: str, section: str) -> None: + """Add a new section to the rules file. + + Args: + file_type: Type of rules file ('project' or 'user') + section: Section name to add + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + title = "Project Rules" if file_type == "project" else "User Rules" + + self._ensure_file_exists(file_path, file_type) + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if section in sections: + console.print(f"[yellow]Section '{section}' already exists in {file_type} rules[/yellow]") + return + + sections[section] = [] + new_content = self._format_rules(sections, title) + file_path.write_text(new_content, encoding='utf-8') + console.print(f"[green]Added section '{section}' to {file_type} rules[/green]") + + def remove_section(self, file_type: str, section: str) -> None: + """Remove a section from the rules file. + + Args: + file_type: Type of rules file ('project' or 'user') + section: Section name to remove + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + title = "Project Rules" if file_type == "project" else "User Rules" + + if not file_path.exists(): + console.print(f"[red]{file_type.title()} rules file does not exist[/red]") + return + + content = file_path.read_text(encoding='utf-8') + sections = self._parse_rules(content) + + if section not in sections: + console.print(f"[red]Section '{section}' not found in {file_type} rules[/red]") + return + + del sections[section] + new_content = self._format_rules(sections, title) + file_path.write_text(new_content, encoding='utf-8') + console.print(f"[green]Removed section '{section}' from {file_type} rules[/green]") + + def validate_permissions(self, file_type: str) -> bool: + """Validate write permissions for the rules file. + + Args: + file_type: Type of rules file ('project' or 'user') + + Returns: + True if permissions are valid, False otherwise + """ + file_path = self.project_rules_file if file_type == "project" else self.user_rules_file + + try: + # Check if directory is writable + if not os.access(file_path.parent, os.W_OK): + console.print(f"[red]No write permission for directory: {file_path.parent}[/red]") + return False + + # Check if file is writable (if it exists) + if file_path.exists() and not os.access(file_path, os.W_OK): + console.print(f"[red]No write permission for file: {file_path}[/red]") + return False + + return True + except Exception as e: + console.print(f"[red]Permission check failed: {e}[/red]") + return False \ No newline at end of file diff --git a/trae_config.yaml.example b/trae_config.yaml.example index 21f9d4d2..f2e77efd 100644 --- a/trae_config.yaml.example +++ b/trae_config.yaml.example @@ -1,6 +1,10 @@ agents: trae_agent: enable_lakeview: true + project_rules_enabled: true + project_rules_path: "Project_rules.md" + user_rules_enabled: true + user_rules_path: "user_rules.md" model: trae_agent_model max_steps: 200 tools: