# -*- coding: utf-8 -*- """构建文档对话检索查询文本和关键词。""" from __future__ import annotations import re from typing import Any, Dict, List, Optional, Sequence from core.document_chat.retrieval.config import DEFAULT_ACTION_TERMS, DEFAULT_DOMAIN_TERMS def build_query( state: Dict[str, Any], domain_terms: Optional[Sequence[str]] = None, action_terms: Optional[Sequence[str]] = None, ) -> str: """构建精炼检索 query,避免冗余的项目摘要。""" selected_section = state.get("selected_section") or {} intent_result = state.get("intent_result") or {} keywords = build_query_keywords(state, domain_terms=domain_terms, action_terms=action_terms) parts = [ state.get("user_message") or "", intent_result.get("normalized_instruction") or "", f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(), " ".join(keywords[:8]), ] return dedupe_join(parts, max_chars=120) def build_query_keywords( state: Dict[str, Any], query: Optional[str] = None, domain_terms: Optional[Sequence[str]] = None, action_terms: Optional[Sequence[str]] = None, ) -> List[str]: """从多来源提取检索关键词。""" selected_section = state.get("selected_section") or {} intent_result = state.get("intent_result") or {} history = state.get("conversation_history") or [] sources = [ state.get("user_message") or "", intent_result.get("normalized_instruction") or "", f"{selected_section.get('index', '')} {selected_section.get('title', '')}", str(selected_section.get("content") or "")[:500], query or "", ] if history: for turn in history[-6:]: if not isinstance(turn, dict): continue role = str(turn.get("role") or turn.get("sender") or "").lower() if role in ("assistant", "ai", "bot", "model"): continue content = str(turn.get("content") or turn.get("message") or "") if content: sources.append(content) keywords: List[str] = [] seen = set() for text in sources: for keyword in extract_retrieval_keywords( str(text or ""), domain_terms=domain_terms, action_terms=action_terms, ): normalized = keyword.strip() if not normalized or normalized in seen: continue seen.add(normalized) keywords.append(normalized) if len(keywords) >= 20: return keywords return keywords def dedupe_join(parts: List[str], max_chars: int) -> str: """去重后拼接文本片段,限制总长度。""" values = [] seen = set() for part in parts: value = re.sub(r"\s+", " ", str(part or "")).strip() if not value or value in seen: continue seen.add(value) values.append(value) return " ".join(values)[:max_chars] def extract_retrieval_keywords( text: str, domain_terms: Optional[Sequence[str]] = None, action_terms: Optional[Sequence[str]] = None, ) -> List[str]: """从文本中提取检索关键词。""" if not text: return [] keywords: List[str] = [] for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text): keywords.append(re.sub(r"\s+", "", match).upper()) for match in re.findall(r"《([^》]{2,40})》", text): keywords.append(match.strip()) domain_terms = tuple(domain_terms or DEFAULT_DOMAIN_TERMS) action_terms = tuple(action_terms or DEFAULT_ACTION_TERMS) for term in domain_terms: if term in text: keywords.append(term) action_pattern = "|".join(re.escape(term) for term in action_terms if term) if action_pattern: for match in re.findall(rf"[一-鿿A-Za-z0-9.-]{{0,12}}(?:{action_pattern})", text): if 2 <= len(match) <= 20: keywords.append(match) normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text) for token in normalized.split(): token = token.strip() if len(token) < 2 or len(token) > 12: continue if any(term in token for term in domain_terms): keywords.append(token) seen = set() unique = [] for keyword in keywords: keyword = keyword.strip() if keyword and keyword not in seen: seen.add(keyword) unique.append(keyword) return unique