Skip to content

Commit 253d491

Browse files
committed
[NOID] Fixes #4000: Add self-explanation to the model, include the verbal schema description to the flow (#4144)
1 parent 720f43f commit 253d491

File tree

6 files changed

+364
-28
lines changed

6 files changed

+364
-28
lines changed

docs/asciidoc/modules/ROOT/pages/ml/genai.adoc

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ RETURN m.title
9191
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
9292
| model | The Open AI model | no, default `gpt-3.5-turbo`
9393
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
94+
| additionalPrompts | To specify other prompts to be passed to improve the request
9495
|===
9596

9697
.Results
@@ -102,6 +103,107 @@ RETURN m.title
102103
|===
103104

104105

106+
We can use the `additionalPrompts` config to improve the request, e.g. adding the natural language description of the schema (like the output of the `apoc.ml.schema` for instance).
107+
Since OpenAI is mainly trained to elaborate natural language questions asked in, rather than Cypher queries, by using this configuration it is possible to achieve better results.
108+
For example, given the https://neo4j.com/docs/getting-started/appendix/tutorials/guide-import-relational-and-etl/[Northwind dataset] we can execute:
109+
110+
.Query call
111+
[source,cypher]
112+
----
113+
CALL apoc.ml.schema({apiKey: $apiKey}) YIELD value
114+
WITH value
115+
CALL apoc.ml.query("Which 5 employees had sold the product 'Chocolade' and has the highest selling count of another product?
116+
Please returns the employee identificator, the other product name and the count orders of another product",
117+
{
118+
retries: 8,
119+
retryWithError: true,
120+
apiKey: $apiKey,
121+
additionalPrompts: [
122+
{role: "system", content: "The human description of the schema is the following:\n" + value}
123+
]
124+
})
125+
YIELD query, value RETURN query, value
126+
----
127+
128+
with a result similar to the following.
129+
130+
NOTE: the results are not deterministic and will potentially change each time the query is re-executed
131+
132+
.Results
133+
[%autowidth, opts=header]
134+
|===
135+
| query | value
136+
| "cypher
137+
MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(:Order)<-[:SOLD]-(e:Employee)
138+
MATCH (e)-[:SOLD]->(o:Order)-[:CONTAINS]->(p2:Product)
139+
WITH e, p2, COUNT(DISTINCT o) AS orderCount
140+
ORDER BY orderCount DESC
141+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, orderCount
142+
LIMIT 5
143+
"
144+
| {
145+
"otherProduct": "Gnocchi di nonna Alice",
146+
"employeeID": "4",
147+
"orderCount": 14
148+
}
149+
| "cypher
150+
MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(:Order)<-[:SOLD]-(e:Employee)
151+
MATCH (e)-[:SOLD]->(o:Order)-[:CONTAINS]->(p2:Product)
152+
WITH e, p2, COUNT(DISTINCT o) AS orderCount
153+
ORDER BY orderCount DESC
154+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, orderCount
155+
LIMIT 5
156+
"
157+
| {
158+
"otherProduct": "Pâté chinois",
159+
"employeeID": "4",
160+
"orderCount": 12
161+
}
162+
| "cypher
163+
MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(:Order)<-[:SOLD]-(e:Employee)
164+
MATCH (e)-[:SOLD]->(o:Order)-[:CONTAINS]->(p2:Product)
165+
WITH e, p2, COUNT(DISTINCT o) AS orderCount
166+
ORDER BY orderCount DESC
167+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, orderCount
168+
LIMIT 5
169+
"
170+
| {
171+
"otherProduct": "Gumbär Gummibärchen",
172+
"employeeID": "3",
173+
"orderCount": 12
174+
}
175+
| "cypher
176+
MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(:Order)<-[:SOLD]-(e:Employee)
177+
MATCH (e)-[:SOLD]->(o:Order)-[:CONTAINS]->(p2:Product)
178+
WITH e, p2, COUNT(DISTINCT o) AS orderCount
179+
ORDER BY orderCount DESC
180+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, orderCount
181+
LIMIT 5
182+
"
183+
| {
184+
"otherProduct": "Flotemysost",
185+
"employeeID": "1",
186+
"orderCount": 12
187+
}
188+
| "cypher
189+
MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(:Order)<-[:SOLD]-(e:Employee)
190+
MATCH (e)-[:SOLD]->(o:Order)-[:CONTAINS]->(p2:Product)
191+
WITH e, p2, COUNT(DISTINCT o) AS orderCount
192+
ORDER BY orderCount DESC
193+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, orderCount
194+
LIMIT 5
195+
"
196+
| {
197+
"otherProduct": "Pavlova",
198+
"employeeID": "1",
199+
"orderCount": 11
200+
}
201+
|===
202+
203+
Respect to using the procedure without the natural language schema description, the output has fewer hallucinations,
204+
like properties hold by different labels and relationships linked to other entities.
205+
206+
105207
== Describe the graph model with natural language
106208

107209
This procedure `apoc.ml.schema` returns a description, in natural language, of the underlying dataset.
@@ -126,6 +228,7 @@ RETURN *
126228
1 row
127229
----
128230

231+
129232
.Input Parameters
130233
[%autowidth, opts=header]
131234
|===
@@ -205,6 +308,7 @@ RETURN DISTINCT a.name
205308
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
206309
| model | The Open AI model | no, default `gpt-3.5-turbo`
207310
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
311+
| additionalPrompts | To specify other prompts to be passed to improve the request
208312
|===
209313

210314
.Results
@@ -214,6 +318,47 @@ RETURN DISTINCT a.name
214318
| value | the description of the dataset
215319
|===
216320

321+
322+
We can use the `additionalPrompts` config to improve the request, e.g. adding the natural language description of the schema (like the output of the `apoc.ml.schema` for instance).
323+
Since OpenAI is mainly trained to elaborate natural language questions asked in, rather than Cypher queries, by using this configuration it is possible to achieve better results.
324+
For example, given the https://neo4j.com/docs/getting-started/appendix/tutorials/guide-import-relational-and-etl/[Northwind dataset] we can execute:
325+
326+
.Query call
327+
[source,cypher]
328+
----
329+
CALL apoc.ml.schema({apiKey: $apiKey}) YIELD value
330+
WITH value
331+
CALL apoc.ml.cypher("Which 5 employees had sold the product 'Chocolade' and has the highest selling count of another product?
332+
Please returns the employee identificator, the other product name and the count orders of another product",
333+
{
334+
count: 1,
335+
apiKey: $apiKey,
336+
additionalPrompts: [
337+
{role: "system", content: "The human description of the schema is the following:\n" + value}
338+
]
339+
})
340+
YIELD value RETURN value
341+
----
342+
343+
with a result similar to the following.
344+
345+
NOTE: the results are not deterministic and will potentially change each time the query is re-executed
346+
347+
.Results
348+
[%autowidth, opts=header]
349+
|===
350+
| value
351+
| MATCH (p:Product {productName: 'Chocolade'})<-[:CONTAINS]-(o:Order)<-[:SOLD]-(e:Employee)
352+
MATCH (e)-[:SOLD]->(o2:Order)-[:CONTAINS]->(p2:Product)
353+
WITH e, p2, COUNT(DISTINCT o2) AS ordersCnt
354+
ORDER BY ordersCnt DESC
355+
RETURN e.employeeID AS employeeID, p2.productName AS otherProduct, ordersCnt
356+
LIMIT 5
357+
|===
358+
359+
Respect to using the procedure without the natural language schema description, the output has fewer hallucinations,
360+
like properties hold by different labels and relationships linked to other entities.
361+
217362
== Create a natural language query explanation from a cypher query
218363

219364
This procedure `apoc.ml.fromCypher` takes a natural language question and transforms it into natural language query explanation.

full/src/main/java/apoc/ml/Prompt.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import java.util.stream.Collectors;
1515
import java.util.stream.LongStream;
1616
import java.util.stream.Stream;
17+
18+
import org.apache.commons.collections4.CollectionUtils;
1719
import org.apache.commons.text.WordUtils;
1820
import org.jetbrains.annotations.NotNull;
1921
import org.neo4j.graphdb.Entity;
@@ -49,6 +51,7 @@ public class Prompt {
4951
public static final String EXPLAIN_SCHEMA_PROMPT =
5052
"You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.\n"
5153
+ "Explain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.\n"
54+
+ "Try to explain as much as possible the nodes, relationships and properties.\n"
5255
+ "Keep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.\n";
5356

5457
static final String SYSTEM_PROMPT = "You are an expert in the Neo4j graph query language Cypher.\n"
@@ -129,7 +132,7 @@ public Stream<StringResult> rag(
129132
context);
130133

131134
String prompt = config.getBasePrompt() + contextPrompt;
132-
String result = prompt("\nQuestion:" + question, prompt, null, null, conf);
135+
String result = prompt("\nQuestion:" + question, prompt, null, null, conf, List.of());
133136
return Stream.of(new StringResult(result));
134137
}
135138

@@ -171,9 +174,10 @@ public Stream<PromptMapResult> query(
171174
long retries = (long) conf.getOrDefault("retries", 3L);
172175
boolean containsField =
173176
procedureCallContext.outputFields().collect(Collectors.toSet()).contains("query");
177+
List<Map<String,String>> otherPrompts = new ArrayList<>();
174178
do {
175179
try {
176-
QueryResult queryResult = tryQuery(question, conf, schema);
180+
QueryResult queryResult = tryQuery(question, conf, schema, otherPrompts);
177181
query = queryResult.query;
178182
// just let it fail so that retries can work if (queryResult.query.isBlank()) return Stream.empty();
179183
/*
@@ -196,12 +200,14 @@ public Stream<PromptMapResult> query(
196200
@Procedure
197201
public Stream<StringResult> schema(@Name(value = "conf", defaultValue = "{}") Map<String, Object> conf)
198202
throws MalformedURLException, JsonProcessingException {
203+
String schema = loadSchema(tx, conf);
199204
String schemaExplanation = prompt(
200205
"Please explain the graph database schema to me and relate it to well known concepts and domains.",
201206
EXPLAIN_SCHEMA_PROMPT,
202207
"This database schema ",
203-
loadSchema(tx, conf),
204-
conf);
208+
schema,
209+
conf,
210+
List.of());
205211
return Stream.of(new StringResult(schemaExplanation));
206212
}
207213

@@ -210,14 +216,14 @@ public Stream<QueryResult> cypher(
210216
@Name("question") String question, @Name(value = "conf", defaultValue = "{}") Map<String, Object> conf) {
211217
String schema = loadSchema(tx, conf);
212218
long count = (long) conf.getOrDefault("count", 1L);
213-
return LongStream.rangeClosed(1, count).mapToObj(i -> tryQuery(question, conf, schema));
219+
return LongStream.rangeClosed(1, count).mapToObj(i -> tryQuery(question, conf, schema, List.of()));
214220
}
215221

216222
@NotNull
217-
private QueryResult tryQuery(String question, Map<String, Object> conf, String schema) {
223+
private QueryResult tryQuery(String question, Map<String, Object> conf, String schema, List<Map<String,String>> otherPrompts) {
218224
String query = "";
219225
try {
220-
query = prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf);
226+
query = prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf, otherPrompts);
221227
// doesn't work right now, fails with security context error
222228
// tx.execute("EXPLAIN " + query).close(); // TODO query plan / estimated rows?
223229
return new QueryResult(query, null, null);
@@ -230,18 +236,23 @@ private QueryResult tryQuery(String question, Map<String, Object> conf, String s
230236

231237
@NotNull
232238
private String prompt(
233-
String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map<String, Object> conf)
239+
String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map<String, Object> conf, List<Map<String,String>> otherPromptsFromRetries)
234240
throws JsonProcessingException, MalformedURLException {
235241
List<Map<String, String>> prompt = new ArrayList<>();
236242
if (systemPrompt != null && !systemPrompt.isBlank())
237243
prompt.add(Map.of("role", "system", "content", systemPrompt));
238244
if (schema != null && !schema.isBlank())
239245
prompt.add(Map.of(
240246
"role", "system", "content", "The graph database schema consists of these elements\n" + schema));
247+
List<Map<String, String>> additionalPrompts = (List<Map<String, String>>) conf.get("additionalPrompts");
248+
if (CollectionUtils.isNotEmpty(additionalPrompts)) {
249+
prompt.addAll(additionalPrompts);
250+
}
241251
if (userQuestion != null && !userQuestion.isBlank())
242252
prompt.add(Map.of("role", "user", "content", userQuestion));
243253
if (assistantPrompt != null && !assistantPrompt.isBlank())
244254
prompt.add(Map.of("role", "assistant", "content", assistantPrompt));
255+
prompt.addAll(otherPromptsFromRetries);
245256
String apiKey = (String) conf.get(API_KEY_CONF);
246257
String model = (String) conf.getOrDefault("model", "gpt-4o");
247258
String result = OpenAI.executeRequest(
@@ -287,7 +298,10 @@ private String prompt(
287298
+ "collect(case type when \"node\" then entities end)[0] as nodes, \n"
288299
+ "collect(case type when \"node\" then patterns end)[0] as patterns \n";
289300

290-
private static final String SCHEMA_PROMPT = "nodes:\n %s\n" + "relationships:\n %s\n" + "patterns: %s";
301+
private static final String SCHEMA_PROMPT =
302+
"nodes:\n```\n%s\n```\n" +
303+
"relationships:\n```\n%s\n```\n" +
304+
"patterns:\n```\n%s\n```";
291305

292306
private String loadSchema(Transaction tx, Map<String, Object> conf) {
293307
Map<String, Object> params = new HashMap<>();

full/src/main/java/apoc/util/ExtendedUtil.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package apoc.util;
22

33
import java.time.Duration;
4+
import java.util.Arrays;
45
import java.util.Collection;
6+
import java.util.List;
57
import java.util.Objects;
68
import java.util.function.Consumer;
79
import java.util.function.Supplier;
@@ -64,4 +66,10 @@ public static String joinStringLabels(Collection<String> labels) {
6466
? ":" + labels.stream().map(Util::quote).collect(Collectors.joining(":"))
6567
: "";
6668
}
69+
70+
public static List<String> splitSemicolonAndRemoveBlanks(String value) {
71+
return Arrays.stream(value.split(";\n"))
72+
.filter(i -> !i.isBlank())
73+
.collect(Collectors.toList());
74+
}
6775
}

0 commit comments

Comments
 (0)