Skip to content

Commit 711c67a

Browse files
committed
Tenant Aware Integ Tests
Signed-off-by: Daniel Widdis <[email protected]>
1 parent 51bd8cb commit 711c67a

9 files changed

+1872
-3
lines changed

.github/workflows/CI-workflow.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ jobs:
8484
echo "::add-mask::$COHERE_KEY" &&
8585
echo "build and run tests" && ./gradlew build -x spotlessJava &&
8686
echo "Publish to Maven Local" && ./gradlew publishToMavenLocal -x spotlessJava &&
87-
echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava'
87+
echo "Multi Nodes Integration Testing" && ./gradlew integTest -PnumNodes=3 -x spotlessJava &&
88+
echo "Tenant Aware Integration Testing" && ./gradlew integTest -PnumNodes=3 -Dtests.rest.tenantaware=true -x spotlessJava'
8889
plugin=`basename $(ls plugin/build/distributions/*.zip)`
8990
echo $plugin
9091
mv -v plugin/build/distributions/$plugin ./
@@ -235,6 +236,9 @@ jobs:
235236
echo "::add-mask::$OPENAI_KEY"
236237
echo "::add-mask::$COHERE_KEY"
237238
./gradlew.bat build -x spotlessJava
239+
- name: Tenant Aware Tests
240+
shell: bash
241+
run: ./gradlew.bat integTest -Dtests.rest.tenantaware=true -x spotlessJava
238242
- name: Publish to Maven Local
239243
run: |
240244
./gradlew publishToMavenLocal -x spotlessJava

plugin/build.gradle

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,19 @@ integTest {
178178
systemProperty "user", System.getProperty("user")
179179
systemProperty "password", System.getProperty("password")
180180

181+
// Only tenant aware test if set
182+
if (System.getProperty("tests.rest.tenantaware") == "true") {
183+
filter {
184+
includeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT"
185+
}
186+
systemProperty "plugins.ml_commons.multi_tenancy_enabled", "true"
187+
}
188+
181189
// Only rest case can run with remote cluster
182-
if (System.getProperty("tests.rest.cluster") != null) {
190+
if (System.getProperty("tests.rest.cluster") != null && System.getProperty("tests.rest.tenantaware") == null) {
183191
filter {
184192
includeTestsMatching "org.opensearch.ml.rest.*IT"
193+
excludeTestsMatching "org.opensearch.ml.rest.*TenantAwareIT"
185194
// mock LLM run in localhost, it will not reachable for docker or remote cluster
186195
excludeTestsMatching "org.opensearch.ml.tools.VisualizationsToolIT"
187196
}
@@ -203,6 +212,28 @@ integTest {
203212

204213
// The 'doFirst' delays till execution time.
205214
doFirst {
215+
if (System.getProperty("tests.rest.tenantaware") == "true") {
216+
def ymlFile = file("$buildDir/testclusters/integTest-0/config/opensearch.yml")
217+
if (ymlFile.exists()) {
218+
ymlFile.withWriterAppend {
219+
writer ->
220+
writer.write("\n# Set multitenancy\n")
221+
writer.write("plugins.ml_commons.multi_tenancy_enabled: true\n")
222+
}
223+
// TODO this properly uses the remote client factory but needs a remote cluster set up
224+
// TODO get the endpoint from a system property
225+
if (System.getProperty("tests.rest.cluster") != null) {
226+
ymlFile.withWriterAppend { writer ->
227+
writer.write("\n# Use a remote cluster\n")
228+
writer.write("plugins.ml_commons.remote_metadata_type: RemoteOpenSearch\n")
229+
writer.write("plugins.ml_commons.remote_metadata_endpoint: https://127.0.0.1:9200\n")
230+
}
231+
}
232+
} else {
233+
throw new GradleException("opensearch.yml not found at: $ymlFile")
234+
}
235+
}
236+
206237
// Tell the test JVM if the cluster JVM is running under a debugger so that tests can
207238
// use longer timeouts for requests.
208239
def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import static java.nio.charset.StandardCharsets.UTF_8;
9+
import static java.util.Collections.emptyMap;
10+
import static java.util.Collections.singletonList;
11+
import static org.opensearch.common.xcontent.XContentType.JSON;
12+
import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER;
13+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
14+
15+
import java.io.IOException;
16+
import java.util.HashMap;
17+
import java.util.List;
18+
import java.util.Map;
19+
import java.util.concurrent.TimeUnit;
20+
import java.util.stream.Collectors;
21+
22+
import org.apache.http.Header;
23+
import org.apache.http.message.BasicHeader;
24+
import org.opensearch.action.search.SearchResponse;
25+
import org.opensearch.client.Response;
26+
import org.opensearch.client.ResponseException;
27+
import org.opensearch.common.xcontent.json.JsonXContent;
28+
import org.opensearch.core.common.bytes.BytesArray;
29+
import org.opensearch.core.rest.RestStatus;
30+
import org.opensearch.core.xcontent.DeprecationHandler;
31+
import org.opensearch.core.xcontent.NamedXContentRegistry;
32+
import org.opensearch.core.xcontent.XContentParser;
33+
import org.opensearch.ml.common.input.Constants;
34+
import org.opensearch.ml.utils.TestHelper;
35+
import org.opensearch.rest.RestRequest;
36+
import org.opensearch.test.rest.FakeRestRequest;
37+
38+
public abstract class MLCommonsTenantAwareRestTestCase extends MLCommonsRestTestCase {
39+
40+
// Toggle to run DDB tests
41+
// TODO: Get this from a property
42+
protected static final boolean DDB = false;
43+
44+
protected static final String DOC_ID = "_id";
45+
46+
// REST methods
47+
protected static final String POST = RestRequest.Method.POST.name();
48+
protected static final String GET = RestRequest.Method.GET.name();
49+
protected static final String PUT = RestRequest.Method.PUT.name();
50+
protected static final String DELETE = RestRequest.Method.DELETE.name();
51+
52+
// REST paths; some subclasses need multiple of these
53+
protected static final String AGENTS_PATH = "/_plugins/_ml/agents/";
54+
protected static final String CONNECTORS_PATH = "/_plugins/_ml/connectors/";
55+
protected static final String MODELS_PATH = "/_plugins/_ml/models/";
56+
protected static final String MODEL_GROUPS_PATH = "/_plugins/_ml/model_groups/";
57+
58+
// REST body
59+
protected static final String MATCH_ALL_QUERY = "{\"query\":{\"match_all\":{}}}";
60+
protected static final String EMPTY_CONTENT = "{}";
61+
62+
// REST Response error reasons
63+
protected static final String MISSING_TENANT_REASON = "Tenant ID header is missing";
64+
protected static final String NO_PERMISSION_REASON = "You don't have permission to access this resource";
65+
protected static final String DEPLOYED_REASON =
66+
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete";
67+
68+
// Common constants and fields used in subclasses
69+
protected static final String CONNECTOR_ID = "connector_id";
70+
71+
protected String tenantId = randomAlphaOfLength(5);
72+
protected String otherTenantId = randomAlphaOfLength(6);
73+
74+
protected final RestRequest tenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
75+
.withHeaders(Map.of(TENANT_ID_HEADER, singletonList(tenantId)))
76+
.build();
77+
protected final RestRequest otherTenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
78+
.withHeaders(Map.of(TENANT_ID_HEADER, singletonList(otherTenantId)))
79+
.build();
80+
protected final RestRequest nullTenantRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
81+
.withHeaders(emptyMap())
82+
.build();
83+
84+
protected final RestRequest tenantMatchAllRequest = getRestRequestWithHeadersAndContent(tenantId, MATCH_ALL_QUERY);
85+
protected final RestRequest otherTenantMatchAllRequest = getRestRequestWithHeadersAndContent(otherTenantId, MATCH_ALL_QUERY);
86+
protected final RestRequest nullTenantMatchAllRequest = getRestRequestWithHeadersAndContent(null, MATCH_ALL_QUERY);
87+
88+
protected static boolean isMultiTenancyEnabled() throws IOException {
89+
// pass -Dtests.rest.tenantaware=true on gradle command line to enable
90+
return Boolean.parseBoolean(System.getProperty(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey()))
91+
|| Boolean.parseBoolean(System.getenv(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey()));
92+
}
93+
94+
protected static Response makeRequest(RestRequest request, String method, String path) throws IOException {
95+
return TestHelper
96+
.makeRequest(client(), method, path, request.params(), request.content().utf8ToString(), getHeadersFromRequest(request));
97+
}
98+
99+
private static List<Header> getHeadersFromRequest(RestRequest request) {
100+
return request
101+
.getHeaders()
102+
.entrySet()
103+
.stream()
104+
.map(e -> new BasicHeader(e.getKey(), e.getValue().stream().collect(Collectors.joining(","))))
105+
.collect(Collectors.toList());
106+
}
107+
108+
protected static RestRequest getRestRequestWithHeadersAndContent(String tenantId, String requestContent) {
109+
Map<String, List<String>> headers = new HashMap<>();
110+
if (tenantId != null) {
111+
headers.put(Constants.TENANT_ID_HEADER, singletonList(tenantId));
112+
}
113+
return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
114+
.withHeaders(headers)
115+
.withContent(new BytesArray(requestContent), JSON)
116+
.build();
117+
}
118+
119+
@SuppressWarnings("unchecked")
120+
protected static Map<String, Object> responseToMap(Response response) throws IOException {
121+
return parseResponseToMap(response);
122+
}
123+
124+
@SuppressWarnings("unchecked")
125+
protected static String getErrorReasonFromResponseMap(Map<String, Object> map) {
126+
// Two possible cases:
127+
String type = ((Map<String, String>) map.get("error")).get("type");
128+
129+
// {
130+
// "error": {
131+
// "root_cause": [
132+
// {
133+
// "type": "status_exception",
134+
// "reason": "You don't have permission to access this resource"
135+
// }
136+
// ],
137+
// "type": "status_exception",
138+
// "reason": "You don't have permission to access this resource"
139+
// },
140+
// "status": 403
141+
// }
142+
if ("status_exception".equals(type)) {
143+
return ((Map<String, String>) map.get("error")).get("reason");
144+
}
145+
146+
// Due to https://github.com/opensearch-project/ml-commons/issues/2958
147+
if ("m_l_resource_not_found_exception".equals(type)) {
148+
return ((Map<String, String>) map.get("error")).get("reason");
149+
}
150+
151+
// {
152+
// "error": {
153+
// "reason": "System Error",
154+
// "details": "You don't have permission to access this resource",
155+
// "type": "OpenSearchStatusException"
156+
// },
157+
// "status": 403
158+
// }
159+
return ((Map<String, String>) map.get("error")).get("details");
160+
}
161+
162+
protected static SearchResponse searchResponseFromResponse(Response response) throws IOException {
163+
XContentParser parser = JsonXContent.jsonXContent
164+
.createParser(
165+
NamedXContentRegistry.EMPTY,
166+
DeprecationHandler.IGNORE_DEPRECATIONS,
167+
TestHelper.httpEntityToString(response.getEntity()).getBytes(UTF_8)
168+
);
169+
return SearchResponse.fromXContent(parser);
170+
}
171+
172+
protected static void assertBadRequest(Response response) {
173+
assertEquals(RestStatus.BAD_REQUEST.getStatus(), response.getStatusLine().getStatusCode());
174+
}
175+
176+
protected static void assertNotFound(Response response) {
177+
assertEquals(RestStatus.NOT_FOUND.getStatus(), response.getStatusLine().getStatusCode());
178+
}
179+
180+
protected static void assertForbidden(Response response) {
181+
assertEquals(RestStatus.FORBIDDEN.getStatus(), response.getStatusLine().getStatusCode());
182+
}
183+
184+
protected static void assertUnauthorized(Response response) {
185+
assertEquals(RestStatus.UNAUTHORIZED.getStatus(), response.getStatusLine().getStatusCode());
186+
}
187+
188+
protected void refreshBeforeSearch(boolean extraDelay) {
189+
try {
190+
refreshAllIndices();
191+
Thread.sleep(extraDelay ? 60000L : 5000L);
192+
} catch (IOException | InterruptedException e) {
193+
// ignore
194+
}
195+
}
196+
197+
/**
198+
* Delete the specified document and wait until a search matches only the specified number of hits
199+
* @param tenantId The tenant ID to filter the search by
200+
* @param restPath The base path for the REST API
201+
* @param id The document ID to be appended to the REST API for deletion
202+
* @param hits The number of hits to expect after the deletion is processed
203+
* @throws Exception on failures with building or making the request
204+
*/
205+
protected static void deleteAndWaitForSearch(String tenantId, String restPath, String id, int hits) throws Exception {
206+
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
207+
.withHeaders(Map.of(TENANT_ID_HEADER, singletonList(tenantId)))
208+
.build();
209+
// First process the deletion. Dependent resources (e.g. model with connector) may cause 409 status until they are deleted
210+
assertBusy(() -> {
211+
try {
212+
Response deleteResponse = makeRequest(request, DELETE, restPath + id);
213+
// first successful deletion should produce an OK
214+
assertOK(deleteResponse);
215+
} catch (ResponseException e) {
216+
// repeat deletions can produce a 404, treat as a success
217+
assertNotFound(e.getResponse());
218+
}
219+
}, 20, TimeUnit.SECONDS);
220+
// Deletion processed, now wait for it to disappear from search
221+
RestRequest searchRequest = getRestRequestWithHeadersAndContent(tenantId, MATCH_ALL_QUERY);
222+
assertBusy(() -> {
223+
Response response = makeRequest(searchRequest, GET, restPath + "_search");
224+
assertOK(response);
225+
SearchResponse searchResponse = searchResponseFromResponse(response);
226+
assertEquals(hits, searchResponse.getHits().getTotalHits().value);
227+
}, 20, TimeUnit.SECONDS);
228+
}
229+
230+
protected static String registerRemoteModelContent(String description, String connectorId, String modelGroupId) {
231+
StringBuilder sb = new StringBuilder();
232+
sb.append("{\n");
233+
sb.append(" \"name\": \"remote model for connector_id ").append(connectorId).append("\",\n");
234+
sb.append(" \"function_name\": \"remote\",\n");
235+
sb.append(" \"description\": \"").append(description).append("\",\n");
236+
if (modelGroupId != null) {
237+
sb.append(" \"model_group_id\": \"").append(modelGroupId).append("\",\n");
238+
}
239+
sb.append(" \"connector_id\": \"").append(connectorId).append("\"\n");
240+
sb.append("}");
241+
return sb.toString();
242+
}
243+
}

0 commit comments

Comments
 (0)