| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- # -*- coding: utf-8 -*-
- """构建文档对话检索查询文本和关键词。"""
- from __future__ import annotations
- import re
- from typing import Any, Dict, List, Optional, Sequence
- from core.document_chat.retrieval.config import DEFAULT_ACTION_TERMS, DEFAULT_DOMAIN_TERMS
- def build_query(
- state: Dict[str, Any],
- domain_terms: Optional[Sequence[str]] = None,
- action_terms: Optional[Sequence[str]] = None,
- ) -> str:
- """构建精炼检索 query,避免冗余的项目摘要。"""
- selected_section = state.get("selected_section") or {}
- intent_result = state.get("intent_result") or {}
- keywords = build_query_keywords(state, domain_terms=domain_terms, action_terms=action_terms)
- parts = [
- state.get("user_message") or "",
- intent_result.get("normalized_instruction") or "",
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(),
- " ".join(keywords[:8]),
- ]
- return dedupe_join(parts, max_chars=120)
- def build_query_keywords(
- state: Dict[str, Any],
- query: Optional[str] = None,
- domain_terms: Optional[Sequence[str]] = None,
- action_terms: Optional[Sequence[str]] = None,
- ) -> List[str]:
- """从多来源提取检索关键词。"""
- selected_section = state.get("selected_section") or {}
- intent_result = state.get("intent_result") or {}
- history = state.get("conversation_history") or []
- sources = [
- state.get("user_message") or "",
- intent_result.get("normalized_instruction") or "",
- f"{selected_section.get('index', '')} {selected_section.get('title', '')}",
- str(selected_section.get("content") or "")[:500],
- query or "",
- ]
- if history:
- for turn in history[-6:]:
- if not isinstance(turn, dict):
- continue
- role = str(turn.get("role") or turn.get("sender") or "").lower()
- if role in ("assistant", "ai", "bot", "model"):
- continue
- content = str(turn.get("content") or turn.get("message") or "")
- if content:
- sources.append(content)
- keywords: List[str] = []
- seen = set()
- for text in sources:
- for keyword in extract_retrieval_keywords(
- str(text or ""),
- domain_terms=domain_terms,
- action_terms=action_terms,
- ):
- normalized = keyword.strip()
- if not normalized or normalized in seen:
- continue
- seen.add(normalized)
- keywords.append(normalized)
- if len(keywords) >= 20:
- return keywords
- return keywords
- def dedupe_join(parts: List[str], max_chars: int) -> str:
- """去重后拼接文本片段,限制总长度。"""
- values = []
- seen = set()
- for part in parts:
- value = re.sub(r"\s+", " ", str(part or "")).strip()
- if not value or value in seen:
- continue
- seen.add(value)
- values.append(value)
- return " ".join(values)[:max_chars]
- def extract_retrieval_keywords(
- text: str,
- domain_terms: Optional[Sequence[str]] = None,
- action_terms: Optional[Sequence[str]] = None,
- ) -> List[str]:
- """从文本中提取检索关键词。"""
- if not text:
- return []
- keywords: List[str] = []
- for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text):
- keywords.append(re.sub(r"\s+", "", match).upper())
- for match in re.findall(r"《([^》]{2,40})》", text):
- keywords.append(match.strip())
- domain_terms = tuple(domain_terms or DEFAULT_DOMAIN_TERMS)
- action_terms = tuple(action_terms or DEFAULT_ACTION_TERMS)
- for term in domain_terms:
- if term in text:
- keywords.append(term)
- action_pattern = "|".join(re.escape(term) for term in action_terms if term)
- if action_pattern:
- for match in re.findall(rf"[一-鿿A-Za-z0-9.-]{{0,12}}(?:{action_pattern})", text):
- if 2 <= len(match) <= 20:
- keywords.append(match)
- normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text)
- for token in normalized.split():
- token = token.strip()
- if len(token) < 2 or len(token) > 12:
- continue
- if any(term in token for term in domain_terms):
- keywords.append(token)
- seen = set()
- unique = []
- for keyword in keywords:
- keyword = keyword.strip()
- if keyword and keyword not in seen:
- seen.add(keyword)
- unique.append(keyword)
- return unique
|