Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/aiperf/common/config/config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
into key and value, trims any whitespace, and coerces the value to the correct type.
- If the input is a dictionary, it is converted to a list of tuples by key and value pairs.
- If the input is a list, it recursively calls this function on each item, and aggregates the results.
- If the item is already a 2-element sequence (key-value pair), it is converted directly to a tuple.
- Otherwise, a ValueError is raised.

Args:
Expand All @@ -133,9 +134,14 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
if isinstance(input, list | tuple | set):
output = []
for item in input:
res = parse_str_or_dict_as_tuple_list(item)
if res is not None:
output.extend(res)
# If item is already a 2-element sequence (key-value pair), convert directly to tuple
if isinstance(item, list | tuple) and len(item) == 2:
key, value = item
output.append((str(key), coerce_value(value)))
else:
res = parse_str_or_dict_as_tuple_list(item)
if res is not None:
output.extend(res)
return output

if isinstance(input, dict):
Expand All @@ -150,11 +156,16 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
f"User Config: {input} - must be a valid JSON string"
) from e
else:
return [
(key.strip(), coerce_value(value.strip()))
for item in input.split(",")
for key, value in [item.split(":")]
]
result = []
for item in input.split(","):
parts = item.split(":", 1)
if len(parts) != 2:
raise ValueError(
f"User Config: {input} - each item must be in 'key:value' format"
)
key, value = parts
result.append((key.strip(), coerce_value(value.strip())))
return result

raise ValueError(f"User Config: {input} - must be a valid string, list, or dict")

Expand Down
65 changes: 55 additions & 10 deletions tests/config/test_config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def test_invalid_json_string_raises_value_error(self, invalid_json):
[
["key1_no_colon"], # Missing colon
["key1:value1", "key2_no_colon"], # One valid, one invalid
["key1:value1:extra"], # Too many colons
],
)
def test_invalid_list_format_raises_value_error(self, invalid_list):
Expand Down Expand Up @@ -190,15 +189,36 @@ def test_invalid_input_type_raises_value_error(self, invalid_input):
with pytest.raises(ValueError, match="must be a valid string, list, or dict"):
parse_str_or_dict_as_tuple_list(invalid_input)

def test_string_with_multiple_colons_raises_value_error(self):
"""Test that strings with multiple colons raise ValueError."""
with pytest.raises(ValueError):
parse_str_or_dict_as_tuple_list("key1:value1:extra,key2:value2")

def test_list_with_multiple_colons_raises_value_error(self):
"""Test that list items with multiple colons raise ValueError."""
with pytest.raises(ValueError):
parse_str_or_dict_as_tuple_list(["key1:value1:extra", "key2:value2"])
@pytest.mark.parametrize(
"input_value,expected",
[
# String with multiple colons
(
"key1:value1:extra,key2:value2",
[("key1", "value1:extra"), ("key2", "value2")],
),
# List with multiple colons
(
["key1:value1:extra", "key2:value2"],
[("key1", "value1:extra"), ("key2", "value2")],
),
# URL with port
("url:http://example.com:8080", [("url", "http://example.com:8080")]),
# Multiple entries with colons in values (timestamps, ports, etc)
(
"server:localhost:8080,time:12:30:45,status:active",
[
("server", "localhost:8080"),
("time", "12:30:45"),
("status", "active"),
],
),
],
)
def test_values_can_contain_colons(self, input_value, expected):
"""Test that values can contain colons (URLs, timestamps, etc)."""
result = parse_str_or_dict_as_tuple_list(input_value)
assert result == expected

def test_whitespace_handling_in_string_input(self):
"""Test that whitespace is properly trimmed in string input."""
Expand Down Expand Up @@ -248,6 +268,31 @@ def test_none_input_returns_none(self):
result = parse_str_or_dict_as_tuple_list(None)
assert result is None

@pytest.mark.parametrize(
"input_list,expected",
[
(
[["temperature", 0.1], ["max_tokens", 150]],
[("temperature", 0.1), ("max_tokens", 150)],
),
(
[("temperature", 0.1), ("max_tokens", 150)],
[("temperature", 0.1), ("max_tokens", 150)],
),
(
[("key1", "value1"), ("key2", 123), ("key3", True)],
[("key1", "value1"), ("key2", 123), ("key3", True)],
),
],
)
def test_list_of_key_value_pairs_input(self, input_list, expected):
"""Test that a list of key-value pairs (lists/tuples) is converted correctly to a list of tuples."""
result = parse_str_or_dict_as_tuple_list(input_list)
assert result == expected
# Make sure that the result is the same when parsed again.
result2 = parse_str_or_dict_as_tuple_list(result)
assert result2 == expected


class TestParseStrOrListOfPositiveValues:
"""Test suite for the parse_str_or_list_of_positive_values function."""
Expand Down
78 changes: 78 additions & 0 deletions tests/config/test_user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,95 @@
import pytest

from aiperf.common.config import (
ConversationConfig,
EndpointConfig,
EndpointDefaults,
InputConfig,
LoadGeneratorConfig,
OutputConfig,
TokenizerConfig,
TurnConfig,
TurnDelayConfig,
UserConfig,
)
from aiperf.common.enums import EndpointType
from aiperf.common.enums.dataset_enums import CustomDatasetType
from aiperf.common.enums.timing_enums import TimingMode

"""
Test suite for the UserConfig class.
"""


class TestUserConfig:
"""Test suite for the UserConfig class."""

def test_user_config_serialization_to_json_string(self):
"""Test the serialization and deserialization of a UserConfig object to and from a JSON string."""
config = UserConfig(
endpoint=EndpointConfig(
model_names=["model1", "model2"],
type=EndpointType.CHAT,
custom_endpoint="custom_endpoint",
streaming=True,
url="http://custom-url",
extra=[
("key1", "value1"),
("key2", "value2"),
("key3", "value3"),
],
headers=[
("Authorization", "Bearer token"),
("Content-Type", "application/json"),
],
api_key="test_api_key",
ssl_options={"verify": False},
timeout=10,
),
conversation_config=ConversationConfig(
num=10,
turn=TurnConfig(
mean=10,
stddev=10,
delay=TurnDelayConfig(
mean=10,
stddev=10,
),
),
),
input=InputConfig(
custom_dataset_type=CustomDatasetType.SINGLE_TURN,
),
output=OutputConfig(
artifact_directory="test_artifacts",
),
tokenizer=TokenizerConfig(
model_name="test_tokenizer",
),
loadgen=LoadGeneratorConfig(
concurrency=10,
request_rate=10,
),
verbose=True,
template_filename="test_template.yaml",
cli_command="test_cli_command",
)

# NOTE: Currently, we have validation logic that uses the concept of whether a field was set by the user, so
# exclude_unset must be used. exclude_defaults should also be able to work.
assert (
UserConfig.model_validate_json(
config.model_dump_json(indent=4, exclude_unset=True)
)
== config
)
assert (
UserConfig.model_validate_json(
config.model_dump_json(indent=4, exclude_defaults=True)
)
== config
)


def test_user_config_serialization_to_file():
"""
Expand Down