Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 326 additions & 16 deletions codemcp/code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,17 @@ async def run_code_command(
commit_message: str,
chat_id: Optional[str] = None,
) -> str:
"""Run a code command (lint, format, etc.) and handle git operations.
"""Run a code command (lint, format, etc.) and handle git operations using commutable commits.

This function implements a sophisticated auto-commit mechanism that:
1. Creates a PRE_COMMIT with all pending changes
2. Resets HEAD/index to the state before making this commit (working tree keeps changes)
3. Runs the intended command
4. Assesses the impact of the command:
a. If no changes were made, it does nothing and ignores PRE_COMMIT
b. If changes were made, it creates POST_COMMIT and tries to commute changes:
- If the cherry-pick succeeds, uses the commuted POST_COMMIT
- If the cherry-pick fails, uses the original uncommuted POST_COMMIT

Args:
project_dir: The directory path containing the code to process
Expand Down Expand Up @@ -128,18 +138,84 @@ async def run_code_command(
# Check if directory is in a git repository
is_git_repo = await is_git_repository(full_dir_path)

# If it's a git repo, commit any pending changes before running the command
# If it's a git repo, handle the commutable auto-commit mechanism
pre_commit_hash = None
original_head_hash = None
if is_git_repo:
logging.info(f"Committing any pending changes before {command_name}")
chat_id_str = str(chat_id) if chat_id is not None else ""
commit_result = await commit_changes(
full_dir_path,
f"Snapshot before auto-{command_name}",
chat_id_str,
commit_all=True,
)
if not commit_result[0]:
logging.warning(f"Failed to commit pending changes: {commit_result[1]}")
try:
git_cwd = await get_repository_root(full_dir_path)

# Get the current HEAD hash
head_hash_result = await run_command(
["git", "rev-parse", "HEAD"],
cwd=git_cwd,
capture_output=True,
text=True,
check=False,
)

if head_hash_result.returncode == 0:
original_head_hash = head_hash_result.stdout.strip()

# Check if there are any changes to commit
has_initial_changes = await check_for_changes(full_dir_path)

if has_initial_changes:
logging.info(f"Creating PRE_COMMIT before running {command_name}")
chat_id_str = str(chat_id) if chat_id is not None else ""

# Create the PRE_COMMIT with all changes
await run_command(
["git", "add", "."],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Commit all changes (including untracked files)
await run_command(
[
"git",
"commit",
"--no-gpg-sign",
"-m",
f"PRE_COMMIT: Snapshot before auto-{command_name}",
],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Get the hash of our PRE_COMMIT
pre_commit_hash_result = await run_command(
["git", "rev-parse", "HEAD"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
pre_commit_hash = pre_commit_hash_result.stdout.strip()

logging.info(f"Created PRE_COMMIT: {pre_commit_hash}")

# Reset HEAD to the previous commit, but keep working tree changes (mixed mode)
# This effectively "uncommits" without losing the changes in the working tree
if original_head_hash:
await run_command(
["git", "reset", original_head_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
logging.info(
f"Reset HEAD to {original_head_hash}, keeping changes in working tree"
)
except Exception as e:
logging.warning(f"Failed to set up PRE_COMMIT: {e}")
# Continue with command execution even if PRE_COMMIT setup fails

# Run the command
try:
Expand All @@ -151,13 +227,195 @@ async def run_code_command(
text=True,
)

# Additional logging is already done by run_command

# Truncate the output if needed, prioritizing the end content
truncated_stdout = truncate_output_content(result.stdout, prefer_end=True)

# If it's a git repo, commit any changes made by the command
if is_git_repo:
# If it's a git repo and PRE_COMMIT was created, handle commutation of changes
if is_git_repo and pre_commit_hash:
git_cwd = await get_repository_root(full_dir_path)

# Check if command made any changes
has_command_changes = await check_for_changes(full_dir_path)

if not has_command_changes:
logging.info(
f"No changes made by {command_name}, ignoring PRE_COMMIT"
)
return f"Code {command_name} successful (no changes made):\n{truncated_stdout}"

logging.info(
f"Changes detected after {command_name}, creating POST_COMMIT"
)

# Create POST_COMMIT with PRE_COMMIT as parent
# First, stage all changes (including untracked files)
await run_command(
["git", "add", "."],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Create the POST_COMMIT on top of PRE_COMMIT
chat_id_str = str(chat_id) if chat_id is not None else ""

# Temporarily set HEAD to PRE_COMMIT
await run_command(
["git", "update-ref", "HEAD", pre_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Create POST_COMMIT
await run_command(
[
"git",
"commit",
"--no-gpg-sign",
"-m",
f"POST_COMMIT: {commit_message}",
],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Get the POST_COMMIT hash
post_commit_hash_result = await run_command(
["git", "rev-parse", "HEAD"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
post_commit_hash = post_commit_hash_result.stdout.strip()
logging.info(f"Created POST_COMMIT: {post_commit_hash}")

# Now try to commute the changes
# Reset to original HEAD
await run_command(
["git", "reset", "--hard", original_head_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Try to cherry-pick PRE_COMMIT onto original HEAD
try:
await run_command(
["git", "cherry-pick", "--no-gpg-sign", pre_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# If we get here, PRE_COMMIT applied cleanly
commuted_pre_commit_hash_result = await run_command(
["git", "rev-parse", "HEAD"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
commuted_pre_commit_hash = (
commuted_pre_commit_hash_result.stdout.strip()
)

# Now try to cherry-pick POST_COMMIT
await run_command(
["git", "cherry-pick", "--no-gpg-sign", post_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)

# Get the commuted POST_COMMIT hash
commuted_post_commit_hash_result = await run_command(
["git", "rev-parse", "HEAD"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
commuted_post_commit_hash = (
commuted_post_commit_hash_result.stdout.strip()
)

# Verify that the final tree is the same
original_tree_result = await run_command(
["git", "rev-parse", f"{post_commit_hash}^{{tree}}"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
original_tree = original_tree_result.stdout.strip()

commuted_tree_result = await run_command(
["git", "rev-parse", f"{commuted_post_commit_hash}^{{tree}}"],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
commuted_tree = commuted_tree_result.stdout.strip()

if original_tree == commuted_tree:
# Commutation successful and trees match!
# Make sure we have the same changes uncommitted
await run_command(
["git", "reset", commuted_pre_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
logging.info(
f"Commutation successful! Set HEAD to commuted POST_COMMIT and reset to commuted PRE_COMMIT"
)
return f"Code {command_name} successful (changes commuted successfully):\n{truncated_stdout}"
else:
# Trees don't match, go back to unconmuted version
logging.info(
f"Commutation resulted in different trees, using original POST_COMMIT"
)
await run_command(
["git", "reset", "--hard", post_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
return f"Code {command_name} successful (changes don't commute, using original order):\n{truncated_stdout}"

except subprocess.CalledProcessError:
# Cherry-pick failed, go back to unconmuted version
logging.info(f"Cherry-pick failed, using original POST_COMMIT")
await run_command(
["git", "cherry-pick", "--abort"],
cwd=git_cwd,
capture_output=True,
text=True,
check=False,
)
await run_command(
["git", "reset", "--hard", post_commit_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
return f"Code {command_name} successful (changes don't commute, using original order):\n{truncated_stdout}"

# If no PRE_COMMIT was created or not a git repo, handle normally
elif is_git_repo:
has_changes = await check_for_changes(full_dir_path)
if has_changes:
logging.info(f"Changes detected after {command_name}, committing")
Expand All @@ -176,6 +434,32 @@ async def run_code_command(

return f"Code {command_name} successful:\n{truncated_stdout}"
except subprocess.CalledProcessError as e:
# If we were in the middle of the commutation process, try to restore the original state
if is_git_repo and pre_commit_hash and original_head_hash:
try:
git_cwd = await get_repository_root(full_dir_path)

# Abort any in-progress cherry-pick
await run_command(
["git", "cherry-pick", "--abort"],
cwd=git_cwd,
capture_output=True,
text=True,
check=False,
)

# Reset to original head
await run_command(
["git", "reset", "--hard", original_head_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
logging.info(f"Restored original state after command failure")
except Exception as restore_error:
logging.error(f"Failed to restore original state: {restore_error}")

# Map the command_name to keep backward compatibility with existing tests
command_key = command_name.title()
if command_name == "linting":
Expand Down Expand Up @@ -212,6 +496,32 @@ async def run_code_command(
return f"Error: {error_msg}"

except Exception as e:
# If we were in the middle of the commutation process, try to restore the original state
if is_git_repo and pre_commit_hash and original_head_hash:
try:
git_cwd = await get_repository_root(full_dir_path)

# Abort any in-progress cherry-pick
await run_command(
["git", "cherry-pick", "--abort"],
cwd=git_cwd,
capture_output=True,
text=True,
check=False,
)

# Reset to original head
await run_command(
["git", "reset", "--hard", original_head_hash],
cwd=git_cwd,
capture_output=True,
text=True,
check=True,
)
logging.info(f"Restored original state after exception")
except Exception as restore_error:
logging.error(f"Failed to restore original state: {restore_error}")

error_msg = f"Error during {command_name}: {e}"
logging.error(error_msg)
return f"Error: {error_msg}"
Loading
Loading