-
Notifications
You must be signed in to change notification settings - Fork 346
Add ECHO terminal agent training integration #1716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vshrivas
wants to merge
2
commits into
NovaSky-AI:main
Choose a base branch
from
vshrivas:msft/echo-rl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # ECHO Terminal-Agent Training | ||
|
|
||
| This example trains terminal agents with ECHO, an environment cross-entropy hybrid objective. ECHO combines standard policy-gradient RL with an auxiliary cross-entropy loss on terminal-output tokens observed in the same rollout. | ||
|
|
||
| SkyRL provides the core RL training stack, distributed worker execution, and vLLM-backed inference. This example adds the terminal-agent dataset loader, prompt formatting, tool-call parsing, Harbor-backed environment execution, rollout construction, token masks, and the optional environment-prediction loss. | ||
|
|
||
| ## Structure | ||
|
|
||
| ```text | ||
| examples/train_integrations/echo_terminal/ | ||
| entrypoint.py # Training entrypoint | ||
| generator.py # Terminal rollout loop and SkyRL trajectory construction | ||
| dataset.py # Parquet dataset loader and prompt tokenization | ||
| harbor_environment.py # Harbor container execution wrapper | ||
| interaction.py # Rollout transcript and token-mask bookkeeping | ||
| parsers.py # Tool-call parsing | ||
| prompts.py # Terminal-agent system prompts | ||
| tools.py # Tool schemas | ||
| chat_template.py # Chat-template loading helpers | ||
| chat_templates/qwen3_xml_tool_calling.jinja | ||
| world_modeling/ | ||
| config.py # ECHO config extensions | ||
| fsdp_worker.py # FSDP auxiliary-loss hook implementation | ||
| loss.py # Environment-token CE loss | ||
| trainer.py # Training-batch conversion for ECHO masks | ||
| configs/ | ||
| qwen3_8b_rl.yaml # Vanilla GRPO baseline | ||
| qwen3_8b_rl_wm05.yaml # GRPO + ECHO loss, lambda=0.05 | ||
| ``` | ||
|
|
||
| ## Quick Start | ||
|
|
||
| Install SkyRL with the FSDP and Harbor dependencies: | ||
|
|
||
| ```bash | ||
| cd SkyRL | ||
| pip install -e ".[fsdp,harbor]" | ||
| ``` | ||
|
|
||
| Edit the train and validation parquet paths in the config you want to run: | ||
|
|
||
| ```yaml | ||
| data: | ||
| train_data: | ||
| - name: terminal_agent_train | ||
| path: /path/to/train.parquet | ||
| val_data: | ||
| - name: terminal_agent_train | ||
| path: /path/to/val.parquet | ||
| ``` | ||
|
|
||
| Set an output directory and launch the vanilla GRPO baseline: | ||
|
|
||
| ```bash | ||
| export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl | ||
| export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml | ||
| bash examples/train_integrations/echo_terminal/run_echo_terminal.sh | ||
| ``` | ||
|
|
||
| Launch ECHO with the auxiliary environment-prediction loss: | ||
|
|
||
| ```bash | ||
| export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl_wm05 | ||
| export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml | ||
| bash examples/train_integrations/echo_terminal/run_echo_terminal.sh | ||
| ``` | ||
|
|
||
| Checkpoints are written to `${OUTPUT_DIR}/ckpts`, and SkyRL logs are written to `${OUTPUT_DIR}/skyrl_logs`. | ||
|
|
||
| ## Design | ||
|
|
||
| The rollout loop is handled directly in this example rather than through Harbor's full rollout API. Harbor is used as the terminal task backend: it starts the task containers, runs shell commands, returns terminal observations, and executes verifiers. SkyRL/vLLM owns model generation so the training code has direct, batched access to generated token ids, logprobs, attention masks, sampling controls, and ECHO-specific token masks. | ||
|
|
||
| During training, the standard GRPO loss is computed on model-generated action tokens. When `trainer.algorithm.world_model_coeff > 0`, ECHO also computes cross entropy on selected terminal-output tokens from the same trajectory: | ||
|
|
||
| ```text | ||
| L = L_GRPO(action tokens) + world_model_coeff * CE(terminal-output tokens) | ||
| ``` | ||
|
|
||
| Setting `world_model_coeff: 0.0` recovers the vanilla GRPO baseline. The included ECHO config uses `world_model_coeff: 0.05` and `generator.world_loss_target: env_only`, which trains on terminal environment-output tokens while leaving the RL action-token mask unchanged. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| __all__ = [] |
97 changes: 97 additions & 0 deletions
97
examples/train_integrations/echo_terminal/chat_template.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| _PROBE_SENTINEL = "OBSERVATION_PROBE_TOKEN_ABCXYZ_12345" | ||
| _PROBE_MESSAGES: list[dict[str, str]] = [ | ||
| {"role": "system", "content": "You are a helpful agent with access to bash."}, | ||
| {"role": "user", "content": "Run `ls /tmp` and report the result."}, | ||
| { | ||
| "role": "assistant", | ||
| "content": ( | ||
| "<think>\nI should call the bash tool with the command.\n</think>\n\n" | ||
| "<tool_call>\n<function=bash>\n<parameter=command>\nls /tmp\n" | ||
| "</parameter>\n</function>\n</tool_call>" | ||
| ), | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| def check_role_roundtrip( | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| role: str, | ||
| *, | ||
| template_kwargs: dict[str, Any] | None = None, | ||
| sentinel: str = _PROBE_SENTINEL, | ||
| ) -> tuple[bool, str]: | ||
| template_kwargs = template_kwargs or {} | ||
| try: | ||
| before = tokenizer.apply_chat_template( | ||
| _PROBE_MESSAGES, | ||
| add_generation_prompt=False, | ||
| tokenize=True, | ||
| return_dict=False, | ||
| **template_kwargs, | ||
| ) | ||
| after = tokenizer.apply_chat_template( | ||
| _PROBE_MESSAGES + [{"role": role, "content": sentinel}], | ||
| add_generation_prompt=True, | ||
| tokenize=True, | ||
| return_dict=False, | ||
| **template_kwargs, | ||
| ) | ||
| except Exception as exc: | ||
| return False, f"apply_chat_template raised {type(exc).__name__}: {exc}" | ||
|
|
||
| if after[: len(before)] != before: | ||
| return False, f"prefix mismatch when {role!r} was appended" | ||
| delta = after[len(before) :] | ||
| if not delta: | ||
| return False, "empty delta" | ||
| decoded = tokenizer.decode(delta, skip_special_tokens=False) | ||
| if sentinel not in decoded: | ||
| return False, f"sentinel content not found for role {role!r}" | ||
| return True, "" | ||
|
|
||
|
|
||
| def choose_obs_role( | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| candidates: tuple[str, ...] = ("tool", "user"), | ||
| *, | ||
| template_kwargs: dict[str, Any] | None = None, | ||
| ) -> str: | ||
| failures: list[str] = [] | ||
| for role in candidates: | ||
| ok, reason = check_role_roundtrip(tokenizer, role, template_kwargs=template_kwargs) | ||
| if ok: | ||
| return role | ||
| failures.append(f" - {role!r}: {reason}") | ||
| raise RuntimeError( | ||
| f"None of {list(candidates)!r} produced a usable observation role for " | ||
| f"tokenizer {tokenizer.name_or_path!r}. Failures:\n" + "\n".join(failures) | ||
| ) | ||
|
|
||
|
|
||
| def resolve_chat_template_path(template_path: str | Path) -> Path: | ||
| path = Path(template_path).expanduser() | ||
| if path.is_absolute(): | ||
| candidates = [path] | ||
| else: | ||
| module_path = Path(__file__).resolve() | ||
| candidates = [ | ||
| Path.cwd() / path, | ||
| module_path.parent / path, | ||
| module_path.parents[2] / path, | ||
| ] | ||
|
|
||
| for candidate in candidates: | ||
| if candidate.exists(): | ||
| return candidate | ||
| raise FileNotFoundError(f"Chat template not found at {template_path!r}; checked {candidates!r}") | ||
|
|
||
|
|
||
| def load_chat_template(tokenizer: PreTrainedTokenizerBase, template_path: str | Path) -> None: | ||
| tokenizer.chat_template = resolve_chat_template_path(template_path).read_text() |
154 changes: 154 additions & 0 deletions
154
examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| {%- set image_count = namespace(value=0) %} | ||
| {%- set video_count = namespace(value=0) %} | ||
| {%- macro render_content(content, do_vision_count, is_system_content=false) %} | ||
| {%- if content is string %} | ||
| {{- content }} | ||
| {%- elif content is iterable and content is not mapping %} | ||
| {%- for item in content %} | ||
| {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} | ||
| {%- if is_system_content %} | ||
| {{- raise_exception('System message cannot contain images.') }} | ||
| {%- endif %} | ||
| {%- if do_vision_count %} | ||
| {%- set image_count.value = image_count.value + 1 %} | ||
| {%- endif %} | ||
| {%- if add_vision_id %} | ||
| {{- 'Picture ' ~ image_count.value ~ ': ' }} | ||
| {%- endif %} | ||
| {{- '<|vision_start|><|image_pad|><|vision_end|>' }} | ||
| {%- elif 'video' in item or item.type == 'video' %} | ||
| {%- if is_system_content %} | ||
| {{- raise_exception('System message cannot contain videos.') }} | ||
| {%- endif %} | ||
| {%- if do_vision_count %} | ||
| {%- set video_count.value = video_count.value + 1 %} | ||
| {%- endif %} | ||
| {%- if add_vision_id %} | ||
| {{- 'Video ' ~ video_count.value ~ ': ' }} | ||
| {%- endif %} | ||
| {{- '<|vision_start|><|video_pad|><|vision_end|>' }} | ||
| {%- elif 'text' in item %} | ||
| {{- item.text }} | ||
| {%- else %} | ||
| {{- raise_exception('Unexpected item type in content.') }} | ||
| {%- endif %} | ||
| {%- endfor %} | ||
| {%- elif content is none or content is undefined %} | ||
| {{- '' }} | ||
| {%- else %} | ||
| {{- raise_exception('Unexpected content type.') }} | ||
| {%- endif %} | ||
| {%- endmacro %} | ||
| {%- if not messages %} | ||
| {{- raise_exception('No messages provided.') }} | ||
| {%- endif %} | ||
| {%- if tools and tools is iterable and tools is not mapping %} | ||
| {{- '<|im_start|>system\n' }} | ||
| {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }} | ||
| {%- for tool in tools %} | ||
| {{- "\n" }} | ||
| {{- tool | tojson }} | ||
| {%- endfor %} | ||
| {{- "\n</tools>" }} | ||
| {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }} | ||
| {%- if messages[0].role == 'system' %} | ||
| {%- set content = render_content(messages[0].content, false, true)|trim %} | ||
| {%- if content %} | ||
| {{- '\n\n' + content }} | ||
| {%- endif %} | ||
| {%- endif %} | ||
| {{- '<|im_end|>\n' }} | ||
| {%- else %} | ||
| {%- if messages[0].role == 'system' %} | ||
| {%- set content = render_content(messages[0].content, false, true)|trim %} | ||
| {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} | ||
| {%- endif %} | ||
| {%- endif %} | ||
| {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} | ||
| {%- for message in messages[::-1] %} | ||
| {%- set index = (messages|length - 1) - loop.index0 %} | ||
| {%- if ns.multi_step_tool and message.role == "user" %} | ||
| {%- set content = render_content(message.content, false)|trim %} | ||
| {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %} | ||
| {%- set ns.multi_step_tool = false %} | ||
| {%- set ns.last_query_index = index %} | ||
| {%- endif %} | ||
| {%- endif %} | ||
| {%- endfor %} | ||
| {%- if ns.multi_step_tool %} | ||
| {{- raise_exception('No user query found in messages.') }} | ||
| {%- endif %} | ||
| {%- for message in messages %} | ||
| {%- set content = render_content(message.content, true)|trim %} | ||
| {%- if message.role == "system" %} | ||
| {%- if not loop.first %} | ||
| {{- raise_exception('System message must be at the beginning.') }} | ||
| {%- endif %} | ||
| {%- elif message.role == "user" %} | ||
| {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} | ||
| {%- elif message.role == "assistant" %} | ||
| {%- set reasoning_content = '' %} | ||
| {%- if message.reasoning_content is string %} | ||
| {%- set reasoning_content = message.reasoning_content %} | ||
| {%- else %} | ||
| {%- if '</think>' in content %} | ||
| {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %} | ||
| {%- set content = content.split('</think>')[-1].lstrip('\n') %} | ||
| {%- endif %} | ||
| {%- endif %} | ||
| {%- set reasoning_content = reasoning_content|trim %} | ||
| {%- if loop.index0 > ns.last_query_index %} | ||
| {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }} | ||
| {%- else %} | ||
| {{- '<|im_start|>' + message.role + '\n' + content }} | ||
| {%- endif %} | ||
| {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} | ||
| {%- for tool_call in message.tool_calls %} | ||
| {%- if tool_call.function is defined %} | ||
| {%- set tool_call = tool_call.function %} | ||
| {%- endif %} | ||
| {%- if loop.first %} | ||
| {%- if content|trim %} | ||
| {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }} | ||
| {%- else %} | ||
| {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }} | ||
| {%- endif %} | ||
| {%- else %} | ||
| {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }} | ||
| {%- endif %} | ||
| {%- if tool_call.arguments is defined %} | ||
| {%- for args_name, args_value in tool_call.arguments|items %} | ||
| {{- '<parameter=' + args_name + '>\n' }} | ||
| {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} | ||
| {{- args_value }} | ||
| {{- '\n</parameter>\n' }} | ||
| {%- endfor %} | ||
| {%- endif %} | ||
| {{- '</function>\n</tool_call>' }} | ||
| {%- endfor %} | ||
| {%- endif %} | ||
| {{- '<|im_end|>\n' }} | ||
| {%- elif message.role == "tool" %} | ||
| {%- if loop.previtem and loop.previtem.role != "tool" %} | ||
| {{- '<|im_start|>user' }} | ||
| {%- endif %} | ||
| {{- '\n<tool_response>\n' }} | ||
| {{- content }} | ||
| {{- '\n</tool_response>' }} | ||
| {%- if not loop.last and loop.nextitem.role != "tool" %} | ||
| {{- '<|im_end|>\n' }} | ||
| {%- elif loop.last %} | ||
| {{- '<|im_end|>\n' }} | ||
| {%- endif %} | ||
| {%- else %} | ||
| {{- raise_exception('Unexpected message role.') }} | ||
| {%- endif %} | ||
| {%- endfor %} | ||
| {%- if add_generation_prompt %} | ||
| {{- '<|im_start|>assistant\n' }} | ||
| {%- if enable_thinking is defined and enable_thinking is false %} | ||
| {{- '<think>\n\n</think>\n\n' }} | ||
| {%- else %} | ||
| {{- '<think>\n' }} | ||
| {%- endif %} | ||
| {%- endif %} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once we merge this and verify it's working, we should pin the commit that's being used for reproducibility