knowledge_base.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. 知识库相关接口
  3. """
  4. from math import ceil
  5. from typing import List
  6. from fastapi import APIRouter, Query, Path, Depends, HTTPException
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from sqlalchemy import select, func, or_
  9. from datetime import datetime
  10. from app.config.database import get_db
  11. from app.sample.models.knowledge_base import KnowledgeBase
  12. from app.schemas.base import PaginatedResponseSchema, PaginationSchema, ResponseSchema
  13. from app.sample.schemas.knowledge_base import (
  14. KnowledgeBaseCreate,
  15. KnowledgeBaseUpdate,
  16. KnowledgeBaseResponse
  17. )
  18. from app.services.milvus_service import milvus_service
  19. router = APIRouter()
  20. @router.get("", response_model=PaginatedResponseSchema)
  21. async def get_knowledge_bases(
  22. page: int = Query(1, ge=1, description="页码"),
  23. page_size: int = Query(10, ge=1, le=100, description="每页数量"),
  24. keyword: str = Query(None, description="搜索关键词"),
  25. status: str = Query(None, description="状态筛选"),
  26. db: AsyncSession = Depends(get_db)
  27. ):
  28. """获取知识库列表"""
  29. # --- 同步 Milvus 数据 (新增逻辑) ---
  30. try:
  31. # 1. 获取 Milvus 所有集合
  32. milvus_names = milvus_service.client.list_collections()
  33. # 2. 获取 DB 中已有的集合
  34. result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0))
  35. existing_kbs = result.scalars().all()
  36. existing_map = {kb.collection_name: kb for kb in existing_kbs}
  37. # 3. 同步逻辑
  38. has_changes = False
  39. import uuid
  40. for m_name in milvus_names:
  41. # 获取统计信息
  42. try:
  43. stats = milvus_service.client.get_collection_stats(m_name)
  44. row_count = int(stats.get("row_count", 0))
  45. except Exception:
  46. row_count = 0
  47. if m_name not in existing_map:
  48. # 新增
  49. new_kb = KnowledgeBase(
  50. id=str(uuid.uuid4()),
  51. name=m_name,
  52. collection_name=m_name,
  53. description="Synced from Milvus",
  54. status="normal",
  55. document_count=row_count,
  56. created_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  57. updated_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  58. )
  59. db.add(new_kb)
  60. has_changes = True
  61. else:
  62. # 更新统计
  63. kb = existing_map[m_name]
  64. if kb.document_count != row_count:
  65. kb.document_count = row_count
  66. has_changes = True
  67. if has_changes:
  68. await db.commit()
  69. except Exception as e:
  70. print(f"Sync Milvus collections failed: {e}")
  71. # ----------------------
  72. query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == False)
  73. if keyword:
  74. query = query.where(or_(
  75. KnowledgeBase.name.like(f"%{keyword}%"),
  76. KnowledgeBase.collection_name.like(f"%{keyword}%")
  77. ))
  78. if status:
  79. query = query.where(KnowledgeBase.status == status)
  80. # 计算总数
  81. count_query = select(func.count()).select_from(query.subquery())
  82. total = await db.scalar(count_query) or 0 # 确保 total 不为 None
  83. # 分页查询
  84. # 使用 created_time 而不是 created_at
  85. query = query.order_by(KnowledgeBase.created_time.desc()).offset((page - 1) * page_size).limit(page_size)
  86. result = await db.execute(query)
  87. items = result.scalars().all()
  88. # 设置 is_synced 属性 (非数据库字段,用于前端展示)
  89. for item in items:
  90. item.is_synced = item.collection_name in milvus_names
  91. total_pages = ceil(total / page_size) if page_size else 0
  92. meta = PaginationSchema(
  93. page=page,
  94. page_size=page_size,
  95. total=total,
  96. total_pages=total_pages,
  97. )
  98. return PaginatedResponseSchema(
  99. code=0,
  100. message="获取知识库列表成功",
  101. data=[KnowledgeBaseResponse.model_validate(item) for item in items],
  102. meta=meta,
  103. )
  104. @router.post("", response_model=ResponseSchema)
  105. async def create_knowledge_base(
  106. payload: KnowledgeBaseCreate,
  107. db: AsyncSession = Depends(get_db)
  108. ):
  109. """创建新知识库"""
  110. # 1. 检查 DB 是否已存在
  111. exists = await db.execute(select(KnowledgeBase).where(
  112. KnowledgeBase.collection_name == payload.collection_name,
  113. KnowledgeBase.is_deleted == False
  114. ))
  115. if exists.scalars().first():
  116. return ResponseSchema(code=400, message="知识库集合名称已存在")
  117. # 2. 检查 Milvus 是否已存在
  118. if milvus_service.has_collection(payload.collection_name):
  119. return ResponseSchema(code=400, message="Milvus集合已存在,请使用其他名称")
  120. try:
  121. # 3. 创建 Milvus 集合
  122. milvus_service.create_collection(
  123. name=payload.collection_name,
  124. dimension=payload.dimension,
  125. description=payload.description or ""
  126. )
  127. # 4. 创建 DB 记录
  128. new_kb = KnowledgeBase(
  129. name=payload.name,
  130. collection_name=payload.collection_name,
  131. description=payload.description,
  132. status=payload.status or "normal"
  133. )
  134. db.add(new_kb)
  135. await db.commit()
  136. await db.refresh(new_kb)
  137. return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
  138. except Exception as e:
  139. await db.rollback()
  140. return ResponseSchema(code=500, message=f"创建失败: {str(e)}")
  141. @router.put("/{id}", response_model=ResponseSchema)
  142. async def update_knowledge_base(
  143. id: str = Path(..., description="知识库ID"),
  144. payload: KnowledgeBaseUpdate = ..., # noqa: B008
  145. db: AsyncSession = Depends(get_db)
  146. ):
  147. """更新知识库信息"""
  148. result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
  149. kb = result.scalars().first()
  150. if not kb:
  151. return ResponseSchema(code=404, message="知识库不存在")
  152. try:
  153. if payload.name:
  154. kb.name = payload.name
  155. if payload.description:
  156. kb.description = payload.description
  157. # 同步更新 Milvus 描述
  158. # 注意:milvus_service 需要实现 update_collection_description
  159. # milvus_service.update_collection_description(kb.collection_name, payload.description)
  160. if payload.status:
  161. kb.status = payload.status
  162. await db.commit()
  163. await db.refresh(kb)
  164. return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
  165. except Exception as e:
  166. await db.rollback()
  167. return ResponseSchema(code=500, message=f"更新失败: {str(e)}")
  168. @router.patch("/{id}/status", response_model=ResponseSchema)
  169. async def update_knowledge_base_status(
  170. id: str = Path(..., description="知识库ID"),
  171. status: str = Query(..., description="状态: normal/test/disabled"),
  172. db: AsyncSession = Depends(get_db)
  173. ):
  174. """更新知识库状态(启用/禁用)"""
  175. result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
  176. kb = result.scalars().first()
  177. if not kb:
  178. return ResponseSchema(code=404, message="知识库不存在")
  179. try:
  180. kb.status = status
  181. # 可选:同步操作 Milvus Load/Release
  182. if status == "normal":
  183. milvus_service.client.load_collection(kb.collection_name)
  184. elif status == "disabled":
  185. milvus_service.client.release_collection(kb.collection_name)
  186. await db.commit()
  187. return ResponseSchema(code=0, message=f"状态已更新为 {status}")
  188. except Exception as e:
  189. await db.rollback()
  190. return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}")
  191. @router.delete("/{id}", response_model=ResponseSchema)
  192. async def delete_knowledge_base(
  193. id: str = Path(..., description="知识库ID"),
  194. db: AsyncSession = Depends(get_db)
  195. ):
  196. """删除知识库"""
  197. result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id))
  198. kb = result.scalars().first()
  199. if not kb:
  200. return ResponseSchema(code=404, message="知识库不存在")
  201. try:
  202. # 1. 删除 Milvus 集合 (强制删除)
  203. milvus_service.drop_collection(kb.collection_name)
  204. # 2. 软删除 DB 记录
  205. kb.is_deleted = True
  206. await db.commit()
  207. return ResponseSchema(code=0, message="删除成功")
  208. except Exception as e:
  209. await db.rollback()
  210. return ResponseSchema(code=500, message=f"删除失败: {str(e)}")