Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 105 additions & 7 deletions .claude/commands/create-pr.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,75 @@ gh --version
**Action:** If there are uncommitted changes, stop, and then ask user to commit or stash
them first.

### Step 1.5: Determine Push Remote (Fork Support)

Check if user has push access to `origin`, and if not, identify the fork remote.

```bash
# List all remotes
git remote -v

# Try to determine push remote
# Option 1: Check if origin is writable (try a dry-run push)
git push --dry-run origin $(git branch --show-current) 2>&1

# Option 2: Check gh auth status and repo permissions
gh auth status
gh repo view --json owner,name,viewerPermission
```

**Logic for Determining Push Remote:**

**If `origin` is writable**: Use `origin` directly (maintainer workflow)

**If `origin` is NOT writable**: Look for a fork remote

- Common fork remote names: `fork`, `user`, `<username>`, or any remote pointing to
user's fork
- Verify the fork remote points to user's own fork via
`gh repo view <remote-url> --json owner`

**If no fork remote found**: Ask user to add their fork as a remote:

```bash
git remote add fork https://github.com/<username>/AReaL.git
```

**Store for later use:**

- `PUSH_REMOTE`: The remote to push to (e.g., `origin` or `fork`)
- `UPSTREAM_REPO`: The upstream repo for PR target (e.g., `inclusionAI/AReaL`)
- `FORK_OWNER`: Fork owner username (for `--head` parameter if needed)

```bash
# Example detection script
UPSTREAM_REPO="inclusionAI/AReaL"
PUSH_REMOTE=""

# Check if we can push to origin
if git push --dry-run origin HEAD 2>/dev/null; then
PUSH_REMOTE="origin"
else
# Find fork remote (any remote that's not origin and points to user's fork)
for remote in $(git remote); do
if [[ "$remote" != "origin" ]]; then
remote_url=$(git remote get-url "$remote" 2>/dev/null)
if [[ "$remote_url" =~ github\.com/([^/]+)/AReaL ]]; then
PUSH_REMOTE="$remote"
FORK_OWNER="${BASH_REMATCH[1]}"
break
fi
fi
done
fi

if [[ -z "$PUSH_REMOTE" ]]; then
echo "ERROR: No writable remote found. Please add your fork:"
echo " git remote add fork https://github.com/<username>/AReaL.git"
exit 1
fi
```

### Step 2: Check for Existing PR

```bash
Expand Down Expand Up @@ -256,6 +325,7 @@ Show preview to user:

```
─────────────────────────────────────────────────
Remote: <fork-remote> (fork) → origin (upstream)
Branch: feat/vision-rlvr → main

PR Title:
Expand Down Expand Up @@ -310,22 +380,38 @@ Files changed:
─────────────────────────────────────────────────

Commands to execute:
1. git push -u origin feat/vision-rlvr
2. gh pr create --title "..." --body "..." [--draft]
1. git push -f -u <fork-remote> feat/vision-rlvr
2. gh pr create --repo inclusionAI/AReaL --head <username>:feat/vision-rlvr --base main --title "..." --body "..." [--draft]
─────────────────────────────────────────────────
```

**Confirm with user**, then execute:

```bash
# Force push branch to remote (required after squash)
git push -f -u origin $(git branch --show-current)
# Get current branch name
CURRENT_BRANCH=$(git branch --show-current)

# Force push branch to determined remote (required after squash)
# PUSH_REMOTE was determined in Step 1.5
git push -f -u "$PUSH_REMOTE" "$CURRENT_BRANCH"

# Determine if this is a fork PR (cross-repo)
if [[ "$PUSH_REMOTE" != "origin" ]]; then
# Fork workflow: PR from fork to upstream
PR_HEAD="${FORK_OWNER}:${CURRENT_BRANCH}"
GH_PR_REPO="--repo ${UPSTREAM_REPO}"
else
# Maintainer workflow: PR within same repo
PR_HEAD="$CURRENT_BRANCH"
GH_PR_REPO=""
fi

# Create or edit PR using gh CLI with GitHub template format
# If PR exists, use 'gh pr edit' instead of 'gh pr create'
if gh pr view &>/dev/null; then
if gh pr view --repo "${UPSTREAM_REPO}" --head "${PR_HEAD}" &>/dev/null; then
# Update existing PR
gh pr edit \
--repo "${UPSTREAM_REPO}" \
--title "feat(workflow): add vision support to RLVR" \
--body "$(cat <<'EOF'
[PR description here]
Expand All @@ -334,6 +420,8 @@ EOF
else
# Create new PR
gh pr create \
--repo "${UPSTREAM_REPO}" \
--head "${PR_HEAD}" \
--base main \
--title "feat(workflow): add vision support to RLVR" \
--body "$(cat <<'EOF'
Expand Down Expand Up @@ -614,7 +702,14 @@ If force push fails:
1. Verify remote branch exists
1. Check GitHub authentication: `gh auth status`
1. Confirm branch protection rules allow force push
1. Provide manual push instructions if needed
1. **For fork workflow**: Verify the fork remote URL is correct and you have push access
1. Provide manual push instructions if needed:
```bash
# Fork workflow
git push -f -u <fork-remote> <branch>
# Maintainer workflow
git push -f -u origin <branch>
```

### PR Creation/Update Failures

Expand Down Expand Up @@ -669,7 +764,10 @@ Invocation: /create-pr

## Design Philosophy

- Automates full PR creation workflow: fetch, rebase, **squash to single commit**, push, create/update PR
- Automates full PR creation workflow: detect remote, fetch, rebase, **squash to single commit**, push, create/update PR
- **Supports both maintainer and fork workflows**:
- Maintainer: push to `origin`, create PR within same repo
- Fork: push to fork remote, create cross-repo PR to upstream
- **Always squashes all commits** since `origin/main` into a single commit with message generated via the `commit-conventions` skill
- **Handles existing PRs** by detecting them and force-updating after user permission
- Follows repository's Conventional Commits format
Expand Down
54 changes: 54 additions & 0 deletions areal/infra/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,60 @@ def retrieve_batch_data(shard_id: str):
return jsonify({"status": "error", "message": str(e)}), 500


@app.route("/data/batch", methods=["POST"])
def retrieve_batch_data_many():
"""Retrieve multiple batch data shards in one request."""

try:
payload = request.get_json(silent=True) or {}
shard_ids = payload.get("shard_ids", [])
if not isinstance(shard_ids, list) or not all(
isinstance(shard_id, str) for shard_id in shard_ids
):
return (
jsonify(
{
"status": "error",
"message": "Expected JSON body with string list field 'shard_ids'",
}
),
400,
)

data = []
missing_shard_ids = []
for shard_id in shard_ids:
try:
data.append(rtensor.fetch(shard_id))
except KeyError:
missing_shard_ids.append(shard_id)

if missing_shard_ids:
return (
jsonify(
{
"status": "error",
"message": "One or more requested shards were not found",
"missing_shard_ids": missing_shard_ids,
}
),
400,
)

serialized_data = serialize_value(data)
data_bytes = orjson.dumps(serialized_data)
logger.debug(
"Retrieved %s batch shards (size=%s bytes)",
len(shard_ids),
len(data_bytes),
)
return Response(data_bytes, mimetype="application/octet-stream")

except Exception as e:
logger.error(f"Error retrieving batch shards: {e}\n{traceback.format_exc()}")
return jsonify({"status": "error", "message": str(e)}), 500


@app.route("/data/clear", methods=["DELETE"])
def clear_batch_data():
"""Clear specified batch data shards.
Expand Down
98 changes: 92 additions & 6 deletions areal/infra/rpc/rtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from dataclasses import dataclass
from threading import Lock
from typing import Any, Protocol
from typing import Any, Protocol, cast

import aiohttp
import orjson
Expand Down Expand Up @@ -83,6 +83,11 @@ class TensorShardInfo:


class HttpRTensorBackend:
def __init__(self, max_shards_per_request: int = 32) -> None:
if max_shards_per_request <= 0:
raise ValueError("max_shards_per_request must be positive")
self.max_shards_per_request = max_shards_per_request

def _create_session(self) -> aiohttp.ClientSession:
"""Create a properly configured aiohttp session for large tensor transfers."""
timeout = aiohttp.ClientTimeout(
Expand Down Expand Up @@ -114,8 +119,10 @@ async def _fetch_tensor(
try:
async with session.get(url) as resp:
if resp.status != 200:
error_body = (await resp.text()).strip()
detail = f" body={error_body}" if error_body else ""
raise RuntimeError(
f"Failed to fetch shard from {url}: {resp.status}"
f"Failed to fetch shard from {url}: {resp.status}{detail}"
)
data_bytes = await resp.read()
serialized_data = orjson.loads(data_bytes)
Expand All @@ -138,17 +145,96 @@ async def _fetch_tensor(
f"Last error: {repr(last_exception)}"
)

async def _fetch_shard_group(
self,
session: aiohttp.ClientSession,
node_addr: str,
grouped: list[tuple[int, TensorShardInfo]],
max_retries: int = 3,
retry_delay: float = 1.0,
) -> list[torch.Tensor]:
from areal.infra.rpc.serialization import deserialize_value

shard_ids = [shard.shard_id for _, shard in grouped]
url = f"http://{node_addr}/data/batch"
last_exception = None

for attempt in range(max_retries):
try:
async with session.post(url, json={"shard_ids": shard_ids}) as resp:
if resp.status != 200:
error_body = (await resp.text()).strip()
detail = f" body={error_body}" if error_body else ""
raise RuntimeError(
f"Failed to fetch shard batch from {url}: {resp.status}{detail}"
)

data_bytes = await resp.read()
serialized_data = orjson.loads(data_bytes)
tensors = cast(
list[torch.Tensor], deserialize_value(serialized_data)
)
if len(tensors) != len(grouped):
raise RuntimeError(
f"Batch fetch from {url} returned {len(tensors)} shards for {len(grouped)} requested"
)
return tensors
except (TimeoutError, aiohttp.ClientError) as e:
last_exception = e
logger.warning(
"RTensor batch fetch from %s failed: %s: %s (attempt %d/%d)",
url,
e.__class__.__name__,
str(e),
attempt + 1,
max_retries,
)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)

raise RuntimeError(
f"Failed to fetch shard batch from {url} after {max_retries} attempts. "
f"Last error: {repr(last_exception)}"
)

def fetch(self, shards: list[TensorShardInfo]) -> list[torch.Tensor]:
"""Fetch multiple shards concurrently via HTTP using a single session."""
if not shards:
return []

async def _fetch():
indexed_shards = list(enumerate(shards))
shards_by_node: dict[str, list[tuple[int, TensorShardInfo]]] = defaultdict(
list
)
for index, shard in indexed_shards:
shards_by_node[shard.node_addr].append((index, shard))

results: list[torch.Tensor | None] = [None] * len(shards)

async with self._create_session() as session:
tasks = [
self._fetch_tensor(session, s.shard_id, s.node_addr) for s in shards
]
return await asyncio.gather(*tasks)

async def _fetch_node(
node_addr: str, grouped: list[tuple[int, TensorShardInfo]]
) -> None:
for start in range(0, len(grouped), self.max_shards_per_request):
chunk = grouped[start : start + self.max_shards_per_request]
tensors = await self._fetch_shard_group(
session, node_addr, chunk
)
for (original_index, _), tensor in zip(
chunk, tensors, strict=True
):
results[original_index] = tensor

await asyncio.gather(
*[
_fetch_node(node_addr, grouped)
for node_addr, grouped in shards_by_node.items()
]
)

return cast(list[torch.Tensor], results)

return run_async_task(_fetch)

Expand Down
Loading
Loading