Skip to content

Commit a7c1ab8

Browse files
ghondarHenryHengZJ
andauthored
Feature/Add Neo4j GraphRag support (#3686)
* added: Neo4j database connectivity, Neo4j credentials, supports the usage of the GraphCypherQaChain node and modifies the FewShotPromptTemplate node to handle variables from the prefix field. * Merge branch 'main' of github.com:FlowiseAI/Flowise into feature/graphragsupport * revert pnpm-lock.yaml * add: neo4j package * Refactor GraphCypherQAChain: Update version to 1.0, remove memory input, and enhance prompt handling - Changed version from 2.0 to 1.0. - Removed the 'Memory' input parameter from the GraphCypherQAChain. - Made 'cypherPrompt' optional and improved error handling for prompt validation. - Updated the 'init' and 'run' methods to streamline input processing and response handling. - Enhanced streaming response logic based on the 'returnDirect' flag. * Refactor GraphCypherQAChain: Simplify imports and update init method signature - Consolidated import statements for better readability. - Removed the 'input' and 'options' parameters from the 'init' method, streamlining its signature to only accept 'nodeData'. * add output, format final response, fix optional inputs --------- Co-authored-by: Henry <[email protected]>
1 parent 93f3a5d commit a7c1ab8

File tree

8 files changed

+34325
-33897
lines changed

8 files changed

+34325
-33897
lines changed
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { INodeParams, INodeCredential } from '../src/Interface'
2+
3+
class Neo4jApi implements INodeCredential {
4+
label: string
5+
name: string
6+
version: number
7+
description: string
8+
inputs: INodeParams[]
9+
10+
constructor() {
11+
this.label = 'Neo4j API'
12+
this.name = 'neo4jApi'
13+
this.version = 1.0
14+
this.description =
15+
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication'
16+
this.inputs = [
17+
{
18+
label: 'Neo4j URL',
19+
name: 'url',
20+
type: 'string',
21+
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)'
22+
},
23+
{
24+
label: 'Username',
25+
name: 'username',
26+
type: 'string',
27+
description: 'Neo4j database username'
28+
},
29+
{
30+
label: 'Password',
31+
name: 'password',
32+
type: 'password',
33+
description: 'Neo4j database password'
34+
}
35+
]
36+
}
37+
}
38+
39+
module.exports = { credClass: Neo4jApi }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
2+
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
3+
import { getBaseClasses } from '../../../src/utils'
4+
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
5+
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
6+
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
7+
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
8+
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
9+
10+
class GraphCypherQA_Chain implements INode {
11+
label: string
12+
name: string
13+
version: number
14+
type: string
15+
icon: string
16+
category: string
17+
description: string
18+
baseClasses: string[]
19+
inputs: INodeParams[]
20+
sessionId?: string
21+
outputs: INodeOutputsValue[]
22+
23+
constructor(fields?: { sessionId?: string }) {
24+
this.label = 'Graph Cypher QA Chain'
25+
this.name = 'graphCypherQAChain'
26+
this.version = 1.0
27+
this.type = 'GraphCypherQAChain'
28+
this.icon = 'graphqa.svg'
29+
this.category = 'Chains'
30+
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
31+
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
32+
this.sessionId = fields?.sessionId
33+
this.inputs = [
34+
{
35+
label: 'Language Model',
36+
name: 'model',
37+
type: 'BaseLanguageModel',
38+
description: 'Model for generating Cypher queries and answers.'
39+
},
40+
{
41+
label: 'Neo4j Graph',
42+
name: 'graph',
43+
type: 'Neo4j'
44+
},
45+
{
46+
label: 'Cypher Generation Prompt',
47+
name: 'cypherPrompt',
48+
optional: true,
49+
type: 'BasePromptTemplate',
50+
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables'
51+
},
52+
{
53+
label: 'Cypher Generation Model',
54+
name: 'cypherModel',
55+
optional: true,
56+
type: 'BaseLanguageModel',
57+
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
58+
},
59+
{
60+
label: 'QA Prompt',
61+
name: 'qaPrompt',
62+
optional: true,
63+
type: 'BasePromptTemplate',
64+
description: 'Prompt template for generating answers. Must include {context} and {question} variables'
65+
},
66+
{
67+
label: 'QA Model',
68+
name: 'qaModel',
69+
optional: true,
70+
type: 'BaseLanguageModel',
71+
description: 'Model for generating answers. If not provided, the main model will be used.'
72+
},
73+
{
74+
label: 'Input Moderation',
75+
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
76+
name: 'inputModeration',
77+
type: 'Moderation',
78+
optional: true,
79+
list: true
80+
},
81+
{
82+
label: 'Return Direct',
83+
name: 'returnDirect',
84+
type: 'boolean',
85+
default: false,
86+
optional: true,
87+
description: 'If true, return the raw query results instead of using the QA chain'
88+
}
89+
]
90+
this.outputs = [
91+
{
92+
label: 'Graph Cypher QA Chain',
93+
name: 'graphCypherQAChain',
94+
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
95+
},
96+
{
97+
label: 'Output Prediction',
98+
name: 'outputPrediction',
99+
baseClasses: ['string', 'json']
100+
}
101+
]
102+
}
103+
104+
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
105+
const model = nodeData.inputs?.model
106+
const cypherModel = nodeData.inputs?.cypherModel
107+
const qaModel = nodeData.inputs?.qaModel
108+
const graph = nodeData.inputs?.graph
109+
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
110+
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
111+
const returnDirect = nodeData.inputs?.returnDirect as boolean
112+
const output = nodeData.outputs?.output as string
113+
114+
// Handle prompt values if they exist
115+
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
116+
let qaPromptTemplate: PromptTemplate | undefined
117+
118+
if (cypherPrompt) {
119+
if (cypherPrompt instanceof PromptTemplate) {
120+
cypherPromptTemplate = new PromptTemplate({
121+
template: cypherPrompt.template as string,
122+
inputVariables: cypherPrompt.inputVariables
123+
})
124+
if (!qaPrompt) {
125+
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
126+
}
127+
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
128+
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
129+
cypherPromptTemplate = new FewShotPromptTemplate({
130+
examples: cypherPrompt.examples,
131+
examplePrompt: examplePrompt,
132+
inputVariables: cypherPrompt.inputVariables,
133+
prefix: cypherPrompt.prefix,
134+
suffix: cypherPrompt.suffix,
135+
exampleSeparator: cypherPrompt.exampleSeparator,
136+
templateFormat: cypherPrompt.templateFormat
137+
})
138+
} else {
139+
cypherPromptTemplate = cypherPrompt as PromptTemplate
140+
}
141+
}
142+
143+
if (qaPrompt instanceof PromptTemplate) {
144+
qaPromptTemplate = new PromptTemplate({
145+
template: qaPrompt.template as string,
146+
inputVariables: qaPrompt.inputVariables
147+
})
148+
}
149+
150+
if ((!cypherModel || !qaModel) && !model) {
151+
throw new Error('Language Model is required when Cypher Model or QA Model are not provided')
152+
}
153+
154+
// Validate required variables in prompts
155+
if (
156+
cypherPromptTemplate &&
157+
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
158+
) {
159+
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
160+
}
161+
162+
const fromLLMInput: FromLLMInput = {
163+
llm: model,
164+
graph,
165+
returnDirect
166+
}
167+
168+
if (cypherModel && cypherPromptTemplate) {
169+
fromLLMInput['cypherLLM'] = cypherModel
170+
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
171+
}
172+
173+
if (qaModel && qaPromptTemplate) {
174+
fromLLMInput['qaLLM'] = qaModel
175+
fromLLMInput['qaPrompt'] = qaPromptTemplate
176+
}
177+
178+
const chain = GraphCypherQAChain.fromLLM(fromLLMInput)
179+
180+
if (output === this.name) {
181+
return chain
182+
} else if (output === 'outputPrediction') {
183+
nodeData.instance = chain
184+
return await this.run(nodeData, input, options)
185+
}
186+
187+
return chain
188+
}
189+
190+
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
191+
const chain = nodeData.instance as GraphCypherQAChain
192+
const moderations = nodeData.inputs?.inputModeration as Moderation[]
193+
const returnDirect = nodeData.inputs?.returnDirect as boolean
194+
195+
const shouldStreamResponse = options.shouldStreamResponse
196+
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
197+
const chatId = options.chatId
198+
199+
// Handle input moderation if configured
200+
if (moderations && moderations.length > 0) {
201+
try {
202+
input = await checkInputs(moderations, input)
203+
} catch (e) {
204+
await new Promise((resolve) => setTimeout(resolve, 500))
205+
if (shouldStreamResponse) {
206+
streamResponse(sseStreamer, chatId, e.message)
207+
}
208+
return formatResponse(e.message)
209+
}
210+
}
211+
212+
const obj = {
213+
query: input
214+
}
215+
216+
const loggerHandler = new ConsoleCallbackHandler(options.logger)
217+
const callbackHandlers = await additionalCallbacks(nodeData, options)
218+
let callbacks = [loggerHandler, ...callbackHandlers]
219+
220+
if (process.env.DEBUG === 'true') {
221+
callbacks.push(new LCConsoleCallbackHandler())
222+
}
223+
224+
try {
225+
let response
226+
if (shouldStreamResponse) {
227+
if (returnDirect) {
228+
response = await chain.invoke(obj, { callbacks })
229+
let result = response?.result
230+
if (typeof result === 'object') {
231+
result = '```json\n' + JSON.stringify(result, null, 2)
232+
}
233+
if (result && typeof result === 'string') {
234+
streamResponse(sseStreamer, chatId, result)
235+
}
236+
} else {
237+
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
238+
callbacks.push(handler)
239+
response = await chain.invoke(obj, { callbacks })
240+
}
241+
} else {
242+
response = await chain.invoke(obj, { callbacks })
243+
}
244+
245+
return formatResponse(response?.result)
246+
} catch (error) {
247+
console.error('Error in GraphCypherQAChain:', error)
248+
if (shouldStreamResponse) {
249+
streamResponse(sseStreamer, chatId, error.message)
250+
}
251+
return formatResponse(`Error: ${error.message}`)
252+
}
253+
}
254+
}
255+
256+
module.exports = { nodeClass: GraphCypherQA_Chain }
Loading

0 commit comments

Comments
 (0)