@@ -395,7 +395,7 @@ public void testRejectIfNeeded_whenFeatureIsNotEnabled() {
395
395
}
396
396
397
397
public void testOnTaskCompleted () {
398
- Task task = createMockTaskWithResourceStats ( SearchTask . class , 100 , 200 , 0 , 12 );
398
+ Task task = new SearchTask ( 12 , "" , "" , () -> "" , null , null );
399
399
mockThreadPool = new TestThreadPool ("queryGroupServiceTests" );
400
400
mockThreadPool .getThreadContext ().putHeader (QueryGroupTask .QUERY_GROUP_ID_HEADER , "testId" );
401
401
QueryGroupState queryGroupState = new QueryGroupState ();
@@ -442,7 +442,7 @@ public void testOnTaskCompleted() {
442
442
}
443
443
444
444
public void testShouldSBPHandle () {
445
- QueryGroupTask task = createMockTaskWithResourceStats (SearchTask .class , 100 , 200 , 0 , 12 );
445
+ SearchTask task = createMockTaskWithResourceStats (SearchTask .class , 100 , 200 , 0 , 12 );
446
446
QueryGroupState queryGroupState = new QueryGroupState ();
447
447
Set <QueryGroup > activeQueryGroups = new HashSet <>();
448
448
mockQueryGroupStateMap .put ("testId" , queryGroupState );
@@ -464,6 +464,8 @@ public void testShouldSBPHandle() {
464
464
mockThreadPool = new TestThreadPool ("queryGroupServiceTests" );
465
465
mockThreadPool .getThreadContext ()
466
466
.putHeader (QueryGroupTask .QUERY_GROUP_ID_HEADER , QueryGroupTask .DEFAULT_QUERY_GROUP_ID_SUPPLIER .get ());
467
+ // we haven't set the queryGroupId yet SBP should still track the task for cancellation
468
+ assertTrue (queryGroupService .shouldSBPHandle (task ));
467
469
task .setQueryGroupId (mockThreadPool .getThreadContext ());
468
470
assertTrue (queryGroupService .shouldSBPHandle (task ));
469
471
@@ -490,6 +492,15 @@ public void testShouldSBPHandle() {
490
492
);
491
493
assertTrue (queryGroupService .shouldSBPHandle (task ));
492
494
495
+ mockThreadPool .shutdownNow ();
496
+
497
+ // test the case when SBP should not track the task
498
+ when (mockWorkloadManagementSettings .getWlmMode ()).thenReturn (WlmMode .ENABLED );
499
+ task = new SearchTask (1 , "" , "test" , () -> "" , null , null );
500
+ mockThreadPool = new TestThreadPool ("queryGroupServiceTests" );
501
+ mockThreadPool .getThreadContext ().putHeader (QueryGroupTask .QUERY_GROUP_ID_HEADER , "testId" );
502
+ task .setQueryGroupId (mockThreadPool .getThreadContext ());
503
+ assertFalse (queryGroupService .shouldSBPHandle (task ));
493
504
}
494
505
495
506
private static Set <QueryGroup > getActiveQueryGroups (
0 commit comments