Skip to content

Commit be5410b

Browse files
opensearch-trigger-bot[bot]github-actions[bot]joshpalis
authored
[Backport 2.13] Added new Guardrail field for remote model (#624)
Added new Guardrail field for remote model (#622) * Added new field guarddail for remote model * Fixed parsing * Deserialize * fixing guardrails * Added break --------- (cherry picked from commit 4a12730) Signed-off-by: Owais Kazi <[email protected]> Signed-off-by: Joshua Palis <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Joshua Palis <[email protected]>
1 parent 9782e5f commit be5410b

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/main/java/org/opensearch/flowframework/common/CommonValue.java

+2
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ private CommonValue() {}
168168
public static final String PIPELINE_ID = "pipeline_id";
169169
/** Pipeline Configurations */
170170
public static final String CONFIGURATIONS = "configurations";
171+
/** Guardrails field */
172+
public static final String GUARDRAILS_FIELD = "guardrails";
171173

172174
/*
173175
* Constants associated with resource provisioning / state

src/main/java/org/opensearch/flowframework/model/WorkflowNode.java

+10-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.flowframework.workflow.ProcessNode;
2121
import org.opensearch.flowframework.workflow.WorkflowData;
2222
import org.opensearch.flowframework.workflow.WorkflowStep;
23+
import org.opensearch.ml.common.model.Guardrails;
2324

2425
import java.io.IOException;
2526
import java.util.ArrayList;
@@ -32,6 +33,7 @@
3233
import static java.util.concurrent.TimeUnit.SECONDS;
3334
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
3435
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
36+
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
3537
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
3638
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
3739
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
@@ -95,6 +97,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9597
xContentBuilder.field(e.getKey());
9698
if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) {
9799
xContentBuilder.value(e.getValue());
100+
} else if (GUARDRAILS_FIELD.equals(e.getKey())) {
101+
Guardrails g = (Guardrails) e.getValue();
102+
xContentBuilder.value(g);
98103
} else if (e.getValue() instanceof Map<?, ?>) {
99104
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
100105
} else if (e.getValue() instanceof Object[]) {
@@ -156,13 +161,16 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
156161
userInputs.put(inputFieldName, parser.text());
157162
break;
158163
case START_OBJECT:
159-
if (CONFIGURATIONS.equals(inputFieldName)) {
164+
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
165+
userInputs.put(inputFieldName, Guardrails.parse(parser));
166+
break;
167+
} else if (CONFIGURATIONS.equals(inputFieldName)) {
160168
Map<String, Object> configurationsMap = parser.map();
161169
try {
162170
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
163171
userInputs.put(inputFieldName, configurationsString);
164172
} catch (Exception ex) {
165-
String errorMessage = "Failed to parse configuration map";
173+
String errorMessage = "Failed to parse" + inputFieldName + "map";
166174
logger.error(errorMessage, ex);
167175
throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST);
168176
}

src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.flowframework.util.ParseUtils;
2121
import org.opensearch.ml.client.MachineLearningNodeClient;
2222
import org.opensearch.ml.common.FunctionName;
23+
import org.opensearch.ml.common.model.Guardrails;
2324
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2425
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
2526
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
@@ -29,6 +30,7 @@
2930

3031
import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
3132
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
33+
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
3234
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
3335
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
3436
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
@@ -71,7 +73,7 @@ public PlainActionFuture<WorkflowData> execute(
7173
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();
7274

7375
Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
74-
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);
76+
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);
7577

7678
try {
7779
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
@@ -87,6 +89,7 @@ public PlainActionFuture<WorkflowData> execute(
8789
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
8890
String description = (String) inputs.get(DESCRIPTION_FIELD);
8991
String connectorId = (String) inputs.get(CONNECTOR_ID);
92+
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
9093
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);
9194

9295
MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
@@ -103,6 +106,11 @@ public PlainActionFuture<WorkflowData> execute(
103106
if (deploy != null) {
104107
builder.deployModel(deploy);
105108
}
109+
110+
if (guardRails != null) {
111+
builder.guardrails(guardRails);
112+
}
113+
106114
MLRegisterModelInput mlInput = builder.build();
107115

108116
mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {

0 commit comments

Comments
 (0)