forked from opensearch-project/skills
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAbstractRetrieverTool.java
147 lines (128 loc) · 5.51 KB
/
AbstractRetrieverTool.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.agent.tools;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
/**
* Abstract tool supports search paradigms in neural-search plugin.
*/
@Log4j2
@Getter
@Setter
public abstract class AbstractRetrieverTool implements Tool {
public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index.";
public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String SOURCE_FIELD = "source_field";
public static final String DOC_SIZE_FIELD = "doc_size";
public static final int DEFAULT_DOC_SIZE = 2;
protected String description = DEFAULT_DESCRIPTION;
protected Client client;
protected NamedXContentRegistry xContentRegistry;
protected String index;
protected String[] sourceFields;
protected Integer docSize;
protected String version;
protected AbstractRetrieverTool(
Client client,
NamedXContentRegistry xContentRegistry,
String index,
String[] sourceFields,
Integer docSize
) {
this.client = client;
this.xContentRegistry = xContentRegistry;
this.index = index;
this.sourceFields = sourceFields;
this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize;
}
protected abstract String getQueryBody(String queryText);
private static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}
protected <T> SearchRequest buildSearchRequest(Map<String, String> parameters) throws IOException {
String question = parameters.get(INPUT_FIELD);
if (StringUtils.isBlank(question)) {
throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it.");
}
String query = getQueryBody(question);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.fetchSource(sourceFields, null);
searchSourceBuilder.size(docSize);
return new SearchRequest().source(searchSourceBuilder).indices(parameters.getOrDefault(INDEX_FIELD, index));
}
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
SearchRequest searchRequest;
try {
searchRequest = buildSearchRequest(parameters);
} catch (Exception e) {
log.error("Failed to build search request.", e);
listener.onFailure(e);
return;
}
ActionListener actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();
if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
Map<String, Object> docContent = processResponse(hit);
String docContentInString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(docContent));
contextBuilder.append(docContentInString).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "Can not get any match from search result.");
}
}, e -> {
log.error("Failed to search index.", e);
listener.onFailure(e);
});
client.search(searchRequest, actionListener);
}
@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input"));
}
protected static abstract class Factory<T extends Tool> implements Tool.Factory<T> {
protected Client client;
protected NamedXContentRegistry xContentRegistry;
public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}
@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}