|
|
@@ -23,6 +23,7 @@ from app.sample.schemas.search_engine import (
|
|
|
from app.schemas.base import PaginationSchema
|
|
|
from app.services.milvus_service import milvus_service
|
|
|
from app.utils.vector_utils import text_to_vector_algo
|
|
|
+import logging
|
|
|
|
|
|
class SearchEngineService:
|
|
|
|
|
|
@@ -36,33 +37,300 @@ class SearchEngineService:
|
|
|
return KBSearchResponse(results=[], total=0)
|
|
|
|
|
|
# 1. 使用算法生成向量 (替代 Embedding 模型)
|
|
|
+ # 尝试从 Milvus collection 获取向量维度,动态匹配维度
|
|
|
# 这样相同的查询词会生成相同的向量,具备了基本的检索能力
|
|
|
- query_vector = text_to_vector_algo(payload.query, dim=768)
|
|
|
+ try:
|
|
|
+ collection_detail = milvus_service.get_collection_detail(kb_id)
|
|
|
+ except Exception:
|
|
|
+ collection_detail = None
|
|
|
+
|
|
|
+ dim = None
|
|
|
+ if collection_detail and isinstance(collection_detail, dict):
|
|
|
+ fields = collection_detail.get("fields", []) or []
|
|
|
+ for f in fields:
|
|
|
+ # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
|
|
|
+ if not isinstance(f, dict):
|
|
|
+ continue
|
|
|
+ ftype = str(f.get("type") or "").lower()
|
|
|
+ print(ftype+'是什么东西')
|
|
|
+ if "100" in ftype or '101' in ftype: # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
|
|
|
+ # 找到向量字段,优先从 params.dim 获取维度
|
|
|
+ params = f.get("params") or {}
|
|
|
+ if params and params.get("dim"):
|
|
|
+ try:
|
|
|
+ dim = int(params.get("dim"))
|
|
|
+ break
|
|
|
+ except Exception:
|
|
|
+ dim = None
|
|
|
+ # 回退默认维度
|
|
|
+ if not dim:
|
|
|
+ dim = 768
|
|
|
+
|
|
|
+ # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
|
|
|
+ anns_field = "vector"
|
|
|
+ if collection_detail and isinstance(collection_detail, dict):
|
|
|
+ fields = collection_detail.get("fields", []) or []
|
|
|
+ # 优先寻找有 params.dim 的向量字段
|
|
|
+ for f in fields:
|
|
|
+ if not isinstance(f, dict):
|
|
|
+ continue
|
|
|
+ params = f.get("params") or {}
|
|
|
+ if params and params.get("dim") and f.get("name"):
|
|
|
+ anns_field = f.get("name")
|
|
|
+ try:
|
|
|
+ dim = int(params.get("dim"))
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ break
|
|
|
+
|
|
|
+ # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
|
|
|
+ if anns_field == "vector":
|
|
|
+ for f in fields:
|
|
|
+ if not isinstance(f, dict):
|
|
|
+ continue
|
|
|
+ fname = (f.get("name") or "")
|
|
|
+ ftype = str(f.get("type") or "").lower()
|
|
|
+ if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"):
|
|
|
+ anns_field = fname
|
|
|
+ break
|
|
|
+ if fname and "vector" in ftype:
|
|
|
+ anns_field = fname
|
|
|
+ break
|
|
|
+
|
|
|
+ # 1. 向量搜索 (Dense Retrieval)
|
|
|
+ # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心
|
|
|
+ # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等
|
|
|
+ # 目前后端实现主要是基于 Milvus 的 ANN 搜索
|
|
|
+
|
|
|
+ # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度)
|
|
|
+ # 除非用户明确指定了其他度量方式(通常不会)
|
|
|
+ requested_metric = payload.metric_type
|
|
|
+ use_hybrid = False
|
|
|
+
|
|
|
+ # 只有当 metric_type 为 None 或者特定值时才尝试混合检索
|
|
|
+ # 或者我们可以认为只要不指定,就优先尝试混合
|
|
|
+ if not requested_metric or requested_metric.lower() == 'hybrid':
|
|
|
+ use_hybrid = True
|
|
|
+
|
|
|
+ search_params = {
|
|
|
+ "metric_type": "L2", # 默认内部计算用 L2
|
|
|
+ "params": {"nprobe": 10},
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性)
|
|
|
+ if payload.metric_type and payload.metric_type.upper() != 'HYBRID':
|
|
|
+ search_params["metric_type"] = payload.metric_type
|
|
|
|
|
|
# 2. 构建过滤表达式
|
|
|
- expr = ""
|
|
|
+ expr_list = []
|
|
|
+
|
|
|
+ # 兼容旧的单一字段过滤
|
|
|
if payload.metadata_field and payload.metadata_value:
|
|
|
- # 示例:假设元数据直接作为字段存在,或者在 extra_info JSON 中
|
|
|
- # 这里需要根据实际 Milvus Collection 的 Schema 调整
|
|
|
- # 暂时忽略,以免报错
|
|
|
- pass
|
|
|
+ safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip()
|
|
|
+ safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip()
|
|
|
|
|
|
+ if safe_field and safe_value:
|
|
|
+ if safe_value.isdigit():
|
|
|
+ expr_list.append(f'{safe_field} == {safe_value}')
|
|
|
+ else:
|
|
|
+ expr_list.append(f'{safe_field} == "{safe_value}"')
|
|
|
+
|
|
|
+ # 处理新的多重过滤
|
|
|
+ if payload.filters:
|
|
|
+ for f in payload.filters:
|
|
|
+ safe_field = f.field.replace("'", "").replace('"', "").strip()
|
|
|
+ safe_value = f.value.replace("'", "").replace('"', "").strip()
|
|
|
+
|
|
|
+ if safe_field and safe_value:
|
|
|
+ if safe_value.isdigit():
|
|
|
+ expr_list.append(f'{safe_field} == {safe_value}')
|
|
|
+ else:
|
|
|
+ expr_list.append(f'{safe_field} == "{safe_value}"')
|
|
|
+
|
|
|
+ # 组合所有条件 (使用 AND)
|
|
|
+ expr = " and ".join(expr_list) if expr_list else ""
|
|
|
+
|
|
|
+ # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
|
|
|
+ query_vector = text_to_vector_algo(payload.query, dim=dim)
|
|
|
+
|
|
|
+ # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
|
|
|
+ metric_type = None
|
|
|
+ # 优先从 collection_detail 检测真实 metric
|
|
|
+ if collection_detail and isinstance(collection_detail, dict):
|
|
|
+ indices = collection_detail.get("indices") or []
|
|
|
+ if isinstance(indices, list) and len(indices) > 0:
|
|
|
+ for idx in indices:
|
|
|
+ try:
|
|
|
+ mt = idx.get("metric_type") or idx.get("metric")
|
|
|
+ if mt:
|
|
|
+ metric_type = str(mt).upper()
|
|
|
+ break
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 尝试从 properties 中读取
|
|
|
+ if not metric_type and collection_detail and isinstance(collection_detail, dict):
|
|
|
+ props = collection_detail.get("properties") or {}
|
|
|
+ if isinstance(props, dict):
|
|
|
+ mt = props.get("metric_type") or props.get("metric")
|
|
|
+ if mt:
|
|
|
+ metric_type = str(mt).upper()
|
|
|
+
|
|
|
+ actual_search_metric = metric_type
|
|
|
+ if not actual_search_metric:
|
|
|
+ # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2
|
|
|
+ actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2"
|
|
|
+
|
|
|
+ metric_type = actual_search_metric
|
|
|
+
|
|
|
+ logger = logging.getLogger(__name__)
|
|
|
+ logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})")
|
|
|
+
|
|
|
# 3. 执行 Milvus 搜索
|
|
|
try:
|
|
|
+ # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误
|
|
|
+ # metric_type 已在上面检测并存放于变量 metric_type
|
|
|
search_params = {
|
|
|
- "metric_type": "COSINE",
|
|
|
+ "metric_type": metric_type,
|
|
|
"params": {"nprobe": 10}
|
|
|
}
|
|
|
+
|
|
|
+ # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段
|
|
|
+ # 优先使用 page/page_size 计算 limit 和 offset
|
|
|
+ page = payload.page if payload.page and payload.page > 0 else 1
|
|
|
+ page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
|
|
|
|
|
|
- results = milvus_service.client.search(
|
|
|
- collection_name=kb_id,
|
|
|
- data=[query_vector],
|
|
|
- anns_field="vector",
|
|
|
- search_params=search_params,
|
|
|
- limit=payload.top_k,
|
|
|
- filter=expr if expr else "",
|
|
|
- output_fields=["*"]
|
|
|
- )
|
|
|
+ # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size
|
|
|
+ # 但这里为了清晰,优先使用 page_size
|
|
|
+
|
|
|
+ offset = (page - 1) * page_size
|
|
|
+ limit = page_size
|
|
|
+
|
|
|
+ # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护
|
|
|
+ if offset + limit > 16384:
|
|
|
+ # 如果超出深度分页限制,可能需要提示或截断
|
|
|
+ # 这里暂时不做处理,让 Milvus 报错或自行截断
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 获取集合总数用于分页显示 (total)
|
|
|
+ collection_count = 0
|
|
|
+ if collection_detail and isinstance(collection_detail, dict):
|
|
|
+ collection_count = collection_detail.get("entity_count") or 0
|
|
|
+
|
|
|
+ if not collection_count:
|
|
|
+ try:
|
|
|
+ stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
|
|
|
+ collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0
|
|
|
+ except Exception:
|
|
|
+ collection_count = 0
|
|
|
+
|
|
|
+ # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0)
|
|
|
+ # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 尝试使用混合检索 (Hybrid Search)
|
|
|
+ # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
|
|
|
+ # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
|
|
|
+
|
|
|
+ # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
|
|
|
+ # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
|
|
|
+
|
|
|
+ # 这里我们直接调用 milvus_service.hybrid_search
|
|
|
+ # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
|
|
|
+
|
|
|
+ use_hybrid = False
|
|
|
+ # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
|
|
|
+ # 或者我们可以认为只要不指定,就优先尝试混合
|
|
|
+ # 已经在上面判断过 use_hybrid = True 了
|
|
|
+
|
|
|
+ if use_hybrid:
|
|
|
+ logger.info(f"Attempting hybrid search for KB={kb_id}")
|
|
|
+ try:
|
|
|
+ # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
|
|
|
+ # 所以我们需要获取 top_k = offset + limit,然后手动切片
|
|
|
+ target_k = offset + limit
|
|
|
+
|
|
|
+ hybrid_results = milvus_service.hybrid_search(
|
|
|
+ collection_name=kb_id,
|
|
|
+ query_text=payload.query,
|
|
|
+ top_k=target_k
|
|
|
+ )
|
|
|
+
|
|
|
+ # 手动切片实现分页
|
|
|
+ start = offset
|
|
|
+ end = offset + limit
|
|
|
+ # 确保不越界
|
|
|
+ if start >= len(hybrid_results):
|
|
|
+ sliced_results = []
|
|
|
+ else:
|
|
|
+ sliced_results = hybrid_results[start:end]
|
|
|
+
|
|
|
+ formatted_results = []
|
|
|
+ for item in sliced_results:
|
|
|
+ formatted_results.append(KBSearchResultItem(
|
|
|
+ id=str(item.get('id')),
|
|
|
+ kb_name=kb_id,
|
|
|
+ doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档",
|
|
|
+ content=item.get('text_content') or "",
|
|
|
+ meta_info=str(item.get('metadata', {})),
|
|
|
+ score=item.get('similarity', 0) * 100 # 假设是 0-1
|
|
|
+ ))
|
|
|
+
|
|
|
+ return KBSearchResponse(results=formatted_results, total=collection_count)
|
|
|
+
|
|
|
+ except Exception as hybrid_err:
|
|
|
+ logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
|
|
|
+ # Fallback to standard vector search below
|
|
|
+ pass
|
|
|
+
|
|
|
+ results = milvus_service.client.search(
|
|
|
+ collection_name=kb_id,
|
|
|
+ data=[query_vector],
|
|
|
+ anns_field=anns_field,
|
|
|
+ search_params=search_params,
|
|
|
+ limit=limit,
|
|
|
+ offset=offset, # 添加 offset 支持分页
|
|
|
+ filter=expr if expr else "",
|
|
|
+ output_fields=["*"]
|
|
|
+ )
|
|
|
+ except Exception as milvus_err:
|
|
|
+ # 捕获 Milvus 异常,常见原因包括 metric mismatch
|
|
|
+ logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}")
|
|
|
+
|
|
|
+ # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试
|
|
|
+ error_msg = str(milvus_err)
|
|
|
+ if "metric type not match" in error_msg:
|
|
|
+ import re
|
|
|
+ # 匹配 expected=COSINE 或 expected='COSINE' 等格式
|
|
|
+ # 支持 COSINE, L2, IP, BM25 等
|
|
|
+ match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg)
|
|
|
+ if match:
|
|
|
+ correct_metric = match.group(1).upper()
|
|
|
+ logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}")
|
|
|
+
|
|
|
+ # 更新 metric_type 并重试搜索
|
|
|
+ search_params["metric_type"] = correct_metric
|
|
|
+ # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度
|
|
|
+ metric_type = correct_metric
|
|
|
+
|
|
|
+ # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的
|
|
|
+ # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector)
|
|
|
+ # 但这里假设 anns_field 是正确的,只是 metric 不对
|
|
|
+
|
|
|
+ results = milvus_service.client.search(
|
|
|
+ collection_name=kb_id,
|
|
|
+ data=[query_vector],
|
|
|
+ anns_field=anns_field,
|
|
|
+ search_params=search_params,
|
|
|
+ limit=limit,
|
|
|
+ offset=offset, # 同样加上 offset
|
|
|
+ filter=expr if expr else "",
|
|
|
+ output_fields=["*"]
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+ else:
|
|
|
+ raise
|
|
|
|
|
|
# 4. 格式化结果
|
|
|
formatted_results = []
|
|
|
@@ -73,27 +341,76 @@ class SearchEngineService:
|
|
|
# continue
|
|
|
|
|
|
entity = hit.entity
|
|
|
-
|
|
|
+
|
|
|
content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
|
|
|
if not content:
|
|
|
- debug_data = {k:v for k,v in entity.items() if k != "vector"}
|
|
|
+ debug_data = {k: v for k, v in entity.items() if k != anns_field}
|
|
|
content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
|
|
|
|
|
|
doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
|
|
|
|
|
|
meta_info = []
|
|
|
for k, v in entity.items():
|
|
|
- if k not in ["vector", "text", "content", "page_content", "id", "pk"]:
|
|
|
+ if k not in [anns_field, "text", "content", "page_content", "id", "pk"]:
|
|
|
meta_info.append(f"{k}: {v}")
|
|
|
meta_str = "; ".join(meta_info[:3])
|
|
|
|
|
|
+ # 根据 collection 的 metric 动态计算相似度分数
|
|
|
+ # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
|
|
|
+ display_metric = requested_metric if requested_metric else metric_type
|
|
|
+
|
|
|
+ similarity_pct = None
|
|
|
+ try:
|
|
|
+ raw_score = float(hit.score)
|
|
|
+ except Exception:
|
|
|
+ raw_score = None
|
|
|
+
|
|
|
+ if raw_score is not None:
|
|
|
+ # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
|
|
|
+ # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
|
|
|
+ # 更好的做法是:
|
|
|
+ # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
|
|
|
+ # 2. 转换为 display_metric 要求的格式
|
|
|
+
|
|
|
+ # Case 1: 实际是 L2 (距离),用户想看 L2
|
|
|
+ if "L2" in metric_type or "EUCLIDEAN" in metric_type:
|
|
|
+ distance = raw_score
|
|
|
+ if display_metric and ("COSINE" in display_metric):
|
|
|
+ # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
|
|
|
+ # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
|
|
|
+ # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
|
|
|
+ similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
|
|
|
+ else:
|
|
|
+ similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
|
|
|
+
|
|
|
+ # Case 2: 实际是 Cosine (相似度 [-1, 1])
|
|
|
+ elif "COSINE" in metric_type:
|
|
|
+ cosine_score = raw_score
|
|
|
+ # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
|
|
|
+ similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
|
|
|
+
|
|
|
+ # Case 3: IP (内积)
|
|
|
+ elif "IP" in metric_type or "INNER" in metric_type:
|
|
|
+ similarity_pct = round(raw_score * 100.0, 2)
|
|
|
+
|
|
|
+ # Fallback
|
|
|
+ else:
|
|
|
+ # 兼容 BM25 或其他未知 metric
|
|
|
+ if "BM25" in metric_type:
|
|
|
+ # BM25 分数通常是正数,没有固定上限,直接显示原值
|
|
|
+ similarity_pct = round(raw_score, 2)
|
|
|
+ else:
|
|
|
+ similarity_pct = round(raw_score * 100.0, 2)
|
|
|
+ else:
|
|
|
+ similarity_pct = 0.0
|
|
|
+
|
|
|
formatted_results.append(KBSearchResultItem(
|
|
|
id=str(hit.id),
|
|
|
- kb_name=kb_id,
|
|
|
+ kb_name=kb_id,
|
|
|
doc_name=doc_name,
|
|
|
content=content,
|
|
|
meta_info=meta_str,
|
|
|
- score=round(hit.score * 100, 2)
|
|
|
+ score=similarity_pct
|
|
|
))
|
|
|
|
|
|
return KBSearchResponse(results=formatted_results, total=len(formatted_results))
|