Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/math-vista/math-vista.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
" }\n",
" ]\n",
" chat_completion = await client.chat.completions.create(\n",
" model=model.name, messages=trajectory.messages()\n",
" model=model.get_inference_name(), messages=trajectory.messages()\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" trajectory.messages_and_choices.append(choice)\n",
Expand Down
2 changes: 1 addition & 1 deletion dev/math-vista/math-vista.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def rollout(scenario: Scenario) -> art.Trajectory:
]

chat_completion = await client.chat.completions.create(
model=model.name, messages=trajectory.messages()
model=model.get_inference_name(), messages=trajectory.messages()
)
choice = chat_completion.choices[0]
trajectory.messages_and_choices.append(choice)
Expand Down
8 changes: 7 additions & 1 deletion dev/new_models/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ async def main():
iteration_start = time.perf_counter()
# launch concurrent requests and time each individually
tasks = [
timed_request(client, model.name, prompt, max_tokens, temperature)
timed_request(
client,
model.get_inference_name(),
prompt,
max_tokens,
temperature,
)
for _ in range(concurrency)
]
# Wait for all responses
Expand Down
2 changes: 1 addition & 1 deletion dev/new_models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
client = model.openai_client()
chat_completion = await client.chat.completions.create(
messages=messages,
model=model.name,
model=model.get_inference_name(),
max_tokens=100,
timeout=100,
)
Expand Down
2 changes: 1 addition & 1 deletion dev/new_models/qwen3_try.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
" client = model.openai_client()\n",
" chat_completion = await client.chat.completions.create(\n",
" messages=messages,\n",
" model=model.name,\n",
" model=model.get_inference_name(),\n",
" max_tokens=100,\n",
" timeout=100,\n",
" extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n",
Expand Down
2 changes: 1 addition & 1 deletion dev/new_models/qwen3_try.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
client = model.openai_client()
chat_completion = await client.chat.completions.create(
messages=messages,
model=model.name,
model=model.get_inference_name(),
max_tokens=100,
timeout=100,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
Expand Down
3 changes: 2 additions & 1 deletion dev/yes-no-maybe-megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ async def main():
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(openai_client, model.name, prompt) for _ in range(32)
rollout(openai_client, model.get_inference_name(), prompt)
for _ in range(32)
)
for prompt in prompts
)
Expand Down
2 changes: 1 addition & 1 deletion dev/yes-no-maybe-vision/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
" }\n",
" ]\n",
" chat_completion = await client.chat.completions.create(\n",
" model=model.name, messages=messages, max_tokens=100, timeout=100\n",
" model=model.get_inference_name(), messages=messages, max_tokens=100, timeout=100\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" content = choice.message.content\n",
Expand Down
2 changes: 1 addition & 1 deletion dev/yes-no-maybe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
" }\n",
" ]\n",
" chat_completion = await client.chat.completions.create(\n",
" messages=messages, model=model.name, max_tokens=100, timeout=100\n",
" messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" content = choice.message.content\n",
Expand Down
2 changes: 1 addition & 1 deletion dev/yes-no-maybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory:
}
]
chat_completion = await client.chat.completions.create(
messages=messages, model=model.name, max_tokens=100, timeout=100
messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
)
choice = chat_completion.choices[0]
content = choice.message.content
Expand Down
4 changes: 2 additions & 2 deletions docs/fundamentals/art-client.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ messages: art.Messages = [
]
chat_completion = await openai_client.chat.completions.create(
messages=messages,
model=model.name,
model=model.get_inference_name(),
max_tokens=100,
timeout=100,
tools=[...]
Expand Down Expand Up @@ -157,7 +157,7 @@ async def rollout(model: art.Model, scenario: Scenario) -> art.Trajectory:

# generate a completion using the client
chat_completion = await openai_client.chat.completions.create(
messages=trajectory.messages(), model=model.name
messages=trajectory.messages(), model=model.get_inference_name()
)
choice = chat_completion.choices[0]
trajectory.messages_and_choices.append(choice)
Expand Down
6 changes: 3 additions & 3 deletions docs/integrations/langgraph-integration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def return_final_answer_tool(answer: str, reference_message_ids: list[str]) -> d
@weave.op
async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory:
# Initialize chat model with temperature
chat_model = init_chat_model(model.name, temperature=1.0)
chat_model = init_chat_model(model.get_inference_name(), temperature=1.0)

# Define available tools
tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]
Expand Down Expand Up @@ -394,7 +394,7 @@ async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTra
return final_answer.model_dump()

tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]
chat_model = init_chat_model(model.name, temperature=1.0)
chat_model = init_chat_model(model.get_inference_name(), temperature=1.0)
react_agent = create_react_agent(chat_model, tools)

try:
Expand Down Expand Up @@ -522,7 +522,7 @@ To use this example, simply replace the mock email functions (`search_emails`, `

**Empty trajectories or no training data captured:**

- Ensure you're using `init_chat_model(model.name)` in your rollout function
- Ensure you're using `init_chat_model(model.get_inference_name())` in your rollout function
- Verify your rollout function actually executes the agent and makes LLM calls
- Check that `init_chat_model()` is called before creating your LangGraph agent

Expand Down
2 changes: 1 addition & 1 deletion examples/2048/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def get_completion():
return await client.chat.completions.create(
max_completion_tokens=128,
messages=trajectory.messages(),
model=model.name,
model=model.get_inference_name(),
)

try:
Expand Down
2 changes: 1 addition & 1 deletion examples/just-the-facts/just_the_facts/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def rollout(model: art.Model, scenario: FactsScenario) -> art.Trajectory:
)

completion = await client.chat.completions.create(
model=model.name if model.trainable else model.inference_model_name,
model=model.get_inference_name(),
messages=traj.messages(),
max_completion_tokens=500,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
Expand Down
4 changes: 1 addition & 3 deletions examples/mcp-rl/mcp_rl/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ async def rollout(
)

response = await client.chat.completions.create(
model=model.inference_model_name
if model.inference_model_name
else model.name,
model=model.get_inference_name(),
messages=traj.messages(),
temperature=1.0,
tools=tool_schemas,
Expand Down
14 changes: 11 additions & 3 deletions examples/prisoners-dilemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"\n",
"\n",
"async def rollout_game(\n",
" models: tuple[str, str] = (model.name, model.name),\n",
" models: tuple[str, str] = (model.get_inference_name(), model.get_inference_name()),\n",
") -> tuple[art.Trajectory, art.Trajectory]:\n",
" messages: tuple[art.Messages, art.Messages] = (\n",
" [{\"role\": \"user\", \"content\": prompt}],\n",
Expand Down Expand Up @@ -122,11 +122,19 @@
" # Simultaneously rollout self-play games, and games versus the base model.\n",
" self_play_trajectories, base_play_trajectories = await asyncio.gather(\n",
" art.gather_trajectories(\n",
" (rollout_game(models=(model.name, model.name)) for _ in range(8)),\n",
" (\n",
" rollout_game(\n",
" models=(model.get_inference_name(), model.get_inference_name())\n",
" )\n",
" for _ in range(8)\n",
" ),\n",
" pbar_desc=\"versus-self\",\n",
" ),\n",
" art.gather_trajectories(\n",
" (rollout_game(models=(model.name, BASE_MODEL)) for _ in range(8)),\n",
" (\n",
" rollout_game(models=(model.get_inference_name(), BASE_MODEL))\n",
" for _ in range(8)\n",
" ),\n",
" pbar_desc=\"versus-base\",\n",
" ),\n",
" )\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal_clue/temporal-clue-7b-async.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
" ]\n",
" client = model.openai_client()\n",
" chat_completion = await client.chat.completions.create(\n",
" messages=messages, model=model.name, max_tokens=4096\n",
" messages=messages, model=model.get_inference_name(), max_tokens=4096\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" content = choice.message.content\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal_clue/temporal-clue-7b.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
" ]\n",
" client = model.openai_client()\n",
" chat_completion = await client.chat.completions.create(\n",
" messages=messages, model=model.name, max_tokens=4096\n",
" messages=messages, model=model.get_inference_name(), max_tokens=4096\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" content = choice.message.content\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/temporal_clue/temporal-clue.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def rollout(model: art.Model, puzzle: TemporalCluePuzzle) -> art.Trajector
messages: art.Messages = [{"role": "user", "content": puzzle["prompt"]}]
client = model.openai_client()
chat_completion = await client.chat.completions.create(
messages=messages, model=model.name
messages=messages, model=model.get_inference_name()
)
choice = chat_completion.choices[0]
content = choice.message.content
Expand Down
14 changes: 7 additions & 7 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ def get_openai_server_config(
config = OpenAIServerConfig()
log_file = config.get("log_file", log_file)

# Build LoRA modules list for multi-checkpoint support
# Register under both model_name (for "current" model) and model_name@step (for specific checkpoint)
# Build LoRA modules list for multi-checkpoint support.
# Only register the explicit step-qualified name so unsuffixed
# trainable model names fail loudly.
lora_modules: list[str] | None = None
if lora_path:
step = int(os.path.basename(lora_path))
lora_modules = [
f'{{"name": "{model_name}", "path": "{lora_path}"}}',
f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}',
]
lora_modules = [f'{{"name": "{model_name}@{step}", "path": "{lora_path}"}}']

server_args = ServerArgs(
api_key="default",
Expand All @@ -38,7 +36,9 @@ def get_openai_server_config(
server_args.update(config.get("server_args", {}))
engine_args = EngineArgs(
model=base_model,
served_model_name=model_name,
# Serve the base model under its own HF name when LoRA is enabled so
# `model.name` does not silently route to a stale/incorrect adapter.
served_model_name=base_model if lora_path else model_name,
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
Expand Down
25 changes: 18 additions & 7 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import os
import shutil
import socket
import subprocess
from types import TracebackType
from typing import AsyncIterator, Iterable, Literal, cast
Expand Down Expand Up @@ -270,26 +271,36 @@ async def _prepare_backend_for_training(
model: AnyTrainableModel,
config: dev.OpenAIServerConfig | None = None,
) -> tuple[str, str]:
config_dict: dict = dict(config or {})
server_args = dict(config_dict.get("server_args", {}))

# Avoid binding collisions on busy hosts when no explicit port is provided.
if "port" not in server_args:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
server_args["port"] = s.getsockname()[1]
config_dict["server_args"] = server_args
resolved_config = cast(dev.OpenAIServerConfig, config_dict)

service = await self._get_service(model)
host, port = await service.start_openai_server(config=config)
host, port = await service.start_openai_server(config=resolved_config)

base_url = f"http://{host}:{port}/v1"
api_key = (config or {}).get("server_args", {}).get(
"api_key", None
) or "default"
api_key = server_args.get("api_key") or "default"

def done_callback(_: asyncio.Task[None]) -> None:
close_proxy(self._services.pop(model.name))

asyncio.create_task(
self._monitor_openai_server(model.name, base_url, api_key)
self._monitor_openai_server(model, base_url, api_key)
).add_done_callback(done_callback)

return base_url, api_key

async def _monitor_openai_server(
self, model_name: str, base_url: str, api_key: str
self, model: AnyTrainableModel, base_url: str, api_key: str
) -> None:
model_name = model.name
openai_client = AsyncOpenAI(
base_url=base_url,
api_key=api_key,
Expand Down Expand Up @@ -324,7 +335,7 @@ async def _monitor_openai_server(
try:
# Send a health check with a short timeout
await openai_client.completions.create(
model=model_name,
model=self._model_inference_name(model),
prompt="Hi",
max_tokens=1,
timeout=float(
Expand Down
11 changes: 0 additions & 11 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,6 @@ async def _add_lora_aliases(
)
if not added:
raise RuntimeError(f"Failed to add LoRA adapter for step {step}")
added_alias = await llm.add_lora(
LoRARequest(
lora_name=self.model_name,
lora_int_id=self._next_lora_id(),
lora_path=checkpoint_dir,
)
)
if not added_alias:
raise RuntimeError(
f"Failed to add LoRA alias for step {step} at {checkpoint_dir}"
)
self._latest_step = step

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/art/pipeline_trainer/binary_prefix_tool_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory:
messages: art.Messages = scenario["messages"]
response = await openai_client.chat.completions.create(
messages=messages,
model=model.name,
model=model.get_inference_name(),
max_tokens=max_tokens,
timeout=request_timeout,
temperature=temp,
Expand Down
5 changes: 0 additions & 5 deletions src/art/tinker/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ def custom_loss_fn(
last_checkpoint_dir.with_name(f"{next_step:04d}"),
state.training_client,
)
state.sampling_clients_and_renderers[self.model_name] = (
new_sampling_client,
state.renderer,
)
state.sampling_clients_and_renderers[f"{self.model_name}@{next_step}"] = (
new_sampling_client,
state.renderer,
Expand Down Expand Up @@ -223,7 +219,6 @@ async def _get_state(self) -> "TinkerState":
rest_client=rest_client,
training_client=training_client,
sampling_clients_and_renderers={
self.model_name: (sampling_client, renderer),
f"{self.model_name}@{current_step}": (sampling_client, renderer),
},
renderer=renderer,
Expand Down
Loading