Przeglądaj źródła

v0.0.2-实现总体流程sse进度推送

WangXuMing 3 miesięcy temu
rodzic
commit
3f2a38245d
49 zmienionych plików z 4609 dodań i 1225 usunięć
  1. 2 1
      .gitignore
  2. 1 1
      Dockerfile
  3. 37 2
      README.md
  4. 38 0
      README_deploy.md
  5. BIN
      build_graph_app.png
  6. 28 10
      config/config.ini
  7. 10 0
      config/prompt/common_model_query.yaml
  8. 22 0
      config/prompt/intent_prompt.yaml
  9. 19 0
      config/prompt/system_prompt.yaml
  10. 68 0
      config/sql/lq_db.sql
  11. 1 1
      config/sql/test.sql
  12. 266 65
      core/base/progress_manager.py
  13. 45 13
      core/base/workflow_manager.py
  14. 1 0
      core/construction_review/component/ai_review_engine.py
  15. 376 54
      core/construction_review/component/document_processor.py
  16. 2 2
      core/construction_review/component/reviewers/base_reviewer.py
  17. 50 0
      core/construction_review/doc_worker/__init__.py
  18. 173 0
      core/construction_review/doc_worker/config.yaml
  19. 194 0
      core/construction_review/doc_worker/config_loader.py
  20. 205 0
      core/construction_review/doc_worker/core.py
  21. 212 0
      core/construction_review/doc_worker/llm_classifier.py
  22. 294 0
      core/construction_review/doc_worker/result_saver.py
  23. 814 0
      core/construction_review/doc_worker/text_splitter.py
  24. 348 0
      core/construction_review/doc_worker/toc_extractor.py
  25. 31 21
      core/construction_review/workflows/ai_review_workflow.py
  26. 13 19
      core/construction_review/workflows/document_workflow.py
  27. 7 11
      core/construction_review/workflows/report_workflow.py
  28. 36 0
      database/repositories/bus_data_query.py
  29. 2 2
      foundation/agent/generate/model_generate.py
  30. 12 3
      foundation/agent/workflow/test_workflow_node.py
  31. 0 157
      foundation/base/mysql/async_mysql_base_dao.py
  32. 3 3
      foundation/base/tasks.py
  33. 2 2
      foundation/logger/loggering.py
  34. 2 35
      foundation/models/silicon_flow.py
  35. 108 0
      foundation/rag/vector/base_vector.py
  36. 367 0
      foundation/rag/vector/milvus_vector.py
  37. 269 0
      foundation/rag/vector/pg_vector.py
  38. 2 2
      foundation/trace/celery_trace.py
  39. 1 0
      foundation/utils/redis_utils.py
  40. 5 2
      foundation/utils/tool_utils.py
  41. 3 1
      requirements.txt
  42. 1 1
      run.sh
  43. 49 27
      temp/AI审查结果.json
  44. 0 281
      test/construction_review/api_test_client.py
  45. 0 370
      test/construction_review/test_error_codes_pytest.py
  46. 5 4
      views/__init__.py
  47. 1 0
      views/construction_review/file_upload.py
  48. 161 126
      views/construction_review/task_progress.py
  49. 323 9
      views/test_views.py

+ 2 - 1
.gitignore

@@ -61,4 +61,5 @@ target/
 todo.md
 .design
 .claude
-.R&D
+.R&D
+temp\AI审查结果.json

+ 1 - 1
Dockerfile

@@ -1,4 +1,4 @@
-FROM python:3.13-slim
+FROM python:3.12-slim
 
 ENV DEBIAN_FRONTEND=noninteractive \
     TZ=Asia/Shanghai

+ 37 - 2
README.md

@@ -10,17 +10,30 @@
     - uvicorn server.app:app --port=8001 --host=0.0.0.0
     - gunicorn -c gunicorn_config.py server.app:app       多进程启动
 
+    - python .\views\construction_review\app.py
 
+
+    
+
+
+
+  
     pip install aioredis -i https://mirrors.aliyun.com/pypi/simple/
     pip install langgraph-checkpoint-postgres -i https://mirrors.aliyun.com/pypi/simple/
     pip install langchain-redis -i https://mirrors.aliyun.com/pypi/simple/
 
 
-
 ### 增加组件依赖
   pip install aiomysql -i https://mirrors.aliyun.com/pypi/simple/
 
 
+
+### 向量模型和重排序模型测试
+  cd LQAgentPlatform
+  python foundation/models/silicon_flow.py
+
+
+
 ### 测试接口
 
   #### 生成模型接口 
@@ -118,4 +131,26 @@ curl -X POST "http://localhost:8001/test/agent/stream" \
           },
           "input": "李四"
         }
-      
+      
+
+
+#### 向量检索测试
+  - 向量检索测试
+   http://localhost:8001/test/bfp/search
+   {
+      "config": {
+          "session_id":"20"
+      },
+      "input": "安全生产条件"
+    }
+
+  - 向量检索和重排序测试
+    - http://localhost:8001/test/bfp/search/rerank
+    {
+      "config": {
+          "session_id":"20"
+      },
+      "input": "安全生产条件"
+    }
+
+

+ 38 - 0
README_deploy.md

@@ -0,0 +1,38 @@
+
+sentence-transformers==4.1.0
+
+#### docker 容器部署
+##### 目前采用离线打包docker容器上传部署模式
+  - 1、本地容器打包        docker build -t lq_agent_platform_server:v0.1 .
+    - 1.1、保存本地镜像文件    docker save -o lq_agent_platform_v0.1.img lq_agent_platform:v0.1
+    - 1.2、容器压缩           tar -czvf lq_agent_platform_v0.1.tar.gz lq_agent_platform_v0.1.img
+    - 1.3、sftp上传到测试环境目录:       /home/cjb/lq_workspace/app/LqAgentServer/docker_tmp
+    - 1.4、容器解压           tar -xzvf lq_agent_platform_v0.1.tar.gz
+  - 6、删除测试环境原镜像   docker rmi lq_agent_platform:v0.1
+  - 7、容器加载           docker load -i lq_agent_platform_v0.1.img
+  - 8、容器启动
+    
+    docker run --name=LQAgentServer -d  --memory="4096m" --memory-swap="5000m" --cpus="3" --cpuset-cpus="0-2" --restart=always -p 8001:8001 -v /home/cjb/lq_workspace/app/LqAgentServer/config:/app/config/ -v /home/cjb/lq_workspace/app/LqAgentServer/gunicorn_log/:/app/gunicorn_log/  -v /home/cjb/lq_workspace/app/LqAgentServer/logs/:/app/logs/ --network=host lq_agent_platform_server:v0.1
+
+
+
+  - 9、容器copy文件
+     - 进入容器查看文件:docker exec -it LQAgentServer /bin/sh 
+     - copy外部文件到容器内:docker cp gunicorn_config.py LQAgentServer:/app/gunicorn_config
+  
+    - 实例启动: docker start LQAgentServer
+    - 实例停止: docker stop LQAgentServer
+    - 实例重启: docker restart LQAgentServer
+    - 使用docker查看日志  docker logs -f LQAgentServer
+    - 使用docker查看日志  docker logs -f --tail {行数} LQAgentServer
+
+  - 9、路径文件映射
+    - 配置文件
+      - 宿主机:/home/cjb/lq_workspace/app/LqAgentServer/config/ ==> 容器:/app/config/ 
+      - 宿主机:/home/cjb/lq_workspace/app/LqAgentServer/logs/ ==> 容器:/app/logs/
+    - 配置文件路径说明,注意修改后重启容器
+      - 宿主机配置文件路径
+        vim /home/cjb/lq_workspace/app/LqAgentServer/config/prompt/{prompt文件名称}.yaml
+        vim /home/cjb/lq_workspace/app/LqAgentServer/config/config.ini
+
+      - 宿主机日志地址 /home/cjb/lq_workspace/app/LqAgentServer/logs/

BIN
build_graph_app.png


+ 28 - 10
config/config.ini

@@ -1,13 +1,13 @@
 
 
 [model]
-MODEL_TYPE=qwen_local_1.5b
+MODEL_TYPE=gemini
 
 
 
 [gemini]
 GEMINI_SERVER_URL=https://generativelanguage.googleapis.com
-GEMINI_MODEL_ID=gemini-2.5-flash
+GEMINI_MODEL_ID=gemini-2.0-flash
 GEMINI_API_KEY=AIzaSyDcL1AZS4u9N-8OyE7q7M25wvYZhj2okJc
 
 [deepseek]
@@ -29,7 +29,7 @@ QWEN_API_KEY=ms-9ad4a379-d592-4acd-b92c-8bac08a4a045
 
 [ai_review]
 # 调试模式配置
-MAX_REVIEW_UNITS=10
+MAX_REVIEW_UNITS=3
 REVIEW_MODE=random
 # REVIEW_MODE=all/random/first
 
@@ -41,11 +41,11 @@ APP_SECRET=sx-73d32556-605e-11f0-9dd8-acde48001122
 
 
 [redis]
-REDIS_URL=redis://:123456@127.0.0.1:6379
-REDIS_HOST=127.0.0.1
+REDIS_URL=redis://:Wxcz666@@192.168.0.3:6379
+REDIS_HOST=192.168.0.3
 REDIS_PORT=6379
 REDIS_DB=0
-REDIS_PASSWORD=123456
+REDIS_PASSWORD=Wxcz666@
 REDIS_MAX_CONNECTIONS=50
 
 [log]
@@ -78,11 +78,29 @@ QWEN_LOCAL_14B_API_KEY=sk-dummy
 
 
 [mysql]
-MYSQL_HOST=localhost
-MYSQL_PORT=3306
+MYSQL_HOST=192.168.0.3
+MYSQL_PORT=13306
 MYSQL_USER=root
-MYSQL_PASSWORD=admin
+MYSQL_PASSWORD=lq@123
 MYSQL_DB=lq_db
 MYSQL_MIN_SIZE=1
-MYSQL_MAX_SIZE=2
+MYSQL_MAX_SIZE=5
 MYSQL_AUTO_COMMIT=True
+
+
+
+
+[pgvector]
+PGVECTOR_HOST=124.223.140.149
+PGVECTOR_PORT=7432
+PGVECTOR_DB=vector_db
+PGVECTOR_USER=vector_user
+PGVECTOR_PASSWORD=pg16@123
+
+
+[milvus]
+MILVUS_HOST=124.223.140.149
+MILVUS_PORT=7432
+MILVUS_DB=vector_db
+MILVUS_USER=vector_user
+MILVUS_PASSWORD=pg16@123

+ 10 - 0
config/prompt/common_model_query.yaml

@@ -0,0 +1,10 @@
+
+# 任务提示词
+task_prompt: |
+  你是一个智能助手,根据提供的信息回答问题。
+
+
+
+# test
+template: |
+  ## 测试内容

+ 22 - 0
config/prompt/intent_prompt.yaml

@@ -0,0 +1,22 @@
+
+# 系统提示词
+system_prompt: |
+  基于提供的样例,结合用户最近的对话历史上下文进行意图识别,精准匹配对应的业务场景指令。
+  必须优先参考最近的上下文语义及用户意图演变,若问题与样例中的任一业务场景相符,则返回对应指令;若无法匹配任何已定义场景,则返回 chat_box_generate。
+  严格遵守:仅输出指令字符串,不附加任何解释、说明或格式。
+  用户目前历史上下文信息:
+  {history}
+
+
+
+
+# 意图案例 准备few-shot样例;
+intent_examples: 
+  - inn: 你好;咨询.
+    out: chat_box_generate
+
+  - inn: 执行;操作;查询;处理;
+    out: common_agent
+
+
+           

+ 19 - 0
config/prompt/system_prompt.yaml

@@ -0,0 +1,19 @@
+
+
+# 系统提示词
+system_prompt: |
+  分析专家于一身的AI助手,提供全方位的智能化指导。
+        你的建议要务实、经济、易操作,并能基于物联网数据提供精准预警和具体解决方案。
+            
+
+    
+# 用户上下文会话记录 摘要提示词
+summary_system_prompt: |
+  请总结以下对话内容,保留关键信息:
+  {history}
+
+
+
+# test
+template: |
+  ## 测试内容

+ 68 - 0
config/sql/lq_db.sql

@@ -0,0 +1,68 @@
+
+
+
+-- 1、编制依据基本信息表
+
+
+DROP TABLE IF EXISTS t_basis_of_preparation;
+CREATE TABLE IF NOT EXISTS t_basis_of_preparation (
+    id INT AUTO_INCREMENT PRIMARY KEY COMMENT '标准唯一标识符',
+    chinese_name VARCHAR(500) NOT NULL COMMENT '中文标准名称',
+    english_name VARCHAR(500) COMMENT '英文标准名称',
+    standard_no VARCHAR(100)  COMMENT '标准编号',
+    issuing_authority VARCHAR(200) COMMENT '发布机构',
+    release_date DATE COMMENT '发布日期',
+    implementation_date DATE COMMENT '实施日期',
+    drafting_unit VARCHAR(300) COMMENT '起草单位',
+    approving_department VARCHAR(200) COMMENT '批准部门',
+    document_type VARCHAR(10) COMMENT '标准类型: national-国家标准, industry-行业标准, local-地方标准, enterprise-企业标准',
+    professional_field VARCHAR(15) COMMENT '专业领域:Laws-法律,Technical-技术规范,Reference-参考规范,Internal-内部规范',
+    engineering_phase VARCHAR(100) COMMENT '工程阶段',
+    participating_units VARCHAR(800) COMMENT '参编单位',
+    reference_basis_list VARCHAR(1000) COMMENT '参考依据列表',
+    file_url VARCHAR(500) COMMENT '文件路径',
+		status VARCHAR(10) COMMENT '状态:current-现行,作废-void',
+    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '记录创建时间',
+    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '记录更新时间',
+    INDEX idx_standard_no (standard_no) COMMENT '标准编号索引',
+    INDEX idx_chinese_name (chinese_name(100)) COMMENT '中文名称索引',
+    INDEX idx_release_date (release_date) COMMENT '发布日期索引',
+    INDEX idx_document_type (document_type) COMMENT '标准类型索引',
+    INDEX idx_professional_field (professional_field) COMMENT '专业领域索引'
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='编制依据基本信息表';
+
+
+
+
+
+
+
+
+INSERT INTO t_basis_of_preparation (
+    chinese_name, english_name, standard_no, issuing_authority, 
+    release_date, implementation_date, drafting_unit, approving_department, 
+    document_type, professional_field, engineering_phase, participating_units, 
+    reference_basis_list, file_url, status
+) VALUES
+('中华人民共和国安全生产法', NULL, NULL, NULL, '2021-06-10', NULL, NULL, NULL, 'national', 'Laws', NULL, NULL, NULL, 'https://safety.jining.gov.cn/module/download/downfile.jsp?classid=0&showname=%E4%B8%AD%E5%8D%8E%E4%BA%BA%E6%B0%91%E5%85%B1%E5%92%8C%E5%9B%BD%E5%AE%89%E5%85%A8%E7%94%9F%E4%BA%A7%E6%B3%95%EF%BC%882021%E5%B9%B46%E6%9C%8810%E6%97%A5%E4%BF%AE%E8%AE%A2%E7%89%88%EF%BC%89.pdf&filename=3b0ee62a494049869e9361ec8ee4fb83.pdf', 'current'),
+('公路水运工程质量监督管理规定', NULL, NULL, '交通运输部', '2017-09-14', NULL, NULL, NULL, 'industry', 'Laws', NULL, NULL, NULL, 'https://xxgk.mot.gov.cn/2020/jigou/fgs/202006/t20200623_3307899.html', 'current'),
+('公路水运工程拟淘汰危及生产安全施工工艺、设备和材料目录', NULL, NULL, '交通运输部', NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'http://ztjfjt.jtgs.taizhou.gov.cn/cms_files/filemanager/1718223565/attach/20235/7485f997a006433f9d2530c46a4b9861.pdf', 'current'),
+('公路桥涵施工技术规范', NULL, 'JTG/T3650-2020', '交通运输部', NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://xxgk.mot.gov.cn/2020/jigou/glj/202006/P020200630665628060420.pdf', 'current'),
+('公路工程质量检验评定标准', NULL, 'JTGF80-1-2017', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://jtst.mot.gov.cn/hb/search/stdHBDetailed?id=dd2ffc7d8c33835bad290e9d741f0634', 'current'),
+('公路工程施工安全技术规范', NULL, 'JTGF90-2015', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://jtst.mot.gov.cn/hb/search/stdHBDetailed?id=4c4ab59797b5b4013c4089972fbb2290', 'current'),
+('混凝土结构工程施工质量验收规范', NULL, 'GB50204-2015', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'http://www.cdapm.com.cn/upload/%E6%B7%B7%E5%87%9D%E5%9C%9F%E7%BB%93%E6%9E%84%E5%B7%A5%E7%A8%8B%E6%96%BD%E5%B7%A5%E8%B4%A8%E9%87%8F%E9%AA%8C%E6%94%B6%E8%A7%84%E8%8C%83GB%2050204-2015.pdf', 'current'),
+('施工现场临时用电安全技术规范', NULL, 'JGJ46-2016', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://zjw.sh.gov.cn/cmsres/73/7320cf3c54aa4a34827bfecbe6ea293d/5a01c703dcca637c3b9247f4c001542f.pdf', 'current'),
+('建筑施工塔式起重机安装、使用、拆卸安全技术规范', NULL, 'JGJ196-2010', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://zjw.sh.gov.cn/cmsres/99/99e29d723c8e49a488df5f787a529711/1314c992b03eb944fe2a020c26d457ae.pdf', 'current'),
+('建筑施工高空作业安全技术规范', NULL, 'JGJ80-2016', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://zjw.sh.gov.cn/cmsres/dd/dd2874d657124e648b54c66a113fb0b1/2b641c95070e63127349d11cc3109bc6.pdf', 'current'),
+('混凝土结构设计规范2015 年版', NULL, 'GB50010-2010', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://www.gbwindows.net/ow-content/uploads/download/gfbzdown/1.0.5%E6%9D%A1/%E5%85%B3%E8%81%94%E6%A0%87%E5%87%86/GB50010-2010(2015%E7%89%88)%20%20%E6%B7%B7%E5%87%9D%E5%9C%9F%E7%BB%93%E6%9E%84%E8%AE%BE%E8%AE%A1%E8%A7%84%E8%8C%83.pdf', 'current'),
+('混凝土结构工程施工质量验收规范', NULL, 'GB50204-2015', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'http://www.cdapm.com.cn/upload/%E6%B7%B7%E5%87%9D%E5%9C%9F%E7%BB%93%E6%9E%84%E5%B7%A5%E7%A8%8B%E6%96%BD%E5%B7%A5%E8%B4%A8%E9%87%8F%E9%AA%8C%E6%94%B6%E8%A7%84%E8%8C%83GB%2050204-2015.pdf', 'current'),
+('建筑施工模板安全技术规程', NULL, 'JGJ162-2008', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'http://www.cdapm.com.cn/upload/%E5%BB%BA%E7%AD%91%E6%96%BD%E5%B7%A5%E6%A8%A1%E6%9D%BF%E5%AE%89%E5%85%A8%E6%8A%80%E6%9C%AF%E8%A7%84%E8%8C%83JGJ162-2008.pdf', 'current'),
+('G4216 线屏山新市至金阳段高速公路 XJ4 标段两阶段施工设计图纸', NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'enterprise', 'Reference', NULL, NULL, NULL, NULL, 'current'),
+('建设单位明确的工程施工工期、质量和环境保护要求以及关键工程控制要点', NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'enterprise', 'Laws', NULL, NULL, NULL, NULL, 'current'),
+('本项目总体施工组织设计', NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'enterprise', 'Reference', NULL, NULL, NULL, NULL, 'current'),
+('四川路桥集团《工程技术管理办法》及《工程质量管理办法》', NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'enterprise', 'Internal', NULL, NULL, NULL, NULL, 'current'),
+('《起重机械安全规程》', NULL, 'B6067-2010', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://openstd.samr.gov.cn/bzgk/gb/newGbInfo?hcno=9DED7058601D511BFD5EEE88677548D8', 'current'),
+('《架桥机通用技术条件》', NULL, 'GB/T26470-2011', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://openstd.samr.gov.cn/bzgk/gb/newGbInfo?hcno=F8FC50E035D93142F37F28F0F5E8B678', 'current'),
+('《架桥机安全规程》', NULL, 'GB 26496-2011', NULL, NULL, NULL, NULL, NULL, 'industry', 'Technical', NULL, NULL, NULL, 'https://openstd.samr.gov.cn/bzgk/gb/newGbInfo?hcno=DF194527717A2C929434449D62FF8196', 'current'),
+('《公路水运工程安全生产监督管理办法》', NULL, '交通运输部令2017 年第25号', '交通运输部', NULL, NULL, NULL, NULL, 'industry', 'Laws', NULL, NULL, NULL, 'https://xxgk.mot.gov.cn/2020/gz/202112/t20211227_3633480.html', 'current'),
+('《危险性较大的分部分项工程安全管理规定》', NULL, '住建部令第37 号', '住房和城乡建设部', NULL, NULL, NULL, NULL, 'industry', 'Laws', NULL, NULL, NULL, 'https://www.gov.cn/gongbao/content/2018/content_5294422.htm', 'current');

+ 1 - 1
config/sql/test.sql

@@ -2,7 +2,7 @@
 
 
  -- 测试信息表
-
+ DROP TABLE IF EXISTS test_tab;
 CREATE TABLE IF NOT EXISTS test_tab (
     id INT AUTO_INCREMENT PRIMARY KEY COMMENT '用户唯一标识符',
     name VARCHAR(100) NOT NULL COMMENT '用户姓名',

+ 266 - 65
core/base/progress_manager.py

@@ -4,59 +4,219 @@
 """
 
 import json
+import asyncio
 from typing import Dict, Any, Optional
 from datetime import datetime
 
 from foundation.logger.loggering import server_logger as logger
+from foundation.base.config import config_handler
+
+class SSECallbackManager:
+    """SSE回调管理器 - 单例模式管理全局SSE回调"""
+    _instance = None
+    _callbacks = {}  # {callback_task_id: callback_function}
+
+    def __new__(cls):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def register_callback(self, callback_task_id: str, callback_func):
+        """注册SSE回调函数"""
+        self._callbacks[callback_task_id] = callback_func
+        logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
+
+    def unregister_callback(self, callback_task_id: str):
+        """注销SSE回调函数"""
+        if callback_task_id in self._callbacks:
+            del self._callbacks[callback_task_id]
+            logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
+
+    async def trigger_callback(self, callback_task_id: str, current_data: dict):
+        """触发SSE回调"""
+        if callback_task_id in self._callbacks:
+            try:
+                # 直接异步执行回调,保持trace上下文
+                await self._callbacks[callback_task_id](callback_task_id, current_data)
+                logger.debug(f"SSE回调执行成功: {callback_task_id}")
+
+                logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
+                return True
+
+            except Exception as e:
+                logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
+                return False
+        else:
+            logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}")
+            return False
+
+    def get_callbacks_count(self):
+        """获取当前回调数量"""
+        return len(self._callbacks)
+
+    def clear_all_callbacks(self):
+        """清空所有回调"""
+        self._callbacks.clear()
+        logger.info("已清空所有SSE回调")
+
+# 全局SSE回调管理器实例
+sse_callback_manager = SSECallbackManager()
 
 class ProgressManager:
-    """任务进度管理器"""
+    """任务进度管理器 - 增长型进度管理版本"""
 
     def __init__(self):
-        self.progress_data = {}  # 简化:使用内存存储
+        self.redis_client = None
+        self.redis_connected = False
+        self._init_redis()
+
+    def _init_redis(self):
+        """初始化Redis连接"""
+        try:
+            import redis
+
+            redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
+            redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
+            redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
+            redis_db = config_handler.get('redis', 'REDIS_DB', '0')
+
+            # 构建Redis连接URL
+            if redis_password:
+                redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
+            else:
+                redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
+
+            logger.debug(f"ProgressManager连接Redis: {redis_url}")
+
+            # 连接Redis
+            self.redis_client = redis.from_url(redis_url, decode_responses=True)
+
+            # 测试连接
+            self.redis_client.ping()
+            self.redis_connected = True
+            logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
+
+        except Exception as e:
+            logger.error(f"ProgressManager Redis连接失败: {e}")
+            self.redis_connected = False
+            logger.warning("ProgressManager将使用内存存储作为备选方案")
+            self.current_data = {}  # 备选内存存储
+
+    async def _get_redis_key(self, callback_task_id: str) -> str:
+        """获取Redis键名"""
+        return f"current:{callback_task_id}"
 
     async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
         """初始化进度记录"""
         try:
-            self.progress_data[callback_task_id] = {
+
+            # 设置总量为100(百分比模式)
+            stage_name = stages[0]["stage_name"] if stages else ""
+            message = "任务开始"
+
+            current_data = {
                 "user_id": user_id,
-                "overall_progress": 0,
-                "current_stage": stages[0]["stage_name"] if stages else "",
-                "stages": stages,
-                "updated_at": datetime.now()
+                "current": 0,
+                "stage_name": "",
+                "status": "准备开始",
+                "message": "任务开始",
+                "updated_at": datetime.now().isoformat(),
+                "overall_task_status": "pending"
             }
-            logger.info(f"初始化任务进度: {callback_task_id}")
+
+            if self.redis_connected:
+                # 使用同步Redis操作避免异步任务销毁问题
+                try:
+                    redis_key = await self._get_redis_key(callback_task_id)
+                    self.redis_client.setex(
+                        redis_key,
+                        3600,  # 1小时过期
+                        json.dumps(current_data)
+                    )
+                    logger.info(f"初始化任务进度列表")
+                except Exception as redis_e:
+                    logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
+                    # 降级到内存存储
+                    if not hasattr(self, 'current_data'):
+                        self.current_data = {}
+                    self.current_data[callback_task_id] = current_data
+                    logger.info(f"降级使用内存存储: {callback_task_id}")
+            else:
+                # 使用内存存储
+                if not hasattr(self, 'current_data'):
+                    self.current_data = {}
+                self.current_data[callback_task_id] = current_data
+                logger.info(f"初始化任务进度到内存: {callback_task_id}")
 
         except Exception as e:
             logger.error(f"初始化进度失败: {str(e)}")
             raise
 
-    async def update_stage_progress(self, callback_task_id: str, stage_name: str,
-                                  progress: int, status: str, message: str = "",
-                                  sub_progress: int = 0):
+    async def update_stage_progress(self, callback_task_id: str, stage_name: str, current: int, status: str, message: str = ""):
         """更新阶段进度"""
         try:
-            if callback_task_id not in self.progress_data:
-                logger.warning(f"任务进度不存在: {callback_task_id}")
-                return
+            task_progress = None
+
+            if self.redis_connected:
+                # 从Redis读取
+                redis_key = await self._get_redis_key(callback_task_id)
+                progress_json = self.redis_client.get(redis_key)
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
+                    return
+            else:
+                # 从内存读取
+                if callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.warning(f"内存中未找到任务进度: {callback_task_id}")
+                    return
 
-            task_progress = self.progress_data[callback_task_id]
+            # 更新进度数据
+            task_progress["current"] = current
+            task_progress["stage_name"] = stage_name
+            task_progress["status"] = status
+            task_progress["message"] = message
+            task_progress["updated_at"] = datetime.now().isoformat()
 
-            # 更新阶段进度
-            for stage in task_progress["stages"]:
-                if stage["stage_name"] == stage_name:
-                    stage["progress"] = progress
-                    stage["stage_status"] = status
-                    stage["message"] = message
-                    stage["sub_progress"] = sub_progress
-                    break
+            # 保留overall_task_status字段,不要被普通进度更新覆盖
+            if "overall_task_status" not in task_progress:
+                task_progress["overall_task_status"] = "processing"
 
-            # 更新当前阶段和整体进度
-            task_progress["current_stage"] = stage_name
-            task_progress["overall_progress"] = self._calculate_overall_progress(task_progress["stages"])
-            task_progress["updated_at"] = datetime.now()
+            try:
+                if self.redis_connected:
+                    try:
+                        self.redis_client.setex(
+                            redis_key,
+                            3600,  # 1小时过期
+                            json.dumps(task_progress)
+                        )
+                        logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
+                    except Exception as sync_e:
+                        logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
+                        # 同步操作也失败时,降级到内存存储
+                        if not hasattr(self, 'current_data'):
+                            self.current_data = {}
+                        self.current_data[callback_task_id] = task_progress
+                        logger.debug(f"降级使用内存存储: {callback_task_id}")
+                else:
+                    if not hasattr(self, 'current_data'):
+                        self.current_data = {}
+                    self.current_data[callback_task_id] = task_progress
+                    logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
+            except Exception as e:
+                logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
+                if not hasattr(self, 'current_data'):
+                    self.current_data = {}
+                self.current_data[callback_task_id] = task_progress
 
-            logger.debug(f"更新进度: {callback_task_id}, 阶段: {stage_name}, 进度: {progress}%")
+            # 触发SSE推送 - 使用全局回调管理器
+            logger.debug(f"触发SSE推送: {callback_task_id}")
+            updated_progress = await self.get_progress(callback_task_id)
+            if updated_progress:
+                await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
 
         except Exception as e:
             logger.error(f"更新阶段进度失败: {str(e)}")
@@ -65,61 +225,102 @@ class ProgressManager:
     async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
         """获取任务进度"""
         try:
-            if callback_task_id not in self.progress_data:
-                return None
-
-            task_progress = self.progress_data[callback_task_id]
-
-            # 计算整体状态
-            if any(stage["stage_status"] == "failed" for stage in task_progress["stages"]):
-                review_task_status = "failed"
-            elif all(stage["stage_status"] == "completed" for stage in task_progress["stages"]):
-                review_task_status = "completed"
-            elif any(stage["stage_status"] == "processing" for stage in task_progress["stages"]):
-                review_task_status = "processing"
+            logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
+            task_progress = None
+
+            if self.redis_connected:
+                # 从Redis读取
+                redis_key = await self._get_redis_key(callback_task_id)
+                logger.debug(f"Redis键: {redis_key}")
+                progress_json = self.redis_client.get(redis_key)
+                logger.debug(f"从Redis读取数据: {progress_json is not None}")
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
+                    return None
             else:
-                review_task_status = "pending"
+                # 从内存读取
+                if hasattr(self, 'current_data') and callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.debug(f"内存中未找到任务进度: {callback_task_id}")
+                    return None
+
+            # 获取overall_task_status,默认为"pending"
+            overall_task_status = task_progress.get("overall_task_status", "pending")
+
+            # 转换时间戳
+            updated_at = task_progress["updated_at"]
+            if isinstance(updated_at, str):
+                updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
+            else:
+                updated_at_timestamp = int(updated_at.timestamp())
 
             return {
                 "callback_task_id": callback_task_id,
                 "user_id": task_progress["user_id"],
-                "review_task_status": review_task_status,
-                "overall_progress": task_progress["overall_progress"],
-                "stages": task_progress["stages"],
-                "updated_at": int(task_progress["updated_at"].timestamp()),
-                "estimated_remaining": 600
+                "current": task_progress["current"],
+                "stage_name": task_progress["stage_name"],
+                "status": task_progress["status"],
+                "message": task_progress["message"],
+                "overall_task_status": overall_task_status,
+                "updated_at": updated_at_timestamp
             }
 
         except Exception as e:
             logger.error(f"获取进度失败: {str(e)}")
             return None
 
-    async def complete_task(self, callback_task_id: str, result: Dict[str, Any]):
+    async def complete_task(self, callback_task_id: str):
         """标记任务完成"""
         try:
-            if callback_task_id in self.progress_data:
-                task_progress = self.progress_data[callback_task_id]
-
-                # 完成最后一个阶段
-                if task_progress["stages"]:
-                    task_progress["stages"][-1]["stage_status"] = "completed"
-                    task_progress["stages"][-1]["progress"] = 100
+            task_progress = None
+            logger.info(f"通知sse连接关闭: {callback_task_id}")
+            if self.redis_connected:
+                redis_key = await self._get_redis_key(callback_task_id)
+                progress_json = self.redis_client.get(redis_key)
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
+                    return
+            else:
+                # 从内存读取
+                if hasattr(self, 'current_data') and callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.warning(f"内存中未找到任务进度: {callback_task_id}")
+                    return
 
-                task_progress["overall_progress"] = 100
-                task_progress["updated_at"] = datetime.now()
+            task_progress["status"] = "completed"
+            task_progress["overall_task_status"] = "completed"
+            task_progress["message"] = "任务已全部完成"
+            task_progress["updated_at"] = datetime.now().isoformat()
 
-                # 保存结果
-                task_progress["result"] = result
 
-            logger.info(f"任务完成: {callback_task_id}")
+            # 保存更新后的数据
+            if self.redis_connected:
+                self.redis_client.setex(
+                    redis_key,
+                    3600,
+                    json.dumps(task_progress)
+                )
+            else:
+                if hasattr(self, 'current_data'):
+                    self.current_data[callback_task_id] = task_progress
 
+            # 触发SSE进度更新推送
+            completed_progress = await self.get_progress(callback_task_id)
+            if completed_progress:
+                await sse_callback_manager.trigger_callback(callback_task_id, completed_progress)
+                logger.debug(f"SSE完成进度已推送: {callback_task_id}")
+            else:
+                logger.warning(f"无法获取完成进度数据: {callback_task_id}")
         except Exception as e:
             logger.error(f"标记任务完成失败: {str(e)}")
             raise
 
-    def _calculate_overall_progress(self, stages: list) -> int:
-        """计算整体进度"""
-        if not stages:
-            return 0
-        total_progress = sum(stage["progress"] for stage in stages)
-        return int(total_progress / len(stages))
+
+
+    

+ 45 - 13
core/base/workflow_manager.py

@@ -12,6 +12,7 @@ from dataclasses import dataclass
 from langgraph.graph import StateGraph, END
 from langgraph.graph.message import add_messages
 from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
+import json
 
 from foundation.logger.loggering import server_logger as logger
 from foundation.utils.time_statistics import track_execution_time
@@ -19,6 +20,28 @@ from .progress_manager import ProgressManager
 from .redis_duplicate_checker import RedisDuplicateChecker
 from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow,ReportWorkflow
 
+class ProgressManagerRegistry:
+    """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
+    _registry = {}  # {callback_task_id: ProgressManager}
+
+    @classmethod
+    def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
+        """注册ProgressManager实例"""
+        cls._registry[callback_task_id] = progress_manager
+        logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
+
+    @classmethod
+    def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
+        """获取ProgressManager实例"""
+        return cls._registry.get(callback_task_id)
+
+    @classmethod
+    def unregister_progress_manager(cls, callback_task_id: str):
+        """注销ProgressManager实例"""
+        if callback_task_id in cls._registry:
+            del cls._registry[callback_task_id]
+            logger.info(f"注销ProgressManager实例: {callback_task_id}")
+
 @dataclass
 class TaskChain:
     """任务链"""
@@ -48,7 +71,7 @@ class WorkflowManager:
         self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
 
         # 服务组件
-        self.progress_manager = ProgressManager()
+        self.progress_manager = ProgressManager()  # 简化:直接使用实例
         self.redis_duplicate_checker = RedisDuplicateChecker()
 
         # 活跃任务跟踪
@@ -58,12 +81,16 @@ class WorkflowManager:
     async def submit_task_processing(self, file_info: dict) -> str:
         """异步提交任务处理(用于file_upload层)"""
         from foundation.base.tasks import submit_task_processing_task
+        from foundation.trace.celery_trace import CeleryTraceManager
 
         try:
             logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
 
-            # 提交到Celery队列
-            task = submit_task_processing_task.delay(file_info)
+            # 使用CeleryTraceManager提交任务,自动传递trace_id
+            task = CeleryTraceManager.submit_celery_task(
+                submit_task_processing_task,
+                file_info
+            )
 
             logger.info(f"Celery任务已提交,Task ID: {task.id}")
             return task.id
@@ -85,8 +112,8 @@ class WorkflowManager:
             # 2. 创建任务链
             task_chain = TaskChain(
                 callback_task_id=callback_task_id,
-                file_id=file_info['file_id'],
-                user_id=file_info['user_id'],
+                file_id=file_info.get('file_id', ''),
+                user_id=file_info.get('user_id', 'default_user'),
                 status="processing",
                 current_stage="document_processing",
                 created_at=datetime.now()
@@ -96,16 +123,11 @@ class WorkflowManager:
             asyncio.run(self.redis_duplicate_checker.register_task(file_info, callback_task_id))
             self.active_chains[callback_task_id] = task_chain
 
-            # 5. 初始化进度
+            # 5. 初始化进度管理
             asyncio.run(self.progress_manager.initialize_progress(
                 callback_task_id=callback_task_id,
-                user_id=file_info['user_id'],
-                stages=[
-                    {"stage_name": "文件上传", "progress": 100, "status": "completed"},
-                    {"stage_name": "文档处理", "progress": 0, "status": "pending"},
-                    {"stage_name": "AI审查", "progress": 0, "status": "pending"},
-                    {"stage_name": "报告生成", "progress": 0, "status": "pending"}
-                ]
+                user_id=file_info.get('user_id', 'default_user'),
+                stages=[]
             ))
 
             # 6. 启动处理流程(同步执行)
@@ -205,6 +227,8 @@ class WorkflowManager:
 
             # 清理任务注册
             asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
+            # 通知SSE连接任务完成
+            asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
 
             logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
             return task_chain.results
@@ -216,6 +240,14 @@ class WorkflowManager:
             # 清理任务注册
             asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
 
+            # 通知SSE连接任务失败
+            error_result = {
+                "error": str(e),
+                "status": "failed",
+                "timestamp": datetime.now().isoformat()
+            }
+            asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
+
             raise
         finally:
             # 清理活跃任务

+ 1 - 0
core/construction_review/component/ai_review_engine.py

@@ -3,6 +3,7 @@ AI审查引擎
 负责执行AI审查,支持审查条目并发处理
 """
 
+import time
 import asyncio
 from enum import Enum
 from dataclasses import dataclass

+ 376 - 54
core/construction_review/component/document_processor.py

@@ -1,23 +1,43 @@
 """
 文档处理器
 负责文档解析、内容提取和结构化处理
+集成doc_worker模块的智能处理能力
 """
 
-import io   
-from docx import Document
+import io
+import os
+import tempfile
+from pathlib import Path
 from typing import Dict, Any, Optional, Callable
 from datetime import datetime
 
 from foundation.logger.loggering import server_logger as logger
 
-from langchain_community.document_loaders import PyPDFLoader
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+# 引入doc_worker核心组件
+try:
+    from ..doc_worker import TOCExtractor, TextSplitter, LLMClassifier
+    from ..doc_worker.config_loader import get_config
+except ImportError:
+    from core.construction_review.doc_worker import TOCExtractor, TextSplitter, LLMClassifier
+    from core.construction_review.doc_worker.config_loader import get_config
 
 class DocumentProcessor:
     """文档处理器"""
 
     def __init__(self):
         self.supported_types = ['pdf', 'docx']
+        # 初始化doc_worker组件
+        self.toc_extractor = TOCExtractor()
+        self.text_splitter = TextSplitter()
+        self.config = get_config()
+        # LLM分类器可选,如果配置了模型URL则初始化
+        self.llm_classifier = None
+        try:
+            model_url = self.config.llm_model_url
+            if model_url:
+                self.llm_classifier = LLMClassifier(model_url)
+        except Exception as e:
+            logger.warning(f"LLM分类器初始化失败,将使用基础处理模式: {str(e)}")
 
     async def process_document(self, file_content: bytes, file_type: str,
                              progress_callback: Optional[Callable[[int, str], None]] = None) -> Dict[str, Any]:
@@ -56,16 +76,286 @@ class DocumentProcessor:
             raise
 
     async def parse_pdf_content(self, file_content: bytes) -> Dict[str, Any]:
-        """解析PDF内容"""
+        """解析PDF内容,使用doc_worker的智能处理能力"""
+        temp_file_path = None
         try:
             # 保存到临时文件
-            import tempfile
             with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_file:
                 temp_file.write(file_content)
                 temp_file_path = temp_file.name
 
-            # 使用PyPDFLoader解析
-            loader = PyPDFLoader(temp_file_path)
+            logger.info(f"开始使用doc_worker处理PDF文档: {temp_file_path}")
+
+            # 步骤1: 提取目录
+            logger.info("步骤1: 提取文档目录")
+            toc_info = self.toc_extractor.extract_toc(temp_file_path)
+            
+            if toc_info['toc_count'] == 0:
+                logger.warning("未检测到目录,使用基础处理模式")
+                return await self._fallback_pdf_processing(temp_file_path)
+
+            logger.info(f"成功提取 {toc_info['toc_count']} 个目录项")
+
+            # 步骤2: 使用LLM进行分类(如果可用)
+            classified_items = None
+            target_level = self.config.target_level
+            
+            if self.llm_classifier:
+                try:
+                    logger.info(f"步骤2: 使用LLM对{target_level}级目录进行分类")
+                    classification_result = self.llm_classifier.classify(
+                        toc_info['toc_items'],
+                        target_level=target_level
+                    )
+                    if classification_result:
+                        classified_items = classification_result['items']
+                        logger.info(f"分类完成,共分类 {len(classified_items)} 个目录项")
+                except Exception as e:
+                    logger.warning(f"LLM分类失败,使用目录项直接处理: {str(e)}")
+            
+            # 如果没有分类结果,使用原始目录项(筛选目标层级)
+            if not classified_items:
+                classified_items = [
+                    item for item in toc_info['toc_items'] 
+                    if item['level'] == target_level
+                ]
+                # 为每个目录项添加默认分类信息
+                for item in classified_items:
+                    item['category'] = '未分类'
+                    item['category_code'] = 'other'
+
+            # 步骤3: 提取文档全文
+            logger.info("步骤3: 提取文档全文")
+            pages_content = self.text_splitter.extract_full_text(temp_file_path)
+            
+            if not pages_content:
+                logger.warning("无法提取文档全文,使用基础处理模式")
+                return await self._fallback_pdf_processing(temp_file_path)
+
+            total_chars = sum(len(page['text']) for page in pages_content)
+            logger.info(f"提取完成,共 {len(pages_content)} 页,{total_chars} 个字符")
+
+            # 步骤4: 按分类标题智能切分文本
+            logger.info("步骤4: 按分类标题智能切分文本")
+            max_chunk_size = self.config.max_chunk_size
+            min_chunk_size = self.config.min_chunk_size
+            
+            chunks = self.text_splitter.split_by_hierarchy(
+                classified_items,
+                pages_content,
+                toc_info,
+                target_level=target_level,
+                max_chunk_size=max_chunk_size,
+                min_chunk_size=min_chunk_size
+            )
+
+            if not chunks:
+                logger.warning("未能生成任何文本块,使用基础处理模式")
+                return await self._fallback_pdf_processing(temp_file_path)
+
+            logger.info(f"切分完成,共生成 {len(chunks)} 个文本块")
+
+            # 适配返回格式
+            return {
+                'document_type': 'pdf',
+                'total_pages': len(pages_content),
+                'total_chunks': len(chunks),
+                'chunks': [
+                    {
+                        'page': chunk.get('element_tag', {}).get('page', 0),
+                        'content': chunk.get('review_chunk_content', ''),
+                        'metadata': {
+                            'chunk_id': chunk.get('chunk_id', ''),
+                            'section_label': chunk.get('section_label', ''),
+                            'project_plan_type': chunk.get('project_plan_type', ''),
+                            'element_tag': chunk.get('element_tag', {})
+                        }
+                    }
+                    for chunk in chunks
+                ],
+                'splits': [
+                    {
+                        'content': chunk.get('review_chunk_content', ''),
+                        'metadata': {
+                            'chunk_id': chunk.get('chunk_id', ''),
+                            'section_label': chunk.get('section_label', ''),
+                            'page': chunk.get('element_tag', {}).get('page', 0)
+                        }
+                    }
+                    for chunk in chunks
+                ],
+                'toc_info': toc_info,
+                'classification': {
+                    'items': classified_items,
+                    'target_level': target_level
+                } if classified_items else None
+            }
+
+        except Exception as e:
+            logger.error(f"PDF解析失败: {str(e)}")
+            # 如果智能处理失败,尝试基础处理
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    logger.info("尝试使用基础处理模式")
+                    return await self._fallback_pdf_processing(temp_file_path)
+                except Exception as fallback_error:
+                    logger.error(f"基础处理模式也失败: {str(fallback_error)}")
+            raise
+        finally:
+            # 清理临时文件
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    os.unlink(temp_file_path)
+                except Exception as e:
+                    logger.warning(f"清理临时文件失败: {str(e)}")
+
+    async def parse_docx_content(self, file_content: bytes) -> Dict[str, Any]:
+        """解析DOCX内容,使用doc_worker的智能处理能力"""
+        temp_file_path = None
+        try:
+            # 保存到临时文件
+            with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_file:
+                temp_file.write(file_content)
+                temp_file_path = temp_file.name
+
+            logger.info(f"开始使用doc_worker处理DOCX文档: {temp_file_path}")
+
+            # 步骤1: 提取目录
+            logger.info("步骤1: 提取文档目录")
+            toc_info = self.toc_extractor.extract_toc(temp_file_path)
+            
+            if toc_info['toc_count'] == 0:
+                logger.warning("未检测到目录,使用基础处理模式")
+                return await self._fallback_docx_processing(temp_file_path)
+
+            logger.info(f"成功提取 {toc_info['toc_count']} 个目录项")
+
+            # 步骤2: 使用LLM进行分类(如果可用)
+            classified_items = None
+            target_level = self.config.target_level
+            
+            if self.llm_classifier:
+                try:
+                    logger.info(f"步骤2: 使用LLM对{target_level}级目录进行分类")
+                    classification_result = self.llm_classifier.classify(
+                        toc_info['toc_items'],
+                        target_level=target_level
+                    )
+                    if classification_result:
+                        classified_items = classification_result['items']
+                        logger.info(f"分类完成,共分类 {len(classified_items)} 个目录项")
+                except Exception as e:
+                    logger.warning(f"LLM分类失败,使用目录项直接处理: {str(e)}")
+            
+            # 如果没有分类结果,使用原始目录项(筛选目标层级)
+            if not classified_items:
+                classified_items = [
+                    item for item in toc_info['toc_items'] 
+                    if item['level'] == target_level
+                ]
+                # 为每个目录项添加默认分类信息
+                for item in classified_items:
+                    item['category'] = '未分类'
+                    item['category_code'] = 'other'
+
+            # 步骤3: 提取文档全文
+            logger.info("步骤3: 提取文档全文")
+            pages_content = self.text_splitter.extract_full_text(temp_file_path)
+            
+            if not pages_content:
+                logger.warning("无法提取文档全文,使用基础处理模式")
+                return await self._fallback_docx_processing(temp_file_path)
+
+            total_chars = sum(len(page['text']) for page in pages_content)
+            logger.info(f"提取完成,共 {len(pages_content)} 页,{total_chars} 个字符")
+
+            # 步骤4: 按分类标题智能切分文本
+            logger.info("步骤4: 按分类标题智能切分文本")
+            max_chunk_size = self.config.max_chunk_size
+            min_chunk_size = self.config.min_chunk_size
+            
+            chunks = self.text_splitter.split_by_hierarchy(
+                classified_items,
+                pages_content,
+                toc_info,
+                target_level=target_level,
+                max_chunk_size=max_chunk_size,
+                min_chunk_size=min_chunk_size
+            )
+
+            if not chunks:
+                logger.warning("未能生成任何文本块,使用基础处理模式")
+                return await self._fallback_docx_processing(temp_file_path)
+
+            logger.info(f"切分完成,共生成 {len(chunks)} 个文本块")
+
+            # 适配返回格式
+            return {
+                'document_type': 'docx',
+                'total_pages': len(pages_content),
+                'total_chunks': len(chunks),
+                'chunks': [
+                    {
+                        'page': chunk.get('element_tag', {}).get('page', 0),
+                        'content': chunk.get('review_chunk_content', ''),
+                        'metadata': {
+                            'chunk_id': chunk.get('chunk_id', ''),
+                            'section_label': chunk.get('section_label', ''),
+                            'project_plan_type': chunk.get('project_plan_type', ''),
+                            'element_tag': chunk.get('element_tag', {})
+                        }
+                    }
+                    for chunk in chunks
+                ],
+                'splits': [
+                    {
+                        'content': chunk.get('review_chunk_content', ''),
+                        'metadata': {
+                            'chunk_id': chunk.get('chunk_id', ''),
+                            'section_label': chunk.get('section_label', ''),
+                            'page': chunk.get('element_tag', {}).get('page', 0)
+                        }
+                    }
+                    for chunk in chunks
+                ],
+                'full_text': ''.join([page['text'] for page in pages_content]),
+                'toc_info': toc_info,
+                'classification': {
+                    'items': classified_items,
+                    'target_level': target_level
+                } if classified_items else None,
+                'metadata': {
+                    'total_pages': len(pages_content),
+                    'total_chars': total_chars
+                }
+            }
+
+        except Exception as e:
+            logger.error(f"DOCX解析失败: {str(e)}")
+            # 如果智能处理失败,尝试基础处理
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    logger.info("尝试使用基础处理模式")
+                    return await self._fallback_docx_processing(temp_file_path)
+                except Exception as fallback_error:
+                    logger.error(f"基础处理模式也失败: {str(fallback_error)}")
+            raise
+        finally:
+            # 清理临时文件
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    os.unlink(temp_file_path)
+                except Exception as e:
+                    logger.warning(f"清理临时文件失败: {str(e)}")
+
+    async def _fallback_pdf_processing(self, file_path: str) -> Dict[str, Any]:
+        """PDF基础处理模式(当智能处理失败时使用)"""
+        try:
+            from langchain_community.document_loaders import PyPDFLoader
+            from langchain.text_splitter import RecursiveCharacterTextSplitter
+            
+            logger.info("使用基础PDF处理模式")
+            loader = PyPDFLoader(file_path)
             documents = loader.load()
 
             # 文本分块
@@ -75,23 +365,21 @@ class DocumentProcessor:
                 separators=["\n\n", "\n", " ", ""]
             )
             splits = text_splitter.split_documents(documents)
-            original_count = len(splits)  # 记录原始分块数量
 
-            # 过滤空内容切块,确保每个切块内容不为空
+            # 过滤空内容切块
             valid_splits = []
             for split in splits:
                 content = split.page_content.strip()
-                if content:  # 确保内容不为空
-                    split.page_content = content  # 更新清理后的内容
+                if content:
+                    split.page_content = content
                     valid_splits.append(split)
 
-            splits = valid_splits  # 使用过滤后的切块
-            logger.info(f"PDF解析完成,过滤前分块数量: {original_count},过滤后有效分块数量: {len(splits)}")
+            logger.info(f"基础处理完成,有效分块数量: {len(valid_splits)}")
 
             return {
                 'document_type': 'pdf',
                 'total_pages': len(documents),
-                'total_chunks': len(splits),
+                'total_chunks': len(valid_splits),
                 'chunks': [
                     {
                         'page': doc.metadata.get('page', 0),
@@ -105,22 +393,20 @@ class DocumentProcessor:
                         'content': split.page_content,
                         'metadata': split.metadata
                     }
-                    for split in splits
+                    for split in valid_splits
                 ]
             }
-
         except Exception as e:
-            logger.error(f"PDF解析失败: {str(e)}")
+            logger.error(f"基础PDF处理失败: {str(e)}")
             raise
 
-    async def parse_docx_content(self, file_content: bytes) -> Dict[str, Any]:
-        """解析DOCX内容"""
+    async def _fallback_docx_processing(self, file_path: str) -> Dict[str, Any]:
+        """DOCX基础处理模式(当智能处理失败时使用)"""
         try:
-            # 简化实现:直接返回文本内容
-            # 实际实现中可以使用python-docx库
-
-
-            doc = Document(io.BytesIO(file_content))
+            from docx import Document
+            
+            logger.info("使用基础DOCX处理模式")
+            doc = Document(file_path)
             full_text = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
 
             # 简单分块,并过滤空内容
@@ -129,7 +415,7 @@ class DocumentProcessor:
             chunk_index = 1
             for i in range(0, len(full_text), chunk_size):
                 chunk_text = full_text[i:i+chunk_size].strip()
-                if chunk_text:  # 确保切块内容不为空
+                if chunk_text:
                     chunks.append({
                         'chunk_id': f'chunk_{chunk_index}',
                         'content': chunk_text,
@@ -137,7 +423,7 @@ class DocumentProcessor:
                     })
                     chunk_index += 1
 
-            logger.info(f"DOCX解析完成,有效分块数量: {len(chunks)}")
+            logger.info(f"基础处理完成,有效分块数量: {len(chunks)}")
 
             return {
                 'document_type': 'docx',
@@ -149,48 +435,84 @@ class DocumentProcessor:
                     'word_count': len(full_text.split())
                 }
             }
-
         except Exception as e:
-            logger.error(f"DOCX解析失败: {str(e)}")
+            logger.error(f"基础DOCX处理失败: {str(e)}")
             raise
 
     def structure_content(self, raw_content: Dict[str, Any]) -> Dict[str, Any]:
-        """结构化处理"""
+        """结构化处理,适配doc_worker返回的格式"""
         try:
-            if raw_content['document_type'] == 'pdf':
-                # PDF结构化
+            document_type = raw_content.get('document_type', 'unknown')
+            
+            # 检查是否使用了doc_worker的智能处理(有toc_info或classification字段)
+            is_smart_processing = 'toc_info' in raw_content or 'classification' in raw_content
+            
+            if is_smart_processing:
+                # 使用doc_worker智能处理的结果
                 chunks = []
-                for i, chunk in enumerate(raw_content['chunks']):
-                    content = chunk['content'].strip()
-                    if content:  # 确保内容不为空
+                for chunk in raw_content.get('chunks', []):
+                    content = chunk.get('content', '').strip()
+                    if content:
+                        metadata = chunk.get('metadata', {})
+                        element_tag = metadata.get('element_tag', {})
+                        
                         chunks.append({
-                            'chunk_id': f'chunk_{i+1}',
-                            'page': chunk['page'],
+                            'chunk_id': metadata.get('chunk_id', ''),
+                            'page': chunk.get('page', 0),
                             'content': content,
-                            'chapter': f'第{chunk["page"]}页',
-                            'title': f'内容块{i+1}',
+                            'section_label': metadata.get('section_label', ''),
+                            'project_plan_type': metadata.get('project_plan_type', ''),
+                            'element_tag': element_tag,
+                            'chapter': metadata.get('section_label', f'第{chunk.get("page", 0)}页'),
+                            'title': metadata.get('section_label', ''),
                             'original_content': content[:100] + '...' if len(content) > 100 else content
                         })
             else:
-                # DOCX结构化 - 也进行空内容检查
-                all_chunks = raw_content.get('chunks', [])
-                chunks = []
-                for chunk in all_chunks:
-                    content = chunk.get('content', '').strip()
-                    if content:  # 确保内容不为空
-                        chunks.append({
-                            'chunk_id': chunk.get('chunk_id', f'chunk_{len(chunks)+1}'),
-                            'content': content,
-                            'metadata': chunk.get('metadata', {})
-                        })
+                # 使用基础处理的结果
+                if document_type == 'pdf':
+                    chunks = []
+                    for i, chunk in enumerate(raw_content.get('chunks', [])):
+                        content = chunk.get('content', '').strip() if isinstance(chunk, dict) else str(chunk).strip()
+                        if content:
+                            page = chunk.get('page', 0) if isinstance(chunk, dict) else 0
+                            chunks.append({
+                                'chunk_id': f'chunk_{i+1}',
+                                'page': page,
+                                'content': content,
+                                'chapter': f'第{page}页',
+                                'title': f'内容块{i+1}',
+                                'original_content': content[:100] + '...' if len(content) > 100 else content
+                            })
+                else:
+                    # DOCX基础处理
+                    all_chunks = raw_content.get('chunks', [])
+                    chunks = []
+                    for chunk in all_chunks:
+                        content = chunk.get('content', '').strip()
+                        if content:
+                            chunks.append({
+                                'chunk_id': chunk.get('chunk_id', f'chunk_{len(chunks)+1}'),
+                                'content': content,
+                                'metadata': chunk.get('metadata', {})
+                            })
 
-            return {
-                'document_name': f"施工方案文档_{raw_content.get('document_type', 'unknown')}",
-                'document_type': raw_content['document_type'],
-                'total_chunks': len(chunks),  # 使用实际的切块数量
+            # 构建返回结果
+            result = {
+                'document_name': f"施工方案文档_{document_type}",
+                'document_type': document_type,
+                'total_chunks': len(chunks),
                 'chunks': chunks,
                 'metadata': raw_content.get('metadata', {})
             }
+            
+            # 如果使用了智能处理,保留额外信息
+            if is_smart_processing:
+                if 'toc_info' in raw_content:
+                    result['toc_info'] = raw_content['toc_info']
+                if 'classification' in raw_content:
+                    result['classification'] = raw_content['classification']
+
+            return result
 
         except Exception as e:
             logger.error(f"内容结构化失败: {str(e)}")

+ 2 - 2
core/construction_review/component/reviewers/base_reviewer.py

@@ -9,7 +9,7 @@ import time
 from abc import ABC
 from typing import Dict, Any, Optional
 from dataclasses import dataclass
-# from langfuse import obverse
+#from langfuse import obverse
 from foundation.agent.monitor.ai_trace_monitor import lf
 from foundation.agent.generate.model_generate import generate_model_client
 from core.construction_review.component.reviewers.utils.prompt_loader import prompt_loader
@@ -32,7 +32,7 @@ class BaseReviewer(ABC):
         self.model_client = generate_model_client
         self.prompt_loader = prompt_loader
     
-    # @obverse
+    #@obverse
     async def review(self, name: str, trace_id: str, reviewer_type: str, prompt_name: str, review_content: str,review_references: str = None) -> ReviewResult:
         """
         执行审查

+ 50 - 0
core/construction_review/doc_worker/__init__.py

@@ -0,0 +1,50 @@
+"""
+文档分类切分库
+支持PDF和Word文档的目录提取、智能分类和文本切分
+
+主要功能:
+1. 提取PDF/Word文档的目录结构
+2. 使用大语言模型对目录进行智能分类
+3. 按目录层级和字符数智能切分文本
+4. 保存分类结果到多种格式
+
+使用示例:
+    from doc_classifier import DocumentClassifier
+    
+    # 创建分类器实例
+    classifier = DocumentClassifier(
+        model_url="http://172.16.35.50:8000/v1/chat/completions"
+    )
+    
+    # 处理文档
+    result = classifier.process_document(
+        file_path="document.pdf",
+        target_level=2,
+        output_dir="./output"
+    )
+"""
+
+__version__ = "2.0.0"
+__author__ = "Your Name"
+
+try:
+    from .core import DocumentClassifier
+    from .toc_extractor import TOCExtractor
+    from .text_splitter import TextSplitter
+    from .llm_classifier import LLMClassifier
+    from .result_saver import ResultSaver
+except ImportError:
+    from core import DocumentClassifier
+    from toc_extractor import TOCExtractor
+    from text_splitter import TextSplitter
+    from llm_classifier import LLMClassifier
+    from result_saver import ResultSaver
+
+__all__ = [
+    'DocumentClassifier',
+    'TOCExtractor',
+    'TextSplitter',
+    'LLMClassifier',
+    'ResultSaver'
+]
+

+ 173 - 0
core/construction_review/doc_worker/config.yaml

@@ -0,0 +1,173 @@
+# 文档分类切分库配置文件
+
+# 大语言模型配置
+llm:
+  # 模型API地址
+  model_url: "http://172.16.35.50:8000/v1/chat/completions"
+  # 模型名称
+  model_name: "Qwen2.5-7B-Instruct"
+  # 温度参数(越低越确定)
+  temperature: 0.1
+  # 请求超时时间(秒)
+  timeout: 60
+
+# 文本切分配置
+text_splitting:
+  # 目标层级(默认按几级目录分类)
+  target_level: 1
+  # 最大分块字符数
+  max_chunk_size: 1000
+  # 最小分块字符数
+  min_chunk_size: 500
+  # 模糊匹配阈值(0-1)
+  fuzzy_threshold: 0.80
+
+# 目录提取配置
+toc_extraction:
+  # 最多读取的页数(目录通常在前几页)
+  max_pages: 15
+  # Word文档每页段落数(模拟分页)
+  paragraphs_per_page: 30
+
+# 分类类别配置
+categories:
+  # 中文名称到英文代码的映射
+  mapping:
+    编制依据: basis
+    工程概况: overview
+    施工计划: plan
+    施工工艺计算: technology
+    安全保证措施: safety
+    质量保证措施: quality
+    环境保证措施: environment
+    施工管理及作业人员配备与分工: management
+    验收要求: acceptance
+    其它资料: other
+  
+  # 类别描述(用于LLM分类提示词)
+  descriptions:
+    编制依据: "包括编制依据、编制说明、规范标准、设计文件、相关法律法规等内容"
+    工程概况: "包括项目概况、工程概况、项目背景、建设概况、工程特点等内容"
+    施工计划: "包括施工计划、施工进度计划、施工部署、施工准备、总体安排等内容"
+    施工工艺计算: "包括施工工艺、施工方法、工艺流程、技术方案、施工计算等内容"
+    安全保证措施: "包括安全保证措施、安全管理、安全施工、安全防护、安全生产等内容"
+    质量保证措施: "包括质量保证措施、质量管理、质量控制、质量检验、质量标准等内容"
+    环境保证措施: "包括环境保护措施、环保施工、水土保持、文明施工、环境管理等内容"
+    施工管理及作业人员配备与分工: "包括人员配置、组织机构、人员分工、劳动力安排、管理体系等内容"
+    验收要求: "包括验收标准、验收程序、验收要求、交工验收、竣工验收等内容"
+    其它资料: "其他说明等不属于以上任何类别的内容"
+
+# LLM分类提示词模板
+prompts:
+  classification: |
+    你是一个专业的工程文档分析助手。现在需要你对以下目录项进行分类。
+
+    【分类类别说明】
+    {category_descriptions}
+
+    【待分类的目录项】
+    {toc_items}
+
+    【任务要求】
+    1. 请仔细阅读每个目录项的标题
+    2. 根据标题的语义,将每个目录项分配到最合适的类别中
+    3. 每个目录项只能属于一个类别
+    4. 如果某个目录项不确定或不属于任何明确类别,请归类到"其它资料"
+
+    【输出格式】
+    请严格按照以下JSON格式输出,不要包含任何其他文字说明:
+    {{
+      "分类结果": [
+        {{
+          "序号": 1,
+          "标题": "目录项标题",
+          "类别": "所属类别名称"
+        }}
+      ]
+    }}
+
+    请开始分类:
+
+# 输出配置
+output:
+  # 默认输出目录名称
+  default_dir_name: "分类切分结果"
+  # 是否默认保存结果
+  save_results: true
+  # 文件名最大长度
+  max_filename_length: 200
+
+# 标题层级识别配置
+title_patterns:
+  # 一级标题模式
+  level1:
+    - '^【\d+】'
+    - '^第[一二三四五六七八九十\d]+章'
+    - '^第[一二三四五六七八九十\d]+部分'
+    - '^[一二三四五六七八九十]、'
+    - '^\d+、'
+    - '^第\d+条'
+  
+  # 二级标题模式
+  level2:
+    - '^第[一二三四五六七八九十\d]+节'
+    - '^[一二三四五六七八九十]+、'
+    - '^\(\d+\)'
+    - '^([一二三四五六七八九十\d]+)'
+    - '^〖\d+(?:\.\d+)*〗'
+  
+  # 三级标题模式
+  level3:
+    - '^\([一二三四五六七八九十]+\)'
+    - '^[①②③④⑤⑥⑦⑧⑨⑩]'
+
+# 编号格式配置
+numbering:
+  # 支持的编号格式
+  formats:
+    - '^【\d+】'
+    - '^第[一二三四五六七八九十\d]+[章节条款]'
+    - '^\d+[、..]'
+    - '^[一二三四五六七八九十]+[、..]'
+    - '^\d+\.\d+'
+    - '^\(\d+\)'
+    - '^([一二三四五六七八九十\d]+)'
+    - '^\([一二三四五六七八九十]+\)'
+    - '^[①②③④⑤⑥⑦⑧⑨⑩]'
+    - '^〖\d+(?:\.\d+)*〗'
+
+# 噪音过滤配置
+noise_filters:
+  # 噪音模式(用于过滤非目录内容)
+  patterns:
+    - '^\d{4}[-年]\d{1,2}[-月]\d{1,2}'
+    - '^[A-Za-z0-9\-]{20,}$'
+    - '^http[s]?://'
+    - '^第\s*\d+\s*页'
+    - '^共\s*\d+\s*页'
+    - '^[\d\s\-_.]+$'
+
+# 目录识别配置
+toc_detection:
+  # 目录行的正则模式
+  patterns:
+    - '^(第[一二三四五六七八九十\d]+[章节条款].+?)[.·]{2,}\s*(\d{1,4})\s*$'
+    - '^(〖\d+(?:\.\d+)*〗.+?)[.·]{2,}\s*(\d{1,4})\s*$'
+    - '^(\d+[、..]\s*.+?)[.·]{2,}\s*(\d{1,4})\s*$'
+    - '^([一二三四五六七八九十]+[、..]\s*.+?)[.·]{2,}\s*(\d{1,4})\s*$'
+    - '^(\d+(?:\.\d+)+\s*.+?)[.·]{2,}\s*(\d{1,4})\s*$'
+    - '^(.+?)[.·]{2,}\s*(\d{1,4})\s*$'
+  
+  # 标题长度限制
+  min_length: 3
+  max_length: 200
+
+# 日志配置
+logging:
+  # 日志级别(DEBUG, INFO, WARNING, ERROR)
+  level: INFO
+  # 日志格式
+  format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+  # 日志文件名
+  filename: 'doc_classifier.log'
+

+ 194 - 0
core/construction_review/doc_worker/config_loader.py

@@ -0,0 +1,194 @@
+"""
+配置加载模块
+从config.yaml文件加载配置参数
+"""
+
+import yaml
+from pathlib import Path
+
+
+class Config:
+    """配置类,用于加载和访问配置参数"""
+    
+    _instance = None
+    _config = None
+    
+    def __new__(cls):
+        """单例模式"""
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+    
+    def __init__(self):
+        """初始化配置"""
+        if self._config is None:
+            self.load_config()
+    
+    def load_config(self, config_path=None):
+        """
+        加载配置文件
+        
+        参数:
+            config_path: 配置文件路径,默认为当前目录下的config.yaml
+        """
+        if config_path is None:
+            config_path = Path(__file__).parent / 'config.yaml'
+        else:
+            config_path = Path(config_path)
+        
+        if not config_path.exists():
+            raise FileNotFoundError(f"配置文件不存在: {config_path}")
+        
+        with open(config_path, 'r', encoding='utf-8') as f:
+            self._config = yaml.safe_load(f)
+    
+    def get(self, key_path, default=None):
+        """
+        获取配置值
+        
+        参数:
+            key_path: 配置键路径,用点号分隔,如 'llm.model_url'
+            default: 默认值
+            
+        返回:
+            配置值
+        """
+        keys = key_path.split('.')
+        value = self._config
+        
+        for key in keys:
+            if isinstance(value, dict) and key in value:
+                value = value[key]
+            else:
+                return default
+        
+        return value
+    
+    # LLM配置
+    @property
+    def llm_model_url(self):
+        return self.get('llm.model_url', 'http://172.16.35.50:8000/v1/chat/completions')
+    
+    @property
+    def llm_model_name(self):
+        return self.get('llm.model_name', 'Qwen2.5-7B-Instruct')
+    
+    @property
+    def llm_temperature(self):
+        return self.get('llm.temperature', 0.1)
+    
+    @property
+    def llm_timeout(self):
+        return self.get('llm.timeout', 60)
+    
+    # 文本切分配置
+    @property
+    def target_level(self):
+        return self.get('text_splitting.target_level', 2)
+    
+    @property
+    def max_chunk_size(self):
+        return self.get('text_splitting.max_chunk_size', 1000)
+    
+    @property
+    def min_chunk_size(self):
+        return self.get('text_splitting.min_chunk_size', 500)
+    
+    @property
+    def fuzzy_threshold(self):
+        return self.get('text_splitting.fuzzy_threshold', 0.80)
+    
+    # 目录提取配置
+    @property
+    def toc_max_pages(self):
+        return self.get('toc_extraction.max_pages', 15)
+    
+    @property
+    def paragraphs_per_page(self):
+        return self.get('toc_extraction.paragraphs_per_page', 30)
+    
+    # 分类配置
+    @property
+    def category_mapping(self):
+        return self.get('categories.mapping', {})
+    
+    @property
+    def category_descriptions(self):
+        return self.get('categories.descriptions', {})
+    
+    # 提示词配置
+    @property
+    def classification_prompt_template(self):
+        return self.get('prompts.classification', '')
+    
+    # 输出配置
+    @property
+    def default_output_dir(self):
+        return self.get('output.default_dir_name', '分类切分结果')
+    
+    @property
+    def save_results_default(self):
+        return self.get('output.save_results', True)
+    
+    @property
+    def max_filename_length(self):
+        return self.get('output.max_filename_length', 200)
+    
+    # 标题模式配置
+    @property
+    def level1_patterns(self):
+        return self.get('title_patterns.level1', [])
+    
+    @property
+    def level2_patterns(self):
+        return self.get('title_patterns.level2', [])
+    
+    @property
+    def level3_patterns(self):
+        return self.get('title_patterns.level3', [])
+    
+    # 编号格式配置
+    @property
+    def numbering_formats(self):
+        return self.get('numbering.formats', [])
+    
+    # 噪音过滤配置
+    @property
+    def noise_patterns(self):
+        return self.get('noise_filters.patterns', [])
+    
+    # 目录检测配置
+    @property
+    def toc_patterns(self):
+        return self.get('toc_detection.patterns', [])
+    
+    @property
+    def toc_min_length(self):
+        return self.get('toc_detection.min_length', 3)
+    
+    @property
+    def toc_max_length(self):
+        return self.get('toc_detection.max_length', 200)
+    
+    # 日志配置
+    @property
+    def log_level(self):
+        return self.get('logging.level', 'INFO')
+    
+    @property
+    def log_format(self):
+        return self.get('logging.format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    
+    @property
+    def log_filename(self):
+        return self.get('logging.filename', 'doc_classifier.log')
+
+
+# 全局配置实例
+config = Config()
+
+
+def get_config():
+    """获取全局配置实例"""
+    return config
+

+ 205 - 0
core/construction_review/doc_worker/core.py

@@ -0,0 +1,205 @@
+"""
+核心处理模块
+提供统一的文档处理接口
+"""
+
+from pathlib import Path
+from collections import Counter
+
+try:
+    from .toc_extractor import TOCExtractor
+    from .llm_classifier import LLMClassifier
+    from .text_splitter import TextSplitter
+    from .result_saver import ResultSaver
+    from .config_loader import get_config
+except ImportError:
+    from toc_extractor import TOCExtractor
+    from llm_classifier import LLMClassifier
+    from text_splitter import TextSplitter
+    from result_saver import ResultSaver
+    from config_loader import get_config
+
+
+class DocumentClassifier:
+    """
+    文档分类切分器
+    
+    支持PDF和Word文档的目录提取、分类和文本切分
+    """
+    
+    def __init__(self, model_url=None):
+        """
+        初始化文档分类器
+        
+        参数:
+            model_url: 大语言模型API地址(可选,默认从配置文件读取)
+        """
+        self.config = get_config()
+        self.model_url = model_url or self.config.llm_model_url
+        self.toc_extractor = TOCExtractor()
+        self.llm_classifier = LLMClassifier(model_url)
+        self.text_splitter = TextSplitter()
+        self.result_saver = ResultSaver()
+    
+    def process_document(self, file_path, target_level=None, output_dir=None, 
+                        max_chunk_size=None, min_chunk_size=None, save_results=None):
+        """
+        处理文档:提取目录、分类、切分文本块
+        
+        参数:
+            file_path: 文档文件路径(PDF或Word)
+            target_level: 要分类的目标层级(可选,默认从配置文件读取)
+            output_dir: 输出目录(可选,仅在save_results=True时使用)
+            max_chunk_size: 最大分块字符数(可选,默认从配置文件读取)
+            min_chunk_size: 最小分块字符数(可选,默认从配置文件读取)
+            save_results: 是否保存结果到文件(可选,默认从配置文件读取)
+            
+        返回:
+            dict: 处理结果,包含目录、分类和文本块信息
+        """
+        # 从配置文件读取默认值
+        if target_level is None:
+            target_level = self.config.target_level
+        if max_chunk_size is None:
+            max_chunk_size = self.config.max_chunk_size
+        if min_chunk_size is None:
+            min_chunk_size = self.config.min_chunk_size
+        if save_results is None:
+            save_results = self.config.save_results_default
+        file_path = Path(file_path)
+        
+        # 检查文件是否存在
+        if not file_path.exists():
+            raise FileNotFoundError(f"文件不存在: {file_path}")
+        
+        # 检查文件格式
+        file_ext = file_path.suffix.lower()
+        if file_ext not in ['.pdf', '.docx', '.doc']:
+            raise ValueError(f"不支持的文件格式: {file_ext}")
+        
+        print("=" * 100)
+        print("文档分类切分工具 v2.0")
+        print("=" * 100)
+        print(f"\n文件: {file_path}")
+        print(f"格式: {file_ext.upper()}")
+        print(f"目标层级: {target_level}级")
+        print(f"分块大小: {min_chunk_size}-{max_chunk_size}字符")
+        print(f"模型地址: {self.model_url}")
+        
+        # 设置输出目录
+        if output_dir is None:
+            output_dir = file_path.parent / self.config.default_output_dir
+        else:
+            output_dir = Path(output_dir)
+        
+        # ========== 步骤1: 提取目录 ==========
+        print("\n" + "=" * 100)
+        print("步骤1: 提取文档目录")
+        print("=" * 100)
+        
+        toc_info = self.toc_extractor.extract_toc(file_path)
+        
+        if toc_info['toc_count'] == 0:
+            raise ValueError("未在文档中检测到目录,无法继续处理")
+        
+        print(f"\n成功提取 {toc_info['toc_count']} 个目录项")
+        print(f"目录所在页: {', '.join(map(str, toc_info['toc_pages']))}")
+        
+        # 显示目录层级统计
+        level_counts = Counter([item['level'] for item in toc_info['toc_items']])
+        print("\n目录层级分布:")
+        for level in sorted(level_counts.keys()):
+            print(f"  {level}级: {level_counts[level]} 项")
+        
+        # ========== 步骤2: 调用模型进行分类 ==========
+        print("\n" + "=" * 100)
+        print("步骤2: 调用模型进行智能分类")
+        print("=" * 100)
+        
+        classification_result = self.llm_classifier.classify(
+            toc_info['toc_items'],
+            target_level=target_level
+        )
+        
+        if classification_result is None:
+            raise ValueError("分类失败,无法继续处理")
+        
+        # 显示分类统计
+        category_counts = Counter([item['category'] for item in classification_result['items']])
+        print(f"\n分类统计:")
+        for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
+            print(f"  {category}: {count} 项")
+        
+        # ========== 步骤3: 提取文档全文 ==========
+        print("\n" + "=" * 100)
+        print("步骤3: 提取文档全文")
+        print("=" * 100)
+        
+        pages_content = self.text_splitter.extract_full_text(file_path)
+        
+        if not pages_content:
+            raise ValueError("无法提取文档全文")
+        
+        total_chars = sum(len(page['text']) for page in pages_content)
+        print(f"\n提取完成,共 {len(pages_content)} 页,{total_chars} 个字符")
+        
+        # ========== 步骤4: 按分类标题切分文本 ==========
+        print("\n" + "=" * 100)
+        print("步骤4: 按分类标题智能切分文本")
+        print("=" * 100)
+        
+        chunks = self.text_splitter.split_by_hierarchy(
+            classification_result['items'],
+            pages_content,
+            toc_info,
+            target_level=target_level,
+            max_chunk_size=max_chunk_size,
+            min_chunk_size=min_chunk_size
+        )
+        
+        if not chunks:
+            raise ValueError("未能生成任何文本块")
+        
+        print(f"\n切分完成,共生成 {len(chunks)} 个文本块")
+        
+        # 显示前5个文本块的信息
+        print("\n文本块预览:")
+        for i, chunk in enumerate(chunks[:5], 1):
+            print(f"  [{i}] {chunk['section_label']} ({len(chunk['review_chunk_content'])} 字符)")
+        if len(chunks) > 5:
+            print(f"  ... 还有 {len(chunks) - 5} 个文本块")
+        
+        # ========== 步骤5: 保存结果(可选) ==========
+        saved_files = None
+        if save_results:
+            print("\n" + "=" * 100)
+            print("步骤5: 保存结果")
+            print("=" * 100)
+            
+            # 保存结果
+            saved_files = self.result_saver.save_all(
+                file_path, 
+                toc_info, 
+                classification_result, 
+                chunks, 
+                output_dir
+            )
+        
+        # ========== 完成 ==========
+        print("\n" + "=" * 100)
+        print("处理完成!")
+        print("=" * 100)
+        
+        if save_results:
+            print(f"\n结果已保存到: {output_dir}")
+        print(f"文本块总数: {len(chunks)}")
+        print(f"类别数量: {len(category_counts)}")
+        
+        return {
+            'toc_info': toc_info,
+            'classification': classification_result,
+            'chunks': chunks,
+            'saved_files': saved_files,
+            'output_dir': str(output_dir) if output_dir else None
+        }
+

+ 212 - 0
core/construction_review/doc_worker/llm_classifier.py

@@ -0,0 +1,212 @@
+"""
+大语言模型分类模块
+使用LLM对目录项进行智能分类
+"""
+
+import json
+import re
+import requests
+
+try:
+    from .config_loader import get_config
+except ImportError:
+    from config_loader import get_config
+
+
+class LLMClassifier:
+    """大语言模型分类器"""
+    
+    def __init__(self, model_url=None, model_name=None):
+        """
+        初始化分类器
+        
+        参数:
+            model_url: 模型API地址(可选,默认从配置文件读取)
+            model_name: 模型名称(可选,默认从配置文件读取)
+        """
+        self.config = get_config()
+        self.model_url = model_url or self.config.llm_model_url
+        self.model_name = model_name or self.config.llm_model_name
+        self.category_mapping = self.config.category_mapping
+    
+    def classify(self, toc_items, target_level=2):
+        """
+        对目录项进行智能分类
+        
+        参数:
+            toc_items: 目录项列表
+            target_level: 要分类的目标层级
+            
+        返回:
+            dict: 分类结果
+        """
+        print(f"\n正在对{target_level}级目录进行智能分类...")
+        
+        # 构建提示词
+        prompt_result = self._build_prompt(toc_items, target_level)
+        if prompt_result is None:
+            print(f"  警告: 未找到{target_level}级目录项")
+            return None
+        
+        prompt, filtered_items = prompt_result
+        
+        print(f"  找到 {len(filtered_items)} 个{target_level}级目录项")
+        print("  正在调用模型进行分类...")
+        
+        # 调用模型
+        llm_response = self._call_api(prompt)
+        
+        if llm_response is None:
+            print("  错误: 模型调用失败")
+            return None
+        
+        print("  模型调用成功,正在解析结果...")
+        
+        # 解析结果
+        classification = self._parse_result(llm_response)
+        
+        if classification is None:
+            print("  错误: 结果解析失败")
+            print(f"  模型原始返回:\n{llm_response[:500]}...")
+            return None
+        
+        if "分类结果" not in classification:
+            print(f"  警告: 解析结果中没有'分类结果'字段")
+            print(f"  模型原始返回:\n{llm_response[:500]}...")
+            return None
+        
+        # 整合分类结果到原始目录项
+        classified_items = []
+        classification_map = {}
+        
+        if "分类结果" in classification:
+            for item in classification["分类结果"]:
+                title = item.get("标题", "")
+                category = item.get("类别", "其他")
+                classification_map[title] = category
+        
+        for item in filtered_items:
+            title = item['title']
+            
+            # 尝试直接匹配
+            category_cn = classification_map.get(title, None)
+            
+            # 如果直接匹配失败,尝试去掉编号后匹配
+            if category_cn is None:
+                # 去掉开头的编号(如 "1 ", "1. ", "第一章 " 等)
+                title_without_number = re.sub(r'^[\d一二三四五六七八九十]+[、\.\s]+', '', title)
+                title_without_number = re.sub(r'^第[一二三四五六七八九十\d]+[章节条款]\s*', '', title_without_number)
+                category_cn = classification_map.get(title_without_number, None)
+            
+            # 如果还是没找到,尝试模糊匹配
+            if category_cn is None:
+                for map_title, map_category in classification_map.items():
+                    if map_title in title or title in map_title:
+                        category_cn = map_category
+                        break
+            
+            # 最后的默认值
+            if category_cn is None:
+                category_cn = "未分类"
+            
+            category_en = self.category_mapping.get(category_cn, "other")
+            
+            classified_items.append({
+                'title': title,
+                'page': item['page'],
+                'level': item['level'],
+                'category': category_cn,
+                'category_code': category_en,
+                'original': item.get('original', '')
+            })
+        
+        print(f"  分类完成!共分类 {len(classified_items)} 个目录项")
+        
+        return {
+            'items': classified_items,
+            'total_count': len(classified_items),
+            'target_level': target_level
+        }
+    
+    def _build_prompt(self, toc_items, target_level=2):
+        """构建目录分类的提示词"""
+        # 从配置文件读取分类类别描述
+        categories = self.config.category_descriptions
+        
+        # 筛选出指定层级的目录项
+        filtered_items = [item for item in toc_items if item['level'] == target_level]
+        
+        if not filtered_items:
+            return None
+        
+        # 构建目录项列表字符串
+        toc_list_str = "\n".join([f"{i+1}. {item['title']}" for i, item in enumerate(filtered_items)])
+        
+        # 构建分类说明字符串
+        category_desc = "\n".join([f"- {cat}: {desc}" for cat, desc in categories.items()])
+        
+        # 从配置文件读取提示词模板
+        prompt_template = self.config.classification_prompt_template
+        
+        # 替换模板中的占位符
+        prompt = prompt_template.format(
+            category_descriptions=category_desc,
+            toc_items=toc_list_str
+        )
+        
+        return prompt, filtered_items
+    
+    def _call_api(self, prompt, temperature=None):
+        """调用大语言模型API进行目录分类"""
+        if temperature is None:
+            temperature = self.config.llm_temperature
+        
+        try:
+            headers = {
+                "Content-Type": "application/json"
+            }
+            
+            data = {
+                "model": self.model_name,
+                "messages": [
+                    {
+                        "role": "user",
+                        "content": prompt
+                    }
+                ],
+                "stream": False,
+                "temperature": temperature
+            }
+            
+            timeout = self.config.llm_timeout
+            response = requests.post(self.model_url, headers=headers, json=data, timeout=timeout)
+            response.raise_for_status()
+            
+            result = response.json()
+            content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
+            
+            return content
+        
+        except requests.exceptions.RequestException as e:
+            print(f"  错误: 调用模型API失败 - {str(e)}")
+            return None
+        except Exception as e:
+            print(f"  错误: 解析模型返回结果失败 - {str(e)}")
+            return None
+    
+    def _parse_result(self, llm_response):
+        """解析模型返回的分类结果"""
+        try:
+            # 尝试提取JSON部分
+            json_match = re.search(r'\{[\s\S]*\}', llm_response)
+            if json_match:
+                json_str = json_match.group(0)
+                result = json.loads(json_str)
+                return result
+            else:
+                print("  警告: 无法从模型返回中提取JSON格式")
+                return None
+        except json.JSONDecodeError as e:
+            print(f"  错误: 解析JSON失败 - {str(e)}")
+            return None
+

+ 294 - 0
core/construction_review/doc_worker/result_saver.py

@@ -0,0 +1,294 @@
+"""
+结果保存模块
+保存分类和切分结果到多种格式
+"""
+
+import json
+from pathlib import Path
+from datetime import datetime
+from collections import defaultdict, Counter
+
+try:
+    from .config_loader import get_config
+except ImportError:
+    from config_loader import get_config
+
+
+class ResultSaver:
+    """结果保存器"""
+    
+    def __init__(self):
+        self.config = get_config()
+    
+    def save_all(self, file_path, toc_info, classification_result, chunks, output_dir):
+        """
+        保存所有结果
+        
+        参数:
+            file_path: 源文件路径
+            toc_info: 目录信息
+            classification_result: 分类结果
+            chunks: 文本块列表
+            output_dir: 输出目录
+            
+        返回:
+            dict: 保存的文件路径
+        """
+        output_path = Path(output_dir)
+        output_path.mkdir(parents=True, exist_ok=True)
+        
+        saved_files = {}
+        
+        # 保存完整JSON
+        json_file = self._save_json(file_path, toc_info, classification_result, chunks, output_dir)
+        saved_files['json'] = json_file
+        
+        # 按类别保存文本块
+        print("\n按类别保存文本块:")
+        category_files = self._save_by_category(chunks, file_path, output_dir)
+        saved_files['category_files'] = category_files
+        
+        # 创建索引
+        index_file = self._create_index(chunks, file_path, output_dir)
+        saved_files['index'] = index_file
+        
+        # 保存统计报告
+        report_file = self._save_report(file_path, toc_info, classification_result, chunks, output_dir)
+        saved_files['report'] = report_file
+        
+        return saved_files
+    
+    def _save_json(self, file_path, toc_info, classification_result, chunks, output_dir):
+        """保存完整的分类和切分结果到JSON"""
+        output_path = Path(output_dir)
+        output_path.mkdir(parents=True, exist_ok=True)
+        
+        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+        file_name = Path(file_path).stem
+        
+        json_file = output_path / f"{file_name}_完整结果_{timestamp}.json"
+        
+        output_data = {
+            'source_file': str(file_path),
+            'process_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+            'toc_summary': {
+                'total_items': toc_info['toc_count'],
+                'toc_pages': toc_info['toc_pages']
+            },
+            'classification': classification_result,
+            'chunks': chunks
+        }
+        
+        with open(json_file, 'w', encoding='utf-8') as f:
+            json.dump(output_data, f, ensure_ascii=False, indent=2)
+        
+        print(f"已保存完整结果JSON: {json_file}")
+        return str(json_file)
+    
+    def _save_by_category(self, chunks, file_path, output_dir):
+        """按类别保存文本块到独立的Markdown文件"""
+        output_path = Path(output_dir)
+        file_name = Path(file_path).stem
+        
+        # 按类别分组
+        category_groups = defaultdict(list)
+        for chunk in chunks:
+            category = chunk['project_plan_type']
+            category_groups[category].append(chunk)
+        
+        saved_files = {}
+        
+        # 为每个类别创建子文件夹并保存文件
+        for category, category_chunks in category_groups.items():
+            category_dir = output_path / self._sanitize_filename(category)
+            category_dir.mkdir(parents=True, exist_ok=True)
+            
+            category_files = []
+            
+            # 为每个文本块创建一个MD文件
+            for i, chunk in enumerate(category_chunks, 1):
+                section_label = chunk['section_label']
+                safe_label = self._sanitize_filename(section_label)
+                
+                md_filename = f"{i:03d}_{safe_label}.md"
+                md_file = category_dir / md_filename
+                
+                with open(md_file, 'w', encoding='utf-8') as f:
+                    f.write(f"# {section_label}\n\n")
+                    f.write(f"**类别**: {category}\n\n")
+                    f.write(f"**来源文件**: {chunk['file_name']}\n\n")
+                    f.write(f"**页码**: {chunk['element_tag']['page']}\n\n")
+                    f.write(f"**块ID**: {chunk['chunk_id']}\n\n")
+                    f.write(f"**字符数**: {len(chunk['review_chunk_content'])}\n\n")
+                    f.write("---\n\n")
+                    f.write(chunk['review_chunk_content'])
+                    
+                    if not chunk['review_chunk_content'].endswith('\n'):
+                        f.write('\n')
+                
+                category_files.append(str(md_file))
+            
+            saved_files[category] = category_files
+            print(f"  [{category}] 保存了 {len(category_files)} 个文件到: {category_dir}")
+        
+        return saved_files
+    
+    def _create_index(self, chunks, file_path, output_dir):
+        """创建按类别分组的索引文件"""
+        output_path = Path(output_dir)
+        output_path.mkdir(parents=True, exist_ok=True)
+        
+        file_name = Path(file_path).stem
+        index_file = output_path / "README.md"
+        
+        # 按类别分组
+        category_groups = defaultdict(list)
+        for chunk in chunks:
+            category_groups[chunk['project_plan_type']].append(chunk)
+        
+        with open(index_file, 'w', encoding='utf-8') as f:
+            f.write(f"# {file_name} - 分类切分结果索引\n\n")
+            f.write(f"**来源文件**: {Path(file_path).name}\n\n")
+            f.write(f"**处理时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
+            f.write(f"**文本块总数**: {len(chunks)}\n\n")
+            f.write(f"**类别数量**: {len(category_groups)}\n\n")
+            
+            # 统计信息
+            total_chars = sum(len(chunk['review_chunk_content']) for chunk in chunks)
+            f.write(f"**总字符数**: {total_chars}\n\n")
+            
+            f.write("---\n\n")
+            f.write("## 分类统计\n\n")
+            
+            # 按类别统计
+            for category, category_chunks in sorted(category_groups.items()):
+                category_chars = sum(len(chunk['review_chunk_content']) for chunk in category_chunks)
+                f.write(f"- **{category}**: {len(category_chunks)} 个文本块, {category_chars} 字符\n")
+            
+            f.write("\n---\n\n")
+            f.write("## 详细目录\n\n")
+            
+            # 按类别输出详细目录
+            for category, category_chunks in sorted(category_groups.items()):
+                f.write(f"### {category}\n\n")
+                
+                for i, chunk in enumerate(category_chunks, 1):
+                    section_label = chunk['section_label']
+                    safe_label = self._sanitize_filename(section_label)
+                    category_safe = self._sanitize_filename(category)
+                    md_filename = f"{i:03d}_{safe_label}.md"
+                    
+                    char_count = len(chunk['review_chunk_content'])
+                    page = chunk['element_tag']['page']
+                    
+                    f.write(f"{i}. [{section_label}]({category_safe}/{md_filename}) - 页码: {page}, 字符数: {char_count}\n")
+                
+                f.write("\n")
+        
+        print(f"已保存索引文件: {index_file}")
+        return str(index_file)
+    
+    def _save_report(self, file_path, toc_info, classification_result, chunks, output_dir):
+        """保存详细的统计报告"""
+        output_path = Path(output_dir)
+        output_path.mkdir(parents=True, exist_ok=True)
+        
+        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+        file_name = Path(file_path).stem
+        
+        report_file = output_path / f"{file_name}_统计报告_{timestamp}.txt"
+        
+        with open(report_file, 'w', encoding='utf-8') as f:
+            f.write("=" * 100 + "\n")
+            f.write("文档分类切分统计报告\n")
+            f.write("=" * 100 + "\n\n")
+            
+            f.write(f"源文件: {Path(file_path).name}\n")
+            f.write(f"处理时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
+            
+            # 目录统计
+            f.write("=" * 100 + "\n")
+            f.write("目录提取统计\n")
+            f.write("=" * 100 + "\n\n")
+            f.write(f"目录项总数: {toc_info['toc_count']}\n")
+            f.write(f"目录所在页: {', '.join(map(str, toc_info['toc_pages']))}\n\n")
+            
+            # 层级统计
+            level_counts = Counter([item['level'] for item in toc_info['toc_items']])
+            f.write("目录层级分布:\n")
+            for level in sorted(level_counts.keys()):
+                f.write(f"  {level}级: {level_counts[level]} 项\n")
+            f.write("\n")
+            
+            # 分类统计
+            if classification_result:
+                f.write("=" * 100 + "\n")
+                f.write("分类统计\n")
+                f.write("=" * 100 + "\n\n")
+                
+                category_counts = Counter([item['category'] for item in classification_result['items']])
+                f.write(f"已分类项数: {classification_result['total_count']}\n")
+                f.write(f"分类数量: {len(category_counts)}\n\n")
+                
+                f.write("各类别统计:\n")
+                for category, count in sorted(category_counts.items(), key=lambda x: x[1], reverse=True):
+                    f.write(f"  {category}: {count} 项\n")
+                f.write("\n")
+            
+            # 文本块统计
+            f.write("=" * 100 + "\n")
+            f.write("文本块切分统计\n")
+            f.write("=" * 100 + "\n\n")
+            
+            f.write(f"文本块总数: {len(chunks)}\n\n")
+            
+            total_chars = sum(len(chunk['review_chunk_content']) for chunk in chunks)
+            avg_chars = total_chars / len(chunks) if chunks else 0
+            
+            f.write(f"总字符数: {total_chars}\n")
+            f.write(f"平均每块字符数: {avg_chars:.1f}\n\n")
+            
+            # 按类别统计文本块
+            category_groups = defaultdict(list)
+            for chunk in chunks:
+                category_groups[chunk['project_plan_type']].append(chunk)
+            
+            f.write("按类别统计:\n")
+            for category, category_chunks in sorted(category_groups.items()):
+                category_chars = sum(len(chunk['review_chunk_content']) for chunk in category_chunks)
+                f.write(f"  {category}: {len(category_chunks)} 块, {category_chars} 字符\n")
+            f.write("\n")
+            
+            # 详细列表
+            f.write("=" * 100 + "\n")
+            f.write("文本块详细列表\n")
+            f.write("=" * 100 + "\n\n")
+            
+            for category, category_chunks in sorted(category_groups.items()):
+                f.write(f"\n【{category}】\n")
+                f.write("-" * 100 + "\n")
+                
+                for i, chunk in enumerate(category_chunks, 1):
+                    char_count = len(chunk['review_chunk_content'])
+                    page = chunk['element_tag']['page']
+                    f.write(f"  [{i}] {chunk['section_label']}\n")
+                    f.write(f"      页码: {page}, 字符数: {char_count}, 块ID: {chunk['chunk_id']}\n")
+        
+        print(f"已保存统计报告: {report_file}")
+        return str(report_file)
+    
+    def _sanitize_filename(self, filename):
+        """清理文件名,移除或替换不合法字符"""
+        invalid_chars = r'<>:"/\|?*'
+        for char in invalid_chars:
+            filename = filename.replace(char, '_')
+        
+        filename = filename.strip()
+        
+        # 从配置读取最大文件名长度
+        max_length = self.config.max_filename_length
+        if len(filename) > max_length:
+            filename = filename[:max_length]
+        
+        return filename
+

+ 814 - 0
core/construction_review/doc_worker/text_splitter.py

@@ -0,0 +1,814 @@
+"""
+文本切分模块
+实现按目录层级和字符数的智能切分逻辑
+"""
+
+import re
+from pathlib import Path
+from difflib import SequenceMatcher
+import fitz  # PyMuPDF
+from docx import Document
+
+try:
+    from .config_loader import get_config
+except ImportError:
+    from config_loader import get_config
+
+
+class TextSplitter:
+    """文本切分器,支持PDF和Word格式"""
+    
+    def __init__(self):
+        self.config = get_config()
+    
+    def extract_full_text(self, file_path):
+        """
+        提取文档的全文内容
+        
+        参数:
+            file_path: 文档路径(PDF或Word)
+            
+        返回:
+            list: 每页的文本内容
+        """
+        file_path = Path(file_path)
+        file_ext = file_path.suffix.lower()
+        
+        if file_ext == '.pdf':
+            return self._extract_from_pdf(file_path)
+        elif file_ext in ['.docx', '.doc']:
+            return self._extract_from_word(file_path)
+        else:
+            raise ValueError(f"不支持的文件格式: {file_ext}")
+    
+    def _extract_from_pdf(self, pdf_path):
+        """提取PDF的全文内容"""
+        try:
+            doc = fitz.open(pdf_path)
+            pages_content = []
+            current_pos = 0
+            
+            for page_num in range(len(doc)):
+                page = doc[page_num]
+                text = page.get_text()
+                
+                pages_content.append({
+                    'page_num': page_num + 1,
+                    'text': text,
+                    'start_pos': current_pos,
+                    'end_pos': current_pos + len(text),
+                    'source_file': str(pdf_path)
+                })
+                
+                current_pos += len(text)
+            
+            doc.close()
+            return pages_content
+        except Exception as e:
+            print(f"  错误: 无法读取PDF全文 - {str(e)}")
+            return []
+    
+    def _extract_from_word(self, word_path):
+        """提取Word的全文内容(包括段落和表格)"""
+        try:
+            doc = Document(word_path)
+            pages_content = []
+            current_pos = 0
+            
+            # 提取所有内容(段落和表格按文档顺序)
+            all_content = []
+            
+            # 遍历文档的所有元素(段落和表格)
+            for element in doc.element.body:
+                # 检查是段落还是表格
+                if element.tag.endswith('p'):  # 段落
+                    for para in doc.paragraphs:
+                        if para._element == element:
+                            text = para.text
+                            if text.strip():
+                                all_content.append(text)
+                            break
+                elif element.tag.endswith('tbl'):  # 表格
+                    for table in doc.tables:
+                        if table._element == element:
+                            table_text = self._extract_table_text(table)
+                            all_content.append(table_text)
+                            break
+            
+            # 模拟分页:每30个元素作为一"页"
+            elements_per_page = 30
+            for page_num in range(0, len(all_content), elements_per_page):
+                page_elements = all_content[page_num:page_num + elements_per_page]
+                page_text = '\n'.join(page_elements)
+                
+                pages_content.append({
+                    'page_num': page_num // elements_per_page + 1,
+                    'text': page_text,
+                    'start_pos': current_pos,
+                    'end_pos': current_pos + len(page_text),
+                    'source_file': str(word_path)
+                })
+                
+                current_pos += len(page_text)
+            
+            return pages_content
+        except Exception as e:
+            print(f"  错误: 无法读取Word全文 - {str(e)}")
+            return []
+    
+    def _extract_table_text(self, table):
+        """提取表格内容为文本格式"""
+        table_text = []
+        for row in table.rows:
+            row_text = []
+            for cell in row.cells:
+                cell_text = cell.text.strip().replace('\n', ' ')
+                row_text.append(cell_text)
+            table_text.append('\t'.join(row_text))
+        
+        return '\n[表格开始]\n' + '\n'.join(table_text) + '\n[表格结束]\n'
+    
+    def split_by_hierarchy(self, classified_items, pages_content, toc_info, 
+                          target_level=2, max_chunk_size=1000, min_chunk_size=500):
+        """
+        按目录层级和字符数智能切分文本
+        
+        新的分块逻辑:
+        1. 按目录项定位到指定层级的正文标题
+        2. 在指定层级正文标题所属的正文块中,先按目录项的最低层级子标题进行分块
+        3. 然后逐个判断字符数:
+           - 超过max_chunk_size的进行分割(句子级,保持语义完整,分割的块不参与合并)
+           - 不足min_chunk_size的块进行合并(合并后不能超过max_chunk_size,否则不合并)
+        
+        参数:
+            classified_items: 已分类的目录项列表
+            pages_content: 文档全文内容(按页)
+            toc_info: 目录信息
+            target_level: 目标层级
+            max_chunk_size: 最大分块字符数
+            min_chunk_size: 最小分块字符数
+            
+        返回:
+            list: 带分类信息的文本块列表
+        """
+        full_text = ''.join([page['text'] for page in pages_content])
+        
+        print(f"  正在定位{len(classified_items)}个已分类的标题...")
+        print(f"  目录所在页: {toc_info['toc_pages']}")
+        
+        # 步骤1: 在正文中定位已分类的标题(跳过目录页)
+        located_titles = self._find_title_positions(
+            classified_items, 
+            full_text, 
+            pages_content, 
+            toc_info['toc_pages']
+        )
+        
+        # 只保留成功定位的标题
+        found_titles = [t for t in located_titles if t['found']]
+        
+        if not found_titles:
+            print(f"  错误: 未能在正文中定位任何标题")
+            return []
+        
+        print(f"  成功定位 {len(found_titles)}/{len(classified_items)} 个标题")
+        
+        # 按位置排序
+        found_titles.sort(key=lambda x: x['position'])
+        
+        # 步骤2: 提取所有层级的目录项,用于在正文块中查找子标题
+        all_toc_items = toc_info['toc_items']
+        
+        # 步骤3: 对每个目标层级的标题,提取其正文块并进行智能切分
+        all_chunks = []
+        
+        for i, title_info in enumerate(found_titles):
+            start_pos = title_info['position']
+            
+            # 确定正文块的结束位置(下一个同级标题的位置)
+            if i + 1 < len(found_titles):
+                end_pos = found_titles[i + 1]['position']
+            else:
+                end_pos = len(full_text)
+            
+            # 提取正文块
+            content_block = full_text[start_pos:end_pos]
+            
+            # 在正文块中查找子标题(比目标层级更低的层级)
+            sub_chunks = self._split_by_sub_titles(
+                content_block,
+                all_toc_items,
+                title_info,
+                target_level,
+                max_chunk_size,
+                min_chunk_size
+            )
+            
+            # 为每个子块添加元数据
+            for j, sub_chunk in enumerate(sub_chunks, 1):
+                # 计算实际页码
+                chunk_start_pos = start_pos + sub_chunk['relative_start']
+                page_num = self._get_page_number(chunk_start_pos, pages_content)
+                
+                # 构建section_label(层级路径)
+                section_label = self._build_section_label(
+                    title_info['title'],
+                    sub_chunk.get('sub_title', '')
+                )
+                
+                # 提取最低层级标题的编号
+                sub_title = sub_chunk.get('sub_title', '')
+                if sub_title:
+                    title_number = self._extract_title_number(sub_title)
+                else:
+                    # 如果没有子标题,从父标题提取
+                    title_number = self._extract_title_number(title_info['title'])
+                
+                # 构建chunk_id格式:doc_chunk_<serial_number>_<序号>
+                # 序号从1开始(如果合并了会从0开始)
+                chunk_id_str = f"doc_chunk_{title_number}_{j}" if title_number else f"doc_chunk_{j}"
+                
+                all_chunks.append({
+                    'file_name': Path(pages_content[0].get('source_file', 'unknown')).name if pages_content else 'unknown',
+                    'chunk_id': chunk_id_str,
+                    'section_label': section_label,
+                    'project_plan_type': 'bridge_up_part',
+                    'element_tag': {
+                        'chunk_id': chunk_id_str,
+                        'page': page_num,
+                        'serial_number': title_number if title_number else str(i + 1)
+                    },
+                    'review_chunk_content': sub_chunk['content'],
+                    '_title_number': title_number,  # 临时存储,用于合并时判断
+                    '_local_index': j  # 临时存储局部索引
+                })
+        
+        # 步骤4: 对小块进行合并
+        merged_chunks = self._merge_small_chunks(all_chunks, max_chunk_size, min_chunk_size)
+        
+        # 步骤5: 生成最终的chunk_id和serial_number
+        final_chunks = self._finalize_chunk_ids(merged_chunks)
+        
+        print(f"  初始切分: {len(all_chunks)} 个块")
+        print(f"  合并后: {len(merged_chunks)} 个块")
+        
+        return final_chunks
+    
+    def _find_title_positions(self, classified_items, full_text, pages_content, toc_pages):
+        """在正文中定位已分类的标题位置(跳过目录页)"""
+        # 计算目录页的文本范围
+        toc_start_pos = float('inf')
+        toc_end_pos = 0
+        
+        for page in pages_content:
+            if page['page_num'] in toc_pages:
+                toc_start_pos = min(toc_start_pos, page['start_pos'])
+                toc_end_pos = max(toc_end_pos, page['end_pos'])
+        
+        print(f"    目录页范围: {toc_start_pos} - {toc_end_pos}")
+        
+        located_titles = []
+        
+        for item in classified_items:
+            title = item['title']
+            category = item['category']
+            category_code = item.get('category_code', 'other')
+            
+            # 在全文中查找标题(使用配置的模糊匹配阈值)
+            fuzzy_threshold = self.config.fuzzy_threshold
+            pos = self._find_title_in_text(title, full_text, fuzzy_threshold=fuzzy_threshold)
+            
+            # 如果找到的位置在目录页范围内,继续查找下一个出现
+            if pos >= 0 and toc_start_pos <= pos < toc_end_pos:
+                print(f"    [跳过目录] {title} -> 位置: {pos} (在目录页)")
+                
+                # 尝试在目录页之后继续查找
+                search_start = toc_end_pos
+                remaining_text = full_text[search_start:]
+                pos_in_remaining = self._find_title_in_text(title, remaining_text, fuzzy_threshold=fuzzy_threshold)
+                
+                if pos_in_remaining >= 0:
+                    pos = search_start + pos_in_remaining
+                    print(f"    [找到正文] {title} -> 位置: {pos}")
+                else:
+                    pos = -1
+                    print(f"    [未找到] {title} (目录页之后)")
+            
+            if pos >= 0:
+                # 确认位置不在目录页
+                if not (toc_start_pos <= pos < toc_end_pos):
+                    # 找到对应的页码
+                    page_num = self._get_page_number(pos, pages_content)
+                    
+                    located_titles.append({
+                        'title': title,
+                        'category': category,
+                        'category_code': category_code,
+                        'position': pos,
+                        'toc_page': item.get('page', ''),
+                        'actual_page': page_num,
+                        'found': True
+                    })
+                    print(f"    [确认] {title} -> 页码: {page_num}, 位置: {pos}")
+                else:
+                    print(f"    [未找到] {title} (只在目录页)")
+                    located_titles.append({
+                        'title': title,
+                        'category': category,
+                        'category_code': category_code,
+                        'position': -1,
+                        'toc_page': item.get('page', ''),
+                        'found': False
+                    })
+            else:
+                print(f"    [未找到] {title}")
+                located_titles.append({
+                    'title': title,
+                    'category': category,
+                    'category_code': category_code,
+                    'position': -1,
+                    'toc_page': item.get('page', ''),
+                    'found': False
+                })
+        
+        return located_titles
+    
+    def _find_title_in_text(self, title, text, fuzzy_threshold=0.85):
+        """在文本中查找标题的位置"""
+        normalized_title = self._normalize_title(title)
+        
+        # 方法1: 精确匹配
+        if normalized_title in text:
+            return text.index(normalized_title)
+        
+        # 方法2: 移除所有空格后匹配
+        title_no_space = normalized_title.replace(' ', '')
+        text_no_space = text.replace(' ', '')
+        if title_no_space in text_no_space:
+            pos_no_space = text_no_space.index(title_no_space)
+            return pos_no_space
+        
+        # 方法3: 按行查找,匹配度最高的行
+        lines = text.split('\n')
+        current_pos = 0
+        best_ratio = 0
+        best_pos = -1
+        
+        for line in lines:
+            line_stripped = line.strip()
+            
+            if len(line_stripped) < 3:
+                current_pos += len(line) + 1
+                continue
+            
+            # 计算相似度
+            ratio = SequenceMatcher(None, normalized_title, line_stripped).ratio()
+            
+            if ratio > best_ratio:
+                best_ratio = ratio
+                best_pos = current_pos
+            
+            current_pos += len(line) + 1
+        
+        # 如果找到相似度足够高的行
+        if best_ratio >= fuzzy_threshold:
+            return best_pos
+        
+        return -1
+    
+    def _normalize_title(self, title):
+        """标准化标题用于匹配"""
+        normalized = re.sub(r'\s+', ' ', title)
+        normalized = normalized.strip()
+        return normalized
+    
+    def _extract_title_number(self, title):
+        """
+        从标题中提取编号部分
+        
+        例如:
+        "1.5 施工条件" -> "1.5"
+        "1.6 风险辨识与分级" -> "1.6"
+        "1 工程概况" -> "1"
+        
+        参数:
+            title: 标题字符串
+            
+        返回:
+            str: 编号部分,如果未找到则返回空字符串
+        """
+        # 匹配数字编号格式(如 1.5, 1.6, 1.2.3等)
+        number_match = re.match(r'^(\d+(?:\.\d+)*)', title)
+        if number_match:
+            return number_match.group(1)
+        
+        # 匹配中文编号格式(如 一、二、三等)
+        chinese_match = re.match(r'^([一二三四五六七八九十]+)[、..]', title)
+        if chinese_match:
+            return chinese_match.group(1)
+        
+        return ""
+    
+    def _get_page_number(self, position, pages_content):
+        """根据位置获取页码"""
+        for page in pages_content:
+            if page['start_pos'] <= position < page['end_pos']:
+                return page['page_num']
+        return 1
+    
+    def _split_by_sub_titles(self, content_block, all_toc_items, parent_title_info, 
+                            target_level, max_chunk_size, min_chunk_size):
+        """
+        在正文块中按子标题进行切分
+        
+        参数:
+            content_block: 正文块内容
+            all_toc_items: 所有目录项
+            parent_title_info: 父标题信息
+            target_level: 目标层级
+            max_chunk_size: 最大分块字符数
+            min_chunk_size: 最小分块字符数
+            
+        返回:
+            list: 子块列表
+        """
+        # 查找比目标层级更低的子标题
+        sub_titles = []
+        fuzzy_threshold = self.config.fuzzy_threshold
+        for toc_item in all_toc_items:
+            if toc_item['level'] > target_level:
+                # 在正文块中查找这个子标题
+                pos = self._find_title_in_text(toc_item['title'], content_block, fuzzy_threshold=fuzzy_threshold)
+                if pos >= 0:
+                    sub_titles.append({
+                        'title': toc_item['title'],
+                        'level': toc_item['level'],
+                        'position': pos
+                    })
+        
+        # 按位置排序
+        sub_titles.sort(key=lambda x: x['position'])
+        
+        # 如果没有找到子标题,将整个正文块作为一个块
+        if not sub_titles:
+            # 检查是否需要分割
+            if len(content_block) > max_chunk_size:
+                return self._split_large_chunk(content_block, max_chunk_size, parent_title_info['title'])
+            else:
+                return [{
+                    'content': content_block,
+                    'relative_start': 0,
+                    'sub_title': '',
+                    'serial_number': ''
+                }]
+        
+        # 按子标题切分
+        chunks = []
+        for i, sub_title in enumerate(sub_titles):
+            start_pos = sub_title['position']
+            
+            # 确定结束位置
+            if i + 1 < len(sub_titles):
+                end_pos = sub_titles[i + 1]['position']
+            else:
+                end_pos = len(content_block)
+            
+            chunk_content = content_block[start_pos:end_pos]
+            
+            # 检查是否需要分割
+            if len(chunk_content) > max_chunk_size:
+                split_chunks = self._split_large_chunk(chunk_content, max_chunk_size, sub_title['title'])
+                for j, split_chunk in enumerate(split_chunks):
+                    split_chunk['relative_start'] = start_pos + split_chunk['relative_start']
+                    split_chunk['sub_title'] = sub_title['title']
+                    chunks.append(split_chunk)
+            else:
+                chunks.append({
+                    'content': chunk_content,
+                    'relative_start': start_pos,
+                    'sub_title': sub_title['title']
+                })
+        
+        return chunks
+    
+    def _split_large_chunk(self, content, max_chunk_size, title):
+        """
+        将超大块按句子级分割(保持语义完整)
+        
+        参数:
+            content: 内容
+            max_chunk_size: 最大分块字符数
+            title: 标题
+            
+        返回:
+            list: 分割后的块列表
+        """
+        # 按句子分割(中文句号、问号、感叹号)
+        sentences = re.split(r'([。!?\n])', content)
+        
+        # 重新组合句子和标点
+        combined_sentences = []
+        for i in range(0, len(sentences) - 1, 2):
+            if i + 1 < len(sentences):
+                combined_sentences.append(sentences[i] + sentences[i + 1])
+            else:
+                combined_sentences.append(sentences[i])
+        
+        if not combined_sentences:
+            combined_sentences = [content]
+        
+        # 按max_chunk_size组合句子
+        chunks = []
+        current_chunk = ""
+        current_start = 0
+        
+        for sentence in combined_sentences:
+            if len(current_chunk) + len(sentence) <= max_chunk_size:
+                current_chunk += sentence
+            else:
+                if current_chunk:
+                    chunks.append({
+                        'content': current_chunk,
+                        'relative_start': current_start,
+                        'is_split': True  # 标记为分割块,不参与合并
+                    })
+                    current_start += len(current_chunk)
+                current_chunk = sentence
+        
+        # 添加最后一个块
+        if current_chunk:
+            chunks.append({
+                'content': current_chunk,
+                'relative_start': current_start,
+                'is_split': True
+            })
+        
+        return chunks
+    
+    def _merge_small_chunks(self, chunks, max_chunk_size, min_chunk_size):
+        """
+        合并小于min_chunk_size的块
+        
+        参数:
+            chunks: 块列表
+            max_chunk_size: 最大分块字符数
+            min_chunk_size: 最小分块字符数
+            
+        返回:
+            list: 合并后的块列表
+        """
+        if not chunks:
+            return []
+        
+        # 先按最低层级标题编号分组处理(在同一标题内合并)
+        current_title_number = None
+        title_groups = []
+        current_group = []
+        
+        for chunk in chunks:
+            title_number = chunk.get('_title_number', '')
+            
+            if title_number != current_title_number:
+                # 保存上一组
+                if current_group:
+                    title_groups.append({
+                        'title_number': current_title_number,
+                        'chunks': current_group
+                    })
+                # 开始新组
+                current_title_number = title_number
+                current_group = [chunk]
+            else:
+                current_group.append(chunk)
+        
+        # 保存最后一组
+        if current_group:
+            title_groups.append({
+                'title_number': current_title_number,
+                'chunks': current_group
+            })
+        
+        # 在每个组内合并小块
+        merged_groups = []
+        for group in title_groups:
+            merged_chunks = self._merge_within_title(group['chunks'], max_chunk_size, min_chunk_size)
+            merged_groups.append({
+                'title_number': group['title_number'],
+                'chunks': merged_chunks
+            })
+        
+        # 处理跨标题合并:如果上一组的最后一个块与当前组的第一个块都是小块,可以合并
+        final_merged = []
+        for i, group in enumerate(merged_groups):
+            if i == 0:
+                final_merged.extend(group['chunks'])
+            else:
+                # 检查是否可以与上一组的最后一个块合并
+                prev_group = merged_groups[i - 1]
+                if prev_group['chunks'] and group['chunks']:
+                    prev_last = prev_group['chunks'][-1]
+                    curr_first = group['chunks'][0]
+                    
+                    prev_content = prev_last['review_chunk_content']
+                    curr_content = curr_first['review_chunk_content']
+                    
+                    # 如果两个块都是小块且不是分割块,可以合并
+                    if (not prev_last.get('is_split', False) and 
+                        not curr_first.get('is_split', False) and
+                        len(prev_content) < min_chunk_size and
+                        len(curr_content) < min_chunk_size and
+                        len(prev_content) + len(curr_content) <= max_chunk_size):
+                        
+                        # 合并
+                        merged_content = prev_content + '\n\n' + curr_content
+                        merged_chunk = prev_last.copy()
+                        merged_chunk['review_chunk_content'] = merged_content
+                        merged_chunk['section_label'] = self._merge_section_labels(
+                            prev_last['section_label'],
+                            curr_first['section_label']
+                        )
+                        # 合并标题编号
+                        prev_title_num = prev_last.get('_title_number', '')
+                        curr_title_num = curr_first.get('_title_number', '')
+                        if prev_title_num and curr_title_num and prev_title_num != curr_title_num:
+                            # chunk_id中使用+号(无空格)
+                            merged_chunk['_title_number'] = f"{prev_title_num}+{curr_title_num}"
+                            # serial_number中使用空格(用于显示)
+                            merged_chunk['_title_number_display'] = f"{prev_title_num} + {curr_title_num}"
+                        merged_chunk['_is_merged'] = True
+                        
+                        # 替换上一组的最后一个块
+                        final_merged[-1] = merged_chunk
+                        # 跳过当前组的第一个块
+                        final_merged.extend(group['chunks'][1:])
+                    else:
+                        final_merged.extend(group['chunks'])
+                else:
+                    final_merged.extend(group['chunks'])
+        
+        return final_merged
+    
+    def _merge_within_title(self, title_chunks, max_chunk_size, min_chunk_size):
+        """在同一个最低层级标题内合并小块"""
+        if not title_chunks:
+            return []
+        
+        merged = []
+        i = 0
+        
+        while i < len(title_chunks):
+            current_chunk = title_chunks[i]
+            current_content = current_chunk['review_chunk_content']
+            
+            # 如果当前块是分割块,不参与合并
+            if current_chunk.get('is_split', False):
+                merged.append(current_chunk)
+                i += 1
+                continue
+            
+            # 如果当前块小于最小值,尝试与下一个块合并
+            if len(current_content) < min_chunk_size and i + 1 < len(title_chunks):
+                next_chunk = title_chunks[i + 1]
+                next_content = next_chunk['review_chunk_content']
+                
+                # 检查下一个块是否也是小块且不是分割块
+                if (not next_chunk.get('is_split', False) and 
+                    len(current_content) + len(next_content) <= max_chunk_size):
+                    # 合并
+                    merged_content = current_content + '\n\n' + next_content
+                    merged_chunk = current_chunk.copy()
+                    merged_chunk['review_chunk_content'] = merged_content
+                    # 使用优化的标签合并函数
+                    merged_chunk['section_label'] = self._merge_section_labels(
+                        current_chunk['section_label'], 
+                        next_chunk['section_label']
+                    )
+                    merged.append(merged_chunk)
+                    i += 2  # 跳过下一个块
+                    continue
+            
+            # 否则直接添加
+            merged.append(current_chunk)
+            i += 1
+        
+        return merged
+    
+    def _finalize_chunk_ids(self, chunks):
+        """
+        生成最终的chunk_id和serial_number
+        
+        参数:
+            chunks: 合并后的块列表
+            
+        返回:
+            list: 最终处理后的块列表
+        """
+        final_chunks = []
+        current_title_number = None
+        local_index = 1
+        
+        for i, chunk in enumerate(chunks):
+            title_number = chunk.get('_title_number', '')
+            is_merged = chunk.get('_is_merged', False)
+            
+            # 提取标题编号的主要部分(用于判断是否在同一标题内)
+            # 如果包含+号,说明是跨标题合并的块
+            if '+' in str(title_number):
+                # 跨标题合并的块,序号从0开始
+                local_index = 0
+                # chunk_id中使用+号(无空格),如"1.5+1.6"
+                merged_title_number = title_number
+                # serial_number中使用空格,如"1.5 + 1.6"
+                serial_number_display = chunk.get('_title_number_display', title_number.replace('+', ' + '))
+                # 更新current_title_number为合并后的编号,这样下一个块会重新开始
+                current_title_number = title_number
+            else:
+                # 如果标题编号变化,重置索引
+                if title_number != current_title_number:
+                    current_title_number = title_number
+                    # 如果上一个块是跨标题合并的,说明当前标题的第一个块已经被合并了,序号从1开始
+                    # 否则序号从1开始
+                    local_index = 1
+                else:
+                    local_index += 1
+                merged_title_number = title_number
+                serial_number_display = title_number
+            
+            # 生成chunk_id(使用无空格的编号)
+            if merged_title_number:
+                chunk_id_str = f"doc_chunk_{merged_title_number}_{local_index}"
+            else:
+                chunk_id_str = f"doc_chunk_{local_index}"
+            
+            # 更新chunk数据
+            final_chunk = {
+                'file_name': chunk['file_name'],
+                'chunk_id': chunk_id_str,
+                'section_label': chunk['section_label'],
+                'project_plan_type': 'bridge_up_part',
+                'element_tag': {
+                    'chunk_id': chunk_id_str,
+                    'page': chunk['element_tag']['page'],
+                    'serial_number': serial_number_display if merged_title_number else ''
+                },
+                'review_chunk_content': chunk['review_chunk_content']
+            }
+            
+            final_chunks.append(final_chunk)
+        
+        return final_chunks
+    
+    def _build_section_label(self, parent_title, sub_title):
+        """构建section_label(层级路径)"""
+        if sub_title:
+            return f"{parent_title}->{sub_title}"
+        else:
+            return parent_title
+    
+    def _merge_section_labels(self, label1, label2):
+        """
+        合并两个section_label,提取公共前缀
+        
+        例如:
+        "1 工程概况->1.3 工程地质" + "1 工程概况->1.4 气象水文"
+        => "1 工程概况->1.3 工程地质 + 1.4 气象水文"
+        
+        参数:
+            label1: 第一个标签
+            label2: 第二个标签
+            
+        返回:
+            str: 合并后的标签
+        """
+        # 按"->"分割标签
+        parts1 = label1.split('->')
+        parts2 = label2.split('->')
+        
+        # 找到公共前缀
+        common_prefix = []
+        for i in range(min(len(parts1), len(parts2))):
+            if parts1[i] == parts2[i]:
+                common_prefix.append(parts1[i])
+            else:
+                break
+        
+        # 如果有公共前缀
+        if common_prefix:
+            # 获取不同的部分
+            diff1 = '->'.join(parts1[len(common_prefix):])
+            diff2 = '->'.join(parts2[len(common_prefix):])
+            
+            # 构建合并后的标签
+            prefix = '->'.join(common_prefix)
+            if diff1 and diff2:
+                return f"{prefix}->{diff1} + {diff2}"
+            elif diff1:
+                return f"{prefix}->{diff1}"
+            elif diff2:
+                return f"{prefix}->{diff2}"
+            else:
+                return prefix
+        else:
+            # 没有公共前缀,直接用+连接
+            return f"{label1} + {label2}"
+

+ 348 - 0
core/construction_review/doc_worker/toc_extractor.py

@@ -0,0 +1,348 @@
+"""
+目录提取模块
+支持从PDF和Word文档中提取目录结构
+"""
+
+import re
+from pathlib import Path
+import fitz  # PyMuPDF
+from docx import Document
+
+try:
+    from .config_loader import get_config
+except ImportError:
+    from config_loader import get_config
+
+
+class TOCExtractor:
+    """目录提取器,支持PDF和Word格式"""
+    
+    def __init__(self):
+        self.config = get_config()
+    
+    def extract_toc(self, file_path):
+        """
+        提取文档目录
+        
+        参数:
+            file_path: 文档路径(PDF或Word)
+            
+        返回:
+            dict: 包含目录项和统计信息的字典
+        """
+        file_path = Path(file_path)
+        file_ext = file_path.suffix.lower()
+        
+        if file_ext == '.pdf':
+            return self._extract_from_pdf(file_path)
+        elif file_ext in ['.docx', '.doc']:
+            return self._extract_from_word(file_path)
+        else:
+            raise ValueError(f"不支持的文件格式: {file_ext}")
+    
+    def _extract_from_pdf(self, pdf_path, max_pages=None):
+        """从PDF中提取目录"""
+        if max_pages is None:
+            max_pages = self.config.toc_max_pages
+        pages_text = self._extract_pdf_pages(pdf_path, max_pages)
+        
+        all_toc_items = []
+        toc_page_nums = []
+        
+        for page_info in pages_text:
+            toc_items = self._detect_toc_patterns(page_info['text'])
+            
+            if toc_items:
+                all_toc_items.extend(toc_items)
+                toc_page_nums.append(page_info['page_num'])
+        
+        # 去重
+        unique_toc = []
+        seen = set()
+        for item in all_toc_items:
+            key = (item['title'], item['page'])
+            if key not in seen:
+                seen.add(key)
+                unique_toc.append(item)
+        
+        return {
+            'toc_items': unique_toc,
+            'toc_count': len(unique_toc),
+            'toc_pages': toc_page_nums
+        }
+    
+    def _extract_from_word(self, word_path, max_pages=None):
+        """从Word中提取目录"""
+        if max_pages is None:
+            max_pages = self.config.toc_max_pages
+        
+        # 方法1: 尝试提取内置目录结构
+        builtin_toc = self._extract_builtin_toc(word_path)
+        
+        # 方法2: 文本模式匹配(作为补充)
+        pages_text = self._extract_word_pages(word_path, max_pages)
+        
+        pattern_toc_items = []
+        toc_page_nums = []
+        
+        for page_info in pages_text:
+            toc_items = self._detect_toc_patterns(page_info['text'])
+            
+            if toc_items:
+                pattern_toc_items.extend(toc_items)
+                toc_page_nums.append(page_info['page_num'])
+        
+        # 合并两种方法的结果
+        all_toc_items = []
+        
+        # 优先使用内置目录
+        if builtin_toc:
+            all_toc_items.extend(builtin_toc)
+        
+        # 如果内置目录为空或数量较少,使用模式匹配的结果
+        if len(builtin_toc) < 3:
+            all_toc_items.extend(pattern_toc_items)
+        
+        # 去重
+        unique_toc = []
+        seen = set()
+        for item in all_toc_items:
+            key = (item['title'], item.get('page', '?'))
+            if key not in seen:
+                seen.add(key)
+                unique_toc.append(item)
+        
+        return {
+            'toc_items': unique_toc,
+            'toc_count': len(unique_toc),
+            'toc_pages': toc_page_nums if toc_page_nums else [1]
+        }
+    
+    def _extract_pdf_pages(self, pdf_path, max_pages=None):
+        """从PDF文件的前几页提取文本"""
+        if max_pages is None:
+            max_pages = self.config.toc_max_pages
+        try:
+            doc = fitz.open(pdf_path)
+            pages_text = []
+            
+            for page_num in range(min(len(doc), max_pages)):
+                page = doc[page_num]
+                text = page.get_text()
+                pages_text.append({
+                    'page_num': page_num + 1,
+                    'text': text
+                })
+            
+            doc.close()
+            return pages_text
+        except Exception as e:
+            print(f"  错误: 无法读取PDF - {str(e)}")
+            return []
+    
+    def _extract_word_pages(self, word_path, max_pages=None):
+        """从Word文件的前几页提取文本"""
+        if max_pages is None:
+            max_pages = self.config.toc_max_pages
+        
+        try:
+            doc = Document(word_path)
+            pages_text = []
+            
+            all_text = []
+            for para in doc.paragraphs:
+                text = para.text.strip()
+                if text:
+                    all_text.append(text)
+            
+            # 模拟分页:从配置读取每页段落数
+            paragraphs_per_page = self.config.paragraphs_per_page
+            for i in range(0, min(len(all_text), max_pages * paragraphs_per_page), paragraphs_per_page):
+                page_text = '\n'.join(all_text[i:i+paragraphs_per_page])
+                pages_text.append({
+                    'page_num': i // paragraphs_per_page + 1,
+                    'text': page_text
+                })
+            
+            return pages_text
+        except Exception as e:
+            print(f"  错误: 无法读取Word - {str(e)}")
+            return []
+    
+    def _extract_builtin_toc(self, word_path):
+        """提取Word文档的内置目录结构"""
+        try:
+            doc = Document(word_path)
+            toc_items = []
+            
+            for para in doc.paragraphs:
+                style_name = para.style.name if para.style else ""
+                text = para.text.strip()
+                
+                if not text:
+                    continue
+                
+                # 检查是否是标题样式
+                if style_name.startswith('Heading'):
+                    if not self._has_numbering(text):
+                        continue
+                    
+                    try:
+                        level = int(style_name.split()[-1]) if len(style_name.split()) > 1 else 1
+                    except:
+                        level = 1
+                    
+                    toc_items.append({
+                        'title': text,
+                        'level': level,
+                        'page': '?',
+                        'original': text,
+                        'source': 'heading_style'
+                    })
+                # 检查是否是TOC样式
+                elif 'TOC' in style_name or 'toc' in style_name.lower():
+                    match = re.search(r'(\d+)\s*$', text)
+                    page = match.group(1) if match else '?'
+                    
+                    title = re.sub(r'\s*\d+\s*$', '', text).strip()
+                    
+                    if not self._has_numbering(title):
+                        continue
+                    
+                    level_match = re.search(r'TOC\s*(\d+)', style_name, re.IGNORECASE)
+                    level = int(level_match.group(1)) if level_match else 1
+                    
+                    if title:
+                        toc_items.append({
+                            'title': title,
+                            'level': level,
+                            'page': page,
+                            'original': text,
+                            'source': 'toc_style'
+                        })
+            
+            return toc_items
+        except Exception as e:
+            print(f"  错误: 无法读取Word内置目录 - {str(e)}")
+            return []
+    
+    def _has_numbering(self, text):
+        """检查文本是否包含编号格式"""
+        # 从配置读取编号格式
+        numbering_patterns = self.config.numbering_formats
+        
+        for pattern in numbering_patterns:
+            if re.match(pattern, text):
+                return True
+        
+        return False
+    
+    def _detect_toc_patterns(self, text):
+        """检测文本中的目录模式"""
+        toc_items = []
+        lines = text.split('\n')
+        
+        # 预处理:合并可能分行的目录项
+        merged_lines = []
+        i = 0
+        while i < len(lines):
+            line = lines[i].strip()
+            
+            if re.match(r'^第[一二三四五六七八九十\d]+[章节条款]\s*$', line):
+                if i + 1 < len(lines):
+                    next_line = lines[i + 1].strip()
+                    if re.search(r'[.·]{2,}.*\d{1,4}\s*$', next_line):
+                        merged_line = line + next_line
+                        merged_lines.append(merged_line)
+                        i += 2
+                        continue
+            
+            merged_lines.append(line)
+            i += 1
+        
+        # 从配置读取目录格式的正则表达式
+        patterns = self.config.toc_patterns
+        
+        # 从配置读取长度限制
+        min_length = self.config.toc_min_length
+        max_length = self.config.toc_max_length
+        
+        for line in merged_lines:
+            line = line.strip()
+            
+            if len(line) < min_length or len(line) > max_length:
+                continue
+            
+            if line.isdigit():
+                continue
+            
+            for pattern in patterns:
+                match = re.match(pattern, line)
+                if match:
+                    title = match.group(1).strip()
+                    page_num = match.group(2).strip()
+                    
+                    title_clean = re.sub(r'[.·]{2,}', '', title)
+                    title_clean = re.sub(r'\s{2,}', ' ', title_clean)
+                    title_clean = title_clean.strip()
+                    
+                    if title_clean and not self._is_likely_noise(title_clean):
+                        toc_items.append({
+                            'original': line,
+                            'title': title_clean,
+                            'page': page_num,
+                            'level': self._detect_level(title_clean)
+                        })
+                        break
+        
+        return toc_items
+    
+    def _is_likely_noise(self, text):
+        """判断文本是否可能是噪音(非目录内容)"""
+        # 从配置读取噪音模式
+        noise_patterns = self.config.noise_patterns
+        
+        for pattern in noise_patterns:
+            if re.search(pattern, text):
+                return True
+        
+        return False
+    
+    def _detect_level(self, title):
+        """检测目录项的层级"""
+        if re.match(r'^【\d+】', title):
+            return 1
+        
+        # 检查数字编号层级(如 1.1, 1.1.1, 1.1.1.1)
+        number_match = re.match(r'^(\d+(?:\.\d+)*)\s', title)
+        if number_match:
+            number_part = number_match.group(1)
+            dot_count = number_part.count('.')
+            return dot_count + 1
+        
+        # 检查〖〗格式的编号
+        bracket_match = re.match(r'^〖(\d+(?:\.\d+)*)〗', title)
+        if bracket_match:
+            number_part = bracket_match.group(1)
+            dot_count = number_part.count('.')
+            return dot_count + 1
+        
+        # 从配置读取标题模式
+        level1_patterns = self.config.level1_patterns
+        level2_patterns = self.config.level2_patterns
+        level3_patterns = self.config.level3_patterns
+        
+        for pattern in level1_patterns:
+            if re.match(pattern, title):
+                return 1
+        
+        for pattern in level2_patterns:
+            if re.match(pattern, title):
+                return 2
+        
+        for pattern in level3_patterns:
+            if re.match(pattern, title):
+                return 3
+        
+        return 1
+

+ 31 - 21
core/construction_review/workflows/ai_review_workflow.py

@@ -202,13 +202,17 @@ class AIReviewWorkflow:
 
         # 更新进度
         if state["progress_manager"]:
+            logger.debug(f"AI审查工作流中ProgressManager ID: {id(state['progress_manager'])}")
+            logger.debug(f"AI审查工作流中ProgressManager有SSE回调: {hasattr(state['progress_manager'], 'sse_callback')}")
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=0,
+                current=0,
                 status="processing",
                 message="开始AI审查"
             )
+        else:
+            logger.warning(f"AI审查工作流中未找到ProgressManager: {state.get('progress_manager', 'None')}")
 
         state["messages"].append(AIMessage(content="进度初始化完成"))
 
@@ -231,20 +235,17 @@ class AIReviewWorkflow:
 
             logger.info(f"AI审查开始: 总单元数 {total_all_units}, 实际审查 {total_units} 个单元")
 
-            # 进度回调函数
-            def progress_callback(progress: int, message: str):
-                overall_progress = 50 + int(progress * 0.4)  # AI审查占整体进度的40%
-                if state["progress_manager"]:
-                    asyncio.create_task(
-                        state["progress_manager"].update_stage_progress(
-                            callback_task_id=state["callback_task_id"],
-                            stage_name="AI审查",
-                            progress=overall_progress,
-                            status="processing",
-                            message=message
-                        )
-                    )
-
+            # 开始AI审查进度
+            if state["progress_manager"]:
+                await state["progress_manager"].update_stage_progress(
+                    callback_task_id=state["callback_task_id"],
+                    stage_name="AI审查",
+                    current=0,
+                    status="processing",
+                    message=f"开始AI审查,共 {total_units} 个审查单元"
+                )
+
+            
             # 基本审查单元
             async def review_single_unit(unit_content: Dict[str, Any], unit_index: int,callback_task_id) -> ReviewResult:
                 """使用LangGraph编排的原子化组件方法审查单个单元"""
@@ -274,11 +275,20 @@ class AIReviewWorkflow:
                         # 更新进度
                         nonlocal completed_units
                         completed_units += 1
-                        progress = int((completed_units / total_units) * 100)
+                        current = int((completed_units / total_units) * 100)
                         message = f"已完成 {completed_units}/{total_units} 个审查单元"
-
-                        if progress_callback:
-                            progress_callback(progress, message)
+                        logger.info(f"更新进度: {current}% {message}")
+                        # 更新ProgressManager进度
+                        if state["progress_manager"]:
+                            asyncio.create_task(
+                                state["progress_manager"].update_stage_progress(
+                                    callback_task_id=state["callback_task_id"],
+                                    stage_name="AI审查",
+                                    current=current,
+                                    status="processing",
+                                    message=message
+                                )
+                            )
 
                         return ReviewResult(
                         unit_index=unit_index,
@@ -351,7 +361,7 @@ class AIReviewWorkflow:
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=90,
+                current=90,
                 status="completed",
                 message="AI审查完成"
             )
@@ -372,7 +382,7 @@ class AIReviewWorkflow:
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=50,
+                current=50,
                 status="failed",
                 message=f"AI审查失败: {state['error_message']}"
             )

+ 13 - 19
core/construction_review/workflows/document_workflow.py

@@ -27,25 +27,19 @@ class DocumentWorkflow:
         try:
             logger.info(f"开始文档处理工作流,文件ID: {self.file_id}")
 
-            # 2. 初始化进度
-            await self.progress_manager.initialize_progress(
-                callback_task_id=self.callback_task_id,
-                user_id=self.user_id,
-                stages=[
-                    {"stage_name": "文档上传", "progress": 100, "status": "completed"},
-                    {"stage_name": "文档解析", "progress": 0, "status": "pending"},
-                    {"stage_name": "内容提取", "progress": 0, "status": "pending"},
-                    {"stage_name": "结构化处理", "progress": 0, "status": "pending"}
-                ]
-            )
+            # 检查是否已初始化进度,避免重复初始化
+            existing_progress = await self.progress_manager.get_progress(self.callback_task_id)
+            if not existing_progress:
+                logger.warning(f"文档处理工作流未找到进度数据: {self.callback_task_id}")
+
 
             # 4. 执行文档处理
-            def progress_callback(progress: int, message: str):
+            def progress_callback(current: int, message: str):
                 asyncio.create_task(
                     self.progress_manager.update_stage_progress(
                         callback_task_id=self.callback_task_id,
-                        stage_name="文档处理",
-                        progress=progress,
+                        stage_name="文档解析",
+                        current=current,
                         status="processing",
                         message=message
                     )
@@ -60,10 +54,10 @@ class DocumentWorkflow:
             # 5. 更新完成状态
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
-                stage_name="文档处理",
-                progress=100,
+                stage_name="文档解析",
+                current=100,
                 status="completed",
-                message="文档处理完成"
+                message="文档解析完成"
             )
 
             # 6. 保存处理结果
@@ -85,8 +79,8 @@ class DocumentWorkflow:
             if self.progress_manager:
                 await self.progress_manager.update_stage_progress(
                     callback_task_id=self.callback_task_id,
-                    stage_name="文档处理",
-                    progress=0,
+                    stage_name="文档解析",
+                    current=0,
                     status="failed",
                     message=f"处理失败: {str(e)}"
                 )

+ 7 - 11
core/construction_review/workflows/report_workflow.py

@@ -31,20 +31,20 @@ class ReportWorkflow:
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
                 stage_name="报告生成",
-                progress=0,
+                current=0,
                 status="processing",
                 message="开始生成报告"
             )
 
             # 2. 生成报告
-            def progress_callback(progress: int, message: str):
+            def progress_callback(current: int, message: str):
                 # 将报告生成的进度映射到整体进度
-                overall_progress = 90 + int(progress * 0.1)  # 报告生成占整体进度的10%
+                overall_progress = 90 + int(current * 0.1)  # 报告生成占整体进度的10%
                 asyncio.create_task(
                     self.progress_manager.update_stage_progress(
                         callback_task_id=self.callback_task_id,
                         stage_name="报告生成",
-                        progress=overall_progress,
+                        current=overall_progress,
                         status="processing",
                         message=message
                     )
@@ -60,16 +60,12 @@ class ReportWorkflow:
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
                 stage_name="报告生成",
-                progress=100,
+                current=100,
                 status="completed",
                 message="报告生成完成"
             )
 
-            # 4. 标记任务链完成
-            await self.progress_manager.complete_task(
-                callback_task_id=self.callback_task_id,
-                result=self._convert_report_to_dict(final_report)
-            )
+
 
             # 5. 处理结果
             result = self._convert_report_to_dict(final_report)
@@ -85,7 +81,7 @@ class ReportWorkflow:
                 await self.progress_manager.update_stage_progress(
                     callback_task_id=self.callback_task_id,
                     stage_name="报告生成",
-                    progress=90,
+                    current=90,
                     status="failed",
                     message=f"报告生成失败: {str(e)}"
                 )

+ 36 - 0
database/repositories/bus_data_query.py

@@ -0,0 +1,36 @@
+from typing import List, Tuple, Any, Optional, Dict
+from foundation.logger.loggering import server_logger
+from foundation.utils.common import handler_err
+from foundation.base.mysql.async_mysql_base_dao import AsyncBaseDAO
+
+
+class BasisOfPreparationDAO(AsyncBaseDAO):
+    """异步编制依据 对象"""
+    
+    
+    async def get_info_by_id(self, id: int) -> Optional[Dict]:
+        """根据ID获取编制依据"""
+        query = "SELECT * FROM t_basis_of_preparation WHERE id = %s"
+        return await self.fetch_one(query, (id,))
+    
+    async def get_list(self) -> List[Dict]:
+        """获取所有编制依据"""
+        query = "SELECT * FROM t_basis_of_preparation WHERE status = 'current' ORDER BY created_at DESC"
+        return await self.fetch_all(query)
+    
+
+    async def get_info_by_condition(self, conditions: Dict) -> List[Dict]:
+        """根据条件查询编制依据"""
+        if not conditions:
+            return await self.get_list()
+        
+        try:
+            where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
+            where_values = list(conditions.values())
+            
+            query = f"SELECT * FROM t_basis_of_preparation WHERE {where_clause} AND status = 'current' ORDER BY created_at DESC"
+            return await self.fetch_all(query, tuple(where_values))
+            
+        except Exception as err:
+            handler_err(logger=server_logger, err=err, err_name="条件查询失败")
+            raise

+ 2 - 2
foundation/agent/generate/model_generate.py

@@ -39,12 +39,12 @@ class GenerateModelClient:
         # logger.info(f"[模型生成结果]: {response.content}")
         return response.content
 
-    async def get_model_generate_stream(self, trace_id, task_prompt_info: dict):
+    def get_model_generate_stream(self, trace_id, task_prompt_info: dict):
         """
             模型流式生成(异步)
         """
         prompt_template = task_prompt_info["task_prompt"]
-        # 直接格式化消息,不需要额外的invoke步骤
+        # 直接格式化消息,不需要额外的invoke步骤  stream
         messages = prompt_template.format_messages()
         response = self.llm.stream(messages)
         for chunk in response:

+ 12 - 3
foundation/agent/workflow/test_workflow_node.py

@@ -22,7 +22,7 @@ from foundation.agent.generate.test_intent import intent_identify_client
 from foundation.agent.test_agent import test_agent_client
 from foundation.schemas.test_schemas import FormConfig
 from foundation.agent.generate.model_generate import generate_model_client
-
+from foundation.utils.yaml_utils import system_prompt_config
 
 
 
@@ -87,7 +87,7 @@ class TestWorkflowNode:
         }
     
 
-    def chat_box_generate(self , state: TestCusState) -> dict:
+    async def chat_box_generate(self , state: TestCusState) -> dict:
         """
             模型生成节点(纯生成类问题)
             :param state:
@@ -98,7 +98,16 @@ class TestWorkflowNode:
         user_input = state["user_input"]
         task_prompt_info = state["task_prompt_info"]
         task_prompt_info["task_prompt"] = ""
-        response_content = generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info, input_query=user_input)
+
+      # 创建ChatPromptTemplate
+        template = ChatPromptTemplate.from_messages([
+            ("system", system_prompt_config['system_prompt']),
+            ("user", user_input)
+        ])
+
+        task_prompt_info = {"task_prompt": template}
+
+        response_content = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
         messages = [AIMessage(content=response_content , name="chat_box_generate")]
         server_logger.info(trace_id=trace_id, msg=f"【result】: {response_content}", log_type="chat_box_generate")
         return {

+ 0 - 157
foundation/base/mysql/async_mysql_base_dao.py

@@ -216,161 +216,4 @@ class TestTabDAO(AsyncBaseDAO):
         return await self.fetch_all(query)
     
 
-    async def get_users_by_condition(self, conditions: Dict) -> List[Dict]:
-        """根据条件查询用户"""
-        if not conditions:
-            return await self.get_all_users()
-        
-        try:
-            where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
-            where_values = list(conditions.values())
-            
-            query = f"SELECT * FROM test_tab WHERE {where_clause} AND status = 'active' ORDER BY created_at DESC"
-            return await self.fetch_all(query, tuple(where_values))
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="条件查询用户失败")
-            raise
-    
-    # ========== 修改方法 ==========
-    
-    async def update_user(self, user_id: int, **updates) -> bool:
-        """
-        更新用户信息
-        
-        Args:
-            user_id: 用户ID
-            **updates: 要更新的字段,如 name='新名字', age=25, email='new@email.com'
-        
-        Returns:
-            bool: 更新是否成功
-        """
-        if not updates:
-            server_logger.warning("没有提供更新字段")
-            return False
-        
-        # 过滤允许更新的字段
-        allowed_fields = {'name', 'email', 'age', 'status'}
-        valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
-        
-        if not valid_updates:
-            server_logger.warning("没有有效的更新字段")
-            return False
-        
-        try:
-            return await self.update_by_id('test_tab', user_id, valid_updates)
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="更新用户失败")
-            raise
-    
-    async def update_user_by_email(self, email: str, **updates) -> bool:
-        """
-        根据邮箱更新用户信息
-        
-        Args:
-            email: 用户邮箱
-            **updates: 要更新的字段
-        
-        Returns:
-            bool: 更新是否成功
-        """
-        if not updates:
-            server_logger.warning("没有提供更新字段")
-            return False
-        
-        # 过滤允许更新的字段
-        allowed_fields = {'name', 'age', 'status'}
-        valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
-        
-        if not valid_updates:
-            server_logger.warning("没有有效的更新字段")
-            return False
-        
-        try:
-            return await self.update_record('test_tab', valid_updates, {'email': email})
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="根据邮箱更新用户失败")
-            raise
-    
-    async def update_user_status(self, user_id: int, status: str) -> bool:
-        """
-        更新用户状态
-        
-        Args:
-            user_id: 用户ID
-            status: 状态值 ('active' 或 'inactive')
-        
-        Returns:
-            bool: 更新是否成功
-        """
-        if status not in ('active', 'inactive'):
-            raise ValueError("状态值必须是 'active' 或 'inactive'")
-        
-        try:
-            return await self.update_user(user_id, status=status)
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="更新用户状态失败")
-            raise
-    
-    async def batch_update_users(self, updates_list: List[Dict]) -> bool:
-        """
-        批量更新用户信息
-        
-        Args:
-            updates_list: 更新数据列表,每个元素必须包含id字段
-        
-        Returns:
-            bool: 批量更新是否成功
-        """
-        try:
-            return await self.batch_update('test_tab', updates_list, 'id')
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="批量更新用户失败")
-            raise
-    
-    async def update_users_age_range(self, min_age: int, max_age: int, updates: Dict) -> bool:
-        """
-        更新年龄范围内的用户
-        
-        Args:
-            min_age: 最小年龄
-            max_age: 最大年龄
-            updates: 要更新的字段
-        
-        Returns:
-            bool: 更新是否成功
-        """
-        try:
-            where_sql = "age BETWEEN %s AND %s AND status = 'active'"
-            params = (min_age, max_age)
-            
-            return await self.update_with_condition('test_tab', updates, where_sql, params)
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="更新年龄范围用户失败")
-            raise
-    
-    async def increment_user_age(self, user_id: int, increment: int = 1) -> bool:
-        """
-        增加用户年龄
-        
-        Args:
-            user_id: 用户ID
-            increment: 增加的值,默认为1
-        
-        Returns:
-            bool: 更新是否成功
-        """
-        try:
-            sql = "UPDATE test_tab SET age = age + %s WHERE id = %s AND status = 'active'"
-            return await self.execute_query(sql, (increment, user_id))
-            
-        except Exception as err:
-            handler_err(logger=server_logger, err=err, err_name="增加用户年龄失败")
-            raise
-
 

+ 3 - 3
foundation/base/tasks.py

@@ -22,7 +22,7 @@ def submit_task_processing_task(self, file_info: dict, _system_trace_id: str = N
     if _system_trace_id:
         from foundation.trace.trace_context import TraceContext
         TraceContext.set_trace_id(_system_trace_id)
-        logger.info(f"Celery任务恢复trace_id: {_system_trace_id}")
+        logger.info(f"Celery任务恢复")
 
     # 添加调试信息
     logger.info("=== Celery任务接收调试 ===")
@@ -37,7 +37,7 @@ def submit_task_processing_task(self, file_info: dict, _system_trace_id: str = N
     try:
         # 更新任务状态 - 开始处理
         self.update_state(
-            state='PROGRESS',
+            state='current',
             meta={
                 'current': 0,
                 'total': 100,
@@ -62,7 +62,7 @@ def submit_task_processing_task(self, file_info: dict, _system_trace_id: str = N
 
         # 更新任务状态 - 完成
         self.update_state(
-            state='PROGRESS',
+            state='current',
             meta={
                 'current': 100,
                 'total': 100,

+ 2 - 2
foundation/logger/loggering.py

@@ -30,7 +30,7 @@ class CompatibleLogger(logging.Logger):
                  log_format=None, datefmt=None):
         # 初始化父类
         super().__init__(name)
-        self.setLevel(logging.DEBUG)  # 设置logger自身为最低级别
+        self.setLevel(logging.INFO)  # 设置logger自身为最低级别
 
         # 存储配置
         self.log_dir = log_dir
@@ -95,7 +95,7 @@ class CompatibleLogger(logging.Logger):
     def _create_console_handler(self):
         """创建控制台日志处理器"""
         console_handler = logging.StreamHandler(sys.stdout)
-        console_handler.setLevel(logging.INFO)
+        console_handler.setLevel(logging.DEBUG)
         console_handler.setFormatter(self.formatter)
         # 添加trace_filter,自动注入system_trace_id
         console_handler.addFilter(trace_filter)

+ 2 - 35
foundation/models/silicon_flow.py

@@ -12,7 +12,7 @@ from foundation.logger.loggering import server_logger
 from foundation.utils.common import handler_err
 from openai import OpenAI
 from langchain_core.embeddings import Embeddings
-from chromadb.utils.embedding_functions import EmbeddingFunction
+#from chromadb.utils.embedding_functions import EmbeddingFunction
 from typing import List
 import numpy as np
 
@@ -55,39 +55,6 @@ class SiliconFlowEmbeddings(Embeddings):
 
 
 
-class ChromaSiliconFlowEmbedding(EmbeddingFunction):
-    """
-        将SiliconFlowEmbeddings适配到ChromaDB的嵌入函数接口
-    """
-    def __init__(self, embeddings):
-        self.embeddings = embeddings
-
-    def __call__(self, input: List[str]) -> List[List[float]]:
-        raw_embeddings = self.embeddings.embed_documents(input)  # 关键添加
-        return self.normalized_embeddings(raw_embeddings)
-
-    def embed_documents(self, input: List[str]) -> List[List[float]]:
-        raw_embeddings = self.embeddings.embed_documents(input)  # 关键添加
-        return self.normalized_embeddings(raw_embeddings)
-
-    def embed_query(self, text: str) -> List[float]:
-        """对查询文本进行向量化"""
-        raw_embeddings = self.embeddings.embed_documents([text])[0]
-        return self.normalized_embeddings(raw_embeddings)
-
-    
-    def normalized_embeddings(self , raw_embeddings):
-        # L2归一化处理
-        normalized = []
-        for vector in raw_embeddings:
-            norm = np.linalg.norm(vector)
-            if norm > 0:
-                normalized.append(vector / norm)
-            else:
-                normalized.append(vector)
-        return normalized
-
-
 
 class SiliconFlowAPI(BaseApiPlatform):
     def __init__(self , trace_id=""):
@@ -104,7 +71,7 @@ class SiliconFlowAPI(BaseApiPlatform):
         self.client = self.get_openai_client(self.model_server_url, self.api_key)
         # 创建LangChain兼容的嵌入对象
         langchain_embeddings = SiliconFlowEmbeddings(base_url = self.model_server_url , api_key=self.api_key , embed_model_id=self.embed_model_id)
-        self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
+        #self.embed_model = ChromaSiliconFlowEmbedding(embeddings=langchain_embeddings)
 
 
 

+ 108 - 0
foundation/rag/vector/base_vector.py

@@ -0,0 +1,108 @@
+from foundation.logger.loggering import server_logger as logger
+import os
+import time
+from tqdm import tqdm
+from typing import List, Dict, Any
+from foundation.models.base_online_platform import BaseApiPlatform
+
+
+class BaseVectorDB:
+    """
+      向量数据库操作基类
+    """
+        
+    def __init__(self , base_api_platform :BaseApiPlatform):
+        self.base_api_platform = base_api_platform
+
+
+
+    def text_to_vector(self, text: str) -> List[float]:
+        """
+        将文本转换为向量
+        """
+        return self.base_api_platform.get_embeddings([text])[0]
+    
+
+    def document_standard(self, documents: List[Dict[str, Any]]):
+        """
+          文档标准处理
+        """
+        raise NotImplementedError
+
+    
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+          单条添加文档
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        raise NotImplementedError
+
+
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+          批量添加文档
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        raise NotImplementedError
+
+
+    def add_tqdm_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]] , batch_size=10):
+        """
+          批量添加文档(带进度条)
+          param: 扩展参数信息,如:表名称等
+          documents: 文档列表,包括元数据信息
+          # 返回: 添加的文档ID列表
+        """
+        
+        logger.info(f"Inserting {len(documents)} documents.")
+        start_time = time.time()
+        total_docs_inserted = 0
+
+        total_batches = (len(documents) + batch_size - 1) // batch_size
+
+        with tqdm(total=total_batches, desc="Inserting batches", unit="batch") as pbar:
+            for i in range(0, len(documents), batch_size):
+                batch = documents[i:i + batch_size]
+                # 调用传入的插入函数
+                self.add_batch_documents(param, batch)
+
+                total_docs_inserted += len(batch)
+                # 计算并显示当前的TPM
+                elapsed_time = time.time() - start_time
+                if elapsed_time > 0:
+                    tpm = (total_docs_inserted / elapsed_time) * 60
+                    pbar.set_postfix({"TPM": f"{tpm:.2f}"})
+
+                pbar.update(1)
+
+        
+
+
+    def retriever(self, input_query):
+        """
+          根据用户问题查询文档
+        """
+        raise NotImplementedError
+
+
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , 
+                          top_k=10, filters: Dict[str, Any] = None):
+      """
+          根据用户问题查询文档
+      """
+      raise NotImplementedError
+
+
+    def retriever(self, param: Dict[str, Any], query_text: str, 
+                          top_k: int = 5, filters: Dict[str, Any] = None):
+      """
+          根据用户问题查询文档
+      """
+      raise NotImplementedError
+
+
+    

+ 367 - 0
foundation/rag/vector/milvus_vector.py

@@ -0,0 +1,367 @@
+import time
+from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
+from sentence_transformers import SentenceTransformer
+import numpy as np
+from typing import List, Dict, Any, Optional
+import json
+from foundation.base.config import config_handler
+from foundation.logger.loggering import server_logger as logger
+from foundation.rag.vector.base_vector import BaseVectorDB
+from foundation.models.base_online_platform import BaseApiPlatform
+
+class MilvusVectorManager(BaseVectorDB):
+    def __init__(self, base_api_platform :BaseApiPlatform):
+        """
+        初始化 Milvus 连接
+        """
+        self.base_api_platform = base_api_platform
+
+        self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
+        self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
+        self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
+        self.user = config_handler.get('milvus', 'MILVUS_USER')
+        self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
+        
+        # 初始化文本向量化模型
+        #self.model = SentenceTransformer('all-MiniLM-L6-v2')  # 可以替换为其他模型
+        
+        # 连接到 Milvus
+        self.connect()
+    
+    def connect(self):
+        """连接到 Milvus 服务器
+        ,
+                password=self.password
+                alias="default",
+        """
+        try:
+            connections.connect(
+                alias="default",
+                host=self.host,
+                port=self.port,
+                user=self.user,
+                db_name="lq_db"
+            )
+            logger.info(f"Connected to Milvus at {self.host}:{self.port}")
+        except Exception as e:
+            logger.error(f"Failed to connect to Milvus: {e}")
+            raise
+    
+    def create_collection(self, collection_name: str, dimension: int = 768, 
+                         description: str = "Vector collection for text embeddings"):
+        """
+        创建向量集合
+        """
+        try:
+            # 检查集合是否已存在
+            if utility.has_collection(collection_name):
+                logger.info(f"Collection {collection_name} already exists")
+                utility.drop_collection(collection_name)
+                logger.info(f"Collection '{collection_name}' dropped successfully")
+                
+            
+            # 定义字段
+            fields = [
+                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
+                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
+                FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
+                FieldSchema(name="metadata", dtype=DataType.JSON),
+                FieldSchema(name="created_at", dtype=DataType.INT64)
+            ]
+            
+            # 创建集合模式
+            schema = CollectionSchema(
+                fields=fields,
+                description=description
+            )
+            
+            # 创建集合
+            collection = Collection(
+                name=collection_name,
+                schema=schema
+            )
+            
+            # 创建索引
+            index_params = {
+                "index_type": "IVF_FLAT",
+                "metric_type": "COSINE",
+                "params": {"nlist": 100}
+            }
+            
+            collection.create_index(field_name="embedding", index_params=index_params)
+            logger.info(f"Collection {collection_name} created successfully!")
+            
+        except Exception as e:
+            logger.error(f"Error creating collection: {e}")
+            raise
+    
+    
+    
+
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+        插入单个文本及其向量
+        """
+        try:
+            collection_name = param.get('collection_name')
+            text = document.get('content')
+            metadata = document.get('metadata')
+            collection = Collection(collection_name)
+            created_at = None
+            
+            # 转换文本为向量
+
+            embedding = self.text_to_vector(text)
+            #logger.info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
+            #logger.info(f"Text converted to embedding:{embedding}")
+            # 准备数据
+            data = [
+                [embedding],  # embedding
+                [text],  # text_content
+                [metadata or {}],  # metadata
+                [created_at or int(time.time())]  # created_at
+            ]
+            logger.info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
+            
+
+            # 插入数据
+            insert_result = collection.insert(data)
+            collection.flush()  # 确保数据被写入
+            
+            logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
+            return insert_result.primary_keys[0]
+            
+        except Exception as e:
+            logger.error(f"Error inserting text: {e}")
+            return None
+    
+
+
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+        批量插入文本
+        texts: [{'text': '...', 'metadata': {...}}, ...]
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            
+            text_contents = []
+            embeddings = []
+            metadatas = []
+            timestamps = []
+            
+            for item in documents:
+                text = item['content']
+                metadata = item.get('metadata', {})
+                
+                # 转换文本为向量
+                embedding = self.text_to_vector(text)
+                
+                text_contents.append(text)
+                embeddings.append(embedding)
+                metadatas.append(metadata)
+                timestamps.append(int(time.time()))
+            
+            
+            # 准备批量数据
+            data = [embeddings, text_contents, metadatas, timestamps]
+            #logger.info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
+            
+            # 批量插入
+            insert_result = collection.insert(data)
+            collection.flush()  # 确保数据被写入
+            
+            logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
+            return insert_result.primary_keys
+            
+        except Exception as e:
+            logger.error(f"Error batch inserting: {e}")
+            return None
+    
+
+
+
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
+                           top_k=5, filters: Dict[str, Any] = None):
+        """
+        搜索相似文本
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            
+            # 加载集合到内存(如果还没有加载)
+            collection.load()
+            
+            # 转换查询文本为向量
+            query_embedding = self.text_to_vector(query_text)
+            
+            # 搜索参数
+            search_params = {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            }
+             # 构建过滤表达式
+            filter_expr = self._create_filter(filters)
+            
+            # 执行搜索
+            results = collection.search(
+                data=[query_embedding],
+                anns_field="vector",
+                param=search_params,
+                limit=top_k,
+                expr=filter_expr,
+                output_fields=["text_content", "metadata"]
+            )
+            
+            # 格式化结果
+            formatted_results = []
+            for hits in results:
+                for hit in hits:
+                    formatted_results.append({
+                        'id': hit.id,
+                        'text_content': hit.entity.get('text_content'),
+                        'metadata': hit.entity.get('metadata'),
+                        'distance': hit.distance,
+                        'similarity': 1 - hit.distance  # 转换为相似度
+                    })
+            
+            return formatted_results
+            
+        except Exception as e:
+            logger.error(f"Error searching: {e}")
+            return []
+    
+    def retriever(self, param: Dict[str, Any], query_text: str, 
+                          top_k: int = 5, filters: Dict[str, Any] = None):
+        """
+        带过滤条件的相似搜索
+        """
+        try:
+            collection_name = param.get('collection_name')
+            collection = Collection(collection_name)
+            collection.load()
+            
+            query_embedding = self.text_to_vector(query_text)
+            
+            # 构建过滤表达式
+            filter_expr = self._create_filter(filters)
+            
+            search_params = {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            }
+            
+            results = collection.search(
+                data=[query_embedding],
+                anns_field="vector",
+                param=search_params,
+                limit=top_k,
+                expr=filter_expr,
+                output_fields=["text_content", "metadata"]
+            )
+            
+            formatted_results = []
+            for hits in results:
+                for hit in hits:
+                    formatted_results.append({
+                        'id': hit.id,
+                        'text_content': hit.entity.get('text_content'),
+                        'metadata': hit.entity.get('metadata'),
+                        'distance': hit.distance,
+                        'similarity': 1 - hit.distance
+                    })
+            
+            return formatted_results
+            
+        except Exception as e:
+            logger.error(f"Error searching with filter: {e}")
+            return []
+    
+    
+    def _create_filter(self, filters: Dict[str, Any]) -> str:
+        """
+        创建过滤条件
+        """
+        # 构建过滤表达式
+        filter_expr = ""
+        if filters:
+            conditions = []
+            for key, value in filters.items():
+                if isinstance(value, str):
+                    conditions.append(f'metadata["{key}"] == "{value}"')
+                elif isinstance(value, (int, float)):
+                    conditions.append(f'metadata["{key}"] == {value}')
+                else:
+                    conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
+            filter_expr = " and ".join(conditions)
+        
+        return filter_expr
+
+    def db_test(self , query_text):
+        query = query_text
+        import time
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        from foundation.models.silicon_flow import SiliconFlowAPI
+        client = SiliconFlowAPI()
+        # 初始化 Milvus 管理器
+        milvus_manager = MilvusVectorManager(base_api_platform=client)
+        
+        # 创建集合
+        collection_name = 'text_embeddings'
+        milvus_manager.create_collection(collection_name, dimension=768)
+        
+        param = {"collection_name": collection_name}
+        
+        # 插入单个文本
+        sample_text = "这是一个关于人工智能的文档。"
+        milvus_manager.add_document(
+            param, 
+            {"content":sample_text , "metadata": {'category': 'AI', 'source': 'example'}}
+        )
+        
+        # 批量插入文本
+        sample_texts = [
+            {
+                'content': '机器学习是人工智能的一个重要分支。',
+                'metadata': {'category': 'ML', 'author': 'John'}
+            },
+            {
+                'content': '深度学习在图像识别领域取得了显著成果。',
+                'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
+            },
+            {
+                'content': '自然语言处理技术在聊天机器人中得到广泛应用。',
+                'metadata': {'category': 'NLP', 'author': 'Bob'}
+            }
+            ,
+            {
+                'content': 'AI发展速度快,但需要更多的计算资源。',
+                'metadata': {'category': 'AI', 'author': 'Bob'}
+            }
+        ]
+        
+       
+        milvus_manager.add_batch_documents(param=param, documents=sample_texts)
+        
+        # 搜索相似文本
+        query = "人工智能相关的技术"
+        similar_docs = milvus_manager.similarity_search(param, query, top_k=5)
+        
+        logger.info(f"Similar documents found-{len(similar_docs)}:")
+        for doc in similar_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
+        logger.info(f"{'=' *20}")
+        # 带过滤条件的搜索
+        filtered_docs = milvus_manager.retriever(
+            param, 
+            query, 
+            top_k=5, 
+            filters={'category': 'AI'}
+        )
+        
+        logger.info(f"\nFiltered similar documents-{len(filtered_docs)}:")
+        for doc in filtered_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
+

+ 269 - 0
foundation/rag/vector/pg_vector.py

@@ -0,0 +1,269 @@
+
+import psycopg2
+from psycopg2.extras import RealDictCursor
+import numpy as np
+#from sentence_transformers import SentenceTransformer
+import json
+from typing import List, Dict, Any
+from foundation.base.config import config_handler
+from foundation.logger.loggering import server_logger as logger
+from foundation.rag.vector.base_vector import BaseVectorDB
+from foundation.models.base_online_platform import BaseApiPlatform
+
+class PGVectorDB(BaseVectorDB):
+    def __init__(self , base_api_platform :BaseApiPlatform):
+        """
+        初始化 pgvector 连接
+        """
+        self.connection_params = {
+            'host': config_handler.get('pgvector', 'PGVECTOR_HOST', 'localhost'),
+            'port': int(config_handler.get('pgvector', 'PGVECTOR_PORT', '5432')),
+            'database': config_handler.get('pgvector', 'PGVECTOR_DB', 'postgres'),
+            'user': config_handler.get('pgvector', 'PGVECTOR_USER', 'postgres'),
+            'password': config_handler.get('pgvector', 'PGVECTOR_PASSWORD', 'postgres')
+        }
+
+        self.base_api_platform = base_api_platform
+        
+        
+    def get_connection(self):
+        """获取数据库连接"""
+        #logger.info(f"Connecting to PostgreSQL...{self.connection_params}")
+        conn = psycopg2.connect(**self.connection_params)
+        # 启用 pgvector 扩展
+        with conn.cursor() as cur:
+            cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+        conn.commit()
+        return conn
+    
+    def create_table(self, table_name: str, vector_dim: int = 384):
+        """
+        创建向量表
+        """
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                # 创建表
+                create_table_sql = f"""
+                CREATE TABLE IF NOT EXISTS {table_name} (
+                    id SERIAL PRIMARY KEY,
+                    text_content TEXT,
+                    embedding vector({vector_dim}),
+                    metadata JSONB DEFAULT '{{}}'::jsonb,
+                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+                );
+                
+                -- 创建向量相似度索引
+                CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding 
+                ON {table_name} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
+                """
+                cur.execute(create_table_sql)
+                conn.commit()
+                print(f"Table {table_name} created successfully!")
+        except Exception as e:
+            logger.error(f"Error creating table: {e}")
+            conn.rollback()
+        finally:
+            conn.close()
+    
+
+    def document_standard(self, documents: List[Dict[str, Any]]):
+        """
+        对文档进行结果标准处理
+        """
+        result = []
+        for doc in documents:
+            tmp = {}
+            tmp['content'] = doc.page_content
+            tmp['metadata'] = doc.metadata if doc.metadata else {}
+            result.append(tmp)
+        return result
+
+
+
+    def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
+        """
+        插入单个文本及其向量
+        """
+        table_name = param.get('table_name')
+        text = document.get('content')
+        metadata = document.get('metadata')
+
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                embedding = self.text_to_vector(text)
+                metadata = metadata or {}
+                
+                insert_sql = f"""
+                INSERT INTO {table_name} (text_content, embedding, metadata)
+                VALUES (%s, %s, %s)
+                RETURNING id;
+                """
+                cur.execute(insert_sql, (text, embedding, json.dumps(metadata)))
+                inserted_id = cur.fetchone()[0]
+                conn.commit()
+                print(f"Text inserted with ID: {inserted_id}")
+                return inserted_id
+        except Exception as e:
+            print(f"Error inserting text: {e}")
+            conn.rollback()
+            return None
+        finally:
+            conn.close()
+    
+    def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
+        """
+        批量插入文本
+        texts: [{'text': '...', 'metadata': {...}}, ...]
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cur:
+                # 准备数据
+                data_to_insert = []
+                for item in documents:
+                    text = item['content']
+                    metadata = item.get('metadata', {})
+                    embedding = self.text_to_vector(text)
+                    data_to_insert.append((text, embedding, json.dumps(metadata)))
+                
+                # 批量插入
+                insert_sql = f"""
+                INSERT INTO {table_name} (text_content, embedding, metadata)
+                VALUES (%s, %s, %s)
+                """
+                cur.executemany(insert_sql, data_to_insert)
+                conn.commit()
+                logger.info(f"Batch inserted {len(data_to_insert)} records")
+        except Exception as e:
+            logger.error(f"Error batch inserting: {e}")
+            conn.rollback()
+        finally:
+            conn.close()
+    
+    def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , 
+                          top_k=5, filters: Dict[str, Any] = None):
+        """
+        搜索相似文本
+            search_similar 使用距离度量(越小越相似)
+            
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor(cursor_factory=RealDictCursor) as cur:
+                query_embedding = self.text_to_vector(query_text)
+                
+                search_sql = f"""
+                SELECT id, text_content, metadata, 
+                       embedding <=> %s::vector AS distance
+                FROM {table_name}
+                ORDER BY embedding <=> %s::vector
+                LIMIT %s;
+                """
+                cur.execute(search_sql, (query_embedding, query_embedding, top_k))
+                results = cur.fetchall()
+                
+                return results
+        except Exception as e:
+            logger.error(f"Error searching: {e}")
+            return []
+        finally:
+            conn.close()
+
+    
+    def retriever(self, param: Dict[str, Any], query_text: str , min_score=0.1 , 
+                                 top_k=10, filters: Dict[str, Any] = None):
+        """
+        使用余弦相似度搜索相似文本
+        """
+        table_name = param.get('table_name')
+        conn = self.get_connection()
+        try:
+            with conn.cursor(cursor_factory=RealDictCursor) as cur:
+                query_embedding = self.text_to_vector(query_text)
+                
+                search_sql = f"""
+                SELECT id, text_content, metadata,
+                       1 - (embedding <=> %s::vector) AS cosine_similarity
+                FROM {table_name}
+                WHERE 1 - (embedding <=> %s::vector) > %s
+                ORDER BY 1 - (embedding <=> %s::vector) DESC
+                LIMIT %s;
+                """
+                cur.execute(search_sql, (query_embedding, query_embedding, min_score, query_embedding, top_k))
+                results = cur.fetchall()
+                # 打印结果
+                self.result_logger_info(query_text , results)
+                return results
+        except Exception as e:
+            logger.error(f"Error searching with cosine similarity: {e}")
+            return []
+        finally:
+            conn.close()
+
+    
+    def result_logger_info(self , query, result_docs_cos):
+        """
+            记录搜索结果
+        """
+        logger.info(f"\n {'=' * 50}")
+        # 使用余弦相似度搜索
+        logger.info(f"\nSimilar documents with cosine similarity,query:{query},result_count: {len(result_docs_cos)}:")
+        for doc in result_docs_cos:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
+
+
+
+    def db_test(self , query_text: str):
+        """
+        测试数据库连接和操作
+        """
+        table_name = 'test_documents'
+        # 创建表
+        self.create_table(table_name, vector_dim=768)
+        
+        # 插入单个文本
+        sample_text = "这是一个关于人工智能的文档。"
+        #self.insert_text(table_name, sample_text, {'category': 'AI', 'source': 'example'})
+        
+        # 批量插入文本
+        sample_texts = [
+            {
+                'text': '机器学习是人工智能的一个重要分支。',
+                'metadata': {'category': 'ML', 'author': 'John'}
+            },
+            {
+                'text': '深度学习在图像识别领域取得了显著成果。',
+                'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
+            },
+            {
+                'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
+                'metadata': {'category': 'NLP', 'author': 'Bob'}
+            }
+        ]
+        
+        #self.batch_insert_texts(table_name, sample_texts)
+        
+
+        logger.info(f"\n {'=' * 50}")
+        # 搜索相似文本
+        #query = "人工智能相关的技术"
+        query = query_text
+        logger.info(f"\n query={query}")
+
+        similar_docs = self.search_similar(table_name, query, top_k=3)
+        logger.info(f"Similar documents found {len(similar_docs)}:")
+        for doc in similar_docs:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {1 - doc['distance']:.3f}")
+        
+        logger.info(f"\n {'=' * 50}")
+        # 使用余弦相似度搜索
+        similar_docs_cos = self.search_by_cosine_similarity(table_name, query, top_k=3)
+        
+        logger.info(f"\nSimilar documents with cosine similarity {len(similar_docs_cos)}:")
+        for doc in similar_docs_cos:
+            logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['cosine_similarity']:.3f}")
+

+ 2 - 2
foundation/trace/celery_trace.py

@@ -83,9 +83,9 @@ class CeleryTraceManager:
 
         # 将trace_id添加到任务参数中
         if current_trace_id and current_trace_id != 'no-trace':
-            kwargs['_system_trace_id'] = current_trace_id
+            kwargs['_system_trace_id'] = current_trace_id   
 
-        logger.info(f"提交Celery任务,trace_id: {current_trace_id}")
+        logger.info(f"提交Celery任务")
 
         # 提交任务
         return task_func.delay(*args, **kwargs)

+ 1 - 0
foundation/utils/redis_utils.py

@@ -37,6 +37,7 @@ async def get_redis_result_cache_data(data_type: str , trace_id: str):
      # 直接获取 RedisStore
     redis_store = await RedisConnectionFactory.get_redis_store()
     value = await redis_store.get(key) 
+    value = value.decode('utf-8')
     return value
 
 

+ 5 - 2
foundation/utils/tool_utils.py

@@ -9,7 +9,7 @@ from foundation.logger.loggering import server_logger
 from foundation.utils.common import handler_err
 from foundation.base.config import config_handler
 import json
-from datetime import datetime
+from datetime import datetime, date
 # 获取当前文件的目录
 current_dir = os.path.dirname(__file__)
 # 构建到 .env 的相对路径
@@ -44,8 +44,11 @@ def get_system_prompt() -> str:
 
 
 
+
 class DateTimeEncoder(json.JSONEncoder):
     def default(self, obj):
         if isinstance(obj, datetime):
-            return obj.isoformat()  # 转换为 ISO 8601 格式字符串
+            return obj.strftime('%Y-%m-%d %H:%M:%S')
+        elif isinstance(obj, date):  # 添加对 date 类型的支持
+            return obj.strftime('%Y-%m-%d')
         return super().default(obj)

+ 3 - 1
requirements.txt

@@ -121,7 +121,6 @@ pydantic-settings==2.10.1
 pydantic_core==2.33.2
 Pygments==2.19.2
 PyJWT==2.8.0
-pymilvus==2.5.12
 PyMuPDF==1.26.3
 PyMySQL==1.1.1
 pyperclip==1.9.0
@@ -182,4 +181,7 @@ langgraph-checkpoint-postgres==2.0.23
 langgraph-checkpoint-redis==0.0.8
 langchain-redis==0.2.3
 aiomysql==0.3.2
+celery=5.5.3
+pypdf==6.2.0
+grandalf==0.8
 

+ 1 - 1
run.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 
 # 服务管理脚本
-APP_NAME="xiwu_agent_server"         # 自定义服务名称
+APP_NAME="lq_agent_platform_server"         # 自定义服务名称
 PID_FILE="./gunicorn_log/gunicorn.pid"          # PID 文件路径
 LOG_FILE="./gunicorn_log/gunicorn.log"          # 日志文件路径
 START_COMMAND="gunicorn -c gunicorn_config.py server.app:app"

Plik diff jest za duży
+ 49 - 27
temp/AI审查结果.json


+ 0 - 281
test/construction_review/api_test_client.py

@@ -1,281 +0,0 @@
-"""
-施工方案审查API测试客户端
-用于测试Mock接口和前端联调
-"""
-
-import requests
-import json
-import time
-import uuid
-from pathlib import Path
-from typing import Optional, Dict, Any
-
-class ConstructionReviewAPIClient:
-    """施工方案审查API客户端"""
-
-    def __init__(self, base_url: str = "http://127.0.0.1:8034", api_key: Optional[str] = None):
-        self.base_url = base_url.rstrip('/')
-        self.api_key = api_key
-        self.session = requests.Session()
-
-        if api_key:
-            self.session.headers.update({
-                'Authorization': f'Bearer {api_key}'
-            })
-
-    def upload_file(self, file_path: str, project_plan_type: str, user: str,
-                   callback_url: Optional[str] = None) -> Dict[str, Any]:
-        """
-        上传文件
-
-        Args:
-            file_path: 文件路径
-            project_plan_type: 工程方案类型
-            user: 用户标识
-            callback_url: 回调URL(可选)
-
-        Returns:
-            上传响应结果
-        """
-        url = f"{self.base_url}/sgsc/file_upload"
-
-        if not Path(file_path).exists():
-            raise FileNotFoundError(f"文件不存在: {file_path}")
-
-        with open(file_path, 'rb') as f:
-            files = {'file': f}
-            data = {
-                'project_plan_type': project_plan_type,
-                'user': user
-            }
-
-            if callback_url:
-                data['callback_url'] = callback_url
-
-            response = self.session.post(url, files=files, data=data)
-            response.raise_for_status()
-            return response.json()
-
-    def get_task_progress(self, callback_task_id: str, user: str) -> Dict[str, Any]:
-        """
-        查询任务进度
-
-        Args:
-            callback_task_id: 任务ID
-            user: 用户标识
-
-        Returns:
-            进度查询结果
-        """
-        url = f"{self.base_url}/sgsc/task_progress/{callback_task_id}"
-        params = {'user': user}
-
-        response = self.session.get(url, params=params)
-        response.raise_for_status()
-        return response.json()
-
-    def get_review_results(self, file_id: str, user: str, result_type: str) -> Dict[str, Any]:
-        """
-        获取审查结果
-
-        Args:
-            file_id: 文件ID
-            user: 用户标识
-            result_type: 结果类型 ("summary" 或 "issues")
-
-        Returns:
-            审查结果
-        """
-        url = f"{self.base_url}/sgsc/review_results"
-        data = {
-            'id': file_id,
-            'user': user,
-            'type': result_type
-        }
-
-        response = self.session.post(url, json=data)
-        response.raise_for_status()
-        return response.json()
-
-    def wait_for_completion(self, callback_task_id: str, user: str,
-                          timeout: int = 1800, interval: int = 10) -> Dict[str, Any]:
-        """
-        等待任务完成
-
-        Args:
-            callback_task_id: 任务ID
-            user: 用户标识
-            timeout: 超时时间(秒)
-            interval: 轮询间隔(秒)
-
-        Returns:
-            最终任务状态
-        """
-        start_time = time.time()
-
-        while time.time() - start_time < timeout:
-            try:
-                result = self.get_task_progress(callback_task_id, user)
-
-                if result['data']['review_task_status'] == 'completed':
-                    print(f"任务完成! 总进度: {result['data']['overall_progress']}%")
-                    return result
-                else:
-                    progress = result['data']['overall_progress']
-                    print(f"任务进行中... 进度: {progress}%")
-                    time.sleep(interval)
-
-            except Exception as e:
-                print(f"查询进度失败: {e}")
-                time.sleep(interval)
-
-        raise TimeoutError(f"任务超时,等待时间超过 {timeout} 秒")
-
-class MockTestRunner:
-    """Mock测试运行器"""
-
-    def __init__(self, client: ConstructionReviewAPIClient):
-        self.client = client
-
-    def test_file_upload(self, file_path: str = None) -> Dict[str, Any]:
-        """测试文件上传"""
-        print("=== 测试文件上传 ===")
-
-        # 创建测试文件(如果没有提供文件路径)
-        if not file_path:
-            test_file = Path(r"D:\wx_work\sichuan_luqiao\LQAgentPlatform\data_pipeline\test_rawdata\1f3e1d98-5b4a-4a06-87b3-c7f0413b901a.pdf")
-            if not test_file.exists():
-                # 创建一个简单的测试PDF文件内容
-                test_file.write_bytes(b"%PDF-1.4\n%Mock PDF for testing\n")
-            file_path = str(test_file)
-
-        try:
-            result = self.client.upload_file(
-                file_path=file_path,
-                project_plan_type="bridge_up_part",
-                user=str(uuid.uuid4()),
-                callback_url="https://client.example.com/callback"
-            )
-
-            print(f"✅ 文件上传成功")
-            print(f"文件ID: {result['data']['id']}")
-            print(f"任务ID: {result['data']['callback_task_id']}")
-
-            return result
-
-        except Exception as e:
-            print(f"❌ 文件上传失败: {e}")
-            raise
-
-    def test_progress_query(self, callback_task_id: str, user: str) -> None:
-        """测试进度查询"""
-        print("\n=== 测试进度查询 ===")
-
-        try:
-            result = self.client.get_task_progress(callback_task_id, user)
-
-            print(f"✅ 进度查询成功")
-            print(f"任务状态: {result['data']['review_task_status']}")
-            print(f"总进度: {result['data']['overall_progress']}%")
-
-            for stage in result['data']['stages']:
-                print(f"  - {stage['stage_name']}: {stage['progress']}% ({stage['stage_status']})")
-
-        except Exception as e:
-            print(f"❌ 进度查询失败: {e}")
-            raise
-
-    def test_review_results(self, file_id: str, user: str) -> None:
-        """测试审查结果获取"""
-        print("\n=== 测试审查结果获取 ===")
-
-        # 测试获取总结报告
-        try:
-            result = self.client.get_review_results(file_id, user, "summary")
-
-            print(f"✅ 总结报告获取成功")
-            print(f"风险统计: {result['data']['risk_stats']}")
-            print(f"四维评分: {result['data']['dimension_scores']}")
-            print(f"总结报告: {result['data']['summary_report']}")
-
-        except Exception as e:
-            print(f"❌ 总结报告获取失败: {e}")
-
-        # 测试获取问题条文
-        try:
-            result = self.client.get_review_results(file_id, user, "issues")
-
-            print(f"\n✅ 问题条文获取成功")
-            issues = result['data']['issues']
-            print(f"发现问题数量: {len(issues)}")
-
-            for i, issue in enumerate(issues):
-                print(f"\n问题 {i+1}:")
-                print(f"  ID: {issue['issue_id']}")
-                print(f"  页码: {issue['metadata']['page']}")
-                print(f"  章节: {issue['metadata']['chapter']}")
-                print(f"  风险等级: {issue['risk_summary']['max_risk_level']}")
-                print(f"  检查项数量: {len(issue['review_lists'])}")
-
-        except Exception as e:
-            print(f"❌ 问题条文获取失败: {e}")
-
-    def run_complete_test(self) -> None:
-        """运行完整测试流程"""
-        print("开始施工方案审查API完整测试...")
-
-        try:
-            # 1. 上传文件
-            upload_result = self.test_file_upload()
-            file_id = upload_result['data']['id']
-            callback_task_id = upload_result['data']['callback_task_id']
-            user = str(uuid.uuid4())  # 实际应该从上传响应中获取,这里简化
-
-            # 2. 查询进度(等待一段时间让任务完成)
-            print("\n等待任务完成...")
-            time.sleep(2)  # 短暂等待
-
-            # 先测试进度查询
-            self.test_progress_query(callback_task_id, user)
-
-            # 3. 获取审查结果(可能需要等待任务完成)
-            print("\n获取审查结果...")
-
-            # 如果任务还未完成,直接标记完成(仅用于Mock测试)
-            try:
-                self.test_review_results(file_id, user)
-            except Exception as e:
-                print(f"审查结果获取失败,尝试完成任务: {e}")
-
-                # 完成任务(Mock功能)
-                response = requests.post(f"{self.client.base_url}/sgsc/mock/complete_task",
-                                       data={"callback_task_id": callback_task_id})
-                print("任务已强制完成,重新获取结果...")
-
-                self.test_review_results(file_id, user)
-
-            print("\n🎉 完整测试流程执行成功!")
-
-        except Exception as e:
-            print(f"\n❌ 测试失败: {e}")
-            raise
-
-def main():
-    """主函数 - 运行测试"""
-    print("施工方案审查API Mock测试客户端")
-    print("=" * 50)
-
-    # 创建客户端
-    client = ConstructionReviewAPIClient(
-        base_url="http://127.0.0.1:8034",
-        api_key="mock-api-key-12345"
-    )
-
-    # 创建测试运行器
-    test_runner = MockTestRunner(client)
-
-    # 运行完整测试
-    test_runner.run_complete_test()
-
-if __name__ == "__main__":
-    main()

+ 0 - 370
test/construction_review/test_error_codes_pytest.py

@@ -1,370 +0,0 @@
-"""
-施工方案审查API错误码测试 - pytest版本
-使用pytest运行的标准测试套件
-"""
-
-import pytest
-import requests
-import json
-import uuid
-import time
-import os
-from typing import Dict, Any
-
-# pytest fixtures
-@pytest.fixture(scope="class")
-def api_config():
-    """API配置fixture"""
-    return {
-        "base_url": "http://127.0.0.1:8034",
-        "api_prefix": "/sgsc",
-        "valid_user": "user-001",
-        "valid_project_type": "bridge_up_part",
-        "test_callback_url": "http://test.callback.com"
-    }
-
-@pytest.fixture
-def test_file():
-    """测试文件fixture - 每个测试都创建新的文件对象"""
-    file_path = "data_pipeline/test_rawdata/1f3e1d98-5b4a-4a06-87b3-c7f0413b901a.pdf"
-
-    class TestFile:
-        def __init__(self):
-            if os.path.exists(file_path):
-                self.file = open(file_path, 'rb')
-                self.file_tuple = (os.path.basename(file_path), self.file, 'application/pdf')
-                self.close_file = True
-            else:
-                self.file = None
-                self.file_tuple = ("test.pdf", b"mock pdf content", "application/pdf")
-                self.close_file = False
-
-        def get_file(self):
-            """获取文件元组"""
-            if self.close_file and self.file:
-                # 重新打开文件,确保文件未被关闭
-                self.file.seek(0)
-            return self.file_tuple
-
-        def cleanup(self):
-            """清理资源"""
-            if self.close_file and self.file:
-                self.file.close()
-
-    test_file_obj = TestFile()
-    yield test_file_obj
-    test_file_obj.cleanup()
-
-class TestFileUploadErrors:
-    """文件上传接口错误码测试"""
-
-    @pytest.mark.parametrize("test_case,expected_code", [
-        ("missing_file", "WJSC001"),
-        ("empty_file", "WJSC003"),
-        ("unsupported_format", "WJSC004"),
-        ("invalid_project_type", "WJSC006")
-    ])
-    def test_file_upload_errors(self, api_config, test_case, expected_code):
-        """测试文件上传各种错误场景"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-
-        if test_case == "missing_file":
-            # 不上传文件
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, data=data)
-
-        elif test_case == "empty_file":
-            # 上传空文件
-            files = {"file": ("empty.pdf", b"", "application/pdf")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        elif test_case == "unsupported_format":
-            # 上传不支持的格式
-            files = {"file": ("test.txt", b"text content", "text/plain")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        elif test_case == "invalid_project_type":
-            # 无效的工程方案类型
-            files = {"file": ("test.pdf", b"mock pdf content", "application/pdf")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": "invalid_type",
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        # 验证错误响应
-        assert response.status_code in [400, 403, 404]  # 允许的业务错误状态码
-
-        try:
-            error_data = response.json()
-            assert error_data["code"] == expected_code
-            assert "error_type" in error_data
-            assert "message" in error_data
-        except json.JSONDecodeError:
-            pytest.fail(f"响应不是有效的JSON: {response.text}")
-
-    def test_file_upload_success(self, api_config, test_file):
-        """测试文件上传成功"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-
-        result = response.json()
-        assert "data" in result
-        assert "callback_task_id" in result["data"]
-        assert "id" in result["data"]
-
-    @pytest.mark.skip(reason="文件大小检查未实现")
-    def test_wjsc005_file_size_exceeded(self, api_config):
-        """测试WJSC005: 文件过大 - 跳过因为未实现"""
-        pass
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_wjsc007_unauthorized(self, api_config):
-        """测试WJSC007: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestTaskProgressErrors:
-    """进度查询接口错误码测试"""
-
-    def test_jdlx001_missing_parameters(self, api_config):
-        """测试JDLX001: 请求参数缺失"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/test-callback-id"
-        response = requests.get(url)  # 不提供user参数
-
-        assert response.status_code == 400
-        error_data = response.json()
-        assert error_data["code"] == "JDLX001"
-
-    @pytest.mark.parametrize("invalid_id", ["short", "123", "invalid-format"])
-    def test_jdlx002_invalid_param_format(self, api_config, invalid_id):
-        """测试JDLX002: 请求参数格式错误"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{invalid_id}"
-        params = {"user": api_config["valid_user"]}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX002"
-
-    @pytest.mark.parametrize("invalid_user", ["invalid_user", "user-999", ""])
-    def test_jdlx004_invalid_user(self, api_config, test_file, invalid_user):
-        """测试JDLX004: 用户标识无效"""
-        # 先上传文件获取有效的callback_task_id
-        callback_task_id = self._upload_file_and_get_callback(api_config, test_file)
-
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{callback_task_id}"
-        params = {"user": invalid_user}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 403
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX004"
-
-    def test_jdlx005_task_not_found(self, api_config):
-        """测试JDLX005: 任务不存在"""
-        fake_callback_id = f"{uuid.uuid4()}-{int(time.time())}"
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{fake_callback_id}"
-        params = {"user": api_config["valid_user"]}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 404
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX005"
-
-    def _upload_file_and_get_callback(self, api_config, test_file):
-        """辅助方法:上传文件并获取callback_task_id"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["callback_task_id"]
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_jdlx003_unauthorized(self, api_config):
-        """测试JDLX003: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestReviewResultsErrors:
-    """审查结果接口错误码测试"""
-
-    @pytest.mark.parametrize("invalid_type", ["invalid", "risk", "detail", ""])
-    def test_scjg001_invalid_type(self, api_config, invalid_type):
-        """测试SCJG001: 结果类型无效"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": str(uuid.uuid4()),
-            "user": api_config["valid_user"],
-            "type": invalid_type
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG001"
-
-    @pytest.mark.parametrize("invalid_id", ["", None])
-    def test_scjg002_missing_param_id(self, api_config, invalid_id):
-        """测试SCJG002: 缺少文档ID"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-
-        if invalid_id is None:
-            payload = {
-                "user": api_config["valid_user"],
-                "type": "summary"
-            }
-        else:
-            payload = {
-                "id": invalid_id,
-                "user": api_config["valid_user"],
-                "type": "summary"
-            }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG002"
-
-    @pytest.mark.parametrize("invalid_format", ["123", "short-id", "invalid-uuid-format"])
-    def test_scjg003_invalid_id_format(self, api_config, invalid_format):
-        """测试SCJG003: 文档ID格式错误"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": invalid_format,
-            "user": api_config["valid_user"],
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG003"
-
-    def test_scjg005_invalid_user_review_results(self, api_config, test_file):
-        """测试SCJG005: 用户标识无效(审查结果接口)"""
-        # 先上传文件获取有效的文件ID
-        file_id = self._upload_file_and_get_file_id(api_config, test_file)
-
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": file_id,
-            "user": "invalid_user",
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 403
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG005"
-
-    def test_scjg006_task_not_found_review_results(self, api_config):
-        """测试SCJG006: 任务不存在(审查结果接口)"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": str(uuid.uuid4()),
-            "user": api_config["valid_user"],
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 404
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG006"
-
-    def _upload_file_and_get_file_id(self, api_config, test_file):
-        """辅助方法:上传文件并获取文件ID"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["id"]
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_scjg004_unauthorized(self, api_config):
-        """测试SCJG004: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestIntegration:
-    """集成测试"""
-
-    def test_complete_workflow_success(self, api_config, test_file):
-        """测试完整工作流程成功场景"""
-        # 1. 文件上传
-        callback_task_id = self._upload_file_and_get_callback(api_config, test_file)
-        assert callback_task_id is not None
-
-        # 2. 进度查询
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{callback_task_id}"
-        params = {"user": api_config["valid_user"]}
-        response = requests.get(url, params=params)
-        assert response.status_code == 200
-
-    def _upload_file_and_get_callback(self, api_config, test_file):
-        """辅助方法:上传文件并获取callback_task_id"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["callback_task_id"]
-
-
-if __name__ == "__main__":
-    # 如果直接运行此文件,给出提示
-    print("请使用 pytest 运行此测试文件:")
-    print("pytest test/construction_review/test_error_codes_pytest.py -v")
-    print("或者运行所有测试:")
-    print("pytest test/ -v")

+ 5 - 4
views/__init__.py

@@ -21,10 +21,11 @@ async def lifespan(app: FastAPI):
     # 启动时加载工具
     #await mcp_server.get_mcp_tools()
     # 全局数据库连接池实例
-    async_db_pool = AsyncMySQLPool()
-    await async_db_pool.initialize()
-    app.state.async_db_pool = async_db_pool
-    server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
+    async_db_pool = None
+    # async_db_pool = AsyncMySQLPool()
+    # await async_db_pool.initialize()
+    # app.state.async_db_pool = async_db_pool
+    #server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
 
     yield
     # 关闭时清理

+ 1 - 0
views/construction_review/file_upload.py

@@ -161,6 +161,7 @@ async def file_upload(
 
         # 生成回调任务ID
         callback_task_id = f"{file_id}-{int(datetime.now().timestamp())}"
+        #callback_task_id = "d0856b13c5328e732e9c590209554b76-1763369817"            
 
         # 更新trace_id为正式的callback_task_id
         TraceContext.set_trace_id(callback_task_id)

+ 161 - 126
views/construction_review/task_progress.py

@@ -1,158 +1,193 @@
 """
-审查进度轮询接口
-支持Celery任务状态查询和进度展示
+审查进度SSE实时推送接口
 """
 
-import time
-import random
+import json
+import asyncio
+from typing import Dict
 from datetime import datetime
-from fastapi import APIRouter, HTTPException, Query
 from pydantic import BaseModel
-from typing import Optional
-from celery.result import AsyncResult
-from foundation.base.celery_app import app
+from fastapi import APIRouter, Query
+from .schemas.error_schemas import TaskProgressErrors
+from fastapi.responses import StreamingResponse
+from foundation.logger.loggering import server_logger as logger
+from foundation.trace.trace_context import TraceContext, auto_trace
+from core.base.progress_manager import ProgressManager, sse_callback_manager
 
-task_progress_router = APIRouter(prefix="/sgsc", tags=["进度轮询"])
+progress_manager = ProgressManager()
 
+task_progress_router = APIRouter(prefix="/sgsc", tags=["进度推送"])
 
-# 导入错误码定义
-from .schemas.error_schemas import TaskProgressErrors
+async def sse_progress_callback(callback_task_id: str, current_data: dict):
+    """SSE推送回调函数 - 接收进度更新并推送到客户端"""
+    await sse_manager.send_progress(callback_task_id, current_data)
 
 class TaskProgressResponse(BaseModel):
     code: int
     data: dict
 
-def update_task_progress(callback_task_id: str) -> dict:
-    """更新任务进度(模拟真实的处理过程)"""
-    if callback_task_id not in uploaded_files:
-        return None
 
-    task_info = uploaded_files[callback_task_id]
-    current_time = int(time.time())
+class SimpleSSEManager:
+    """SSE连接管理器 - 管理客户端SSE连接和消息推送"""
 
-    # 根据时间模拟进度推进
-    time_elapsed = current_time - task_info.get("updated_at", current_time)
 
-    # 定义各阶段的时间分配(总时长约30分钟)
-    stage_durations = {
-        "格式校验": 60,      # 1分钟
-        "内容提取": 900,     # 15分钟
-        "智能审查": 840      # 14分钟
-    }
+    def __init__(self):
+        self.connections: Dict[str, asyncio.Queue] = {}
 
-    total_duration = sum(stage_durations.values())
 
-    # 计算当前应该处于哪个阶段
-    accumulated_time = 0
-    overall_progress = 0
-    stages = []
+    async def connect(self, callback_task_id: str):
+        """建立SSE连接 - 创建消息队列并发送连接确认"""
+        queue = asyncio.Queue()
+        self.connections[callback_task_id] = queue
 
-    for stage_name, duration in stage_durations.items():
-        if time_elapsed > accumulated_time + duration:
-            # 阶段已完成
-            stages.append({
-                "stage_name": stage_name,
-                "progress": 100,
-                "stage_status": "completed"
-            })
-            accumulated_time += duration
-        elif time_elapsed > accumulated_time:
-            # 阶段进行中
-            stage_progress = min(100, int((time_elapsed - accumulated_time) / duration * 100))
-            stages.append({
-                "stage_name": stage_name,
-                "progress": stage_progress,
-                "stage_status": "processing"
-            })
-            accumulated_time += duration
-        else:
-            # 阶段未开始
-            stages.append({
-                "stage_name": stage_name,
-                "progress": 0,
-                "stage_status": "pending"
-            })
+        await queue.put({
+            "type": "connection_established",
+            "callback_task_id": callback_task_id,
+            "timestamp": datetime.now().isoformat()
+        })
 
-    # 计算总进度
-    overall_progress = min(100, int(time_elapsed / total_duration * 100))
-
-    # 确定任务状态
-    if overall_progress >= 100:
-        review_task_status = "completed"
-        estimated_remaining = 0
-    else:
-        review_task_status = "processing"
-        estimated_remaining = max(0, total_duration - time_elapsed)
-
-    # 更新任务信息
-    task_info.update({
-        "review_task_status": review_task_status,
-        "overall_progress": overall_progress,
-        "stages": stages,
-        "updated_at": current_time,
-        "estimated_remaining": estimated_remaining
-    })
-
-    return task_info
-
-@task_progress_router.get("/task_progress/{callback_task_id}", response_model=TaskProgressResponse)
-async def task_progress(
-    callback_task_id: str,
-    user: str = Query(None)
-):
-    """
-    任务进度轮询接口
-    """
-    try:
-        # 验证参数
-        if user is None or not isinstance(user, str):
-            raise TaskProgressErrors.missing_parameters()
+        logger.info(f"SSE连接: {callback_task_id}")
+        return queue
 
-        if not callback_task_id or not isinstance(callback_task_id, str):
-            raise TaskProgressErrors.missing_parameters()
 
-        # 检查callback_task_id格式(应该是UUID-时间戳格式)
-        if len(callback_task_id) < 20 or callback_task_id.count('-') < 4:
-            raise TaskProgressErrors.invalid_param_format()
+    async def disconnect(self, callback_task_id: str):
+        """断开SSE连接 - 清理连接队列"""
+        if callback_task_id in self.connections:
+            del self.connections[callback_task_id]
+        logger.info(f"SSE连接已断开: {callback_task_id}")
 
-        # 验证用户标识(应该是指定用户如user-001)
-        valid_users = {"user-001", "user-002", "user-003"}  # 可以配置化
-        if user == "" or user not in valid_users:
-            raise TaskProgressErrors.invalid_user()
 
-        # 检查任务是否存在
-        if callback_task_id not in uploaded_files:
-            raise TaskProgressErrors.task_not_found()
+    async def send_progress(self, callback_task_id: str, current_data: dict):
+        """发送进度更新 - 将进度数据放入队列推送给客户端"""
+        queue = self.connections.get(callback_task_id)
+        if queue:
+            await queue.put({
+                "type": "progress_update",
+                "data": current_data,
+                "timestamp": datetime.now().isoformat()
+            })
+            logger.debug(f"SSE进度已推送: {callback_task_id}")
+
+sse_manager = SimpleSSEManager()
+
+def format_sse_event(event_type: str, data: str) -> str:
+    """格式化SSE事件 - 按照SSE协议格式化事件数据"""
+    lines = [
+        f"event: {event_type}",
+        f"data: {data}",
+        "",
+        ""
+    ]
+    return "\n".join(lines) + "\n" 
 
-        # 验证用户权限
-        task_info = uploaded_files[callback_task_id]
-        if task_info.get("user") != user:
-            raise TaskProgressErrors.invalid_user()
 
-        # 更新进度
-        updated_task = update_task_progress(callback_task_id)
-
-        return TaskProgressResponse(
-            code=200,
-            data={
-                "callback_task_id": callback_task_id,
-                "user": user,
-                "review_task_status": updated_task["review_task_status"],
-                "overall_progress": updated_task["overall_progress"],
-                "stages": updated_task["stages"],
-                "updated_at": updated_task["updated_at"]
+@task_progress_router.get("/sse/current/{callback_task_id}")
+@auto_trace("callback_task_id")
+async def sse_progress_stream(
+    callback_task_id: str,
+    user: str = Query(..., description="用户标识")
+):
+    """SSE实时进度推送接口 - 建立SSE连接并实时推送任务进度"""
+    try:
+        valid_users = {"user-001", "user-002", "user-003"}
+        if user not in valid_users:
+            raise TaskProgressErrors.invalid_user()
+        sse_callback_manager.register_callback(callback_task_id, sse_progress_callback)
+
+        queue = await sse_manager.connect(callback_task_id)
+
+        async def generate_events():
+            """生成SSE事件流 - 处理连接确认、进度推送和任务完成检测"""
+            try:
+                logger.info(f"开始SSE事件流: {callback_task_id}")
+
+                connected_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "message": "SSE连接已建立,等待进度更新...",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("connected", connected_data)
+
+                current_progress = await progress_manager.get_progress(callback_task_id)
+                if current_progress:
+                    progress_json = json.dumps(current_progress, ensure_ascii=False)
+                    yield format_sse_event("current", progress_json)
+
+                logger.debug(f"开始监听队列中的进度更新: {callback_task_id}")
+
+                while True:
+                    try:
+                        message = await queue.get()
+
+                        if message.get("type") == "progress_update":
+                            current_data = message.get("data")
+                            if current_data:
+                                logger.info(f"总流程处理进度: {current_data.get("message")}")
+
+                                progress_json = json.dumps(current_data, ensure_ascii=False)
+                                yield format_sse_event("current", progress_json)
+
+                                overall_task_status = current_data.get("overall_task_status")
+
+                                if overall_task_status in ["completed", "failed"]:
+                                    completion_data = {
+                                        "callback_task_id": callback_task_id,
+                                        "task_status": overall_task_status,
+                                        "overall_progress": current_data.get("current", 100),
+                                        "timestamp": datetime.now().isoformat(),
+                                        "message": "全部任务完成!"
+                                    }
+                                    completion_json = json.dumps(completion_data, ensure_ascii=False)
+                                    yield format_sse_event("completed", completion_json)
+
+                                    logger.info(f"全部任务完成,断开SSE连接: {callback_task_id}, 状态: {overall_task_status}")
+                                    break
+
+                        elif message.get("type") == "connection_established":
+                            pass
+
+                    except Exception as e:
+                        logger.error(f"队列消息处理异常: {callback_task_id}, {e}")
+                        break
+
+            except Exception as e:
+                logger.error(f"SSE事件流异常: {callback_task_id}, {e}")
+                error_data = json.dumps({
+                    "error": f"SSE异常: {str(e)}",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("error", error_data)
+
+            finally:
+                sse_callback_manager.unregister_callback(callback_task_id)
+                await sse_manager.disconnect(callback_task_id)
+                logger.debug(f"SSE流已结束: {callback_task_id}")
+
+        return StreamingResponse(
+            generate_events(),
+            media_type="text/event-stream",
+            headers={
+                "Cache-Control": "no-cache, no-store, must-revalidate",
+                "Connection": "keep-alive",
+                "Access-Control-Allow-Origin": "*",
+                "Access-Control-Allow-Headers": "Cache-Control, EventSource",
+                "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
+                "X-Accel-Buffering": "no",  
+                "X-Content-Type-Options": "nosniff"
             }
         )
 
-    except HTTPException:
-        raise
     except Exception as e:
+        logger.error(f"SSE连接失败: {callback_task_id}, {e}")
         raise TaskProgressErrors.server_internal_error(e)
 
-@task_progress_router.post("/mock/advance_time")
-async def advance_time(seconds: int = 300):
-    """Mock接口:推进时间(用于测试)"""
-    for callback_task_id in list(uploaded_files.keys()):
-        if "review_task_status" in uploaded_files[callback_task_id]:
-            uploaded_files[callback_task_id]["updated_at"] -= seconds
-    return {"message": f"时间推进了 {seconds} 秒"}
+
+@task_progress_router.get("/sse/status")
+async def get_sse_status():
+    """获取SSE连接状态 - 返回当前活跃的SSE连接信息"""
+    return {
+        "active_connections": len(sse_manager.connections),
+        "connections": list(sse_manager.connections.keys()),
+        "timestamp": datetime.now().isoformat()
+    }
+

+ 323 - 9
views/test_views.py

@@ -24,7 +24,15 @@ from views import test_router, get_operation_id
 from foundation.agent.workflow.test_workflow_graph import test_workflow_graph
 
 from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
+from database.repositories.bus_data_query import BasisOfPreparationDAO
 from foundation.utils.tool_utils import DateTimeEncoder
+from langchain_core.prompts import ChatPromptTemplate
+from foundation.utils.yaml_utils import system_prompt_config
+
+from foundation.models.silicon_flow import SiliconFlowAPI
+from foundation.rag.vector.pg_vector import PGVectorDB
+from foundation.rag.vector.milvus_vector import MilvusVectorManager
+
 
 
 @test_router.post("/generate/chat", response_model=TestForm)
@@ -43,21 +51,27 @@ async def generate_chat_endpoint(
         context = param.context
         header_info = {
         }
-        task_prompt_info = {"task_prompt": ""}
-        output = generate_model_client.get_model_generate_invoke(trace_id , task_prompt_info, 
-                                                                                 input_query, context)
+        
+            # 创建ChatPromptTemplate
+        template = ChatPromptTemplate.from_messages([
+            ("system", system_prompt_config['system_prompt']),
+            ("user", input_query)
+        ])
+
+        task_prompt_info = {"task_prompt": template}
+        output = await generate_model_client.get_model_generate_invoke(trace_id=trace_id , task_prompt_info=task_prompt_info)
         # 直接执行
-        server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="agent/chat")
         # 返回字典格式的响应
         return JSONResponse(
             return_json(data={"output": output}, data_type="text", trace_id=trace_id))
 
     except ValueError as err:
-        handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
         return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
 
     except Exception as err:
-        handler_err(server_logger, trace_id=trace_id, err=err, err_name="generate/stream")
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="agent/chat")
         return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
 
 
@@ -77,13 +91,18 @@ async def generate_stream_endpoint(
         context = param.context
         header_info = {
         }
-        task_prompt_info = {"task_prompt": ""}
+              # 创建ChatPromptTemplate
+        template = ChatPromptTemplate.from_messages([
+            ("system", system_prompt_config['system_prompt']),
+            ("user", input_query)
+        ])
+
+        task_prompt_info = {"task_prompt": template}
         # 创建 SSE 流式响应
         async def event_generator():
             try:
                 # 流式处理查询 trace_id, task_prompt_info: dict, input_query, context=None
-                for chunk in generate_model_client.get_model_generate_stream(trace_id , task_prompt_info, 
-                                                                                 input_query, context):
+                for chunk in generate_model_client.get_model_generate_stream(trace_id=trace_id , task_prompt_info=task_prompt_info):
                     # 发送数据块
                     yield {
                         "event": "message",
@@ -340,6 +359,46 @@ async def chat_graph_stream(param: TestForm,
 
 
 
+@test_router.post("/redis", response_model=TestForm)
+async def test_redis(
+        request: Request,
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+    根据MySQL应用
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        # 验证参数
+
+        # 从字典中获取input
+        input_data = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        from foundation.utils.redis_utils import set_redis_result_cache_data , get_redis_result_cache_data
+        output = "success"
+        data_type = "output"
+
+        await set_redis_result_cache_data(data_type=data_type , trace_id=trace_id , value=input_data)
+        server_logger.info(trace_id=trace_id, msg=f"key:{trace_id}:{data_type},value:{input_data} redis 设置成功")
+        output = await get_redis_result_cache_data(data_type=data_type , trace_id=trace_id)
+        # 直接执行
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {output}", log_type="/redis")
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/redis")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+    
+
+
 
 
 @test_router.post("/mysql/add", response_model=TestForm)
@@ -511,3 +570,258 @@ async def test_mysql_add(
         handler_err(server_logger, trace_id=trace_id, err=err, err_name="/mysql/update")
         return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
     
+
+
+
+
+@test_router.post("/bop/get", response_model=TestForm)
+async def test_bop_get(
+        request: Request,
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+    根据MySQL应用
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        # 验证参数
+
+        # 从字典中获取input
+        input_data = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        # 从app.state中获取数据库连接池
+        async_db_pool = request.app.state.async_db_pool
+        bop_dao = BasisOfPreparationDAO(async_db_pool)
+        test_id = input_data;
+        data = await bop_dao.get_info_by_id(id=test_id)
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/get")
+        json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
+        output = json_str
+        # 直接执行
+        server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/get")
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/get")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+    
+@test_router.post("/bop/list", response_model=TestForm)
+async def test_mysql_add(
+        request: Request,
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+    根据MySQL应用
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        # 验证参数
+
+        # 从字典中获取input
+        input_data = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        # 从app.state中获取数据库连接池
+        async_db_pool = request.app.state.async_db_pool
+        from foundation.base.mysql.async_mysql_base_dao import TestTabDAO
+        bop_dao = BasisOfPreparationDAO(async_db_pool)
+        test_id = input_data;
+        data = await bop_dao.get_list()
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {data}", log_type="/bop/list")
+        json_str = json.dumps(data , cls=DateTimeEncoder, ensure_ascii=False, indent=2)
+        output = json_str
+        # 直接执行
+        server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="/bop/list")
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="/bop/list")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+
+
+
+##################【RAG 相关测试】##############################################
+@test_router.post("/embedding", response_model=TestForm)
+async def embedding_test_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        embedding模型测试
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        text = input_query
+         # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        from foundation.models.silicon_flow import SiliconFlowAPI
+        base_api_platform = SiliconFlowAPI()
+        embedding = base_api_platform.get_embeddings([text])[0]
+        embed_dim = len(embedding)
+        server_logger.info(trace_id=trace_id, msg=f"【result】: {embed_dim}")
+
+        output = f"embed_dim={embed_dim},embedding:{embedding}"
+        #output = test_generate_model_client.get_model_data_governance_invoke(trace_id , task_prompt_info, input_query, context)
+        # 直接执行
+        #server_logger.debug(trace_id=trace_id, msg=f"【result】: {output}", log_type="embedding")
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="embedding")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+
+
+
+
+@test_router.post("/bfp/search", response_model=TestForm)
+async def bfp_search_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        编制依据向量检索
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        top_k = int(session_id)
+        
+        output = None
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        client = SiliconFlowAPI()
+        # 抽象测试
+        pg_vector_db = PGVectorDB(base_api_platform=client)
+        output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
+
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+    
+
+
+
+@test_router.post("/bfp/search/rerank", response_model=TestForm)
+async def bfp_search_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        编制依据文档检索和重排序
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        top_k = int(session_id)
+        
+        output = None
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        client = SiliconFlowAPI()
+        # 抽象测试
+        pg_vector_db = PGVectorDB(base_api_platform=client)
+        output = pg_vector_db.retriever(param={"table_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
+        # 重排序处理
+        content_list = [doc["text_content"] for doc in output]
+        output = client.rerank(input_query=input_query, documents=content_list , top_n=top_k)
+
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+    
+
+
+
+
+@test_router.post("/data/bfp/milvus/search", response_model=TestForm)
+async def bfp_search_endpoint(
+        param: TestForm,
+        trace_id: str = Depends(get_operation_id)):
+    """
+        编制依据文档切分处理 和 入库处理
+    """
+    try:
+        server_logger.info(trace_id=trace_id, msg=f"{param}")
+        print(trace_id)
+        # 从字典中获取input
+        input_query = param.input
+        session_id = param.config.session_id
+        context = param.context
+        header_info = {
+        }
+        task_prompt_info = {"task_prompt": ""}
+        top_k = int(session_id)
+        
+        output = None
+        # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
+        client = SiliconFlowAPI()
+        # 抽象测试
+        vector_db = MilvusVectorManager(base_api_platform=client)
+        output = vector_db.retriever(param={"collection_name": "tv_basis_of_preparation"}, query_text=input_query , top_k=top_k)
+
+        # 返回字典格式的响应
+        return JSONResponse(
+            return_json(data={"output": output}, data_type="text", trace_id=trace_id))
+
+    except ValueError as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))
+
+    except Exception as err:
+        handler_err(server_logger, trace_id=trace_id, err=err, err_name="bfp/milvus/search")
+        return JSONResponse(return_json(code=100500, msg=f"{err}", trace_id=trace_id))

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików