Skip to content

Commit 1928ef1

Browse files
authored
Enhance: support mocking final functions and classes (opensearch-project#1528)
* support mock inline Signed-off-by: yuye-aws <[email protected]> * adjust ut Signed-off-by: yuye-aws <[email protected]> * UT fix for inference processor Signed-off-by: yuye-aws <[email protected]> * UT fix for NeuralSparseTwoPhaseProcessorTests Signed-off-by: yuye-aws <[email protected]> * fix UT in HybridQueryUtilTests Signed-off-by: yuye-aws <[email protected]> * fix tests in HybridQueryWeightTests Signed-off-by: yuye-aws <[email protected]> * update change log Signed-off-by: yuye-aws <[email protected]> * fix UT in HybridCollectorManagerTests Signed-off-by: yuye-aws <[email protected]> * UT fix for NeuralQueryBuilderRewriteTests Signed-off-by: yuye-aws <[email protected]> * UT fix after rebase Signed-off-by: yuye-aws <[email protected]> * fix forbidden API Signed-off-by: yuye-aws <[email protected]> * Address code review comments Signed-off-by: yuye-aws <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]>
1 parent 9a943d4 commit 1928ef1

File tree

12 files changed

+54
-18
lines changed

12 files changed

+54
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616

1717
### Infrastructure
1818

19+
- [Unit Test] Enable mocking of final classes and static functions ([#1528](https://github.com/opensearch-project/neural-search/pull/1528)).
20+
1921
### Documentation
2022

2123
### Maintenance

build.gradle

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ dependencies {
291291
testImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
292292
testImplementation "org.opensearch.plugin:parent-join-client:${opensearch_version}"
293293
testImplementation 'org.assertj:assertj-core:3.24.2'
294+
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: "${versions.bytebuddy}"
295+
testImplementation group: 'org.objenesis', name: 'objenesis', version: "${versions.objenesis}"
296+
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: "${versions.bytebuddy}"
294297
}
295298

296299
// In order to add the jar to the classpath, we need to unzip the
@@ -320,10 +323,7 @@ def _numNodes = findProperty('numNodes') as Integer ?: 1
320323
test {
321324
include '**/*Tests.class'
322325
systemProperty 'tests.security.manager', 'false'
323-
filter {
324-
// TODO: include sparse tests
325-
excludeTestsMatching "org.opensearch.neuralsearch.sparse.codec.SparseTermsLuceneReaderTests.*"
326-
}
326+
systemProperty "jdk.attach.allowAttachSelf", true
327327
}
328328

329329
// Setting up Integration Tests

src/test/java/org/opensearch/neuralsearch/processor/AgenticQueryTranslatorProcessorTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.index.query.QueryBuilder;
1414
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
1515
import org.opensearch.neuralsearch.query.AgenticSearchQueryBuilder;
16+
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
1617
import org.opensearch.search.aggregations.AggregationBuilders;
1718
import org.opensearch.search.builder.SearchSourceBuilder;
1819
import org.opensearch.search.pipeline.PipelineProcessingContext;
@@ -67,6 +68,7 @@ public void setUp() throws Exception {
6768
mockContext = mock(PipelineProcessingContext.class);
6869
mockSettingsAccessor = mock(NeuralSearchSettingsAccessor.class);
6970
when(mockSettingsAccessor.isAgenticSearchEnabled()).thenReturn(true);
71+
EventStatsManager.instance().initialize(mockSettingsAccessor);
7072

7173
// Use factory to create processor since constructor is private
7274
AgenticQueryTranslatorProcessor.Factory factory = new AgenticQueryTranslatorProcessor.Factory(

src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import lombok.Getter;
88
import org.junit.Before;
99
import org.mockito.ArgumentCaptor;
10+
import org.mockito.Mock;
1011
import org.mockito.MockitoAnnotations;
12+
import org.opensearch.cluster.ClusterState;
13+
import org.opensearch.cluster.metadata.IndexMetadata;
14+
import org.opensearch.cluster.metadata.Metadata;
1115
import org.opensearch.cluster.service.ClusterService;
1216
import org.opensearch.common.settings.Settings;
1317
import org.opensearch.core.action.ActionListener;
@@ -26,7 +30,7 @@
2630
import java.util.function.Consumer;
2731

2832
import static org.mockito.ArgumentMatchers.any;
29-
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
33+
import static org.mockito.ArgumentMatchers.anyString;
3034
import static org.mockito.Mockito.mock;
3135
import static org.mockito.Mockito.never;
3236
import static org.mockito.Mockito.verify;
@@ -37,7 +41,14 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase {
3741
private MLCommonsClientAccessor clientAccessor;
3842
private Environment environment;
3943

40-
private ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS);
44+
@Mock
45+
private ClusterService clusterService;
46+
@Mock
47+
private ClusterState clusterState;
48+
@Mock
49+
private Metadata metadata;
50+
@Mock
51+
private IndexMetadata indexMetadata;
4152

4253
private static final String TAG = "tag";
4354
private static final String TYPE = "type";
@@ -50,6 +61,10 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase {
5061
@Before
5162
public void setup() {
5263
MockitoAnnotations.openMocks(this);
64+
when(clusterService.state()).thenReturn(clusterState);
65+
when(clusterState.metadata()).thenReturn(metadata);
66+
when(metadata.index(anyString())).thenReturn(indexMetadata);
67+
when(indexMetadata.getSettings()).thenReturn(null);
5368
clientAccessor = mock(MLCommonsClientAccessor.class);
5469
environment = mock(Environment.class);
5570
Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build();

src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.HashMap;
2424
import java.util.Map;
2525

26+
import static org.mockito.ArgumentMatchers.anyFloat;
2627
import static org.mockito.ArgumentMatchers.eq;
2728
import static org.mockito.Mockito.when;
2829
import static org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils.setUpClusterService;
@@ -215,6 +216,7 @@ public void testProcessRequest_whenTwoPhaseEnabledWithNeuralQuerySparseEmbedding
215216

216217
when(neuralQueryBuilder.isTargetSparseEmbedding(searchRequest)).thenReturn(true);
217218
when(neuralQueryBuilder.prepareTwoPhaseQuery(eq(0.5f), eq(PruneType.MAX_RATIO))).thenReturn(copy);
219+
when(copy.boost(anyFloat())).thenReturn(copy);
218220

219221
processor.processRequest(searchRequest);
220222

src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
public class HybridQueryWeightTests extends OpenSearchQueryTestCase {
4747

4848
private static final String TERM_QUERY_TEXT = "keyword";
49-
private static final String RANGE_FIELD = "date _range";
49+
private static final String RANGE_FIELD = "date_range";
5050
private static final String FROM_TEXT = "123";
5151
private static final String TO_TEXT = "456";
5252

@@ -103,6 +103,7 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() {
103103
@SneakyThrows
104104
public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() {
105105
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
106+
when(mockQueryShardContext.convertToShardContext()).thenReturn(mockQueryShardContext);
106107
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
107108
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
108109

src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderRewriteTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ public void testRewrite_whenTargetKnnAndOnShardDirectly_thenKnnQueryBuilder() {
296296
// prepare data to rewrite on shard level
297297
final QueryShardContext queryShardContext = mock(QueryShardContext.class);
298298
final KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class);
299+
when(queryShardContext.convertToShardContext()).thenReturn(queryShardContext);
299300
when(queryShardContext.fieldMapper(FIELD_NAME)).thenReturn(knnVectorFieldType);
300301
when(knnVectorFieldType.typeName()).thenReturn(KNNVectorFieldMapper.CONTENT_TYPE);
301302

@@ -678,6 +679,7 @@ private QueryShardContext mockQueryShardContextForKnn(
678679
final QueryShardContext queryShardContext = mock(QueryShardContext.class);
679680
final SemanticFieldMapper.SemanticFieldType semanticFieldType = mock(SemanticFieldMapper.SemanticFieldType.class);
680681
final KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class);
682+
when(queryShardContext.convertToShardContext()).thenReturn(queryShardContext);
681683
when(queryShardContext.fieldMapper(FIELD_NAME)).thenReturn(semanticFieldType);
682684
final String semanticInfoFieldPath = semanticInfoFieldName == null ? FIELD_NAME + "_semantic_info" : semanticInfoFieldName;
683685
String embeddingFullPath;
@@ -842,6 +844,7 @@ public void testRewriteTargetSemanticRankFeatures_whenMultipleTargetIndices_then
842844
final QueryShardContext queryShardContext = mock(QueryShardContext.class);
843845
final Index index = mock(Index.class);
844846
when(queryShardContext.index()).thenReturn(index);
847+
when(queryShardContext.convertToShardContext()).thenReturn(queryShardContext);
845848
when(index.getName()).thenReturn(LOCAL_INDEX_NAME);
846849
final SemanticFieldMapper.SemanticFieldType semanticFieldType = mock(SemanticFieldMapper.SemanticFieldType.class);
847850
final RankFeaturesFieldMapper.RankFeaturesFieldType rankFeaturesFieldType = new RankFeaturesFieldMapper.RankFeaturesFieldType(
@@ -870,6 +873,7 @@ public void testRewriteTargetSemanticRankFeatures_whenMultipleTargetIndices_then
870873
final QueryShardContext queryShardContext2 = mock(QueryShardContext.class);
871874
final Index index2 = mock(Index.class);
872875
when(queryShardContext2.index()).thenReturn(index2);
876+
when(queryShardContext2.convertToShardContext()).thenReturn(queryShardContext2);
873877
when(index2.getName()).thenReturn(LOCAL_INDEX_NAME_2);
874878
final SemanticFieldMapper.SemanticFieldType semanticFieldType2 = mock(SemanticFieldMapper.SemanticFieldType.class);
875879
when(queryShardContext2.fieldMapper(FIELD_NAME)).thenReturn(semanticFieldType2);
@@ -974,6 +978,7 @@ public void testRewriteTargetSemanticRankFeatures_whenMultipleTargetIndicesOneWi
974978
final QueryShardContext queryShardContext = mock(QueryShardContext.class);
975979
final Index index = mock(Index.class);
976980
when(queryShardContext.index()).thenReturn(index);
981+
when(queryShardContext.convertToShardContext()).thenReturn(queryShardContext);
977982
when(index.getName()).thenReturn(LOCAL_INDEX_NAME);
978983
final SemanticFieldMapper.SemanticFieldType semanticFieldType = mock(SemanticFieldMapper.SemanticFieldType.class);
979984
final RankFeaturesFieldMapper.RankFeaturesFieldType rankFeaturesFieldType = new RankFeaturesFieldMapper.RankFeaturesFieldType(
@@ -1006,6 +1011,7 @@ public void testRewriteTargetSemanticRankFeatures_whenMultipleTargetIndicesOneWi
10061011
final QueryShardContext queryShardContext2 = mock(QueryShardContext.class);
10071012
final Index index2 = mock(Index.class);
10081013
when(queryShardContext2.index()).thenReturn(index2);
1014+
when(queryShardContext2.convertToShardContext()).thenReturn(queryShardContext2);
10091015
when(index2.getName()).thenReturn(LOCAL_INDEX_NAME_2);
10101016
final SemanticFieldMapper.SemanticFieldType semanticFieldType2 = mock(SemanticFieldMapper.SemanticFieldType.class);
10111017
when(queryShardContext2.fieldMapper(FIELD_NAME)).thenReturn(semanticFieldType2);

src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the
977977
ScorerSupplier mockScorerSupplier = mock(ScorerSupplier.class);
978978
when(mockScorerSupplier.get(anyLong())).thenReturn(rescoreScorer);
979979
when(mockScorerSupplier.cost()).thenReturn(1L);
980-
when(rescoreWeight.scorerSupplier(any(LeafReaderContext.class))).thenReturn(mockScorerSupplier);
980+
when(rescoreWeight.scorer(any(LeafReaderContext.class))).thenReturn(rescoreScorer);
981981

982982
when(rescoreScorer.docID()).thenReturn(1);
983983
DocIdSetIterator iterator = mock(DocIdSetIterator.class);

src/test/java/org/opensearch/neuralsearch/stats/NeuralStatsInputTests.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@ public void test_builderWithAllFields() {
8181
public void test_streamInput() throws IOException {
8282
StreamInput mockInput = mock(StreamInput.class);
8383

84-
// Have to return the readByte since readBoolean can't be mocked
85-
when(mockInput.readByte()).thenReturn((byte) 1) // true for includeMetadata
86-
.thenReturn((byte) 1) // true for flatten
87-
.thenReturn((byte) 0) // false for includeIndividualNodes
88-
.thenReturn((byte) 0) // false for includeAllNodes
89-
.thenReturn((byte) 0); // false for includeInfo
84+
when(mockInput.readBoolean()).thenReturn(true) // true for includeMetadata
85+
.thenReturn(true) // true for flatten
86+
.thenReturn(false) // false for includeIndividualNodes
87+
.thenReturn(false) // false for includeAllNodes
88+
.thenReturn(false); // false for includeInfo
9089

9190
when(mockInput.readOptionalStringList()).thenReturn(Arrays.asList(NODE_ID_1, NODE_ID_2));
9291
when(mockInput.readOptionalEnumSet(EventStatName.class)).thenReturn(EnumSet.of(EVENT_STAT));
@@ -103,7 +102,7 @@ public void test_streamInput() throws IOException {
103102
assertFalse(input.isIncludeAllNodes());
104103
assertFalse(input.isIncludeInfo());
105104

106-
verify(mockInput, times(5)).readByte();
105+
verify(mockInput, times(5)).readBoolean();
107106
verify(mockInput, times(1)).readOptionalStringList();
108107
verify(mockInput, times(2)).readOptionalEnumSet(any());
109108
}

src/test/java/org/opensearch/neuralsearch/transport/NeuralStatsResponseTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ public void test_constructor() throws IOException {
6767
.thenReturn((Map) aggregatedNodeStats)
6868
.thenReturn((Map) nodeIdToNodeEventStats);
6969

70-
// Booleans as bytes
71-
when(mockStreamInput.readByte()).thenReturn((byte) 1).thenReturn((byte) 0);
70+
when(mockStreamInput.readBoolean()).thenReturn(true).thenReturn(false);
7271

7372
NeuralStatsResponse response = new NeuralStatsResponse(mockStreamInput);
7473

0 commit comments

Comments
 (0)