Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,39 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
patches = []
remaining_files_list_new = []
files_in_patch_list = []

# Cache verbosity_level so that expensive config access isn't repeated in hot loop
settings = get_settings()
# config might rarely change but is extremely expensive to access, so cache for one complete function call
config = getattr(settings, 'config', None)
verbosity_level = getattr(config, 'verbosity_level', 0) if config is not None else 0

hard_threshold = max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
soft_threshold = max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD

# Use set for remaining_files_list_prev to optimize 'not in' checks for large sets
prev_files_set = set(remaining_files_list_prev)

for filename, data in file_dict.items():
if filename not in remaining_files_list_prev:
if filename not in prev_files_set:
continue

patch = data['patch']
new_patch_tokens = data['tokens']
edit_type = data['edit_type']

# Hard Stop, no more tokens
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
if total_tokens > hard_threshold:
if verbosity_level >= 2:
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
continue

# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > soft_threshold:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if get_settings().config.verbosity_level >= 2:
if verbosity_level >= 2:
get_logger().warning(f"Patch too large, skipping it: '{filename}'")
remaining_files_list_new.append(filename)
continue
Expand All @@ -312,7 +326,7 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
files_in_patch_list.append(filename)
if get_settings().config.verbosity_level >= 2:
if verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {filename}")
return total_tokens, patches, remaining_files_list_new, files_in_patch_list

Expand Down
45 changes: 22 additions & 23 deletions pr_agent/algo/token_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from threading import Lock
from math import ceil
import re
from math import ceil
from threading import Lock

from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding
Expand All @@ -12,11 +12,11 @@
class ModelTypeValidator:
@staticmethod
def is_openai_model(model_name: str) -> bool:
return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name)
return "gpt" in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name)

@staticmethod
def is_anthropic_model(model_name: str) -> bool:
return 'claude' in model_name
return "claude" in model_name


class TokenEncoder:
Expand All @@ -32,8 +32,9 @@ def get_token_encoder(cls):
if cls._encoder_instance is None or model != cls._model:
cls._model = model
try:
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
"o200k_base")
cls._encoder_instance = (
encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding("o200k_base")
)
except:
cls._encoder_instance = get_encoding("o200k_base")
return cls._encoder_instance
Expand All @@ -54,7 +55,7 @@ class TokenHandler:

# Constants
CLAUDE_MODEL = "claude-3-7-sonnet-20250219"
CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API
CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API

def __init__(self, pr=None, vars: dict = {}, system="", user=""):
"""
Expand All @@ -67,7 +68,7 @@ def __init__(self, pr=None, vars: dict = {}, system="", user=""):
- user: The user string.
"""
self.encoder = TokenEncoder.get_token_encoder()

if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)

Expand Down Expand Up @@ -99,12 +100,13 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
def _calc_claude_tokens(self, patch: str) -> int:
try:
import anthropic

from pr_agent.algo import MAX_TOKENS
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))

client = anthropic.Anthropic(api_key=get_settings(use_context=False).get("anthropic.key"))
max_tokens = MAX_TOKENS[get_settings().config.model]

if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE:
if len(patch.encode("utf-8")) > self.CLAUDE_MAX_CONTENT_SIZE:
get_logger().warning(
"Content too large for Anthropic token counting API, falling back to local tokenizer"
)
Expand All @@ -113,10 +115,7 @@ def _calc_claude_tokens(self, patch: str) -> int:
response = client.messages.count_tokens(
model=self.CLAUDE_MODEL,
system="system",
messages=[{
"role": "user",
"content": patch
}],
messages=[{"role": "user", "content": patch}],
)
return response.input_tokens

Expand All @@ -125,9 +124,9 @@ def _calc_claude_tokens(self, patch: str) -> int:
return max_tokens

def _apply_estimation_factor(self, model_name: str, default_estimate: int) -> int:
factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
factor = 1 + get_settings().get("config.model_token_count_estimate_factor", 0)
get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")

return ceil(factor * default_estimate)

def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
Expand All @@ -142,15 +141,15 @@ def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> i
int: The calculated token count.
"""
model_name = get_settings().config.model.lower()
if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'):

if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get("openai.key"):
return default_estimate

if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get('anthropic.key'):
if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get("anthropic.key"):
return self._calc_claude_tokens(patch)

return self._apply_estimation_factor(model_name, default_estimate)

def count_tokens(self, patch: str, force_accurate: bool = False) -> int:
"""
Counts the number of tokens in a given patch string.
Expand Down
71 changes: 40 additions & 31 deletions pr_agent/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,38 @@
from dynaconf import Dynaconf
from starlette_context import context

PR_AGENT_TOML_KEY = 'pr-agent'
PR_AGENT_TOML_KEY = "pr-agent"

current_dir = dirname(abspath(__file__))
global_settings = Dynaconf(
envvar_prefix=False,
merge_enabled=True,
settings_files=[join(current_dir, f) for f in [
"settings/configuration.toml",
"settings/ignore.toml",
"settings/generated_code_ignore.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/code_suggestions/pr_code_suggestions_prompts.toml",
"settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml",
"settings/code_suggestions/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
"settings/custom_labels.toml",
"settings/pr_help_prompts.toml",
"settings/pr_help_docs_prompts.toml",
"settings/pr_help_docs_headings_prompts.toml",
"settings/.secrets.toml",
"settings_prod/.secrets.toml",
]]
settings_files=[
join(current_dir, f)
for f in [
"settings/configuration.toml",
"settings/ignore.toml",
"settings/generated_code_ignore.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/code_suggestions/pr_code_suggestions_prompts.toml",
"settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml",
"settings/code_suggestions/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
"settings/custom_labels.toml",
"settings/pr_help_prompts.toml",
"settings/pr_help_docs_prompts.toml",
"settings/pr_help_docs_headings_prompts.toml",
"settings/.secrets.toml",
"settings_prod/.secrets.toml",
]
],
)


Expand Down Expand Up @@ -81,7 +84,7 @@ def _find_pyproject() -> Optional[Path]:

pyproject_path = _find_pyproject()
if pyproject_path is not None:
get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')
get_settings().load_file(pyproject_path, env=f"tool.{PR_AGENT_TOML_KEY}")


def apply_secrets_manager_config():
Expand All @@ -90,15 +93,17 @@ def apply_secrets_manager_config():
"""
try:
# Dynamic imports to avoid circular dependency (secret_providers imports config_loader)
from pr_agent.secret_providers import get_secret_provider
from pr_agent.log import get_logger
from pr_agent.secret_providers import get_secret_provider

secret_provider = get_secret_provider()
if not secret_provider:
return

if (hasattr(secret_provider, 'get_all_secrets') and
get_settings().get("CONFIG.SECRET_PROVIDER") == 'aws_secrets_manager'):
if (
hasattr(secret_provider, "get_all_secrets")
and get_settings().get("CONFIG.SECRET_PROVIDER") == "aws_secrets_manager"
):
try:
secrets = secret_provider.get_all_secrets()
if secrets:
Expand All @@ -109,6 +114,7 @@ def apply_secrets_manager_config():
except Exception as e:
try:
from pr_agent.log import get_logger

get_logger().debug(f"Secret provider not configured: {e}")
except:
# Fail completely silently if log module is not available
Expand All @@ -123,14 +129,17 @@ def apply_secrets_to_config(secrets: dict):
# Dynamic import to avoid potential circular dependency
from pr_agent.log import get_logger
except:

def get_logger():
class DummyLogger:
def debug(self, msg): pass
def debug(self, msg):
pass

return DummyLogger()

for key, value in secrets.items():
if '.' in key: # nested key like "openai.key"
parts = key.split('.')
if "." in key: # nested key like "openai.key"
parts = key.split(".")
if len(parts) == 2:
section, setting = parts
section_upper = section.upper()
Expand Down
3 changes: 2 additions & 1 deletion pr_agent/log/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

os.environ["AUTO_CAST_FOR_DYNACONF"] = "false"
import json
import logging
Expand Down Expand Up @@ -42,7 +43,7 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
colorize=False,
serialize=True,
)
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
logger.remove(None)
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)

Expand Down