|
|
@@ -18,6 +18,7 @@
|
|
|
├── _start_node() # 开始节点
|
|
|
├── _initialize_progress_node() # 初始化进度节点
|
|
|
├── _ai_review_node() # AI审查核心节点
|
|
|
+├── _save_results_node() # 保存结果节点(入库/本地文件)
|
|
|
├── _complete_node() # 完成节点
|
|
|
└── _error_handler_node() # 错误处理节点
|
|
|
|
|
|
@@ -42,6 +43,8 @@ import asyncio
|
|
|
import json
|
|
|
import random
|
|
|
import re
|
|
|
+import time
|
|
|
+import os
|
|
|
from dataclasses import dataclass, asdict
|
|
|
from typing import Optional, Callable, Dict, Any, TypedDict, Annotated, List
|
|
|
from langgraph.graph import StateGraph, END
|
|
|
@@ -75,6 +78,7 @@ class AIReviewState(TypedDict):
|
|
|
|
|
|
file_id: str
|
|
|
callback_task_id: str
|
|
|
+ file_name: str
|
|
|
user_id: str
|
|
|
structured_content: Dict[str, Any]
|
|
|
review_results: Optional[Dict[str, Any]]
|
|
|
@@ -136,12 +140,14 @@ class AIReviewWorkflow:
|
|
|
workflow.add_node("start", self._start_node)
|
|
|
workflow.add_node("initialize_progress", self._initialize_progress_node)
|
|
|
workflow.add_node("ai_review", self._ai_review_node)
|
|
|
+ workflow.add_node("save_results", self._save_results_node)
|
|
|
workflow.add_node("complete", self._complete_node)
|
|
|
workflow.add_node("error_handler", self._error_handler_node)
|
|
|
workflow.set_entry_point("start")
|
|
|
workflow.add_edge("start", "initialize_progress")
|
|
|
workflow.add_edge("initialize_progress", "ai_review")
|
|
|
- workflow.add_edge("ai_review", "complete")
|
|
|
+ workflow.add_edge("ai_review", "save_results")
|
|
|
+ workflow.add_edge("save_results", "complete")
|
|
|
workflow.add_edge("complete", END)
|
|
|
workflow.add_edge("error_handler", END)
|
|
|
|
|
|
@@ -150,7 +156,7 @@ class AIReviewWorkflow:
|
|
|
"ai_review",
|
|
|
self.inter_tool._check_ai_review_result,
|
|
|
{
|
|
|
- "success": "complete",
|
|
|
+ "success": "save_results",
|
|
|
"error": "error_handler"
|
|
|
}
|
|
|
)
|
|
|
@@ -276,21 +282,28 @@ class AIReviewWorkflow:
|
|
|
# 4. 汇总结果
|
|
|
summary = self.inter_tool._aggregate_results(successful_results)
|
|
|
|
|
|
+ # 将所有单元的issues合并成一个列表
|
|
|
+ all_issues = []
|
|
|
+ for unit_issues in successful_results:
|
|
|
+ if unit_issues and isinstance(unit_issues, list):
|
|
|
+ all_issues.extend(unit_issues)
|
|
|
+
|
|
|
+ # 构建符合格式的review_results
|
|
|
review_results = {
|
|
|
- 'total_all_units': total_all_units, # 原始总单元数
|
|
|
- 'total_reviewed_units': total_units, # 实际审查的单元数
|
|
|
- 'successful_units': len(successful_results),
|
|
|
- 'failed_units': total_units - len(successful_results),
|
|
|
- 'review_mode': self.review_mode,
|
|
|
- 'max_review_units': self.max_review_units,
|
|
|
- 'review_results': successful_results,
|
|
|
- 'summary': summary
|
|
|
+ "callback_task_id": state["callback_task_id"],
|
|
|
+ "file_name": state.get("file_name", ""), # 从state中获取文件名
|
|
|
+ "user_id": state["user_id"],
|
|
|
+ "current": 100,
|
|
|
+ "stage_name": "完整审查结果",
|
|
|
+ "status": "full_review_result",
|
|
|
+ "message": f"审查完成,共发现{summary.get('total_issues', 0)}个问题",
|
|
|
+ "overall_task_status": "completed",
|
|
|
+ "updated_at": int(time.time()),
|
|
|
+ "issues": all_issues
|
|
|
}
|
|
|
|
|
|
+ # 将格式化的review_results存储到state中,供save_results_node使用
|
|
|
state["review_results"] = review_results
|
|
|
- state["messages"].append(AIMessage(
|
|
|
- content=f"AI审查完成,共处理{total_units}个单元,成功{len(successful_results)}个"
|
|
|
- ))
|
|
|
|
|
|
logger.info(f"AI审查节点执行成功,任务ID: {state['callback_task_id']}")
|
|
|
return state
|
|
|
@@ -301,6 +314,63 @@ class AIReviewWorkflow:
|
|
|
state["messages"].append(AIMessage(content=f"AI审查失败: {str(e)}"))
|
|
|
return state
|
|
|
|
|
|
+ async def _save_results_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """
|
|
|
+ 保存结果节点 - 将审查结果存储到本地文件或数据库
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: AI审查工作流状态,包含审查结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ AIReviewState: 更新后的工作流状态
|
|
|
+
|
|
|
+ Note:
|
|
|
+ 当前实现:将审查结果以JSON格式保存到temp目录
|
|
|
+ 文件名:callback_task_id.json
|
|
|
+ 未来规划:使用SQL语句存储到数据库
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ logger.info(f"开始保存审查结果,任务ID: {state['callback_task_id']}")
|
|
|
+
|
|
|
+ # 创建temp目录(如果不存在)
|
|
|
+ temp_dir = "temp"
|
|
|
+ os.makedirs(temp_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 构建文件路径
|
|
|
+ file_path = os.path.join(temp_dir, f"{state['callback_task_id']}.json")
|
|
|
+
|
|
|
+ # 直接获取并保存review_results数据
|
|
|
+ review_results = state.get("review_results", {})
|
|
|
+
|
|
|
+ # 保存review_results到本地JSON文件
|
|
|
+ with open(file_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(review_results, f, ensure_ascii=False, indent=2, default=str)
|
|
|
+
|
|
|
+ logger.info(f"审查结果已保存到: {file_path}")
|
|
|
+
|
|
|
+ # # 更新进度状态
|
|
|
+ # if state["progress_manager"]:
|
|
|
+ # await state["progress_manager"].update_stage_progress(
|
|
|
+ # callback_task_id=state["callback_task_id"],
|
|
|
+ # stage_name="结果保存",
|
|
|
+ # current=95,
|
|
|
+ # status="processing",
|
|
|
+ # message=f"审查结果已保存到 {file_path}",
|
|
|
+ # overall_task_status="processing",
|
|
|
+ # event_type="processing"
|
|
|
+ # )
|
|
|
+
|
|
|
+ # state["current_stage"] = "save_results"
|
|
|
+ # state["messages"].append(AIMessage(content=f"审查结果已保存到: {file_path}"))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"保存审查结果失败: {str(e)}", exc_info=True)
|
|
|
+ state["error_message"] = f"保存结果失败: {str(e)}"
|
|
|
+ state["messages"].append(AIMessage(content=f"保存结果失败: {str(e)}"))
|
|
|
+ return state
|
|
|
+
|
|
|
async def _complete_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
"""
|
|
|
完成节点 - 工作流结束处理
|
|
|
@@ -405,7 +475,7 @@ class AIReviewCoreFun:
|
|
|
|
|
|
|
|
|
async def _execute_concurrent_reviews(self, review_chunks: List[Dict[str, Any]],
|
|
|
- total_units: int, state: AIReviewState) -> List[ReviewResult]:
|
|
|
+ total_units: int, state: AIReviewState) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
执行并发审查
|
|
|
|
|
|
@@ -415,7 +485,7 @@ class AIReviewCoreFun:
|
|
|
state: AI审查状态
|
|
|
|
|
|
Returns:
|
|
|
- List[ReviewResult]: 审查结果列表
|
|
|
+ List[Dict[str, Any]]: 审查结果列表(issues格式)
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
@@ -446,10 +516,11 @@ class AIReviewCoreFun:
|
|
|
|
|
|
# 立即发送单元审查详情(包含unit_review和processing_flag事件)
|
|
|
await self._send_unit_review_progress(state, unit_index, total_units, section_label, issues, current)
|
|
|
-
|
|
|
+ return issues
|
|
|
else:
|
|
|
logger.error(f"执行单个单元审查失败: {str(result.error_message)}")
|
|
|
- return result
|
|
|
+ return None
|
|
|
+
|
|
|
|
|
|
# 创建并发任务
|
|
|
tasks = [
|
|
|
@@ -460,8 +531,8 @@ class AIReviewCoreFun:
|
|
|
# 等待所有任务完成
|
|
|
all_results = await asyncio.gather(*tasks)
|
|
|
|
|
|
- # 过滤成功结果
|
|
|
- successful_results = [result for result in all_results if result.overall_risk != "error"]
|
|
|
+ # 过滤有效结果(issues格式)
|
|
|
+ successful_results = [issues for issues in all_results if issues and isinstance(issues, list)]
|
|
|
return successful_results
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -1084,19 +1155,18 @@ class InterTool:
|
|
|
logger.warning(f"风险等级计算异常: {str(e)},使用默认风险等级")
|
|
|
return DEFAULT_RISK_LEVEL
|
|
|
|
|
|
- def _aggregate_results(self, successful_results: List[ReviewResult]) -> Dict[str, Any]:
|
|
|
+ def _aggregate_results(self, successful_results: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
|
|
"""
|
|
|
- 汇总审查结果
|
|
|
+ 汇总审查结果(issues格式)
|
|
|
|
|
|
Args:
|
|
|
- successful_results: 成功的审查结果列表,每个结果包含风险等级和得分
|
|
|
+ successful_results: 成功的审查结果列表(issues格式),每个单元返回一个issues列表
|
|
|
|
|
|
Returns:
|
|
|
Dict[str, Any]: 汇总后的统计信息,包含以下字段:
|
|
|
- risk_stats: 各风险等级的数量统计 {"low": 0, "medium": 0, "high": 0}
|
|
|
- - avg_basic_score: 基础合规性平均得分
|
|
|
- - avg_technical_score: 技术性审查平均得分
|
|
|
- total_reviewed: 成功审查的总数量
|
|
|
+ - total_issues: 总问题数量
|
|
|
|
|
|
Note:
|
|
|
当输入为空时返回空字典,异常时记录错误并返回空字典
|
|
|
@@ -1105,23 +1175,32 @@ class InterTool:
|
|
|
if not successful_results:
|
|
|
return {}
|
|
|
|
|
|
- # 计算风险等级统计
|
|
|
- risk_stats = {"low": 0, "medium": 0, "high": 0, "error": 0}
|
|
|
- for result in successful_results:
|
|
|
- risk_stats[result.overall_risk] += 1
|
|
|
+ # 计算风险等级统计和问题总数
|
|
|
+ risk_stats = {"low": 0, "medium": 0, "high": 0}
|
|
|
+ total_issues = 0
|
|
|
+
|
|
|
+ for unit_issues in successful_results:
|
|
|
+ # 每个unit_issues是一个issues列表
|
|
|
+ if unit_issues and isinstance(unit_issues, list):
|
|
|
+ total_issues += len(unit_issues)
|
|
|
|
|
|
- # 计算平均分
|
|
|
- total_basic_score = sum(r.basic_compliance.get('overall_score', 0) for r in successful_results)
|
|
|
- total_technical_score = sum(r.technical_compliance.get('overall_score', 0) for r in successful_results)
|
|
|
+ # 统计每个issue中的风险等级
|
|
|
+ for issue in unit_issues:
|
|
|
+ if isinstance(issue, dict):
|
|
|
+ # issue格式: {issue_id: {risk_summary: {...}}}
|
|
|
+ for issue_data in issue.values():
|
|
|
+ risk_summary = issue_data.get('risk_summary', {})
|
|
|
+ max_risk = risk_summary.get('max_risk_level', '0')
|
|
|
|
|
|
- avg_basic_score = total_basic_score / len(successful_results)
|
|
|
- avg_technical_score = total_technical_score / len(successful_results)
|
|
|
+ if max_risk in risk_stats:
|
|
|
+ risk_stats[max_risk] += 1
|
|
|
+ elif max_risk == '0':
|
|
|
+ risk_stats['low'] += 1 # 无风险视为低风险
|
|
|
|
|
|
return {
|
|
|
'risk_stats': risk_stats,
|
|
|
- 'avg_basic_score': avg_basic_score,
|
|
|
- 'avg_technical_score': avg_technical_score,
|
|
|
- 'total_reviewed': len(successful_results)
|
|
|
+ 'total_reviewed': len(successful_results),
|
|
|
+ 'total_issues': total_issues
|
|
|
}
|
|
|
except (ZeroDivisionError, KeyError, TypeError) as e:
|
|
|
logger.error(f"结果汇总失败: {str(e)}")
|