Make JAX optional for tinker with skyrl-train backend#1003
Closed
tyler-griggs wants to merge 2 commits intotyler/tinker-sampling-mainfrom
Closed
Make JAX optional for tinker with skyrl-train backend#1003tyler-griggs wants to merge 2 commits intotyler/tinker-sampling-mainfrom
tyler-griggs wants to merge 2 commits intotyler/tinker-sampling-mainfrom
Conversation
This refactoring allows running the tinker API server with the skyrl-train backend without requiring JAX dependencies. Changes: - Move LOSS_TYPES from loss_fns.py to types.py (no JAX dependency) - Lazy load backends in engine.py to avoid importing unused dependencies - Move JAX, flax, and optax from base to [jax] extra in pyproject.toml - Replace isinstance check with string comparison to avoid importing JaxBackend Installation: - Tinker + skyrl-train (no JAX): pip install skyrl-tx[tinker,skyrl_train] - Tinker + JAX backend: pip install skyrl-tx[tinker,jax] Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Remove nohup.out, tinker-try.txt, and tensorboard_log/ - Add these patterns to .gitignore to prevent future commits Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
f06e1e4 to
2b58d70
Compare
Member
Author
|
Closed in favor of #1004 which is based on main instead of tyler/tinker-sampling-main, to avoid including the sampling changes. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR refactors
skyrl-txto allow running the tinker API server with the skyrl-train backend without requiring JAX dependencies.Stacks on top of #999.
Changes
1. Made JAX Optional (pyproject.toml)
jax>=0.8,<1.0,flax>=0.12.2,optax>=0.2.5from base dependencies[jax]extra only2. Separated Loss Type Definitions (types.py)
LOSS_TYPESmapping totypes.py(no JAX imports)loss_fns.py3. Lazy Load Backends (engine.py)
JaxBackendandSkyRLTrainBackendat module levelget_backend_classes()function that imports backends on-demand--backend=<name>is specifiedisinstance(self.backend, JaxBackend)toself.config.backend == "jax"to avoid import4. Updated Import (engine.py)
from tx.tinker.loss_fns import LOSS_TYPEStofrom tx.tinker.types import LOSS_TYPESClean Separation Verified
No JAX imports in:
tx/tinker/types.pytx/tinker/engine.pytx/tinker/api.pytx/tinker/db_models.pytx/tinker/config.pytx/tinker/backends/skyrl_train.pytx/tinker/backends/backend.pytx/tinker/backends/utils.pyJAX only in backend-specific files:
tx/tinker/backends/jax.pytx/tinker/loss_fns.pyInstallation Patterns
Test plan
🤖 Generated with Claude Code