Skip to content

Make JAX optional for tinker with skyrl-train backend#1003

Closed
tyler-griggs wants to merge 2 commits intotyler/tinker-sampling-mainfrom
tyler/tinker-optional-jax
Closed

Make JAX optional for tinker with skyrl-train backend#1003
tyler-griggs wants to merge 2 commits intotyler/tinker-sampling-mainfrom
tyler/tinker-optional-jax

Conversation

@tyler-griggs
Copy link
Member

Summary

This PR refactors skyrl-tx to allow running the tinker API server with the skyrl-train backend without requiring JAX dependencies.

Stacks on top of #999.

Changes

1. Made JAX Optional (pyproject.toml)

  • Removed jax>=0.8,<1.0, flax>=0.12.2, optax>=0.2.5 from base dependencies
  • Moved all three to the [jax] extra only
  • Base install no longer requires JAX

2. Separated Loss Type Definitions (types.py)

  • Added LOSS_TYPES mapping to types.py (no JAX imports)
  • This is just a string-to-index dict for validation and backend dispatch
  • Added comment to keep it in sync with loss_fns.py

3. Lazy Load Backends (engine.py)

  • Removed eager imports of JaxBackend and SkyRLTrainBackend at module level
  • Created get_backend_classes() function that imports backends on-demand
  • Backends now only loaded when --backend=<name> is specified
  • Changed isinstance(self.backend, JaxBackend) to self.config.backend == "jax" to avoid import

4. Updated Import (engine.py)

  • Changed from tx.tinker.loss_fns import LOSS_TYPES to from tx.tinker.types import LOSS_TYPES
  • No more JAX dependency in engine.py

Clean Separation Verified

No JAX imports in:

  • tx/tinker/types.py
  • tx/tinker/engine.py
  • tx/tinker/api.py
  • tx/tinker/db_models.py
  • tx/tinker/config.py
  • tx/tinker/backends/skyrl_train.py
  • tx/tinker/backends/backend.py
  • tx/tinker/backends/utils.py

JAX only in backend-specific files:

  • tx/tinker/backends/jax.py
  • tx/tinker/loss_fns.py

Installation Patterns

# Tinker API with skyrl-train backend (NO JAX!)
pip install skyrl-tx[tinker,skyrl_train]
python -m tx.tinker.api --backend=skyrl_train --base_model=Qwen/Qwen3-0.6B

# Tinker API with JAX backend
pip install skyrl-tx[tinker,jax]
python -m tx.tinker.api --backend=jax --base_model=Qwen/Qwen3-0.6B

# Both backends available
pip install skyrl-tx[tinker,jax,skyrl_train]

Test plan

  • Verify tinker API starts with skyrl-train backend without JAX installed
  • Verify JAX backend still works with JAX installed
  • Verify LOSS_TYPES mappings are in sync between types.py and loss_fns.py
  • Run existing tinker tests with both backends

🤖 Generated with Claude Code

tyler-griggs and others added 2 commits February 2, 2026 00:05
This refactoring allows running the tinker API server with the skyrl-train
backend without requiring JAX dependencies.

Changes:
- Move LOSS_TYPES from loss_fns.py to types.py (no JAX dependency)
- Lazy load backends in engine.py to avoid importing unused dependencies
- Move JAX, flax, and optax from base to [jax] extra in pyproject.toml
- Replace isinstance check with string comparison to avoid importing JaxBackend

Installation:
- Tinker + skyrl-train (no JAX): pip install skyrl-tx[tinker,skyrl_train]
- Tinker + JAX backend: pip install skyrl-tx[tinker,jax]

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Remove nohup.out, tinker-try.txt, and tensorboard_log/
- Add these patterns to .gitignore to prevent future commits

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@tyler-griggs
Copy link
Member Author

Closed in favor of #1004 which is based on main instead of tyler/tinker-sampling-main, to avoid including the sampling changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments