""" 检索引擎业务逻辑服务 """ from math import ceil from typing import List, Optional, Tuple, Dict from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, or_ from datetime import datetime import uuid import random import json import hashlib import math from app.sample.models.search_engine import SearchEngine from app.sample.schemas.search_engine import ( SearchEngineCreate, SearchEngineUpdate, KBSearchRequest, KBSearchResultItem, KBSearchResponse ) from app.schemas.base import PaginationSchema from app.services.milvus_service import milvus_service from app.utils.vector_utils import text_to_vector_algo import logging class SearchEngineService: async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse: """ 知识库搜索 (基于算法向量) """ kb_id = payload.kb_id if not milvus_service.has_collection(kb_id): return KBSearchResponse(results=[], total=0) # 1. 使用算法生成向量 (替代 Embedding 模型) # 尝试从 Milvus collection 获取向量维度,动态匹配维度 # 这样相同的查询词会生成相同的向量,具备了基本的检索能力 try: collection_detail = milvus_service.get_collection_detail(kb_id) except Exception: collection_detail = None dim = None if collection_detail and isinstance(collection_detail, dict): fields = collection_detail.get("fields", []) or [] for f in fields: # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector) if not isinstance(f, dict): continue ftype = str(f.get("type") or "").lower() if "100" in ftype or '101' in ftype: # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector # 找到向量字段,优先从 params.dim 获取维度 params = f.get("params") or {} if params and params.get("dim"): try: dim = int(params.get("dim")) break except Exception: dim = None # 回退默认维度 if not dim: dim = 768 # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等 anns_field = "vector" if collection_detail and isinstance(collection_detail, dict): fields = collection_detail.get("fields", []) or [] # 优先寻找有 params.dim 的向量字段 for f in fields: if not isinstance(f, dict): continue params = f.get("params") or {} if params and params.get("dim") and f.get("name"): anns_field = f.get("name") try: dim = int(params.get("dim")) except Exception: pass break # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector" if anns_field == "vector": for f in fields: if not isinstance(f, dict): continue fname = (f.get("name") or "") ftype = str(f.get("type") or "").lower() if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"): anns_field = fname break if fname and "vector" in ftype: anns_field = fname break # 1. 向量搜索 (Dense Retrieval) # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心 # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等 # 目前后端实现主要是基于 Milvus 的 ANN 搜索 # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度) # 除非用户明确指定了其他度量方式(通常不会) requested_metric = payload.metric_type use_hybrid = False # 只有当 metric_type 为 None 或者特定值时才尝试混合检索 # 或者我们可以认为只要不指定,就优先尝试混合 if not requested_metric or requested_metric.lower() == 'hybrid': use_hybrid = True search_params = { "metric_type": "L2", # 默认内部计算用 L2 "params": {"nprobe": 10}, } # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性) if payload.metric_type and payload.metric_type.upper() != 'HYBRID': search_params["metric_type"] = payload.metric_type # 2. 构建过滤表达式 expr_list = [] # 兼容旧的单一字段过滤 if payload.metadata_field and payload.metadata_value: safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip() safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip() if safe_field and safe_value: if safe_value.isdigit(): expr_list.append(f'{safe_field} == {safe_value}') else: expr_list.append(f'{safe_field} == "{safe_value}"') # 处理新的多重过滤 if payload.filters: for f in payload.filters: # 特殊处理文档过滤 (IN 查询) if f.field == 'doc_name_in': try: doc_names = json.loads(f.value) if isinstance(doc_names, list) and doc_names: # 构建 doc_name in ["A", "B"] # 注意:Schema 中 doc_name 字段名可能不统一,通常是 doc_name, file_name, title # 这里我们需要尝试匹配正确的字段名。 # 假设我们在 create 时主要存的是 file_name 或 doc_name # 简单起见,我们尝试对常见字段做 OR,但这在 Milvus expr 中可能复杂 # 更稳妥的是我们在存数据时统一了字段。 # 假设统一用 "file_name" 或 "doc_name" # 获取 collection fields target_field = "file_name" # 默认 if collection_detail and isinstance(collection_detail, dict): fields = [fl.get("name") for fl in collection_detail.get("fields", []) if isinstance(fl, dict)] if "doc_name" in fields: target_field = "doc_name" elif "title" in fields: target_field = "title" # 构建 IN 列表 in_values = ",".join([f'"{name}"' for name in doc_names]) expr_list.append(f'{target_field} in [{in_values}]') except Exception as e: print(f"Error parsing doc_name_in: {e}") else: safe_field = f.field.replace("'", "").replace('"', "").strip() safe_value = f.value.replace("'", "").replace('"', "").strip() if safe_field and safe_value: if safe_value.isdigit(): expr_list.append(f'{safe_field} == {safe_value}') else: expr_list.append(f'{safe_field} == "{safe_value}"') # 组合所有条件 (使用 AND) expr = " and ".join(expr_list) if expr_list else "" # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了) query_vector = text_to_vector_algo(payload.query, dim=dim) # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要) metric_type = None # 优先从 collection_detail 检测真实 metric if collection_detail and isinstance(collection_detail, dict): indices = collection_detail.get("indices") or [] if isinstance(indices, list) and len(indices) > 0: for idx in indices: try: mt = idx.get("metric_type") or idx.get("metric") if mt: metric_type = str(mt).upper() break except Exception: continue # 尝试从 properties 中读取 if not metric_type and collection_detail and isinstance(collection_detail, dict): props = collection_detail.get("properties") or {} if isinstance(props, dict): mt = props.get("metric_type") or props.get("metric") if mt: metric_type = str(mt).upper() actual_search_metric = metric_type if not actual_search_metric: # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2 actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2" metric_type = actual_search_metric logger = logging.getLogger(__name__) logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})") # 3. 执行 Milvus 搜索 try: # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误 # metric_type 已在上面检测并存放于变量 metric_type search_params = { "metric_type": metric_type, "params": {"nprobe": 10} } # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段 # 优先使用 page/page_size 计算 limit 和 offset page = payload.page if payload.page and payload.page > 0 else 1 page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10 # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size # 但这里为了清晰,优先使用 page_size offset = (page - 1) * page_size limit = page_size # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护 if offset + limit > 16384: # 如果超出深度分页限制,可能需要提示或截断 # 这里暂时不做处理,让 Milvus 报错或自行截断 pass # 获取集合总数用于分页显示 (total) collection_count = 0 if collection_detail and isinstance(collection_detail, dict): collection_count = collection_detail.get("entity_count") or 0 if not collection_count: try: stats = milvus_service.client.get_collection_stats(collection_name=kb_id) collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0 except Exception: collection_count = 0 # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0) # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑 try: # 尝试使用混合检索 (Hybrid Search) # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用 # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通 # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退 # 这里我们直接调用 milvus_service.hybrid_search # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配 use_hybrid = False # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突 # 或者我们可以认为只要不指定,就优先尝试混合 # 已经在上面判断过 use_hybrid = True 了 if use_hybrid: logger.info(f"Attempting hybrid search for KB={kb_id}") try: # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset # 所以我们需要获取 top_k = offset + limit,然后手动切片 target_k = offset + limit hybrid_results = milvus_service.hybrid_search( collection_name=kb_id, query_text=payload.query, top_k=target_k, expr=expr if expr else None ) # 手动切片实现分页 start = offset end = offset + limit # 确保不越界 if start >= len(hybrid_results): sliced_results = [] else: sliced_results = hybrid_results[start:end] formatted_results = [] for item in sliced_results: _meta = item.get('metadata', {}) or {} if isinstance(_meta, str): try: _meta = json.loads(_meta) except Exception: _meta = {} formatted_results.append(KBSearchResultItem( id=str(item.get('id')), kb_name=kb_id, doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档", content=item.get('text_content') or "", meta_info=str(item.get('metadata', {})), document_id=_meta.get('document_id'), metadata=_meta if isinstance(_meta, dict) else None, score=item.get('similarity', 0) * 100 # 假设是 0-1 )) return KBSearchResponse(results=formatted_results, total=collection_count) except Exception as hybrid_err: logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}") # Fallback to standard vector search below pass results = milvus_service.client.search( collection_name=kb_id, data=[query_vector], anns_field=anns_field, search_params=search_params, limit=limit, offset=offset, # 添加 offset 支持分页 filter=expr if expr else "", output_fields=["*"] ) except Exception as milvus_err: # 捕获 Milvus 异常,常见原因包括 metric mismatch logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}") # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试 error_msg = str(milvus_err) if "metric type not match" in error_msg: import re # 匹配 expected=COSINE 或 expected='COSINE' 等格式 # 支持 COSINE, L2, IP, BM25 等 match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg) if match: correct_metric = match.group(1).upper() logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}") # 更新 metric_type 并重试搜索 search_params["metric_type"] = correct_metric # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度 metric_type = correct_metric # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的 # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector) # 但这里假设 anns_field 是正确的,只是 metric 不对 results = milvus_service.client.search( collection_name=kb_id, data=[query_vector], anns_field=anns_field, search_params=search_params, limit=limit, offset=offset, # 同样加上 offset filter=expr if expr else "", output_fields=["*"] ) else: raise else: raise # 4. 格式化结果 formatted_results = [] for hits in results: for hit in hits: # 过滤低相似度结果 (算法生成的向量相似度可能较低,阈值可适当调低) # if hit.score < payload.score_threshold: # continue entity = hit.entity content = entity.get("text") or entity.get("content") or entity.get("page_content") or "" if not content: debug_data = {k: v for k, v in entity.items() if k != anns_field} content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..." doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档" _meta = entity.get("metadata") or {} if isinstance(_meta, str): try: _meta = json.loads(_meta) except Exception: _meta = {} document_id = entity.get("document_id") or (_meta.get("document_id") if isinstance(_meta, dict) else None) meta_info = [] for k, v in entity.items(): if k not in [anns_field, "text", "content", "page_content", "id", "pk"]: meta_info.append(f"{k}: {v}") meta_str = "; ".join(meta_info[:3]) # 根据 collection 的 metric 动态计算相似度分数 # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric display_metric = requested_metric if requested_metric else metric_type similarity_pct = None try: raw_score = float(hit.score) except Exception: raw_score = None if raw_score is not None: # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换 # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况 # 更好的做法是: # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type # 2. 转换为 display_metric 要求的格式 # Case 1: 实际是 L2 (距离),用户想看 L2 if "L2" in metric_type or "EUCLIDEAN" in metric_type: distance = raw_score if display_metric and ("COSINE" in display_metric): # L2 距离转 Cosine 相似度 (仅适用于归一化向量) # dist^2 = 2(1-cos) => cos = 1 - dist^2/2 # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误 similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2) else: similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2) # Case 2: 实际是 Cosine (相似度 [-1, 1]) elif "COSINE" in metric_type: cosine_score = raw_score # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观 similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2) # Case 3: IP (内积) elif "IP" in metric_type or "INNER" in metric_type: similarity_pct = round(raw_score * 100.0, 2) # Fallback else: # 兼容 BM25 或其他未知 metric if "BM25" in metric_type: # BM25 分数通常是正数,没有固定上限,直接显示原值 similarity_pct = round(raw_score, 2) else: similarity_pct = round(raw_score * 100.0, 2) else: similarity_pct = 0.0 formatted_results.append(KBSearchResultItem( id=str(hit.id), kb_name=kb_id, doc_name=doc_name, content=content, meta_info=meta_str, document_id=document_id, metadata=_meta if isinstance(_meta, dict) else None, score=similarity_pct )) # [Fix] 批量查询 t_samp_document_main 获取准确的文档名称 # 提取所有可能的 document_id doc_ids = set() for item in formatted_results: did = item.document_id if not did: # 尝试从 metadata 中再次获取 if item.metadata and isinstance(item.metadata, dict): did = item.metadata.get("document_id") or item.metadata.get("doc_id") if did: item.document_id = did # 兼容处理 document_id 可能为 int 的情况 if did and isinstance(did, (str, int)) and len(str(did)) > 0: doc_ids.add(str(did)) if doc_ids: try: from app.sample.models.base_info import DocumentMain # 由于 search_kb 方法签名中已经传入了 db: AsyncSession,我们直接使用它 # 不需要像 SnippetService 那样重新创建连接 # 打印调试信息 # print(f"SearchEngine: Querying DocumentMain for {len(doc_ids)} ids") stmt = select(DocumentMain.id, DocumentMain.title).where(DocumentMain.id.in_(list(doc_ids))) doc_res = await db.execute(stmt) rows = doc_res.all() doc_name_map = {str(row[0]): row[1] for row in rows} if doc_name_map: for item in formatted_results: did = item.document_id if did and str(did) in doc_name_map: item.doc_name = doc_name_map[str(did)] except Exception as e: import traceback traceback.print_exc() print(f"SearchEngine: Failed to fetch document names from DB: {e}") return KBSearchResponse(results=formatted_results, total=collection_count) # [Modified] 修复分页总数返回错误 except Exception as e: print(f"Search error: {e}") return KBSearchResponse(results=[], total=0) # ... (Keep existing CRUD methods below) ... async def get_list( self, db: AsyncSession, page: int = 1, page_size: int = 10, keyword: Optional[str] = None, status: Optional[str] = None ) -> Tuple[List[SearchEngine], PaginationSchema]: """获取检索引擎列表""" query = select(SearchEngine).where(SearchEngine.is_deleted == 0) if keyword: query = query.where(or_( SearchEngine.name.like(f"%{keyword}%"), SearchEngine.description.like(f"%{keyword}%") )) if status: query = query.where(SearchEngine.status == status) # 计算总数 count_query = select(func.count()).select_from(query.subquery()) total = await db.scalar(count_query) or 0 # 分页查询 query = query.order_by(SearchEngine.created_at.desc()).offset((page - 1) * page_size).limit(page_size) result = await db.execute(query) items = result.scalars().all() total_pages = ceil(total / page_size) if page_size else 0 meta = PaginationSchema( page=page, page_size=page_size, total=total, total_pages=total_pages, ) return items, meta async def create(self, db: AsyncSession, payload: SearchEngineCreate) -> SearchEngine: """创建检索引擎""" # 1. 检查名称是否已存在 exists = await db.execute(select(SearchEngine).where( SearchEngine.name == payload.name, SearchEngine.is_deleted == 0 )) if exists.scalars().first(): raise ValueError("检索引擎名称已存在") try: now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") new_engine = SearchEngine( id=str(uuid.uuid4()), name=payload.name, engine_type=payload.engine_type, base_url=payload.base_url, api_key=payload.api_key, description=payload.description, status=payload.status or "normal", created_at=now, updated_at=now ) db.add(new_engine) await db.commit() await db.refresh(new_engine) return new_engine except Exception as e: await db.rollback() raise e async def update(self, db: AsyncSession, id: str, payload: SearchEngineUpdate) -> SearchEngine: """更新检索引擎信息""" result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0)) engine = result.scalars().first() if not engine: raise ValueError("检索引擎不存在") try: if payload.name is not None: engine.name = payload.name if payload.engine_type is not None: engine.engine_type = payload.engine_type if payload.base_url is not None: engine.base_url = payload.base_url if payload.api_key is not None: engine.api_key = payload.api_key if payload.description is not None: engine.description = payload.description if payload.status is not None: engine.status = payload.status engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await db.commit() await db.refresh(engine) return engine except Exception as e: await db.rollback() raise e async def update_status(self, db: AsyncSession, id: str, status: str) -> SearchEngine: """更新检索引擎状态""" result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0)) engine = result.scalars().first() if not engine: raise ValueError("检索引擎不存在") try: engine.status = status engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await db.commit() await db.refresh(engine) return engine except Exception as e: await db.rollback() raise e async def delete(self, db: AsyncSession, id: str) -> None: """删除检索引擎""" result = await db.execute(select(SearchEngine).where(SearchEngine.id == id)) engine = result.scalars().first() if not engine: raise ValueError("检索引擎不存在") try: # 软删除 engine.is_deleted = 1 engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await db.commit() except Exception as e: await db.rollback() raise e search_engine_service = SearchEngineService()