Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ docs/_spelling/
/skyrl-gym/dist

*.log
nohup.out
tensorboard_log/

# SQLite database files
*.db
Expand Down
6 changes: 3 additions & 3 deletions skyrl-tx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"datasets>=4.0.0",
"flax>=0.12.2",
"optax>=0.2.5",
"pillow>=11.3.0",
"rich>=14.1.0",
"safetensors>=0.6.2",
Expand All @@ -22,7 +20,6 @@ dependencies = [
"peft",
"hf_transfer",
"cloudpathlib>=0.23.0",
"jax>=0.8,<1.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -61,7 +58,10 @@ azure = [
# respectively.

jax = [
"jax>=0.8,<1.0",
"jax[cuda12]>=0.7.2; sys_platform == 'linux'",
"flax>=0.12.2",
"optax>=0.2.5",
]

skyrl_train = [
Expand Down
30 changes: 18 additions & 12 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from tx.tinker.db_models import FutureDB, RequestStatus, CheckpointDB, CheckpointStatus, ModelDB, SessionDB
from tx.tinker import types
from tx.tinker.config import EngineConfig, add_model
from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig
from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig
from tx.tinker.backends.utils import log_timing
from tx.tinker.loss_fns import LOSS_TYPES
from tx.tinker.types import LOSS_TYPES
from tx.utils.log import logger


Expand Down Expand Up @@ -130,10 +128,21 @@ def prepare_model_pass_batch(
)


BACKENDS = {
"jax": (JaxBackend, JaxBackendConfig),
"skyrl_train": (SkyRLTrainBackend, SkyRLTrainBackendConfig),
}
def get_backend_classes(backend_name: str):
"""Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
if backend_name == "jax":
from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig

return JaxBackend, JaxBackendConfig
elif backend_name == "skyrl_train":
from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig

return SkyRLTrainBackend, SkyRLTrainBackendConfig
else:
raise ValueError(
f"Unknown backend: {backend_name}. Available backends: jax, skyrl_train. "
f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])"
)


class TinkerEngine:
Expand Down Expand Up @@ -189,10 +198,7 @@ def __init__(
self.db_engine = create_engine(config.database_url, echo=False)

# Initialize the backend (handles model state, computation, and adapter management)
if config.backend not in BACKENDS:
raise ValueError(f"Unknown backend: {config.backend}. Available backends: {list(BACKENDS.keys())}")

backend_class, backend_config_class = BACKENDS[config.backend]
backend_class, backend_config_class = get_backend_classes(config.backend)
backend_config = backend_config_class(**config.backend_config)
self.backend = backend_class(config.base_model, backend_config)

Expand Down Expand Up @@ -312,7 +318,7 @@ def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.

# TODO: This leaks the abstraction by accessing backend-specific config.
# We should find a better way to handle this going forward.
if isinstance(self.backend, JaxBackend) and self.backend.config.sample_max_num_sequences > 0:
if self.config.backend == "jax" and self.backend.config.sample_max_num_sequences > 0:
batchable = batchable[: self.backend.config.sample_max_num_sequences]

return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable}
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tx/tinker/loss_fns.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Loss functions for training."""
"""Loss functions for training (JAX implementations)."""

import jax
import jax.numpy as jnp
Expand Down
9 changes: 9 additions & 0 deletions skyrl-tx/tx/tinker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,12 @@ class PreparedSampleBatch(BaseModel):

# Mapping from samples back to requests: (request_id, model_id, start_idx, end_idx, prompt_logprobs_requested)
request_batch_slices: list[tuple[str, str, int, int, bool]]


# Loss function type mappings (used for validation and backend dispatch)
# NOTE: Must stay in sync with LOSS_FUNCTION_MAP in loss_fns.py
LOSS_TYPES = {
"cross_entropy": 0,
"importance_sampling": 1,
"ppo": 2,
}
Loading