Sfoglia il codice sorgente

feat:完成知识库相关接口

ZengChao 1 mese fa
parent
commit
90b0215426

+ 1 - 1
requirements/base.txt

@@ -44,4 +44,4 @@ flower==2.0.1
 python-dotenv==1.0.0
 
 # 向量数据库
-pymilvus==2.3.4
+pymilvus==2.6.3

+ 0 - 0
src/__init__.py


+ 3 - 1
src/app/api/v1/api_router.py

@@ -4,9 +4,11 @@ API路由聚合模块
 from fastapi import APIRouter
 from .auth.router import router as auth_router
 from .oauth.router import router as oauth_router
+from .document.router import router as document_router
 
 api_router = APIRouter()
 
 # 包含各个模块的路由
 api_router.include_router(auth_router, prefix="/auth")
-api_router.include_router(oauth_router, prefix="/oauth")
+api_router.include_router(oauth_router, prefix="/oauth")
+api_router.include_router(document_router, prefix="/document")

+ 1 - 0
src/app/api/v1/document/__init__.py

@@ -0,0 +1 @@
+"""认证API模块"""

+ 99 - 0
src/app/api/v1/document/knowledge_base.py

@@ -0,0 +1,99 @@
+"""
+知识库相关接口:提供Milvus集合信息给前端(分页)
+"""
+from math import ceil
+from fastapi import APIRouter, Query, Path
+from app.schemas.base import PaginatedResponseSchema, PaginationSchema, ResponseSchema
+from app.services.milvus_service import milvus_service
+from app.models.knowledge_base import DescriptionUpdate
+
+router = APIRouter()
+
+
+@router.get("/collections", response_model=PaginatedResponseSchema)
+async def get_collections(
+    page: int = Query(1, ge=1, description="页码"),
+    page_size: int = Query(20, ge=1, le=100, description="每页数量"),
+):
+	"""获取Milvus所有集合的详细信息(分页)"""
+	try:
+		details = milvus_service.get_collection_details()
+		total = len(details)
+		total_pages = ceil(total / page_size) if page_size else 0
+		start = (page - 1) * page_size
+		end = start + page_size
+		page_items = details[start:end]
+
+		meta = PaginationSchema(
+			page=page,
+			page_size=page_size,
+			total=total,
+			total_pages=total_pages,
+		)
+
+		return PaginatedResponseSchema(
+			code=200,
+			message="获取集合信息成功",
+			data=page_items,
+			meta=meta,
+		)
+	except Exception as e:
+		return PaginatedResponseSchema(
+			code=500001,
+			message="获取集合信息失败",
+			data=str(e),
+			meta=None,
+		)
+
+
+@router.post("/collections/{name}/state", response_model=ResponseSchema)
+async def set_collection_state(
+	name: str = Path(..., description="集合名称"),
+	action: str = Query(..., description="操作:load 或 release"),
+):
+	"""更改指定集合的加载状态(load/release)"""
+	try:
+		result = milvus_service.set_collection_state(name=name, action=action)
+		return ResponseSchema(code=0, message="集合状态更新成功", data=result)
+	except ValueError as ve:
+		return ResponseSchema(code=400001, message=str(ve), data=None)
+	except Exception as e:
+		return ResponseSchema(code=500001, message="集合状态更新失败", data=str(e))
+
+
+@router.get("/collections/{name}", response_model=ResponseSchema)
+async def get_collection_detail(name: str = Path(..., description="集合名称")):
+	"""获取指定集合的详细信息"""
+	try:
+		detail = milvus_service.get_collection_detail(name=name)
+		return ResponseSchema(code=0, message="获取集合详情成功", data=detail)
+	except Exception as e:
+		return ResponseSchema(code=500001, message="获取集合详情失败", data=str(e))
+
+
+@router.delete("/collections/{name}", response_model=ResponseSchema)
+async def delete_collection(name: str = Path(..., description="集合名称")):
+	"""当集合内容为空时删除集合"""
+	try:
+		result = milvus_service.delete_collection_if_empty(name=name)
+		return ResponseSchema(code=0, message="删除集合成功", data=result)
+	except ValueError as ve:
+		return ResponseSchema(code=400001, message=str(ve), data=None)
+	except Exception as e:
+		return ResponseSchema(code=500001, message="删除集合失败", data=str(e))
+
+
+@router.put("/collections/{name}/description", response_model=ResponseSchema)
+async def update_collection_description(
+    name: str = Path(..., description="集合名称"),
+    payload: DescriptionUpdate = ...,  # noqa: B008 fastapi dependency style
+):
+    """修改指定集合的描述"""
+    try:
+        detail = milvus_service.update_collection_description(name=name, description=payload.description)
+        return ResponseSchema(code=0, message="更新集合描述成功", data=detail)
+    except NotImplementedError as nie:
+        return ResponseSchema(code=400002, message=str(nie), data=None)
+    except Exception as e:
+        return ResponseSchema(code=500001, message="更新集合描述失败", data=str(e))
+

+ 10 - 0
src/app/api/v1/document/router.py

@@ -0,0 +1,10 @@
+"""
+认证路由模块
+"""
+from fastapi import APIRouter
+from .knowledge_base import router as knowledge_base
+
+router = APIRouter()
+
+# 包含认证端点
+router.include_router(knowledge_base, tags=["知识库"])

+ 11 - 0
src/app/models/knowledge_base.py

@@ -0,0 +1,11 @@
+"""
+Knowledge base related request/response models.
+"""
+from pydantic import BaseModel
+from typing import Optional, List, Dict, Any
+
+
+class DescriptionUpdate(BaseModel):
+    """Payload for updating a collection description."""
+    description: str
+

+ 161 - 5
src/app/services/milvus_service.py

@@ -7,7 +7,7 @@ import logging
 from typing import List, Dict, Any
 from datetime import datetime
 
-from src.app.config.milvus import get_milvus_manager
+from app.config.milvus import get_milvus_manager
 
 logger = logging.getLogger(__name__)
 
@@ -18,10 +18,7 @@ class MilvusService:
 
     def get_collection_details(self) -> List[Dict[str, Any]]:
         """
-        获取所有 Collections 详细信息(按你的要求):
-        - 时间转换直接写在这里:physical_ms = ts_int >> 18
-        - load state / row_count 不保底:拿不到就让异常抛出
-        - 直接调用 MilvusClient 原生方法(不再二次封装)
+        获取所有 Collections 详细信息
         """
         details: List[Dict[str, Any]] = []
 
@@ -66,6 +63,165 @@ class MilvusService:
         logger.info(f"成功获取Collections详细信息,共{len(details)}个")
         return details
 
+    def set_collection_state(self, name: str, action: str) -> Dict[str, Any]:
+        """
+        改变指定 Collection 的加载状态。
+
+        参数:
+        - name: 集合名称
+        - action: 操作,取值 'load' 或 'release'
+
+        返回:
+        - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"}
+        """
+        action_norm = (action or "").strip().lower()
+        if action_norm not in {"load", "release"}:
+            raise ValueError("action 必须为 'load' 或 'release'")
+
+        # 执行加载/释放
+        if action_norm == "load":
+            self.client.load_collection(collection_name=name)
+        else:
+            self.client.release_collection(collection_name=name)
+
+        # 返回最新状态
+        load_state = self.client.get_load_state(collection_name=name)
+        state = load_state.get("state") if isinstance(load_state, dict) else load_state
+        result = {"name": name, "state": state, "action": action_norm}
+        logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})")
+        return result
+
+    def delete_collection_if_empty(self, name: str) -> Dict[str, Any]:
+        """仅当集合内容为空时删除集合,否则抛出异常"""
+        stats = self.client.get_collection_stats(collection_name=name)
+        row_count = stats.get("row_count") if isinstance(stats, dict) else None
+        if row_count is None:
+            raise ValueError("无法获取集合行数,禁止删除")
+        if int(row_count) > 0:
+            raise ValueError("集合内容不为空,不能删除")
+
+        self.client.drop_collection(collection_name=name)
+        logger.info(f"集合 {name} 已删除")
+        return {"name": name, "deleted": True}
+
+    def get_collection_detail(self, name: str) -> Dict[str, Any]:
+        """获取单个集合的详细信息,包含schema、索引等所有desc字段"""
+        desc = self.client.describe_collection(collection_name=name)
+        stats = self.client.get_collection_stats(collection_name=name)
+        load_state = self.client.get_load_state(collection_name=name)
+
+        # 时间戳转换
+        created_time = None
+        updated_time = None
+
+        if desc.get("created_timestamp") is not None:
+            ts_int = int(desc["created_timestamp"])
+            physical_ms = ts_int >> 18
+            created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
+
+        if desc.get("update_timestamp") is not None:
+            ts_int = int(desc["update_timestamp"])
+            physical_ms = ts_int >> 18
+            updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
+
+        entity_count = stats.get("row_count", 0)
+        status = load_state.get("state") if isinstance(load_state, dict) else load_state
+
+        # 提取字段schema
+        fields = []
+        if "fields" in desc:
+            for field in desc["fields"]:
+                field_info = {
+                    "name": field.get("name"),
+                    "type": str(field.get("type")),
+                    "description": field.get("description", ""),
+                    "is_primary": field.get("is_primary", False),
+                    "auto_id": field.get("auto_id"),
+                }
+                # 向量维度
+                if "params" in field and "dim" in field["params"]:
+                    field_info["dim"] = field["params"]["dim"]
+                # 字符串长度
+                if "params" in field and "max_length" in field["params"]:
+                    field_info["max_length"] = field["params"]["max_length"]
+                # 其他params
+                if "params" in field:
+                    field_info["params"] = field["params"]
+                fields.append(field_info)
+
+        # 提取索引信息
+        indices = []
+        if "indexes" in desc:
+            for idx in desc["indexes"]:
+                index_info = {
+                    "field_name": idx.get("field_name"),
+                    "index_name": idx.get("index_name"),
+                    "index_type": idx.get("index_type"),
+                    "metric_type": idx.get("metric_type"),
+                    "params": idx.get("params"),
+                }
+                indices.append(index_info)
+
+        detail = {
+            "name": name,
+            "description": desc.get("description", ""),
+            "status": status,
+            "entity_count": entity_count,
+            "created_time": created_time,
+            "updated_time": updated_time,
+            "fields": fields,
+            "enable_dynamic_field": desc.get("enable_dynamic_field", False),
+            "consistency_level": desc.get("consistency_level"),
+            "num_shards": desc.get("num_shards"),
+            "num_partitions": desc.get("num_partitions"),
+            "indices": indices,
+            "properties": desc.get("properties"),
+            "aliases": desc.get("aliases", []),
+        }
+
+        logger.info(f"成功获取集合 {name} 的详细信息")
+        return detail
+
+    
+    def update_collection_description(self, name: str, description: str) -> Dict[str, Any]:
+        """使用 alter_collection_properties 更新集合描述"""
+        description = description or ""
+
+        # 1. 更新集合 description(唯一修改点)
+        self.client.alter_collection_properties(
+            collection_name=name,
+            properties={"collection.description": description},
+        )
+
+        # 2. 重新获取集合信息
+        desc = self.client.describe_collection(collection_name=name)
+        print(desc)
+        stats = self.client.get_collection_stats(collection_name=name)
+        load_state = self.client.get_load_state(collection_name=name)
+
+        # 3. 时间戳转换(Milvus TSO -> 物理时间)
+        def ts_to_str(ts):
+            if ts is None:
+                return None
+            ts_int = int(ts)
+            physical_ms = ts_int >> 18
+            return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
+
+        created_time = ts_to_str(desc.get("created_timestamp"))
+        updated_time = ts_to_str(desc.get("update_timestamp"))
+
+        entity_count = stats.get("row_count") if isinstance(stats, dict) else None
+        status = load_state.get("state") if isinstance(load_state, dict) else load_state
+
+        return {
+            "name": name,
+            "status": status,
+            "entity_count": entity_count,
+            "description": desc.get("description", ""),
+            "created_time": created_time,
+            "updated_time": updated_time,
+        }
+
 
 # 可选:单例
 milvus_service = MilvusService()