|
|
@@ -0,0 +1,402 @@
|
|
|
+"""
|
|
|
+基于LangGraph的AI审查工作流
|
|
|
+负责AI审查的流程控制和业务编排,使用LangGraph进行状态管理
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+from dataclasses import asdict
|
|
|
+import time
|
|
|
+from typing import Optional, Callable, Dict, Any, TypedDict, Annotated, List
|
|
|
+from dataclasses import dataclass
|
|
|
+from langgraph.graph import StateGraph, END
|
|
|
+from langgraph.graph.message import add_messages
|
|
|
+from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
|
|
+from foundation.logger.loggering import server_logger as logger
|
|
|
+from foundation.utils.time_statistics import track_execution_time
|
|
|
+from ..component import AIReviewEngine
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ReviewResult:
|
|
|
+ """审查结果"""
|
|
|
+ unit_index: int
|
|
|
+ unit_content: Dict[str, Any]
|
|
|
+ basic_compliance: Dict[str, Any]
|
|
|
+ technical_compliance: Dict[str, Any]
|
|
|
+ rag_enhanced: Dict[str, Any]
|
|
|
+ overall_risk: str
|
|
|
+
|
|
|
+class AIReviewState(TypedDict):
|
|
|
+ """AI审查工作流状态"""
|
|
|
+ # 基本信息
|
|
|
+ file_id: str
|
|
|
+ callback_task_id: str
|
|
|
+ user_id: str
|
|
|
+ structured_content: Dict[str, Any]
|
|
|
+
|
|
|
+ # AI审查结果
|
|
|
+ review_results: Optional[Dict[str, Any]]
|
|
|
+
|
|
|
+ # 状态和进度
|
|
|
+ current_stage: str
|
|
|
+ status: str
|
|
|
+ error_message: Optional[str]
|
|
|
+
|
|
|
+ # 进度管理
|
|
|
+ progress_manager: Optional[Any]
|
|
|
+
|
|
|
+ # 消息日志(用于LangGraph状态追踪)
|
|
|
+ messages: Annotated[List[BaseMessage], add_messages]
|
|
|
+
|
|
|
+
|
|
|
+class AIReviewWorkflow:
|
|
|
+ """基于LangGraph的AI审查工作流"""
|
|
|
+
|
|
|
+ def __init__(self, file_id: str, callback_task_id: str, user_id: str,
|
|
|
+ structured_content: Dict[str, Any], progress_manager=None):
|
|
|
+ self.file_id = file_id
|
|
|
+ self.callback_task_id = callback_task_id
|
|
|
+ self.user_id = user_id
|
|
|
+ self.structured_content = structured_content
|
|
|
+ self.progress_manager = progress_manager
|
|
|
+ self.ai_review_engine = AIReviewEngine()
|
|
|
+
|
|
|
+ # 构建LangGraph工作流
|
|
|
+ self.graph = self._build_workflow()
|
|
|
+
|
|
|
+ def _build_workflow(self) -> StateGraph:
|
|
|
+ """构建AI审查的LangGraph工作流图"""
|
|
|
+ workflow = StateGraph(AIReviewState)
|
|
|
+
|
|
|
+ # 添加节点
|
|
|
+ workflow.add_node("start", self._start_node)
|
|
|
+ workflow.add_node("initialize_progress", self._initialize_progress_node)
|
|
|
+ workflow.add_node("ai_review", self._ai_review_node)
|
|
|
+ workflow.add_node("complete", self._complete_node)
|
|
|
+ workflow.add_node("error_handler", self._error_handler_node)
|
|
|
+
|
|
|
+ # 设置入口点
|
|
|
+ workflow.set_entry_point("start")
|
|
|
+
|
|
|
+ # 添加边(定义流程)
|
|
|
+ workflow.add_edge("start", "initialize_progress")
|
|
|
+ workflow.add_edge("initialize_progress", "ai_review")
|
|
|
+ workflow.add_edge("ai_review", "complete")
|
|
|
+ workflow.add_edge("complete", END)
|
|
|
+ workflow.add_edge("error_handler", END)
|
|
|
+
|
|
|
+ # 添加条件边(错误处理)
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "ai_review",
|
|
|
+ self._check_ai_review_result,
|
|
|
+ {
|
|
|
+ "success": "complete",
|
|
|
+ "error": "error_handler"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ self.graph = workflow.compile()
|
|
|
+ self._get_workflow_graph()
|
|
|
+
|
|
|
+ return self.graph
|
|
|
+
|
|
|
+ async def execute(self) -> dict:
|
|
|
+ """执行基于LangGraph的AI审查工作流"""
|
|
|
+ try:
|
|
|
+ logger.info(f"开始AI审查工作流,文件ID: {self.file_id}")
|
|
|
+
|
|
|
+ # 初始状态
|
|
|
+ initial_state = AIReviewState(
|
|
|
+ file_id=self.file_id,
|
|
|
+ callback_task_id=self.callback_task_id,
|
|
|
+ user_id=self.user_id,
|
|
|
+ structured_content=self.structured_content,
|
|
|
+ review_results=None,
|
|
|
+ current_stage="start",
|
|
|
+ status="processing",
|
|
|
+ error_message=None,
|
|
|
+ progress_manager=self.progress_manager,
|
|
|
+ messages=[HumanMessage(content=f"开始AI审查: {self.file_id}")]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行LangGraph工作流
|
|
|
+ result = await self.graph.ainvoke(initial_state)
|
|
|
+
|
|
|
+ logger.info(f"LangGraph AI审查工作流完成,文件ID: {self.file_id}")
|
|
|
+ review_results = {
|
|
|
+ 'file_id': result['file_id'],
|
|
|
+ 'total_units': result['review_results'].get('total_units', 0) if result['review_results'] else 0,
|
|
|
+ 'successful_units': result['review_results'].get('successful_units', 0) if result['review_results'] else 0,
|
|
|
+ 'failed_units': result['review_results'].get('failed_units', 0) if result['review_results'] else 0,
|
|
|
+ 'review_results': result['review_results'].get('review_results', []) if result['review_results'] else [],
|
|
|
+ 'summary': result['review_results'].get('summary', {}) if result['review_results'] else {},
|
|
|
+ 'status': result['status']
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"保存审查结果")
|
|
|
+ with open('temp/AI审查结果.json', "w",encoding='utf-8') as f:
|
|
|
+ json.dump(result, f, ensure_ascii=False, indent=2, default=str)
|
|
|
+
|
|
|
+ return review_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"LangGraph AI审查工作流执行失败: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # ========== LangGraph节点实现 ==========
|
|
|
+
|
|
|
+ async def _start_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """开始节点"""
|
|
|
+ logger.info(f"AI审查工作流启动: {state['file_id']}")
|
|
|
+
|
|
|
+ state["current_stage"] = "start"
|
|
|
+ state["status"] = "processing"
|
|
|
+ state["messages"].append(AIMessage(content="AI审查工作流启动"))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def _initialize_progress_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """初始化进度节点"""
|
|
|
+ logger.info(f"初始化AI审查进度: {state['file_id']}")
|
|
|
+
|
|
|
+ state["current_stage"] = "initialize_progress"
|
|
|
+
|
|
|
+ # 更新进度
|
|
|
+ if state["progress_manager"]:
|
|
|
+ await state["progress_manager"].update_stage_progress(
|
|
|
+ callback_task_id=state["callback_task_id"],
|
|
|
+ stage_name="AI审查",
|
|
|
+ progress=0,
|
|
|
+ status="processing",
|
|
|
+ message="开始AI审查"
|
|
|
+ )
|
|
|
+
|
|
|
+ state["messages"].append(AIMessage(content="进度初始化完成"))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def _ai_review_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """AI审查节点 - 使用LangGraph编排原子化组件方法"""
|
|
|
+ try:
|
|
|
+ logger.info(f"执行AI审查: {state['file_id']}")
|
|
|
+
|
|
|
+ state["current_stage"] = "ai_review"
|
|
|
+
|
|
|
+ total_units = len(state['structured_content']['chunks'])
|
|
|
+ completed_units = 0
|
|
|
+
|
|
|
+ # 进度回调函数
|
|
|
+ def progress_callback(progress: int, message: str):
|
|
|
+ # 将AI审查的进度映射到整体进度
|
|
|
+ overall_progress = 50 + int(progress * 0.4) # AI审查占整体进度的40%
|
|
|
+ if state["progress_manager"]:
|
|
|
+ asyncio.create_task(
|
|
|
+ state["progress_manager"].update_stage_progress(
|
|
|
+ callback_task_id=state["callback_task_id"],
|
|
|
+ stage_name="AI审查",
|
|
|
+ progress=overall_progress,
|
|
|
+ status="processing",
|
|
|
+ message=message
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # 使用原子化组件方法审查单个单元
|
|
|
+ async def review_single_unit(unit_content: Dict[str, Any], unit_index: int) -> ReviewResult:
|
|
|
+ """使用LangGraph编排的原子化组件方法审查单个单元"""
|
|
|
+ async with self.ai_review_engine.semaphore:
|
|
|
+ try:
|
|
|
+ # 并发执行各种原子化审查方法
|
|
|
+ review_tasks = [
|
|
|
+ self.ai_review_engine.basic_compliance_check(unit_content),
|
|
|
+ self.ai_review_engine.technical_compliance_check(unit_content),
|
|
|
+ self.ai_review_engine.rag_enhanced_check(unit_content)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 等待所有审查完成
|
|
|
+ review_results = await asyncio.gather(*review_tasks, return_exceptions=True)
|
|
|
+
|
|
|
+ # 处理异常结果
|
|
|
+ basic_result = review_results[0] if not isinstance(review_results[0], Exception) else {"error": str(review_results[0])}
|
|
|
+ technical_result = review_results[1] if not isinstance(review_results[1], Exception) else {"error": str(review_results[1])}
|
|
|
+ rag_result = review_results[2] if not isinstance(review_results[2], Exception) else {"error": str(review_results[2])}
|
|
|
+
|
|
|
+ # 计算总体风险等级
|
|
|
+ overall_risk = self._calculate_overall_risk(basic_result, technical_result, rag_result)
|
|
|
+
|
|
|
+ # 更新进度
|
|
|
+ nonlocal completed_units
|
|
|
+ completed_units += 1
|
|
|
+ progress = int((completed_units / total_units) * 100)
|
|
|
+ message = f"已完成 {completed_units}/{total_units} 个审查单元"
|
|
|
+
|
|
|
+ if progress_callback:
|
|
|
+ progress_callback(progress, message)
|
|
|
+
|
|
|
+ return ReviewResult(
|
|
|
+ unit_index=unit_index,
|
|
|
+ unit_content=unit_content,
|
|
|
+ basic_compliance=basic_result,
|
|
|
+ technical_compliance=technical_result,
|
|
|
+ rag_enhanced=rag_result,
|
|
|
+ overall_risk=overall_risk
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"审查单元 {unit_index} 失败: {str(e)}")
|
|
|
+ return ReviewResult(
|
|
|
+ unit_index=unit_index,
|
|
|
+ unit_content=unit_content,
|
|
|
+ basic_compliance={"error": str(e)},
|
|
|
+ technical_compliance={"error": str(e)},
|
|
|
+ rag_enhanced={"error": str(e)},
|
|
|
+ overall_risk="error"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 并发审查所有单元
|
|
|
+ review_tasks = [
|
|
|
+ asyncio.create_task(review_single_unit(content, i))
|
|
|
+ for i, content in enumerate(state['structured_content']['chunks'])
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 等待所有审查完成
|
|
|
+ all_results = await asyncio.gather(*review_tasks)
|
|
|
+
|
|
|
+ # 过滤成功结果
|
|
|
+ successful_results = [result for result in all_results if result.overall_risk != "error"]
|
|
|
+
|
|
|
+ # 汇总结果
|
|
|
+ summary = self._aggregate_results(successful_results)
|
|
|
+
|
|
|
+ review_results = {
|
|
|
+ 'total_units': total_units,
|
|
|
+ 'successful_units': len(successful_results),
|
|
|
+ 'failed_units': total_units - len(successful_results),
|
|
|
+ 'review_results': successful_results,
|
|
|
+ 'summary': summary
|
|
|
+ }
|
|
|
+
|
|
|
+ state["review_results"] = review_results
|
|
|
+ state["messages"].append(AIMessage(
|
|
|
+ content=f"AI审查完成,共处理{total_units}个单元,成功{len(successful_results)}个"
|
|
|
+ ))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"AI审查失败: {str(e)}")
|
|
|
+ state["error_message"] = str(e)
|
|
|
+ state["messages"].append(AIMessage(content=f"AI审查失败: {str(e)}"))
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def _complete_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """完成节点"""
|
|
|
+ logger.info(f"AI审查完成: {state['file_id']}")
|
|
|
+
|
|
|
+ state["current_stage"] = "complete"
|
|
|
+ state["status"] = "completed"
|
|
|
+
|
|
|
+ # 更新完成状态
|
|
|
+ if state["progress_manager"]:
|
|
|
+ await state["progress_manager"].update_stage_progress(
|
|
|
+ callback_task_id=state["callback_task_id"],
|
|
|
+ stage_name="AI审查",
|
|
|
+ progress=90,
|
|
|
+ status="completed",
|
|
|
+ message="AI审查完成"
|
|
|
+ )
|
|
|
+
|
|
|
+ state["messages"].append(AIMessage(content="AI审查工作流完成"))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def _error_handler_node(self, state: AIReviewState) -> AIReviewState:
|
|
|
+ """错误处理节点"""
|
|
|
+ logger.error(f"AI审查错误处理: {state['file_id']}, 错误: {state['error_message']}")
|
|
|
+
|
|
|
+ state["status"] = "failed"
|
|
|
+ state["current_stage"] = "error_handler"
|
|
|
+
|
|
|
+ # 更新错误状态
|
|
|
+ if state["progress_manager"]:
|
|
|
+ await state["progress_manager"].update_stage_progress(
|
|
|
+ callback_task_id=state["callback_task_id"],
|
|
|
+ stage_name="AI审查",
|
|
|
+ progress=50,
|
|
|
+ status="failed",
|
|
|
+ message=f"AI审查失败: {state['error_message']}"
|
|
|
+ )
|
|
|
+
|
|
|
+ state["messages"].append(AIMessage(
|
|
|
+ content=f"错误处理: {state['error_message']}"
|
|
|
+ ))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ # ========== 辅助方法 ==========
|
|
|
+
|
|
|
+ def _calculate_overall_risk(self, basic_result: Dict, technical_result: Dict, rag_result: Dict) -> str:
|
|
|
+ """计算总体风险等级"""
|
|
|
+ try:
|
|
|
+ # 基于各种审查结果计算风险等级
|
|
|
+ basic_score = basic_result.get('overall_score', 0)
|
|
|
+ technical_score = technical_result.get('overall_score', 0)
|
|
|
+
|
|
|
+ if basic_score >= 90 and technical_score >= 90:
|
|
|
+ return "low"
|
|
|
+ elif basic_score >= 70 and technical_score >= 70:
|
|
|
+ return "medium"
|
|
|
+ else:
|
|
|
+ return "high"
|
|
|
+ except:
|
|
|
+ return "medium"
|
|
|
+
|
|
|
+ def _aggregate_results(self, successful_results: List[ReviewResult]) -> Dict[str, Any]:
|
|
|
+ """汇总审查结果"""
|
|
|
+ try:
|
|
|
+ if not successful_results:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # 计算统计数据
|
|
|
+ risk_stats = {"low": 0, "medium": 0, "high": 0, "error": 0}
|
|
|
+ for result in successful_results:
|
|
|
+ risk_stats[result.overall_risk] += 1
|
|
|
+
|
|
|
+ # 计算平均分
|
|
|
+ total_basic_score = sum(r.basic_compliance.get('overall_score', 0) for r in successful_results)
|
|
|
+ total_technical_score = sum(r.technical_compliance.get('overall_score', 0) for r in successful_results)
|
|
|
+
|
|
|
+ avg_basic_score = total_basic_score / len(successful_results)
|
|
|
+ avg_technical_score = total_technical_score / len(successful_results)
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'risk_stats': risk_stats,
|
|
|
+ 'avg_basic_score': avg_basic_score,
|
|
|
+ 'avg_technical_score': avg_technical_score,
|
|
|
+ 'total_reviewed': len(successful_results)
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"结果汇总失败: {str(e)}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # ========== 条件边函数 ==========
|
|
|
+
|
|
|
+ def _check_ai_review_result(self, state: AIReviewState) -> str:
|
|
|
+ """检查AI审查结果"""
|
|
|
+ if state.get("error_message"):
|
|
|
+ return "error"
|
|
|
+ return "success"
|
|
|
+
|
|
|
+ def _get_workflow_graph(self):
|
|
|
+ """获取工作流图(可视化输出)"""
|
|
|
+ grandalf_graph = self.graph.get_graph()
|
|
|
+ grandalf_graph.print_ascii()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ async def _get_status(self) -> dict:
|
|
|
+ """获取工作流状态"""
|
|
|
+ if self.progress_manager:
|
|
|
+ return await self.progress_manager.get_progress(self.callback_task_id)
|
|
|
+ return {}
|