diff --git a/.gitignore b/.gitignore index 5b5b47357a..d8b39c1dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,8 @@ docs/_spelling/ /skyrl-gym/dist *.log +nohup.out +tensorboard_log/ # SQLite database files *.db diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 587d9c19ea..c0f9db742c 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -10,8 +10,6 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "datasets>=4.0.0", - "flax>=0.12.2", - "optax>=0.2.5", "pillow>=11.3.0", "rich>=14.1.0", "safetensors>=0.6.2", @@ -22,7 +20,6 @@ dependencies = [ "peft", "hf_transfer", "cloudpathlib>=0.23.0", - "jax>=0.8,<1.0", ] [project.optional-dependencies] @@ -61,7 +58,10 @@ azure = [ # respectively. jax = [ + "jax>=0.8,<1.0", "jax[cuda12]>=0.7.2; sys_platform == 'linux'", + "flax>=0.12.2", + "optax>=0.2.5", ] skyrl_train = [ diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index a0618f68c4..6cdd90d940 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -14,10 +14,8 @@ from tx.tinker.db_models import FutureDB, RequestStatus, CheckpointDB, CheckpointStatus, ModelDB, SessionDB from tx.tinker import types from tx.tinker.config import EngineConfig, add_model -from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig -from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig from tx.tinker.backends.utils import log_timing -from tx.tinker.loss_fns import LOSS_TYPES +from tx.tinker.types import LOSS_TYPES from tx.utils.log import logger @@ -130,10 +128,21 @@ def prepare_model_pass_batch( ) -BACKENDS = { - "jax": (JaxBackend, JaxBackendConfig), - "skyrl_train": (SkyRLTrainBackend, SkyRLTrainBackendConfig), -} +def get_backend_classes(backend_name: str): + """Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray).""" + if backend_name == "jax": + from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig + + return JaxBackend, JaxBackendConfig + elif backend_name == "skyrl_train": + from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig + + return SkyRLTrainBackend, SkyRLTrainBackendConfig + else: + raise ValueError( + f"Unknown backend: {backend_name}. Available backends: jax, skyrl_train. " + f"Make sure the backend's dependencies are installed (e.g., pip install skyrl-tx[jax])" + ) class TinkerEngine: @@ -189,10 +198,7 @@ def __init__( self.db_engine = create_engine(config.database_url, echo=False) # Initialize the backend (handles model state, computation, and adapter management) - if config.backend not in BACKENDS: - raise ValueError(f"Unknown backend: {config.backend}. Available backends: {list(BACKENDS.keys())}") - - backend_class, backend_config_class = BACKENDS[config.backend] + backend_class, backend_config_class = get_backend_classes(config.backend) backend_config = backend_config_class(**config.backend_config) self.backend = backend_class(config.base_model, backend_config) @@ -312,7 +318,7 @@ def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types. # TODO: This leaks the abstraction by accessing backend-specific config. # We should find a better way to handle this going forward. - if isinstance(self.backend, JaxBackend) and self.backend.config.sample_max_num_sequences > 0: + if self.config.backend == "jax" and self.backend.config.sample_max_num_sequences > 0: batchable = batchable[: self.backend.config.sample_max_num_sequences] return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable} diff --git a/skyrl-tx/tx/tinker/loss_fns.py b/skyrl-tx/tx/tinker/loss_fns.py index 6aa7c98db0..ad011df5b3 100644 --- a/skyrl-tx/tx/tinker/loss_fns.py +++ b/skyrl-tx/tx/tinker/loss_fns.py @@ -1,4 +1,4 @@ -"""Loss functions for training.""" +"""Loss functions for training (JAX implementations).""" import jax import jax.numpy as jnp diff --git a/skyrl-tx/tx/tinker/types.py b/skyrl-tx/tx/tinker/types.py index adbe1a8c2d..392da26605 100644 --- a/skyrl-tx/tx/tinker/types.py +++ b/skyrl-tx/tx/tinker/types.py @@ -256,3 +256,12 @@ class PreparedSampleBatch(BaseModel): # Mapping from samples back to requests: (request_id, model_id, start_idx, end_idx, prompt_logprobs_requested) request_batch_slices: list[tuple[str, str, int, int, bool]] + + +# Loss function type mappings (used for validation and backend dispatch) +# NOTE: Must stay in sync with LOSS_FUNCTION_MAP in loss_fns.py +LOSS_TYPES = { + "cross_entropy": 0, + "importance_sampling": 1, + "ppo": 2, +}