diff --git a/src/aiperf/common/config/config_validators.py b/src/aiperf/common/config/config_validators.py index e441077aa..0941b3479 100644 --- a/src/aiperf/common/config/config_validators.py +++ b/src/aiperf/common/config/config_validators.py @@ -118,6 +118,7 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]] into key and value, trims any whitespace, and coerces the value to the correct type. - If the input is a dictionary, it is converted to a list of tuples by key and value pairs. - If the input is a list, it recursively calls this function on each item, and aggregates the results. + - If the item is already a 2-element sequence (key-value pair), it is converted directly to a tuple. - Otherwise, a ValueError is raised. Args: @@ -133,9 +134,14 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]] if isinstance(input, list | tuple | set): output = [] for item in input: - res = parse_str_or_dict_as_tuple_list(item) - if res is not None: - output.extend(res) + # If item is already a 2-element sequence (key-value pair), convert directly to tuple + if isinstance(item, list | tuple) and len(item) == 2: + key, value = item + output.append((str(key), coerce_value(value))) + else: + res = parse_str_or_dict_as_tuple_list(item) + if res is not None: + output.extend(res) return output if isinstance(input, dict): @@ -150,11 +156,16 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]] f"User Config: {input} - must be a valid JSON string" ) from e else: - return [ - (key.strip(), coerce_value(value.strip())) - for item in input.split(",") - for key, value in [item.split(":")] - ] + result = [] + for item in input.split(","): + parts = item.split(":", 1) + if len(parts) != 2: + raise ValueError( + f"User Config: {input} - each item must be in 'key:value' format" + ) + key, value = parts + result.append((key.strip(), coerce_value(value.strip()))) + return result raise ValueError(f"User Config: {input} - must be a valid string, list, or dict") diff --git a/tests/config/test_config_validators.py b/tests/config/test_config_validators.py index 0e61e5bc7..372c6f81f 100644 --- a/tests/config/test_config_validators.py +++ b/tests/config/test_config_validators.py @@ -156,7 +156,6 @@ def test_invalid_json_string_raises_value_error(self, invalid_json): [ ["key1_no_colon"], # Missing colon ["key1:value1", "key2_no_colon"], # One valid, one invalid - ["key1:value1:extra"], # Too many colons ], ) def test_invalid_list_format_raises_value_error(self, invalid_list): @@ -190,15 +189,36 @@ def test_invalid_input_type_raises_value_error(self, invalid_input): with pytest.raises(ValueError, match="must be a valid string, list, or dict"): parse_str_or_dict_as_tuple_list(invalid_input) - def test_string_with_multiple_colons_raises_value_error(self): - """Test that strings with multiple colons raise ValueError.""" - with pytest.raises(ValueError): - parse_str_or_dict_as_tuple_list("key1:value1:extra,key2:value2") - - def test_list_with_multiple_colons_raises_value_error(self): - """Test that list items with multiple colons raise ValueError.""" - with pytest.raises(ValueError): - parse_str_or_dict_as_tuple_list(["key1:value1:extra", "key2:value2"]) + @pytest.mark.parametrize( + "input_value,expected", + [ + # String with multiple colons + ( + "key1:value1:extra,key2:value2", + [("key1", "value1:extra"), ("key2", "value2")], + ), + # List with multiple colons + ( + ["key1:value1:extra", "key2:value2"], + [("key1", "value1:extra"), ("key2", "value2")], + ), + # URL with port + ("url:http://example.com:8080", [("url", "http://example.com:8080")]), + # Multiple entries with colons in values (timestamps, ports, etc) + ( + "server:localhost:8080,time:12:30:45,status:active", + [ + ("server", "localhost:8080"), + ("time", "12:30:45"), + ("status", "active"), + ], + ), + ], + ) + def test_values_can_contain_colons(self, input_value, expected): + """Test that values can contain colons (URLs, timestamps, etc).""" + result = parse_str_or_dict_as_tuple_list(input_value) + assert result == expected def test_whitespace_handling_in_string_input(self): """Test that whitespace is properly trimmed in string input.""" @@ -248,6 +268,31 @@ def test_none_input_returns_none(self): result = parse_str_or_dict_as_tuple_list(None) assert result is None + @pytest.mark.parametrize( + "input_list,expected", + [ + ( + [["temperature", 0.1], ["max_tokens", 150]], + [("temperature", 0.1), ("max_tokens", 150)], + ), + ( + [("temperature", 0.1), ("max_tokens", 150)], + [("temperature", 0.1), ("max_tokens", 150)], + ), + ( + [("key1", "value1"), ("key2", 123), ("key3", True)], + [("key1", "value1"), ("key2", 123), ("key3", True)], + ), + ], + ) + def test_list_of_key_value_pairs_input(self, input_list, expected): + """Test that a list of key-value pairs (lists/tuples) is converted correctly to a list of tuples.""" + result = parse_str_or_dict_as_tuple_list(input_list) + assert result == expected + # Make sure that the result is the same when parsed again. + result2 = parse_str_or_dict_as_tuple_list(result) + assert result2 == expected + class TestParseStrOrListOfPositiveValues: """Test suite for the parse_str_or_list_of_positive_values function.""" diff --git a/tests/config/test_user_config.py b/tests/config/test_user_config.py index 67a1c4483..08bc26c91 100644 --- a/tests/config/test_user_config.py +++ b/tests/config/test_user_config.py @@ -6,17 +6,95 @@ import pytest from aiperf.common.config import ( + ConversationConfig, EndpointConfig, EndpointDefaults, InputConfig, LoadGeneratorConfig, OutputConfig, TokenizerConfig, + TurnConfig, + TurnDelayConfig, UserConfig, ) from aiperf.common.enums import EndpointType +from aiperf.common.enums.dataset_enums import CustomDatasetType from aiperf.common.enums.timing_enums import TimingMode +""" +Test suite for the UserConfig class. +""" + + +class TestUserConfig: + """Test suite for the UserConfig class.""" + + def test_user_config_serialization_to_json_string(self): + """Test the serialization and deserialization of a UserConfig object to and from a JSON string.""" + config = UserConfig( + endpoint=EndpointConfig( + model_names=["model1", "model2"], + type=EndpointType.CHAT, + custom_endpoint="custom_endpoint", + streaming=True, + url="http://custom-url", + extra=[ + ("key1", "value1"), + ("key2", "value2"), + ("key3", "value3"), + ], + headers=[ + ("Authorization", "Bearer token"), + ("Content-Type", "application/json"), + ], + api_key="test_api_key", + ssl_options={"verify": False}, + timeout=10, + ), + conversation_config=ConversationConfig( + num=10, + turn=TurnConfig( + mean=10, + stddev=10, + delay=TurnDelayConfig( + mean=10, + stddev=10, + ), + ), + ), + input=InputConfig( + custom_dataset_type=CustomDatasetType.SINGLE_TURN, + ), + output=OutputConfig( + artifact_directory="test_artifacts", + ), + tokenizer=TokenizerConfig( + model_name="test_tokenizer", + ), + loadgen=LoadGeneratorConfig( + concurrency=10, + request_rate=10, + ), + verbose=True, + template_filename="test_template.yaml", + cli_command="test_cli_command", + ) + + # NOTE: Currently, we have validation logic that uses the concept of whether a field was set by the user, so + # exclude_unset must be used. exclude_defaults should also be able to work. + assert ( + UserConfig.model_validate_json( + config.model_dump_json(indent=4, exclude_unset=True) + ) + == config + ) + assert ( + UserConfig.model_validate_json( + config.model_dump_json(indent=4, exclude_defaults=True) + ) + == config + ) + def test_user_config_serialization_to_file(): """