classifier.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. """
  2. 简化版分类器(一级/二级/三级)
  3. 直接调用 OpenAI 兼容 API,不依赖 core/foundation 代码。
  4. """
  5. import asyncio
  6. import csv
  7. import json
  8. import re
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional, Tuple
  11. from openai import AsyncOpenAI
  12. # ==================== 配置默认值 ====================
  13. DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
  14. DEFAULT_MODEL = "qwen3.5-122b-a10b"
  15. DEFAULT_CONCURRENCY = 10
  16. # 一级分类标准
  17. PRIMARY_CATEGORIES = {
  18. "编制依据": "basis",
  19. "工程概况": "overview",
  20. "施工计划": "plan",
  21. "施工工艺技术": "technology",
  22. "安全保证措施": "safety",
  23. "质量保证措施": "quality",
  24. "环境保证措施": "environment",
  25. "施工管理及作业人员配备与分工": "management",
  26. "验收要求": "acceptance",
  27. "其他资料": "other",
  28. }
  29. # 标准二级标题白名单
  30. STANDARD_SECONDARY_TITLES: Dict[str, List[str]] = {
  31. "basis": ["法律法规", "标准规范", "文件制度", "编制原则", "编制范围"],
  32. "overview": ["设计概况", "工程地质与水文气象", "周边环境", "施工平面及立面布置", "施工要求和技术保证条件", "风险辨识与分级", "参建各方责任主体单位"],
  33. "plan": ["施工进度计划", "施工材料计划", "施工设备计划", "劳动力计划", "安全生产费用使用计划"],
  34. "technology": ["主要施工方法概述", "技术参数", "工艺流程", "施工准备", "施工方法及操作要求", "检查要求"],
  35. "safety": ["安全保证体系", "组织保证措施", "技术保证措施", "监测监控措施", "应急处置措施"],
  36. "quality": ["质量保证体系", "质量目标", "工程创优规划", "质量控制程序与具体措施"],
  37. "environment": ["环境保证体系", "环境保护组织机构", "环境保护及文明施工措施"],
  38. "management": ["施工管理人员", "专职安全生产管理人员", "其他作业人员"],
  39. "acceptance": ["验收标准", "验收程序", "验收内容", "验收时间", "验收人员"],
  40. "other": ["计算书", "相关施工图纸", "附图附表", "编制及审核人员情况"],
  41. }
  42. class SimpleClassifier:
  43. """简化版文档分类器"""
  44. def __init__(
  45. self,
  46. api_key: str,
  47. base_url: str = DEFAULT_BASE_URL,
  48. model: str = DEFAULT_MODEL,
  49. concurrency: int = DEFAULT_CONCURRENCY,
  50. csv_path: Optional[str] = None,
  51. ):
  52. self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
  53. self.model = model
  54. self.concurrency = concurrency
  55. self.classification_tree = self._load_classification_tree(csv_path)
  56. def _load_classification_tree(self, csv_path: Optional[str]) -> Dict[str, Dict[str, Any]]:
  57. """从 CSV 加载分类标准树"""
  58. tree: Dict[str, Dict[str, Any]] = {}
  59. if csv_path is None:
  60. # 默认路径:相对于项目根目录
  61. csv_path = Path(__file__).parent.parent.parent / "core" / "construction_review" / "component" / "doc_worker" / "config" / "StandardCategoryTable.csv"
  62. else:
  63. csv_path = Path(csv_path)
  64. if not csv_path.exists():
  65. # 如果找不到 CSV,使用硬编码的最小标准
  66. return self._build_minimal_tree()
  67. with csv_path.open("r", encoding="utf-8-sig") as f:
  68. reader = csv.DictReader(f)
  69. for row in reader:
  70. first_code = (row.get("first_code") or "").strip()
  71. first_name = (row.get("first_name") or "").strip()
  72. second_code = (row.get("second_code") or "").strip()
  73. second_name = (row.get("second_name") or "").strip()
  74. second_focus = (row.get("second_focus") or "").strip()
  75. third_code = (row.get("third_code") or "").strip()
  76. third_name = (row.get("third_name") or "").strip()
  77. third_focus = (row.get("third_focus") or "").strip()
  78. if not first_code or not second_code:
  79. continue
  80. if first_code not in tree:
  81. tree[first_code] = {}
  82. if second_code not in tree[first_code]:
  83. tree[first_code][second_code] = {
  84. "second_name": second_name,
  85. "second_focus": second_focus,
  86. "third_items": [],
  87. }
  88. if third_code and third_name:
  89. tree[first_code][second_code]["third_items"].append({
  90. "third_code": third_code,
  91. "third_name": third_name,
  92. "third_focus": third_focus,
  93. })
  94. return tree
  95. def _build_minimal_tree(self) -> Dict[str, Dict[str, Any]]:
  96. """构建最小化的分类标准树(兜底)"""
  97. tree: Dict[str, Dict[str, Any]] = {}
  98. for first_name, first_code in PRIMARY_CATEGORIES.items():
  99. tree[first_code] = {}
  100. second_titles = STANDARD_SECONDARY_TITLES.get(first_code, [])
  101. for idx, title in enumerate(second_titles, 1):
  102. tree[first_code][f"sec_{idx}"] = {
  103. "second_name": title,
  104. "second_focus": "",
  105. "third_items": [],
  106. }
  107. return tree
  108. # ==================== 公共接口 ====================
  109. async def classify_primary(self, toc_items: List[Dict[str, Any]]) -> Dict[str, Any]:
  110. """一级目录分类"""
  111. level1_items = [item for item in toc_items if item["level"] == 1]
  112. if not level1_items:
  113. return {"items": [], "total_count": 0, "target_level": 1, "category_stats": {}}
  114. semaphore = asyncio.Semaphore(self.concurrency)
  115. async def _classify_one(item: Dict[str, Any]) -> Dict[str, Any]:
  116. async with semaphore:
  117. return await self._call_llm_primary(item)
  118. tasks = [_classify_one(item) for item in level1_items]
  119. classified_items = await asyncio.gather(*tasks)
  120. category_stats = {}
  121. for item in classified_items:
  122. cat = item.get("category", "非标准项")
  123. category_stats[cat] = category_stats.get(cat, 0) + 1
  124. return {
  125. "items": classified_items,
  126. "total_count": len(classified_items),
  127. "target_level": 1,
  128. "category_stats": category_stats,
  129. }
  130. async def classify_secondary(self, primary_result: Dict[str, Any]) -> Dict[str, Any]:
  131. """二级目录分类"""
  132. primary_items = primary_result.get("items", [])
  133. if not primary_items:
  134. return {"items": [], "total_count": 0, "category_stats": {}}
  135. semaphore = asyncio.Semaphore(self.concurrency)
  136. async def _classify_one(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
  137. async with semaphore:
  138. first_category = item.get("category", "")
  139. first_code = item.get("category_code", "")
  140. level2_titles = item.get("level2_titles", [])
  141. if not level2_titles:
  142. return None
  143. return await self._call_llm_secondary(
  144. first_category, first_code, level2_titles, item.get("title", "")
  145. )
  146. tasks = [_classify_one(item) for item in primary_items]
  147. results = await asyncio.gather(*tasks)
  148. results = [r for r in results if r is not None]
  149. category_stats = {}
  150. for result in results:
  151. for cls in result.get("classifications", []):
  152. code = cls.get("category_code", "non_standard")
  153. category_stats[code] = category_stats.get(code, 0) + 1
  154. return {
  155. "items": results,
  156. "total_count": sum(r.get("level2_count", 0) for r in results),
  157. "category_stats": category_stats,
  158. }
  159. async def classify_tertiary(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  160. """三级分类(简化版:逐 chunk 分类)"""
  161. if not chunks:
  162. return chunks
  163. semaphore = asyncio.Semaphore(self.concurrency)
  164. async def _classify_chunk(chunk: Dict[str, Any]) -> Dict[str, Any]:
  165. first_code = chunk.get("chapter_classification", "")
  166. second_code = chunk.get("secondary_category_code", "")
  167. if not first_code or not second_code or second_code == "non_standard":
  168. chunk["tertiary_category_code"] = "none"
  169. chunk["tertiary_category_cn"] = "无"
  170. return chunk
  171. standards = self._build_tertiary_standards(first_code, second_code)
  172. if not standards:
  173. chunk["tertiary_category_code"] = "none"
  174. chunk["tertiary_category_cn"] = "无"
  175. return chunk
  176. async with semaphore:
  177. return await self._call_llm_tertiary(chunk, standards)
  178. tasks = [_classify_chunk(c) for c in chunks]
  179. return list(await asyncio.gather(*tasks))
  180. # ==================== LLM 调用实现 ====================
  181. async def _call_llm(self, system_prompt: str, user_prompt: str) -> Optional[Dict[str, Any]]:
  182. """基础 LLM 调用"""
  183. try:
  184. response = await self.client.chat.completions.create(
  185. model=self.model,
  186. messages=[
  187. {"role": "system", "content": system_prompt},
  188. {"role": "user", "content": user_prompt},
  189. ],
  190. temperature=0.3,
  191. )
  192. content = response.choices[0].message.content or ""
  193. return _extract_json(content)
  194. except Exception as e:
  195. print(f"[LLM 调用失败] {e}")
  196. return None
  197. async def _call_llm_primary(self, item: Dict[str, Any]) -> Dict[str, Any]:
  198. """调用 LLM 进行一级分类"""
  199. title = item.get("title", "")
  200. system_prompt = """你是一个施工方案文档目录分类专家。
  201. 请将给定的一级章节标题分类到以下类别之一,返回 JSON 格式:
  202. {"category_cn": "类别中文名", "category_code": "类别代码", "confidence": 0.95}
  203. 可选类别:
  204. - 编制依据 (basis)
  205. - 工程概况 (overview)
  206. - 施工计划 (plan)
  207. - 施工工艺技术 (technology)
  208. - 安全保证措施 (safety)
  209. - 质量保证措施 (quality)
  210. - 环境保证措施 (environment)
  211. - 施工管理及作业人员配备与分工 (management)
  212. - 验收要求 (acceptance)
  213. - 其他资料 (other)
  214. - 非标准项 (non_standard)
  215. 如果标题明显不属于以上任何类别,归为"非标准项"。"""
  216. user_prompt = f"一级章节标题:{title}"
  217. result = await self._call_llm(system_prompt, user_prompt)
  218. if result and isinstance(result, dict):
  219. category_cn = result.get("category_cn", "")
  220. category_code = result.get("category_code", "")
  221. confidence = result.get("confidence", 0.0)
  222. if category_cn not in PRIMARY_CATEGORIES and category_cn != "非标准项":
  223. category_cn = "非标准项"
  224. category_code = "non_standard"
  225. confidence = 0.0
  226. if category_cn in PRIMARY_CATEGORIES and not category_code:
  227. category_code = PRIMARY_CATEGORIES[category_cn]
  228. else:
  229. category_cn = "非标准项"
  230. category_code = "non_standard"
  231. confidence = 0.0
  232. return {
  233. "title": title,
  234. "page": item.get("page", 0),
  235. "level": item.get("level", 1),
  236. "category": category_cn,
  237. "category_code": category_code,
  238. "original": item.get("original", ""),
  239. "level2_titles": item.get("level2_titles", []),
  240. "confidence": confidence,
  241. }
  242. async def _call_llm_secondary(
  243. self,
  244. first_category: str,
  245. first_category_code: str,
  246. level2_titles: List[str],
  247. original_title: str,
  248. ) -> Dict[str, Any]:
  249. """调用 LLM 进行二级分类(批量模式)"""
  250. # 获取该一级分类下的二级标准
  251. secondary_items = []
  252. if first_category_code in self.classification_tree:
  253. for sec_code, sec_data in self.classification_tree[first_category_code].items():
  254. secondary_items.append(f"- {sec_data['second_name']} ({sec_code})")
  255. standards_text = "\n".join(secondary_items) if secondary_items else "(无预定义标准)"
  256. titles_list = "\n".join(f"{i+1}. {title}" for i, title in enumerate(level2_titles))
  257. system_prompt = f"""你是一个施工方案文档目录分类专家。
  258. 请将以下二级小节标题分类到对应类别,返回 JSON 格式:
  259. {{"classifications": [{{"title": "原标题", "category_index": 索引, "category_name": "分类名"}}]}}
  260. 一级分类:{first_category}
  261. 可选二级分类:
  262. {standards_text}
  263. 特殊索引:0 = 非标准项
  264. 要求:
  265. 1. 返回的 classifications 数组长度必须与输入标题数量完全一致
  266. 2. category_index 必须是数字索引
  267. 3. 只返回 JSON,不要其他解释"""
  268. user_prompt = f"待分类的二级标题:\n{titles_list}"
  269. result = await self._call_llm(system_prompt, user_prompt)
  270. classifications = []
  271. if result and isinstance(result, dict) and "classifications" in result:
  272. raw_list = result["classifications"]
  273. if len(raw_list) == len(level2_titles):
  274. for i, raw in enumerate(raw_list):
  275. idx = raw.get("category_index", 0)
  276. name = raw.get("category_name", "")
  277. # 查找代码
  278. code = "non_standard"
  279. if first_category_code in self.classification_tree:
  280. for sec_code, sec_data in self.classification_tree[first_category_code].items():
  281. if sec_data["second_name"] == name or sec_code == name:
  282. code = sec_code
  283. break
  284. if idx == 0 or not name:
  285. name = "非标准项"
  286. code = "non_standard"
  287. classifications.append({
  288. "title": level2_titles[i],
  289. "category_index": idx,
  290. "category_code": code,
  291. "category_name": name,
  292. })
  293. else:
  294. # 数量不匹配,全部设为非标准项
  295. for title in level2_titles:
  296. classifications.append({
  297. "title": title,
  298. "category_index": 0,
  299. "category_code": "non_standard",
  300. "category_name": "非标准项",
  301. })
  302. else:
  303. # LLM 调用失败,全部设为非标准项
  304. for title in level2_titles:
  305. classifications.append({
  306. "title": title,
  307. "category_index": 0,
  308. "category_code": "non_standard",
  309. "category_name": "非标准项",
  310. })
  311. return {
  312. "first_category": first_category,
  313. "first_category_code": first_category_code,
  314. "original_title": original_title,
  315. "level2_count": len(level2_titles),
  316. "classifications": classifications,
  317. }
  318. async def _call_llm_tertiary(
  319. self,
  320. chunk: Dict[str, Any],
  321. standards: List[Dict[str, str]],
  322. ) -> Dict[str, Any]:
  323. """调用 LLM 进行三级分类(简化版)"""
  324. content = chunk.get("review_chunk_content", "")[:500] # 限制长度
  325. section_label = chunk.get("section_label", "")
  326. standards_text = "\n".join(
  327. f"{i+1}. {s['name']} ({s['code']}) - {s.get('focus', '')}"
  328. for i, s in enumerate(standards)
  329. )
  330. system_prompt = """你是一个施工方案文档内容分类专家。
  331. 请判断给定的文档内容属于哪个三级分类,返回 JSON 格式:
  332. {"category_index": 索引, "category_name": "分类名"}
  333. 如果内容不属于任何类别,返回 {"category_index": 0, "category_name": "非标准项"}。
  334. 只返回 JSON,不要其他解释。"""
  335. user_prompt = f"""文档章节:{section_label}
  336. 内容预览:
  337. {content}
  338. 可选分类:
  339. {standards_text}
  340. """
  341. result = await self._call_llm(system_prompt, user_prompt)
  342. if result and isinstance(result, dict):
  343. idx = result.get("category_index", 0)
  344. name = result.get("category_name", "")
  345. if idx == 0 or not name:
  346. chunk["tertiary_category_code"] = "non_standard"
  347. chunk["tertiary_category_cn"] = "非标准项"
  348. else:
  349. # 查找 code
  350. code = "non_standard"
  351. if idx <= len(standards):
  352. code = standards[idx - 1]["code"]
  353. name = standards[idx - 1]["name"]
  354. chunk["tertiary_category_code"] = code
  355. chunk["tertiary_category_cn"] = name
  356. else:
  357. chunk["tertiary_category_code"] = "non_standard"
  358. chunk["tertiary_category_cn"] = "非标准项"
  359. return chunk
  360. def _build_tertiary_standards(self, first_code: str, second_code: str) -> List[Dict[str, str]]:
  361. """构建三级分类标准列表"""
  362. if first_code not in self.classification_tree:
  363. return []
  364. if second_code not in self.classification_tree[first_code]:
  365. return []
  366. third_items = self.classification_tree[first_code][second_code].get("third_items", [])
  367. if not third_items:
  368. return []
  369. return [
  370. {
  371. "code": item["third_code"],
  372. "name": item["third_name"],
  373. "focus": item.get("third_focus", ""),
  374. }
  375. for item in third_items
  376. ]
  377. # ==================== 工具函数 ====================
  378. def _extract_json(text: str) -> Optional[Dict[str, Any]]:
  379. """从字符串中提取第一个有效 JSON 对象"""
  380. if not text or not text.strip():
  381. return None
  382. text = text.strip()
  383. try:
  384. return json.loads(text)
  385. except json.JSONDecodeError:
  386. pass
  387. for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
  388. m = re.search(pattern, text, re.DOTALL)
  389. if m:
  390. try:
  391. return json.loads(m.group(1))
  392. except json.JSONDecodeError:
  393. pass
  394. try:
  395. for candidate in re.findall(r"(\{[\s\S]*?})", text):
  396. try:
  397. result = json.loads(candidate)
  398. if isinstance(result, dict):
  399. return result
  400. except json.JSONDecodeError:
  401. continue
  402. except Exception:
  403. pass
  404. return None