| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472 |
- """
- 简化版分类器(一级/二级/三级)
- 直接调用 OpenAI 兼容 API,不依赖 core/foundation 代码。
- """
- import asyncio
- import csv
- import json
- import re
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- from openai import AsyncOpenAI
- # ==================== 配置默认值 ====================
- DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
- DEFAULT_MODEL = "qwen3.5-122b-a10b"
- DEFAULT_CONCURRENCY = 10
- # 一级分类标准
- PRIMARY_CATEGORIES = {
- "编制依据": "basis",
- "工程概况": "overview",
- "施工计划": "plan",
- "施工工艺技术": "technology",
- "安全保证措施": "safety",
- "质量保证措施": "quality",
- "环境保证措施": "environment",
- "施工管理及作业人员配备与分工": "management",
- "验收要求": "acceptance",
- "其他资料": "other",
- }
- # 标准二级标题白名单
- STANDARD_SECONDARY_TITLES: Dict[str, List[str]] = {
- "basis": ["法律法规", "标准规范", "文件制度", "编制原则", "编制范围"],
- "overview": ["设计概况", "工程地质与水文气象", "周边环境", "施工平面及立面布置", "施工要求和技术保证条件", "风险辨识与分级", "参建各方责任主体单位"],
- "plan": ["施工进度计划", "施工材料计划", "施工设备计划", "劳动力计划", "安全生产费用使用计划"],
- "technology": ["主要施工方法概述", "技术参数", "工艺流程", "施工准备", "施工方法及操作要求", "检查要求"],
- "safety": ["安全保证体系", "组织保证措施", "技术保证措施", "监测监控措施", "应急处置措施"],
- "quality": ["质量保证体系", "质量目标", "工程创优规划", "质量控制程序与具体措施"],
- "environment": ["环境保证体系", "环境保护组织机构", "环境保护及文明施工措施"],
- "management": ["施工管理人员", "专职安全生产管理人员", "其他作业人员"],
- "acceptance": ["验收标准", "验收程序", "验收内容", "验收时间", "验收人员"],
- "other": ["计算书", "相关施工图纸", "附图附表", "编制及审核人员情况"],
- }
- class SimpleClassifier:
- """简化版文档分类器"""
- def __init__(
- self,
- api_key: str,
- base_url: str = DEFAULT_BASE_URL,
- model: str = DEFAULT_MODEL,
- concurrency: int = DEFAULT_CONCURRENCY,
- csv_path: Optional[str] = None,
- ):
- self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
- self.model = model
- self.concurrency = concurrency
- self.classification_tree = self._load_classification_tree(csv_path)
- def _load_classification_tree(self, csv_path: Optional[str]) -> Dict[str, Dict[str, Any]]:
- """从 CSV 加载分类标准树"""
- tree: Dict[str, Dict[str, Any]] = {}
- if csv_path is None:
- # 默认路径:相对于项目根目录
- csv_path = Path(__file__).parent.parent.parent / "core" / "construction_review" / "component" / "doc_worker" / "config" / "StandardCategoryTable.csv"
- else:
- csv_path = Path(csv_path)
- if not csv_path.exists():
- # 如果找不到 CSV,使用硬编码的最小标准
- return self._build_minimal_tree()
- with csv_path.open("r", encoding="utf-8-sig") as f:
- reader = csv.DictReader(f)
- for row in reader:
- first_code = (row.get("first_code") or "").strip()
- first_name = (row.get("first_name") or "").strip()
- second_code = (row.get("second_code") or "").strip()
- second_name = (row.get("second_name") or "").strip()
- second_focus = (row.get("second_focus") or "").strip()
- third_code = (row.get("third_code") or "").strip()
- third_name = (row.get("third_name") or "").strip()
- third_focus = (row.get("third_focus") or "").strip()
- if not first_code or not second_code:
- continue
- if first_code not in tree:
- tree[first_code] = {}
- if second_code not in tree[first_code]:
- tree[first_code][second_code] = {
- "second_name": second_name,
- "second_focus": second_focus,
- "third_items": [],
- }
- if third_code and third_name:
- tree[first_code][second_code]["third_items"].append({
- "third_code": third_code,
- "third_name": third_name,
- "third_focus": third_focus,
- })
- return tree
- def _build_minimal_tree(self) -> Dict[str, Dict[str, Any]]:
- """构建最小化的分类标准树(兜底)"""
- tree: Dict[str, Dict[str, Any]] = {}
- for first_name, first_code in PRIMARY_CATEGORIES.items():
- tree[first_code] = {}
- second_titles = STANDARD_SECONDARY_TITLES.get(first_code, [])
- for idx, title in enumerate(second_titles, 1):
- tree[first_code][f"sec_{idx}"] = {
- "second_name": title,
- "second_focus": "",
- "third_items": [],
- }
- return tree
- # ==================== 公共接口 ====================
- async def classify_primary(self, toc_items: List[Dict[str, Any]]) -> Dict[str, Any]:
- """一级目录分类"""
- level1_items = [item for item in toc_items if item["level"] == 1]
- if not level1_items:
- return {"items": [], "total_count": 0, "target_level": 1, "category_stats": {}}
- semaphore = asyncio.Semaphore(self.concurrency)
- async def _classify_one(item: Dict[str, Any]) -> Dict[str, Any]:
- async with semaphore:
- return await self._call_llm_primary(item)
- tasks = [_classify_one(item) for item in level1_items]
- classified_items = await asyncio.gather(*tasks)
- category_stats = {}
- for item in classified_items:
- cat = item.get("category", "非标准项")
- category_stats[cat] = category_stats.get(cat, 0) + 1
- return {
- "items": classified_items,
- "total_count": len(classified_items),
- "target_level": 1,
- "category_stats": category_stats,
- }
- async def classify_secondary(self, primary_result: Dict[str, Any]) -> Dict[str, Any]:
- """二级目录分类"""
- primary_items = primary_result.get("items", [])
- if not primary_items:
- return {"items": [], "total_count": 0, "category_stats": {}}
- semaphore = asyncio.Semaphore(self.concurrency)
- async def _classify_one(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- async with semaphore:
- first_category = item.get("category", "")
- first_code = item.get("category_code", "")
- level2_titles = item.get("level2_titles", [])
- if not level2_titles:
- return None
- return await self._call_llm_secondary(
- first_category, first_code, level2_titles, item.get("title", "")
- )
- tasks = [_classify_one(item) for item in primary_items]
- results = await asyncio.gather(*tasks)
- results = [r for r in results if r is not None]
- category_stats = {}
- for result in results:
- for cls in result.get("classifications", []):
- code = cls.get("category_code", "non_standard")
- category_stats[code] = category_stats.get(code, 0) + 1
- return {
- "items": results,
- "total_count": sum(r.get("level2_count", 0) for r in results),
- "category_stats": category_stats,
- }
- async def classify_tertiary(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """三级分类(简化版:逐 chunk 分类)"""
- if not chunks:
- return chunks
- semaphore = asyncio.Semaphore(self.concurrency)
- async def _classify_chunk(chunk: Dict[str, Any]) -> Dict[str, Any]:
- first_code = chunk.get("chapter_classification", "")
- second_code = chunk.get("secondary_category_code", "")
- if not first_code or not second_code or second_code == "non_standard":
- chunk["tertiary_category_code"] = "none"
- chunk["tertiary_category_cn"] = "无"
- return chunk
- standards = self._build_tertiary_standards(first_code, second_code)
- if not standards:
- chunk["tertiary_category_code"] = "none"
- chunk["tertiary_category_cn"] = "无"
- return chunk
- async with semaphore:
- return await self._call_llm_tertiary(chunk, standards)
- tasks = [_classify_chunk(c) for c in chunks]
- return list(await asyncio.gather(*tasks))
- # ==================== LLM 调用实现 ====================
- async def _call_llm(self, system_prompt: str, user_prompt: str) -> Optional[Dict[str, Any]]:
- """基础 LLM 调用"""
- try:
- response = await self.client.chat.completions.create(
- model=self.model,
- messages=[
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- temperature=0.3,
- )
- content = response.choices[0].message.content or ""
- return _extract_json(content)
- except Exception as e:
- print(f"[LLM 调用失败] {e}")
- return None
- async def _call_llm_primary(self, item: Dict[str, Any]) -> Dict[str, Any]:
- """调用 LLM 进行一级分类"""
- title = item.get("title", "")
- system_prompt = """你是一个施工方案文档目录分类专家。
- 请将给定的一级章节标题分类到以下类别之一,返回 JSON 格式:
- {"category_cn": "类别中文名", "category_code": "类别代码", "confidence": 0.95}
- 可选类别:
- - 编制依据 (basis)
- - 工程概况 (overview)
- - 施工计划 (plan)
- - 施工工艺技术 (technology)
- - 安全保证措施 (safety)
- - 质量保证措施 (quality)
- - 环境保证措施 (environment)
- - 施工管理及作业人员配备与分工 (management)
- - 验收要求 (acceptance)
- - 其他资料 (other)
- - 非标准项 (non_standard)
- 如果标题明显不属于以上任何类别,归为"非标准项"。"""
- user_prompt = f"一级章节标题:{title}"
- result = await self._call_llm(system_prompt, user_prompt)
- if result and isinstance(result, dict):
- category_cn = result.get("category_cn", "")
- category_code = result.get("category_code", "")
- confidence = result.get("confidence", 0.0)
- if category_cn not in PRIMARY_CATEGORIES and category_cn != "非标准项":
- category_cn = "非标准项"
- category_code = "non_standard"
- confidence = 0.0
- if category_cn in PRIMARY_CATEGORIES and not category_code:
- category_code = PRIMARY_CATEGORIES[category_cn]
- else:
- category_cn = "非标准项"
- category_code = "non_standard"
- confidence = 0.0
- return {
- "title": title,
- "page": item.get("page", 0),
- "level": item.get("level", 1),
- "category": category_cn,
- "category_code": category_code,
- "original": item.get("original", ""),
- "level2_titles": item.get("level2_titles", []),
- "confidence": confidence,
- }
- async def _call_llm_secondary(
- self,
- first_category: str,
- first_category_code: str,
- level2_titles: List[str],
- original_title: str,
- ) -> Dict[str, Any]:
- """调用 LLM 进行二级分类(批量模式)"""
- # 获取该一级分类下的二级标准
- secondary_items = []
- if first_category_code in self.classification_tree:
- for sec_code, sec_data in self.classification_tree[first_category_code].items():
- secondary_items.append(f"- {sec_data['second_name']} ({sec_code})")
- standards_text = "\n".join(secondary_items) if secondary_items else "(无预定义标准)"
- titles_list = "\n".join(f"{i+1}. {title}" for i, title in enumerate(level2_titles))
- system_prompt = f"""你是一个施工方案文档目录分类专家。
- 请将以下二级小节标题分类到对应类别,返回 JSON 格式:
- {{"classifications": [{{"title": "原标题", "category_index": 索引, "category_name": "分类名"}}]}}
- 一级分类:{first_category}
- 可选二级分类:
- {standards_text}
- 特殊索引:0 = 非标准项
- 要求:
- 1. 返回的 classifications 数组长度必须与输入标题数量完全一致
- 2. category_index 必须是数字索引
- 3. 只返回 JSON,不要其他解释"""
- user_prompt = f"待分类的二级标题:\n{titles_list}"
- result = await self._call_llm(system_prompt, user_prompt)
- classifications = []
- if result and isinstance(result, dict) and "classifications" in result:
- raw_list = result["classifications"]
- if len(raw_list) == len(level2_titles):
- for i, raw in enumerate(raw_list):
- idx = raw.get("category_index", 0)
- name = raw.get("category_name", "")
- # 查找代码
- code = "non_standard"
- if first_category_code in self.classification_tree:
- for sec_code, sec_data in self.classification_tree[first_category_code].items():
- if sec_data["second_name"] == name or sec_code == name:
- code = sec_code
- break
- if idx == 0 or not name:
- name = "非标准项"
- code = "non_standard"
- classifications.append({
- "title": level2_titles[i],
- "category_index": idx,
- "category_code": code,
- "category_name": name,
- })
- else:
- # 数量不匹配,全部设为非标准项
- for title in level2_titles:
- classifications.append({
- "title": title,
- "category_index": 0,
- "category_code": "non_standard",
- "category_name": "非标准项",
- })
- else:
- # LLM 调用失败,全部设为非标准项
- for title in level2_titles:
- classifications.append({
- "title": title,
- "category_index": 0,
- "category_code": "non_standard",
- "category_name": "非标准项",
- })
- return {
- "first_category": first_category,
- "first_category_code": first_category_code,
- "original_title": original_title,
- "level2_count": len(level2_titles),
- "classifications": classifications,
- }
- async def _call_llm_tertiary(
- self,
- chunk: Dict[str, Any],
- standards: List[Dict[str, str]],
- ) -> Dict[str, Any]:
- """调用 LLM 进行三级分类(简化版)"""
- content = chunk.get("review_chunk_content", "")[:500] # 限制长度
- section_label = chunk.get("section_label", "")
- standards_text = "\n".join(
- f"{i+1}. {s['name']} ({s['code']}) - {s.get('focus', '')}"
- for i, s in enumerate(standards)
- )
- system_prompt = """你是一个施工方案文档内容分类专家。
- 请判断给定的文档内容属于哪个三级分类,返回 JSON 格式:
- {"category_index": 索引, "category_name": "分类名"}
- 如果内容不属于任何类别,返回 {"category_index": 0, "category_name": "非标准项"}。
- 只返回 JSON,不要其他解释。"""
- user_prompt = f"""文档章节:{section_label}
- 内容预览:
- {content}
- 可选分类:
- {standards_text}
- """
- result = await self._call_llm(system_prompt, user_prompt)
- if result and isinstance(result, dict):
- idx = result.get("category_index", 0)
- name = result.get("category_name", "")
- if idx == 0 or not name:
- chunk["tertiary_category_code"] = "non_standard"
- chunk["tertiary_category_cn"] = "非标准项"
- else:
- # 查找 code
- code = "non_standard"
- if idx <= len(standards):
- code = standards[idx - 1]["code"]
- name = standards[idx - 1]["name"]
- chunk["tertiary_category_code"] = code
- chunk["tertiary_category_cn"] = name
- else:
- chunk["tertiary_category_code"] = "non_standard"
- chunk["tertiary_category_cn"] = "非标准项"
- return chunk
- def _build_tertiary_standards(self, first_code: str, second_code: str) -> List[Dict[str, str]]:
- """构建三级分类标准列表"""
- if first_code not in self.classification_tree:
- return []
- if second_code not in self.classification_tree[first_code]:
- return []
- third_items = self.classification_tree[first_code][second_code].get("third_items", [])
- if not third_items:
- return []
- return [
- {
- "code": item["third_code"],
- "name": item["third_name"],
- "focus": item.get("third_focus", ""),
- }
- for item in third_items
- ]
- # ==================== 工具函数 ====================
- def _extract_json(text: str) -> Optional[Dict[str, Any]]:
- """从字符串中提取第一个有效 JSON 对象"""
- if not text or not text.strip():
- return None
- text = text.strip()
- try:
- return json.loads(text)
- except json.JSONDecodeError:
- pass
- for pattern in [r"```json\s*(\{.*?})\s*```", r"```\s*(\{.*?})\s*```"]:
- m = re.search(pattern, text, re.DOTALL)
- if m:
- try:
- return json.loads(m.group(1))
- except json.JSONDecodeError:
- pass
- try:
- for candidate in re.findall(r"(\{[\s\S]*?})", text):
- try:
- result = json.loads(candidate)
- if isinstance(result, dict):
- return result
- except json.JSONDecodeError:
- continue
- except Exception:
- pass
- return None
|