Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ private Constant() {

public static final String PYTHON_TRIES_COUNT = "PYTHON_TRIES_COUNT";

// 标记是否进入Python执行失败的降级模式(超过最大重试次数后触发)
public static final String PYTHON_FALLBACK_MODE = "PYTHON_FALLBACK_MODE";

// If code execution succeeds, output code running result; if fails, output error
// information
public static final String PYTHON_EXECUTE_NODE_OUTPUT = "PYTHON_EXECUTE_NODE_OUTPUT";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,9 @@ public class CodeExecutorProperties {
*/
String networkMode = "none";

/**
* Python执行的最大重试次数
*/
Integer pythonMaxTriesCount = 5;

}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ public WebClient.Builder webClientBuilder(@Value("${webclient.response.timeout:6
}

@Bean
public StateGraph nl2sqlGraph(NodeBeanUtil nodeBeanUtil) throws GraphStateException {
public StateGraph nl2sqlGraph(NodeBeanUtil nodeBeanUtil, CodeExecutorProperties codeExecutorProperties)
throws GraphStateException {

KeyStrategyFactory keyStrategyFactory = () -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
Expand Down Expand Up @@ -155,6 +156,7 @@ public StateGraph nl2sqlGraph(NodeBeanUtil nodeBeanUtil) throws GraphStateExcept
keyStrategyHashMap.put(SQL_RESULT_LIST_MEMORY, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_IS_SUCCESS, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_TRIES_COUNT, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_FALLBACK_MODE, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_EXECUTE_NODE_OUTPUT, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_GENERATE_NODE_OUTPUT, KeyStrategy.REPLACE);
keyStrategyHashMap.put(PYTHON_ANALYSIS_NODE_OUTPUT, KeyStrategy.REPLACE);
Expand Down Expand Up @@ -203,7 +205,7 @@ public StateGraph nl2sqlGraph(NodeBeanUtil nodeBeanUtil) throws GraphStateExcept
.addEdge(PLANNER_NODE, PLAN_EXECUTOR_NODE)
// python nodes
.addEdge(PYTHON_GENERATE_NODE, PYTHON_EXECUTE_NODE)
.addConditionalEdges(PYTHON_EXECUTE_NODE, edge_async(new PythonExecutorDispatcher()),
.addConditionalEdges(PYTHON_EXECUTE_NODE, edge_async(new PythonExecutorDispatcher(codeExecutorProperties)),
Map.of(PYTHON_ANALYZE_NODE, PYTHON_ANALYZE_NODE, END, END, PYTHON_GENERATE_NODE,
PYTHON_GENERATE_NODE))
.addEdge(PYTHON_ANALYZE_NODE, PLAN_EXECUTOR_NODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.dataagent.common.util.StateUtil;
import com.alibaba.cloud.ai.dataagent.config.CodeExecutorProperties;
import lombok.extern.slf4j.Slf4j;

import static com.alibaba.cloud.ai.dataagent.common.constant.Constant.*;
Expand All @@ -31,16 +32,28 @@
@Slf4j
public class PythonExecutorDispatcher implements EdgeAction {

private final CodeExecutorProperties codeExecutorProperties;

public PythonExecutorDispatcher(CodeExecutorProperties codeExecutorProperties) {
this.codeExecutorProperties = codeExecutorProperties;
}

@Override
public String apply(OverAllState state) throws Exception {
boolean isFallbackMode = StateUtil.getObjectValue(state, PYTHON_FALLBACK_MODE, Boolean.class, false);
if (isFallbackMode) {
log.warn("Python执行进入降级模式,跳过重试直接进入分析节点");
return PYTHON_ANALYZE_NODE;
}

// Determine if failed
boolean isSuccess = StateUtil.getObjectValue(state, PYTHON_IS_SUCCESS, Boolean.class, false);
if (!isSuccess) {
String message = StateUtil.getStringValue(state, PYTHON_EXECUTE_NODE_OUTPUT);
log.error("Python Executor Node Error: {}", message);
int tries = StateUtil.getObjectValue(state, PYTHON_TRIES_COUNT, Integer.class, 0);
if (tries <= 0) {
log.warn("Python Executor Node Error: Exceeding the maximum number of iterations");
if (tries >= codeExecutorProperties.getPythonMaxTriesCount()) {
log.error("Python执行失败且已超过最大重试次数(已尝试次数:{}),流程终止", tries);
return END;
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.alibaba.cloud.ai.dataagent.common.util.FluxUtil;
import com.alibaba.cloud.ai.dataagent.common.util.PlanProcessUtil;
import com.alibaba.cloud.ai.dataagent.common.util.StateUtil;
import com.alibaba.cloud.ai.dataagent.common.util.ChatResponseUtil;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -60,7 +61,27 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
Map<String, String> sqlExecuteResult = StateUtil.getObjectValue(state, SQL_EXECUTE_NODE_OUTPUT, Map.class,
new HashMap<>());

// Load Python code generation template
// 检查是否进入降级模式
boolean isFallbackMode = StateUtil.getObjectValue(state, PYTHON_FALLBACK_MODE, Boolean.class, false);

if (isFallbackMode) {
// 降级模式
String fallbackMessage = "Python 高级分析功能暂时不可用,出现错误";
log.warn("Python分析节点检测到降级模式,返回固定提示信息");

Flux<ChatResponse> fallbackFlux = Flux.just(ChatResponseUtil.createResponse(fallbackMessage));

Flux<GraphResponse<StreamingOutput>> generator = FluxUtil.createStreamingGeneratorWithMessages(
this.getClass(), state, "正在处理分析结果...\n", "\n处理完成。", aiResponse -> {
Map<String, String> updatedSqlResult = PlanProcessUtil.addStepResult(sqlExecuteResult,
currentStep, fallbackMessage);
log.info("python fallback message: {}", fallbackMessage);
return Map.of(SQL_EXECUTE_NODE_OUTPUT, updatedSqlResult, PLAN_CURRENT_STEP, currentStep + 1);
}, fallbackFlux);

return Map.of(PYTHON_ANALYSIS_NODE_OUTPUT, generator);
}

String systemPrompt = PromptConstant.getPythonAnalyzePromptTemplate()
.render(Map.of("python_output", pythonOutput, "user_query", userQuery));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.alibaba.cloud.ai.dataagent.common.enums.TextType;
import com.alibaba.cloud.ai.dataagent.common.util.JsonParseUtil;
import com.alibaba.cloud.ai.dataagent.config.CodeExecutorProperties;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
Expand Down Expand Up @@ -55,10 +56,14 @@ public class PythonExecuteNode implements NodeAction {

private final JsonParseUtil jsonParseUtil;

public PythonExecuteNode(CodePoolExecutorService codePoolExecutor, JsonParseUtil jsonParseUtil) {
private final CodeExecutorProperties codeExecutorProperties;

public PythonExecuteNode(CodePoolExecutorService codePoolExecutor, JsonParseUtil jsonParseUtil,
CodeExecutorProperties codeExecutorProperties) {
this.codePoolExecutor = codePoolExecutor;
this.objectMapper = JsonUtil.getObjectMapper();
this.jsonParseUtil = jsonParseUtil;
this.codeExecutorProperties = codeExecutorProperties;
}

@Override
Expand All @@ -69,6 +74,10 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
String pythonCode = StateUtil.getStringValue(state, PYTHON_GENERATE_NODE_OUTPUT);
List<Map<String, String>> sqlResults = StateUtil.hasValue(state, SQL_RESULT_LIST_MEMORY)
? StateUtil.getListValue(state, SQL_RESULT_LIST_MEMORY) : new ArrayList<>();

// 检查重试次数
int triesCount = StateUtil.getObjectValue(state, PYTHON_TRIES_COUNT, Integer.class, 0);

CodePoolExecutorService.TaskRequest taskRequest = new CodePoolExecutorService.TaskRequest(pythonCode,
objectMapper.writeValueAsString(sqlResults), null);

Expand All @@ -78,6 +87,28 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
String errorMsg = "Python Execute Failed!\nStdOut: " + taskResponse.stdOut() + "\nStdErr: "
+ taskResponse.stdErr() + "\nExceptionMsg: " + taskResponse.exceptionMsg();
log.error(errorMsg);

// 检查是否超过最大重试次数
if (triesCount >= codeExecutorProperties.getPythonMaxTriesCount()) {
log.error("Python执行失败且已超过最大重试次数(已尝试次数:{}),启动降级兜底逻辑。错误信息: {}", triesCount, errorMsg);

String fallbackOutput = "{}";

Flux<ChatResponse> fallbackDisplayFlux = Flux.create(emitter -> {
emitter.next(ChatResponseUtil.createResponse("开始执行Python代码..."));
emitter.next(ChatResponseUtil.createResponse("Python代码执行失败已超过最大重试次数,采用降级策略继续处理。"));
emitter.complete();
});

Flux<GraphResponse<StreamingOutput>> fallbackGenerator = FluxUtil
.createStreamingGeneratorWithMessages(this.getClass(), state,
v -> Map.of(PYTHON_EXECUTE_NODE_OUTPUT, fallbackOutput, PYTHON_IS_SUCCESS, false,
PYTHON_FALLBACK_MODE, true),
fallbackDisplayFlux);

return Map.of(PYTHON_EXECUTE_NODE_OUTPUT, fallbackGenerator);
}

throw new RuntimeException(errorMsg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ public class PythonGenerateNode implements NodeAction {

private static final int SAMPLE_DATA_NUMBER = 5;

private static final int MAX_TRIES_COUNT = 5;

private final ObjectMapper objectMapper;

private final CodeExecutorProperties codeExecutorProperties;
Expand All @@ -78,7 +76,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
List<Map<String, String>> sqlResults = StateUtil.hasValue(state, SQL_RESULT_LIST_MEMORY)
? StateUtil.getListValue(state, SQL_RESULT_LIST_MEMORY) : new ArrayList<>();
boolean codeRunSuccess = StateUtil.getObjectValue(state, PYTHON_IS_SUCCESS, Boolean.class, true);
int triesCount = StateUtil.getObjectValue(state, PYTHON_TRIES_COUNT, Integer.class, MAX_TRIES_COUNT);
int triesCount = StateUtil.getObjectValue(state, PYTHON_TRIES_COUNT, Integer.class, 0);

String userPrompt = StateUtil.getCanonicalQuery(state);
if (!codeRunSuccess) {
Expand Down Expand Up @@ -121,7 +119,7 @@ public Map<String, Object> apply(OverAllState state) throws Exception {
aiResponse.length() - TextType.PYTHON.getEndSign().length());
aiResponse = MarkdownParserUtil.extractRawText(aiResponse);
log.info("Python Generate Code: {}", aiResponse);
return Map.of(PYTHON_GENERATE_NODE_OUTPUT, aiResponse, PYTHON_TRIES_COUNT, triesCount - 1);
return Map.of(PYTHON_GENERATE_NODE_OUTPUT, aiResponse, PYTHON_TRIES_COUNT, triesCount + 1);
},
Flux.concat(Flux.just(ChatResponseUtil.createPureResponse(TextType.PYTHON.getStartSign())),
pythonGenerateFlux,
Expand Down
3 changes: 3 additions & 0 deletions data-agent-management/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ spring:
code-executor:
# 运行Python代码的环境(生产环境建议使用docker,不建议使用local)
code-pool-executor: local
# Python执行的最大重试次数
python-max-tries-count: 5
file:
type: local
path-prefix: data-agent
Expand All @@ -47,6 +49,7 @@ spring:
enabled: true

mybatis:

configuration:
map-underscore-to-camel-case: true
log-impl: org.apache.ibatis.logging.slf4j.Slf4jImpl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ try:
# 将输入数据转换为DataFrame以便于分析
df = pd.DataFrame(input_data)

# 自动类型推断:将字符串形式的数值转换为实际数值类型(解决后端统一返回字符串的问题)
for col in df.columns:
df[col] = pd.to_numeric(df[col], errors='ignore')

# 动态分析逻辑
# 示例:计算某些统计指标
result = \{
Expand Down Expand Up @@ -79,8 +83,9 @@ except Exception:
# 注意事项

1. **输入验证**:确保代码能够正确处理空输入或格式不正确的输入,并在异常时提供清晰的错误信息。**处理的数据必须来自`json.load(sys.stdin)`**。
2. **性能优化**:尽量减少不必要的计算和内存占用,确保代码在性能约束内高效运行。
3. **结果完整性**:输出的JSON对象应全面反映分析结果,且字段命名清晰易懂。
2. **类型转换**:由于后端SQL查询结果统一转换为字符串格式,**必须在DataFrame创建后进行类型推断**,使用`pd.to_numeric()`将数值字符串转换为实际数值类型,以便进行数学运算和统计分析。
3. **性能优化**:尽量减少不必要的计算和内存占用,确保代码在性能约束内高效运行。
4. **结果完整性**:输出的JSON对象应全面反映分析结果,且字段命名清晰易懂。

---

Expand All @@ -101,6 +106,10 @@ try:
# 转换为DataFrame
df = pd.DataFrame(input_data)

# 自动类型推断:将字符串形式的数值转换为实际数值类型
for col in df.columns:
df[col] = pd.to_numeric(df[col], errors='ignore')

# 动态分析逻辑
result = \{
"channel_stats": []
Expand Down
Loading