|
|
@@ -97,7 +97,7 @@ class SearchEngineService:
|
|
|
dim = 768
|
|
|
|
|
|
# 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
|
|
|
- anns_field = "vector"
|
|
|
+ anns_field = "dense"
|
|
|
if collection_detail and isinstance(collection_detail, dict):
|
|
|
fields = collection_detail.get("fields", []) or []
|
|
|
# 优先寻找有 params.dim 的向量字段
|
|
|
@@ -114,7 +114,7 @@ class SearchEngineService:
|
|
|
break
|
|
|
|
|
|
# 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
|
|
|
- if anns_field == "vector":
|
|
|
+ if anns_field == "dense":
|
|
|
for f in fields:
|
|
|
if not isinstance(f, dict):
|
|
|
continue
|
|
|
@@ -179,24 +179,24 @@ class SearchEngineService:
|
|
|
target_field = safe_field
|
|
|
if not is_top_level:
|
|
|
target_field = f'metadata["{safe_field}"]'
|
|
|
-
|
|
|
- if safe_value.isdigit():
|
|
|
- expr_list.append(f'{target_field} == {safe_value}')
|
|
|
- else:
|
|
|
- expr_list.append(f'{target_field} == "{safe_value}"')
|
|
|
+
|
|
|
+ expr_list.append(f'{target_field} == "{safe_value}"')
|
|
|
|
|
|
# 处理新的多重过滤
|
|
|
if payload.filters:
|
|
|
for f in payload.filters:
|
|
|
safe_field = f.field.replace("'", "").replace('"', "").strip()
|
|
|
- safe_value = f.value.replace("'", "").replace('"', "").strip()
|
|
|
+ raw_value = (f.value or "").strip()
|
|
|
+ safe_value = raw_value.replace("'", "").replace('"', "").strip()
|
|
|
|
|
|
if safe_field and safe_value:
|
|
|
# [Special Case] 文档名称过滤 (doc_name_in)
|
|
|
# 前端传递的是 "doc_name_in", value 是 JSON 数组字符串 (e.g. '["doc1", "doc2"]')
|
|
|
if safe_field == 'doc_name_in':
|
|
|
try:
|
|
|
- doc_names = json.loads(safe_value)
|
|
|
+ # doc_name_in 的 value 必须保留引号,否则不是合法 JSON
|
|
|
+ parse_value = raw_value
|
|
|
+ doc_names = json.loads(parse_value)
|
|
|
if doc_names and isinstance(doc_names, list):
|
|
|
# 构建 OR 条件: (metadata["doc_name"] == "A" || metadata["doc_name"] == "B")
|
|
|
# 注意:Milvus 字段可能是 doc_name, file_name, title, source
|
|
|
@@ -231,7 +231,7 @@ class SearchEngineService:
|
|
|
# JSON 字段内访问不存在的 key 通常返回 null/empty,不会报错
|
|
|
# 所以用 OR 连接是安全的
|
|
|
|
|
|
- doc_filter_expr = f"({' || '.join(sub_exprs)})"
|
|
|
+ doc_filter_expr = f"({' or '.join(sub_exprs)})"
|
|
|
expr_list.append(doc_filter_expr)
|
|
|
continue # 处理完特殊字段,跳过后续通用逻辑
|
|
|
except Exception as e:
|
|
|
@@ -251,14 +251,127 @@ class SearchEngineService:
|
|
|
if not is_top_level:
|
|
|
target_field = f'metadata["{safe_field}"]'
|
|
|
|
|
|
- if safe_value.isdigit():
|
|
|
- expr_list.append(f'{target_field} == {safe_value}')
|
|
|
- else:
|
|
|
- expr_list.append(f'{target_field} == "{safe_value}"')
|
|
|
+ # [Fix] 统一将 metadata 值视为字符串查询
|
|
|
+ expr_list.append(f'{target_field} == "{safe_value}"')
|
|
|
|
|
|
# 组合所有条件 (使用 AND)
|
|
|
expr = " and ".join(expr_list) if expr_list else ""
|
|
|
|
|
|
+ # 3. 确定分页参数
|
|
|
+ page = payload.page if payload.page and payload.page > 0 else 1
|
|
|
+ page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
|
|
|
+ offset = (page - 1) * page_size
|
|
|
+ limit = page_size
|
|
|
+
|
|
|
+ has_query = payload.query and payload.query.strip()
|
|
|
+
|
|
|
+ if not has_query:
|
|
|
+ # --- 分支 A: 纯标量查询 (无关键词) ---
|
|
|
+ logger = logging.getLogger(__name__)
|
|
|
+ logger.info(f"Scalar query mode for KB={kb_id}, expr={expr}")
|
|
|
+ try:
|
|
|
+ # 1. 获取总数
|
|
|
+ total = 0
|
|
|
+ count_expr = expr if expr else ""
|
|
|
+ # 如果没有表达式,默认查所有 (需要满足 Milvus 语法)
|
|
|
+ if not count_expr:
|
|
|
+ # 简单获取 stats
|
|
|
+ stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
|
|
|
+ total = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
|
|
|
+ else:
|
|
|
+ # 带条件 count
|
|
|
+ res_cnt = milvus_service.client.query(kb_id, filter=count_expr, output_fields=["count(*)"])
|
|
|
+ if res_cnt:
|
|
|
+ total = res_cnt[0].get("count(*)") or 0
|
|
|
+
|
|
|
+ # 2. 分页查询
|
|
|
+ # 如果没有 expr,Milvus query 需要一个 valid expression
|
|
|
+ # 尝试用 id >= 0,前提是 id 是 int。如果是 varchar,用 id != ""
|
|
|
+ query_expr = expr
|
|
|
+ if not query_expr:
|
|
|
+ # 获取主键字段名和类型
|
|
|
+ pk_field = "pk" # 默认
|
|
|
+ is_int = True
|
|
|
+ try:
|
|
|
+ desc = milvus_service.client.describe_collection(kb_id)
|
|
|
+ if isinstance(desc, dict) and 'fields' in desc:
|
|
|
+ for f in desc['fields']:
|
|
|
+ if f.get('primary_key') or f.get('is_primary'):
|
|
|
+ pk_field = f.get('name')
|
|
|
+ # Type 5 is INT64, 21 is VARCHAR.
|
|
|
+ if f.get('type') == 21 or str(f.get('type')).upper() == 'VARCHAR':
|
|
|
+ is_int = False
|
|
|
+ break
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ query_expr = f'{pk_field} >= 0' if is_int else f'{pk_field} != ""'
|
|
|
+
|
|
|
+ res_page = milvus_service.client.query(
|
|
|
+ collection_name=kb_id,
|
|
|
+ filter=query_expr,
|
|
|
+ output_fields=["*"],
|
|
|
+ limit=limit,
|
|
|
+ offset=offset
|
|
|
+ )
|
|
|
+
|
|
|
+ formatted_results = []
|
|
|
+ for item in res_page:
|
|
|
+ item_metadata = item.get('metadata') or {}
|
|
|
+ if isinstance(item_metadata, str):
|
|
|
+ try:
|
|
|
+ item_metadata = json.loads(item_metadata)
|
|
|
+ except Exception:
|
|
|
+ item_metadata = {}
|
|
|
+
|
|
|
+ # PDR 模式内容获取 (可选)
|
|
|
+ item_content = item.get('text') or item.get('content') or item.get('page_content') or ""
|
|
|
+ if is_pdr:
|
|
|
+ parent_id = item_metadata.get("parent_id") or item.get("parent_id")
|
|
|
+ if parent_id:
|
|
|
+ try:
|
|
|
+ parent_results = milvus_service.client.query(
|
|
|
+ collection_name=parent_col,
|
|
|
+ filter=f'parent_id == "{parent_id}"',
|
|
|
+ output_fields=["text", "content", "page_content"]
|
|
|
+ )
|
|
|
+ if parent_results:
|
|
|
+ p_entity = parent_results[0]
|
|
|
+ parent_full = p_entity.get("text") or p_entity.get("content") or p_entity.get("page_content")
|
|
|
+ if parent_full:
|
|
|
+ item_content = f"【父段内容】\n{parent_full}\n\n【片段内容】\n{item_content}"
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ doc_name = (
|
|
|
+ item_metadata.get('doc_name')
|
|
|
+ or item_metadata.get('file_name')
|
|
|
+ or item_metadata.get('title')
|
|
|
+ or item_metadata.get('source')
|
|
|
+ or item.get('file_name')
|
|
|
+ or item.get('title')
|
|
|
+ or item.get('source')
|
|
|
+ or "未知文档"
|
|
|
+ )
|
|
|
+
|
|
|
+ formatted_results.append(KBSearchResultItem(
|
|
|
+ id=str(item.get('pk') or item.get('id')),
|
|
|
+ kb_name=original_kb_id,
|
|
|
+ doc_name=doc_name,
|
|
|
+ content=item_content,
|
|
|
+ meta_info=str(item_metadata),
|
|
|
+ document_id=str(item.get("document_id") or ""),
|
|
|
+ metadata=item_metadata,
|
|
|
+ score=0
|
|
|
+ ))
|
|
|
+
|
|
|
+ return KBSearchResponse(results=formatted_results, total=total)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Scalar query failed: {e}")
|
|
|
+ return KBSearchResponse(results=[], total=0)
|
|
|
+
|
|
|
+ # --- 分支 B: 向量/混合检索 (有关键词) ---
|
|
|
# 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
|
|
|
query_vector = text_to_vector_algo(payload.query, dim=dim)
|
|
|
|
|
|
@@ -337,80 +450,7 @@ class SearchEngineService:
|
|
|
# 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
|
|
|
|
|
|
try:
|
|
|
- # 尝试使用混合检索 (Hybrid Search)
|
|
|
- # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
|
|
|
- # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
|
|
|
-
|
|
|
- # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
|
|
|
- # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
|
|
|
-
|
|
|
- # 这里我们直接调用 milvus_service.hybrid_search
|
|
|
- # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
|
|
|
-
|
|
|
- use_hybrid = False
|
|
|
- # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
|
|
|
- # 或者我们可以认为只要不指定,就优先尝试混合
|
|
|
- # 已经在上面判断过 use_hybrid = True 了
|
|
|
-
|
|
|
- if use_hybrid:
|
|
|
- logger.info(f"Attempting hybrid search for KB={kb_id}")
|
|
|
- try:
|
|
|
- # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
|
|
|
- # 所以我们需要获取 top_k = offset + limit,然后手动切片
|
|
|
- target_k = offset + limit
|
|
|
-
|
|
|
- hybrid_results = milvus_service.hybrid_search(
|
|
|
- collection_name=kb_id,
|
|
|
- query_text=payload.query,
|
|
|
- top_k=target_k
|
|
|
- )
|
|
|
-
|
|
|
- # 手动切片实现分页
|
|
|
- start = offset
|
|
|
- end = offset + limit
|
|
|
- # 确保不越界
|
|
|
- if start >= len(hybrid_results):
|
|
|
- sliced_results = []
|
|
|
- else:
|
|
|
- sliced_results = hybrid_results[start:end]
|
|
|
-
|
|
|
- formatted_results = []
|
|
|
- for item in sliced_results:
|
|
|
- item_content = item.get('text_content') or ""
|
|
|
- item_metadata = item.get('metadata', {})
|
|
|
-
|
|
|
- # PDR 模式:从父表获取内容
|
|
|
- if is_pdr:
|
|
|
- parent_id = item_metadata.get("parent_id")
|
|
|
- if parent_id:
|
|
|
- try:
|
|
|
- parent_results = milvus_service.client.query(
|
|
|
- collection_name=parent_col,
|
|
|
- filter=f'parent_id == "{parent_id}"',
|
|
|
- output_fields=["text", "content", "page_content"]
|
|
|
- )
|
|
|
- if parent_results:
|
|
|
- p_entity = parent_results[0]
|
|
|
- item_content = p_entity.get("text") or p_entity.get("content") or p_entity.get("page_content") or item_content
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"Failed to fetch parent chunk {parent_id} from {parent_col}: {e}")
|
|
|
-
|
|
|
- formatted_results.append(KBSearchResultItem(
|
|
|
- id=str(item.get('id')),
|
|
|
- kb_name=original_kb_id, # 使用原始 ID
|
|
|
- doc_name=item_metadata.get('file_name') or item_metadata.get('source') or "未知文档",
|
|
|
- content=item_content,
|
|
|
- meta_info=str(item_metadata),
|
|
|
- score=item.get('similarity', 0) * 100
|
|
|
- ))
|
|
|
-
|
|
|
- return KBSearchResponse(results=formatted_results, total=collection_count)
|
|
|
-
|
|
|
- except Exception as hybrid_err:
|
|
|
- logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
|
|
|
- # Fallback to standard vector search below
|
|
|
- pass
|
|
|
-
|
|
|
+ # 暂时仅使用向量检索,关闭混合检索以保证相似度与查询内容强相关
|
|
|
results = milvus_service.client.search(
|
|
|
collection_name=kb_id,
|
|
|
data=[query_vector],
|
|
|
@@ -651,8 +691,7 @@ class SearchEngineService:
|
|
|
meta_info.append(f"{k}: {v}")
|
|
|
meta_str = "; ".join(meta_info[:3])
|
|
|
|
|
|
- # 根据 collection 的 metric 动态计算相似度分数
|
|
|
- # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
|
|
|
+ # 根据 collection 的 metric 动态计算相似度分数(先从原始向量距离/相似度换算到 0-100)
|
|
|
display_metric = requested_metric if requested_metric else metric_type
|
|
|
|
|
|
similarity_pct = None
|
|
|
@@ -662,54 +701,61 @@ class SearchEngineService:
|
|
|
raw_score = None
|
|
|
|
|
|
if raw_score is not None:
|
|
|
- # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
|
|
|
- # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
|
|
|
- # 更好的做法是:
|
|
|
- # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
|
|
|
- # 2. 转换为 display_metric 要求的格式
|
|
|
-
|
|
|
- # Case 1: 实际是 L2 (距离),用户想看 L2
|
|
|
if "L2" in metric_type or "EUCLIDEAN" in metric_type:
|
|
|
distance = raw_score
|
|
|
- if display_metric and ("COSINE" in display_metric):
|
|
|
- # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
|
|
|
- # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
|
|
|
- # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
|
|
|
- similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
|
|
|
- else:
|
|
|
- similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
|
|
|
-
|
|
|
- # Case 2: 实际是 Cosine (相似度 [-1, 1])
|
|
|
+ similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
|
|
|
elif "COSINE" in metric_type:
|
|
|
cosine_score = raw_score
|
|
|
- # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
|
|
|
similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
|
|
|
-
|
|
|
- # Case 3: IP (内积)
|
|
|
elif "IP" in metric_type or "INNER" in metric_type:
|
|
|
similarity_pct = round(raw_score * 100.0, 2)
|
|
|
-
|
|
|
- # Fallback
|
|
|
else:
|
|
|
- # 兼容 BM25 或其他未知 metric
|
|
|
if "BM25" in metric_type:
|
|
|
- # BM25 分数通常是正数,没有固定上限,直接显示原值
|
|
|
similarity_pct = round(raw_score, 2)
|
|
|
else:
|
|
|
similarity_pct = round(raw_score * 100.0, 2)
|
|
|
else:
|
|
|
similarity_pct = 0.0
|
|
|
|
|
|
+ # 结合关键词做一次简单的“相关性校正”:若片段完全不包含检索词,则降低相似度权重
|
|
|
+ query_text = (payload.query or "").strip()
|
|
|
+ if query_text:
|
|
|
+ plain_content = (content or "") + " " + (doc_name or "") + " " + (meta_str or "")
|
|
|
+ if query_text not in plain_content:
|
|
|
+ # 纯向量相似但没有任何文本命中,认为是“语义可能相关,但与关键词弱相关”,适当降权
|
|
|
+ similarity_pct = round(similarity_pct * 0.4, 2)
|
|
|
+
|
|
|
formatted_results.append(KBSearchResultItem(
|
|
|
id=str(hit.id),
|
|
|
kb_name=original_kb_id,
|
|
|
doc_name=doc_name,
|
|
|
content=content,
|
|
|
meta_info=meta_str,
|
|
|
+ document_id=str(document_id) if document_id is not None else None,
|
|
|
+ metadata=meta_dict if isinstance(meta_dict, dict) else None,
|
|
|
score=similarity_pct
|
|
|
))
|
|
|
|
|
|
- return KBSearchResponse(results=formatted_results, total=len(formatted_results))
|
|
|
+ # 按相似度由大到小排序
|
|
|
+ formatted_results.sort(key=lambda x: x.score, reverse=True)
|
|
|
+
|
|
|
+ # [Fix] 动态计算 total 用于分页
|
|
|
+ # 如果当前页结果不满 limit,说明是最后一页
|
|
|
+ current_count = len(formatted_results)
|
|
|
+ if current_count < limit:
|
|
|
+ final_total = offset + current_count
|
|
|
+ else:
|
|
|
+ # 否则,使用 collection_count (上限 1000)
|
|
|
+ # 如果 collection_count 获取失败(0),则至少允许翻页
|
|
|
+ base_total = collection_count if collection_count > 0 else 1000
|
|
|
+ final_total = min(base_total, 1000)
|
|
|
+
|
|
|
+ # 确保 final_total 至少能覆盖当前页+1 (如果有满页结果)
|
|
|
+ # 这样用户能看到"下一页"按钮
|
|
|
+ if final_total <= offset + current_count:
|
|
|
+ final_total = offset + current_count + 10 # 预留一页
|
|
|
+
|
|
|
+ return KBSearchResponse(results=formatted_results, total=final_total)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Search error: {e}")
|