diff --git a/plugin/build.gradle b/plugin/build.gradle index e2ef299232..7923f35c8a 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -213,7 +213,7 @@ tasks.named("check").configure { dependsOn(integTest) } // Enable Security if -Dsecurity=true or -Dhttps=true def securityEnabled = System.getProperty("security", "false") == "true" || - System.getProperty("https", "false") == "true" + System.getProperty("https", "false") == "true" integTest { dependsOn "bundlePlugin" @@ -321,7 +321,7 @@ def configureClusterPlugins(cluster, jobSchedZip, securityZip, securityEnabled) plugin(project.tasks.bundlePlugin.archiveFile) if (securityEnabled) { - nodes.each { node -> + cluster.nodes.each { node -> node.extraConfigFile("kirk.pem", file("build/resources/test/kirk.pem")) node.extraConfigFile("kirk-key.pem", file("build/resources/test/kirk-key.pem")) node.extraConfigFile("esnode.pem", file("build/resources/test/esnode.pem")) diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 0052c8bfaa..9873ee5fad 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,6 +7,8 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -52,7 +54,6 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; -import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; @@ -147,11 +148,13 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, mlFeatureEnabledSetting.isMultiTenancyEnabled(), CommonValue.ML_MODEL_GROUP_INDEX ); - boolean rsClientPresent = ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null; - if (rsClientPresent && user != null && modelAccessControlHelper.modelAccessControlEnabled() && hasModelGroupIndex) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE) + && user != null + && modelAccessControlHelper.modelAccessControlEnabled() + && hasModelGroupIndex) { // RSC fast-path: get accessible group IDs → gate models (IDs or missing) - ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + var rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); rsc.getAccessibleResourceIds(CommonValue.ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> { SearchSourceBuilder gated = Optional.ofNullable(request.source()).orElseGet(SearchSourceBuilder::new); gated.query(rewriteQueryBuilderRSC(gated.query(), ids)); // ids may be empty → "missing only" diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index fbb87c0ff6..e486e7def6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -6,7 +6,9 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import org.opensearch.ExceptionsHelper; @@ -27,7 +29,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; @@ -96,7 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); // if resource sharing feature is enabled, access will be automatically checked by security plugin, so no need to check again - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { checkForAssociatedModels(modelGroupId, tenantId, wrappedListener); } else { validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index 74fc87da30..915a4d8a4f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -8,6 +8,8 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; @@ -27,7 +29,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; @@ -186,7 +187,7 @@ private void validateModelGroupAccess( ) { // if resource sharing feature is enabled, security plugin will have automatically evaluated access to this model group, hence no // need to validate again - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); return; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index b9f00a706e..a9d31acf8c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -7,6 +7,8 @@ import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.Collections; @@ -31,7 +33,6 @@ import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -89,7 +90,7 @@ private void preProcessRoleAndPerformSearch( .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); // If resource-sharing feature is enabled, we fetch accessible model-groups and restrict the search to those model-groups only. - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { // If a model-group is shared, then it will have been shared at-least at read access, hence the final result is guaranteed // to only contain model-groups that the user at-least has read access to. addAccessibleModelGroupsFilterAndSearch(tenantId, request, doubleWrappedListener); @@ -113,7 +114,7 @@ private void addAccessibleModelGroupsFilterAndSearch( ActionListener wrappedListener ) { SearchSourceBuilder sourceBuilder = request.source() != null ? request.source() : new SearchSourceBuilder(); - ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + var rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); // filter by accessible model-groups rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> { sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), ids)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 2db39e2952..9c9aef0683 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,6 +9,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; @@ -36,7 +38,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -150,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener feature is disabled, follow old route - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() == null) { + if (!shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { // TODO: At some point, this call must be replaced by the one above, (i.e. no user info to // be stored in model-group index) if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index adcc5d196e..91036e619d 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -9,6 +9,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import java.util.Collections; @@ -57,7 +58,6 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableList; @@ -98,8 +98,8 @@ public void validateModelGroupAccess(User user, String modelGroupId, String acti listener.onResponse(true); return; } - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { - ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { + var resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { if (!isAuthorized) { listener @@ -173,8 +173,8 @@ public void validateModelGroupAccess( listener.onResponse(true); return; } - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { - ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { + var resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { if (!isAuthorized) { listener @@ -288,6 +288,16 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti } } + /** + * Checks whether to utilize new ResourceAuthz + * @param resourceType for which to decide whether to use resource authz + * @return true if the resource-sharing feature is enabled, false otherwise. + */ + public static boolean shouldUseResourceAuthz(String resourceType) { + var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + return client != null && client.isFeatureEnabledForType(resourceType); + } + public boolean skipModelAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java index a8cb908ab0..5f9b0548da 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java @@ -9,7 +9,10 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; @@ -39,6 +42,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; @@ -48,6 +52,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -123,6 +128,8 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } @Test @@ -291,6 +298,110 @@ public void testDeleteModelGroup_Failure() { assertEquals("Failed to delete data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } + @Test + public void test_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws Exception { + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + // Feature enabled for this type => apply resource sharing + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + // Associated models search -> empty => proceed to delete + SearchResponse empty = getEmptySearchResponse(); + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(empty); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + // Delete succeeds + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + + // Legacy validation must be skipped + verify(modelAccessControlHelper, never()).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); + + // RSC path still does search (for associated models) then delete + verify(client, times(1)).search(any(), any()); + verify(client, times(1)).delete(any(), any()); + + verify(actionListener, times(1)).onResponse(any(DeleteResponse.class)); + } + + @Test + public void test_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws Exception { + // Feature enabled globally but TYPE disabled → legacy path + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + // Associated models search -> empty => proceed to delete + SearchResponse empty = getEmptySearchResponse(); + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(empty); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + // Delete succeeds + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + + // Legacy validation must run + verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); + + // Search (models) + delete executed + verify(client, times(1)).search(any(), any()); + verify(client, times(1)).delete(any(), any()); + + verify(actionListener, times(1)).onResponse(any(DeleteResponse.class)); + } + + @Test + public void test_RSC_FeatureDisabled_UsesLegacyValidation() throws Exception { + // Feature disabled by forcing the gate to false + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + + // (setup() already stubs validateModelGroupAccess(...)->onResponse(true)) + + // Associated models search -> empty => proceed to delete + SearchResponse empty = getEmptySearchResponse(); + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(empty); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + // Delete succeeds + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + + // Legacy validation must run + verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); + + // Search (models) + delete executed + verify(client, times(1)).search(any(), any()); + verify(client, times(1)).delete(any(), any()); + + verify(actionListener, times(1)).onResponse(any(DeleteResponse.class)); + } + private SearchResponse getEmptySearchResponse() { SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, true, false, null, 1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java index bd6792136b..ba73f396c5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java @@ -6,8 +6,12 @@ package org.opensearch.ml.action.model_group; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; @@ -37,12 +41,14 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.impl.SdkClientFactory; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -114,10 +120,10 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } public void test_Success() throws IOException { - GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -217,6 +223,106 @@ public void testGetModel_RuntimeException() { assertEquals("Failed to get data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } + public void test_Get_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws IOException { + // Force RSC fast-path (feature + type enabled) + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + // Tenant on request and document must match for TenantAwareHelper.validateTenantResource + String tenantId = "t-1"; + MLModelGroupGetRequest req = MLModelGroupGetRequest.builder().modelGroupId("mg-123").tenantId(tenantId).build(); + + // SDK returns the model-group doc + GetResponse getResponse = prepareMLModelGroup(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + // Execute + getModelGroupTransportAction.doExecute(null, req, actionListener); + + // Legacy validation MUST be skipped + verify(modelAccessControlHelper, times(0)).validateModelGroupAccess(any(), any(), any(), any(), any()); + + ArgumentCaptor captor = ArgumentCaptor.forClass(MLModelGroupGetResponse.class); + verify(actionListener).onResponse(captor.capture()); + MLModelGroupGetResponse resp = captor.getValue(); + assertNotNull(resp); + assertNotNull(resp.getMlModelGroup()); + assertEquals("modelGroup", resp.getMlModelGroup().getName()); + } + + public void test_Get_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws IOException { + // Feature enabled globally but TYPE disabled → legacy path + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + // Allow legacy access validation to pass + doAnswer(inv -> { + ActionListener l = inv.getArgument(4); + l.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), anyString(), anyString(), any(), any()); + + String tenantId = "t-2"; + MLModelGroupGetRequest req = MLModelGroupGetRequest.builder().modelGroupId("mg-456").tenantId(tenantId).build(); + + GetResponse getResponse = prepareMLModelGroup(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getModelGroupTransportAction.doExecute(null, req, actionListener); + + // Legacy validation MUST run + verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), eq("mg-456"), anyString(), eq(client), any()); + + // Successful response + ArgumentCaptor captor = ArgumentCaptor.forClass(MLModelGroupGetResponse.class); + verify(actionListener, times(1)).onResponse(captor.capture()); + assertEquals("modelGroup", captor.getValue().getMlModelGroup().getName()); + } + + public void test_Get_RSC_FeatureDisabled_UsesLegacyValidation() throws IOException { + // Entire feature disabled → legacy path + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + + // Allow legacy access validation to pass + doAnswer(inv -> { + ActionListener l = inv.getArgument(4); + l.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), anyString(), anyString(), any(), any()); + + String tenantId = "t-3"; + MLModelGroupGetRequest req = MLModelGroupGetRequest.builder().modelGroupId("mg-789").tenantId(tenantId).build(); + + GetResponse getResponse = prepareMLModelGroup(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getModelGroupTransportAction.doExecute(null, req, actionListener); + + // Legacy validation MUST run + verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), eq("mg-789"), anyString(), eq(client), any()); + + // Successful response + ArgumentCaptor captor = ArgumentCaptor.forClass(MLModelGroupGetResponse.class); + verify(actionListener, times(1)).onResponse(captor.capture()); + assertEquals("modelGroup", captor.getValue().getMlModelGroup().getName()); + } + public GetResponse prepareMLModelGroup() throws IOException { MLModelGroup mlModelGroup = MLModelGroup .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index 7287b551c6..c5d3fa640c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -9,6 +9,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -122,6 +123,8 @@ public void setUp() throws Exception { // Simplify the merged query for tests when(modelAccessControlHelper.mergeWithAccessFilter(any(QueryBuilder.class), any(Set.class))) .thenAnswer(inv -> QueryBuilders.termQuery("dummy", "value")); + + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } @Override @@ -267,6 +270,7 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() { public void testResourceSharingEnabled_successPath_filtersByAccessibleIds_andCallsSdkClient() { ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); ArgumentCaptor>> rscListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -300,6 +304,7 @@ public void testResourceSharingEnabled_successPath_filtersByAccessibleIds_andCal public void testResourceSharingEnabled_failSafePath_usesEmptySet_andCallsSdkClient() { ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); ArgumentCaptor>> rscListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -323,6 +328,33 @@ public void testResourceSharingEnabled_failSafePath_usesEmptySet_andCallsSdkClie verify(actionListener).onResponse(any(SearchResponse.class)); } + @Test + public void testResourceSharingEnabled_notMarkedAsProtectedType_skipsEvaluation() { + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + + // Feature disabled for this type => skip resource sharing + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + CompletableFuture future = new CompletableFuture<>(); + when(sdkClient.searchDataObjectAsync(any(SearchDataObjectRequest.class))).thenReturn(future); + + SearchRequest sr = new SearchRequest(new String[] { CommonValue.ML_MODEL_GROUP_INDEX }, new SearchSourceBuilder()); + MLSearchActionRequest req = new MLSearchActionRequest(sr, "tenant-2"); + + searchModelGroupTransportAction.doExecute(null, req, actionListener); + + // Complete the search as if it returned normally + future.complete(emptySearchDataObjectResponse()); + + // Verify RSC is NOT called + verify(rsc, never()).getAccessibleResourceIds(any(), any()); + + // Verify we executed normal flow (i.e., used SDK and returned a response) + verify(sdkClient).searchDataObjectAsync(any(SearchDataObjectRequest.class)); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + @Test public void testThreadContext_isRestored_afterExecution() { String key = "test-header"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index c62716d793..5d7d1d5917 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -8,6 +8,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,6 +43,7 @@ import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -55,12 +57,14 @@ import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; +//@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TransportUpdateModelGroupActionTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -164,6 +168,8 @@ public void setup() throws IOException { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } public void test_NonOwnerChangingAccessContentException() { @@ -454,6 +460,80 @@ public void test_ExceptionSecurityDisabledCluster() { ); } + public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws Exception { + // Enable RSC fast-path. + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + // No ACL changes in request (so even legacy would pass, but we won't go there). + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy validation was skipped. + verify(modelAccessControlHelper, times(0)).isSecurityEnabledAndModelAccessControlEnabled(any()); + verify(modelAccessControlHelper, times(0)).isOwner(any(), any()); + verify(modelAccessControlHelper, times(0)).isAdmin(any()); + + // Update succeeded. + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + + public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws Exception { + // RSC feature on, but type disabled → legacy path. + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + // Allow legacy validation to pass: + // security/model-access-control enabled: + when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); + // user is allowed to update: + when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); + when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); + + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy path consulted helper + verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); + + // Update succeeded + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + + public void test_Update_RSC_FeatureDisabled_UsesLegacyValidation() throws Exception { + // Entire feature disabled → legacy path. + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + + // Allow legacy validation to pass: + when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); + when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); + when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); + + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy path consulted helper + verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); + + // Update succeeded + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { MLUpdateModelGroupInput UpdateModelGroupInput = MLUpdateModelGroupInput .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 9a9ec854b2..036601395a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.isA; @@ -18,6 +19,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.Set; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -39,12 +41,16 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.handler.MLSearchHandler; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; @@ -58,6 +64,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -165,6 +172,7 @@ public void setup() { mock(SearchResponse.Clusters.class), null ); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } public void test_DoExecute_admin() { @@ -342,6 +350,114 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws In verify(client, times(2)).search(any(), any()); } + @Test + public void test_RSC_featureEnabled_typeEnabled_callsGetAccessibleIds() throws Exception { + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + // Feature enabled for this type => apply resource sharing + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + var user = new User(); + + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, user.toString()); + // model access control enabled & index exists + when(modelAccessControlHelper.modelAccessControlEnabled()).thenReturn(true); + + // When RSC asks for accessible IDs, capture listener and respond + ArgumentCaptor>> rscListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doAnswer(inv -> { + ActionListener> l = inv.getArgument(1); + // Simulate async success with some accessible IDs + l.onResponse(Set.of("model_group_IT")); + return null; + }).when(rsc).getAccessibleResourceIds(eq(CommonValue.ML_MODEL_GROUP_INDEX), rscListenerCaptor.capture()); + + // The final remote search goes through client.search(...); return a normal response + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + SearchRequest req = new SearchRequest(new String[] { "ml_model" }, new SearchSourceBuilder()); + mlSearchActionRequest = new MLSearchActionRequest(req, "tenant-1"); + + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + // Verify RSC path was taken + verify(rsc, times(1)).getAccessibleResourceIds(eq(CommonValue.ML_MODEL_GROUP_INDEX), any()); + + // Verify we executed the final search and returned + verify(client, times(1)).search(any(), any()); + verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); + + } + + @Test + public void test_RSC_featureEnabled_typeDisabled_skipsRSC() throws Exception { + // Feature enabled globally but TYPE disabled → shouldUseResourceAuthz = false + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + // Legacy path will query model-group index then models → 2 searches + when(modelAccessControlHelper.modelAccessControlEnabled()).thenReturn(true); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + + // First search (model-groups) returns one group to force the legacy gate path + SearchResponse mgResponse = createModelGroupSearchResponse(); + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(mgResponse); + return null; + }).doAnswer(inv -> { // second call: final models search + ActionListener l = inv.getArgument(1); + l.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + // RSC should not be called at all + verify(rsc, times(0)).getAccessibleResourceIds(any(), any()); + + // Legacy path did two searches + verify(client, times(2)).search(any(), any()); + verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); + } + + @Test + public void test_RSC_featureDisabled_skipsRSC_entirely() throws Exception { + // Entire feature disabled → skip resource sharing + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + // With feature disabled, we go to legacy path. Make it return zero model-groups, then proceed. + when(modelAccessControlHelper.modelAccessControlEnabled()).thenReturn(true); + when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + + // First search returns 0 hits; second search still executes with rewritten "missing only" + doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(searchResponse /* 0 hits already configured in setup() */); + return null; + }).doAnswer(inv -> { + ActionListener l = inv.getArgument(1); + l.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); + + // Verify RSC is never invoked + verify(rsc, times(0)).getAccessibleResourceIds(any(), any()); + + // Legacy path still performed searches + verify(client, times(2)).search(any(), any()); + verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); + } + private SearchResponse createModelGroupSearchResponse() throws IOException { String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index 1ebf9eac6d..b1b1c5f675 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLModelGroup.MLModelGroupBuilder; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.impl.SdkClientFactory; @@ -101,6 +102,8 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } public void setupModelGroup(String owner, String access, List backendRoles) throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java index 03bcad8ec4..fc0ca4fbe2 100644 --- a/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java @@ -37,6 +37,10 @@ public void tearDown() { ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); } + private static Object getResourceSharingClient() { + return ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + } + @Test public void testGetResourceProviders_returnsExpectedSingleProvider() { MLResourceSharingExtension ext = new MLResourceSharingExtension(); @@ -71,15 +75,11 @@ public void testAssignResourceSharingClient_setsClientOnAccessor() { MLResourceSharingExtension ext = new MLResourceSharingExtension(); ResourceSharingClient mockClient = mock(ResourceSharingClient.class); - assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), is(nullValue())); + assertThat(getResourceSharingClient(), is(nullValue())); ext.assignResourceSharingClient(mockClient); - assertThat( - "Accessor should hold the client passed to extension", - ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), - equalTo(mockClient) - ); + assertThat("Accessor should hold the client passed to extension", getResourceSharingClient(), equalTo(mockClient)); } @Test @@ -90,16 +90,12 @@ public void testAssignResourceSharingClient_overwritesExistingClient() { // Prime with the first client ResourceSharingClientAccessor.getInstance().setResourceSharingClient(first); - assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), equalTo(first)); + assertThat(getResourceSharingClient(), equalTo(first)); // Now assign a new one via the extension ext.assignResourceSharingClient(second); - assertThat( - "Accessor should be updated to the new client", - ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), - equalTo(second) - ); + assertThat("Accessor should be updated to the new client", getResourceSharingClient(), equalTo(second)); } @Test