diff --git a/dev/math-vista/math-vista.ipynb b/dev/math-vista/math-vista.ipynb index a33ba5614..93965f89a 100644 --- a/dev/math-vista/math-vista.ipynb +++ b/dev/math-vista/math-vista.ipynb @@ -128,7 +128,7 @@ " }\n", " ]\n", " chat_completion = await client.chat.completions.create(\n", - " model=model.name, messages=trajectory.messages()\n", + " model=model.get_inference_name(), messages=trajectory.messages()\n", " )\n", " choice = chat_completion.choices[0]\n", " trajectory.messages_and_choices.append(choice)\n", diff --git a/dev/math-vista/math-vista.py b/dev/math-vista/math-vista.py index 455bf764c..68694ccd7 100644 --- a/dev/math-vista/math-vista.py +++ b/dev/math-vista/math-vista.py @@ -61,7 +61,7 @@ async def rollout(scenario: Scenario) -> art.Trajectory: ] chat_completion = await client.chat.completions.create( - model=model.name, messages=trajectory.messages() + model=model.get_inference_name(), messages=trajectory.messages() ) choice = chat_completion.choices[0] trajectory.messages_and_choices.append(choice) diff --git a/dev/new_models/benchmark_inference.py b/dev/new_models/benchmark_inference.py index 8d1038897..2366f2859 100644 --- a/dev/new_models/benchmark_inference.py +++ b/dev/new_models/benchmark_inference.py @@ -77,7 +77,13 @@ async def main(): iteration_start = time.perf_counter() # launch concurrent requests and time each individually tasks = [ - timed_request(client, model.name, prompt, max_tokens, temperature) + timed_request( + client, + model.get_inference_name(), + prompt, + max_tokens, + temperature, + ) for _ in range(concurrency) ] # Wait for all responses diff --git a/dev/new_models/gemma3.py b/dev/new_models/gemma3.py index 6cbe57705..e95d984c6 100644 --- a/dev/new_models/gemma3.py +++ b/dev/new_models/gemma3.py @@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory: client = model.openai_client() chat_completion = await client.chat.completions.create( messages=messages, - model=model.name, + model=model.get_inference_name(), max_tokens=100, timeout=100, ) diff --git a/dev/new_models/qwen3_try.ipynb b/dev/new_models/qwen3_try.ipynb index fc2123948..00cd9d180 100644 --- a/dev/new_models/qwen3_try.ipynb +++ b/dev/new_models/qwen3_try.ipynb @@ -32,7 +32,7 @@ " client = model.openai_client()\n", " chat_completion = await client.chat.completions.create(\n", " messages=messages,\n", - " model=model.name,\n", + " model=model.get_inference_name(),\n", " max_tokens=100,\n", " timeout=100,\n", " extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n", diff --git a/dev/new_models/qwen3_try.py b/dev/new_models/qwen3_try.py index 9ff33805f..c3c43b1e1 100644 --- a/dev/new_models/qwen3_try.py +++ b/dev/new_models/qwen3_try.py @@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory: client = model.openai_client() chat_completion = await client.chat.completions.create( messages=messages, - model=model.name, + model=model.get_inference_name(), max_tokens=100, timeout=100, extra_body={"chat_template_kwargs": {"enable_thinking": False}}, diff --git a/dev/yes-no-maybe-megatron.py b/dev/yes-no-maybe-megatron.py index ea7a669fa..ad1e3c1e9 100644 --- a/dev/yes-no-maybe-megatron.py +++ b/dev/yes-no-maybe-megatron.py @@ -64,7 +64,8 @@ async def main(): train_groups = await art.gather_trajectory_groups( ( art.TrajectoryGroup( - rollout(openai_client, model.name, prompt) for _ in range(32) + rollout(openai_client, model.get_inference_name(), prompt) + for _ in range(32) ) for prompt in prompts ) diff --git a/dev/yes-no-maybe-vision/train.ipynb b/dev/yes-no-maybe-vision/train.ipynb index 939c47aca..a32878b3a 100644 --- a/dev/yes-no-maybe-vision/train.ipynb +++ b/dev/yes-no-maybe-vision/train.ipynb @@ -60,7 +60,7 @@ " }\n", " ]\n", " chat_completion = await client.chat.completions.create(\n", - " model=model.name, messages=messages, max_tokens=100, timeout=100\n", + " model=model.get_inference_name(), messages=messages, max_tokens=100, timeout=100\n", " )\n", " choice = chat_completion.choices[0]\n", " content = choice.message.content\n", diff --git a/dev/yes-no-maybe.ipynb b/dev/yes-no-maybe.ipynb index b3113db5e..444106d1e 100644 --- a/dev/yes-no-maybe.ipynb +++ b/dev/yes-no-maybe.ipynb @@ -65,7 +65,7 @@ " }\n", " ]\n", " chat_completion = await client.chat.completions.create(\n", - " messages=messages, model=model.name, max_tokens=100, timeout=100\n", + " messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100\n", " )\n", " choice = chat_completion.choices[0]\n", " content = choice.message.content\n", diff --git a/dev/yes-no-maybe.py b/dev/yes-no-maybe.py index a396b2194..e32215f47 100644 --- a/dev/yes-no-maybe.py +++ b/dev/yes-no-maybe.py @@ -17,7 +17,7 @@ async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory: } ] chat_completion = await client.chat.completions.create( - messages=messages, model=model.name, max_tokens=100, timeout=100 + messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100 ) choice = chat_completion.choices[0] content = choice.message.content diff --git a/docs/fundamentals/art-client.mdx b/docs/fundamentals/art-client.mdx index b99ee0514..79198c714 100644 --- a/docs/fundamentals/art-client.mdx +++ b/docs/fundamentals/art-client.mdx @@ -104,7 +104,7 @@ messages: art.Messages = [ ] chat_completion = await openai_client.chat.completions.create( messages=messages, - model=model.name, + model=model.get_inference_name(), max_tokens=100, timeout=100, tools=[...] @@ -157,7 +157,7 @@ async def rollout(model: art.Model, scenario: Scenario) -> art.Trajectory: # generate a completion using the client chat_completion = await openai_client.chat.completions.create( - messages=trajectory.messages(), model=model.name + messages=trajectory.messages(), model=model.get_inference_name() ) choice = chat_completion.choices[0] trajectory.messages_and_choices.append(choice) diff --git a/docs/integrations/langgraph-integration.mdx b/docs/integrations/langgraph-integration.mdx index 5ba5214e0..a9f49d742 100644 --- a/docs/integrations/langgraph-integration.mdx +++ b/docs/integrations/langgraph-integration.mdx @@ -89,7 +89,7 @@ def return_final_answer_tool(answer: str, reference_message_ids: list[str]) -> d @weave.op async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory: # Initialize chat model with temperature - chat_model = init_chat_model(model.name, temperature=1.0) + chat_model = init_chat_model(model.get_inference_name(), temperature=1.0) # Define available tools tools = [search_inbox_tool, read_email_tool, return_final_answer_tool] @@ -394,7 +394,7 @@ async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTra return final_answer.model_dump() tools = [search_inbox_tool, read_email_tool, return_final_answer_tool] - chat_model = init_chat_model(model.name, temperature=1.0) + chat_model = init_chat_model(model.get_inference_name(), temperature=1.0) react_agent = create_react_agent(chat_model, tools) try: @@ -522,7 +522,7 @@ To use this example, simply replace the mock email functions (`search_emails`, ` **Empty trajectories or no training data captured:** -- Ensure you're using `init_chat_model(model.name)` in your rollout function +- Ensure you're using `init_chat_model(model.get_inference_name())` in your rollout function - Verify your rollout function actually executes the agent and makes LLM calls - Check that `init_chat_model()` is called before creating your LangGraph agent diff --git a/examples/2048/rollout.py b/examples/2048/rollout.py index 1cb2828ef..cc7c66859 100644 --- a/examples/2048/rollout.py +++ b/examples/2048/rollout.py @@ -57,7 +57,7 @@ async def get_completion(): return await client.chat.completions.create( max_completion_tokens=128, messages=trajectory.messages(), - model=model.name, + model=model.get_inference_name(), ) try: diff --git a/examples/just-the-facts/just_the_facts/rollout.py b/examples/just-the-facts/just_the_facts/rollout.py index df01a931e..bf7528762 100644 --- a/examples/just-the-facts/just_the_facts/rollout.py +++ b/examples/just-the-facts/just_the_facts/rollout.py @@ -52,7 +52,7 @@ async def rollout(model: art.Model, scenario: FactsScenario) -> art.Trajectory: ) completion = await client.chat.completions.create( - model=model.name if model.trainable else model.inference_model_name, + model=model.get_inference_name(), messages=traj.messages(), max_completion_tokens=500, extra_body={"chat_template_kwargs": {"enable_thinking": False}}, diff --git a/examples/mcp-rl/mcp_rl/rollout.py b/examples/mcp-rl/mcp_rl/rollout.py index c7dc35923..8f8a2d258 100644 --- a/examples/mcp-rl/mcp_rl/rollout.py +++ b/examples/mcp-rl/mcp_rl/rollout.py @@ -150,9 +150,7 @@ async def rollout( ) response = await client.chat.completions.create( - model=model.inference_model_name - if model.inference_model_name - else model.name, + model=model.get_inference_name(), messages=traj.messages(), temperature=1.0, tools=tool_schemas, diff --git a/examples/prisoners-dilemma.ipynb b/examples/prisoners-dilemma.ipynb index b04f7ad3d..05921df7a 100644 --- a/examples/prisoners-dilemma.ipynb +++ b/examples/prisoners-dilemma.ipynb @@ -50,7 +50,7 @@ "\n", "\n", "async def rollout_game(\n", - " models: tuple[str, str] = (model.name, model.name),\n", + " models: tuple[str, str] = (model.get_inference_name(), model.get_inference_name()),\n", ") -> tuple[art.Trajectory, art.Trajectory]:\n", " messages: tuple[art.Messages, art.Messages] = (\n", " [{\"role\": \"user\", \"content\": prompt}],\n", @@ -122,11 +122,19 @@ " # Simultaneously rollout self-play games, and games versus the base model.\n", " self_play_trajectories, base_play_trajectories = await asyncio.gather(\n", " art.gather_trajectories(\n", - " (rollout_game(models=(model.name, model.name)) for _ in range(8)),\n", + " (\n", + " rollout_game(\n", + " models=(model.get_inference_name(), model.get_inference_name())\n", + " )\n", + " for _ in range(8)\n", + " ),\n", " pbar_desc=\"versus-self\",\n", " ),\n", " art.gather_trajectories(\n", - " (rollout_game(models=(model.name, BASE_MODEL)) for _ in range(8)),\n", + " (\n", + " rollout_game(models=(model.get_inference_name(), BASE_MODEL))\n", + " for _ in range(8)\n", + " ),\n", " pbar_desc=\"versus-base\",\n", " ),\n", " )\n", diff --git a/examples/temporal_clue/temporal-clue-7b-async.ipynb b/examples/temporal_clue/temporal-clue-7b-async.ipynb index 929bfd017..cc9988e4d 100644 --- a/examples/temporal_clue/temporal-clue-7b-async.ipynb +++ b/examples/temporal_clue/temporal-clue-7b-async.ipynb @@ -90,7 +90,7 @@ " ]\n", " client = model.openai_client()\n", " chat_completion = await client.chat.completions.create(\n", - " messages=messages, model=model.name, max_tokens=4096\n", + " messages=messages, model=model.get_inference_name(), max_tokens=4096\n", " )\n", " choice = chat_completion.choices[0]\n", " content = choice.message.content\n", diff --git a/examples/temporal_clue/temporal-clue-7b.ipynb b/examples/temporal_clue/temporal-clue-7b.ipynb index e75dad331..ea8ba9401 100644 --- a/examples/temporal_clue/temporal-clue-7b.ipynb +++ b/examples/temporal_clue/temporal-clue-7b.ipynb @@ -69,7 +69,7 @@ " ]\n", " client = model.openai_client()\n", " chat_completion = await client.chat.completions.create(\n", - " messages=messages, model=model.name, max_tokens=4096\n", + " messages=messages, model=model.get_inference_name(), max_tokens=4096\n", " )\n", " choice = chat_completion.choices[0]\n", " content = choice.message.content\n", diff --git a/examples/temporal_clue/temporal-clue.py b/examples/temporal_clue/temporal-clue.py index e4fc078de..1b69735a7 100644 --- a/examples/temporal_clue/temporal-clue.py +++ b/examples/temporal_clue/temporal-clue.py @@ -36,7 +36,7 @@ async def rollout(model: art.Model, puzzle: TemporalCluePuzzle) -> art.Trajector messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}] client = model.openai_client() chat_completion = await client.chat.completions.create( - messages=messages, model=model.name + messages=messages, model=model.get_inference_name() ) choice = chat_completion.choices[0] content = choice.message.content diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index e6f400d16..b3b8ab535 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -18,15 +18,13 @@ def get_openai_server_config( config = OpenAIServerConfig() log_file = config.get("log_file", log_file) - # Build LoRA modules list for multi-checkpoint support - # Register under both model_name (for "current" model) and model_name@step (for specific checkpoint) + # Build LoRA modules list for multi-checkpoint support. + # Only register the explicit step-qualified name so unsuffixed + # trainable model names fail loudly. lora_modules: list[str] | None = None if lora_path: step = int(os.path.basename(lora_path)) - lora_modules = [ - f'{{"name": "{model_name}", "path": "{lora_path}"}}', - f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}', - ] + lora_modules = [f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}'] server_args = ServerArgs( api_key="default", @@ -38,7 +36,9 @@ def get_openai_server_config( server_args.update(config.get("server_args", {})) engine_args = EngineArgs( model=base_model, - served_model_name=model_name, + # Serve the base model under its own HF name when LoRA is enabled so + # `model.name` does not silently route to a stale/incorrect adapter. + served_model_name=base_model if lora_path else model_name, generation_config="vllm", ) engine_args.update(config.get("engine_args", {})) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 640ee1f33..afba81b18 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -3,6 +3,7 @@ import math import os import shutil +import socket import subprocess from types import TracebackType from typing import AsyncIterator, Iterable, Literal, cast @@ -270,26 +271,36 @@ async def _prepare_backend_for_training( model: AnyTrainableModel, config: dev.OpenAIServerConfig | None = None, ) -> tuple[str, str]: + config_dict: dict = dict(config or {}) + server_args = dict(config_dict.get("server_args", {})) + + # Avoid binding collisions on busy hosts when no explicit port is provided. + if "port" not in server_args: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + server_args["port"] = s.getsockname()[1] + config_dict["server_args"] = server_args + resolved_config = cast(dev.OpenAIServerConfig, config_dict) + service = await self._get_service(model) - host, port = await service.start_openai_server(config=config) + host, port = await service.start_openai_server(config=resolved_config) base_url = f"http://{host}:{port}/v1" - api_key = (config or {}).get("server_args", {}).get( - "api_key", None - ) or "default" + api_key = server_args.get("api_key") or "default" def done_callback(_: asyncio.Task[None]) -> None: close_proxy(self._services.pop(model.name)) asyncio.create_task( - self._monitor_openai_server(model.name, base_url, api_key) + self._monitor_openai_server(model, base_url, api_key) ).add_done_callback(done_callback) return base_url, api_key async def _monitor_openai_server( - self, model_name: str, base_url: str, api_key: str + self, model: AnyTrainableModel, base_url: str, api_key: str ) -> None: + model_name = model.name openai_client = AsyncOpenAI( base_url=base_url, api_key=api_key, @@ -324,7 +335,7 @@ async def _monitor_openai_server( try: # Send a health check with a short timeout await openai_client.completions.create( - model=model_name, + model=self._model_inference_name(model), prompt="Hi", max_tokens=1, timeout=float( diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index c335d4e2b..e0f94367a 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -147,17 +147,6 @@ async def _add_lora_aliases( ) if not added: raise RuntimeError(f"Failed to add LoRA adapter for step {step}") - added_alias = await llm.add_lora( - LoRARequest( - lora_name=self.model_name, - lora_int_id=self._next_lora_id(), - lora_path=checkpoint_dir, - ) - ) - if not added_alias: - raise RuntimeError( - f"Failed to add LoRA alias for step {step} at {checkpoint_dir}" - ) self._latest_step = step async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: diff --git a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py index 606a33318..52c829750 100644 --- a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py +++ b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py @@ -234,7 +234,7 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory: messages: art.Messages = scenario["messages"] response = await openai_client.chat.completions.create( messages=messages, - model=model.name, + model=model.get_inference_name(), max_tokens=max_tokens, timeout=request_timeout, temperature=temp, diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 702bdce77..c2d9515f4 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -143,10 +143,6 @@ def custom_loss_fn( last_checkpoint_dir.with_name(f"{next_step:04d}"), state.training_client, ) - state.sampling_clients_and_renderers[self.model_name] = ( - new_sampling_client, - state.renderer, - ) state.sampling_clients_and_renderers[f"{self.model_name}@{next_step}"] = ( new_sampling_client, state.renderer, @@ -223,7 +219,6 @@ async def _get_state(self) -> "TinkerState": rest_client=rest_client, training_client=training_client, sampling_clients_and_renderers={ - self.model_name: (sampling_client, renderer), f"{self.model_name}@{current_step}": (sampling_client, renderer), }, renderer=renderer, diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index 741721686..e5eb1180e 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -334,7 +334,7 @@ async def _run_openai_server( @app.post("/v1/chat/completions") async def chat_completions(body: CompletionCreateParams) -> ChatCompletion: model_name = body.get("model") - _, step = self._parse_model_name(model_name) + parsed_model_name, step = self._parse_model_name(model_name) sampler_client = await self._get_sampler_client(state, step) messages = self._normalize_messages(body["messages"]) @@ -427,7 +427,7 @@ async def chat_completions(body: CompletionCreateParams) -> ChatCompletion: id=str(uuid.uuid4()), choices=choices, created=int(time.time()), - model=self._format_response_model(model_name, step, state), + model=self._format_response_model(parsed_model_name, step), object="chat.completion", usage=CompletionUsage( completion_tokens=completion_tokens, @@ -666,27 +666,32 @@ def _normalize_tools( normalized.append(dict(tool)) return normalized - def _parse_model_name( - self, model_name: str | None - ) -> tuple[str | None, int | None]: - if model_name and "@" in model_name: - base_name, step_str = model_name.rsplit("@", 1) - try: - return base_name, int(step_str) - except ValueError as exc: - raise HTTPException( - status_code=400, detail=f"Invalid model step: {model_name}" - ) from exc - return model_name, None - - def _format_response_model( - self, model_name: str | None, step: int | None, state: ModelState - ) -> str: - if model_name is None: - return f"{state.model_name}@{state.current_step}" - if step is None and "@" not in model_name: - return f"{model_name}@{state.current_step}" - return model_name + def _parse_model_name(self, model_name: str | None) -> tuple[str, int]: + if not model_name: + raise HTTPException( + status_code=400, + detail="Model name is required and must include an '@step' suffix. Use model.get_inference_name().", + ) + if "@" not in model_name: + raise HTTPException( + status_code=400, + detail=( + f"Model '{model_name}' is missing an '@step' suffix. " + "Use model.get_inference_name()." + ), + ) + + base_name, step_str = model_name.rsplit("@", 1) + try: + return base_name, int(step_str) + except ValueError as exc: + raise HTTPException( + status_code=400, detail=f"Invalid model step: {model_name}" + ) from exc + + def _format_response_model(self, model_name: str, step: int) -> str: + # Echo back the explicit model@step used for this completion. + return f"{model_name}@{step}" async def _create_training_client_from_checkpoint( self, diff --git a/tests/unit/test_multi_checkpoint_inference.py b/tests/unit/test_multi_checkpoint_inference.py index 108a7e1c4..dadaf09a4 100644 --- a/tests/unit/test_multi_checkpoint_inference.py +++ b/tests/unit/test_multi_checkpoint_inference.py @@ -311,6 +311,22 @@ def test_lora_name_step_zero(self): assert len(lora_modules) == 1 assert "my-model@0" in lora_modules[0] + def test_served_model_name_uses_base_model_when_lora_enabled(self): + """With LoRA enabled, served model name should remain the base model.""" + from art.dev.openai_server import get_openai_server_config + + config = get_openai_server_config( + model_name="my-model", + base_model="meta-llama/Llama-3.1-8B", + log_file="/tmp/test.log", + lora_path="/path/to/checkpoints/0005", + ) + + assert ( + config.get("engine_args", {}).get("served_model_name") + == "meta-llama/Llama-3.1-8B" + ) + # ============================================================================= # Step Parsing Tests @@ -318,32 +334,40 @@ def test_lora_name_step_zero(self): class TestStepParsing: - """Test parsing of @step suffix from model names.""" - - def test_parse_step_from_model_name(self): - """Test the step parsing logic used in TinkerService.""" - test_cases = [ - ("model-name", None), # No @ suffix - ("model-name@5", 5), # Valid step - ("model-name@0", 0), # Step 0 - ("model-name@100", 100), # Large step - ("model@name@5", 5), # Multiple @ (use last) - ("model-name@invalid", None), # Invalid step (not a number) - ("model-name@", None), # Empty step - ] - - for model_name, expected_step in test_cases: - step = None - if "@" in str(model_name): - _, step_str = str(model_name).rsplit("@", 1) - try: - step = int(step_str) - except ValueError: - pass - - assert step == expected_step, ( - f"Failed for {model_name}: got {step}, expected {expected_step}" - ) + """Test TinkerNative model-name parsing behavior.""" + + @pytest.fixture + def tinker_native_backend_class(self): + """Import TinkerNativeBackend, skipping if dependency unavailable.""" + try: + from art.tinker_native.backend import TinkerNativeBackend + + return TinkerNativeBackend + except ImportError as e: + pytest.skip(f"Tinker dependencies not available: {e}") + + def test_parse_step_from_model_name(self, tinker_native_backend_class): + """Valid `model@step` names should parse correctly.""" + backend = object.__new__(tinker_native_backend_class) + assert backend._parse_model_name("model-name@5") == ("model-name", 5) + assert backend._parse_model_name("model-name@0") == ("model-name", 0) + assert backend._parse_model_name("model@name@12") == ("model@name", 12) + + def test_missing_step_suffix_fails_loudly(self, tinker_native_backend_class): + """Unsuffixed model names should fail with a helpful message.""" + from fastapi import HTTPException + + backend = object.__new__(tinker_native_backend_class) + with pytest.raises(HTTPException, match="missing an '@step' suffix"): + backend._parse_model_name("model-name") + + def test_invalid_step_suffix_fails_loudly(self, tinker_native_backend_class): + """Non-numeric step suffix should fail with a helpful message.""" + from fastapi import HTTPException + + backend = object.__new__(tinker_native_backend_class) + with pytest.raises(HTTPException, match="Invalid model step"): + backend._parse_model_name("model-name@not-a-number") # =============================================================================