Skip to content

Commit cb2f23e

Browse files
[Backport 2.x] Fixed static fields initialization in WorkflowStepFactory (#533)
Fixed static fields initialization in WorkflowStepFactory (#532) Fixed static fields initialization (cherry picked from commit 24bf51a) Signed-off-by: Owais Kazi <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4446fb4 commit cb2f23e

File tree

4 files changed

+44
-139
lines changed

4 files changed

+44
-139
lines changed

src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java

-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ public Collection<Object> createComponents(
121121
);
122122
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
123123
threadPool,
124-
clusterService,
125-
client,
126124
mlClient,
127125
flowFrameworkIndicesHandler,
128126
flowFrameworkSettings

src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java

+42-99
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import org.apache.logging.log4j.LogManager;
1212
import org.apache.logging.log4j.Logger;
13-
import org.opensearch.client.Client;
14-
import org.opensearch.cluster.service.ClusterService;
1513
import org.opensearch.common.unit.TimeValue;
1614
import org.opensearch.core.common.Strings;
1715
import org.opensearch.core.rest.RestStatus;
@@ -61,37 +59,47 @@ public class WorkflowStepFactory {
6159

6260
private final Map<String, Supplier<WorkflowStep>> stepMap = new HashMap<>();
6361
private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class);
64-
private static ThreadPool threadPool;
65-
private static MachineLearningNodeClient mlClient;
66-
private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
67-
private static FlowFrameworkSettings flowFrameworkSettings;
6862

6963
/**
7064
* Instantiate this class.
7165
*
7266
* @param threadPool The OpenSearch thread pool
73-
* @param clusterService The OpenSearch cluster service
74-
* @param client The OpenSearch client steps can use
7567
* @param mlClient Machine Learning client to perform ml operations
7668
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
7769
* @param flowFrameworkSettings common settings of the plugin
7870
*/
7971
public WorkflowStepFactory(
8072
ThreadPool threadPool,
81-
ClusterService clusterService,
82-
Client client,
8373
MachineLearningNodeClient mlClient,
8474
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
8575
FlowFrameworkSettings flowFrameworkSettings
8676
) {
87-
this.threadPool = threadPool;
88-
this.mlClient = mlClient;
89-
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
90-
this.flowFrameworkSettings = flowFrameworkSettings;
91-
// Initialize the WorkflowSteps enum inside the constructor
92-
for (WorkflowSteps workflowStep : WorkflowSteps.values()) {
93-
stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step());
94-
}
77+
stepMap.put(NoOpStep.NAME, NoOpStep::new);
78+
stepMap.put(
79+
RegisterLocalCustomModelStep.NAME,
80+
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
81+
);
82+
stepMap.put(
83+
RegisterLocalSparseEncodingModelStep.NAME,
84+
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
85+
);
86+
stepMap.put(
87+
RegisterLocalPretrainedModelStep.NAME,
88+
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
89+
);
90+
stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler));
91+
stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient));
92+
stepMap.put(
93+
DeployModelStep.NAME,
94+
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
95+
);
96+
stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient));
97+
stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
98+
stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient));
99+
stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler));
100+
stepMap.put(ToolStep.NAME, ToolStep::new);
101+
stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler));
102+
stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient));
95103
}
96104

97105
/**
@@ -101,16 +109,15 @@ public WorkflowStepFactory(
101109
public enum WorkflowSteps {
102110

103111
/** Noop Step */
104-
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new),
112+
NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null),
105113

106114
/** Create Connector Step */
107115
CREATE_CONNECTOR(
108116
CreateConnectorStep.NAME,
109117
List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD),
110118
List.of(CONNECTOR_ID),
111119
List.of(OPENSEARCH_ML),
112-
TimeValue.timeValueSeconds(60),
113-
() -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)
120+
TimeValue.timeValueSeconds(60)
114121
),
115122

116123
/** Register Local Custom Model Step */
@@ -129,8 +136,7 @@ public enum WorkflowSteps {
129136
),
130137
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
131138
List.of(OPENSEARCH_ML),
132-
TimeValue.timeValueSeconds(60),
133-
() -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
139+
TimeValue.timeValueSeconds(60)
134140
),
135141

136142
/** Register Local Sparse Encoding Model Step */
@@ -139,8 +145,7 @@ public enum WorkflowSteps {
139145
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL),
140146
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
141147
List.of(OPENSEARCH_ML),
142-
TimeValue.timeValueSeconds(60),
143-
() -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
148+
TimeValue.timeValueSeconds(60)
144149
),
145150

146151
/** Register Local Pretrained Model Step */
@@ -149,8 +154,7 @@ public enum WorkflowSteps {
149154
List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT),
150155
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
151156
List.of(OPENSEARCH_ML),
152-
TimeValue.timeValueSeconds(60),
153-
() -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
157+
TimeValue.timeValueSeconds(60)
154158
),
155159

156160
/** Register Remote Model Step */
@@ -159,8 +163,7 @@ public enum WorkflowSteps {
159163
List.of(NAME_FIELD, CONNECTOR_ID),
160164
List.of(MODEL_ID, REGISTER_MODEL_STATUS),
161165
List.of(OPENSEARCH_ML),
162-
null,
163-
() -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)
166+
null
164167
),
165168

166169
/** Register Model Group Step */
@@ -169,94 +172,42 @@ public enum WorkflowSteps {
169172
List.of(NAME_FIELD),
170173
List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS),
171174
List.of(OPENSEARCH_ML),
172-
null,
173-
() -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)
175+
null
174176
),
175177

176178
/** Deploy Model Step */
177-
DEPLOY_MODEL(
178-
DeployModelStep.NAME,
179-
List.of(MODEL_ID),
180-
List.of(MODEL_ID),
181-
List.of(OPENSEARCH_ML),
182-
TimeValue.timeValueSeconds(15),
183-
() -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings)
184-
),
179+
DEPLOY_MODEL(DeployModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), TimeValue.timeValueSeconds(15)),
185180

186181
/** Undeploy Model Step */
187-
UNDEPLOY_MODEL(
188-
UndeployModelStep.NAME,
189-
List.of(MODEL_ID),
190-
List.of(SUCCESS),
191-
List.of(OPENSEARCH_ML),
192-
null,
193-
() -> new UndeployModelStep(mlClient)
194-
),
182+
UNDEPLOY_MODEL(UndeployModelStep.NAME, List.of(MODEL_ID), List.of(SUCCESS), List.of(OPENSEARCH_ML), null),
195183

196184
/** Delete Model Step */
197-
DELETE_MODEL(
198-
DeleteModelStep.NAME,
199-
List.of(MODEL_ID),
200-
List.of(MODEL_ID),
201-
List.of(OPENSEARCH_ML),
202-
null,
203-
() -> new DeleteModelStep(mlClient)
204-
),
185+
DELETE_MODEL(DeleteModelStep.NAME, List.of(MODEL_ID), List.of(MODEL_ID), List.of(OPENSEARCH_ML), null),
205186

206187
/** Delete Connector Step */
207-
DELETE_CONNECTOR(
208-
DeleteConnectorStep.NAME,
209-
List.of(CONNECTOR_ID),
210-
List.of(CONNECTOR_ID),
211-
List.of(OPENSEARCH_ML),
212-
null,
213-
() -> new DeleteConnectorStep(mlClient)
214-
),
188+
DELETE_CONNECTOR(DeleteConnectorStep.NAME, List.of(CONNECTOR_ID), List.of(CONNECTOR_ID), List.of(OPENSEARCH_ML), null),
215189

216190
/** Register Agent Step */
217-
REGISTER_AGENT(
218-
RegisterAgentStep.NAME,
219-
List.of(NAME_FIELD, TYPE),
220-
List.of(AGENT_ID),
221-
List.of(OPENSEARCH_ML),
222-
null,
223-
() -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)
224-
),
191+
REGISTER_AGENT(RegisterAgentStep.NAME, List.of(NAME_FIELD, TYPE), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),
225192

226193
/** Delete Agent Step */
227-
DELETE_AGENT(
228-
DeleteAgentStep.NAME,
229-
List.of(AGENT_ID),
230-
List.of(AGENT_ID),
231-
List.of(OPENSEARCH_ML),
232-
null,
233-
() -> new DeleteAgentStep(mlClient)
234-
),
194+
DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null),
235195

236196
/** Create Tool Step */
237-
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new);
197+
CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null);
238198

239199
private final String workflowStepName;
240200
private final List<String> inputs;
241201
private final List<String> outputs;
242202
private final List<String> requiredPlugins;
243203
private final TimeValue timeout;
244-
private final Supplier<WorkflowStep> workflowStep;
245-
246-
WorkflowSteps(
247-
String workflowStepName,
248-
List<String> inputs,
249-
List<String> outputs,
250-
List<String> requiredPlugins,
251-
TimeValue timeout,
252-
Supplier<WorkflowStep> workflowStep
253-
) {
204+
205+
WorkflowSteps(String workflowStepName, List<String> inputs, List<String> outputs, List<String> requiredPlugins, TimeValue timeout) {
254206
this.workflowStepName = workflowStepName;
255207
this.inputs = List.copyOf(inputs);
256208
this.outputs = List.copyOf(outputs);
257209
this.requiredPlugins = requiredPlugins;
258210
this.timeout = timeout;
259-
this.workflowStep = workflowStep;
260211
}
261212

262213
/**
@@ -299,14 +250,6 @@ public TimeValue timeout() {
299250
return timeout;
300251
}
301252

302-
/**
303-
* Get the step
304-
* @return the step
305-
*/
306-
public Supplier<WorkflowStep> step() {
307-
return workflowStep;
308-
}
309-
310253
/**
311254
* Get the workflow step validator object
312255
* @return the WorkflowStepValidator

src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java

+1-30
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,6 @@
88
*/
99
package org.opensearch.flowframework.model;
1010

11-
import org.opensearch.client.AdminClient;
12-
import org.opensearch.client.Client;
13-
import org.opensearch.client.ClusterAdminClient;
14-
import org.opensearch.cluster.service.ClusterService;
15-
import org.opensearch.common.settings.ClusterSettings;
16-
import org.opensearch.common.settings.Setting;
17-
import org.opensearch.common.settings.Settings;
1811
import org.opensearch.flowframework.common.FlowFrameworkSettings;
1912
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
2013
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
@@ -27,14 +20,7 @@
2720
import java.util.HashMap;
2821
import java.util.List;
2922
import java.util.Map;
30-
import java.util.Set;
31-
import java.util.stream.Collectors;
32-
import java.util.stream.Stream;
33-
34-
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
35-
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS;
36-
import static org.opensearch.flowframework.common.FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION;
37-
import static org.opensearch.flowframework.common.FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT;
23+
3824
import static org.mockito.Mockito.mock;
3925
import static org.mockito.Mockito.when;
4026

@@ -174,26 +160,11 @@ public void testParseWorkflowValidator() throws IOException {
174160
public void testWorkflowStepFactoryHasValidators() throws IOException {
175161

176162
ThreadPool threadPool = mock(ThreadPool.class);
177-
ClusterService clusterService = mock(ClusterService.class);
178-
ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class);
179-
AdminClient adminClient = mock(AdminClient.class);
180-
Client client = mock(Client.class);
181-
when(client.admin()).thenReturn(adminClient);
182-
when(adminClient.cluster()).thenReturn(clusterAdminClient);
183163
MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class);
184164
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
185165

186-
final Set<Setting<?>> settingsSet = Stream.concat(
187-
ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(),
188-
Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION)
189-
).collect(Collectors.toSet());
190-
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet);
191-
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
192-
193166
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
194167
threadPool,
195-
clusterService,
196-
client,
197168
mlClient,
198169
flowFrameworkIndicesHandler,
199170
flowFrameworkSettings

src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java

+1-8
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,7 @@ public static void setup() throws IOException {
118118
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + DEPROVISION_WORKFLOW_THREAD_POOL
119119
)
120120
);
121-
WorkflowStepFactory factory = new WorkflowStepFactory(
122-
testThreadPool,
123-
clusterService,
124-
client,
125-
mlClient,
126-
flowFrameworkIndicesHandler,
127-
flowFrameworkSettings
128-
);
121+
WorkflowStepFactory factory = new WorkflowStepFactory(testThreadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings);
129122
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings);
130123
}
131124

0 commit comments

Comments
 (0)