Skip to content

Commit 53bfeba

Browse files
authored
Add AsyncQueryRequestContext to QueryIdProvider parameter (#2870)
Signed-off-by: Tomoyuki Morita <[email protected]>
1 parent aa7a690 commit 53bfeba

File tree

5 files changed

+59
-14
lines changed

5 files changed

+59
-14
lines changed

async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55

66
package org.opensearch.sql.spark.dispatcher;
77

8+
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
89
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
910
import org.opensearch.sql.spark.utils.IDUtils;
1011

1112
/** Generates QueryId by embedding Datasource name and random UUID */
1213
public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider {
1314

1415
@Override
15-
public String getQueryId(DispatchQueryRequest dispatchQueryRequest) {
16+
public String getQueryId(
17+
DispatchQueryRequest dispatchQueryRequest,
18+
AsyncQueryRequestContext asyncQueryRequestContext) {
1619
return IDUtils.encode(dispatchQueryRequest.getDatasource());
1720
}
1821
}

async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
package org.opensearch.sql.spark.dispatcher;
77

8+
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
89
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
910

1011
/** Interface for extension point to specify queryId. Called when new query is executed. */
1112
public interface QueryIdProvider {
12-
String getQueryId(DispatchQueryRequest dispatchQueryRequest);
13+
String getQueryId(
14+
DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext);
1315
}

async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery(
6969
DataSourceMetadata dataSourceMetadata) {
7070
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
7171
DispatchQueryContext context =
72-
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
72+
getDefaultDispatchContextBuilder(
73+
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
7374
.indexQueryDetails(indexQueryDetails)
7475
.asyncQueryRequestContext(asyncQueryRequestContext)
7576
.build();
@@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery(
8485
DataSourceMetadata dataSourceMetadata) {
8586

8687
DispatchQueryContext context =
87-
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
88+
getDefaultDispatchContextBuilder(
89+
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
8890
.asyncQueryRequestContext(asyncQueryRequestContext)
8991
.build();
9092

@@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery(
9395
}
9496

9597
private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
96-
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
98+
DispatchQueryRequest dispatchQueryRequest,
99+
DataSourceMetadata dataSourceMetadata,
100+
AsyncQueryRequestContext asyncQueryRequestContext) {
97101
return DispatchQueryContext.builder()
98102
.dataSourceMetadata(dataSourceMetadata)
99103
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
100-
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest));
104+
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext));
101105
}
102106

103107
private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(

async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ public void setUp() {
185185
public void createDropIndexQuery() {
186186
givenSparkExecutionEngineConfigIsSupplied();
187187
givenValidDataSourceMetadataExist();
188-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
188+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
189189
String indexName = "flint_datasource_name_table_name_index_name_index";
190190
givenFlintIndexMetadataExists(indexName);
191191
givenCancelJobRunSucceed();
@@ -209,7 +209,7 @@ public void createDropIndexQuery() {
209209
public void createVacuumIndexQuery() {
210210
givenSparkExecutionEngineConfigIsSupplied();
211211
givenValidDataSourceMetadataExist();
212-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
212+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
213213
String indexName = "flint_datasource_name_table_name_index_name_index";
214214
givenFlintIndexMetadataExists(indexName);
215215

@@ -231,7 +231,7 @@ public void createVacuumIndexQuery() {
231231
public void createAlterIndexQuery() {
232232
givenSparkExecutionEngineConfigIsSupplied();
233233
givenValidDataSourceMetadataExist();
234-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
234+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
235235
String indexName = "flint_datasource_name_table_name_index_name_index";
236236
givenFlintIndexMetadataExists(indexName);
237237
givenCancelJobRunSucceed();
@@ -261,7 +261,7 @@ public void createAlterIndexQuery() {
261261
public void createStreamingQuery() {
262262
givenSparkExecutionEngineConfigIsSupplied();
263263
givenValidDataSourceMetadataExist();
264-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
264+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
265265
when(awsemrServerless.startJobRun(any()))
266266
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));
267267

@@ -297,7 +297,7 @@ private void verifyStartJobRunCalled() {
297297
public void createCreateIndexQuery() {
298298
givenSparkExecutionEngineConfigIsSupplied();
299299
givenValidDataSourceMetadataExist();
300-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
300+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
301301
when(awsemrServerless.startJobRun(any()))
302302
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));
303303

@@ -321,7 +321,7 @@ public void createCreateIndexQuery() {
321321
public void createRefreshQuery() {
322322
givenSparkExecutionEngineConfigIsSupplied();
323323
givenValidDataSourceMetadataExist();
324-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
324+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
325325
when(awsemrServerless.startJobRun(any()))
326326
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));
327327

@@ -344,7 +344,7 @@ public void createInteractiveQuery() {
344344
givenSparkExecutionEngineConfigIsSupplied();
345345
givenValidDataSourceMetadataExist();
346346
givenSessionExists();
347-
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
347+
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
348348
when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID);
349349
givenSessionExists(); // called twice
350350
when(awsemrServerless.startJobRun(any()))
@@ -538,7 +538,8 @@ private void givenGetJobRunReturnJobRunWithState(String state) {
538538
}
539539

540540
private void verifyGetQueryIdCalled() {
541-
verify(queryIdProvider).getQueryId(dispatchQueryRequestArgumentCaptor.capture());
541+
verify(queryIdProvider)
542+
.getQueryId(dispatchQueryRequestArgumentCaptor.capture(), eq(asyncQueryRequestContext));
542543
DispatchQueryRequest dispatchQueryRequest = dispatchQueryRequestArgumentCaptor.getValue();
543544
assertEquals(ACCOUNT_ID, dispatchQueryRequest.getAccountId());
544545
assertEquals(APPLICATION_ID, dispatchQueryRequest.getApplicationId());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.spark.dispatcher;
7+
8+
import static org.junit.jupiter.api.Assertions.assertNotNull;
9+
import static org.mockito.Mockito.verifyNoInteractions;
10+
11+
import org.junit.jupiter.api.Test;
12+
import org.junit.jupiter.api.extension.ExtendWith;
13+
import org.mockito.Mock;
14+
import org.mockito.junit.jupiter.MockitoExtension;
15+
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
16+
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
17+
18+
@ExtendWith(MockitoExtension.class)
19+
class DatasourceEmbeddedQueryIdProviderTest {
20+
@Mock AsyncQueryRequestContext asyncQueryRequestContext;
21+
22+
DatasourceEmbeddedQueryIdProvider datasourceEmbeddedQueryIdProvider =
23+
new DatasourceEmbeddedQueryIdProvider();
24+
25+
@Test
26+
public void test() {
27+
String queryId =
28+
datasourceEmbeddedQueryIdProvider.getQueryId(
29+
DispatchQueryRequest.builder().datasource("DATASOURCE").build(),
30+
asyncQueryRequestContext);
31+
32+
assertNotNull(queryId);
33+
verifyNoInteractions(asyncQueryRequestContext);
34+
}
35+
}

0 commit comments

Comments
 (0)