| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- """
- 检索引擎业务逻辑服务
- """
- from math import ceil
- from typing import List, Optional, Tuple, Dict
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select, func, or_
- from datetime import datetime
- import uuid
- import random
- import json
- import hashlib
- import math
- from app.sample.models.search_engine import SearchEngine
- from app.sample.schemas.search_engine import (
- SearchEngineCreate,
- SearchEngineUpdate,
- KBSearchRequest,
- KBSearchResultItem,
- KBSearchResponse
- )
- from app.schemas.base import PaginationSchema
- from app.services.milvus_service import milvus_service
- from app.utils.vector_utils import text_to_vector_algo
- import logging
- class SearchEngineService:
-
- async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse:
- """
- 知识库搜索 (基于算法向量)
- """
- kb_id = payload.kb_id
-
- if not milvus_service.has_collection(kb_id):
- return KBSearchResponse(results=[], total=0)
-
- # 1. 使用算法生成向量 (替代 Embedding 模型)
- # 尝试从 Milvus collection 获取向量维度,动态匹配维度
- # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
- try:
- collection_detail = milvus_service.get_collection_detail(kb_id)
- except Exception:
- collection_detail = None
- dim = None
- if collection_detail and isinstance(collection_detail, dict):
- fields = collection_detail.get("fields", []) or []
- for f in fields:
- # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
- if not isinstance(f, dict):
- continue
- ftype = str(f.get("type") or "").lower()
- print(ftype+'是什么东西')
- if "100" in ftype or '101' in ftype: # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
- # 找到向量字段,优先从 params.dim 获取维度
- params = f.get("params") or {}
- if params and params.get("dim"):
- try:
- dim = int(params.get("dim"))
- break
- except Exception:
- dim = None
- # 回退默认维度
- if not dim:
- dim = 768
- # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
- anns_field = "vector"
- if collection_detail and isinstance(collection_detail, dict):
- fields = collection_detail.get("fields", []) or []
- # 优先寻找有 params.dim 的向量字段
- for f in fields:
- if not isinstance(f, dict):
- continue
- params = f.get("params") or {}
- if params and params.get("dim") and f.get("name"):
- anns_field = f.get("name")
- try:
- dim = int(params.get("dim"))
- except Exception:
- pass
- break
- # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
- if anns_field == "vector":
- for f in fields:
- if not isinstance(f, dict):
- continue
- fname = (f.get("name") or "")
- ftype = str(f.get("type") or "").lower()
- if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"):
- anns_field = fname
- break
- if fname and "vector" in ftype:
- anns_field = fname
- break
- # 1. 向量搜索 (Dense Retrieval)
- # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心
- # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等
- # 目前后端实现主要是基于 Milvus 的 ANN 搜索
-
- # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度)
- # 除非用户明确指定了其他度量方式(通常不会)
- requested_metric = payload.metric_type
- use_hybrid = False
-
- # 只有当 metric_type 为 None 或者特定值时才尝试混合检索
- # 或者我们可以认为只要不指定,就优先尝试混合
- if not requested_metric or requested_metric.lower() == 'hybrid':
- use_hybrid = True
-
- search_params = {
- "metric_type": "L2", # 默认内部计算用 L2
- "params": {"nprobe": 10},
- }
-
- # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性)
- if payload.metric_type and payload.metric_type.upper() != 'HYBRID':
- search_params["metric_type"] = payload.metric_type
-
- # 2. 构建过滤表达式
- expr_list = []
-
- # 兼容旧的单一字段过滤
- if payload.metadata_field and payload.metadata_value:
- safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip()
- safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip()
-
- if safe_field and safe_value:
- if safe_value.isdigit():
- expr_list.append(f'{safe_field} == {safe_value}')
- else:
- expr_list.append(f'{safe_field} == "{safe_value}"')
-
- # 处理新的多重过滤
- if payload.filters:
- for f in payload.filters:
- safe_field = f.field.replace("'", "").replace('"', "").strip()
- safe_value = f.value.replace("'", "").replace('"', "").strip()
-
- if safe_field and safe_value:
- if safe_value.isdigit():
- expr_list.append(f'{safe_field} == {safe_value}')
- else:
- expr_list.append(f'{safe_field} == "{safe_value}"')
-
- # 组合所有条件 (使用 AND)
- expr = " and ".join(expr_list) if expr_list else ""
-
- # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
- query_vector = text_to_vector_algo(payload.query, dim=dim)
-
- # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
- metric_type = None
- # 优先从 collection_detail 检测真实 metric
- if collection_detail and isinstance(collection_detail, dict):
- indices = collection_detail.get("indices") or []
- if isinstance(indices, list) and len(indices) > 0:
- for idx in indices:
- try:
- mt = idx.get("metric_type") or idx.get("metric")
- if mt:
- metric_type = str(mt).upper()
- break
- except Exception:
- continue
-
- # 尝试从 properties 中读取
- if not metric_type and collection_detail and isinstance(collection_detail, dict):
- props = collection_detail.get("properties") or {}
- if isinstance(props, dict):
- mt = props.get("metric_type") or props.get("metric")
- if mt:
- metric_type = str(mt).upper()
-
- actual_search_metric = metric_type
- if not actual_search_metric:
- # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2
- actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2"
-
- metric_type = actual_search_metric
-
- logger = logging.getLogger(__name__)
- logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})")
- # 3. 执行 Milvus 搜索
- try:
- # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误
- # metric_type 已在上面检测并存放于变量 metric_type
- search_params = {
- "metric_type": metric_type,
- "params": {"nprobe": 10}
- }
- # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段
- # 优先使用 page/page_size 计算 limit 和 offset
- page = payload.page if payload.page and payload.page > 0 else 1
- page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
-
- # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size
- # 但这里为了清晰,优先使用 page_size
-
- offset = (page - 1) * page_size
- limit = page_size
-
- # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护
- if offset + limit > 16384:
- # 如果超出深度分页限制,可能需要提示或截断
- # 这里暂时不做处理,让 Milvus 报错或自行截断
- pass
- # 获取集合总数用于分页显示 (total)
- collection_count = 0
- if collection_detail and isinstance(collection_detail, dict):
- collection_count = collection_detail.get("entity_count") or 0
-
- if not collection_count:
- try:
- stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
- collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0
- except Exception:
- collection_count = 0
- # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0)
- # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
-
- try:
- # 尝试使用混合检索 (Hybrid Search)
- # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
- # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
-
- # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
- # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
-
- # 这里我们直接调用 milvus_service.hybrid_search
- # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
-
- use_hybrid = False
- # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
- # 或者我们可以认为只要不指定,就优先尝试混合
- # 已经在上面判断过 use_hybrid = True 了
-
- if use_hybrid:
- logger.info(f"Attempting hybrid search for KB={kb_id}")
- try:
- # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
- # 所以我们需要获取 top_k = offset + limit,然后手动切片
- target_k = offset + limit
-
- hybrid_results = milvus_service.hybrid_search(
- collection_name=kb_id,
- query_text=payload.query,
- top_k=target_k
- )
-
- # 手动切片实现分页
- start = offset
- end = offset + limit
- # 确保不越界
- if start >= len(hybrid_results):
- sliced_results = []
- else:
- sliced_results = hybrid_results[start:end]
-
- formatted_results = []
- for item in sliced_results:
- formatted_results.append(KBSearchResultItem(
- id=str(item.get('id')),
- kb_name=kb_id,
- doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档",
- content=item.get('text_content') or "",
- meta_info=str(item.get('metadata', {})),
- score=item.get('similarity', 0) * 100 # 假设是 0-1
- ))
- return KBSearchResponse(results=formatted_results, total=collection_count)
- except Exception as hybrid_err:
- logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
- # Fallback to standard vector search below
- pass
- results = milvus_service.client.search(
- collection_name=kb_id,
- data=[query_vector],
- anns_field=anns_field,
- search_params=search_params,
- limit=limit,
- offset=offset, # 添加 offset 支持分页
- filter=expr if expr else "",
- output_fields=["*"]
- )
- except Exception as milvus_err:
- # 捕获 Milvus 异常,常见原因包括 metric mismatch
- logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}")
-
- # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试
- error_msg = str(milvus_err)
- if "metric type not match" in error_msg:
- import re
- # 匹配 expected=COSINE 或 expected='COSINE' 等格式
- # 支持 COSINE, L2, IP, BM25 等
- match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg)
- if match:
- correct_metric = match.group(1).upper()
- logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}")
-
- # 更新 metric_type 并重试搜索
- search_params["metric_type"] = correct_metric
- # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度
- metric_type = correct_metric
-
- # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的
- # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector)
- # 但这里假设 anns_field 是正确的,只是 metric 不对
-
- results = milvus_service.client.search(
- collection_name=kb_id,
- data=[query_vector],
- anns_field=anns_field,
- search_params=search_params,
- limit=limit,
- offset=offset, # 同样加上 offset
- filter=expr if expr else "",
- output_fields=["*"]
- )
- else:
- raise
- else:
- raise
-
- # 4. 格式化结果
- formatted_results = []
- for hits in results:
- for hit in hits:
- # 过滤低相似度结果 (算法生成的向量相似度可能较低,阈值可适当调低)
- # if hit.score < payload.score_threshold:
- # continue
-
- entity = hit.entity
- content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
- if not content:
- debug_data = {k: v for k, v in entity.items() if k != anns_field}
- content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
-
- doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
-
- meta_info = []
- for k, v in entity.items():
- if k not in [anns_field, "text", "content", "page_content", "id", "pk"]:
- meta_info.append(f"{k}: {v}")
- meta_str = "; ".join(meta_info[:3])
-
- # 根据 collection 的 metric 动态计算相似度分数
- # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
- display_metric = requested_metric if requested_metric else metric_type
-
- similarity_pct = None
- try:
- raw_score = float(hit.score)
- except Exception:
- raw_score = None
- if raw_score is not None:
- # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
- # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
- # 更好的做法是:
- # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
- # 2. 转换为 display_metric 要求的格式
-
- # Case 1: 实际是 L2 (距离),用户想看 L2
- if "L2" in metric_type or "EUCLIDEAN" in metric_type:
- distance = raw_score
- if display_metric and ("COSINE" in display_metric):
- # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
- # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
- # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
- similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
- else:
- similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
-
- # Case 2: 实际是 Cosine (相似度 [-1, 1])
- elif "COSINE" in metric_type:
- cosine_score = raw_score
- # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
- similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
-
- # Case 3: IP (内积)
- elif "IP" in metric_type or "INNER" in metric_type:
- similarity_pct = round(raw_score * 100.0, 2)
-
- # Fallback
- else:
- # 兼容 BM25 或其他未知 metric
- if "BM25" in metric_type:
- # BM25 分数通常是正数,没有固定上限,直接显示原值
- similarity_pct = round(raw_score, 2)
- else:
- similarity_pct = round(raw_score * 100.0, 2)
- else:
- similarity_pct = 0.0
- formatted_results.append(KBSearchResultItem(
- id=str(hit.id),
- kb_name=kb_id,
- doc_name=doc_name,
- content=content,
- meta_info=meta_str,
- score=similarity_pct
- ))
-
- return KBSearchResponse(results=formatted_results, total=len(formatted_results))
-
- except Exception as e:
- print(f"Search error: {e}")
- return KBSearchResponse(results=[], total=0)
- # ... (Keep existing CRUD methods below) ...
-
- async def get_list(
- self,
- db: AsyncSession,
- page: int = 1,
- page_size: int = 10,
- keyword: Optional[str] = None,
- status: Optional[str] = None
- ) -> Tuple[List[SearchEngine], PaginationSchema]:
- """获取检索引擎列表"""
-
- query = select(SearchEngine).where(SearchEngine.is_deleted == 0)
-
- if keyword:
- query = query.where(or_(
- SearchEngine.name.like(f"%{keyword}%"),
- SearchEngine.description.like(f"%{keyword}%")
- ))
-
- if status:
- query = query.where(SearchEngine.status == status)
- # 计算总数
- count_query = select(func.count()).select_from(query.subquery())
- total = await db.scalar(count_query) or 0
- # 分页查询
- query = query.order_by(SearchEngine.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
- result = await db.execute(query)
- items = result.scalars().all()
- total_pages = ceil(total / page_size) if page_size else 0
-
- meta = PaginationSchema(
- page=page,
- page_size=page_size,
- total=total,
- total_pages=total_pages,
- )
-
- return items, meta
- async def create(self, db: AsyncSession, payload: SearchEngineCreate) -> SearchEngine:
- """创建检索引擎"""
- # 1. 检查名称是否已存在
- exists = await db.execute(select(SearchEngine).where(
- SearchEngine.name == payload.name,
- SearchEngine.is_deleted == 0
- ))
- if exists.scalars().first():
- raise ValueError("检索引擎名称已存在")
- try:
- now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- new_engine = SearchEngine(
- id=str(uuid.uuid4()),
- name=payload.name,
- engine_type=payload.engine_type,
- base_url=payload.base_url,
- api_key=payload.api_key,
- description=payload.description,
- status=payload.status or "normal",
- created_at=now,
- updated_at=now
- )
- db.add(new_engine)
- await db.commit()
- await db.refresh(new_engine)
- return new_engine
- except Exception as e:
- await db.rollback()
- raise e
- async def update(self, db: AsyncSession, id: str, payload: SearchEngineUpdate) -> SearchEngine:
- """更新检索引擎信息"""
- result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
- engine = result.scalars().first()
-
- if not engine:
- raise ValueError("检索引擎不存在")
- try:
- if payload.name is not None:
- engine.name = payload.name
- if payload.engine_type is not None:
- engine.engine_type = payload.engine_type
- if payload.base_url is not None:
- engine.base_url = payload.base_url
- if payload.api_key is not None:
- engine.api_key = payload.api_key
- if payload.description is not None:
- engine.description = payload.description
- if payload.status is not None:
- engine.status = payload.status
-
- engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- await db.commit()
- await db.refresh(engine)
-
- return engine
- except Exception as e:
- await db.rollback()
- raise e
- async def update_status(self, db: AsyncSession, id: str, status: str) -> SearchEngine:
- """更新检索引擎状态"""
- result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
- engine = result.scalars().first()
-
- if not engine:
- raise ValueError("检索引擎不存在")
-
- try:
- engine.status = status
- engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- await db.commit()
- await db.refresh(engine)
- return engine
- except Exception as e:
- await db.rollback()
- raise e
- async def delete(self, db: AsyncSession, id: str) -> None:
- """删除检索引擎"""
- result = await db.execute(select(SearchEngine).where(SearchEngine.id == id))
- engine = result.scalars().first()
-
- if not engine:
- raise ValueError("检索引擎不存在")
- try:
- # 软删除
- engine.is_deleted = 1
- engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- await db.commit()
- except Exception as e:
- await db.rollback()
- raise e
- search_engine_service = SearchEngineService()
|