query_builder.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # -*- coding: utf-8 -*-
  2. """构建文档对话检索查询文本和关键词。"""
  3. from __future__ import annotations
  4. import re
  5. from typing import Any, Dict, List, Optional, Sequence
  6. from core.document_chat.retrieval.config import DEFAULT_ACTION_TERMS, DEFAULT_DOMAIN_TERMS
  7. def build_query(
  8. state: Dict[str, Any],
  9. domain_terms: Optional[Sequence[str]] = None,
  10. action_terms: Optional[Sequence[str]] = None,
  11. ) -> str:
  12. """构建精炼检索 query,避免冗余的项目摘要。"""
  13. selected_section = state.get("selected_section") or {}
  14. intent_result = state.get("intent_result") or {}
  15. keywords = build_query_keywords(state, domain_terms=domain_terms, action_terms=action_terms)
  16. parts = [
  17. state.get("user_message") or "",
  18. intent_result.get("normalized_instruction") or "",
  19. f"{selected_section.get('index', '')} {selected_section.get('title', '')}".strip(),
  20. " ".join(keywords[:8]),
  21. ]
  22. return dedupe_join(parts, max_chars=120)
  23. def build_query_keywords(
  24. state: Dict[str, Any],
  25. query: Optional[str] = None,
  26. domain_terms: Optional[Sequence[str]] = None,
  27. action_terms: Optional[Sequence[str]] = None,
  28. ) -> List[str]:
  29. """从多来源提取检索关键词。"""
  30. selected_section = state.get("selected_section") or {}
  31. intent_result = state.get("intent_result") or {}
  32. history = state.get("conversation_history") or []
  33. sources = [
  34. state.get("user_message") or "",
  35. intent_result.get("normalized_instruction") or "",
  36. f"{selected_section.get('index', '')} {selected_section.get('title', '')}",
  37. str(selected_section.get("content") or "")[:500],
  38. query or "",
  39. ]
  40. if history:
  41. for turn in history[-6:]:
  42. if not isinstance(turn, dict):
  43. continue
  44. role = str(turn.get("role") or turn.get("sender") or "").lower()
  45. if role in ("assistant", "ai", "bot", "model"):
  46. continue
  47. content = str(turn.get("content") or turn.get("message") or "")
  48. if content:
  49. sources.append(content)
  50. keywords: List[str] = []
  51. seen = set()
  52. for text in sources:
  53. for keyword in extract_retrieval_keywords(
  54. str(text or ""),
  55. domain_terms=domain_terms,
  56. action_terms=action_terms,
  57. ):
  58. normalized = keyword.strip()
  59. if not normalized or normalized in seen:
  60. continue
  61. seen.add(normalized)
  62. keywords.append(normalized)
  63. if len(keywords) >= 20:
  64. return keywords
  65. return keywords
  66. def dedupe_join(parts: List[str], max_chars: int) -> str:
  67. """去重后拼接文本片段,限制总长度。"""
  68. values = []
  69. seen = set()
  70. for part in parts:
  71. value = re.sub(r"\s+", " ", str(part or "")).strip()
  72. if not value or value in seen:
  73. continue
  74. seen.add(value)
  75. values.append(value)
  76. return " ".join(values)[:max_chars]
  77. def extract_retrieval_keywords(
  78. text: str,
  79. domain_terms: Optional[Sequence[str]] = None,
  80. action_terms: Optional[Sequence[str]] = None,
  81. ) -> List[str]:
  82. """从文本中提取检索关键词。"""
  83. if not text:
  84. return []
  85. keywords: List[str] = []
  86. for match in re.findall(r"[A-Za-z]{1,8}\s*\d{2,8}(?:[-—]\d{2,4})?", text):
  87. keywords.append(re.sub(r"\s+", "", match).upper())
  88. for match in re.findall(r"《([^》]{2,40})》", text):
  89. keywords.append(match.strip())
  90. domain_terms = tuple(domain_terms or DEFAULT_DOMAIN_TERMS)
  91. action_terms = tuple(action_terms or DEFAULT_ACTION_TERMS)
  92. for term in domain_terms:
  93. if term in text:
  94. keywords.append(term)
  95. action_pattern = "|".join(re.escape(term) for term in action_terms if term)
  96. if action_pattern:
  97. for match in re.findall(rf"[一-鿿A-Za-z0-9.-]{{0,12}}(?:{action_pattern})", text):
  98. if 2 <= len(match) <= 20:
  99. keywords.append(match)
  100. normalized = re.sub(r"[\s,,。;;::、/\\|()\[\]{}<>《》\"'""??]+", " ", text)
  101. for token in normalized.split():
  102. token = token.strip()
  103. if len(token) < 2 or len(token) > 12:
  104. continue
  105. if any(term in token for term in domain_terms):
  106. keywords.append(token)
  107. seen = set()
  108. unique = []
  109. for keyword in keywords:
  110. keyword = keyword.strip()
  111. if keyword and keyword not in seen:
  112. seen.add(keyword)
  113. unique.append(keyword)
  114. return unique