search_engine_service.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. """
  2. 检索引擎业务逻辑服务
  3. """
  4. from math import ceil
  5. from typing import List, Optional, Tuple, Dict
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from sqlalchemy import select, func, or_
  8. from datetime import datetime
  9. import uuid
  10. import random
  11. import json
  12. import hashlib
  13. import math
  14. from app.sample.models.search_engine import SearchEngine
  15. from app.sample.schemas.search_engine import (
  16. SearchEngineCreate,
  17. SearchEngineUpdate,
  18. KBSearchRequest,
  19. KBSearchResultItem,
  20. KBSearchResponse
  21. )
  22. from app.schemas.base import PaginationSchema
  23. from app.services.milvus_service import milvus_service
  24. from app.utils.vector_utils import text_to_vector_algo
  25. import logging
  26. class SearchEngineService:
  27. async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse:
  28. """
  29. 知识库搜索 (基于算法向量)
  30. """
  31. kb_id = payload.kb_id
  32. if not milvus_service.has_collection(kb_id):
  33. return KBSearchResponse(results=[], total=0)
  34. # 1. 使用算法生成向量 (替代 Embedding 模型)
  35. # 尝试从 Milvus collection 获取向量维度,动态匹配维度
  36. # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
  37. try:
  38. collection_detail = milvus_service.get_collection_detail(kb_id)
  39. except Exception:
  40. collection_detail = None
  41. dim = None
  42. if collection_detail and isinstance(collection_detail, dict):
  43. fields = collection_detail.get("fields", []) or []
  44. for f in fields:
  45. # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
  46. if not isinstance(f, dict):
  47. continue
  48. ftype = str(f.get("type") or "").lower()
  49. print(ftype+'是什么东西')
  50. if "100" in ftype or '101' in ftype: # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
  51. # 找到向量字段,优先从 params.dim 获取维度
  52. params = f.get("params") or {}
  53. if params and params.get("dim"):
  54. try:
  55. dim = int(params.get("dim"))
  56. break
  57. except Exception:
  58. dim = None
  59. # 回退默认维度
  60. if not dim:
  61. dim = 768
  62. # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
  63. anns_field = "vector"
  64. if collection_detail and isinstance(collection_detail, dict):
  65. fields = collection_detail.get("fields", []) or []
  66. # 优先寻找有 params.dim 的向量字段
  67. for f in fields:
  68. if not isinstance(f, dict):
  69. continue
  70. params = f.get("params") or {}
  71. if params and params.get("dim") and f.get("name"):
  72. anns_field = f.get("name")
  73. try:
  74. dim = int(params.get("dim"))
  75. except Exception:
  76. pass
  77. break
  78. # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
  79. if anns_field == "vector":
  80. for f in fields:
  81. if not isinstance(f, dict):
  82. continue
  83. fname = (f.get("name") or "")
  84. ftype = str(f.get("type") or "").lower()
  85. if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"):
  86. anns_field = fname
  87. break
  88. if fname and "vector" in ftype:
  89. anns_field = fname
  90. break
  91. # 1. 向量搜索 (Dense Retrieval)
  92. # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心
  93. # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等
  94. # 目前后端实现主要是基于 Milvus 的 ANN 搜索
  95. # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度)
  96. # 除非用户明确指定了其他度量方式(通常不会)
  97. requested_metric = payload.metric_type
  98. use_hybrid = False
  99. # 只有当 metric_type 为 None 或者特定值时才尝试混合检索
  100. # 或者我们可以认为只要不指定,就优先尝试混合
  101. if not requested_metric or requested_metric.lower() == 'hybrid':
  102. use_hybrid = True
  103. search_params = {
  104. "metric_type": "L2", # 默认内部计算用 L2
  105. "params": {"nprobe": 10},
  106. }
  107. # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性)
  108. if payload.metric_type and payload.metric_type.upper() != 'HYBRID':
  109. search_params["metric_type"] = payload.metric_type
  110. # 2. 构建过滤表达式
  111. expr_list = []
  112. # 兼容旧的单一字段过滤
  113. if payload.metadata_field and payload.metadata_value:
  114. safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip()
  115. safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip()
  116. if safe_field and safe_value:
  117. if safe_value.isdigit():
  118. expr_list.append(f'{safe_field} == {safe_value}')
  119. else:
  120. expr_list.append(f'{safe_field} == "{safe_value}"')
  121. # 处理新的多重过滤
  122. if payload.filters:
  123. for f in payload.filters:
  124. safe_field = f.field.replace("'", "").replace('"', "").strip()
  125. safe_value = f.value.replace("'", "").replace('"', "").strip()
  126. if safe_field and safe_value:
  127. if safe_value.isdigit():
  128. expr_list.append(f'{safe_field} == {safe_value}')
  129. else:
  130. expr_list.append(f'{safe_field} == "{safe_value}"')
  131. # 组合所有条件 (使用 AND)
  132. expr = " and ".join(expr_list) if expr_list else ""
  133. # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
  134. query_vector = text_to_vector_algo(payload.query, dim=dim)
  135. # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
  136. metric_type = None
  137. # 优先从 collection_detail 检测真实 metric
  138. if collection_detail and isinstance(collection_detail, dict):
  139. indices = collection_detail.get("indices") or []
  140. if isinstance(indices, list) and len(indices) > 0:
  141. for idx in indices:
  142. try:
  143. mt = idx.get("metric_type") or idx.get("metric")
  144. if mt:
  145. metric_type = str(mt).upper()
  146. break
  147. except Exception:
  148. continue
  149. # 尝试从 properties 中读取
  150. if not metric_type and collection_detail and isinstance(collection_detail, dict):
  151. props = collection_detail.get("properties") or {}
  152. if isinstance(props, dict):
  153. mt = props.get("metric_type") or props.get("metric")
  154. if mt:
  155. metric_type = str(mt).upper()
  156. actual_search_metric = metric_type
  157. if not actual_search_metric:
  158. # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2
  159. actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2"
  160. metric_type = actual_search_metric
  161. logger = logging.getLogger(__name__)
  162. logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})")
  163. # 3. 执行 Milvus 搜索
  164. try:
  165. # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误
  166. # metric_type 已在上面检测并存放于变量 metric_type
  167. search_params = {
  168. "metric_type": metric_type,
  169. "params": {"nprobe": 10}
  170. }
  171. # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段
  172. # 优先使用 page/page_size 计算 limit 和 offset
  173. page = payload.page if payload.page and payload.page > 0 else 1
  174. page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
  175. # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size
  176. # 但这里为了清晰,优先使用 page_size
  177. offset = (page - 1) * page_size
  178. limit = page_size
  179. # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护
  180. if offset + limit > 16384:
  181. # 如果超出深度分页限制,可能需要提示或截断
  182. # 这里暂时不做处理,让 Milvus 报错或自行截断
  183. pass
  184. # 获取集合总数用于分页显示 (total)
  185. collection_count = 0
  186. if collection_detail and isinstance(collection_detail, dict):
  187. collection_count = collection_detail.get("entity_count") or 0
  188. if not collection_count:
  189. try:
  190. stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
  191. collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0
  192. except Exception:
  193. collection_count = 0
  194. # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0)
  195. # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
  196. try:
  197. # 尝试使用混合检索 (Hybrid Search)
  198. # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
  199. # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
  200. # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
  201. # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
  202. # 这里我们直接调用 milvus_service.hybrid_search
  203. # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
  204. use_hybrid = False
  205. # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
  206. # 或者我们可以认为只要不指定,就优先尝试混合
  207. # 已经在上面判断过 use_hybrid = True 了
  208. if use_hybrid:
  209. logger.info(f"Attempting hybrid search for KB={kb_id}")
  210. try:
  211. # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
  212. # 所以我们需要获取 top_k = offset + limit,然后手动切片
  213. target_k = offset + limit
  214. hybrid_results = milvus_service.hybrid_search(
  215. collection_name=kb_id,
  216. query_text=payload.query,
  217. top_k=target_k
  218. )
  219. # 手动切片实现分页
  220. start = offset
  221. end = offset + limit
  222. # 确保不越界
  223. if start >= len(hybrid_results):
  224. sliced_results = []
  225. else:
  226. sliced_results = hybrid_results[start:end]
  227. formatted_results = []
  228. for item in sliced_results:
  229. formatted_results.append(KBSearchResultItem(
  230. id=str(item.get('id')),
  231. kb_name=kb_id,
  232. doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档",
  233. content=item.get('text_content') or "",
  234. meta_info=str(item.get('metadata', {})),
  235. score=item.get('similarity', 0) * 100 # 假设是 0-1
  236. ))
  237. return KBSearchResponse(results=formatted_results, total=collection_count)
  238. except Exception as hybrid_err:
  239. logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
  240. # Fallback to standard vector search below
  241. pass
  242. results = milvus_service.client.search(
  243. collection_name=kb_id,
  244. data=[query_vector],
  245. anns_field=anns_field,
  246. search_params=search_params,
  247. limit=limit,
  248. offset=offset, # 添加 offset 支持分页
  249. filter=expr if expr else "",
  250. output_fields=["*"]
  251. )
  252. except Exception as milvus_err:
  253. # 捕获 Milvus 异常,常见原因包括 metric mismatch
  254. logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}")
  255. # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试
  256. error_msg = str(milvus_err)
  257. if "metric type not match" in error_msg:
  258. import re
  259. # 匹配 expected=COSINE 或 expected='COSINE' 等格式
  260. # 支持 COSINE, L2, IP, BM25 等
  261. match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg)
  262. if match:
  263. correct_metric = match.group(1).upper()
  264. logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}")
  265. # 更新 metric_type 并重试搜索
  266. search_params["metric_type"] = correct_metric
  267. # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度
  268. metric_type = correct_metric
  269. # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的
  270. # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector)
  271. # 但这里假设 anns_field 是正确的,只是 metric 不对
  272. results = milvus_service.client.search(
  273. collection_name=kb_id,
  274. data=[query_vector],
  275. anns_field=anns_field,
  276. search_params=search_params,
  277. limit=limit,
  278. offset=offset, # 同样加上 offset
  279. filter=expr if expr else "",
  280. output_fields=["*"]
  281. )
  282. else:
  283. raise
  284. else:
  285. raise
  286. # 4. 格式化结果
  287. formatted_results = []
  288. for hits in results:
  289. for hit in hits:
  290. # 过滤低相似度结果 (算法生成的向量相似度可能较低,阈值可适当调低)
  291. # if hit.score < payload.score_threshold:
  292. # continue
  293. entity = hit.entity
  294. content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
  295. if not content:
  296. debug_data = {k: v for k, v in entity.items() if k != anns_field}
  297. content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
  298. doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
  299. meta_info = []
  300. for k, v in entity.items():
  301. if k not in [anns_field, "text", "content", "page_content", "id", "pk"]:
  302. meta_info.append(f"{k}: {v}")
  303. meta_str = "; ".join(meta_info[:3])
  304. # 根据 collection 的 metric 动态计算相似度分数
  305. # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
  306. display_metric = requested_metric if requested_metric else metric_type
  307. similarity_pct = None
  308. try:
  309. raw_score = float(hit.score)
  310. except Exception:
  311. raw_score = None
  312. if raw_score is not None:
  313. # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
  314. # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
  315. # 更好的做法是:
  316. # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
  317. # 2. 转换为 display_metric 要求的格式
  318. # Case 1: 实际是 L2 (距离),用户想看 L2
  319. if "L2" in metric_type or "EUCLIDEAN" in metric_type:
  320. distance = raw_score
  321. if display_metric and ("COSINE" in display_metric):
  322. # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
  323. # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
  324. # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
  325. similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
  326. else:
  327. similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
  328. # Case 2: 实际是 Cosine (相似度 [-1, 1])
  329. elif "COSINE" in metric_type:
  330. cosine_score = raw_score
  331. # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
  332. similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
  333. # Case 3: IP (内积)
  334. elif "IP" in metric_type or "INNER" in metric_type:
  335. similarity_pct = round(raw_score * 100.0, 2)
  336. # Fallback
  337. else:
  338. # 兼容 BM25 或其他未知 metric
  339. if "BM25" in metric_type:
  340. # BM25 分数通常是正数,没有固定上限,直接显示原值
  341. similarity_pct = round(raw_score, 2)
  342. else:
  343. similarity_pct = round(raw_score * 100.0, 2)
  344. else:
  345. similarity_pct = 0.0
  346. formatted_results.append(KBSearchResultItem(
  347. id=str(hit.id),
  348. kb_name=kb_id,
  349. doc_name=doc_name,
  350. content=content,
  351. meta_info=meta_str,
  352. score=similarity_pct
  353. ))
  354. return KBSearchResponse(results=formatted_results, total=len(formatted_results))
  355. except Exception as e:
  356. print(f"Search error: {e}")
  357. return KBSearchResponse(results=[], total=0)
  358. # ... (Keep existing CRUD methods below) ...
  359. async def get_list(
  360. self,
  361. db: AsyncSession,
  362. page: int = 1,
  363. page_size: int = 10,
  364. keyword: Optional[str] = None,
  365. status: Optional[str] = None
  366. ) -> Tuple[List[SearchEngine], PaginationSchema]:
  367. """获取检索引擎列表"""
  368. query = select(SearchEngine).where(SearchEngine.is_deleted == 0)
  369. if keyword:
  370. query = query.where(or_(
  371. SearchEngine.name.like(f"%{keyword}%"),
  372. SearchEngine.description.like(f"%{keyword}%")
  373. ))
  374. if status:
  375. query = query.where(SearchEngine.status == status)
  376. # 计算总数
  377. count_query = select(func.count()).select_from(query.subquery())
  378. total = await db.scalar(count_query) or 0
  379. # 分页查询
  380. query = query.order_by(SearchEngine.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
  381. result = await db.execute(query)
  382. items = result.scalars().all()
  383. total_pages = ceil(total / page_size) if page_size else 0
  384. meta = PaginationSchema(
  385. page=page,
  386. page_size=page_size,
  387. total=total,
  388. total_pages=total_pages,
  389. )
  390. return items, meta
  391. async def create(self, db: AsyncSession, payload: SearchEngineCreate) -> SearchEngine:
  392. """创建检索引擎"""
  393. # 1. 检查名称是否已存在
  394. exists = await db.execute(select(SearchEngine).where(
  395. SearchEngine.name == payload.name,
  396. SearchEngine.is_deleted == 0
  397. ))
  398. if exists.scalars().first():
  399. raise ValueError("检索引擎名称已存在")
  400. try:
  401. now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  402. new_engine = SearchEngine(
  403. id=str(uuid.uuid4()),
  404. name=payload.name,
  405. engine_type=payload.engine_type,
  406. base_url=payload.base_url,
  407. api_key=payload.api_key,
  408. description=payload.description,
  409. status=payload.status or "normal",
  410. created_at=now,
  411. updated_at=now
  412. )
  413. db.add(new_engine)
  414. await db.commit()
  415. await db.refresh(new_engine)
  416. return new_engine
  417. except Exception as e:
  418. await db.rollback()
  419. raise e
  420. async def update(self, db: AsyncSession, id: str, payload: SearchEngineUpdate) -> SearchEngine:
  421. """更新检索引擎信息"""
  422. result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
  423. engine = result.scalars().first()
  424. if not engine:
  425. raise ValueError("检索引擎不存在")
  426. try:
  427. if payload.name is not None:
  428. engine.name = payload.name
  429. if payload.engine_type is not None:
  430. engine.engine_type = payload.engine_type
  431. if payload.base_url is not None:
  432. engine.base_url = payload.base_url
  433. if payload.api_key is not None:
  434. engine.api_key = payload.api_key
  435. if payload.description is not None:
  436. engine.description = payload.description
  437. if payload.status is not None:
  438. engine.status = payload.status
  439. engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  440. await db.commit()
  441. await db.refresh(engine)
  442. return engine
  443. except Exception as e:
  444. await db.rollback()
  445. raise e
  446. async def update_status(self, db: AsyncSession, id: str, status: str) -> SearchEngine:
  447. """更新检索引擎状态"""
  448. result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
  449. engine = result.scalars().first()
  450. if not engine:
  451. raise ValueError("检索引擎不存在")
  452. try:
  453. engine.status = status
  454. engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  455. await db.commit()
  456. await db.refresh(engine)
  457. return engine
  458. except Exception as e:
  459. await db.rollback()
  460. raise e
  461. async def delete(self, db: AsyncSession, id: str) -> None:
  462. """删除检索引擎"""
  463. result = await db.execute(select(SearchEngine).where(SearchEngine.id == id))
  464. engine = result.scalars().first()
  465. if not engine:
  466. raise ValueError("检索引擎不存在")
  467. try:
  468. # 软删除
  469. engine.is_deleted = 1
  470. engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  471. await db.commit()
  472. except Exception as e:
  473. await db.rollback()
  474. raise e
  475. search_engine_service = SearchEngineService()