diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index af5c01775..eaaa28f7e 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -34,6 +34,10 @@ tpu = [ "jax[tpu]>=0.7.2", ] +ray = [ + "ray[default]>=2.53.0", +] + tinker = [ "tinker>=0.3.0", "fastapi[standard]", diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 3319a8c3e..46a183498 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + from cloudpathlib import AnyPath from datetime import datetime, timedelta, timezone @@ -24,7 +26,9 @@ def test_process_unload_model(): model_id = "test_model" _ = engine.process_single_request( - types.RequestType.CREATE_MODEL, model_id, {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}} + types.RequestType.CREATE_MODEL, + model_id, + {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}}, ) assert engine.backend.has_model(model_id) @@ -50,7 +54,9 @@ def test_cleanup_stale_sessions(): # Create model in backend _ = engine.process_single_request( - types.RequestType.CREATE_MODEL, model_id, {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}} + types.RequestType.CREATE_MODEL, + model_id, + {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}}, ) assert engine.backend.has_model(model_id) @@ -80,3 +86,74 @@ def test_cleanup_stale_sessions(): # Run cleanup and assert one model was unloaded assert engine.cleanup_stale_sessions() == 1 assert not engine.backend.has_model(model_id) + + +def test_shutdown_without_ray(): + """Test that shutdown() works correctly when Ray is not enabled.""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={"max_lora_adapters": 4, "max_lora_rank": 32}, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + + # Without Ray, _ray_process_manager should be None + assert engine._ray_process_manager is None + + # shutdown() should not raise an error even when Ray is not used + engine.shutdown() + + # Verify _ray_process_manager is still None (no change) + assert engine._ray_process_manager is None + + +def test_shutdown_with_ray_process_manager(): + """Test that shutdown() correctly calls RayProcessManager.shutdown().""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={"max_lora_adapters": 4, "max_lora_rank": 32}, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + + # Mock the RayProcessManager + mock_ray_manager = MagicMock() + engine._ray_process_manager = mock_ray_manager + + # Call shutdown + engine.shutdown() + + # Verify RayProcessManager.shutdown() was called exactly once + mock_ray_manager.shutdown.assert_called_once() + + # Verify _ray_process_manager is set to None after shutdown + assert engine._ray_process_manager is None + + +def test_shutdown_idempotent(): + """Test that calling shutdown() multiple times is safe (idempotent).""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={"max_lora_adapters": 4, "max_lora_rank": 32}, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + + # Mock the RayProcessManager + mock_ray_manager = MagicMock() + engine._ray_process_manager = mock_ray_manager + + # Call shutdown multiple times + engine.shutdown() + engine.shutdown() + engine.shutdown() + + # Verify RayProcessManager.shutdown() was called only once + # (subsequent calls should be no-ops since _ray_process_manager is None) + mock_ray_manager.shutdown.assert_called_once() diff --git a/skyrl-tx/tx/tinker/api.py b/skyrl-tx/tx/tinker/api.py index dbedd198d..fe2c09e34 100644 --- a/skyrl-tx/tx/tinker/api.py +++ b/skyrl-tx/tx/tinker/api.py @@ -53,19 +53,29 @@ async def lifespan(app: FastAPI): # Setup external inference client if configured if app.state.engine_config.external_inference_url: - app.state.external_inference_client = ExternalInferenceClient(app.state.engine_config, app.state.db_engine) - logger.info(f"External engine configured: {app.state.engine_config.external_inference_url}") + app.state.external_inference_client = ExternalInferenceClient( + app.state.engine_config, app.state.db_engine + ) + logger.info( + f"External engine configured: {app.state.engine_config.external_inference_url}" + ) else: app.state.external_inference_client = None logger.info("Using internal engine for inference") # Build subprocess command with engine config parameters - cmd = ["uv", "run", "--extra", "tinker", "-m", "tx.tinker.engine"] + cmd = ["uv", "run", "--extra", "tinker"] + + # Add --extra ray if enable_ray is set in backend_config + backend_config = app.state.engine_config.backend_config or {} + if backend_config.get("enable_ray", False): + cmd.extend(["--extra", "ray"]) + + cmd.extend(["-m", "tx.tinker.engine"]) cmd.extend(config_to_argv(app.state.engine_config)) background_engine = await asyncio.create_subprocess_exec(*cmd) app.state.background_engine = background_engine - logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") shutting_down = False @@ -73,7 +83,9 @@ async def monitor_engine(): """Monitor engine process and exit API server if it crashes.""" exit_code = await background_engine.wait() if not shutting_down: - logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") + logger.error( + f"Background engine crashed with exit code {exit_code}, exiting API server" + ) # Start a background timer that force-exits after timeout. # Using a thread instead of asyncio task because SIGTERM handling @@ -104,7 +116,9 @@ def force_exit(): try: await asyncio.wait_for(background_engine.wait(), timeout=5) except asyncio.TimeoutError: - logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") + logger.warning( + f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing" + ) background_engine.kill() await background_engine.wait() logger.info("Background engine stopped") @@ -174,14 +188,16 @@ async def create_checkpoint( raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") else: raise HTTPException( - status_code=409, detail=f"Checkpoint '{checkpoint_id}' already exists for model '{model_id}'" + status_code=409, + detail=f"Checkpoint '{checkpoint_id}' already exists for model '{model_id}'", ) class LoRAConfig(BaseModel): rank: int seed: int | None = Field( - default=None, description="Seed for LoRA weight initialization. If None, a random seed is used." + default=None, + description="Seed for LoRA weight initialization. If None, a random seed is used.", ) @@ -278,8 +294,16 @@ def to_types(self) -> types.Datum: loss_fn_inputs=types.LossFnInputs( target_tokens=inp["target_tokens"].to_types(), weights=weights, - advantages=inp["advantages"].to_types() if "advantages" in inp else types.TensorData(data=[]), - logprobs=inp["logprobs"].to_types() if "logprobs" in inp else types.TensorData(data=[]), + advantages=( + inp["advantages"].to_types() + if "advantages" in inp + else types.TensorData(data=[]) + ), + logprobs=( + inp["logprobs"].to_types() + if "logprobs" in inp + else types.TensorData(data=[]) + ), ), model_input=self.model_input.to_types(), ) @@ -290,7 +314,9 @@ class ForwardBackwardInput(BaseModel): loss_fn: Literal["cross_entropy", "importance_sampling", "ppo"] def to_types(self) -> types.ForwardBackwardInput: - return types.ForwardBackwardInput(data=[datum.to_types() for datum in self.data], loss_fn=self.loss_fn) + return types.ForwardBackwardInput( + data=[datum.to_types() for datum in self.data], loss_fn=self.loss_fn + ) class ForwardBackwardRequest(BaseModel): @@ -334,8 +360,12 @@ class SaveWeightsForSamplerRequest(BaseModel): @model_validator(mode="after") def check_path_or_ids(self): - if not self.path and (self.sampling_session_seq_id is None or self.seq_id is None): - raise ValueError("Either 'path' or both 'sampling_session_seq_id' and 'seq_id' must be provided") + if not self.path and ( + self.sampling_session_seq_id is None or self.seq_id is None + ): + raise ValueError( + "Either 'path' or both 'sampling_session_seq_id' and 'seq_id' must be provided" + ) return self @@ -349,7 +379,9 @@ class SamplingParams(BaseModel): def to_types(self) -> types.SamplingParams: if self.max_tokens is None: - raise HTTPException(status_code=400, detail="max_tokens is currently required") + raise HTTPException( + status_code=400, detail="max_tokens is currently required" + ) # Generate a random seed if not provided seed = self.seed if self.seed is not None else random.randint(0, 2**31 - 1) @@ -399,7 +431,9 @@ def validate_model_source(self): """ if self.sampling_session_id is not None: if self.seq_id is None: - raise ValueError("'seq_id' must be provided when 'sampling_session_id' is used") + raise ValueError( + "'seq_id' must be provided when 'sampling_session_id' is used" + ) return self if (self.base_model is None) == (self.model_path is None): raise ValueError( @@ -530,7 +564,9 @@ async def healthz(): @app.post("/api/v1/create_session", response_model=CreateSessionResponse) -async def create_session(request: CreateSessionRequest, session: AsyncSession = Depends(get_session)): +async def create_session( + request: CreateSessionRequest, session: AsyncSession = Depends(get_session) +): """Create a new session + persist in DB""" session_id = f"session_{uuid4().hex[:8]}" session_db = SessionDB( @@ -546,7 +582,9 @@ async def create_session(request: CreateSessionRequest, session: AsyncSession = @app.post("/api/v1/session_heartbeat", response_model=SessionHeartbeatResponse) -async def session_heartbeat(request: SessionHeartbeatRequest, session: AsyncSession = Depends(get_session)): +async def session_heartbeat( + request: SessionHeartbeatRequest, session: AsyncSession = Depends(get_session) +): """Heartbeat for an active session to keep it alive.""" session_db = await session.get(SessionDB, request.session_id) if session_db is None: @@ -557,15 +595,22 @@ async def session_heartbeat(request: SessionHeartbeatRequest, session: AsyncSess return SessionHeartbeatResponse() -@app.post("/api/v1/create_sampling_session", response_model=CreateSamplingSessionResponse) -async def create_sampling_session(request: CreateSamplingSessionRequest, session: AsyncSession = Depends(get_session)): +@app.post( + "/api/v1/create_sampling_session", response_model=CreateSamplingSessionResponse +) +async def create_sampling_session( + request: CreateSamplingSessionRequest, session: AsyncSession = Depends(get_session) +): """Create a new sampling session within an existing session.""" session_db = await session.get(SessionDB, request.session_id) if session_db is None: raise HTTPException(status_code=404, detail="Session not found") # Exactly one of base_model or model_path must be provided if (request.base_model is None) == (request.model_path is None): - raise HTTPException(status_code=400, detail="Exactly one of base_model or model_path must be provided") + raise HTTPException( + status_code=400, + detail="Exactly one of base_model or model_path must be provided", + ) sampling_session_id = f"sampling_{uuid4().hex[:8]}" sampling_db = SamplingSessionDB( sampling_session_id=sampling_session_id, @@ -580,7 +625,9 @@ async def create_sampling_session(request: CreateSamplingSessionRequest, session @app.post("/api/v1/create_model", response_model=CreateModelResponse) -async def create_model(request: CreateModelRequest, session: AsyncSession = Depends(get_session)): +async def create_model( + request: CreateModelRequest, session: AsyncSession = Depends(get_session) +): """Create a new model, optionally with a LoRA adapter.""" # Validate session exists session_db = await session.get(SessionDB, request.session_id) @@ -591,7 +638,11 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe # alpha = 32 seems to be the tinker default (see https://thinkingmachines.ai/blog/lora/) # Generate a random seed if not provided - seed = request.lora_config.seed if request.lora_config.seed is not None else random.randint(0, 2**31 - 1) + seed = ( + request.lora_config.seed + if request.lora_config.seed is not None + else random.randint(0, 2**31 - 1) + ) lora_config = types.LoraConfig(rank=request.lora_config.rank, alpha=32.0, seed=seed) request_id = await create_future( session=session, @@ -622,7 +673,9 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe @app.post("/api/v1/unload_model", response_model=UnloadModelResponse) -async def unload_model(request: UnloadModelRequest, session: AsyncSession = Depends(get_session)): +async def unload_model( + request: UnloadModelRequest, session: AsyncSession = Depends(get_session) +): """Unload a model and free all associated resources.""" # Validate model exists model_db = await session.get(ModelDB, request.model_id) @@ -651,16 +704,22 @@ class GetInfoRequest(BaseModel): @app.post("/api/v1/get_info", response_model=ModelInfoResponse) -async def get_model_info(request: GetInfoRequest, session: AsyncSession = Depends(get_session)): +async def get_model_info( + request: GetInfoRequest, session: AsyncSession = Depends(get_session) +): """Retrieve information about the current model.""" model = await get_model(session, request.model_id) lora_config = types.LoraConfig.model_validate(model.lora_config) model_data = ModelData( - base_model=model.base_model, lora_config=LoRAConfig(rank=lora_config.rank), model_name=model.base_model + base_model=model.base_model, + lora_config=LoRAConfig(rank=lora_config.rank), + model_name=model.base_model, ) - return ModelInfoResponse(model_id=model.model_id, status=model.status, model_data=model_data) + return ModelInfoResponse( + model_id=model.model_id, status=model.status, model_data=model_data + ) @app.get("/api/v1/training_runs/{model_id}", response_model=TrainingRun) @@ -686,7 +745,9 @@ async def get_training_run(model_id: str, session: AsyncSession = Depends(get_se @app.post("/api/v1/forward_backward", response_model=FutureResponse) -async def forward_backward(request: ForwardBackwardRequest, session: AsyncSession = Depends(get_session)): +async def forward_backward( + request: ForwardBackwardRequest, session: AsyncSession = Depends(get_session) +): """Compute and accumulate gradients.""" await get_model(session, request.model_id) @@ -699,11 +760,15 @@ async def forward_backward(request: ForwardBackwardRequest, session: AsyncSessio await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) @app.post("/api/v1/forward", response_model=FutureResponse) -async def forward(request: ForwardRequest, session: AsyncSession = Depends(get_session)): +async def forward( + request: ForwardRequest, session: AsyncSession = Depends(get_session) +): """Forward pass to obtain logprobs without accumulating gradients""" await get_model(session, request.model_id) @@ -716,11 +781,15 @@ async def forward(request: ForwardRequest, session: AsyncSession = Depends(get_s await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) @app.post("/api/v1/optim_step", response_model=FutureResponse) -async def optim_step(request: OptimStepRequest, session: AsyncSession = Depends(get_session)): +async def optim_step( + request: OptimStepRequest, session: AsyncSession = Depends(get_session) +): """Update model using accumulated gradients.""" await get_model(session, request.model_id) @@ -733,11 +802,17 @@ async def optim_step(request: OptimStepRequest, session: AsyncSession = Depends( await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) @app.post("/api/v1/load_weights", response_model=FutureResponse) -async def load_weights(request: LoadWeightsRequest, req: Request, session: AsyncSession = Depends(get_session)): +async def load_weights( + request: LoadWeightsRequest, + req: Request, + session: AsyncSession = Depends(get_session), +): """Loads weights and training state.""" await get_model(session, request.model_id) @@ -749,25 +824,34 @@ async def load_weights(request: LoadWeightsRequest, req: Request, session: Async or not (checkpoint_id := path.secondary_id) ): raise HTTPException( - status_code=400, detail="request.path must be in format tinker://source_model_id/weights/checkpoint_id" + status_code=400, + detail="request.path must be in format tinker://source_model_id/weights/checkpoint_id", ) - await validate_checkpoint(req, source_model_id, checkpoint_id, types.CheckpointType.TRAINING, session) + await validate_checkpoint( + req, source_model_id, checkpoint_id, types.CheckpointType.TRAINING, session + ) request_id = await create_future( session=session, request_type=types.RequestType.LOAD_WEIGHTS, model_id=request.model_id, - request_data=types.LoadWeightsInput(source_model_id=source_model_id, checkpoint_id=checkpoint_id), + request_data=types.LoadWeightsInput( + source_model_id=source_model_id, checkpoint_id=checkpoint_id + ), ) await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) @app.post("/api/v1/save_weights", response_model=FutureResponse) -async def save_weights(request: SaveWeightsRequest, session: AsyncSession = Depends(get_session)): +async def save_weights( + request: SaveWeightsRequest, session: AsyncSession = Depends(get_session) +): """Saves weights and training state.""" # Create pending checkpoint entry (validates model exists) await create_checkpoint( @@ -786,16 +870,22 @@ async def save_weights(request: SaveWeightsRequest, session: AsyncSession = Depe await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) @app.post("/api/v1/save_weights_for_sampler", response_model=FutureResponse) -async def save_weights_for_sampler(request: SaveWeightsForSamplerRequest, session: AsyncSession = Depends(get_session)): +async def save_weights_for_sampler( + request: SaveWeightsForSamplerRequest, session: AsyncSession = Depends(get_session) +): """Saves weights in a format compatible with sampling/inference servers.""" # Get the model (validates it exists and gives us the session_id) model = await get_model(session, request.model_id) - checkpoint_id = request.path or f"ss{request.sampling_session_seq_id}_seq{request.seq_id}" + checkpoint_id = ( + request.path or f"ss{request.sampling_session_seq_id}_seq{request.seq_id}" + ) sampling_session_id = None if request.sampling_session_seq_id is not None and request.seq_id is not None: # Create the sampling session using the model's session @@ -831,14 +921,20 @@ async def save_weights_for_sampler(request: SaveWeightsForSamplerRequest, sessio await session.commit() - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) -async def get_sampling_model(request: SampleRequest, session: AsyncSession) -> (str | None, str | None): +async def get_sampling_model( + request: SampleRequest, session: AsyncSession +) -> (str | None, str | None): """Return (base_model, model_path) for a sampling request.""" # Resolve model/base from sampling_session_id if provided if request.sampling_session_id is not None: - sampling_session = await session.get(SamplingSessionDB, request.sampling_session_id) + sampling_session = await session.get( + SamplingSessionDB, request.sampling_session_id + ) if sampling_session is None: raise HTTPException(status_code=404, detail="Sampling session not found") return (sampling_session.base_model, sampling_session.model_path) @@ -846,7 +942,9 @@ async def get_sampling_model(request: SampleRequest, session: AsyncSession) -> ( @app.post("/api/v1/asample", response_model=FutureResponse) -async def asample(request: SampleRequest, req: Request, session: AsyncSession = Depends(get_session)): +async def asample( + request: SampleRequest, req: Request, session: AsyncSession = Depends(get_session) +): """Generates samples from the model (async version).""" base_model, model_path = await get_sampling_model(request, session) @@ -868,12 +966,16 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession = ) await get_model(session, model_id) # Validate that the checkpoint exists and is ready - await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.SAMPLER, session) + await validate_checkpoint( + req, model_id, checkpoint_id, types.CheckpointType.SAMPLER, session + ) request_id = await create_future( session=session, request_type=( - types.RequestType.EXTERNAL if req.app.state.external_inference_client else types.RequestType.SAMPLE + types.RequestType.EXTERNAL + if req.app.state.external_inference_client + else types.RequestType.SAMPLE ), model_id=model_id, request_data=types.SampleInput( @@ -890,13 +992,19 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession = if req.app.state.external_inference_client: asyncio.create_task( - req.app.state.external_inference_client.call_and_store_result(request_id, request, model_id, checkpoint_id) + req.app.state.external_inference_client.call_and_store_result( + request_id, request, model_id, checkpoint_id + ) ) - return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + return FutureResponse( + future_id=str(request_id), status="pending", request_id=str(request_id) + ) -@app.get("/api/v1/get_server_capabilities", response_model=GetServerCapabilitiesResponse) +@app.get( + "/api/v1/get_server_capabilities", response_model=GetServerCapabilitiesResponse +) async def get_server_capabilities(request: Request): """Retrieve information about supported models and server capabilities.""" supported_models = [ @@ -923,7 +1031,9 @@ async def retrieve_future(request: RetrieveFutureRequest, req: Request): try: async with AsyncSession(req.app.state.db_engine) as session: # First, only query the status to avoid deserializing JSON data - statement = select(FutureDB.status).where(FutureDB.request_id == int(request.request_id)) + statement = select(FutureDB.status).where( + FutureDB.request_id == int(request.request_id) + ) result = await session.exec(statement) status = result.first() @@ -932,7 +1042,9 @@ async def retrieve_future(request: RetrieveFutureRequest, req: Request): # Only fetch full record if status is terminal (completed or failed) if status in (RequestStatus.COMPLETED, RequestStatus.FAILED): - statement = select(FutureDB).where(FutureDB.request_id == int(request.request_id)) + statement = select(FutureDB).where( + FutureDB.request_id == int(request.request_id) + ) result = await session.exec(statement) future = result.first() @@ -942,7 +1054,9 @@ async def retrieve_future(request: RetrieveFutureRequest, req: Request): if future.status == RequestStatus.FAILED: # Return 400 for handled errors (validation, etc.), 500 for unexpected failures if future.result_data and "error" in future.result_data: - raise HTTPException(status_code=400, detail=future.result_data["error"]) + raise HTTPException( + status_code=400, detail=future.result_data["error"] + ) else: raise HTTPException(status_code=500, detail="Unknown error") except SATimeoutError: @@ -963,22 +1077,40 @@ async def send_telemetry(request: TelemetryRequest): async def validate_checkpoint( - request: Request, unique_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType, session: AsyncSession + request: Request, + unique_id: str, + checkpoint_id: str, + checkpoint_type: types.CheckpointType, + session: AsyncSession, ): """Validate that a model and checkpoint exist in the database, returning the checkpoint path.""" - checkpoint_db = await session.get(CheckpointDB, (unique_id, checkpoint_id, checkpoint_type)) + checkpoint_db = await session.get( + CheckpointDB, (unique_id, checkpoint_id, checkpoint_type) + ) if not checkpoint_db: - raise HTTPException(status_code=404, detail=f"Checkpoint not found: {unique_id}/{checkpoint_id}") + raise HTTPException( + status_code=404, detail=f"Checkpoint not found: {unique_id}/{checkpoint_id}" + ) if checkpoint_db.status == CheckpointStatus.PENDING: raise HTTPException(status_code=425, detail="Checkpoint is still being created") if checkpoint_db.status == CheckpointStatus.FAILED: - raise HTTPException(status_code=500, detail=f"Checkpoint creation failed: {checkpoint_db.error_message}") + raise HTTPException( + status_code=500, + detail=f"Checkpoint creation failed: {checkpoint_db.error_message}", + ) - subdir = "sampler_weights" if checkpoint_type == types.CheckpointType.SAMPLER else "" - return request.app.state.engine_config.checkpoints_base / unique_id / subdir / f"{checkpoint_id}.tar.gz" + subdir = ( + "sampler_weights" if checkpoint_type == types.CheckpointType.SAMPLER else "" + ) + return ( + request.app.state.engine_config.checkpoints_base + / unique_id + / subdir + / f"{checkpoint_id}.tar.gz" + ) @app.get("/api/v1/training_runs") @@ -988,7 +1120,11 @@ async def list_training_runs( """List all training runs""" # Use window function to get total count alongside paginated results in a single query - statement = select(ModelDB, func.count().over().label("total_count")).offset(offset).limit(limit) + statement = ( + select(ModelDB, func.count().over().label("total_count")) + .offset(offset) + .limit(limit) + ) result = await session.exec(statement) rows = result.all() @@ -1015,7 +1151,8 @@ async def list_training_runs( ) return TrainingRunsResponse( - training_runs=training_runs, cursor=Cursor(offset=offset, limit=limit, total_count=total_count) + training_runs=training_runs, + cursor=Cursor(offset=offset, limit=limit, total_count=total_count), ) @@ -1023,14 +1160,24 @@ async def list_training_runs( async def get_checkpoint_archive_url( request: Request, unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), - checkpoint_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + checkpoint_id: str = fastapi.Path( + ..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH + ), session: AsyncSession = Depends(get_session), ): """Return a 302 redirect to the download URL (SDK expects this pattern)""" - await validate_checkpoint(request, unique_id, checkpoint_id, types.CheckpointType.SAMPLER, session) + await validate_checkpoint( + request, unique_id, checkpoint_id, types.CheckpointType.SAMPLER, session + ) # Generate URL to the download endpoint and return 302 redirect - download_url = str(request.url_for("download_checkpoint_archive", unique_id=unique_id, checkpoint_id=checkpoint_id)) + download_url = str( + request.url_for( + "download_checkpoint_archive", + unique_id=unique_id, + checkpoint_id=checkpoint_id, + ) + ) expires = datetime.utcnow() + timedelta(minutes=120) response = RedirectResponse(url=download_url, status_code=302) @@ -1042,7 +1189,9 @@ async def get_checkpoint_archive_url( async def download_checkpoint_archive( request: Request, unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), - checkpoint_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + checkpoint_id: str = fastapi.Path( + ..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH + ), session: AsyncSession = Depends(get_session), ): """Actually download the checkpoint archive bytes""" @@ -1058,7 +1207,9 @@ async def download_checkpoint_archive( "Content-Length": str(file_buffer.getbuffer().nbytes), } - return StreamingResponse(file_buffer, media_type="application/octet-stream", headers=headers) + return StreamingResponse( + file_buffer, media_type="application/octet-stream", headers=headers + ) @app.get("/api/v1/training_runs/{unique_id}/checkpoints") @@ -1077,7 +1228,11 @@ async def list_checkpoints( checkpoints = [] for checkpoint in result.all(): # Construct tinker_path based on checkpoint type - path_kind = "weights" if checkpoint.checkpoint_type == types.CheckpointType.TRAINING else "sampler_weights" + path_kind = ( + "weights" + if checkpoint.checkpoint_type == types.CheckpointType.TRAINING + else "sampler_weights" + ) tinker_path = f"tinker://{unique_id}/{path_kind}/{checkpoint.checkpoint_id}" checkpoints.append( @@ -1102,13 +1257,18 @@ async def list_checkpoints_models( @app.post("/api/v1/weights_info", response_model=WeightsInfoResponse) -async def get_weights_info(request: WeightsInfoRequest, req: Request, session: AsyncSession = Depends(get_session)): +async def get_weights_info( + request: WeightsInfoRequest, + req: Request, + session: AsyncSession = Depends(get_session), +): """Get information about weights/checkpoint from a tinker path.""" path = types.TinkerPath.parse(request.tinker_path) if not path or path.kind != "weights": raise HTTPException( - status_code=400, detail="Invalid tinker path format. Expected: tinker://model_id/weights/checkpoint_id" + status_code=400, + detail="Invalid tinker path format. Expected: tinker://model_id/weights/checkpoint_id", ) model_id = path.primary_id @@ -1118,7 +1278,9 @@ async def get_weights_info(request: WeightsInfoRequest, req: Request, session: A model = await get_model(session, model_id) # Validate checkpoint exists and is completed - await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.TRAINING, session) + await validate_checkpoint( + req, model_id, checkpoint_id, types.CheckpointType.TRAINING, session + ) lora_config = types.LoraConfig.model_validate(model.lora_config) is_lora = lora_config.rank > 0 @@ -1137,7 +1299,11 @@ async def root(): "name": "Tinker API Mock", "version": "0.0.1", "endpoints": { - "models": ["/api/v1/create_model", "/api/v1/get_info", "/api/v1/training_runs/{model_id}"], + "models": [ + "/api/v1/create_model", + "/api/v1/get_info", + "/api/v1/training_runs/{model_id}", + ], "training": ["/api/v1/forward_backward", "/api/v1/optim_step"], "futures": ["/api/v1/retrieve_future"], "service": ["/api/v1/get_server_capabilities"], @@ -1163,7 +1329,9 @@ async def root(): args = parser.parse_args() # Create EngineConfig from parsed arguments (only EngineConfig fields) - engine_config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) + engine_config = EngineConfig.model_validate( + {k: v for k, v in vars(args).items() if k in EngineConfig.model_fields} + ) # Store config in app.state so lifespan can access it app.state.engine_config = engine_config diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index e57bc7b7e..be3ff688f 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -59,10 +59,16 @@ class JaxBackendConfig(BaseModel, extra="forbid"): """Configuration specific to the JAX backend.""" - max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters") + max_lora_adapters: int = Field( + default=32, description="Maximum number of LoRA adapters" + ) max_lora_rank: int = Field(default=32, description="Maximum LoRA rank") - tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model") - expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers") + tensor_parallel_size: int = Field( + default=1, description="Tensor parallelism degree to use for the model" + ) + expert_parallel_size: int = Field( + default=1, description="Expert parallelism degree for MoE layers" + ) fully_sharded_data_parallel_size: int = Field( default=1, description="Fully sharded data parallelism degree for the model" ) @@ -74,7 +80,9 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=0, description="Maximum batch size (measured in number of sequences) for sampling/generation; 0 means disabled (use full batch)", ) - enforce_eager: bool = Field(default=False, description="Disable JAX JIT compilation") + enforce_eager: bool = Field( + default=False, description="Disable JAX JIT compilation" + ) shard_attention_heads: bool = Field( default=True, description="Whether to shard attention linear layers (qkvo projections) across tensor parallel devices", @@ -92,6 +100,32 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=None, description="Total number of processes in the multi-node cluster", ) + # Ray configuration (Solution C: Ray as process launcher only) + enable_ray: bool = Field( + default=False, + description="If true, use Ray to automatically manage worker processes across nodes", + ) + ray_num_workers: int = Field( + default=1, + description="Number of Ray worker actors to spawn (only used when enable_ray=True)", + ) + # TODO: Future enhancements for GPU allocation: + # 1. Support "auto" mode: detect GPUs per node via ray.cluster_resources() + # 2. Support scheduling_strategy (SPREAD, PACK) for better placement control + # 3. Support placement_group for gang scheduling (ensure all workers start together) + # 4. Support accelerator_type (e.g., "A100", "H100") for heterogeneous clusters + ray_gpus_per_worker: int | None = Field( + default=None, + description=( + "Number of GPUs to allocate per Ray worker actor (typically the number of local GPUs per node). " + "If None, no GPU resource constraint is applied. " + "Example: For 2 nodes with 8 GPUs each, set ray_num_workers=2 and ray_gpus_per_worker=8." + ), + ) + ray_cpus_per_worker: int = Field( + default=1, + description="Number of CPUs to allocate per Ray worker actor", + ) @jax.tree_util.register_dataclass @@ -103,14 +137,18 @@ class AccumulatedGradients: counts: jax.Array @classmethod - def create(cls, lora_params: nnx.State, max_adapters: int) -> "AccumulatedGradients": + def create( + cls, lora_params: nnx.State, max_adapters: int + ) -> "AccumulatedGradients": """Initialize with zeros.""" return cls( grad_sum=jax.tree.map(jnp.zeros_like, lora_params), counts=jnp.zeros((max_adapters,), dtype=jnp.int32), ) - def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "AccumulatedGradients": + def add( + self, lora_grads: nnx.State, adapter_indices: jax.Array + ) -> "AccumulatedGradients": """Accumulate gradients and increment counts.""" # Count occurrences of each adapter index in the batch batch_counts = jnp.bincount(adapter_indices, length=self.counts.shape[0]) @@ -123,14 +161,18 @@ def get_mean(self, adapter_index: jax.Array) -> nnx.State: """Compute mean gradients for a specific adapter, with zeros for all other adapters.""" count = self.counts[adapter_index] return jax.tree.map( - lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)), + lambda g: jnp.zeros_like(g) + .at[adapter_index] + .set(g[adapter_index] / count.astype(g.dtype)), self.grad_sum, ) def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients": """Reset gradients and count for a specific adapter.""" return AccumulatedGradients( - grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum), + grad_sum=jax.tree.map( + lambda g: g.at[adapter_index].set(0.0), self.grad_sum + ), counts=self.counts.at[adapter_index].set(0), ) @@ -177,17 +219,29 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ("fsdp", "ep", "tp"), ) with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): - self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0)) + self.model = model_class( + self.model_config, + dtype=get_dtype(self.model_config.dtype), + rngs=nnx.Rngs(0), + ) load_safetensors(checkpoint_path, self.model_config, self.model) # Split model into LoRA and non-LoRA parameters - self.graphdef, self.lora_params, self.non_lora_params = nnx.split(self.model, self.model.is_lora_param, ...) + self.graphdef, self.lora_params, self.non_lora_params = nnx.split( + self.model, self.model.is_lora_param, ... + ) # Initialize adapter 0 with minimal config (required for base model sampling path) - init_lora_adapter(self.model, adapter_index=0, lora_config=types.LoraConfig(rank=1, alpha=1.0, seed=0)) + init_lora_adapter( + self.model, + adapter_index=0, + lora_config=types.LoraConfig(rank=1, alpha=1.0, seed=0), + ) # Initialize global accumulated gradients - self.accumulated_grads = AccumulatedGradients.create(self.lora_params, config.max_lora_adapters) + self.accumulated_grads = AccumulatedGradients.create( + self.lora_params, config.max_lora_adapters + ) # Per-model optimizer storage (managed internally) self.optimizers: dict[str, nnx.Optimizer] = {} @@ -215,14 +269,20 @@ def _jit_timing_context(self, seq_len: int, mode: str): seq_len: The sequence length being compiled mode: Either 'train' or 'sample' to track separately """ - jit_times = self.metrics.train_seq_len_jit_times if mode == "train" else self.metrics.sample_seq_len_jit_times + jit_times = ( + self.metrics.train_seq_len_jit_times + if mode == "train" + else self.metrics.sample_seq_len_jit_times + ) if not self.config.enforce_eager and seq_len not in jit_times: logger.info(f"JIT compiling for {mode} seq_len={seq_len} in progress...") start_time = time.time() yield elapsed = time.time() - start_time jit_times[seq_len] = elapsed - logger.info(f"JIT compilation for {mode} seq_len={seq_len} took {elapsed:.2f}s") + logger.info( + f"JIT compilation for {mode} seq_len={seq_len} took {elapsed:.2f}s" + ) else: yield @@ -245,7 +305,9 @@ def _model_forward( attention_mask=attention_mask, adapter_indices=adapter_indices, ) - return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) + return model.compute_logprobs( + output.last_hidden_state, target_ids, adapter_indices + ) if self.config.gradient_checkpointing: # Wrap the model forward call to use jax.checkpoint for gradient checkpointing @@ -274,7 +336,9 @@ def loss_for_lora( target_ids, ) - def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages): + def compute_loss_per_example( + loss_fn_type, target_logprobs, loss_mask, sampling_logprobs, advantages + ): return jax.lax.switch( loss_fn_type, LOSS_FUNCTIONS, @@ -292,7 +356,9 @@ def compute_loss_per_example(loss_fn_type, target_logprobs, loss_mask, sampling_ advantages, ) - per_seq_loss = per_token_losses.sum(axis=-1) / jnp.maximum(loss_mask.sum(axis=-1), 1e-9) + per_seq_loss = per_token_losses.sum(axis=-1) / jnp.maximum( + loss_mask.sum(axis=-1), 1e-9 + ) # Return sum of losses (we'll divide gradients by per-adapter batch size later) return per_seq_loss.sum(), (target_logprobs, per_token_losses) @@ -365,14 +431,17 @@ def forward_backward_and_accumulate( else: # Retrieve the sharding of lora and non_lora params and compute the sharding of inputs and outputs lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.lora_params) + lambda spec: jax.NamedSharding(self.mesh, spec), + nnx.get_partition_spec(self.lora_params), ) non_lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.non_lora_params) + lambda spec: jax.NamedSharding(self.mesh, spec), + nnx.get_partition_spec(self.non_lora_params), ) # Get sharding for AccumulatedGradients accumulated_grads_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.accumulated_grads) + lambda spec: jax.NamedSharding(self.mesh, spec), + nnx.get_partition_spec(self.accumulated_grads), ) # Shard batch inputs along the FSDP axis (batch, seq_len) @@ -395,14 +464,32 @@ def forward_backward_and_accumulate( ) self._forward_backward_and_accumulate = jax.jit( forward_backward_and_accumulate, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + input_shardings, - out_shardings=(accumulated_grads_shardings, batch_sharded_2d, batch_sharded_2d), + in_shardings=( + accumulated_grads_shardings, + lora_shardings, + non_lora_shardings, + ) + + input_shardings, + out_shardings=( + accumulated_grads_shardings, + batch_sharded_2d, + batch_sharded_2d, + ), donate_argnames=("accumulated_grads",), ) self._forward = jax.jit( forward_only, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) + input_shardings, - out_shardings=(accumulated_grads_shardings, batch_sharded_2d, batch_sharded_2d), + in_shardings=( + accumulated_grads_shardings, + lora_shardings, + non_lora_shardings, + ) + + input_shardings, + out_shardings=( + accumulated_grads_shardings, + batch_sharded_2d, + batch_sharded_2d, + ), ) # JIT-compiled function to compute full gradients and apply optimizer update @@ -435,13 +522,17 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: used_indices = {m.adapter_index for m in self.models.values()} available_indices = set(range(1, self.config.max_lora_adapters)) - used_indices if not available_indices: - raise ValueError(f"Maximum number of LoRA adapters ({self.config.max_lora_adapters}) reached") + raise ValueError( + f"Maximum number of LoRA adapters ({self.config.max_lora_adapters}) reached" + ) adapter_index = min(available_indices) assert 1 <= adapter_index <= self.config.max_lora_adapters - 1 # Validate rank doesn't exceed max if not (0 < lora_config.rank <= self.config.max_lora_rank): - raise ValueError(f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}") + raise ValueError( + f"LoRA rank {lora_config.rank} must be between 1 and {self.config.max_lora_rank}" + ) # Store model metadata self.models[model_id] = types.ModelMetadata( @@ -452,11 +543,15 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: # Create optimizer with jax.set_mesh(self.mesh): tx = optax.inject_hyperparams(optax.adamw)(learning_rate=0.0) - self.optimizers[model_id] = nnx.Optimizer(self.model, tx, wrt=self.model.is_lora_param) + self.optimizers[model_id] = nnx.Optimizer( + self.model, tx, wrt=self.model.is_lora_param + ) # Configure adapter init_lora_adapter(self.model, adapter_index, lora_config) - logger.info(f"Created model {model_id} with adapter_index={adapter_index}, config={lora_config}") + logger.info( + f"Created model {model_id} with adapter_index={adapter_index}, config={lora_config}" + ) def delete_model(self, model_id: str) -> None: """Delete a model and free all associated resources.""" @@ -507,7 +602,10 @@ def _model_pass( request_batch_slices = prepared_batch.request_batch_slices # Convert model_ids to adapter_indices - all_adapter_indices = [self.models[model_id].adapter_index for model_id in prepared_batch.all_model_ids] + all_adapter_indices = [ + self.models[model_id].adapter_index + for model_id in prepared_batch.all_model_ids + ] # Pad sequences to same length. Also bin it so the JIT has to compile fewer kernels. max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids)) @@ -518,7 +616,9 @@ def _model_pass( loss_fn_types = np.array(all_loss_fn_types, dtype=np.int32) # Create attention mask (1 for real tokens, 0 for padding) - attention_mask = pad_batch([[1] * len(seq) for seq in all_input_ids], max_len, np.int32) + attention_mask = pad_batch( + [[1] * len(seq) for seq in all_input_ids], max_len, np.int32 + ) loss_mask = pad_batch(all_token_weights, max_len, np.float32) sampling_logprobs = pad_batch(all_sampling_logprobs, max_len, np.float32) advantages = pad_batch(all_advantages, max_len, np.float32) @@ -565,18 +665,20 @@ def _model_pass( (sharding_2d,) * 6 + (sharding_1d,) * 2, ) - self.accumulated_grads, per_token_losses, target_logprobs = model_pass_fn( - self.accumulated_grads, - self.lora_params, - self.non_lora_params, - mb_input_ids, - mb_attention_mask, - mb_adapter_indices, - mb_target_ids, - mb_loss_mask, - mb_loss_fn_types, - mb_sampling_logprobs, - mb_advantages, + self.accumulated_grads, per_token_losses, target_logprobs = ( + model_pass_fn( + self.accumulated_grads, + self.lora_params, + self.non_lora_params, + mb_input_ids, + mb_attention_mask, + mb_adapter_indices, + mb_target_ids, + mb_loss_mask, + mb_loss_fn_types, + mb_sampling_logprobs, + mb_advantages, + ) ) # Slice back to original size (remove FSDP padding) token_losses_device.append(per_token_losses[: mb_end - mb_start]) @@ -584,11 +686,19 @@ def _model_pass( # Gather results from all hosts before device_get if jax.process_count() > 1: - token_losses_device = [multihost_utils.process_allgather(x, tiled=True) for x in token_losses_device] - logprobs_device = [multihost_utils.process_allgather(x, tiled=True) for x in logprobs_device] + token_losses_device = [ + multihost_utils.process_allgather(x, tiled=True) + for x in token_losses_device + ] + logprobs_device = [ + multihost_utils.process_allgather(x, tiled=True) + for x in logprobs_device + ] # Single batched device-to-host transfer for all arrays - token_losses_host, logprobs_host = jax.device_get((token_losses_device, logprobs_device)) + token_losses_host, logprobs_host = jax.device_get( + (token_losses_device, logprobs_device) + ) # Flatten microbatches and slice to actual sequence lengths token_losses_out = [] @@ -596,7 +706,9 @@ def _model_pass( idx = 0 for mb_losses, mb_logprobs in zip(token_losses_host, logprobs_host): for i in range(mb_losses.shape[0]): - token_losses_out.append(mb_losses[i, : seq_lens[idx]].astype(jnp.float32)) + token_losses_out.append( + mb_losses[i, : seq_lens[idx]].astype(jnp.float32) + ) logprobs_out.append(mb_logprobs[i, : seq_lens[idx]].astype(jnp.float32)) idx += 1 @@ -645,14 +757,18 @@ def forward( """Run forward-only pass on a batch (no gradient computation).""" return self._model_pass(prepared_batch, self._forward) - def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: + def optim_step( + self, model_id: str, request_data: types.OptimStepInput + ) -> types.OptimStepOutput: """Apply an optimizer step using accumulated gradients.""" adapter_index = self.models[model_id].adapter_index optimizer = self.optimizers[model_id] # Check if we have any gradients accumulated (count > 0) if self.accumulated_grads.counts[adapter_index] == 0: - logger.warning(f"No accumulated gradients for model {model_id}, skipping optimizer step") + logger.warning( + f"No accumulated gradients for model {model_id}, skipping optimizer step" + ) return types.OptimStepOutput() # Update hyperparameters from the request @@ -672,7 +788,9 @@ def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types jnp.int32(adapter_index), ) - logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index})") + logger.info( + f"Applied optimizer step for model {model_id} (adapter {adapter_index})" + ) return types.OptimStepOutput() def sample( @@ -703,7 +821,9 @@ def sample( total_batch_size = len(all_prompts) max_batch_size = ( - self.config.sample_max_num_sequences if self.config.sample_max_num_sequences > 0 else total_batch_size + self.config.sample_max_num_sequences + if self.config.sample_max_num_sequences > 0 + else total_batch_size ) # Collect generated sequences and prompt logprobs across batches all_sequences: list[types.GeneratedSequence] = [] @@ -717,22 +837,36 @@ def sample( model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) for batch_start in range(0, total_batch_size, max_batch_size): batch_end = min(batch_start + max_batch_size, total_batch_size) - batch_prompts = pad(all_prompts[batch_start:batch_end], max_batch_size, fill=[]) - batch_adapter_indices = pad(all_adapter_indices[batch_start:batch_end], max_batch_size, fill=0) + batch_prompts = pad( + all_prompts[batch_start:batch_end], max_batch_size, fill=[] + ) + batch_adapter_indices = pad( + all_adapter_indices[batch_start:batch_end], max_batch_size, fill=0 + ) sampling_params = pad( - all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start] + all_sampling_params[batch_start:batch_end], + max_batch_size, + fill=all_sampling_params[batch_start], ) # Pad sequences to same length within the batch to minimize memory usage. # Also bin it so the JIT has to compile fewer kernels. # Use right-padding, which means during decoding there will be "gaps" in the attention mask. - max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) + max_len = round_up_seq_len( + max((len(seq) for seq in batch_prompts), default=0) + ) input_ids = pad_batch(batch_prompts, max_len, np.int32) - attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32) + attention_mask = pad_batch( + [[1] * len(seq) for seq in batch_prompts], max_len, np.int32 + ) # Shard inputs along FSDP axis (already padded to max_batch_size) input_ids, attention_mask, adapter_indices = jax.device_put( - (input_ids, attention_mask, np.array(batch_adapter_indices, dtype=np.int32)), + ( + input_ids, + attention_mask, + np.array(batch_adapter_indices, dtype=np.int32), + ), (sharding_2d, sharding_2d, sharding_1d), ) @@ -748,7 +882,9 @@ def sample( # Only take the actual results, not the padded ones batch_size = batch_end - batch_start all_sequences.extend( - types.GeneratedSequence(stop_reason=stop_reason, tokens=tokens, logprobs=logprobs) + types.GeneratedSequence( + stop_reason=stop_reason, tokens=tokens, logprobs=logprobs + ) for stop_reason, tokens, logprobs in zip( result.stop_reasons[:batch_size], result.generated_ids[:batch_size], @@ -758,13 +894,23 @@ def sample( if needs_prompt_logprobs and result.prompt_logprobs: all_prompt_logprobs.extend(result.prompt_logprobs[:batch_size]) - for request_id, _, start_idx, end_idx, prompt_logprobs_requested in request_batch_slices: + for ( + request_id, + _, + start_idx, + end_idx, + prompt_logprobs_requested, + ) in request_batch_slices: sequences = [all_sequences[i] for i in range(start_idx, end_idx)] # Each of `num_samples` samples in a request share the same prompt; use the first's prompt logprobs prompt_logprobs = ( - all_prompt_logprobs[start_idx] if prompt_logprobs_requested and all_prompt_logprobs else None + all_prompt_logprobs[start_idx] + if prompt_logprobs_requested and all_prompt_logprobs + else None + ) + results[request_id] = types.SampleOutput( + sequences=sequences, prompt_logprobs=prompt_logprobs ) - results[request_id] = types.SampleOutput(sequences=sequences, prompt_logprobs=prompt_logprobs) return results @@ -785,7 +931,9 @@ def _extract_checkpoint_data(self, model_id: str) -> dict: adapter_index = self.models[model_id].adapter_index rank = self.models[model_id].lora_config.rank lora_weights = extract_adapter_state(adapter_index, self.lora_params, rank) - optimizer_state = extract_adapter_state(adapter_index, nnx.state(self.optimizers[model_id]), rank) + optimizer_state = extract_adapter_state( + adapter_index, nnx.state(self.optimizers[model_id]), rank + ) return { "lora_weights": lora_weights, "optimizer_state": optimizer_state, @@ -803,9 +951,14 @@ def _insert_checkpoint_data(self, model_id: str, checkpoint_data: dict) -> None: f"model configured with rank {self.models[model_id].lora_config.rank}" ) - insert_adapter_state(adapter_index, self.lora_params, checkpoint_data["lora_weights"], rank) insert_adapter_state( - adapter_index, nnx.state(self.optimizers[model_id]), checkpoint_data["optimizer_state"], rank + adapter_index, self.lora_params, checkpoint_data["lora_weights"], rank + ) + insert_adapter_state( + adapter_index, + nnx.state(self.optimizers[model_id]), + checkpoint_data["optimizer_state"], + rank, ) def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: @@ -817,7 +970,9 @@ def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: ) if checkpoint is None: - raise FileNotFoundError(f"Training checkpoint not found in {checkpoint_path}") + raise FileNotFoundError( + f"Training checkpoint not found in {checkpoint_path}" + ) self._insert_checkpoint_data(model_id, checkpoint) logger.info(f"Loaded training checkpoint from {checkpoint_path}") @@ -834,15 +989,21 @@ def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str) -> None: ) logger.info(f"Saved LoRA sampler checkpoint to {output_path}") - def load_sampler_checkpoint(self, model_id: str, checkpoint_id: str, checkpoint_path: AnyPath) -> None: + def load_sampler_checkpoint( + self, model_id: str, checkpoint_id: str, checkpoint_path: AnyPath + ) -> None: """Insert sampler weights from checkpoint file.""" adapter_index = self.models[model_id].adapter_index adapter_config = self.models[model_id].lora_config load_lora_checkpoint(self.model, adapter_config, adapter_index, checkpoint_path) self.models[model_id].loaded_checkpoint_id = checkpoint_id - logger.info(f"Loaded LoRA sampler weights for model {model_id} at adapter index {adapter_index}") + logger.info( + f"Loaded LoRA sampler weights for model {model_id} at adapter index {adapter_index}" + ) - def load_sampler_weights(self, prepared_batch: types.PreparedSampleBatch) -> list[int]: + def load_sampler_weights( + self, prepared_batch: types.PreparedSampleBatch + ) -> list[int]: """Load sampler weights for all requests and return adapter indices array. Ensures all required checkpoints are loaded before sampling. @@ -858,7 +1019,9 @@ def load_sampler_weights(self, prepared_batch: types.PreparedSampleBatch) -> lis loaded_adapters = set() # Track adapters already used in this batch for model_id, checkpoint_id, checkpoint_path in zip( - prepared_batch.all_model_ids, prepared_batch.all_checkpoint_ids, prepared_batch.all_checkpoint_paths + prepared_batch.all_model_ids, + prepared_batch.all_checkpoint_ids, + prepared_batch.all_checkpoint_paths, ): if model_id: # This code path is for sampling from a LoRA adapter @@ -870,10 +1033,16 @@ def load_sampler_weights(self, prepared_batch: types.PreparedSampleBatch) -> lis adapter_indices.append(adapter_index) else: # Need to load from disk - assert adapter_index not in loaded_adapters, "Cannot override already used adapter" + assert ( + adapter_index not in loaded_adapters + ), "Cannot override already used adapter" - logger.info(f"Loading LoRA sampler checkpoint from {checkpoint_path}") - self.load_sampler_checkpoint(model_id, checkpoint_id, AnyPath(checkpoint_path)) + logger.info( + f"Loading LoRA sampler checkpoint from {checkpoint_path}" + ) + self.load_sampler_checkpoint( + model_id, checkpoint_id, AnyPath(checkpoint_path) + ) adapter_indices.append(adapter_index) loaded_adapters.add(adapter_index) @@ -940,9 +1109,20 @@ class JaxBackend(JaxBackendImpl): def __init__(self, base_model: str, config: JaxBackendConfig): if config.coordinator_address is not None: + # Calculate num_processes: if enable_ray, use ray_num_workers + 1 (for coordinator) + num_processes = config.num_processes + if num_processes is None and config.enable_ray: + num_processes = config.ray_num_workers + 1 + logger.info( + f"Auto-calculated num_processes={num_processes} (ray_num_workers={config.ray_num_workers} + 1 coordinator)" + ) + + if num_processes is None: + raise ValueError("num_processes must be set when using multi-node mode") + jax.distributed.initialize( coordinator_address=config.coordinator_address, - num_processes=config.num_processes, + num_processes=num_processes, process_id=0, ) logger.info( @@ -955,33 +1135,48 @@ def __init__(self, base_model: str, config: JaxBackendConfig): def _broadcast_and_call(self, method: str, **kwargs): """Broadcast method call to workers and execute locally via super().""" if jax.process_count() > 1: - clean = {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in kwargs.items()} + clean = { + k: v.model_dump() if isinstance(v, BaseModel) else v + for k, v in kwargs.items() + } _broadcast_command(RpcPayload(method=method, kwargs=clean)) return getattr(super(), method)(**kwargs) def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: - self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config) + self._broadcast_and_call( + "create_model", model_id=model_id, lora_config=lora_config + ) def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): - return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch) + return self._broadcast_and_call( + "forward_backward", prepared_batch=prepared_batch + ) def forward(self, prepared_batch: types.PreparedModelPassBatch): return self._broadcast_and_call("forward", prepared_batch=prepared_batch) def optim_step(self, model_id: str, request_data: types.OptimStepInput): - return self._broadcast_and_call("optim_step", model_id=model_id, request_data=request_data) + return self._broadcast_and_call( + "optim_step", model_id=model_id, request_data=request_data + ) def sample(self, prepared_batch: types.PreparedSampleBatch): return self._broadcast_and_call("sample", prepared_batch=prepared_batch) def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: - self._broadcast_and_call("save_checkpoint", output_path=output_path, model_id=model_id) + self._broadcast_and_call( + "save_checkpoint", output_path=output_path, model_id=model_id + ) def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: - self._broadcast_and_call("load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id) + self._broadcast_and_call( + "load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id + ) def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str) -> None: - self._broadcast_and_call("save_sampler_checkpoint", output_path=output_path, model_id=model_id) + self._broadcast_and_call( + "save_sampler_checkpoint", output_path=output_path, model_id=model_id + ) def run_worker(coordinator_address: str, num_processes: int, process_id: int): @@ -1010,9 +1205,13 @@ def run_worker(coordinator_address: str, num_processes: int, process_id: int): # Receive INIT payload with base_model and config from coordinator init_payload = _broadcast_command(None) - assert init_payload.method == "__init__", f"Expected __init__, got {init_payload.method}" + assert ( + init_payload.method == "__init__" + ), f"Expected __init__, got {init_payload.method}" config = JaxBackendConfig.model_validate(init_payload.kwargs["config"]) - logger.info(f"Worker received config: base_model={init_payload.kwargs['base_model']}, config={config}") + logger.info( + f"Worker received config: base_model={init_payload.kwargs['base_model']}, config={config}" + ) backend = JaxBackendImpl(init_payload.kwargs["base_model"], config) @@ -1029,10 +1228,283 @@ def run_worker(coordinator_address: str, num_processes: int, process_id: int): # Re-hydrate raw dicts into Pydantic models using type hints hints = get_type_hints(method) - kwargs = {k: TypeAdapter(hints[k]).validate_python(v) if k in hints else v for k, v in payload.kwargs.items()} + kwargs = { + k: TypeAdapter(hints[k]).validate_python(v) if k in hints else v + for k, v in payload.kwargs.items() + } method(**kwargs) +# ============================================================================= +# Ray-based process launcher (Solution C) +# ============================================================================= +# +# This uses Ray only to launch worker processes on remote nodes. +# After launching, all communication goes through JAX distributed (PR #810 architecture). +# +# Usage: +# uv run -m tx.tinker.api --base-model Qwen/Qwen3-8B \ +# --backend-config '{"enable_ray": true, "ray_num_workers": 2}' + +try: + import ray + + RAY_AVAILABLE = True +except ImportError: + RAY_AVAILABLE = False + + +def _create_ray_worker_launcher_class(): + """Factory function to create the RayWorkerLauncher class with ray.remote decorator.""" + if not RAY_AVAILABLE: + raise ImportError("Ray is not installed. Install it with: pip install ray") + + @ray.remote + class RayWorkerLauncher: + """Ray Actor that launches and runs a JAX worker process. + + This actor runs on a Ray worker node and executes the run_worker() function, + which enters the JAX distributed worker loop. The actor stays alive as long + as the worker is running. + """ + + def __init__(self): + self.process_id: int | None = None + self.is_running = False + + def start_worker( + self, + coordinator_address: str, + num_processes: int, + process_id: int, + ) -> dict: + """Start the JAX worker process. + + This method blocks and runs the worker loop until shutdown. + """ + self.process_id = process_id + self.is_running = True + + logger.info( + f"[RayWorkerLauncher] Starting JAX worker process_id={process_id}" + ) + + try: + # This will block and run the worker loop + run_worker(coordinator_address, num_processes, process_id) + except Exception as e: + logger.error(f"[RayWorkerLauncher] Worker {process_id} failed: {e}") + self.is_running = False + raise + + self.is_running = False + return {"process_id": process_id, "status": "stopped"} + + def get_status(self) -> dict: + """Get worker status.""" + return { + "process_id": self.process_id, + "is_running": self.is_running, + } + + return RayWorkerLauncher + + +class RayProcessManager: + """Manages launching JAX worker processes on a Ray cluster. + + This class uses Ray to spawn worker processes on different nodes, + then relies on JAX distributed for all subsequent communication. + This is the "Solution C" approach - Ray is just a process launcher. + + Usage: + manager = RayProcessManager(coordinator_address, num_workers, gpus_per_worker=4) + manager.start_workers() + # ... workers are now running and communicating via JAX distributed ... + manager.shutdown() + """ + + def __init__( + self, + coordinator_address: str, + num_workers: int, + gpus_per_worker: int | None = None, + cpus_per_worker: int = 1, + ): + if not RAY_AVAILABLE: + raise ImportError("Ray is not installed. Install it with: pip install ray") + + # Ray should already be initialized by start_ray_workers() + if not ray.is_initialized(): + raise RuntimeError( + "Ray must be initialized before creating RayProcessManager" + ) + + self.coordinator_address = coordinator_address + self.num_workers = num_workers + self.gpus_per_worker = gpus_per_worker + self.cpus_per_worker = cpus_per_worker + self.worker_handles: list = [] + self.worker_futures: list = [] + + def start_workers(self) -> None: + """Start all worker processes on Ray cluster. + + Workers are started with process_id from 1 to num_workers. + Process 0 is reserved for the coordinator (JaxBackend). + + Resource allocation strategy: + - gpus_per_worker: Number of local GPUs per node (e.g., 8 for a node with 8 GPUs) + - This ensures each worker is placed on a separate node with dedicated GPU access + - Ray scheduler will wait until a node with enough GPUs is available + + Example topology (2 nodes, 8 GPUs each): + Node 0 (head): Coordinator (process_id=0) + 8 GPUs + Node 1: Worker (process_id=1) + 8 GPUs via RayWorkerLauncher + """ + RayWorkerLauncher = _create_ray_worker_launcher_class() + + num_processes = self.num_workers + 1 # +1 for coordinator (process 0) + + # Build resource options for Ray actor placement + # gpus_per_worker should match the number of local GPUs per node + actor_options: dict = {"num_cpus": self.cpus_per_worker} + if self.gpus_per_worker is not None: + actor_options["num_gpus"] = self.gpus_per_worker + + # TODO: Future enhancements for resource allocation: + # 1. Auto-detect GPUs per node: + # cluster_resources = ray.cluster_resources() + # total_gpus = cluster_resources.get("GPU", 0) + # gpus_per_node = total_gpus // num_nodes + # + # 2. Use SPREAD scheduling to distribute workers across nodes: + # from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + # actor_options["scheduling_strategy"] = "SPREAD" + # + # 3. Use placement groups for gang scheduling: + # pg = ray.util.placement_group([{"GPU": gpus_per_worker}] * num_workers) + # ray.get(pg.ready()) + # actor_options["placement_group"] = pg + + logger.info( + f"Launching {self.num_workers} JAX workers via Ray with resources: " + f"num_gpus={self.gpus_per_worker}, num_cpus={self.cpus_per_worker}" + ) + + for i in range(self.num_workers): + process_id = i + 1 # Workers start from 1 + + # Create the launcher actor with resource specification + launcher = RayWorkerLauncher.options(**actor_options).remote() + self.worker_handles.append(launcher) + + # Start the worker (non-blocking, returns a future) + future = launcher.start_worker.remote( + self.coordinator_address, + num_processes, + process_id, + ) + self.worker_futures.append(future) + + logger.info(f"Launched worker process_id={process_id} via Ray") + + logger.info(f"All {self.num_workers} Ray worker launchers started") + + def shutdown(self) -> None: + """Shutdown all worker processes.""" + logger.info("Shutting down Ray worker launchers...") + for handle in self.worker_handles: + try: + ray.kill(handle) + except Exception as e: + logger.warning(f"Error killing worker: {e}") + self.worker_handles = [] + self.worker_futures = [] + logger.info("RayProcessManager shutdown complete") + + +def _get_coordinator_address(config: JaxBackendConfig) -> str: + """Get or auto-detect the JAX coordinator address. + + When enable_ray is True and coordinator_address is not set, + automatically use the Ray head node's IP with a default port. + + Args: + config: JaxBackendConfig with coordinator_address and enable_ray + + Returns: + The coordinator address (host:port) + """ + if config.coordinator_address is not None: + return config.coordinator_address + + if not config.enable_ray: + raise ValueError("coordinator_address must be set when enable_ray is False") + + # Auto-detect coordinator address using Ray head node IP + # This assumes the coordinator (JaxBackend) runs on the Ray head node + head_node_ip = ray.util.get_node_ip_address() + default_port = 7777 # Default JAX coordinator port + coordinator_address = f"{head_node_ip}:{default_port}" + logger.info(f"Auto-detected coordinator_address: {coordinator_address}") + return coordinator_address + + +def start_ray_workers( + config: JaxBackendConfig, +) -> tuple[RayProcessManager | None, str | None]: + """Start Ray worker processes if enable_ray is True. + + This should be called before JaxBackend.__init__ so that workers are + ready to receive the JAX distributed initialization. + + When enable_ray is True but coordinator_address is not set, this function + will auto-detect the coordinator address using the Ray head node's IP. + + Args: + config: JaxBackendConfig with enable_ray and ray_num_workers + + Returns: + Tuple of (RayProcessManager, coordinator_address) if enable_ray is True, + (None, None) otherwise. The coordinator_address is returned so the caller + can use it when initializing JaxBackend. + """ + if not config.enable_ray: + return None, None + + if not RAY_AVAILABLE: + raise ImportError("Ray is not installed. Install it with: pip install ray") + + # Initialize Ray if not already initialized (needed for auto-detect) + if not ray.is_initialized(): + ray.init( + dashboard_host="0.0.0.0", + dashboard_port=8265, + ) + logger.info( + "Ray initialized for process management (dashboard at http://localhost:8265)" + ) + + # Get or auto-detect coordinator address + coordinator_address = _get_coordinator_address(config) + + manager = RayProcessManager( + coordinator_address=coordinator_address, + num_workers=config.ray_num_workers, + gpus_per_worker=config.ray_gpus_per_worker, + cpus_per_worker=config.ray_cpus_per_worker, + ) + manager.start_workers() + + # Give workers time to initialize JAX distributed + import time + + time.sleep(2) + + return manager, coordinator_address + + def main(): """Entry point for running as a worker process.""" import argparse diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index b7f4b2917..d402fbfd9 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -1,6 +1,8 @@ """Background engine for processing training requests.""" import argparse +import signal +import sys import time from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -11,10 +13,17 @@ from pydantic import BaseModel from sqlmodel import create_engine, Session, select, update, func -from tx.tinker.db_models import FutureDB, RequestStatus, CheckpointDB, CheckpointStatus, ModelDB, SessionDB +from tx.tinker.db_models import ( + FutureDB, + RequestStatus, + CheckpointDB, + CheckpointStatus, + ModelDB, + SessionDB, +) from tx.tinker import types from tx.tinker.config import EngineConfig, add_model -from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig +from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig, start_ray_workers from tx.tinker.backends.utils import log_timing from tx.tinker.loss_fns import LOSS_TYPES from tx.utils.log import logger @@ -43,17 +52,24 @@ def prepare_sample_batch( all_checkpoint_paths = [] request_batch_slices = [] - needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in requests.values()) + needs_prompt_logprobs = any( + request_data.prompt_logprobs for (_, request_data) in requests.values() + ) for request_id, (model_id, request_data) in requests.items(): request_start = len(all_prompts) # Expand requests for num_samples - prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] + prompt_tokens = [ + token for chunk in request_data.prompt.chunks for token in chunk.tokens + ] checkpoint_path = "" if model_id and request_data.checkpoint_id and checkpoints_base: checkpoint_path = str( - checkpoints_base / model_id / "sampler_weights" / f"{request_data.checkpoint_id}.tar.gz" + checkpoints_base + / model_id + / "sampler_weights" + / f"{request_data.checkpoint_id}.tar.gz" ) for _ in range(request_data.num_samples): all_prompts.append(prompt_tokens) @@ -63,7 +79,13 @@ def prepare_sample_batch( all_checkpoint_paths.append(checkpoint_path) request_batch_slices.append( - (request_id, model_id, request_start, len(all_prompts), request_data.prompt_logprobs) + ( + request_id, + model_id, + request_start, + len(all_prompts), + request_data.prompt_logprobs, + ) ) return types.PreparedSampleBatch( @@ -115,7 +137,9 @@ def prepare_model_pass_batch( all_model_ids.append(model_id) all_loss_fn_types.append(loss_fn_type) - request_batch_slices.append((request_id, model_id, request_start, len(all_input_ids))) + request_batch_slices.append( + (request_id, model_id, request_start, len(all_input_ids)) + ) return types.PreparedModelPassBatch( all_input_ids=all_input_ids, @@ -188,16 +212,43 @@ def __init__( # Initialize the backend (handles model state, computation, and adapter management) if config.backend not in BACKENDS: - raise ValueError(f"Unknown backend: {config.backend}. Available backends: {list(BACKENDS.keys())}") + raise ValueError( + f"Unknown backend: {config.backend}. Available backends: {list(BACKENDS.keys())}" + ) backend_class, backend_config_class = BACKENDS[config.backend] backend_config = backend_config_class(**config.backend_config) + + # If enable_ray is True, start Ray worker processes before initializing the backend + # The workers will connect via JAX distributed when JaxBackend initializes + self._ray_process_manager = None + if hasattr(backend_config, "enable_ray") and backend_config.enable_ray: + logger.info("Starting Ray worker processes for multi-node support...") + self._ray_process_manager, coordinator_address = start_ray_workers( + backend_config + ) + + # Update backend_config with auto-detected coordinator_address if it was None + if ( + backend_config.coordinator_address is None + and coordinator_address is not None + ): + # Create a new config with the auto-detected coordinator_address + backend_config = backend_config.model_copy( + update={"coordinator_address": coordinator_address} + ) + logger.info( + f"Updated backend_config with auto-detected coordinator_address: {coordinator_address}" + ) + self.backend = backend_class(config.base_model, backend_config) # Track last cleanup time for periodic stale session cleanup self._last_cleanup_time: float = time.time() - logger.info(f"Initialized TinkerEngine with backend={type(self.backend).__name__}") + logger.info( + f"Initialized TinkerEngine with backend={type(self.backend).__name__}" + ) @property def metrics(self) -> types.EngineMetrics: @@ -205,14 +256,18 @@ def metrics(self) -> types.EngineMetrics: return self.backend.metrics @contextmanager - def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType): + def _checkpoint_status_context( + self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType + ): """Context manager to handle checkpoint DB status updates. Fetches the checkpoint entry, yields it, and updates its status to COMPLETED or FAILED based on whether an exception occurred. """ with Session(self.db_engine) as session: - checkpoint_db = session.get(CheckpointDB, (model_id, checkpoint_id, checkpoint_type)) + checkpoint_db = session.get( + CheckpointDB, (model_id, checkpoint_id, checkpoint_type) + ) if checkpoint_db is None: raise ValueError( f"Checkpoint entry not found for model '{model_id}', checkpoint '{checkpoint_id}', type '{checkpoint_type}'" @@ -222,7 +277,9 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi yield checkpoint_db checkpoint_db.status = CheckpointStatus.COMPLETED except Exception as e: - logger.exception(f"Error saving checkpoint for model {model_id}, checkpoint {checkpoint_id}: {e}") + logger.exception( + f"Error saving checkpoint for model {model_id}, checkpoint {checkpoint_id}: {e}" + ) checkpoint_db.status = CheckpointStatus.FAILED checkpoint_db.error_message = str(e) raise @@ -268,14 +325,23 @@ def find_batchable_model_passes( ops = session.exec(query).all() # Filter: only include ops that come before their model's barrier - batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + batchable = [ + op + for op in ops + if op.model_id not in barriers or op.request_id < barriers[op.model_id] + ] return { - str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) + str(f.request_id): ( + f.model_id, + types.ForwardBackwardInput.model_validate(f.request_data), + ) for f in batchable } - def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.SampleInput]]: + def find_batchable_sample( + self, session: Session + ) -> dict[str, tuple[str, types.SampleInput]]: """Find all sample ops that can be safely batched together. Returns sample operations ensuring that each model_id has only one checkpoint_id @@ -300,22 +366,39 @@ def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types. sample_ops = session.exec(sample_query).all() batchable = [] - model_checkpoints = {} # Map from model_id to checkpoint_id of first request to that model + model_checkpoints = ( + {} + ) # Map from model_id to checkpoint_id of first request to that model for op in sample_ops: checkpoint_id = op.request_data["checkpoint_id"] # Base model requests (empty checkpoint_id) are always compatible, otherwise only # take only requests with one checkpoint_id for a given model_id - if checkpoint_id == "" or model_checkpoints.setdefault(op.model_id, checkpoint_id) == checkpoint_id: + if ( + checkpoint_id == "" + or model_checkpoints.setdefault(op.model_id, checkpoint_id) + == checkpoint_id + ): batchable.append(op) # TODO: This leaks the abstraction by accessing backend-specific config. # We should find a better way to handle this going forward. - if isinstance(self.backend, JaxBackend) and self.backend.config.sample_max_num_sequences > 0: + if ( + isinstance(self.backend, JaxBackend) + and self.backend.config.sample_max_num_sequences > 0 + ): batchable = batchable[: self.backend.config.sample_max_num_sequences] - return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable} + return { + str(f.request_id): ( + f.model_id, + types.SampleInput.model_validate(f.request_data), + ) + for f in batchable + } - def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.RequestType, dict]]: + def find_single_requests( + self, session: Session + ) -> dict[str, tuple[str, types.RequestType, dict]]: """Find all requests that need to be processed individually (not batchable). Args: @@ -335,9 +418,14 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R ) other_futures = session.exec(statement).all() - return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures} + return { + str(f.request_id): (f.model_id, f.request_type, f.request_data) + for f in other_futures + } - def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: + def process_create_model( + self, model_id: str, request_data: types.CreateModelInput + ) -> types.CreateModelOutput: """Create and initialize a model.""" # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter) self.backend.create_model(model_id, request_data.lora_config) @@ -350,16 +438,24 @@ def process_create_model(self, model_id: str, request_data: types.CreateModelInp lora_config=request_data.lora_config, ) - def process_unload_model(self, model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput: + def process_unload_model( + self, model_id: str, request_data: types.UnloadModelInput + ) -> types.UnloadModelOutput: """Unload a model and free all resources.""" if not self.backend.has_model(model_id): - logger.warning(f"Ignoring unload request for model {model_id} that is not loaded.") + logger.warning( + f"Ignoring unload request for model {model_id} that is not loaded." + ) else: self.backend.delete_model(model_id) # Update model status in DB with Session(self.db_engine) as session: - _ = session.exec(update(ModelDB).where(ModelDB.model_id == model_id).values(status="unloaded")) + _ = session.exec( + update(ModelDB) + .where(ModelDB.model_id == model_id) + .values(status="unloaded") + ) session.commit() logger.info(f"Unloaded model {model_id}") @@ -372,7 +468,9 @@ def cleanup_stale_sessions(self) -> int: Returns: Number of models unloaded """ - cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.config.session_timeout_sec) + cutoff = datetime.now(timezone.utc) - timedelta( + seconds=self.config.session_timeout_sec + ) unloaded_count = 0 with Session(self.db_engine) as session: @@ -404,9 +502,13 @@ def cleanup_stale_sessions(self) -> int: self.backend.delete_model(model.model_id) model.status = "unloaded" unloaded_count += 1 - logger.info(f"Auto-unloaded stale model {model.model_id} from session {model.session_id}") + logger.info( + f"Auto-unloaded stale model {model.model_id} from session {model.session_id}" + ) except Exception as e: - logger.error(f"Failed to auto-unload model {model.model_id}: {e}") + logger.error( + f"Failed to auto-unload model {model.model_id}: {e}" + ) sessions_with_failed_unloads.add(model.session_id) else: # Model not in backend but status not unloaded - fix DB state @@ -415,48 +517,66 @@ def cleanup_stale_sessions(self) -> int: for sess in stale_sessions: if sess.session_id not in sessions_with_failed_unloads: sess.status = "expired" - logger.info(f"Expired stale session {sess.session_id} (last heartbeat: {sess.last_heartbeat_at})") + logger.info( + f"Expired stale session {sess.session_id} (last heartbeat: {sess.last_heartbeat_at})" + ) session.commit() return unloaded_count - def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: + def process_optim_step( + self, model_id: str, request_data: types.OptimStepInput + ) -> types.OptimStepOutput: """Process an optim_step request and apply accumulated gradients.""" if not self.backend.has_model(model_id): raise ValueError(f"Model {model_id} not loaded") return self.backend.optim_step(model_id, request_data) - def process_forward_backward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict: + def process_forward_backward( + self, requests: dict[str, tuple[str, types.ForwardBackwardInput]] + ) -> dict: """Run forward and backward pass on a batch of requests.""" prepared = prepare_model_pass_batch(requests) return self.backend.forward_backward(prepared) - def process_forward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict: + def process_forward( + self, requests: dict[str, tuple[str, types.ForwardBackwardInput]] + ) -> dict: """Run forward-only pass on a batch of requests.""" prepared = prepare_model_pass_batch(requests) return self.backend.forward(prepared) - def process_sample(self, requests: dict[str, tuple[str, types.SampleInput]]) -> dict: + def process_sample( + self, requests: dict[str, tuple[str, types.SampleInput]] + ) -> dict: """Generate samples for a batch of requests.""" prepared = prepare_sample_batch(requests, self.config.checkpoints_base) return self.backend.sample(prepared) - def process_load_weights(self, model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput: + def process_load_weights( + self, model_id: str, request_data: types.LoadWeightsInput + ) -> types.LoadWeightsOutput: """Loads a clean, trimmed training checkpoint.""" if not self.backend.has_model(model_id): - raise ValueError("Model not loaded. Create the model before loading a checkpoint.") + raise ValueError( + "Model not loaded. Create the model before loading a checkpoint." + ) checkpoint_path = ( - self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz" + self.config.checkpoints_base + / request_data.source_model_id + / f"{request_data.checkpoint_id}.tar.gz" ) self.backend.load_checkpoint(checkpoint_path, model_id) return types.LoadWeightsOutput(type="load_weights") - def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInput) -> types.SaveWeightsOutput: + def process_save_weights( + self, model_id: str, request_data: types.SaveWeightsInput + ) -> types.SaveWeightsOutput: """ Saves a clean training checkpoint by converting the trimmed NNX graph to a pure dictionary before serialization, following official Flax docs. @@ -465,11 +585,17 @@ def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInp raise ValueError(f"Model {model_id} not loaded") checkpoint_id = request_data.path - output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz" + output_path = ( + self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz" + ) - with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING): + with self._checkpoint_status_context( + model_id, checkpoint_id, types.CheckpointType.TRAINING + ): self.backend.save_checkpoint(output_path, model_id) - logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}") + logger.info( + f"Saved trimmed training checkpoint for model {model_id} to {output_path}" + ) return types.SaveWeightsOutput( path=f"tinker://{model_id}/weights/{checkpoint_id}", @@ -485,14 +611,26 @@ def process_save_weights_for_sampler( # Make sure the user cannot store checkpoints in places like ../../ checkpoint_id = Path(request_data.path).name - output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" + output_path = ( + self.config.checkpoints_base + / model_id + / "sampler_weights" + / f"{checkpoint_id}.tar.gz" + ) - with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER): + with self._checkpoint_status_context( + model_id, checkpoint_id, types.CheckpointType.SAMPLER + ): self.backend.save_sampler_checkpoint(output_path, model_id) - logger.info(f"Saved LoRA adapter weights for model {model_id} to {output_path}") + logger.info( + f"Saved LoRA adapter weights for model {model_id} to {output_path}" + ) # Return path=None when using sampling_session_seq_id and seq_id (SDK expects this) - if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None: + if ( + request_data.sampling_session_seq_id is not None + and request_data.seq_id is not None + ): output_path_str = None else: output_path_str = f"tinker://{model_id}/{checkpoint_id}" @@ -514,7 +652,11 @@ def _complete_futures(self, results: dict[str, BaseModel]): { "request_id": int(request_id), "result_data": result.model_dump(), - "status": RequestStatus.FAILED if isinstance(result, types.ErrorResponse) else RequestStatus.COMPLETED, + "status": ( + RequestStatus.FAILED + if isinstance(result, types.ErrorResponse) + else RequestStatus.COMPLETED + ), "completed_at": completed_at, } for request_id, result in results.items() @@ -524,26 +666,41 @@ def _complete_futures(self, results: dict[str, BaseModel]): session.execute(update(FutureDB), params) session.commit() - def process_single_request(self, request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel: + def process_single_request( + self, request_type: types.RequestType, model_id: str, request_data: dict + ) -> BaseModel: match request_type: case types.RequestType.CREATE_MODEL: - return self.process_create_model(model_id, types.CreateModelInput.model_validate(request_data)) + return self.process_create_model( + model_id, types.CreateModelInput.model_validate(request_data) + ) case types.RequestType.OPTIM_STEP: - return self.process_optim_step(model_id, types.OptimStepInput.model_validate(request_data)) + return self.process_optim_step( + model_id, types.OptimStepInput.model_validate(request_data) + ) case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER: return self.process_save_weights_for_sampler( - model_id, types.SaveWeightsForSamplerInput.model_validate(request_data) + model_id, + types.SaveWeightsForSamplerInput.model_validate(request_data), ) case types.RequestType.SAVE_WEIGHTS: - return self.process_save_weights(model_id, types.SaveWeightsInput.model_validate(request_data)) + return self.process_save_weights( + model_id, types.SaveWeightsInput.model_validate(request_data) + ) case types.RequestType.LOAD_WEIGHTS: - return self.process_load_weights(model_id, types.LoadWeightsInput.model_validate(request_data)) + return self.process_load_weights( + model_id, types.LoadWeightsInput.model_validate(request_data) + ) case types.RequestType.UNLOAD_MODEL: - return self.process_unload_model(model_id, types.UnloadModelInput.model_validate(request_data)) + return self.process_unload_model( + model_id, types.UnloadModelInput.model_validate(request_data) + ) case _: raise ValueError(f"Unknown request type: {request_type}") - def process_single_requests(self, requests: dict[str, tuple[str, types.RequestType, dict]]): + def process_single_requests( + self, requests: dict[str, tuple[str, types.RequestType, dict]] + ): """Process a collection of single (non-batchable) requests. Args: @@ -555,7 +712,9 @@ def process_single_requests(self, requests: dict[str, tuple[str, types.RequestTy for request_id, (model_id, request_type, request_data) in requests.items(): with log_timing(f"process_single_request({request_type.value})"): try: - result = self.process_single_request(request_type, model_id, request_data) + result = self.process_single_request( + request_type, model_id, request_data + ) except Exception as e: logger.exception(f"Error processing request {request_id}: {e}") result = types.ErrorResponse(error=str(e), status="failed") @@ -587,7 +746,10 @@ def process_batch_requests( results = error_results except Exception as e: logger.exception(f"Error processing batch: {e}") - results = {request_id: types.ErrorResponse(error=str(e), status="failed") for request_id in requests} + results = { + request_id: types.ErrorResponse(error=str(e), status="failed") + for request_id in requests + } self._complete_futures(results) def process_pending_requests(self): @@ -599,29 +761,52 @@ def process_pending_requests(self): forward_backward_requests = self.find_batchable_model_passes( session, types.RequestType.FORWARD_BACKWARD ) - forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD) + forward_requests = self.find_batchable_model_passes( + session, types.RequestType.FORWARD + ) # Find pending sample requests that can be batched sample_requests = self.find_batchable_sample(session) # Get other pending requests (non forward_backward and non sampling) other_requests = self.find_single_requests(session) # Process batches outside of session context - self.process_batch_requests(forward_backward_requests, self.process_forward_backward, "forward_backward") - self.process_batch_requests(forward_requests, self.process_forward, "forward") + self.process_batch_requests( + forward_backward_requests, + self.process_forward_backward, + "forward_backward", + ) + self.process_batch_requests( + forward_requests, self.process_forward, "forward" + ) self.process_batch_requests(sample_requests, self.process_sample, "sample") # Process other request types individually (in the future we can also batch independent optim_steps) self.process_single_requests(other_requests) # Periodically cleanup stale sessions (disabled if either config is negative) - cleanup_enabled = self.config.session_cleanup_interval_sec >= 0 and self.config.session_timeout_sec >= 0 - if cleanup_enabled and time.time() - self._last_cleanup_time > self.config.session_cleanup_interval_sec: + cleanup_enabled = ( + self.config.session_cleanup_interval_sec >= 0 + and self.config.session_timeout_sec >= 0 + ) + if ( + cleanup_enabled + and time.time() - self._last_cleanup_time + > self.config.session_cleanup_interval_sec + ): _ = self.cleanup_stale_sessions() self._last_cleanup_time = time.time() # Poll every 100ms time.sleep(0.1) + def shutdown(self): + """Gracefully shutdown the engine and release resources.""" + logger.info("Shutting down TinkerEngine...") + if self._ray_process_manager is not None: + self._ray_process_manager.shutdown() + self._ray_process_manager = None + logger.info("TinkerEngine shutdown complete") + def run(self): """Entry point to start the engine.""" logger.info("Starting background engine...") @@ -631,7 +816,9 @@ def run(self): def main(): """Entry point for the background engine.""" # Create argument parser and add Pydantic model fields - parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests") + parser = argparse.ArgumentParser( + description="SkyRL tx tinker engine for processing requests" + ) add_model(parser, EngineConfig) # Parse command-line arguments @@ -640,8 +827,21 @@ def main(): # Create EngineConfig from parsed arguments config = EngineConfig.model_validate(vars(args)) - # Initialize and run the engine - TinkerEngine(config).run() + # Initialize the engine + engine = TinkerEngine(config) + + # Register signal handlers for graceful shutdown + def handle_shutdown_signal(signum, frame): + signal_name = signal.Signals(signum).name + logger.info(f"Received {signal_name}, initiating graceful shutdown...") + engine.shutdown() + sys.exit(0) + + signal.signal(signal.SIGTERM, handle_shutdown_signal) + signal.signal(signal.SIGINT, handle_shutdown_signal) + + # Run the engine + engine.run() if __name__ == "__main__":