|
@@ -13,7 +13,6 @@ from langgraph.graph import StateGraph, END
|
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph.message import add_messages
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
|
|
from foundation.logger.loggering import server_logger as logger
|
|
from foundation.logger.loggering import server_logger as logger
|
|
|
-from foundation.utils.time_statistics import track_execution_time
|
|
|
|
|
from ..component import AIReviewEngine
|
|
from ..component import AIReviewEngine
|
|
|
|
|
|
|
|
|
|
|
|
@@ -29,23 +28,16 @@ class ReviewResult:
|
|
|
|
|
|
|
|
class AIReviewState(TypedDict):
|
|
class AIReviewState(TypedDict):
|
|
|
"""AI审查工作流状态"""
|
|
"""AI审查工作流状态"""
|
|
|
- # 基本信息
|
|
|
|
|
|
|
+
|
|
|
file_id: str
|
|
file_id: str
|
|
|
callback_task_id: str
|
|
callback_task_id: str
|
|
|
user_id: str
|
|
user_id: str
|
|
|
structured_content: Dict[str, Any]
|
|
structured_content: Dict[str, Any]
|
|
|
-
|
|
|
|
|
- # AI审查结果
|
|
|
|
|
review_results: Optional[Dict[str, Any]]
|
|
review_results: Optional[Dict[str, Any]]
|
|
|
-
|
|
|
|
|
- # 状态和进度
|
|
|
|
|
current_stage: str
|
|
current_stage: str
|
|
|
status: str
|
|
status: str
|
|
|
error_message: Optional[str]
|
|
error_message: Optional[str]
|
|
|
-
|
|
|
|
|
- # 进度管理
|
|
|
|
|
progress_manager: Optional[Any]
|
|
progress_manager: Optional[Any]
|
|
|
-
|
|
|
|
|
# 消息日志(用于LangGraph状态追踪)
|
|
# 消息日志(用于LangGraph状态追踪)
|
|
|
messages: Annotated[List[BaseMessage], add_messages]
|
|
messages: Annotated[List[BaseMessage], add_messages]
|
|
|
|
|
|
|
@@ -54,32 +46,82 @@ class AIReviewWorkflow:
|
|
|
"""基于LangGraph的AI审查工作流"""
|
|
"""基于LangGraph的AI审查工作流"""
|
|
|
|
|
|
|
|
def __init__(self, file_id: str, callback_task_id: str, user_id: str,
|
|
def __init__(self, file_id: str, callback_task_id: str, user_id: str,
|
|
|
- structured_content: Dict[str, Any], progress_manager=None):
|
|
|
|
|
|
|
+ structured_content: Dict[str, Any], progress_manager=None,
|
|
|
|
|
+ max_review_units: int = None, review_mode: str = "all"):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化AI审查工作流
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ file_id: 文件ID
|
|
|
|
|
+ callback_task_id: 回调任务ID
|
|
|
|
|
+ user_id: 用户ID
|
|
|
|
|
+ structured_content: 结构化内容
|
|
|
|
|
+ progress_manager: 进度管理器
|
|
|
|
|
+ max_review_units: 最大审查单元数量(None表示审查所有)
|
|
|
|
|
+ review_mode: 审查模式 ("all"=全部, "first"=前N个, "random"=随机N个)
|
|
|
|
|
+ """
|
|
|
self.file_id = file_id
|
|
self.file_id = file_id
|
|
|
self.callback_task_id = callback_task_id
|
|
self.callback_task_id = callback_task_id
|
|
|
self.user_id = user_id
|
|
self.user_id = user_id
|
|
|
self.structured_content = structured_content
|
|
self.structured_content = structured_content
|
|
|
self.progress_manager = progress_manager
|
|
self.progress_manager = progress_manager
|
|
|
self.ai_review_engine = AIReviewEngine()
|
|
self.ai_review_engine = AIReviewEngine()
|
|
|
-
|
|
|
|
|
- # 构建LangGraph工作流
|
|
|
|
|
|
|
+ self.max_review_units = max_review_units
|
|
|
|
|
+ self.review_mode = review_mode
|
|
|
self.graph = self._build_workflow()
|
|
self.graph = self._build_workflow()
|
|
|
|
|
|
|
|
|
|
+ def _filter_review_units(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 根据配置筛选要审查的单元
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ chunks: 所有审查单元
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List[Dict[str, Any]]: 筛选后的审查单元
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.max_review_units is None or self.review_mode == "all":
|
|
|
|
|
+ return chunks
|
|
|
|
|
+
|
|
|
|
|
+ # 验证原始chunks不为空
|
|
|
|
|
+ if not chunks:
|
|
|
|
|
+ logger.warning("没有可用的审查单元")
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ # 安全的切片操作,考虑边界情况
|
|
|
|
|
+ start_index = min(30, len(chunks) - 1) # 确保start_index不超过数组边界
|
|
|
|
|
+ chunks = chunks[start_index:]
|
|
|
|
|
+
|
|
|
|
|
+ # 再次验证切片后的结果
|
|
|
|
|
+ if not chunks:
|
|
|
|
|
+ logger.warning(f"从索引{start_index}切片后没有可用的审查单元")
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ total_chunks = len(chunks)
|
|
|
|
|
+ actual_review_count = min(self.max_review_units, total_chunks)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"审查模式: {self.review_mode}, 总单元数: {total_chunks}, 实际审查数: {actual_review_count}")
|
|
|
|
|
+
|
|
|
|
|
+ if self.review_mode == "first":
|
|
|
|
|
+ # 取前N个
|
|
|
|
|
+ return chunks[:actual_review_count]
|
|
|
|
|
+ elif self.review_mode == "random":
|
|
|
|
|
+ # 随机取N个
|
|
|
|
|
+ import random
|
|
|
|
|
+ return random.sample(chunks, actual_review_count)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 默认取前N个
|
|
|
|
|
+ return chunks[:actual_review_count]
|
|
|
|
|
+
|
|
|
def _build_workflow(self) -> StateGraph:
|
|
def _build_workflow(self) -> StateGraph:
|
|
|
"""构建AI审查的LangGraph工作流图"""
|
|
"""构建AI审查的LangGraph工作流图"""
|
|
|
workflow = StateGraph(AIReviewState)
|
|
workflow = StateGraph(AIReviewState)
|
|
|
-
|
|
|
|
|
- # 添加节点
|
|
|
|
|
workflow.add_node("start", self._start_node)
|
|
workflow.add_node("start", self._start_node)
|
|
|
workflow.add_node("initialize_progress", self._initialize_progress_node)
|
|
workflow.add_node("initialize_progress", self._initialize_progress_node)
|
|
|
workflow.add_node("ai_review", self._ai_review_node)
|
|
workflow.add_node("ai_review", self._ai_review_node)
|
|
|
workflow.add_node("complete", self._complete_node)
|
|
workflow.add_node("complete", self._complete_node)
|
|
|
workflow.add_node("error_handler", self._error_handler_node)
|
|
workflow.add_node("error_handler", self._error_handler_node)
|
|
|
-
|
|
|
|
|
- # 设置入口点
|
|
|
|
|
- workflow.set_entry_point("start")
|
|
|
|
|
-
|
|
|
|
|
- # 添加边(定义流程)
|
|
|
|
|
|
|
+ workflow.set_entry_point("start")# 设置入口节点
|
|
|
workflow.add_edge("start", "initialize_progress")
|
|
workflow.add_edge("start", "initialize_progress")
|
|
|
workflow.add_edge("initialize_progress", "ai_review")
|
|
workflow.add_edge("initialize_progress", "ai_review")
|
|
|
workflow.add_edge("ai_review", "complete")
|
|
workflow.add_edge("ai_review", "complete")
|
|
@@ -105,8 +147,6 @@ class AIReviewWorkflow:
|
|
|
"""执行基于LangGraph的AI审查工作流"""
|
|
"""执行基于LangGraph的AI审查工作流"""
|
|
|
try:
|
|
try:
|
|
|
logger.info(f"开始AI审查工作流,文件ID: {self.file_id}")
|
|
logger.info(f"开始AI审查工作流,文件ID: {self.file_id}")
|
|
|
-
|
|
|
|
|
- # 初始状态
|
|
|
|
|
initial_state = AIReviewState(
|
|
initial_state = AIReviewState(
|
|
|
file_id=self.file_id,
|
|
file_id=self.file_id,
|
|
|
callback_task_id=self.callback_task_id,
|
|
callback_task_id=self.callback_task_id,
|
|
@@ -144,8 +184,6 @@ class AIReviewWorkflow:
|
|
|
logger.error(f"LangGraph AI审查工作流执行失败: {str(e)}")
|
|
logger.error(f"LangGraph AI审查工作流执行失败: {str(e)}")
|
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
- # ========== LangGraph节点实现 ==========
|
|
|
|
|
-
|
|
|
|
|
async def _start_node(self, state: AIReviewState) -> AIReviewState:
|
|
async def _start_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
"""开始节点"""
|
|
"""开始节点"""
|
|
|
logger.info(f"AI审查工作流启动: {state['file_id']}")
|
|
logger.info(f"AI审查工作流启动: {state['file_id']}")
|
|
@@ -177,18 +215,24 @@ class AIReviewWorkflow:
|
|
|
return state
|
|
return state
|
|
|
|
|
|
|
|
async def _ai_review_node(self, state: AIReviewState) -> AIReviewState:
|
|
async def _ai_review_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
- """AI审查节点 - 使用LangGraph编排原子化组件方法"""
|
|
|
|
|
|
|
+ """AI审查节点"""
|
|
|
try:
|
|
try:
|
|
|
logger.info(f"执行AI审查: {state['file_id']}")
|
|
logger.info(f"执行AI审查: {state['file_id']}")
|
|
|
|
|
|
|
|
state["current_stage"] = "ai_review"
|
|
state["current_stage"] = "ai_review"
|
|
|
|
|
|
|
|
- total_units = len(state['structured_content']['chunks'])
|
|
|
|
|
|
|
+ # 筛选要审查的单元
|
|
|
|
|
+ all_chunks = state['structured_content']['chunks']
|
|
|
|
|
+ review_chunks = self._filter_review_units(all_chunks)
|
|
|
|
|
+
|
|
|
|
|
+ total_units = len(review_chunks)
|
|
|
|
|
+ total_all_units = len(all_chunks)
|
|
|
completed_units = 0
|
|
completed_units = 0
|
|
|
|
|
|
|
|
|
|
+ logger.info(f"AI审查开始: 总单元数 {total_all_units}, 实际审查 {total_units} 个单元")
|
|
|
|
|
+
|
|
|
# 进度回调函数
|
|
# 进度回调函数
|
|
|
def progress_callback(progress: int, message: str):
|
|
def progress_callback(progress: int, message: str):
|
|
|
- # 将AI审查的进度映射到整体进度
|
|
|
|
|
overall_progress = 50 + int(progress * 0.4) # AI审查占整体进度的40%
|
|
overall_progress = 50 + int(progress * 0.4) # AI审查占整体进度的40%
|
|
|
if state["progress_manager"]:
|
|
if state["progress_manager"]:
|
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -201,16 +245,17 @@ class AIReviewWorkflow:
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 使用原子化组件方法审查单个单元
|
|
|
|
|
- async def review_single_unit(unit_content: Dict[str, Any], unit_index: int) -> ReviewResult:
|
|
|
|
|
|
|
+ # 基本审查单元
|
|
|
|
|
+ async def review_single_unit(unit_content: Dict[str, Any], unit_index: int,callback_task_id) -> ReviewResult:
|
|
|
"""使用LangGraph编排的原子化组件方法审查单个单元"""
|
|
"""使用LangGraph编排的原子化组件方法审查单个单元"""
|
|
|
- async with self.ai_review_engine.semaphore:
|
|
|
|
|
- try:
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 构建Trace ID
|
|
|
|
|
+ trace_id_idx = "("+str(callback_task_id)+'-'+str(unit_index)+")"
|
|
|
# 并发执行各种原子化审查方法
|
|
# 并发执行各种原子化审查方法
|
|
|
review_tasks = [
|
|
review_tasks = [
|
|
|
- self.ai_review_engine.basic_compliance_check(unit_content),
|
|
|
|
|
- self.ai_review_engine.technical_compliance_check(unit_content),
|
|
|
|
|
- self.ai_review_engine.rag_enhanced_check(unit_content)
|
|
|
|
|
|
|
+ self.ai_review_engine.basic_compliance_check(trace_id_idx, unit_content),
|
|
|
|
|
+ self.ai_review_engine.technical_compliance_check(trace_id_idx, unit_content),
|
|
|
|
|
+ # self.ai_review_engine.rag_enhanced_check(unit_content, trace_id_idx)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
# 等待所有审查完成
|
|
# 等待所有审查完成
|
|
@@ -218,8 +263,10 @@ class AIReviewWorkflow:
|
|
|
|
|
|
|
|
# 处理异常结果
|
|
# 处理异常结果
|
|
|
basic_result = review_results[0] if not isinstance(review_results[0], Exception) else {"error": str(review_results[0])}
|
|
basic_result = review_results[0] if not isinstance(review_results[0], Exception) else {"error": str(review_results[0])}
|
|
|
- technical_result = review_results[1] if not isinstance(review_results[1], Exception) else {"error": str(review_results[1])}
|
|
|
|
|
- rag_result = review_results[2] if not isinstance(review_results[2], Exception) else {"error": str(review_results[2])}
|
|
|
|
|
|
|
+ technical_result = review_results[1] if len(review_results) > 1 and not isinstance(review_results[1], Exception) else {"error": str(review_results[1]) if len(review_results) > 1 else "No result"}
|
|
|
|
|
+
|
|
|
|
|
+ # RAG检查已注释,提供空结果
|
|
|
|
|
+ rag_result = {"error": "RAG check disabled"}
|
|
|
|
|
|
|
|
# 计算总体风险等级
|
|
# 计算总体风险等级
|
|
|
overall_risk = self._calculate_overall_risk(basic_result, technical_result, rag_result)
|
|
overall_risk = self._calculate_overall_risk(basic_result, technical_result, rag_result)
|
|
@@ -234,29 +281,29 @@ class AIReviewWorkflow:
|
|
|
progress_callback(progress, message)
|
|
progress_callback(progress, message)
|
|
|
|
|
|
|
|
return ReviewResult(
|
|
return ReviewResult(
|
|
|
- unit_index=unit_index,
|
|
|
|
|
- unit_content=unit_content,
|
|
|
|
|
- basic_compliance=basic_result,
|
|
|
|
|
- technical_compliance=technical_result,
|
|
|
|
|
- rag_enhanced=rag_result,
|
|
|
|
|
- overall_risk=overall_risk
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ unit_index=unit_index,
|
|
|
|
|
+ unit_content=unit_content,
|
|
|
|
|
+ basic_compliance=basic_result,
|
|
|
|
|
+ technical_compliance=technical_result,
|
|
|
|
|
+ rag_enhanced=rag_result,
|
|
|
|
|
+ overall_risk=overall_risk
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"审查单元 {unit_index} 失败: {str(e)}")
|
|
|
|
|
- return ReviewResult(
|
|
|
|
|
- unit_index=unit_index,
|
|
|
|
|
- unit_content=unit_content,
|
|
|
|
|
- basic_compliance={"error": str(e)},
|
|
|
|
|
- technical_compliance={"error": str(e)},
|
|
|
|
|
- rag_enhanced={"error": str(e)},
|
|
|
|
|
- overall_risk="error"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"审查单元 {unit_index} 失败: {str(e)}")
|
|
|
|
|
+ return ReviewResult(
|
|
|
|
|
+ unit_index=unit_index,
|
|
|
|
|
+ unit_content=unit_content,
|
|
|
|
|
+ basic_compliance={"error": str(e)},
|
|
|
|
|
+ technical_compliance={"error": str(e)},
|
|
|
|
|
+ rag_enhanced={"error": str(e)},
|
|
|
|
|
+ overall_risk="error"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- # 并发审查所有单元
|
|
|
|
|
|
|
+ # 实现异步并发
|
|
|
review_tasks = [
|
|
review_tasks = [
|
|
|
- asyncio.create_task(review_single_unit(content, i))
|
|
|
|
|
- for i, content in enumerate(state['structured_content']['chunks'])
|
|
|
|
|
|
|
+ asyncio.create_task(review_single_unit(content, i,state["callback_task_id"]))
|
|
|
|
|
+ for i, content in enumerate(review_chunks)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
# 等待所有审查完成
|
|
# 等待所有审查完成
|
|
@@ -269,9 +316,12 @@ class AIReviewWorkflow:
|
|
|
summary = self._aggregate_results(successful_results)
|
|
summary = self._aggregate_results(successful_results)
|
|
|
|
|
|
|
|
review_results = {
|
|
review_results = {
|
|
|
- 'total_units': total_units,
|
|
|
|
|
|
|
+ 'total_all_units': total_all_units, # 原始总单元数
|
|
|
|
|
+ 'total_reviewed_units': total_units, # 实际审查的单元数
|
|
|
'successful_units': len(successful_results),
|
|
'successful_units': len(successful_results),
|
|
|
'failed_units': total_units - len(successful_results),
|
|
'failed_units': total_units - len(successful_results),
|
|
|
|
|
+ 'review_mode': self.review_mode,
|
|
|
|
|
+ 'max_review_units': self.max_review_units,
|
|
|
'review_results': successful_results,
|
|
'review_results': successful_results,
|
|
|
'summary': summary
|
|
'summary': summary
|
|
|
}
|
|
}
|
|
@@ -333,7 +383,6 @@ class AIReviewWorkflow:
|
|
|
|
|
|
|
|
return state
|
|
return state
|
|
|
|
|
|
|
|
- # ========== 辅助方法 ==========
|
|
|
|
|
|
|
|
|
|
def _calculate_overall_risk(self, basic_result: Dict, technical_result: Dict, rag_result: Dict) -> str:
|
|
def _calculate_overall_risk(self, basic_result: Dict, technical_result: Dict, rag_result: Dict) -> str:
|
|
|
"""计算总体风险等级"""
|
|
"""计算总体风险等级"""
|
|
@@ -379,7 +428,6 @@ class AIReviewWorkflow:
|
|
|
logger.error(f"结果汇总失败: {str(e)}")
|
|
logger.error(f"结果汇总失败: {str(e)}")
|
|
|
return {}
|
|
return {}
|
|
|
|
|
|
|
|
- # ========== 条件边函数 ==========
|
|
|
|
|
|
|
|
|
|
def _check_ai_review_result(self, state: AIReviewState) -> str:
|
|
def _check_ai_review_result(self, state: AIReviewState) -> str:
|
|
|
"""检查AI审查结果"""
|
|
"""检查AI审查结果"""
|
|
@@ -393,8 +441,6 @@ class AIReviewWorkflow:
|
|
|
grandalf_graph.print_ascii()
|
|
grandalf_graph.print_ascii()
|
|
|
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
async def _get_status(self) -> dict:
|
|
async def _get_status(self) -> dict:
|
|
|
"""获取工作流状态"""
|
|
"""获取工作流状态"""
|
|
|
if self.progress_manager:
|
|
if self.progress_manager:
|