""" 知识库相关接口 """ from math import ceil from typing import List from fastapi import APIRouter, Query, Path, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, or_ from datetime import datetime from app.config.database import get_db from app.sample.models.knowledge_base import KnowledgeBase from app.schemas.base import PaginatedResponseSchema, PaginationSchema, ResponseSchema from app.sample.schemas.knowledge_base import ( KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseResponse ) from app.services.milvus_service import milvus_service router = APIRouter() @router.get("", response_model=PaginatedResponseSchema) async def get_knowledge_bases( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(10, ge=1, le=100, description="每页数量"), keyword: str = Query(None, description="搜索关键词"), status: str = Query(None, description="状态筛选"), db: AsyncSession = Depends(get_db) ): """获取知识库列表""" # --- 同步 Milvus 数据 (新增逻辑) --- try: # 1. 获取 Milvus 所有集合 milvus_names = milvus_service.client.list_collections() # 2. 获取 DB 中已有的集合 result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0)) existing_kbs = result.scalars().all() existing_map = {kb.collection_name: kb for kb in existing_kbs} # 3. 同步逻辑 has_changes = False import uuid for m_name in milvus_names: # 获取统计信息 try: stats = milvus_service.client.get_collection_stats(m_name) row_count = int(stats.get("row_count", 0)) except Exception: row_count = 0 if m_name not in existing_map: # 新增 new_kb = KnowledgeBase( id=str(uuid.uuid4()), name=m_name, collection_name=m_name, description="Synced from Milvus", status="normal", document_count=row_count, created_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S") ) db.add(new_kb) has_changes = True else: # 更新统计 kb = existing_map[m_name] if kb.document_count != row_count: kb.document_count = row_count has_changes = True if has_changes: await db.commit() except Exception as e: print(f"Sync Milvus collections failed: {e}") # ---------------------- query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == False) if keyword: query = query.where(or_( KnowledgeBase.name.like(f"%{keyword}%"), KnowledgeBase.collection_name.like(f"%{keyword}%") )) if status: query = query.where(KnowledgeBase.status == status) # 计算总数 count_query = select(func.count()).select_from(query.subquery()) total = await db.scalar(count_query) or 0 # 确保 total 不为 None # 分页查询 # 使用 created_time 而不是 created_at query = query.order_by(KnowledgeBase.created_time.desc()).offset((page - 1) * page_size).limit(page_size) result = await db.execute(query) items = result.scalars().all() # 设置 is_synced 属性 (非数据库字段,用于前端展示) for item in items: item.is_synced = item.collection_name in milvus_names total_pages = ceil(total / page_size) if page_size else 0 meta = PaginationSchema( page=page, page_size=page_size, total=total, total_pages=total_pages, ) return PaginatedResponseSchema( code=0, message="获取知识库列表成功", data=[KnowledgeBaseResponse.model_validate(item) for item in items], meta=meta, ) @router.post("", response_model=ResponseSchema) async def create_knowledge_base( payload: KnowledgeBaseCreate, db: AsyncSession = Depends(get_db) ): """创建新知识库""" # 1. 检查 DB 是否已存在 exists = await db.execute(select(KnowledgeBase).where( KnowledgeBase.collection_name == payload.collection_name, KnowledgeBase.is_deleted == False )) if exists.scalars().first(): return ResponseSchema(code=400, message="知识库集合名称已存在") # 2. 检查 Milvus 是否已存在 if milvus_service.has_collection(payload.collection_name): return ResponseSchema(code=400, message="Milvus集合已存在,请使用其他名称") try: # 3. 创建 Milvus 集合 milvus_service.create_collection( name=payload.collection_name, dimension=payload.dimension, description=payload.description or "" ) # 4. 创建 DB 记录 new_kb = KnowledgeBase( name=payload.name, collection_name=payload.collection_name, description=payload.description, status=payload.status or "normal" ) db.add(new_kb) await db.commit() await db.refresh(new_kb) return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb)) except Exception as e: await db.rollback() return ResponseSchema(code=500, message=f"创建失败: {str(e)}") @router.put("/{id}", response_model=ResponseSchema) async def update_knowledge_base( id: str = Path(..., description="知识库ID"), payload: KnowledgeBaseUpdate = ..., # noqa: B008 db: AsyncSession = Depends(get_db) ): """更新知识库信息""" result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False)) kb = result.scalars().first() if not kb: return ResponseSchema(code=404, message="知识库不存在") try: if payload.name: kb.name = payload.name if payload.description: kb.description = payload.description # 同步更新 Milvus 描述 # 注意:milvus_service 需要实现 update_collection_description # milvus_service.update_collection_description(kb.collection_name, payload.description) if payload.status: kb.status = payload.status await db.commit() await db.refresh(kb) return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb)) except Exception as e: await db.rollback() return ResponseSchema(code=500, message=f"更新失败: {str(e)}") @router.patch("/{id}/status", response_model=ResponseSchema) async def update_knowledge_base_status( id: str = Path(..., description="知识库ID"), status: str = Query(..., description="状态: normal/test/disabled"), db: AsyncSession = Depends(get_db) ): """更新知识库状态(启用/禁用)""" result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False)) kb = result.scalars().first() if not kb: return ResponseSchema(code=404, message="知识库不存在") try: kb.status = status # 可选:同步操作 Milvus Load/Release if status == "normal": milvus_service.client.load_collection(kb.collection_name) elif status == "disabled": milvus_service.client.release_collection(kb.collection_name) await db.commit() return ResponseSchema(code=0, message=f"状态已更新为 {status}") except Exception as e: await db.rollback() return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}") @router.delete("/{id}", response_model=ResponseSchema) async def delete_knowledge_base( id: str = Path(..., description="知识库ID"), db: AsyncSession = Depends(get_db) ): """删除知识库""" result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id)) kb = result.scalars().first() if not kb: return ResponseSchema(code=404, message="知识库不存在") try: # 1. 删除 Milvus 集合 (强制删除) milvus_service.drop_collection(kb.collection_name) # 2. 软删除 DB 记录 kb.is_deleted = True await db.commit() return ResponseSchema(code=0, message="删除成功") except Exception as e: await db.rollback() return ResponseSchema(code=500, message=f"删除失败: {str(e)}")