|
|
@@ -0,0 +1,246 @@
|
|
|
+"""
|
|
|
+检索引擎业务逻辑服务
|
|
|
+"""
|
|
|
+from math import ceil
|
|
|
+from typing import List, Optional, Tuple, Dict
|
|
|
+from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
+from sqlalchemy import select, func, or_
|
|
|
+from datetime import datetime
|
|
|
+import uuid
|
|
|
+import random
|
|
|
+import json
|
|
|
+import hashlib
|
|
|
+import math
|
|
|
+
|
|
|
+from app.sample.models.search_engine import SearchEngine
|
|
|
+from app.sample.schemas.search_engine import (
|
|
|
+ SearchEngineCreate,
|
|
|
+ SearchEngineUpdate,
|
|
|
+ KBSearchRequest,
|
|
|
+ KBSearchResultItem,
|
|
|
+ KBSearchResponse
|
|
|
+)
|
|
|
+from app.schemas.base import PaginationSchema
|
|
|
+from app.services.milvus_service import milvus_service
|
|
|
+from app.utils.vector_utils import text_to_vector_algo
|
|
|
+
|
|
|
+class SearchEngineService:
|
|
|
+
|
|
|
+ async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse:
|
|
|
+ """
|
|
|
+ 知识库搜索 (基于算法向量)
|
|
|
+ """
|
|
|
+ kb_id = payload.kb_id
|
|
|
+
|
|
|
+ if not milvus_service.has_collection(kb_id):
|
|
|
+ return KBSearchResponse(results=[], total=0)
|
|
|
+
|
|
|
+ # 1. 使用算法生成向量 (替代 Embedding 模型)
|
|
|
+ # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
|
|
|
+ query_vector = text_to_vector_algo(payload.query, dim=768)
|
|
|
+
|
|
|
+ # 2. 构建过滤表达式
|
|
|
+ expr = ""
|
|
|
+ if payload.metadata_field and payload.metadata_value:
|
|
|
+ # 示例:假设元数据直接作为字段存在,或者在 extra_info JSON 中
|
|
|
+ # 这里需要根据实际 Milvus Collection 的 Schema 调整
|
|
|
+ # 暂时忽略,以免报错
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 3. 执行 Milvus 搜索
|
|
|
+ try:
|
|
|
+ search_params = {
|
|
|
+ "metric_type": "COSINE",
|
|
|
+ "params": {"nprobe": 10}
|
|
|
+ }
|
|
|
+
|
|
|
+ results = milvus_service.client.search(
|
|
|
+ collection_name=kb_id,
|
|
|
+ data=[query_vector],
|
|
|
+ anns_field="vector",
|
|
|
+ search_params=search_params,
|
|
|
+ limit=payload.top_k,
|
|
|
+ filter=expr if expr else "",
|
|
|
+ output_fields=["*"]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 4. 格式化结果
|
|
|
+ formatted_results = []
|
|
|
+ for hits in results:
|
|
|
+ for hit in hits:
|
|
|
+ # 过滤低相似度结果 (算法生成的向量相似度可能较低,阈值可适当调低)
|
|
|
+ # if hit.score < payload.score_threshold:
|
|
|
+ # continue
|
|
|
+
|
|
|
+ entity = hit.entity
|
|
|
+
|
|
|
+ content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
|
|
|
+ if not content:
|
|
|
+ debug_data = {k:v for k,v in entity.items() if k != "vector"}
|
|
|
+ content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
|
|
|
+
|
|
|
+ doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
|
|
|
+
|
|
|
+ meta_info = []
|
|
|
+ for k, v in entity.items():
|
|
|
+ if k not in ["vector", "text", "content", "page_content", "id", "pk"]:
|
|
|
+ meta_info.append(f"{k}: {v}")
|
|
|
+ meta_str = "; ".join(meta_info[:3])
|
|
|
+
|
|
|
+ formatted_results.append(KBSearchResultItem(
|
|
|
+ id=str(hit.id),
|
|
|
+ kb_name=kb_id,
|
|
|
+ doc_name=doc_name,
|
|
|
+ content=content,
|
|
|
+ meta_info=meta_str,
|
|
|
+ score=round(hit.score * 100, 2)
|
|
|
+ ))
|
|
|
+
|
|
|
+ return KBSearchResponse(results=formatted_results, total=len(formatted_results))
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Search error: {e}")
|
|
|
+ return KBSearchResponse(results=[], total=0)
|
|
|
+
|
|
|
+ # ... (Keep existing CRUD methods below) ...
|
|
|
+
|
|
|
+ async def get_list(
|
|
|
+ self,
|
|
|
+ db: AsyncSession,
|
|
|
+ page: int = 1,
|
|
|
+ page_size: int = 10,
|
|
|
+ keyword: Optional[str] = None,
|
|
|
+ status: Optional[str] = None
|
|
|
+ ) -> Tuple[List[SearchEngine], PaginationSchema]:
|
|
|
+ """获取检索引擎列表"""
|
|
|
+
|
|
|
+ query = select(SearchEngine).where(SearchEngine.is_deleted == 0)
|
|
|
+
|
|
|
+ if keyword:
|
|
|
+ query = query.where(or_(
|
|
|
+ SearchEngine.name.like(f"%{keyword}%"),
|
|
|
+ SearchEngine.description.like(f"%{keyword}%")
|
|
|
+ ))
|
|
|
+
|
|
|
+ if status:
|
|
|
+ query = query.where(SearchEngine.status == status)
|
|
|
+
|
|
|
+ # 计算总数
|
|
|
+ count_query = select(func.count()).select_from(query.subquery())
|
|
|
+ total = await db.scalar(count_query) or 0
|
|
|
+
|
|
|
+ # 分页查询
|
|
|
+ query = query.order_by(SearchEngine.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
|
|
+ result = await db.execute(query)
|
|
|
+ items = result.scalars().all()
|
|
|
+
|
|
|
+ total_pages = ceil(total / page_size) if page_size else 0
|
|
|
+
|
|
|
+ meta = PaginationSchema(
|
|
|
+ page=page,
|
|
|
+ page_size=page_size,
|
|
|
+ total=total,
|
|
|
+ total_pages=total_pages,
|
|
|
+ )
|
|
|
+
|
|
|
+ return items, meta
|
|
|
+
|
|
|
+ async def create(self, db: AsyncSession, payload: SearchEngineCreate) -> SearchEngine:
|
|
|
+ """创建检索引擎"""
|
|
|
+ # 1. 检查名称是否已存在
|
|
|
+ exists = await db.execute(select(SearchEngine).where(
|
|
|
+ SearchEngine.name == payload.name,
|
|
|
+ SearchEngine.is_deleted == 0
|
|
|
+ ))
|
|
|
+ if exists.scalars().first():
|
|
|
+ raise ValueError("检索引擎名称已存在")
|
|
|
+
|
|
|
+ try:
|
|
|
+ now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ new_engine = SearchEngine(
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
+ name=payload.name,
|
|
|
+ engine_type=payload.engine_type,
|
|
|
+ base_url=payload.base_url,
|
|
|
+ api_key=payload.api_key,
|
|
|
+ description=payload.description,
|
|
|
+ status=payload.status or "normal",
|
|
|
+ created_at=now,
|
|
|
+ updated_at=now
|
|
|
+ )
|
|
|
+ db.add(new_engine)
|
|
|
+ await db.commit()
|
|
|
+ await db.refresh(new_engine)
|
|
|
+
|
|
|
+ return new_engine
|
|
|
+ except Exception as e:
|
|
|
+ await db.rollback()
|
|
|
+ raise e
|
|
|
+
|
|
|
+ async def update(self, db: AsyncSession, id: str, payload: SearchEngineUpdate) -> SearchEngine:
|
|
|
+ """更新检索引擎信息"""
|
|
|
+ result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
|
|
|
+ engine = result.scalars().first()
|
|
|
+
|
|
|
+ if not engine:
|
|
|
+ raise ValueError("检索引擎不存在")
|
|
|
+
|
|
|
+ try:
|
|
|
+ if payload.name is not None:
|
|
|
+ engine.name = payload.name
|
|
|
+ if payload.engine_type is not None:
|
|
|
+ engine.engine_type = payload.engine_type
|
|
|
+ if payload.base_url is not None:
|
|
|
+ engine.base_url = payload.base_url
|
|
|
+ if payload.api_key is not None:
|
|
|
+ engine.api_key = payload.api_key
|
|
|
+ if payload.description is not None:
|
|
|
+ engine.description = payload.description
|
|
|
+ if payload.status is not None:
|
|
|
+ engine.status = payload.status
|
|
|
+
|
|
|
+ engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ await db.commit()
|
|
|
+ await db.refresh(engine)
|
|
|
+
|
|
|
+ return engine
|
|
|
+ except Exception as e:
|
|
|
+ await db.rollback()
|
|
|
+ raise e
|
|
|
+
|
|
|
+ async def update_status(self, db: AsyncSession, id: str, status: str) -> SearchEngine:
|
|
|
+ """更新检索引擎状态"""
|
|
|
+ result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
|
|
|
+ engine = result.scalars().first()
|
|
|
+
|
|
|
+ if not engine:
|
|
|
+ raise ValueError("检索引擎不存在")
|
|
|
+
|
|
|
+ try:
|
|
|
+ engine.status = status
|
|
|
+ engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ await db.commit()
|
|
|
+ await db.refresh(engine)
|
|
|
+ return engine
|
|
|
+ except Exception as e:
|
|
|
+ await db.rollback()
|
|
|
+ raise e
|
|
|
+
|
|
|
+ async def delete(self, db: AsyncSession, id: str) -> None:
|
|
|
+ """删除检索引擎"""
|
|
|
+ result = await db.execute(select(SearchEngine).where(SearchEngine.id == id))
|
|
|
+ engine = result.scalars().first()
|
|
|
+
|
|
|
+ if not engine:
|
|
|
+ raise ValueError("检索引擎不存在")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 软删除
|
|
|
+ engine.is_deleted = 1
|
|
|
+ engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ await db.commit()
|
|
|
+ except Exception as e:
|
|
|
+ await db.rollback()
|
|
|
+ raise e
|
|
|
+
|
|
|
+search_engine_service = SearchEngineService()
|