Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,5 +349,5 @@ private MLCommonsSettings() {}

// Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor
public static final Setting<Boolean> ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final);
.boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, it -> {
isStaticMetricCollectionEnabled = it;
for (SettingsChangeListener listener : listeners) {
listener.onStaticMetricCollectionEnabledChanged(it);
}
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,20 @@ public interface SettingsChangeListener {
* <li><code>false</code> if multi-tenancy is disabled</li>
* </ul>
*/
void onMultiTenancyEnabledChanged(boolean isEnabled);
default void onMultiTenancyEnabledChanged(boolean isEnabled) {
// do nothing
}

/**
* Callback method that gets triggered when the static metric collection setting changes.
*
* @param isEnabled A boolean value indicating the new state of the static metric collection setting:
* <ul>
* <li><code>true</code> if static metric collection is enabled</li>
* <li><code>false</code> if static metric collection is disabled</li>
* </ul>
*/
default void onStaticMetricCollectionEnabledChanged(boolean isEnabled) {
// do nothing
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,19 @@ public void testMultiTenancyChangeNotifiesListeners() {
setting.notifyMultiTenancyListeners(true);
verify(mockListener).onMultiTenancyEnabledChanged(true);
}

@Test
public void testStaticMetricCollectionSettingChangeNotifiesListeners() {
Settings settings = Settings.builder().put("plugins.ml_commons.metrics_static_collection_enabled", false).build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);

SettingsChangeListener mockListener = mock(SettingsChangeListener.class);
setting.addListener(mockListener);

mockClusterSettings.applySettings(Settings.builder().put("plugins.ml_commons.metrics_static_collection_enabled", true).build());

verify(mockListener).onStaticMetricCollectionEnabledChanged(true);
assertTrue(setting.isStaticMetricCollectionEnabled());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ public void clusterChanged(ClusterChangedEvent event) {
* The following logic implements this behavior.
*/
for (DiscoveryNode node : state.nodes()) {
if (node.isDataNode() && Version.V_3_1_0.onOrAfter(node.getVersion())) {
if (node.isDataNode() && node.getVersion().onOrAfter(Version.V_3_1_0)) {
if (mlFeatureEnabledSetting.isMetricCollectionEnabled() && mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()) {
mlTaskManager.startStatsCollectorJob();
mlTaskManager.indexStatsCollectorJob(true);
}

if (clusterService.state().getMetadata().hasIndex(TASK_POLLING_JOB_INDEX)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ public class MLJobParameter implements ScheduledJobParameter {

public MLJobParameter() {}

public MLJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter, MLJobType jobType) {
public MLJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter, MLJobType jobType, boolean isEnabled) {
this.jobName = name;
this.schedule = schedule;
this.lockDurationSeconds = lockDurationSeconds;
this.jitter = jitter;

Instant now = Instant.now();
this.isEnabled = true;
this.isEnabled = isEnabled;
this.enabledTime = now;
this.lastUpdateTime = now;
this.jobType = jobType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont
throw new IllegalArgumentException("Job parameters is invalid.");
}

if (!jobParameter.isEnabled()) {
throw new IllegalStateException(String.format("Attempted to run disabled job of type: %s", jobParameter.getJobType().name()));
}

switch (jobParameter.getJobType()) {
case STATS_COLLECTOR:
MLStatsJobProcessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ public Collection<Object> createComponents(
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
mlFeatureEnabledSetting.addListener(mlTaskManager);
mlModelManager = new MLModelManager(
clusterService,
scriptService,
Expand Down
25 changes: 15 additions & 10 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.settings.SettingsChangeListener;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.jobs.MLJobParameter;
import org.opensearch.ml.jobs.MLJobType;
Expand All @@ -73,7 +74,7 @@
* MLTaskManager is responsible for managing MLTask.
*/
@Log4j2
public class MLTaskManager {
public class MLTaskManager implements SettingsChangeListener {
public static int TASK_SEMAPHORE_TIMEOUT = 5000; // 5 seconds
private final Map<String, MLTaskCache> taskCaches;
private final Client client;
Expand Down Expand Up @@ -553,7 +554,8 @@ public void startTaskPollingJob() {
new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES),
20L,
null,
MLJobType.BATCH_TASK_UPDATE
MLJobType.BATCH_TASK_UPDATE,
true
);

IndexRequest indexRequest = new IndexRequest()
Expand All @@ -562,24 +564,27 @@ public void startTaskPollingJob() {
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

startJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> this.taskPollingJobStarted = true);
indexJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> this.taskPollingJobStarted = true);
} catch (IOException e) {
log.error("Failed to index task polling job", e);
}
}

public void startStatsCollectorJob() {
if (statsCollectorJobStarted) {
return;
}
@Override
public void onStaticMetricCollectionEnabledChanged(boolean isEnabled) {
log.info("Static metric collection setting changed to: {}", isEnabled);
indexStatsCollectorJob(isEnabled);
}

public void indexStatsCollectorJob(boolean enabled) {
try {
MLJobParameter jobParameter = new MLJobParameter(
MLJobType.STATS_COLLECTOR.name(),
new IntervalSchedule(Instant.now(), 5, ChronoUnit.MINUTES),
60L,
null,
MLJobType.STATS_COLLECTOR
MLJobType.STATS_COLLECTOR,
enabled
);

IndexRequest indexRequest = new IndexRequest()
Expand All @@ -588,7 +593,7 @@ public void startStatsCollectorJob() {
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

startJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> this.statsCollectorJobStarted = true);
indexJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> {});
} catch (IOException e) {
log.error("Failed to index stats collection job", e);
}
Expand All @@ -601,7 +606,7 @@ public void startStatsCollectorJob() {
* @param jobType the type of job being started
* @param successCallback callback to execute on successful job indexing
*/
private void startJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) {
private void indexJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) {
mlIndicesHandler.initMLJobsIndex(ActionListener.wrap(success -> {
if (success) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.cluster;

import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX;

import java.util.Collections;

import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.client.Client;

public class MLCommonsClusterEventListenerTests extends OpenSearchTestCase {

@Mock
private ClusterService clusterService;
@Mock
private MLModelManager mlModelManager;
@Mock
private MLTaskManager mlTaskManager;
@Mock
private MLModelCacheHelper modelCacheHelper;
@Mock
private MLModelAutoReDeployer mlModelAutoReDeployer;
@Mock
private Client client;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
@Mock
private ClusterChangedEvent event;
@Mock
private ClusterState clusterState;
@Mock
private Metadata metadata;

private MLCommonsClusterEventListener listener;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
listener = new MLCommonsClusterEventListener(
clusterService,
mlModelManager,
mlTaskManager,
modelCacheHelper,
mlModelAutoReDeployer,
client,
mlFeatureEnabledSetting
);
}

public void testClusterChanged_WithV31DataNode_MetricCollectionEnabled() {
DiscoveryNode dataNode = createDataNode(Version.V_3_1_0);
setupClusterState(dataNode, false);

when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()).thenReturn(true);

listener.clusterChanged(event);

verify(mlTaskManager).indexStatsCollectorJob(true);
verify(mlTaskManager, never()).startTaskPollingJob();
}

public void testClusterChanged_WithV31DataNode_TaskPollingIndexExists() {
DiscoveryNode dataNode = createDataNode(Version.V_3_1_0);
setupClusterState(dataNode, true);

when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(false);

listener.clusterChanged(event);

verify(mlTaskManager, never()).indexStatsCollectorJob(anyBoolean());
verify(mlTaskManager).startTaskPollingJob();
}

public void testClusterChanged_WithPreV31DataNode_NoJobsStarted() {
DiscoveryNode dataNode = createDataNode(Version.V_3_0_0);
setupClusterState(dataNode, true);

when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()).thenReturn(true);

listener.clusterChanged(event);

verify(mlTaskManager, never()).indexStatsCollectorJob(anyBoolean());
verify(mlTaskManager, never()).startTaskPollingJob();
}

private DiscoveryNode createDataNode(Version version) {
return new DiscoveryNode(
"dataNode",
"dataNodeId",
buildNewFakeTransportAddress(),
Collections.emptyMap(),
Collections.singleton(DiscoveryNodeRole.DATA_ROLE),
version
);
}

private void setupClusterState(DiscoveryNode node, boolean hasTaskPollingIndex) {
DiscoveryNodes nodes = DiscoveryNodes.builder().add(node).build();

when(event.state()).thenReturn(clusterState);
when(event.previousState()).thenReturn(clusterState);
when(event.nodesDelta()).thenReturn(mock(DiscoveryNodes.Delta.class));
when(clusterState.nodes()).thenReturn(nodes);
when(clusterState.getMetadata()).thenReturn(metadata);
when(clusterService.state()).thenReturn(clusterState);
when(metadata.hasIndex(TASK_POLLING_JOB_INDEX)).thenReturn(hasTaskPollingIndex);
when(metadata.settings()).thenReturn(org.opensearch.common.settings.Settings.EMPTY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void setUp() {
lockDurationSeconds = 20L;
jitter = 0.5;
jobType = null;
jobParameter = new MLJobParameter(jobName, schedule, lockDurationSeconds, jitter, jobType);
jobParameter = new MLJobParameter(jobName, schedule, lockDurationSeconds, jitter, jobType, true);
}

@Test
Expand All @@ -54,7 +54,7 @@ public void testToXContent() throws Exception {
@Test
public void testNullCase() throws IOException {
String newJobName = "test-job";
MLJobParameter nullParameter = new MLJobParameter(newJobName, null, null, null, null);
MLJobParameter nullParameter = new MLJobParameter(newJobName, null, null, null, null, true);
nullParameter.setLastUpdateTime(null);
nullParameter.setEnabledTime(null);

Expand All @@ -64,6 +64,7 @@ public void testNullCase() throws IOException {

assertTrue(jsonString.contains(newJobName));
assertEquals(newJobName, nullParameter.getName());
assertTrue(nullParameter.isEnabled());
assertNull(nullParameter.getSchedule());
assertNull(nullParameter.getLockDurationSeconds());
assertNull(nullParameter.getJitter());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,11 @@ public void testRunJobWithNullJobType() {
when(jobParameter.getJobType()).thenReturn(null);
jobRunner.runJob(jobParameter, jobExecutionContext);
}

@Test(expected = IllegalStateException.class)
public void testRunJobWithDisabledJob() {
when(jobParameter.isEnabled()).thenReturn(false);
when(jobParameter.getJobType()).thenReturn(MLJobType.STATS_COLLECTOR);
jobRunner.runJob(jobParameter, jobExecutionContext);
}
}
Loading