diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java index e227c30478..72a5acccb8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.connector; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +21,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.Builder; import lombok.Getter; @@ -38,12 +42,14 @@ public MLCreateConnectorRequest(StreamInput in) throws IOException { @Override public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; if (mlCreateConnectorInput == null) { - exception = addValidationError("ML Connector input can't be null", exception); + return addValidationError("ML Connector input can't be null", null); } + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model connector name", new FieldDescriptor(mlCreateConnectorInput.getName(), true)); + fieldsToValidate.put("Model connector description", new FieldDescriptor(mlCreateConnectorInput.getDescription(), false)); - return exception; + return validateFields(fieldsToValidate); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index d496004db5..e1fe1db68a 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.connector; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -19,6 +22,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.Builder; import lombok.Getter; @@ -57,8 +61,12 @@ public ActionRequestValidationException validate() { if (updateContent == null) { exception = addValidationError("Update connector content can't be null", exception); + } else { + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model connector name", new FieldDescriptor(updateContent.getName(), false)); + fieldsToValidate.put("Model connector description", new FieldDescriptor(updateContent.getDescription(), false)); + exception = validateFields(fieldsToValidate); } - return exception; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java index 61524689f7..48389763e8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.model; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +21,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.AccessLevel; import lombok.Builder; @@ -44,12 +48,13 @@ public MLUpdateModelRequest(StreamInput in) throws IOException { @Override public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; if (updateModelInput == null) { - exception = addValidationError("Update Model Input can't be null", exception); + return addValidationError("Update Model Input can't be null", null); } - - return exception; + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model Name", new FieldDescriptor(updateModelInput.getName(), false)); + fieldsToValidate.put("Model Description", new FieldDescriptor(updateModelInput.getDescription(), false)); + return validateFields(fieldsToValidate); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java index 4ecfa46b4b..5fd56292ee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +21,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.AccessLevel; import lombok.Builder; @@ -44,12 +48,15 @@ public MLRegisterModelGroupRequest(StreamInput in) throws IOException { @Override public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; if (registerModelGroupInput == null) { - exception = addValidationError("Model meta input can't be null", exception); + return addValidationError("Model group input can't be null", null); } - return exception; + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model group name", new FieldDescriptor(registerModelGroupInput.getName(), true)); + fieldsToValidate.put("Model group description", new FieldDescriptor(registerModelGroupInput.getDescription(), false)); + + return validateFields(fieldsToValidate); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java index e3f103dcf3..e130975c71 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +21,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.AccessLevel; import lombok.Builder; @@ -44,12 +48,15 @@ public MLUpdateModelGroupRequest(StreamInput in) throws IOException { @Override public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; if (updateModelGroupInput == null) { - exception = addValidationError("Update Model group input can't be null", exception); + return addValidationError("Update Model group input can't be null", null); } - return exception; + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model group name", new FieldDescriptor(updateModelGroupInput.getName(), false)); + fieldsToValidate.put("Model group description", new FieldDescriptor(updateModelGroupInput.getDescription(), false)); + + return validateFields(fieldsToValidate); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java index adff46812f..4dcfaea971 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java @@ -6,11 +6,14 @@ package org.opensearch.ml.common.transport.register; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.utils.StringUtils.validateFields; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +21,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.FieldDescriptor; import lombok.AccessLevel; import lombok.Builder; @@ -44,12 +48,15 @@ public MLRegisterModelRequest(StreamInput in) throws IOException { @Override public ActionRequestValidationException validate() { - ActionRequestValidationException exception = null; if (registerModelInput == null) { - exception = addValidationError("ML input can't be null", exception); + return addValidationError("ML input can't be null", null); } - return exception; + Map fieldsToValidate = new HashMap<>(); + fieldsToValidate.put("Model name", new FieldDescriptor(registerModelInput.getModelName(), true)); + fieldsToValidate.put("Model description", new FieldDescriptor(registerModelInput.getDescription(), false)); + + return validateFields(fieldsToValidate); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/utils/FieldDescriptor.java b/common/src/main/java/org/opensearch/ml/common/utils/FieldDescriptor.java new file mode 100644 index 0000000000..101bb84f69 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/FieldDescriptor.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +public class FieldDescriptor { + private final String value; + private final boolean required; + + public FieldDescriptor(String value, boolean required) { + this.value = value; + this.required = required; + } + + public String getValue() { + return value; + } + + public boolean isRequired() { + return required; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 4fd9332519..5247754480 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.utils; +import static org.opensearch.action.ValidateActions.addValidationError; + import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; @@ -28,6 +30,7 @@ import org.json.JSONException; import org.json.JSONObject; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionRequestValidationException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -60,6 +63,12 @@ public class StringUtils { + " return input;" + "\n }\n"; + // Regex allows letters, digits, spaces, hyphens, underscores, and dots. + private static final Pattern SAFE_INPUT_PATTERN = Pattern.compile("^[\\p{L}\\p{N}\\s.,!?():@\\-_'\"]*$"); + + public static final String SAFE_INPUT_DESCRIPTION = + "can only contain letters, numbers, whitespace, and basic punctuation (.,!?():@-_'\")"; + public static final Gson gson; static { @@ -497,4 +506,51 @@ public static String hashString(String input) { } } + /** + * Validates a map of fields to ensure that their values only contain allowed characters. + *

+ * Allowed characters are: letters, digits, spaces, underscores (_), hyphens (-), dots (.), and colons (:). + * If a value does not comply, a validation error is added. + * + * @param fields A map where the key is the field name (used for error messages) and the value is the text to validate. + * @return An {@link ActionRequestValidationException} containing all validation errors, or {@code null} if all fields are valid. + */ + public static ActionRequestValidationException validateFields(Map fields) { + ActionRequestValidationException exception = null; + + for (Map.Entry entry : fields.entrySet()) { + String key = entry.getKey(); + FieldDescriptor descriptor = entry.getValue(); + String value = descriptor.getValue(); + + if (descriptor.isRequired()) { + if (!isSafeText(value)) { + String reason = (value == null || value.isBlank()) ? "is required and cannot be null or blank" : SAFE_INPUT_DESCRIPTION; + exception = addValidationError(key + " " + reason, exception); + } + } else { + if (value != null && !value.isBlank() && !matchesSafePattern(value)) { + exception = addValidationError(key + " " + SAFE_INPUT_DESCRIPTION, exception); + } + } + } + + return exception; + } + + /** + * Checks if the input is safe (non-null, non-blank, matches safe character set). + * + * @param value The input string to validate + * @return true if input is safe, false otherwise + */ + public static boolean isSafeText(String value) { + return value != null && !value.isBlank() && matchesSafePattern(value); + } + + // Just checks pattern + public static boolean matchesSafePattern(String value) { + return SAFE_INPUT_PATTERN.matcher(value).matches(); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index 719e427684..4c24b67f84 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -9,6 +9,8 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -147,4 +149,125 @@ public void writeTo(StreamOutput out) throws IOException { }; MLCreateConnectorRequest.fromActionRequest(actionRequest); } + + @Test + public void validateWithUnsafeModelConnectorName() { + MLCreateConnectorInput unsafeInput = MLCreateConnectorInput + .builder() + .name("") // Unsafe name + .description("safe description") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test")) + .credential(Map.of("key", "value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1")) + .addAllBackendRoles(false) + .build(); + + MLCreateConnectorRequest request = MLCreateConnectorRequest.builder().mlCreateConnectorInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model connector name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validateWithUnsafeModelConnectorDescription() { + MLCreateConnectorInput unsafeInput = MLCreateConnectorInput + .builder() + .name("safeName") + .description("") // Unsafe description + .version("1") + .protocol("http") + .parameters(Map.of("input", "test")) + .credential(Map.of("key", "value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1")) + .addAllBackendRoles(false) + .build(); + + MLCreateConnectorRequest request = MLCreateConnectorRequest.builder().mlCreateConnectorInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model connector description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validateWithEmptyAndInvalidModelConnectorNameAndDescription() { + // Test empty name (should fail validation) + MLCreateConnectorInput emptyNameInput = MLCreateConnectorInput + .builder() + .name("") // Empty name + .description("valid description") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test")) + .credential(Map.of("key", "value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1")) + .addAllBackendRoles(false) + .build(); + + MLCreateConnectorRequest emptyNameRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(emptyNameInput).build(); + ActionRequestValidationException emptyNameException = emptyNameRequest.validate(); + assertEquals( + "Validation Failed: 1: Model connector name is required and cannot be null or blank;", + emptyNameException.getMessage() + ); + + // Test empty description (should pass validation) + MLCreateConnectorInput emptyDescriptionInput = MLCreateConnectorInput + .builder() + .name("valid name") + .description("") // Empty description + .version("1") + .protocol("http") + .parameters(Map.of("input", "test")) + .credential(Map.of("key", "value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1")) + .addAllBackendRoles(false) + .build(); + + MLCreateConnectorRequest emptyDescriptionRequest = MLCreateConnectorRequest + .builder() + .mlCreateConnectorInput(emptyDescriptionInput) + .build(); + ActionRequestValidationException emptyDescriptionException = emptyDescriptionRequest.validate(); + assertNull("Empty description should pass validation", emptyDescriptionException); + + // Test invalid characters in name and description + MLCreateConnectorInput invalidInput = MLCreateConnectorInput + .builder() + .name("invalid") + .description("invalid") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test")) + .credential(Map.of("key", "value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1")) + .addAllBackendRoles(false) + .build(); + + MLCreateConnectorRequest invalidRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(invalidInput).build(); + ActionRequestValidationException invalidException = invalidRequest.validate(); + String exceptionMessage = invalidException.getMessage(); + assertTrue( + "Error message should contain name validation failure", + exceptionMessage + .contains("Model connector name can only contain letters, numbers, whitespace, and basic punctuation (.,!?():@-_'\");") + ); + assertTrue( + "Error message should contain description validation failure", + exceptionMessage + .contains( + "Model connector description can only contain letters, numbers, whitespace, and basic punctuation (.,!?():@-_'\");" + ) + ); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index 9fab57d545..5c89ebbd78 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -10,6 +10,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -184,4 +185,34 @@ public void writeTo_withTenantId_Success() throws IOException { assertEquals(connectorId, parsedRequest.getConnectorId()); } + @Test + public void validate_Exception_UnsafeConnectorName() { + MLCreateConnectorInput unsafeInput = MLCreateConnectorInput + .builder() + .name("") // Unsafe name + .description("safe description") + .updateConnector(true) + .build(); + + MLUpdateConnectorRequest request = MLUpdateConnectorRequest.builder().connectorId("connectorId").updateContent(unsafeInput).build(); + + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model connector name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validate_Exception_UnsafeConnectorDescription() { + MLCreateConnectorInput unsafeInput = MLCreateConnectorInput + .builder() + .name("safeName") + .description("") // Unsafe description + .updateConnector(true) + .build(); + + MLUpdateConnectorRequest request = MLUpdateConnectorRequest.builder().connectorId("connectorId").updateContent(unsafeInput).build(); + + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model connector description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java index 184ab097d2..5085ddbae8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -9,6 +9,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -113,4 +114,50 @@ public void writeTo(StreamOutput out) throws IOException { MLUpdateModelRequest.fromActionRequest(actionRequest); } + @Test + public void validate_Exception_InvalidName() { + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput input = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .name("") // unsafe input + .description("safe description") + .modelConfig(config) + .build(); + + MLUpdateModelRequest request = MLUpdateModelRequest.builder().updateModelInput(input).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model Name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validate_Exception_InvalidDescription() { + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput input = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .name("safeName") + .description("") // unsafe input + .modelConfig(config) + .build(); + + MLUpdateModelRequest request = MLUpdateModelRequest.builder().updateModelInput(input).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model Description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index f675f9f321..0b2b1b66da 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -1,9 +1,11 @@ package org.opensearch.ml.common.transport.model_group; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -68,11 +70,10 @@ public void validateSuccess() { public void validateNullMLRegisterModelGroupInputException() { MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder().build(); ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); + assertEquals("Validation Failed: 1: Model group input can't be null;", exception.getMessage()); } @Test - // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here public void validateNullMLModelNameException() { mlRegisterModelGroupInput.setName(null); MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest @@ -80,8 +81,9 @@ public void validateNullMLModelNameException() { .registerModelGroupInput(mlRegisterModelGroupInput) .build(); - assertNull(request.validate()); - assertNull(request.getRegisterModelGroupInput().getName()); + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals("Validation Failed: 1: Model group name is required and cannot be null or blank;", exception.getMessage()); } @Test @@ -122,4 +124,39 @@ public void writeTo(StreamOutput out) throws IOException { }; MLRegisterModelGroupRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_Exception_UnsafeModelGroupName() { + MLRegisterModelGroupInput unsafeInput = MLRegisterModelGroupInput + .builder() + .name("") // unsafe input + .description("safe description") + .backendRoles(List.of("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); + + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder().registerModelGroupInput(unsafeInput).build(); + + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model group name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validate_Exception_UnsafeModelGroupDescription() { + MLRegisterModelGroupInput unsafeInput = MLRegisterModelGroupInput + .builder() + .name("safeName") + .description("") // unsafe input + .backendRoles(List.of("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); + + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder().registerModelGroupInput(unsafeInput).build(); + + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model group description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java index d823b77b16..7e9d8fbe3c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java @@ -4,6 +4,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -114,4 +115,38 @@ public void writeTo(StreamOutput out) throws IOException { }; MLUpdateModelGroupRequest.fromActionRequest(actionRequest); } + + @Test + public void validateWithUnsafeModelGroupName() { + MLUpdateModelGroupInput unsafeInput = MLUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("") // unsafe input + .description("safe description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); + + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model group name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validateWithUnsafeModelGroupDescription() { + MLUpdateModelGroupInput unsafeInput = MLUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("safeName") + .description("") // unsafe input + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); + + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model group description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java index bcbee60593..85e6d92b4b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java @@ -1,6 +1,7 @@ package org.opensearch.ml.common.transport.register; import static org.junit.Assert.*; +import static org.opensearch.ml.common.utils.StringUtils.SAFE_INPUT_DESCRIPTION; import java.io.IOException; import java.io.UncheckedIOException; @@ -81,13 +82,13 @@ public void validate_Exception_NullMLRegisterModelInput() { } @Test - // MLRegisterModelInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { mlRegisterModelInput.setModelName(null); MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); - assertNull(request.validate()); - assertNull(request.getRegisterModelInput().getModelName()); + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals("Validation Failed: 1: Model name is required and cannot be null or blank;", exception.getMessage()); } @Test @@ -134,4 +135,60 @@ public void writeTo(StreamOutput out) throws IOException { }; MLRegisterModelRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_Exception_UnsafeModelName() { + TextEmbeddingModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLRegisterModelInput unsafeInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.KMEANS) + .modelName("") // unsafe + .version("version") + .url("url") + .modelGroupId("modelGroupId") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .build(); + + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model name " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + + @Test + public void validate_Exception_UnsafeDescription() { + TextEmbeddingModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLRegisterModelInput unsafeInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.KMEANS) + .modelName("SafeModel") + .description("") // unsafe + .version("version") + .url("url") + .modelGroupId("modelGroupId") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .build(); + + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(unsafeInput).build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model description " + SAFE_INPUT_DESCRIPTION + ";", exception.getMessage()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index c4eb3b5c32..43e3be42f1 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -32,6 +33,7 @@ import org.junit.Assert; import org.junit.Test; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionRequestValidationException; import com.jayway.jsonpath.JsonPath; @@ -75,7 +77,7 @@ public void isJson_False() { public void toUTF8() { String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; String utf8 = StringUtils.toUTF8(rawString); - Assert.assertNotNull(utf8); + assertNotNull(utf8); } @Test @@ -749,4 +751,111 @@ public void testValidateSchema() throws IOException { String json2 = "{\"key1\": \"foo\"}"; assertThrows(OpenSearchParseException.class, () -> StringUtils.validateSchema(schema, json2)); } + + @Test + public void testIsSafeText_ValidInputs() { + assertTrue(StringUtils.isSafeText("Model-Name_1.0")); + assertTrue(StringUtils.isSafeText("This is a description:")); + assertTrue(StringUtils.isSafeText("Name_with-dots.and:colons")); + } + + @Test + public void testValidateFields_AllValid() { + Map fields = Map + .of("Field1", new FieldDescriptor("Valid Name 1", true), "Field2", new FieldDescriptor("Another_Valid-Field.Name:Here", true)); + assertNull(StringUtils.validateFields(fields)); + } + + @Test + public void testValidateFields_OptionalFieldsValidWhenBlank() { + Map fields = Map + .of( + "OptionalField1", + new FieldDescriptor("", false), + "OptionalField2", + new FieldDescriptor(" ", false), + "OptionalField3", + new FieldDescriptor(null, false) + ); + assertNull(StringUtils.validateFields(fields)); + } + + @Test + public void testValidateFields_OptionalFieldInvalidPattern() { + Map fields = Map.of("OptionalField1", new FieldDescriptor("Bad@Value$", false)); + ActionRequestValidationException exception = StringUtils.validateFields(fields); + assertNotNull(exception); + assertTrue(exception.getMessage().contains("OptionalField1")); + } + + @Test + public void testIsSafeText_AdvancedValidInputs() { + // Testing all allowed characters + assertTrue(StringUtils.isSafeText("Hello World")); // spaces + assertTrue(StringUtils.isSafeText("Hello.World")); // period + assertTrue(StringUtils.isSafeText("Hello,World")); // comma + assertTrue(StringUtils.isSafeText("Hello!World")); // exclamation + assertTrue(StringUtils.isSafeText("Hello?World")); // question mark + assertTrue(StringUtils.isSafeText("Hello(World)")); // parentheses + assertTrue(StringUtils.isSafeText("Hello:World")); // colon + assertTrue(StringUtils.isSafeText("Hello@World")); // at sign + assertTrue(StringUtils.isSafeText("Hello-World")); // hyphen + assertTrue(StringUtils.isSafeText("Hello_World")); // underscore + assertTrue(StringUtils.isSafeText("Hello'World")); // single quote + assertTrue(StringUtils.isSafeText("Hello\"World")); // double quote + } + + @Test + public void testIsSafeText_AdvancedInvalidInputs() { + // Testing specifically excluded characters + assertFalse(StringUtils.isSafeText("HelloWorld")); // greater than + assertFalse(StringUtils.isSafeText("Hello/World")); // forward slash + assertFalse(StringUtils.isSafeText("Hello\\World")); // backslash + assertFalse(StringUtils.isSafeText("Hello&World")); // ampersand + assertFalse(StringUtils.isSafeText("Hello+World")); // plus + assertFalse(StringUtils.isSafeText("Hello=World")); // equals + assertFalse(StringUtils.isSafeText("Hello;World")); // semicolon + assertFalse(StringUtils.isSafeText("Hello|World")); // pipe + assertFalse(StringUtils.isSafeText("Hello*World")); // asterisk + } + + @Test + public void testValidateFields_RequiredFields_MissingOrInvalid() { + Map fields = new HashMap<>(); + fields.put("RequiredField1", new FieldDescriptor("", true)); + fields.put("RequiredField2", new FieldDescriptor(" ", true)); + fields.put("RequiredField3", new FieldDescriptor("Bad@#Char$", true)); + fields.put("RequiredField4", new FieldDescriptor(null, true)); + + ActionRequestValidationException exception = StringUtils.validateFields(fields); + assertNotNull(exception); + String message = exception.getMessage(); + assertTrue(message.contains("RequiredField1")); + assertTrue(message.contains("RequiredField2")); + assertTrue(message.contains("RequiredField3")); + assertTrue(message.contains("RequiredField4")); + } + + @Test + public void testValidateFields_EmptyMap() { + Map fields = new HashMap<>(); + assertNull(StringUtils.validateFields(fields)); + } + + @Test + public void testValidateFields_UnicodeLettersAndNumbers() { + Map fields = Map + .of("field1", new FieldDescriptor("Hello世界123", true), "field2", new FieldDescriptor("Café42", true)); + assertNull(StringUtils.validateFields(fields)); + } + + @Test + public void testValidateFields_InvalidCharacterSet() { + Map fields = Map.of("Field1", new FieldDescriptor("Bad#Value$With^Weird*Chars", true)); + ActionRequestValidationException exception = StringUtils.validateFields(fields); + assertNotNull(exception); + assertTrue(exception.getMessage().contains("Field1")); + } + }