Skip to content

Make per-turn max_tokens handling (vs. truncation at the end) more explicit and better for multi-turn #406

@CharlieFRuan

Description

@CharlieFRuan

Currently we do a truncation according to max generation length after the agent loop finishes. Ideally the agent loop should stop generating with stop_reason="length" naturally, so the truncation shouldn't have much effect if we handle the max_tokens in each turn for the agent loop.

We should make it explicit that the max generation length is not per-turn's max_tokens, but the total generation length of the trajectory.

This should also relate to the engine config of max_model_len.

# Determine stop reason
max_response_tokens = (
self.generator_cfg.sampling_params.max_generate_length
+ self.generator_cfg.max_input_length
- initial_prompt_length
)
stop_reason = "complete" # Default for trial completion
if len(response_ids) > max_response_tokens:
stop_reason = "length"

# need to truncate loss mask correctly for responses that go to max length
if self.max_turns > 1:
# max total resp length = max tokens (max length of final turn generation) + max_input_length (max input for any generation turn) - len(original prompt)
max_response_tokens = max_tokens + max_input_length - initial_prompt_length
else:
max_response_tokens = max_tokens
if len(response_ids) > max_response_tokens:
stop_reason = "length"
response_ids = response_ids[:max_response_tokens]
loss_mask = loss_mask[:max_response_tokens]

# Calculate maximum response tokens allowed
max_response_tokens = max_tokens + max_input_length - initial_prompt_length
# Determine stop reason
stop_reason = "complete" # Default for trial completion
if len(response_ids) > max_response_tokens:
stop_reason = "length"
# Truncate to maximum allowed length
response_ids = response_ids[:max_response_tokens]
loss_mask = loss_mask[:max_response_tokens]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions