Explorar el Código

检索的初步

linyang hace 1 mes
padre
commit
8ac252a6a8

+ 23 - 0
src/app/sample/models/search_engine.py

@@ -0,0 +1,23 @@
+"""
+检索引擎数据库模型
+"""
+from sqlalchemy import Column, String, Integer, Text
+from app.base.async_mysql_connection import Base
+
+class SearchEngine(Base):
+    """检索引擎模型"""
+    __tablename__ = "search_engine"
+
+    id = Column(String(36), primary_key=True, comment="ID")
+    name = Column(String(100), nullable=False, comment="引擎名称")
+    engine_type = Column(String(50), nullable=False, comment="引擎类型: google, bing, duckduckgo, custom")
+    base_url = Column(String(255), nullable=True, comment="基础URL")
+    api_key = Column(String(255), nullable=True, comment="API Key")
+    description = Column(Text, nullable=True, comment="描述")
+    status = Column(String(20), default="normal", comment="状态: normal(正常), disabled(禁用)")
+    is_deleted = Column(Integer, default=0, comment="是否删除")
+    created_at = Column(String(32), comment="创建时间")
+    updated_at = Column(String(32), comment="更新时间")
+
+    def __repr__(self):
+        return f"<SearchEngine {self.name}>"

+ 68 - 0
src/app/sample/schemas/search_engine.py

@@ -0,0 +1,68 @@
+from typing import Optional, List, Dict, Any
+from pydantic import BaseModel, Field
+from app.schemas.base import BaseModelSchema
+from datetime import datetime
+
+class SearchEngineBase(BaseModel):
+    name: str = Field(..., description="引擎名称")
+    engine_type: str = Field(..., description="引擎类型")
+    base_url: Optional[str] = Field(None, description="基础URL")
+    api_key: Optional[str] = Field(None, description="API Key")
+    description: Optional[str] = Field(None, description="描述")
+    status: Optional[str] = Field("normal", description="状态")
+
+class SearchEngineCreate(SearchEngineBase):
+    """创建检索引擎请求参数"""
+    pass
+
+class SearchEngineUpdate(BaseModel):
+    """更新检索引擎请求参数"""
+    name: Optional[str] = None
+    engine_type: Optional[str] = None
+    base_url: Optional[str] = None
+    api_key: Optional[str] = None
+    description: Optional[str] = None
+    status: Optional[str] = None
+
+class SearchEngineResponse(BaseModelSchema):
+    """检索引擎响应模型"""
+    id: str
+    name: str
+    engine_type: str
+    base_url: Optional[str]
+    api_key: Optional[str]
+    description: Optional[str]
+    status: str
+    created_at: Optional[datetime] = None
+    updated_at: Optional[datetime] = None
+
+    class Config:
+        from_attributes = True
+        json_encoders = {
+            datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") if v else None
+        }
+
+# --- 新增:知识库搜索相关模型 ---
+
+class KBSearchRequest(BaseModel):
+    """知识库搜索请求"""
+    kb_id: str = Field(..., description="知识库ID或集合名称")
+    query: str = Field(..., description="检索关键字")
+    metadata_field: Optional[str] = Field(None, description="元数据字典字段")
+    metadata_value: Optional[str] = Field(None, description="元数据字典值")
+    top_k: int = Field(10, description="返回结果数量")
+    score_threshold: float = Field(0.0, description="相似度阈值")
+
+class KBSearchResultItem(BaseModel):
+    """单条搜索结果"""
+    id: str
+    kb_name: str
+    doc_name: str
+    content: str
+    meta_info: str
+    score: float
+    
+class KBSearchResponse(BaseModel):
+    """搜索响应"""
+    results: List[KBSearchResultItem]
+    total: int

+ 4 - 2
src/app/server/app.py

@@ -49,9 +49,10 @@ from views.auth_view import router as auth_router
 from views.knowledge_base_view import router as knowledge_base_router
 from views.snippet_view import router as snippet_router
 from views.tag_view import router as tag_router
+from views.search_engine_view import router as search_engine_router
 
 # 导入现有API路由
-from app.api.v1.api_router import api_router
+# from app.api.v1.api_router import api_router
 
 
 @asynccontextmanager
@@ -241,7 +242,7 @@ async def root():
 
 # 包含API路由
 # 现有的API路由(保持兼容)
-app.include_router(api_router, prefix="/api/v1")
+# app.include_router(api_router, prefix="/api/v1")
 # 新的模块化视图路由
 app.include_router(system_router, prefix="/api/v1")
 app.include_router(oauth_router, prefix="")
@@ -250,6 +251,7 @@ app.include_router(sample_router, prefix="/api/v1")
 app.include_router(knowledge_base_router, prefix="/api/v1")
 app.include_router(snippet_router, prefix="/api/v1")
 app.include_router(tag_router, prefix="/api/v1")
+app.include_router(search_engine_router, prefix="/api/v1")
 
 
 def create_app() -> FastAPI:

+ 246 - 0
src/app/services/search_engine_service.py

@@ -0,0 +1,246 @@
+"""
+检索引擎业务逻辑服务
+"""
+from math import ceil
+from typing import List, Optional, Tuple, Dict
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, func, or_
+from datetime import datetime
+import uuid
+import random
+import json
+import hashlib
+import math
+
+from app.sample.models.search_engine import SearchEngine
+from app.sample.schemas.search_engine import (
+    SearchEngineCreate,
+    SearchEngineUpdate,
+    KBSearchRequest,
+    KBSearchResultItem,
+    KBSearchResponse
+)
+from app.schemas.base import PaginationSchema
+from app.services.milvus_service import milvus_service
+from app.utils.vector_utils import text_to_vector_algo
+
+class SearchEngineService:
+    
+    async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse:
+        """
+        知识库搜索 (基于算法向量)
+        """
+        kb_id = payload.kb_id 
+        
+        if not milvus_service.has_collection(kb_id):
+            return KBSearchResponse(results=[], total=0)
+            
+        # 1. 使用算法生成向量 (替代 Embedding 模型)
+        # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
+        query_vector = text_to_vector_algo(payload.query, dim=768)
+        
+        # 2. 构建过滤表达式
+        expr = ""
+        if payload.metadata_field and payload.metadata_value:
+            # 示例:假设元数据直接作为字段存在,或者在 extra_info JSON 中
+            # 这里需要根据实际 Milvus Collection 的 Schema 调整
+            # 暂时忽略,以免报错
+            pass
+            
+        # 3. 执行 Milvus 搜索
+        try:
+            search_params = {
+                "metric_type": "COSINE", 
+                "params": {"nprobe": 10}
+            }
+            
+            results = milvus_service.client.search(
+                collection_name=kb_id,
+                data=[query_vector],
+                anns_field="vector", 
+                search_params=search_params,
+                limit=payload.top_k,
+                filter=expr if expr else "",
+                output_fields=["*"] 
+            )
+            
+            # 4. 格式化结果
+            formatted_results = []
+            for hits in results:
+                for hit in hits:
+                    # 过滤低相似度结果 (算法生成的向量相似度可能较低,阈值可适当调低)
+                    # if hit.score < payload.score_threshold:
+                    #     continue
+                        
+                    entity = hit.entity
+                    
+                    content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
+                    if not content:
+                        debug_data = {k:v for k,v in entity.items() if k != "vector"}
+                        content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
+                        
+                    doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
+                    
+                    meta_info = []
+                    for k, v in entity.items():
+                        if k not in ["vector", "text", "content", "page_content", "id", "pk"]:
+                            meta_info.append(f"{k}: {v}")
+                    meta_str = "; ".join(meta_info[:3])
+                    
+                    formatted_results.append(KBSearchResultItem(
+                        id=str(hit.id),
+                        kb_name=kb_id, 
+                        doc_name=doc_name,
+                        content=content,
+                        meta_info=meta_str,
+                        score=round(hit.score * 100, 2)
+                    ))
+            
+            return KBSearchResponse(results=formatted_results, total=len(formatted_results))
+            
+        except Exception as e:
+            print(f"Search error: {e}")
+            return KBSearchResponse(results=[], total=0)
+
+    # ... (Keep existing CRUD methods below) ...
+    
+    async def get_list(
+        self, 
+        db: AsyncSession,
+        page: int = 1,
+        page_size: int = 10,
+        keyword: Optional[str] = None,
+        status: Optional[str] = None
+    ) -> Tuple[List[SearchEngine], PaginationSchema]:
+        """获取检索引擎列表"""
+        
+        query = select(SearchEngine).where(SearchEngine.is_deleted == 0)
+        
+        if keyword:
+            query = query.where(or_(
+                SearchEngine.name.like(f"%{keyword}%"),
+                SearchEngine.description.like(f"%{keyword}%")
+            ))
+        
+        if status:
+            query = query.where(SearchEngine.status == status)
+
+        # 计算总数
+        count_query = select(func.count()).select_from(query.subquery())
+        total = await db.scalar(count_query) or 0
+
+        # 分页查询
+        query = query.order_by(SearchEngine.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
+        result = await db.execute(query)
+        items = result.scalars().all()
+
+        total_pages = ceil(total / page_size) if page_size else 0
+        
+        meta = PaginationSchema(
+            page=page,
+            page_size=page_size,
+            total=total,
+            total_pages=total_pages,
+        )
+        
+        return items, meta
+
+    async def create(self, db: AsyncSession, payload: SearchEngineCreate) -> SearchEngine:
+        """创建检索引擎"""
+        # 1. 检查名称是否已存在
+        exists = await db.execute(select(SearchEngine).where(
+            SearchEngine.name == payload.name,
+            SearchEngine.is_deleted == 0
+        ))
+        if exists.scalars().first():
+            raise ValueError("检索引擎名称已存在")
+
+        try:
+            now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            new_engine = SearchEngine(
+                id=str(uuid.uuid4()),
+                name=payload.name,
+                engine_type=payload.engine_type,
+                base_url=payload.base_url,
+                api_key=payload.api_key,
+                description=payload.description,
+                status=payload.status or "normal",
+                created_at=now,
+                updated_at=now
+            )
+            db.add(new_engine)
+            await db.commit()
+            await db.refresh(new_engine)
+
+            return new_engine
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def update(self, db: AsyncSession, id: str, payload: SearchEngineUpdate) -> SearchEngine:
+        """更新检索引擎信息"""
+        result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
+        engine = result.scalars().first()
+        
+        if not engine:
+            raise ValueError("检索引擎不存在")
+
+        try:
+            if payload.name is not None:
+                engine.name = payload.name
+            if payload.engine_type is not None:
+                engine.engine_type = payload.engine_type
+            if payload.base_url is not None:
+                engine.base_url = payload.base_url
+            if payload.api_key is not None:
+                engine.api_key = payload.api_key
+            if payload.description is not None:
+                engine.description = payload.description
+            if payload.status is not None:
+                engine.status = payload.status
+            
+            engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            await db.commit()
+            await db.refresh(engine)
+            
+            return engine
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def update_status(self, db: AsyncSession, id: str, status: str) -> SearchEngine:
+        """更新检索引擎状态"""
+        result = await db.execute(select(SearchEngine).where(SearchEngine.id == id, SearchEngine.is_deleted == 0))
+        engine = result.scalars().first()
+        
+        if not engine:
+            raise ValueError("检索引擎不存在")
+        
+        try:
+            engine.status = status
+            engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            await db.commit()
+            await db.refresh(engine)
+            return engine
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def delete(self, db: AsyncSession, id: str) -> None:
+        """删除检索引擎"""
+        result = await db.execute(select(SearchEngine).where(SearchEngine.id == id))
+        engine = result.scalars().first()
+        
+        if not engine:
+            raise ValueError("检索引擎不存在")
+
+        try:
+            # 软删除
+            engine.is_deleted = 1
+            engine.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            await db.commit()
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+search_engine_service = SearchEngineService()

+ 87 - 2
src/app/services/snippet_service.py

@@ -5,10 +5,13 @@
 from typing import List, Optional, Tuple, Dict, Any
 import json
 import random
+import csv
+import io
 from datetime import datetime
 
 from app.services.milvus_service import milvus_service
 from app.schemas.base import PaginationSchema, PaginatedResponseSchema
+from app.utils.vector_utils import text_to_vector_algo
 
 class SnippetService:
     
@@ -126,7 +129,8 @@ class SnippetService:
 
     def create(self, payload: Any) -> Dict:
         """创建知识片段"""
-        fake_vector = [random.random() for _ in range(768)] 
+        # 使用统一算法生成向量
+        fake_vector = text_to_vector_algo(payload.content, dim=768)
         
         data = [{
             "vector": fake_vector,
@@ -162,7 +166,8 @@ class SnippetService:
         milvus_service.client.delete(collection_name=kb, filter=expr)
         
         # 2. 插入新数据
-        fake_vector = [random.random() for _ in range(768)] 
+        # 使用统一算法生成向量
+        fake_vector = text_to_vector_algo(payload.content, dim=768)
         
         data = [{
             "vector": fake_vector,
@@ -226,4 +231,84 @@ class SnippetService:
             "updated_at": "-"
         }
 
+    def export_snippets(self, kb: Optional[str] = None, keyword: Optional[str] = None) -> Any:
+        """导出知识片段 (生成器)"""
+        
+        # 1. 确定要查询的目标集合列表
+        target_collections = []
+        if kb:
+            target_collections = [kb]
+        else:
+            target_collections = milvus_service.client.list_collections()
+        
+        for col_name in target_collections:
+            if not milvus_service.has_collection(col_name):
+                continue
+                
+            try:
+                # 获取该集合总数
+                stats = milvus_service.client.get_collection_stats(col_name)
+                col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
+                
+                if col_count == 0:
+                    continue
+
+                output_fields = ["*"]
+                expr = ""
+                
+                if keyword:
+                    desc = milvus_service.client.describe_collection(col_name)
+                    existing_fields = [f['name'] for f in desc.get('fields', [])]
+                    if 'text' in existing_fields:
+                        expr = f'text like "%{keyword}%"'
+                    else:
+                        continue 
+                
+                # 分批获取所有数据
+                batch_size = 1000
+                offset = 0
+                
+                while True:
+                    res = milvus_service.client.query(
+                        collection_name=col_name,
+                        filter=expr,
+                        output_fields=output_fields,
+                        limit=batch_size,
+                        offset=offset
+                    )
+                    
+                    if not res:
+                        break
+                        
+                    for r in res:
+                        yield self._format_snippet(r, col_name)
+                        
+                    offset += len(res)
+                    if len(res) < batch_size:
+                        break
+                        
+            except Exception as e:
+                print(f"Collection {col_name} export error: {e}")
+                continue
+
+    def generate_csv_stream(self, kb: Optional[str] = None, keyword: Optional[str] = None):
+        """生成CSV流"""
+        output = io.StringIO()
+        fieldnames = ["id", "collection_name", "doc_name", "content", "meta_info", "created_at", "status"]
+        writer = csv.DictWriter(output, fieldnames=fieldnames)
+        
+        # 写入表头
+        writer.writeheader()
+        yield output.getvalue()
+        output.seek(0)
+        output.truncate(0)
+        
+        for item in self.export_snippets(kb, keyword):
+            # 过滤掉不在 fieldnames 中的字段
+            row = {k: item.get(k, "") for k in fieldnames}
+            writer.writerow(row)
+            yield output.getvalue()
+            output.seek(0)
+            output.truncate(0)
+
 snippet_service = SnippetService()

+ 122 - 0
src/views/search_engine_view.py

@@ -0,0 +1,122 @@
+"""
+检索引擎视图路由
+"""
+from fastapi import APIRouter, Depends, Query, Path, Body
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.base.async_mysql_connection import get_db
+from app.services.search_engine_service import search_engine_service
+from app.sample.schemas.search_engine import (
+    SearchEngineCreate, 
+    SearchEngineUpdate, 
+    SearchEngineResponse,
+    KBSearchRequest,
+    KBSearchResponse
+)
+from app.schemas.base import PaginatedResponseSchema, ResponseSchema
+from app.services.jwt_token import verify_token
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+
+router = APIRouter(prefix="/sample/search-engine", tags=["样本中心-检索引擎"])
+security = HTTPBearer()
+
+# --- 新增:知识库搜索接口 ---
+@router.post("/search", response_model=ResponseSchema)
+async def search_knowledge_base(
+    payload: KBSearchRequest,
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """知识库语义搜索"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    result = await search_engine_service.search_kb(db, payload)
+    return ResponseSchema(code=0, message="搜索成功", data=result)
+
+
+# --- 原有接口 ---
+@router.get("", response_model=PaginatedResponseSchema)
+async def get_search_engines(
+    page: int = Query(1, ge=1, description="页码"),
+    page_size: int = Query(10, ge=1, le=100, description="每页数量"),
+    keyword: str = Query(None, description="搜索关键词"),
+    status: str = Query(None, description="状态筛选"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取检索引擎列表"""
+    # 鉴权
+    payload = verify_token(credentials.credentials)
+    if not payload:
+        return PaginatedResponseSchema(code=401, message="无效的访问令牌", data=[], meta=None)
+
+    items, meta = await search_engine_service.get_list(
+        db, page=page, page_size=page_size, keyword=keyword, status=status
+    )
+
+    return PaginatedResponseSchema(
+        code=0,
+        message="获取检索引擎列表成功",
+        data=[SearchEngineResponse.model_validate(item) for item in items],
+        meta=meta,
+    )
+
+@router.post("", response_model=ResponseSchema)
+async def create_search_engine(
+    payload: SearchEngineCreate,
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """创建新检索引擎"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    new_engine = await search_engine_service.create(db, payload)
+    return ResponseSchema(code=0, message="创建成功", data=SearchEngineResponse.model_validate(new_engine))
+
+@router.post("/{id}", response_model=ResponseSchema)
+async def update_search_engine(
+    payload: SearchEngineUpdate,
+    id: str = Path(..., description="检索引擎ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """更新检索引擎信息"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    engine = await search_engine_service.update(db, id, payload)
+    return ResponseSchema(code=0, message="更新成功", data=SearchEngineResponse.model_validate(engine))
+
+@router.post("/{id}/status", response_model=ResponseSchema)
+async def update_search_engine_status(
+    id: str = Path(..., description="检索引擎ID"),
+    status: str = Query(..., description="状态: normal/disabled"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """更新检索引擎状态"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    await search_engine_service.update_status(db, id, status)
+    return ResponseSchema(code=0, message=f"状态已更新为 {status}")
+
+@router.post("/{id}/delete", response_model=ResponseSchema)
+async def delete_search_engine(
+    id: str = Path(..., description="检索引擎ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """删除检索引擎"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    await search_engine_service.delete(db, id)
+    return ResponseSchema(code=0, message="删除成功")

+ 26 - 0
src/views/snippet_view.py

@@ -2,7 +2,10 @@
 知识片段视图路由
 """
 from fastapi import APIRouter, Depends, Query, Path, Body
+from fastapi.responses import StreamingResponse
 from typing import Optional
+from datetime import datetime
+import urllib.parse
 
 from app.services.snippet_service import snippet_service
 from app.schemas.base import ResponseSchema, PaginatedResponseSchema
@@ -44,6 +47,29 @@ async def get_snippets(
         meta=meta
     )
 
+@router.get("/export")
+async def export_snippets(
+    kb: Optional[str] = Query(None, description="知识库集合名称"),
+    keyword: Optional[str] = Query(None),
+    status: Optional[str] = Query(None),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """导出知识片段"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    filename = f"snippets_export_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv"
+    encoded_filename = urllib.parse.quote(filename)
+    
+    return StreamingResponse(
+        snippet_service.generate_csv_stream(kb, keyword),
+        media_type="text/csv",
+        headers={
+            "Content-Disposition": f"attachment; filename={filename}; filename*=utf-8''{encoded_filename}"
+        }
+    )
+
 @router.post("", response_model=ResponseSchema)
 async def create_snippet(
     payload: SnippetCreate,

+ 1 - 2
src/views/system_view.py

@@ -1278,7 +1278,6 @@ def generate_random_string(length=32):
     alphabet = string.ascii_letters + string.digits
     return ''.join(secrets.choice(alphabet) for _ in range(length))
 
-
 ### 2. 获取所有角色
 @router.get("/admin/roles")
 async def api_get_all_roles(
@@ -1563,4 +1562,4 @@ async def get_apps(
         
     except Exception as e:
         logger.exception("获取应用列表错误")
-        return ApiResponse(code=500001, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        return ApiResponse(code=500001, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()