Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"rich>=14.1.0",
"safetensors>=0.6.2",
"tokenizers>=0.21.2",
"transformers>=4.56.1,<5",
"transformers>=5.0.0",
"typer>=0.17.4",
# "wandb>=0.22.0",
"peft",
Expand Down Expand Up @@ -97,11 +97,11 @@ skyrl-train = [

fsdp = [
"skyrl[skyrl-train]",
"vllm==0.13.0; sys_platform == 'linux'",
"vllm==0.16.0; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.9.0; sys_platform == 'linux'",
"torch==2.9.1; sys_platform == 'linux'",
"flashinfer-python; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-jit-cache==0.6.3; sys_platform == 'linux' and platform_machine == 'x86_64'",
"torchvision; sys_platform == 'linux'",
]

Expand All @@ -111,13 +111,13 @@ megatron = [
"skyrl[skyrl-train]; python_version == '3.12'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux' and python_version == '3.12'",
"flash-attn==2.8.1; sys_platform == 'linux' and python_version == '3.12'",
"vllm==0.13.0; sys_platform == 'linux' and python_version == '3.12'",
"torch==2.9.0; sys_platform == 'linux' and python_version == '3.12'",
"flashinfer-python==0.5.3; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"vllm==0.16.0; sys_platform == 'linux' and python_version == '3.12'",
"torch==2.9.1; sys_platform == 'linux' and python_version == '3.12'",
"flashinfer-python==0.6.3; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"torchvision; sys_platform == 'linux' and python_version == '3.12'",
"megatron-bridge; sys_platform == 'linux' and python_version == '3.12'",
"megatron-core==0.15.0; sys_platform == 'linux' and python_version == '3.12'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flashinfer-jit-cache==0.6.3; sys_platform == 'linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"nvidia-modelopt; sys_platform == 'linux' and python_version == '3.12'",
]

Expand Down Expand Up @@ -174,7 +174,7 @@ required-environments = [
]

constraint-dependencies = [
"flashinfer-jit-cache==0.5.3",
"flashinfer-jit-cache==0.6.3",
]
# each backend should have separate dependencies that can potentially clash
# megatron also clashes with the jax dependency from gpu and tpu extras
Expand Down Expand Up @@ -209,6 +209,8 @@ override-dependencies = [
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.15.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux' and python_version == '3.12'",
"numpy>=2.0.0; sys_platform == 'linux'",
"transformers>=5.0.0; sys_platform == 'linux'",
]

[tool.uv.extra-build-dependencies]
Expand Down
7 changes: 6 additions & 1 deletion skyrl-agent/skyrl_agent/agents/oh_codeact/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ def _encode_prompt(self, messages):
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
enable_thinking=self.qwen3_enable_thinking,
chat_template=chat_template.read_text(),
)
else:
input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, enable_thinking=self.qwen3_enable_thinking
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
enable_thinking=self.qwen3_enable_thinking,
)
return input_ids

Expand Down
10 changes: 7 additions & 3 deletions skyrl-agent/skyrl_agent/functional/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def encode_messages(
kwargs["chat_template"] = self.chat_template.read_text()
kwargs["enable_thinking"] = self.qwen3_enable_thinking
if is_first_message:
input_ids = self.tokenizer.apply_chat_template(formatted_messages, **kwargs)
input_ids = self.tokenizer.apply_chat_template(formatted_messages, return_dict=False, **kwargs)
else:
# do incremental encoding,
# for assistant messages, we assume the generation prompt is already added in the previous message
Expand All @@ -169,7 +169,9 @@ def encode_messages(
]
is_assistant_message = formatted_messages[0]["role"] == "assistant"
kwargs["add_generation_prompt"] = True if is_assistant_message else False
base_conversation_token_ids = self.tokenizer.apply_chat_template(base_conversation, **kwargs)
base_conversation_token_ids = self.tokenizer.apply_chat_template(
base_conversation, return_dict=False, **kwargs
)

if not is_assistant_message:
# remove tokens after the last EOS
Expand All @@ -181,7 +183,9 @@ def encode_messages(
base_conversation_token_ids = base_conversation_token_ids[: last_eos_token_index + 1]
kwargs["add_generation_prompt"] = add_generation
full_conversation = base_conversation + formatted_messages
full_conversation_token_ids = self.tokenizer.apply_chat_template(full_conversation, **kwargs)
full_conversation_token_ids = self.tokenizer.apply_chat_template(
full_conversation, return_dict=False, **kwargs
)
input_ids = full_conversation_token_ids[len(base_conversation_token_ids) :]

return input_ids
Expand Down
15 changes: 8 additions & 7 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ override-dependencies = [
"causal-conv1d; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.15.0; sys_platform == 'linux'",
"transformers>=5.0.0; sys_platform == 'linux'",
]
[tool.uv.extra-build-dependencies]
flash-attn = [{requirement = "torch", match-runtime = true}]
Expand Down Expand Up @@ -130,11 +131,11 @@ harbor = [
"harbor",
]
vllm = [
"vllm==0.13.0; sys_platform == 'linux'",
"vllm==0.16.0; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.9.0; sys_platform == 'linux'",
"torch==2.9.1; sys_platform == 'linux'",
"flashinfer-python; sys_platform == 'linux'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux'",
"flashinfer-jit-cache==0.6.3; sys_platform == 'linux'",
"torchvision; sys_platform == 'linux'",
]
sglang = [
Expand All @@ -147,13 +148,13 @@ sglang = [
mcore = [
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"flash-attn==2.8.1; sys_platform == 'linux'",
"vllm==0.13.0; sys_platform == 'linux'",
"torch==2.9.0; sys_platform == 'linux'",
"flashinfer-python==0.5.3; sys_platform == 'linux'",
"vllm==0.16.0; sys_platform == 'linux'",
"torch==2.9.1; sys_platform == 'linux'",
"flashinfer-python==0.6.3; sys_platform == 'linux'",
"torchvision; sys_platform == 'linux'",
"megatron-bridge; sys_platform == 'linux'",
"megatron-core==0.15.0; sys_platform == 'linux'",
"flashinfer-jit-cache==0.5.3; sys_platform == 'linux'",
"flashinfer-jit-cache==0.6.3; sys_platform == 'linux'",
"nvidia-modelopt; sys_platform == 'linux'",
]
flashrl = [
Expand Down
4 changes: 3 additions & 1 deletion skyrl-train/skyrl_train/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _read_files_and_tokenize(self):
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe.filter(
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))
lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, return_dict=False)
)
<= self.max_prompt_length,
Comment on lines +61 to 64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the change to add return_dict=False is correct for transformers>=5.0.0, I've noticed that this file seems to be an exact duplicate of skyrl/train/dataset/dataset.py. There also appear to be other duplicated or near-duplicated files like skyrl-train/skyrl_train/generators/skyrl_gym_generator.py and skyrl-train/skyrl_train/generators/utils.py.

This code duplication increases maintenance overhead, as changes need to be applied in multiple places, which is error-prone. It would be beneficial to refactor this to eliminate the duplication. Perhaps these modules could be shared in a common library.

num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
Expand Down
5 changes: 5 additions & 0 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.base_conversation,
add_generation_prompt=False,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
# We remove tokens after the last EOS token so that it can be captured in `observation_ids`.
Expand Down Expand Up @@ -245,6 +246,7 @@ async def agent_loop(
add_generation_prompt=not retokenize_chat_history,
chat_template=self.custom_chat_template if retokenize_chat_history else None,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)

Expand Down Expand Up @@ -295,6 +297,7 @@ async def agent_loop(
chat_template=self.custom_chat_template if retokenize_chat_history else None,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
agent_loop_state.loss_mask = []
Expand Down Expand Up @@ -532,6 +535,7 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List
[*self.base_conversation, *new_obs],
add_generation_prompt=not is_done,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)[len(self.base_conversation_token_ids) :]
elif not is_done:
Expand Down Expand Up @@ -612,6 +616,7 @@ async def generate_batched(
init_prompts,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
)
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
engine_output = await self.inference_engine_client.generate(engine_input)
Expand Down
10 changes: 8 additions & 2 deletions skyrl-train/skyrl_train/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,14 @@ def get_generation_prompt_ids(tokenizer, chat_template: Optional[str] = None) ->
List[int]: Token IDs for the generation prompt (e.g., "<|im_start|>assistant\n" for Qwen).
"""
empty_user = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], tokenize=True, chat_template=chat_template
[{"role": "user", "content": ""}], tokenize=True, return_dict=False, chat_template=chat_template
)
empty_user_with_generation_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True, chat_template=chat_template
[{"role": "user", "content": ""}],
add_generation_prompt=True,
tokenize=True,
return_dict=False,
chat_template=chat_template,
)

generation_prompt_ids = empty_user_with_generation_prompt[len(empty_user) :]
Expand Down Expand Up @@ -447,6 +451,7 @@ def encode_messages_subset(messages: ConversationType, tokenizer, chat_template:
base_conversation,
add_generation_prompt=False,
tokenize=True,
return_dict=False,
chat_template=chat_template,
)

Expand All @@ -455,6 +460,7 @@ def encode_messages_subset(messages: ConversationType, tokenizer, chat_template:
full_conversation,
add_generation_prompt=False,
tokenize=True,
return_dict=False,
chat_template=chat_template,
)
conversation_token_ids = full_conversation_token_ids[len(base_conversation_token_ids) :]
Expand Down
16 changes: 7 additions & 9 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
from types import SimpleNamespace
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
CompletionResponse,
)
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
from vllm.lora.request import LoRARequest
from uuid import uuid4
from skyrl_train.inference_engines.base import (
Expand Down Expand Up @@ -519,8 +521,6 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint
assert request.stream is False, "Streaming is not supported in SkyRL yet, please set stream to False."
except Exception as e:
if version.parse(vllm.__version__) >= version.parse("0.10.0"):
from vllm.entrypoints.openai.protocol import ErrorInfo

return ErrorResponse(
error=ErrorInfo(
message=str(e),
Expand Down Expand Up @@ -568,8 +568,6 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint
http_status = HTTPStatus.INTERNAL_SERVER_ERROR

if version.parse(vllm.__version__) >= version.parse("0.10.0"):
from vllm.entrypoints.openai.protocol import ErrorInfo

return ErrorResponse(
error=ErrorInfo(
message=str(e),
Expand Down
16 changes: 7 additions & 9 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
from types import SimpleNamespace
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
CompletionResponse,
)
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
from vllm.lora.request import LoRARequest
from uuid import uuid4
from skyrl.backends.skyrl_train.inference_engines.base import (
Expand Down Expand Up @@ -519,8 +521,6 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint
assert request.stream is False, "Streaming is not supported in SkyRL yet, please set stream to False."
except Exception as e:
if version.parse(vllm.__version__) >= version.parse("0.10.0"):
from vllm.entrypoints.openai.protocol import ErrorInfo

return ErrorResponse(
error=ErrorInfo(
message=str(e),
Expand Down Expand Up @@ -568,8 +568,6 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint
http_status = HTTPStatus.INTERNAL_SERVER_ERROR

if version.parse(vllm.__version__) >= version.parse("0.10.0"):
from vllm.entrypoints.openai.protocol import ErrorInfo

return ErrorResponse(
error=ErrorInfo(
message=str(e),
Expand Down
4 changes: 3 additions & 1 deletion skyrl/train/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _read_files_and_tokenize(self):
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe.filter(
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))
lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, return_dict=False)
)
<= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
Expand Down
5 changes: 5 additions & 0 deletions skyrl/train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
self.base_conversation,
add_generation_prompt=False,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
# We remove tokens after the last EOS token so that it can be captured in `observation_ids`.
Expand Down Expand Up @@ -243,6 +244,7 @@ async def agent_loop(
add_generation_prompt=not retokenize_chat_history,
chat_template=self.custom_chat_template if retokenize_chat_history else None,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)

Expand Down Expand Up @@ -287,6 +289,7 @@ async def agent_loop(
chat_template=self.custom_chat_template if retokenize_chat_history else None,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)
agent_loop_state.loss_mask = []
Expand Down Expand Up @@ -524,6 +527,7 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List
[*self.base_conversation, *new_obs],
add_generation_prompt=not is_done,
tokenize=True,
return_dict=False,
**self.generator_cfg.chat_template_kwargs,
)[len(self.base_conversation_token_ids) :]
elif not is_done:
Expand Down Expand Up @@ -604,6 +608,7 @@ async def generate_batched(
init_prompts,
add_generation_prompt=True,
tokenize=True,
return_dict=False,
)
engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
engine_output = await self.inference_engine_client.generate(engine_input)
Expand Down
Loading
Loading