diff --git a/CHANGELOG.md b/CHANGELOG.md index 2600815d1..4ecd70dd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - Adds capability to automatically switch to old access-control if model-group is excluded from protected resources setting ([#1569](https://github.com/opensearch-project/anomaly-detection/pull/1569)) - Adding suggest and validate transport actions to node client ([#1605](https://github.com/opensearch-project/anomaly-detection/pull/1605)) - Adding auto create as an optional field on detectors ([#1602](https://github.com/opensearch-project/anomaly-detection/pull/1602)) +- Adding create and start to AD node client ([#1611](https://github.com/opensearch-project/anomaly-detection/pull/1611)) ### Bug Fixes diff --git a/build.gradle b/build.gradle index 954441b2f..6317b6860 100644 --- a/build.gradle +++ b/build.gradle @@ -192,6 +192,9 @@ dependencies { testImplementation 'org.reflections:reflections:0.10.2' testImplementation "org.opensearch.test:framework:${opensearch_version}" + + zipArchive("org.opensearch.plugin:opensearch-ml-plugin:${opensearch_build}") + } apply plugin: 'java' diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java index 63fd73bad..3bce7b3a6 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionClient.java @@ -9,9 +9,13 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.transport.SuggestConfigParamRequest; import org.opensearch.timeseries.transport.SuggestConfigParamResponse; import org.opensearch.timeseries.transport.ValidateConfigRequest; @@ -110,4 +114,40 @@ default ActionFuture suggestAnomalyDetector(SuggestC * @param listener a listener to be notified of the result */ void suggestAnomalyDetector(SuggestConfigParamRequest suggestRequest, ActionListener listener); + + /** + * Create anomaly detector - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#create-detector + * @param createRequest request to create the detector + * @return ActionFuture of IndexAnomalyDetectorResponse + */ + default ActionFuture createAnomalyDetector(IndexAnomalyDetectorRequest createRequest) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + createAnomalyDetector(createRequest, actionFuture); + return actionFuture; + } + + /** + * Create anomaly detector - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#create-detector + * @param createRequest request to create the detector + * @param listener a listener to be notified of the result + */ + void createAnomalyDetector(IndexAnomalyDetectorRequest createRequest, ActionListener listener); + + /** + * Start anomaly detector - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#start-detector + * @param startRequest request to start the detector + * @return ActionFuture of JobResponse + */ + default ActionFuture startAnomalyDetector(JobRequest startRequest) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + startAnomalyDetector(startRequest, actionFuture); + return actionFuture; + } + + /** + * Start anomaly detector - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#start-detector + * @param startRequest request to start the detector + * @param listener a listener to be notified of the result + */ + void startAnomalyDetector(JobRequest startRequest, ActionListener listener); } diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index ebbd34f96..f28cd2f17 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -9,8 +9,12 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; import org.opensearch.ad.transport.GetAnomalyDetectorAction; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.IndexAnomalyDetectorAction; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; import org.opensearch.ad.transport.SearchAnomalyResultAction; import org.opensearch.ad.transport.SuggestAnomalyDetectorParamAction; @@ -19,6 +23,8 @@ import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.transport.SuggestConfigParamRequest; import org.opensearch.timeseries.transport.SuggestConfigParamResponse; import org.opensearch.timeseries.transport.ValidateConfigRequest; @@ -63,6 +69,16 @@ public void suggestAnomalyDetector(SuggestConfigParamRequest suggestRequest, Act this.client.execute(SuggestAnomalyDetectorParamAction.INSTANCE, suggestRequest, suggestConfigResponseActionListener(listener)); } + @Override + public void createAnomalyDetector(IndexAnomalyDetectorRequest createRequest, ActionListener listener) { + this.client.execute(IndexAnomalyDetectorAction.INSTANCE, createRequest, indexAnomalyDetectorResponseActionListener(listener)); + } + + @Override + public void startAnomalyDetector(JobRequest startRequest, ActionListener listener) { + this.client.execute(AnomalyDetectorJobAction.INSTANCE, startRequest, jobResponseActionListener(listener)); + } + // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. // Additionally, we need to inject the configured NamedWriteableRegistry so NamedWriteables (present in sub-fields of @@ -107,6 +123,30 @@ private ActionListener suggestConfigResponseActionLi return actionListener; } + private ActionListener indexAnomalyDetectorResponseActionListener( + ActionListener listener + ) { + ActionListener internalListener = ActionListener.wrap(indexAnomalyDetectorResponse -> { + listener.onResponse(indexAnomalyDetectorResponse); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { + IndexAnomalyDetectorResponse response = IndexAnomalyDetectorResponse + .fromActionResponse(actionResponse, this.namedWriteableRegistry); + return response; + }); + return actionListener; + } + + private ActionListener jobResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener + .wrap(jobResponse -> { listener.onResponse(jobResponse); }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { + JobResponse response = JobResponse.fromActionResponse(actionResponse, this.namedWriteableRegistry); + return response; + }); + return actionListener; + } + private ActionListener wrapActionListener( final ActionListener listener, final Function recreate diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java index 4b847c5ca..d429566e6 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java @@ -32,6 +32,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.timeseries.transport.BaseJobTransportAction; import org.opensearch.transport.TransportService; @@ -47,7 +48,8 @@ public AnomalyDetectorJobTransportAction( ClusterService clusterService, Settings settings, NamedXContentRegistry xContentRegistry, - ADIndexJobActionHandler adIndexJobActionHandler + ADIndexJobActionHandler adIndexJobActionHandler, + NamedWriteableRegistry namedWriteableRegistry ) { super( transportService, @@ -64,7 +66,8 @@ public AnomalyDetectorJobTransportAction( AnomalyDetector.class, adIndexJobActionHandler, Clock.systemUTC(), // inject cannot find clock due to OS limitation - AnomalyDetector.class + AnomalyDetector.class, + namedWriteableRegistry ); } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java index a493031de..4b272df4f 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -23,6 +23,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.rest.RestRequest; public class IndexAnomalyDetectorRequest extends ActionRequest implements DocRequest { @@ -50,10 +51,10 @@ public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { detector = new AnomalyDetector(in); method = in.readEnum(RestRequest.Method.class); requestTimeout = in.readTimeValue(); - maxSingleEntityAnomalyDetectors = in.readInt(); - maxMultiEntityAnomalyDetectors = in.readInt(); - maxAnomalyFeatures = in.readInt(); - maxCategoricalFields = in.readInt(); + maxSingleEntityAnomalyDetectors = in.readOptionalInt(); + maxMultiEntityAnomalyDetectors = in.readOptionalInt(); + maxAnomalyFeatures = in.readOptionalInt(); + maxCategoricalFields = in.readOptionalInt(); } public IndexAnomalyDetectorRequest( @@ -83,6 +84,22 @@ public IndexAnomalyDetectorRequest( this.maxCategoricalFields = maxCategoricalFields; } + public IndexAnomalyDetectorRequest(String detectorID, AnomalyDetector detector, RestRequest.Method method) { + this( + detectorID, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + WriteRequest.RefreshPolicy.IMMEDIATE, + detector, + method, + TimeValue.timeValueSeconds(60), + null, + null, + null, + null + ); + } + public String getDetectorID() { return detectorID; } @@ -137,10 +154,10 @@ public void writeTo(StreamOutput out) throws IOException { detector.writeTo(out); out.writeEnum(method); out.writeTimeValue(requestTimeout); - out.writeInt(maxSingleEntityAnomalyDetectors); - out.writeInt(maxMultiEntityAnomalyDetectors); - out.writeInt(maxAnomalyFeatures); - out.writeInt(maxCategoricalFields); + out.writeOptionalInt(maxSingleEntityAnomalyDetectors); + out.writeOptionalInt(maxMultiEntityAnomalyDetectors); + out.writeOptionalInt(maxAnomalyFeatures); + out.writeOptionalInt(maxCategoricalFields); } @Override @@ -162,4 +179,32 @@ public String index() { public String id() { return detectorID; } + + public static IndexAnomalyDetectorRequest fromActionRequest( + final ActionRequest actionRequest, + org.opensearch.core.common.io.stream.NamedWriteableRegistry namedWriteableRegistry + ) { + if (actionRequest instanceof IndexAnomalyDetectorRequest) { + return (IndexAnomalyDetectorRequest) actionRequest; + } + + try ( + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + org.opensearch.core.common.io.stream.OutputStreamStreamOutput osso = + new org.opensearch.core.common.io.stream.OutputStreamStreamOutput(baos) + ) { + actionRequest.writeTo(osso); + try ( + org.opensearch.core.common.io.stream.StreamInput input = new org.opensearch.core.common.io.stream.InputStreamStreamInput( + new java.io.ByteArrayInputStream(baos.toByteArray()) + ); + org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput namedInput = + new org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput(input, namedWriteableRegistry) + ) { + return new IndexAnomalyDetectorRequest(namedInput); + } + } catch (java.io.IOException e) { + throw new IllegalArgumentException("failed to parse ActionRequest into IndexAnomalyDetectorRequest", e); + } + } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java index 661f16285..a21b6189f 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java @@ -11,10 +11,17 @@ package org.opensearch.ad.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; @@ -81,4 +88,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) .endObject(); } + + public static IndexAnomalyDetectorResponse fromActionResponse( + ActionResponse actionResponse, + NamedWriteableRegistry namedWriteableRegistry + ) { + if (actionResponse instanceof IndexAnomalyDetectorResponse) { + return (IndexAnomalyDetectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try ( + StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray())); + NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry) + ) { + return new IndexAnomalyDetectorResponse(namedWriteableAwareInput); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into IndexAnomalyDetectorResponse", e); + } + } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index 65afbc0b8..bde16576d 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -24,6 +24,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -32,6 +33,7 @@ import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; @@ -41,6 +43,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; @@ -54,7 +57,7 @@ import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; -public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { +public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(IndexAnomalyDetectorTransportAction.class); private final Client client; private final SecurityClientUtil clientUtil; @@ -66,6 +69,11 @@ public class IndexAnomalyDetectorTransportAction extends HandledTransportAction< private volatile Boolean filterByEnabled; private final SearchFeatureDao searchFeatureDao; private final Settings settings; + protected final NamedWriteableRegistry namedWriteableRegistry; + private volatile Integer maxSingleEntityAnomalyDetectors; + private volatile Integer maxMultiEntityAnomalyDetectors; + private volatile Integer maxAnomalyFeatures; + private volatile Integer maxCategoricalFields; @Inject public IndexAnomalyDetectorTransportAction( @@ -78,7 +86,8 @@ public IndexAnomalyDetectorTransportAction( ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager, - SearchFeatureDao searchFeatureDao + SearchFeatureDao searchFeatureDao, + NamedWriteableRegistry namedWriteableRegistry ) { super(IndexAnomalyDetectorAction.NAME, transportService, actionFilters, IndexAnomalyDetectorRequest::new); this.client = client; @@ -92,10 +101,42 @@ public IndexAnomalyDetectorTransportAction( filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); this.settings = settings; + this.namedWriteableRegistry = namedWriteableRegistry; + + // Initialize cluster settings for node client requests + this.maxSingleEntityAnomalyDetectors = AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxMultiEntityAnomalyDetectors = AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS.get(settings); + this.maxAnomalyFeatures = AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings); + this.maxCategoricalFields = ADNumericSetting.maxCategoricalFields(); + + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + it -> maxSingleEntityAnomalyDetectors = it + ); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS, it -> maxMultiEntityAnomalyDetectors = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); } @Override - protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionListener actionListener) { + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) { + IndexAnomalyDetectorRequest request = IndexAnomalyDetectorRequest.fromActionRequest(actionRequest, namedWriteableRegistry); + + // Use cached settings if request has nulls (request directly from AD Node Client) + Integer maxSingle = request.getMaxSingleEntityAnomalyDetectors() != null + ? request.getMaxSingleEntityAnomalyDetectors() + : maxSingleEntityAnomalyDetectors; + Integer maxMulti = request.getMaxMultiEntityAnomalyDetectors() != null + ? request.getMaxMultiEntityAnomalyDetectors() + : maxMultiEntityAnomalyDetectors; + Integer maxFeatures = request.getMaxAnomalyFeatures() != null ? request.getMaxAnomalyFeatures() : maxAnomalyFeatures; + Integer maxCategorical = request.getMaxCategoricalFields() != null ? request.getMaxCategoricalFields() : maxCategoricalFields; + User user = ParseUtils.getUserContext(client); String detectorId = request.getDetectorID(); RestRequest.Method method = request.getMethod(); @@ -105,13 +146,19 @@ protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionL try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { verifyResourceAccessAndProcessRequest( ADCommonName.AD_RESOURCE_TYPE, - () -> indexDetector(user, detectorId, method, listener, detector -> adExecute(request, user, detector, context, listener)), + () -> indexDetector( + user, + detectorId, + method, + listener, + detector -> adExecute(request, user, detector, context, listener, maxSingle, maxMulti, maxFeatures, maxCategorical) + ), () -> resolveUserAndExecute( user, detectorId, method, listener, - (detector) -> adExecute(request, user, detector, context, listener) + (detector) -> adExecute(request, user, detector, context, listener, maxSingle, maxMulti, maxFeatures, maxCategorical) ) ); @@ -181,7 +228,11 @@ protected void adExecute( User user, AnomalyDetector currentDetector, ThreadContext.StoredContext storedContext, - ActionListener listener + ActionListener listener, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + Integer maxCategoricalFields ) { anomalyDetectionIndices.update(); String detectorId = request.getDetectorID(); @@ -191,10 +242,6 @@ protected void adExecute( AnomalyDetector detector = request.getDetector(); RestRequest.Method method = request.getMethod(); TimeValue requestTimeout = request.getRequestTimeout(); - Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); - Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); - Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); - Integer maxCategoricalFields = request.getMaxCategoricalFields(); storedContext.restore(); checkIndicesAndExecute(detector.getIndices(), () -> { diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index b477d082b..462c03a7d 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -73,7 +73,15 @@ public ValidateAnomalyDetectorTransportAction( } @Override - protected Processor createProcessor(Config detector, ValidateConfigRequest request, User user) { + protected Processor createProcessor( + Config detector, + ValidateConfigRequest request, + User user, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + Integer maxCategoricalFields + ) { return new ValidateAnomalyDetectorActionHandler( clusterService, client, @@ -81,10 +89,10 @@ protected Processor createProcessor(Config detector, Val indexManagement, detector, request.getRequestTimeout(), - request.getMaxSingleEntityAnomalyDetectors(), - request.getMaxMultiEntityAnomalyDetectors(), - request.getMaxAnomalyFeatures(), - request.getMaxCategoricalFields(), + maxSingleStreamConfigs, + maxHCConfigs, + maxFeatures, + maxCategoricalFields, RestRequest.Method.POST, xContentRegistry, user, diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java index df4f71a90..731e1e356 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java @@ -16,6 +16,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; import org.opensearch.forecast.indices.ForecastIndex; @@ -42,7 +43,8 @@ public ForecasterJobTransportAction( ClusterService clusterService, Settings settings, NamedXContentRegistry xContentRegistry, - ForecastIndexJobActionHandler forecastIndexJobActionHandler + ForecastIndexJobActionHandler forecastIndexJobActionHandler, + NamedWriteableRegistry namedWriteableRegistry ) { super( transportService, @@ -59,7 +61,8 @@ public ForecasterJobTransportAction( Forecaster.class, forecastIndexJobActionHandler, Clock.systemUTC(), // inject cannot find clock due to OS limitation - Forecaster.class + Forecaster.class, + namedWriteableRegistry ); } } diff --git a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java index a9050e44f..667294c10 100644 --- a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java @@ -68,7 +68,15 @@ public ValidateForecasterTransportAction( } @Override - protected Processor createProcessor(Config forecaster, ValidateConfigRequest request, User user) { + protected Processor createProcessor( + Config forecaster, + ValidateConfigRequest request, + User user, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + Integer maxCategoricalFields + ) { return new ValidateForecasterActionHandler( clusterService, client, @@ -76,10 +84,10 @@ protected Processor createProcessor(Config forecaster, V indexManagement, forecaster, request.getRequestTimeout(), - request.getMaxSingleEntityAnomalyDetectors(), - request.getMaxMultiEntityAnomalyDetectors(), - request.getMaxAnomalyFeatures(), - request.getMaxCategoricalFields(), + maxSingleStreamConfigs, + maxHCConfigs, + maxFeatures, + maxCategoricalFields, RestRequest.Method.POST, xContentRegistry, user, diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java index 8c3f6558f..a36067535 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -24,6 +25,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.tasks.Task; import org.opensearch.timeseries.ExecuteResultResponseRecorder; @@ -43,7 +45,7 @@ import org.opensearch.transport.client.Client; public abstract class BaseJobTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder, IndexJobActionHandlerType extends IndexJobActionHandler, ConfigType extends Config> - extends HandledTransportAction { + extends HandledTransportAction { private final Logger logger = LogManager.getLogger(BaseJobTransportAction.class); private final Client client; @@ -59,6 +61,7 @@ public abstract class BaseJobTransportAction & private final IndexJobActionHandlerType indexJobActionHandlerType; private final Clock clock; private final Class configTypeClass; + protected final NamedWriteableRegistry namedWriteableRegistry; public BaseJobTransportAction( TransportService transportService, @@ -75,7 +78,8 @@ public BaseJobTransportAction( Class configClass, IndexJobActionHandlerType indexJobActionHandlerType, Clock clock, - Class configTypeClass + Class configTypeClass, + NamedWriteableRegistry namedWriteableRegistry ) { super(jobActionName, transportService, actionFilters, JobRequest::new); this.transportService = transportService; @@ -92,10 +96,12 @@ public BaseJobTransportAction( this.indexJobActionHandlerType = indexJobActionHandlerType; this.clock = clock; this.configTypeClass = configTypeClass; + this.namedWriteableRegistry = namedWriteableRegistry; } @Override - protected void doExecute(Task task, JobRequest request, ActionListener actionListener) { + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) { + JobRequest request = JobRequest.fromActionRequest(actionRequest, namedWriteableRegistry); String configId = request.getConfigID(); DateRange dateRange = request.getDateRange(); boolean historical = request.isHistorical(); diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java index 9356038b2..5a63f0615 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java @@ -18,6 +18,8 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -64,6 +66,10 @@ public abstract class BaseValidateConfigTransportAction configTypeClass; + private volatile Integer maxSingleEntityAnomalyDetectors; + private volatile Integer maxMultiEntityAnomalyDetectors; + private volatile Integer maxAnomalyFeatures; + private volatile Integer maxCategoricalFields; public BaseValidateConfigTransportAction( String actionName, @@ -95,6 +101,25 @@ public BaseValidateConfigTransportAction( this.settings = settings; this.validationAspect = validationAspect; this.configTypeClass = configTypeClass; + + // Initialize cluster settings for node client requests + this.maxSingleEntityAnomalyDetectors = AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxMultiEntityAnomalyDetectors = AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS.get(settings); + this.maxAnomalyFeatures = AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings); + this.maxCategoricalFields = ADNumericSetting.maxCategoricalFields(); + + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + it -> maxSingleEntityAnomalyDetectors = it + ); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS, it -> maxMultiEntityAnomalyDetectors = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); } @Override @@ -216,6 +241,17 @@ public void validateExecute( ActionListener listener ) { storedContext.restore(); + + // Resolve settings (from request or cached values) + Integer maxSingleEntity = request.getMaxSingleEntityAnomalyDetectors() != null + ? request.getMaxSingleEntityAnomalyDetectors() + : maxSingleEntityAnomalyDetectors; + Integer maxMultiEntity = request.getMaxMultiEntityAnomalyDetectors() != null + ? request.getMaxMultiEntityAnomalyDetectors() + : maxMultiEntityAnomalyDetectors; + Integer maxFeatures = request.getMaxAnomalyFeatures() != null ? request.getMaxAnomalyFeatures() : maxAnomalyFeatures; + Integer maxCategorical = request.getMaxCategoricalFields() != null ? request.getMaxCategoricalFields() : maxCategoricalFields; + Config config = request.getConfig(); ActionListener validateListener = ActionListener.wrap(response -> { // forcing response to be empty @@ -232,7 +268,8 @@ public void validateExecute( }); checkIndicesAndExecute(config.getIndices(), () -> { try { - createProcessor(config, request, user).start(validateListener); + createProcessor(config, request, user, maxSingleEntity, maxMultiEntity, maxFeatures, maxCategorical) + .start(validateListener); } catch (Exception exception) { String errorMessage = String .format(Locale.ROOT, "Unknown exception caught while validating config %s", request.getConfig()); @@ -242,5 +279,13 @@ public void validateExecute( }, listener); } - protected abstract Processor createProcessor(Config config, ValidateConfigRequest request, User user); + protected abstract Processor createProcessor( + Config config, + ValidateConfigRequest request, + User user, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + Integer maxCategoricalFields + ); } diff --git a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java index b97f80c23..0f718540c 100644 --- a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java @@ -121,4 +121,32 @@ public String index() { public String id() { return configID; } + + public static JobRequest fromActionRequest( + final ActionRequest actionRequest, + org.opensearch.core.common.io.stream.NamedWriteableRegistry namedWriteableRegistry + ) { + if (actionRequest instanceof JobRequest) { + return (JobRequest) actionRequest; + } + + try ( + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + org.opensearch.core.common.io.stream.OutputStreamStreamOutput osso = + new org.opensearch.core.common.io.stream.OutputStreamStreamOutput(baos) + ) { + actionRequest.writeTo(osso); + try ( + org.opensearch.core.common.io.stream.StreamInput input = new org.opensearch.core.common.io.stream.InputStreamStreamInput( + new java.io.ByteArrayInputStream(baos.toByteArray()) + ); + org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput namedInput = + new org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput(input, namedWriteableRegistry) + ) { + return new JobRequest(namedInput); + } + } catch (java.io.IOException e) { + throw new IllegalArgumentException("failed to parse ActionRequest into JobRequest", e); + } + } } diff --git a/src/main/java/org/opensearch/timeseries/transport/JobResponse.java b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java index faa7df2c8..56516ea60 100644 --- a/src/main/java/org/opensearch/timeseries/transport/JobResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java @@ -11,9 +11,16 @@ package org.opensearch.timeseries.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; @@ -45,4 +52,22 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { return builder.startObject().field(RestHandlerUtils._ID, id).endObject(); } + + public static JobResponse fromActionResponse(ActionResponse actionResponse, NamedWriteableRegistry namedWriteableRegistry) { + if (actionResponse instanceof JobResponse) { + return (JobResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try ( + StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray())); + NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry) + ) { + return new JobResponse(namedWriteableAwareInput); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into JobResponse", e); + } + } } diff --git a/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java index 8d9675d93..e81eea68f 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java @@ -54,11 +54,11 @@ public ValidateConfigRequest(StreamInput in) throws IOException { } validationType = in.readString(); - maxSingleStreamConfigs = in.readInt(); - maxHCConfigs = in.readInt(); - maxFeatures = in.readInt(); + maxSingleStreamConfigs = in.readOptionalInt(); + maxHCConfigs = in.readOptionalInt(); + maxFeatures = in.readOptionalInt(); requestTimeout = in.readTimeValue(); - maxCategoricalFields = in.readInt(); + maxCategoricalFields = in.readOptionalInt(); } public ValidateConfigRequest( @@ -81,17 +81,21 @@ public ValidateConfigRequest( this.maxCategoricalFields = maxCategoricalFields; } + public ValidateConfigRequest(AnalysisType context, Config config, String validationType) { + this(context, config, validationType, null, null, null, TimeValue.timeValueSeconds(60), null); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeEnum(context); config.writeTo(out); out.writeString(validationType); - out.writeInt(maxSingleStreamConfigs); - out.writeInt(maxHCConfigs); - out.writeInt(maxFeatures); + out.writeOptionalInt(maxSingleStreamConfigs); + out.writeOptionalInt(maxHCConfigs); + out.writeOptionalInt(maxFeatures); out.writeTimeValue(requestTimeout); - out.writeInt(maxCategoricalFields); + out.writeOptionalInt(maxCategoricalFields); } @Override @@ -134,7 +138,6 @@ public static ValidateConfigRequest fromActionRequest( if (actionRequest instanceof ValidateConfigRequest) { return (ValidateConfigRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try ( diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java index 52319f04d..a8e0cfa77 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionClientTests.java @@ -15,9 +15,13 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.transport.SuggestConfigParamRequest; import org.opensearch.timeseries.transport.SuggestConfigParamResponse; import org.opensearch.timeseries.transport.ValidateConfigRequest; @@ -42,6 +46,12 @@ public class AnomalyDetectionClientTests { @Mock SuggestConfigParamResponse suggestResponse; + @Mock + IndexAnomalyDetectorResponse createResponse; + + @Mock + JobResponse startResponse; + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -75,6 +85,19 @@ public void suggestAnomalyDetector( ) { listener.onResponse(suggestResponse); } + + @Override + public void createAnomalyDetector( + IndexAnomalyDetectorRequest createRequest, + ActionListener listener + ) { + listener.onResponse(createResponse); + } + + @Override + public void startAnomalyDetector(JobRequest startRequest, ActionListener listener) { + listener.onResponse(startResponse); + } }; } @@ -130,4 +153,32 @@ public void suggestAnomalyDetector() { assertEquals(suggestResponse, anomalyDetectionClient.suggestAnomalyDetector(suggestRequest).actionGet()); } + @Test + public void createAnomalyDetector() { + IndexAnomalyDetectorRequest createRequest = new IndexAnomalyDetectorRequest( + "test-detector-id", + 1L, + 1L, + org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE, + null, + org.opensearch.rest.RestRequest.Method.POST, + org.opensearch.common.unit.TimeValue.timeValueSeconds(30), + 10, + 10, + 5, + 2 + ); + assertEquals(createResponse, anomalyDetectionClient.createAnomalyDetector(createRequest).actionGet()); + } + + @Test + public void startAnomalyDetector() { + JobRequest startRequest = new JobRequest( + "test-detector-id", + ADIndex.CONFIG.getIndexName(), + "/_plugins/_anomaly_detection/detectors/test-detector-id/_start" + ); + assertEquals(startResponse, anomalyDetectionClient.startAnomalyDetector(startRequest).actionGet()); + } + } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index d5cca1fc6..69df8d38c 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -34,8 +34,12 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; import org.opensearch.ad.transport.GetAnomalyDetectorAction; import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.IndexAnomalyDetectorAction; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.ad.transport.SuggestAnomalyDetectorParamAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; import org.opensearch.common.lucene.uid.Versions; @@ -50,6 +54,8 @@ import org.opensearch.timeseries.model.ConfigValidationIssue; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.transport.SuggestConfigParamRequest; import org.opensearch.timeseries.transport.SuggestConfigParamResponse; import org.opensearch.timeseries.transport.ValidateConfigRequest; @@ -316,4 +322,79 @@ public void testSuggestAnomalyDetector() throws IOException { verify(clientSpy, times(1)).execute(any(SuggestAnomalyDetectorParamAction.class), any(), any()); } + @Test + public void testCreateAnomalyDetector() throws IOException { + ingestTestData(indexName, startTime, 1, "test", 10); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + null + ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + IndexAnomalyDetectorResponse response = new IndexAnomalyDetectorResponse( + "test-detector-id", + 1L, + 1L, + 1L, + detector, + RestStatus.CREATED + ); + listener.onResponse(response); + return null; + }).when(clientSpy).execute(any(IndexAnomalyDetectorAction.class), any(), any()); + + IndexAnomalyDetectorRequest createRequest = new IndexAnomalyDetectorRequest( + "test-detector-id", + 1L, + 1L, + org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE, + detector, + org.opensearch.rest.RestRequest.Method.POST, + org.opensearch.common.unit.TimeValue.timeValueSeconds(30), + 10, + 10, + 5, + 2 + ); + + IndexAnomalyDetectorResponse response = adClient.createAnomalyDetector(createRequest).actionGet(10000); + assertNotNull(response); + assertEquals("test-detector-id", response.getId()); + verify(clientSpy, times(1)).execute(any(IndexAnomalyDetectorAction.class), any(), any()); + } + + @Test + public void testStartAnomalyDetector() throws IOException { + ingestTestData(indexName, startTime, 1, "test", 10); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + JobResponse response = new JobResponse("test-detector-id"); + listener.onResponse(response); + return null; + }).when(clientSpy).execute(any(AnomalyDetectorJobAction.class), any(), any()); + + JobRequest startRequest = new JobRequest( + "test-detector-id", + ADIndex.CONFIG.getIndexName(), + null, + false, + "/_plugins/_anomaly_detection/detectors/test-detector-id/_start" + ); + + JobResponse response = adClient.startAnomalyDetector(startRequest).actionGet(10000); + assertNotNull(response); + assertEquals("test-detector-id", response.getId()); + verify(clientSpy, times(1)).execute(any(AnomalyDetectorJobAction.class), any(), any()); + } + } diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java index e4f160aa1..4c167a8ad 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -23,6 +23,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -84,4 +85,81 @@ public void testIndexResponse() throws Exception { Map map = TestHelpers.XContentBuilderToMap(builder); Assert.assertEquals(map.get(RestHandlerUtils._ID), "1234"); } + + @Test + public void testIndexRequestFromActionRequest_SameType() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorRequest request = new IndexAnomalyDetectorRequest( + "1234", + 4321, + 5678, + WriteRequest.RefreshPolicy.NONE, + detector, + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5, + 10 + ); + IndexAnomalyDetectorRequest result = IndexAnomalyDetectorRequest.fromActionRequest(request, writableRegistry()); + Assert.assertSame(request, result); + } + + @Test + public void testIndexRequestFromActionRequest_DifferentType() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorRequest original = new IndexAnomalyDetectorRequest( + "1234", + 4321, + 5678, + WriteRequest.RefreshPolicy.NONE, + detector, + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5, + 10 + ); + + org.opensearch.action.ActionRequest actionRequest = new org.opensearch.action.ActionRequest() { + @Override + public org.opensearch.action.ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws java.io.IOException { + original.writeTo(out); + } + }; + + IndexAnomalyDetectorRequest result = IndexAnomalyDetectorRequest.fromActionRequest(actionRequest, writableRegistry()); + Assert.assertEquals(original.getDetectorID(), result.getDetectorID()); + } + + @Test + public void testIndexResponseFromActionResponse_SameType() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorResponse response = new IndexAnomalyDetectorResponse("1234", 2, 2, 2, detector, RestStatus.OK); + IndexAnomalyDetectorResponse result = IndexAnomalyDetectorResponse.fromActionResponse(response, writableRegistry()); + Assert.assertSame(response, result); + } + + @Test + public void testIndexResponseFromActionResponse_DifferentType() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorResponse original = new IndexAnomalyDetectorResponse("1234", 2, 2, 2, detector, RestStatus.OK); + + org.opensearch.core.action.ActionResponse actionResponse = new org.opensearch.core.action.ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws java.io.IOException { + original.writeTo(out); + } + }; + + IndexAnomalyDetectorResponse result = IndexAnomalyDetectorResponse.fromActionResponse(actionResponse, writableRegistry()); + Assert.assertEquals(original.getId(), result.getId()); + } } diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index c37736829..79588c2f6 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -17,10 +17,12 @@ import static org.mockito.Mockito.when; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Map; @@ -52,9 +54,16 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.rest.RestRequest; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; @@ -78,6 +87,7 @@ public class IndexAnomalyDetectorTransportActionTests extends OpenSearchIntegTes private Client client = mock(Client.class); private SecurityClientUtil clientUtil; private SearchFeatureDao searchFeatureDao; + private NamedWriteableRegistry registry; @SuppressWarnings("unchecked") @Override @@ -87,10 +97,35 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES + ) + ) + ) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new)); + namedWriteables + .add( + new NamedWriteableRegistry.Entry( + AggregationBuilder.class, + ValueCountAggregationBuilder.NAME, + ValueCountAggregationBuilder::new + ) + ); + registry = new NamedWriteableRegistry(namedWriteables); + ClusterName clusterName = new ClusterName("test"); Settings indexSettings = Settings .builder() @@ -119,7 +154,8 @@ public void setUp() throws Exception { mock(ADIndexManagement.class), xContentRegistry(), adTaskManager, - searchFeatureDao + searchFeatureDao, + registry ); task = mock(Task.class); AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); @@ -218,7 +254,8 @@ public void testIndexTransportActionWithUserAndFilterOn() { mock(ADIndexManagement.class), xContentRegistry(), adTaskManager, - searchFeatureDao + searchFeatureDao, + registry ); transportAction.doExecute(task, request, response); } @@ -243,7 +280,8 @@ public void testIndexTransportActionWithUserAndFilterOff() { mock(ADIndexManagement.class), xContentRegistry(), adTaskManager, - searchFeatureDao + searchFeatureDao, + registry ); transportAction.doExecute(task, request, response); } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java index 5f7a4ede4..0e49b3adc 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java @@ -47,8 +47,29 @@ public void testValidateAnomalyDetectorRequestSerialization() throws IOException request1.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); ValidateConfigRequest request2 = new ValidateConfigRequest(input); + } + + @Test + public void testValidateAnomalyDetectorRequestWithNullValues() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + String typeStr = "model"; + + // Test convenience constructor with null cluster settings + ValidateConfigRequest request1 = new ValidateConfigRequest(AnalysisType.AD, detector, typeStr); + + // Test serialization with null values + BytesStreamOutput output = new BytesStreamOutput(); + request1.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ValidateConfigRequest request2 = new ValidateConfigRequest(input); + assertEquals("serialization has the wrong detector", request2.getConfig(), detector); + assertEquals("serialization has the wrong validation type", request2.getValidationType(), typeStr); + assertNull("maxSingleStreamConfigs should be null", request2.getMaxSingleEntityAnomalyDetectors()); + assertNull("maxHCConfigs should be null", request2.getMaxMultiEntityAnomalyDetectors()); + assertNull("maxFeatures should be null", request2.getMaxAnomalyFeatures()); + assertEquals("serialization has the wrong typeStr", request2.getValidationType(), typeStr); - assertEquals("serialization has the wrong requestTimeout", request2.getRequestTimeout(), requestTimeout); + assertEquals("serialization has the wrong requestTimeout", request2.getRequestTimeout(), TimeValue.timeValueSeconds(60)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java index 903d055ac..3a36d5d84 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java @@ -548,4 +548,23 @@ public void testValidateAnomalyDetectorWithDateNanosWithoutIssue() throws IOExce assertNull(response.getIssue()); } + @Test + public void testValidateAnomalyDetectorWithNullSettings() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(sumValueFeature(nameField, ipField + ".is_error", "test-3"))); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + null, + null, + null, + new TimeValue(5_000L), + null + ); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNull(response.getIssue()); + } + } diff --git a/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java index 2117e48e1..054fd70ac 100644 --- a/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java @@ -16,9 +16,11 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import org.junit.Assert; import org.junit.Before; @@ -36,7 +38,14 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; @@ -49,6 +58,7 @@ public class AnomalyDetectorJobActionTests extends OpenSearchIntegTestCase { private Task task; private JobRequest request; private ActionListener response; + private NamedWriteableRegistry registry; @Override @Before @@ -68,7 +78,19 @@ public void setUp() throws Exception { org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(mockThreadPool); when(mockThreadPool.getThreadContext()).thenReturn(threadContext); - + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new)); + namedWriteables + .add( + new NamedWriteableRegistry.Entry( + AggregationBuilder.class, + ValueCountAggregationBuilder.NAME, + ValueCountAggregationBuilder::new + ) + ); + registry = new NamedWriteableRegistry(namedWriteables); action = new AnomalyDetectorJobTransportAction( mock(TransportService.class), mock(ActionFilters.class), @@ -76,7 +98,8 @@ public void setUp() throws Exception { clusterService, indexSettings(), xContentRegistry(), - mock(ADIndexJobActionHandler.class) + mock(ADIndexJobActionHandler.class), + registry ); task = mock(Task.class); request = new JobRequest( @@ -156,4 +179,53 @@ public void testAdJobResponse() throws IOException { JobResponse newResponse = new JobResponse(input); Assert.assertEquals(response.getId(), newResponse.getId()); } + + @Test + public void testJobRequestFromActionRequest_SameType() { + JobRequest request = new JobRequest("1234", ADIndex.CONFIG.getIndexName(), "_start"); + JobRequest result = JobRequest.fromActionRequest(request, registry); + Assert.assertSame(request, result); + } + + @Test + public void testJobRequestFromActionRequest_DifferentType() { + JobRequest original = new JobRequest("1234", ADIndex.CONFIG.getIndexName(), "_start"); + + org.opensearch.action.ActionRequest actionRequest = new org.opensearch.action.ActionRequest() { + @Override + public org.opensearch.action.ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(org.opensearch.core.common.io.stream.StreamOutput out) throws java.io.IOException { + original.writeTo(out); + } + }; + + JobRequest result = JobRequest.fromActionRequest(actionRequest, registry); + Assert.assertEquals(original.getConfigID(), result.getConfigID()); + } + + @Test + public void testJobResponseFromActionResponse_SameType() { + JobResponse response = new JobResponse("1234"); + JobResponse result = JobResponse.fromActionResponse(response, registry); + Assert.assertSame(response, result); + } + + @Test + public void testJobResponseFromActionResponse_DifferentType() { + JobResponse original = new JobResponse("1234"); + + org.opensearch.core.action.ActionResponse actionResponse = new org.opensearch.core.action.ActionResponse() { + @Override + public void writeTo(org.opensearch.core.common.io.stream.StreamOutput out) throws java.io.IOException { + original.writeTo(out); + } + }; + + JobResponse result = JobResponse.fromActionResponse(actionResponse, registry); + Assert.assertEquals(original.getId(), result.getId()); + } }