| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- """
- 知识库相关接口
- """
- from math import ceil
- from typing import List
- from fastapi import APIRouter, Query, Path, Depends, HTTPException
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select, func, or_
- from datetime import datetime
- from app.config.database import get_db
- from app.sample.models.knowledge_base import KnowledgeBase
- from app.schemas.base import PaginatedResponseSchema, PaginationSchema, ResponseSchema
- from app.sample.schemas.knowledge_base import (
- KnowledgeBaseCreate,
- KnowledgeBaseUpdate,
- KnowledgeBaseResponse
- )
- from app.services.milvus_service import milvus_service
- router = APIRouter()
- @router.get("", response_model=PaginatedResponseSchema)
- async def get_knowledge_bases(
- page: int = Query(1, ge=1, description="页码"),
- page_size: int = Query(10, ge=1, le=100, description="每页数量"),
- keyword: str = Query(None, description="搜索关键词"),
- status: str = Query(None, description="状态筛选"),
- db: AsyncSession = Depends(get_db)
- ):
- """获取知识库列表"""
-
- # --- 同步 Milvus 数据 (新增逻辑) ---
- try:
- # 1. 获取 Milvus 所有集合
- milvus_names = milvus_service.client.list_collections()
-
- # 2. 获取 DB 中已有的集合
- result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0))
- existing_kbs = result.scalars().all()
- existing_map = {kb.collection_name: kb for kb in existing_kbs}
-
- # 3. 同步逻辑
- has_changes = False
- import uuid
- for m_name in milvus_names:
- # 获取统计信息
- try:
- stats = milvus_service.client.get_collection_stats(m_name)
- row_count = int(stats.get("row_count", 0))
- except Exception:
- row_count = 0
- if m_name not in existing_map:
- # 新增
- new_kb = KnowledgeBase(
- id=str(uuid.uuid4()),
- name=m_name,
- collection_name=m_name,
- description="Synced from Milvus",
- status="normal",
- document_count=row_count,
- created_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- updated_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- )
- db.add(new_kb)
- has_changes = True
- else:
- # 更新统计
- kb = existing_map[m_name]
- if kb.document_count != row_count:
- kb.document_count = row_count
- has_changes = True
-
- if has_changes:
- await db.commit()
-
- except Exception as e:
- print(f"Sync Milvus collections failed: {e}")
- # ----------------------
- query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == False)
-
- if keyword:
- query = query.where(or_(
- KnowledgeBase.name.like(f"%{keyword}%"),
- KnowledgeBase.collection_name.like(f"%{keyword}%")
- ))
-
- if status:
- query = query.where(KnowledgeBase.status == status)
- # 计算总数
- count_query = select(func.count()).select_from(query.subquery())
- total = await db.scalar(count_query) or 0 # 确保 total 不为 None
- # 分页查询
- # 使用 created_time 而不是 created_at
- query = query.order_by(KnowledgeBase.created_time.desc()).offset((page - 1) * page_size).limit(page_size)
- result = await db.execute(query)
- items = result.scalars().all()
- # 设置 is_synced 属性 (非数据库字段,用于前端展示)
- for item in items:
- item.is_synced = item.collection_name in milvus_names
- total_pages = ceil(total / page_size) if page_size else 0
-
- meta = PaginationSchema(
- page=page,
- page_size=page_size,
- total=total,
- total_pages=total_pages,
- )
- return PaginatedResponseSchema(
- code=0,
- message="获取知识库列表成功",
- data=[KnowledgeBaseResponse.model_validate(item) for item in items],
- meta=meta,
- )
- @router.post("", response_model=ResponseSchema)
- async def create_knowledge_base(
- payload: KnowledgeBaseCreate,
- db: AsyncSession = Depends(get_db)
- ):
- """创建新知识库"""
- # 1. 检查 DB 是否已存在
- exists = await db.execute(select(KnowledgeBase).where(
- KnowledgeBase.collection_name == payload.collection_name,
- KnowledgeBase.is_deleted == False
- ))
- if exists.scalars().first():
- return ResponseSchema(code=400, message="知识库集合名称已存在")
- # 2. 检查 Milvus 是否已存在
- if milvus_service.has_collection(payload.collection_name):
- return ResponseSchema(code=400, message="Milvus集合已存在,请使用其他名称")
- try:
- # 3. 创建 Milvus 集合
- milvus_service.create_collection(
- name=payload.collection_name,
- dimension=payload.dimension,
- description=payload.description or ""
- )
- # 4. 创建 DB 记录
- new_kb = KnowledgeBase(
- name=payload.name,
- collection_name=payload.collection_name,
- description=payload.description,
- status=payload.status or "normal"
- )
- db.add(new_kb)
- await db.commit()
- await db.refresh(new_kb)
- return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
-
- except Exception as e:
- await db.rollback()
- return ResponseSchema(code=500, message=f"创建失败: {str(e)}")
- @router.put("/{id}", response_model=ResponseSchema)
- async def update_knowledge_base(
- id: str = Path(..., description="知识库ID"),
- payload: KnowledgeBaseUpdate = ..., # noqa: B008
- db: AsyncSession = Depends(get_db)
- ):
- """更新知识库信息"""
- result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
- kb = result.scalars().first()
-
- if not kb:
- return ResponseSchema(code=404, message="知识库不存在")
- try:
- if payload.name:
- kb.name = payload.name
- if payload.description:
- kb.description = payload.description
- # 同步更新 Milvus 描述
- # 注意:milvus_service 需要实现 update_collection_description
- # milvus_service.update_collection_description(kb.collection_name, payload.description)
- if payload.status:
- kb.status = payload.status
-
- await db.commit()
- await db.refresh(kb)
-
- return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
- except Exception as e:
- await db.rollback()
- return ResponseSchema(code=500, message=f"更新失败: {str(e)}")
- @router.patch("/{id}/status", response_model=ResponseSchema)
- async def update_knowledge_base_status(
- id: str = Path(..., description="知识库ID"),
- status: str = Query(..., description="状态: normal/test/disabled"),
- db: AsyncSession = Depends(get_db)
- ):
- """更新知识库状态(启用/禁用)"""
- result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
- kb = result.scalars().first()
-
- if not kb:
- return ResponseSchema(code=404, message="知识库不存在")
-
- try:
- kb.status = status
-
- # 可选:同步操作 Milvus Load/Release
- if status == "normal":
- milvus_service.client.load_collection(kb.collection_name)
- elif status == "disabled":
- milvus_service.client.release_collection(kb.collection_name)
-
- await db.commit()
-
- return ResponseSchema(code=0, message=f"状态已更新为 {status}")
- except Exception as e:
- await db.rollback()
- return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}")
- @router.delete("/{id}", response_model=ResponseSchema)
- async def delete_knowledge_base(
- id: str = Path(..., description="知识库ID"),
- db: AsyncSession = Depends(get_db)
- ):
- """删除知识库"""
- result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id))
- kb = result.scalars().first()
-
- if not kb:
- return ResponseSchema(code=404, message="知识库不存在")
- try:
- # 1. 删除 Milvus 集合 (强制删除)
- milvus_service.drop_collection(kb.collection_name)
-
- # 2. 软删除 DB 记录
- kb.is_deleted = True
- await db.commit()
-
- return ResponseSchema(code=0, message="删除成功")
- except Exception as e:
- await db.rollback()
- return ResponseSchema(code=500, message=f"删除失败: {str(e)}")
|