ソースを参照

Merge branch 'dev' of http://47.109.151.80:15030/CRBC-MaaS-Platform-Project/LQAdminPlatform into dev

lingmin_package@163.com 1 ヶ月 前
コミット
c88e21eaa2

+ 245 - 0
src/app/api/v1/document/knowledge_base.py

@@ -0,0 +1,245 @@
+"""
+知识库相关接口
+"""
+from math import ceil
+from typing import List
+from fastapi import APIRouter, Query, Path, Depends, HTTPException
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, func, or_
+from datetime import datetime
+
+from app.config.database import get_db
+from app.sample.models.knowledge_base import KnowledgeBase
+from app.schemas.base import PaginatedResponseSchema, PaginationSchema, ResponseSchema
+from app.sample.schemas.knowledge_base import (
+    KnowledgeBaseCreate, 
+    KnowledgeBaseUpdate, 
+    KnowledgeBaseResponse
+)
+from app.services.milvus_service import milvus_service
+
+router = APIRouter()
+
+@router.get("", response_model=PaginatedResponseSchema)
+async def get_knowledge_bases(
+    page: int = Query(1, ge=1, description="页码"),
+    page_size: int = Query(10, ge=1, le=100, description="每页数量"),
+    keyword: str = Query(None, description="搜索关键词"),
+    status: str = Query(None, description="状态筛选"),
+    db: AsyncSession = Depends(get_db)
+):
+    """获取知识库列表"""
+    
+    # --- 同步 Milvus 数据 (新增逻辑) ---
+    try:
+        # 1. 获取 Milvus 所有集合
+        milvus_names = milvus_service.client.list_collections()
+        
+        # 2. 获取 DB 中已有的集合
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0))
+        existing_kbs = result.scalars().all()
+        existing_map = {kb.collection_name: kb for kb in existing_kbs}
+        
+        # 3. 同步逻辑
+        has_changes = False
+        import uuid
+        for m_name in milvus_names:
+            # 获取统计信息
+            try:
+                stats = milvus_service.client.get_collection_stats(m_name)
+                row_count = int(stats.get("row_count", 0))
+            except Exception:
+                row_count = 0
+
+            if m_name not in existing_map:
+                # 新增
+                new_kb = KnowledgeBase(
+                    id=str(uuid.uuid4()),
+                    name=m_name,
+                    collection_name=m_name,
+                    description="Synced from Milvus",
+                    status="normal",
+                    document_count=row_count,
+                    created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+                    updated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+                )
+                db.add(new_kb)
+                has_changes = True
+            else:
+                # 更新统计
+                kb = existing_map[m_name]
+                if kb.document_count != row_count:
+                    kb.document_count = row_count
+                    has_changes = True
+        
+        if has_changes:
+            await db.commit()
+            
+    except Exception as e:
+        print(f"Sync Milvus collections failed: {e}")
+    # ----------------------
+
+    query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == False)
+    
+    if keyword:
+        query = query.where(or_(
+            KnowledgeBase.name.like(f"%{keyword}%"),
+            KnowledgeBase.collection_name.like(f"%{keyword}%")
+        ))
+    
+    if status:
+        query = query.where(KnowledgeBase.status == status)
+
+    # 计算总数
+    count_query = select(func.count()).select_from(query.subquery())
+    total = await db.scalar(count_query)
+
+    # 分页查询
+    query = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
+    result = await db.execute(query)
+    items = result.scalars().all()
+
+    total_pages = ceil(total / page_size) if page_size else 0
+    
+    meta = PaginationSchema(
+        page=page,
+        page_size=page_size,
+        total=total,
+        total_pages=total_pages,
+    )
+
+    return PaginatedResponseSchema(
+        code=0,
+        message="获取知识库列表成功",
+        data=[KnowledgeBaseResponse.model_validate(item) for item in items],
+        meta=meta,
+    )
+
+@router.post("", response_model=ResponseSchema)
+async def create_knowledge_base(
+    payload: KnowledgeBaseCreate,
+    db: AsyncSession = Depends(get_db)
+):
+    """创建新知识库"""
+    # 1. 检查 DB 是否已存在
+    exists = await db.execute(select(KnowledgeBase).where(
+        KnowledgeBase.collection_name == payload.collection_name,
+        KnowledgeBase.is_deleted == False
+    ))
+    if exists.scalars().first():
+        return ResponseSchema(code=400, message="知识库集合名称已存在")
+
+    # 2. 检查 Milvus 是否已存在
+    if milvus_service.has_collection(payload.collection_name):
+        return ResponseSchema(code=400, message="Milvus集合已存在,请使用其他名称")
+
+    try:
+        # 3. 创建 Milvus 集合
+        milvus_service.create_collection(
+            name=payload.collection_name,
+            dimension=payload.dimension,
+            description=payload.description or ""
+        )
+
+        # 4. 创建 DB 记录
+        new_kb = KnowledgeBase(
+            name=payload.name,
+            collection_name=payload.collection_name,
+            description=payload.description,
+            status=payload.status or "normal"
+        )
+        db.add(new_kb)
+        await db.commit()
+        await db.refresh(new_kb)
+
+        return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
+    
+    except Exception as e:
+        await db.rollback()
+        return ResponseSchema(code=500, message=f"创建失败: {str(e)}")
+
+@router.put("/{id}", response_model=ResponseSchema)
+async def update_knowledge_base(
+    id: str = Path(..., description="知识库ID"),
+    payload: KnowledgeBaseUpdate = ...,  # noqa: B008
+    db: AsyncSession = Depends(get_db)
+):
+    """更新知识库信息"""
+    result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
+    kb = result.scalars().first()
+    
+    if not kb:
+        return ResponseSchema(code=404, message="知识库不存在")
+
+    try:
+        if payload.name:
+            kb.name = payload.name
+        if payload.description:
+            kb.description = payload.description
+            # 同步更新 Milvus 描述
+            # 注意:milvus_service 需要实现 update_collection_description
+            # milvus_service.update_collection_description(kb.collection_name, payload.description)
+        if payload.status:
+            kb.status = payload.status
+        
+        await db.commit()
+        await db.refresh(kb)
+        
+        return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
+    except Exception as e:
+        await db.rollback()
+        return ResponseSchema(code=500, message=f"更新失败: {str(e)}")
+
+@router.patch("/{id}/status", response_model=ResponseSchema)
+async def update_knowledge_base_status(
+    id: str = Path(..., description="知识库ID"),
+    status: str = Query(..., description="状态: normal/test/disabled"),
+    db: AsyncSession = Depends(get_db)
+):
+    """更新知识库状态(启用/禁用)"""
+    result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == False))
+    kb = result.scalars().first()
+    
+    if not kb:
+        return ResponseSchema(code=404, message="知识库不存在")
+    
+    try:
+        kb.status = status
+        
+        # 可选:同步操作 Milvus Load/Release
+        if status == "normal":
+            milvus_service.client.load_collection(kb.collection_name)
+        elif status == "disabled":
+            milvus_service.client.release_collection(kb.collection_name)
+            
+        await db.commit()
+        
+        return ResponseSchema(code=0, message=f"状态已更新为 {status}")
+    except Exception as e:
+        await db.rollback()
+        return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}")
+
+@router.delete("/{id}", response_model=ResponseSchema)
+async def delete_knowledge_base(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db)
+):
+    """删除知识库"""
+    result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id))
+    kb = result.scalars().first()
+    
+    if not kb:
+        return ResponseSchema(code=404, message="知识库不存在")
+
+    try:
+        # 1. 删除 Milvus 集合 (强制删除)
+        milvus_service.drop_collection(kb.collection_name)
+        
+        # 2. 软删除 DB 记录
+        kb.is_deleted = True
+        await db.commit()
+        
+        return ResponseSchema(code=0, message="删除成功")
+    except Exception as e:
+        await db.rollback()
+        return ResponseSchema(code=500, message=f"删除失败: {str(e)}")

+ 3 - 2
src/app/config/config.ini

@@ -16,13 +16,14 @@ RELOAD=True
 # 注意:如果密码包含特殊字符(如@),需要进行URL编码
 # @ 编码为 %40
 # 例如:密码 lq@123 应该写成 lq%40123
-DATABASE_URL=mysql+aiomysql://root:admin@localhost:3306/lq_db
-#DATABASE_URL=mysql+aiomysql://root:lq%40123@192.168.92.61:13306/lq_oauth_db
+#DATABASE_URL=mysql+aiomysql://root:admin@localhost:3306/lq_db
+DATABASE_URL=mysql+aiomysql://root:lq123@192.168.92.61:13306/lq_oauth_db
 DATABASE_ECHO=False
 
 # Milvus向量数据库配置信息
 MILVUS_HOST=192.168.92.61
 MILVUS_PORT=19530
+MILVUS_DB=lq_db
 MILVUS_USER=
 MILVUS_PASSWORD=
 

+ 5 - 11
src/app/sample/models/__init__.py

@@ -1,18 +1,12 @@
 """
 样本中心模块 - 数据模型
 """
-from app.models.knowledge_base import (
-    KnowledgeBase,
-    Document,
-    DocumentChunk,
-    Tag,
-    DocumentTag,
-)
+from app.sample.models.knowledge_base import KnowledgeBase
+
+# Document, DocumentChunk, Tag, DocumentTag 暂时未迁移或不存在
+# from app.sample.models.document import Document
+# ...
 
 __all__ = [
     "KnowledgeBase",
-    "Document",
-    "DocumentChunk",
-    "Tag",
-    "DocumentTag",
 ]

+ 22 - 0
src/app/sample/models/knowledge_base.py

@@ -0,0 +1,22 @@
+"""
+知识库数据库模型
+"""
+from sqlalchemy import Column, String, Integer, Text
+from app.base.async_mysql_connection import Base
+
+class KnowledgeBase(Base):
+    """知识库模型"""
+    __tablename__ = "knowledge_base"
+
+    id = Column(String(36), primary_key=True, comment="ID")
+    name = Column(String(100), nullable=False, comment="知识库名称")
+    collection_name = Column(String(100), nullable=False, unique=True, comment="Milvus集合名称(Table Name)")
+    description = Column(Text, nullable=True, comment="描述")
+    status = Column(String(20), default="normal", comment="状态: normal(正常), test(测试), disabled(禁用)")
+    document_count = Column(Integer, default=0, comment="文档数量")
+    is_deleted = Column(Integer, default=0, comment="是否删除")
+    created_at = Column(String(32), comment="创建时间")
+    updated_at = Column(String(32), comment="更新时间")
+
+    def __repr__(self):
+        return f"<KnowledgeBase {self.name}>"

+ 24 - 2
src/app/sample/schemas/__init__.py

@@ -1,6 +1,28 @@
 """
 样本中心模块 - 数据结构
 """
-# 这里将来添加样本中心相关的schema定义
+from app.sample.schemas.knowledge_base import (
+    KnowledgeBaseBase,
+    KnowledgeBaseCreate,
+    KnowledgeBaseUpdate,
+    KnowledgeBaseResponse,
+    DescriptionUpdate
+)
+from app.sample.schemas.sample_schemas import (
+    BatchEnterRequest,
+    BatchDeleteRequest,
+    ConvertRequest,
+    DocumentAdd
+)
 
-__all__ = []
+__all__ = [
+    "KnowledgeBaseBase",
+    "KnowledgeBaseCreate",
+    "KnowledgeBaseUpdate",
+    "KnowledgeBaseResponse",
+    "DescriptionUpdate",
+    "BatchEnterRequest",
+    "BatchDeleteRequest",
+    "ConvertRequest",
+    "DocumentAdd"
+]

+ 42 - 0
src/app/sample/schemas/knowledge_base.py

@@ -0,0 +1,42 @@
+from typing import Optional
+from pydantic import BaseModel, Field
+from app.schemas.base import BaseModelSchema
+
+class KnowledgeBaseBase(BaseModel):
+    name: str = Field(..., description="知识库名称")
+    collection_name: str = Field(..., description="Milvus集合名称")
+    description: Optional[str] = Field(None, description="描述")
+    status: Optional[str] = Field("normal", description="状态")
+
+class KnowledgeBaseCreate(KnowledgeBaseBase):
+    """创建知识库请求参数"""
+    dimension: int = Field(768, description="向量维度,默认768")
+
+class KnowledgeBaseUpdate(BaseModel):
+    """更新知识库请求参数"""
+    name: Optional[str] = None
+    description: Optional[str] = None
+    status: Optional[str] = None
+
+class DescriptionUpdate(BaseModel):
+    """仅更新描述(保留原有兼容性)"""
+    description: str
+
+from datetime import datetime
+
+class KnowledgeBaseResponse(BaseModelSchema):
+    """知识库响应模型"""
+    id: str
+    name: str
+    collection_name: str
+    description: Optional[str]
+    status: str
+    document_count: int
+    created_at: Optional[datetime] = None
+    updated_at: Optional[datetime] = None
+
+    class Config:
+        from_attributes = True
+        json_encoders = {
+            datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") if v else None
+        }

+ 20 - 1
src/app/server/app.py

@@ -46,6 +46,8 @@ from views.system_view import router as system_router
 from views.oauth_view import router as oauth_router
 from views.sample_view import router as sample_router
 from views.auth_view import router as auth_router
+from views.knowledge_base_view import router as knowledge_base_router
+from views.snippet_view import router as snippet_router
 
 # 导入现有API路由
 from app.api.v1.api_router import api_router
@@ -126,13 +128,28 @@ app.add_middleware(
         "http://127.0.0.1:3000",
         "http://127.0.0.1:3001",
         "http://127.0.0.1:5173",
-        "http://127.0.0.1:8080"
+        "http://127.0.0.1:8080",
+        "*" # 临时允许所有,方便调试
     ],
     allow_credentials=True,
     allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
     allow_headers=["*"],
 )
 
+# --- 调试中间件 ---
+@app.middleware("http")
+async def log_requests(request: Request, call_next):
+    # logger.info(f"收到请求: {request.method} {request.url}")
+    try:
+        response = await call_next(request)
+        # logger.info(f"请求响应: {response.status_code}")
+        return response
+    except Exception as e:
+        logger.error(f"请求处理异常: {e}")
+        raise
+# ------------------
+
+
 
 # 全局异常处理
 @app.exception_handler(BaseAPIException)
@@ -229,6 +246,8 @@ app.include_router(system_router, prefix="/api/v1")
 app.include_router(oauth_router, prefix="")
 app.include_router(auth_router, prefix="/api/v1")
 app.include_router(sample_router, prefix="/api/v1")
+app.include_router(knowledge_base_router, prefix="/api/v1")
+app.include_router(snippet_router, prefix="/api/v1")
 
 
 def create_app() -> FastAPI:

+ 227 - 0
src/app/services/knowledge_base_service.py

@@ -0,0 +1,227 @@
+"""
+知识库业务逻辑服务
+"""
+from math import ceil
+from typing import List, Optional, Tuple, Dict, Any
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, func, or_
+from datetime import datetime
+import uuid
+
+from app.sample.models.knowledge_base import KnowledgeBase
+from app.sample.schemas.knowledge_base import (
+    KnowledgeBaseCreate, 
+    KnowledgeBaseUpdate,
+    KnowledgeBaseResponse
+)
+from app.services.milvus_service import milvus_service
+from app.schemas.base import PaginationSchema
+
+class KnowledgeBaseService:
+    
+    async def get_list(
+        self, 
+        db: AsyncSession,
+        page: int = 1,
+        page_size: int = 10,
+        keyword: Optional[str] = None,
+        status: Optional[str] = None
+    ) -> Tuple[List[KnowledgeBase], PaginationSchema]:
+        """获取知识库列表"""
+        
+        # --- 同步 Milvus 数据 ---
+        try:
+            # 1. 获取 Milvus 所有集合
+            milvus_names = milvus_service.client.list_collections()
+            
+            # 2. 获取 DB 中已有的集合
+            result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0))
+            existing_kbs = result.scalars().all()
+            existing_map = {kb.collection_name: kb for kb in existing_kbs}
+            
+            # 3. 同步逻辑
+            has_changes = False
+            for m_name in milvus_names:
+                # 获取统计信息
+                try:
+                    stats = milvus_service.client.get_collection_stats(m_name)
+                    row_count = int(stats.get("row_count", 0))
+                except Exception:
+                    row_count = 0
+
+                if m_name not in existing_map:
+                    # 新增
+                    new_kb = KnowledgeBase(
+                        id=str(uuid.uuid4()),
+                        name=m_name,
+                        collection_name=m_name,
+                        description="Synced from Milvus",
+                        status="normal",
+                        document_count=row_count,
+                        created_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+                        updated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+                    )
+                    db.add(new_kb)
+                    has_changes = True
+                else:
+                    # 更新统计
+                    kb = existing_map[m_name]
+                    if kb.document_count != row_count:
+                        kb.document_count = row_count
+                        # kb.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 统计更新不一定更新时间
+                        has_changes = True
+            
+            if has_changes:
+                await db.commit()
+                
+        except Exception as e:
+            # 同步失败不影响查询,只打印日志
+            print(f"Sync Milvus collections failed: {e}")
+        # ----------------------
+
+        query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0)
+        
+        if keyword:
+            query = query.where(or_(
+                KnowledgeBase.name.like(f"%{keyword}%"),
+                KnowledgeBase.collection_name.like(f"%{keyword}%")
+            ))
+        
+        if status:
+            query = query.where(KnowledgeBase.status == status)
+
+        # 计算总数
+        count_query = select(func.count()).select_from(query.subquery())
+        total = await db.scalar(count_query) or 0
+
+        # 分页查询
+        query = query.order_by(KnowledgeBase.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
+        result = await db.execute(query)
+        items = result.scalars().all()
+
+        total_pages = ceil(total / page_size) if page_size else 0
+        
+        meta = PaginationSchema(
+            page=page,
+            page_size=page_size,
+            total=total,
+            total_pages=total_pages,
+        )
+        
+        return items, meta
+
+    async def create(self, db: AsyncSession, payload: KnowledgeBaseCreate) -> KnowledgeBase:
+        """创建新知识库"""
+        # 1. 检查 DB 是否已存在
+        exists = await db.execute(select(KnowledgeBase).where(
+            KnowledgeBase.collection_name == payload.collection_name,
+            KnowledgeBase.is_deleted == 0
+        ))
+        if exists.scalars().first():
+            raise ValueError("知识库集合名称已存在")
+
+        # 2. 检查 Milvus 是否已存在
+        if milvus_service.has_collection(payload.collection_name):
+            raise ValueError("Milvus集合已存在,请使用其他名称")
+
+        try:
+            # 3. 创建 Milvus 集合
+            milvus_service.create_collection(
+                name=payload.collection_name,
+                dimension=payload.dimension,
+                description=payload.description or ""
+            )
+
+            # 4. 创建 DB 记录
+            now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            new_kb = KnowledgeBase(
+                id=str(uuid.uuid4()),
+                name=payload.name,
+                collection_name=payload.collection_name,
+                description=payload.description,
+                status=payload.status or "normal",
+                created_at=now,
+                updated_at=now
+            )
+            db.add(new_kb)
+            await db.commit()
+            await db.refresh(new_kb)
+
+            return new_kb
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def update(self, db: AsyncSession, id: str, payload: KnowledgeBaseUpdate) -> KnowledgeBase:
+        """更新知识库信息"""
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        try:
+            if payload.name is not None:
+                kb.name = payload.name
+            if payload.description is not None:
+                kb.description = payload.description
+                # 同步更新 Milvus 描述
+                milvus_service.update_collection_description(kb.collection_name, payload.description)
+            if payload.status is not None:
+                kb.status = payload.status
+            
+            kb.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            await db.commit()
+            await db.refresh(kb)
+            
+            return kb
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def update_status(self, db: AsyncSession, id: str, status: str) -> KnowledgeBase:
+        """更新知识库状态(启用/禁用)"""
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        
+        if not kb:
+            raise ValueError("知识库不存在")
+        
+        try:
+            kb.status = status
+            kb.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            
+            # 可选:同步操作 Milvus Load/Release
+            if status == "normal":
+                milvus_service.set_collection_state(kb.collection_name, "load")
+            elif status == "disabled":
+                milvus_service.set_collection_state(kb.collection_name, "release")
+                
+            await db.commit()
+            await db.refresh(kb)
+            return kb
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+    async def delete(self, db: AsyncSession, id: str) -> None:
+        """删除知识库"""
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id))
+        kb = result.scalars().first()
+        
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        try:
+            # 1. 删除 Milvus 集合 (强制删除)
+            milvus_service.drop_collection(kb.collection_name)
+            
+            # 2. 软删除 DB 记录
+            kb.is_deleted = 1
+            kb.updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            await db.commit()
+        except Exception as e:
+            await db.rollback()
+            raise e
+
+knowledge_base_service = KnowledgeBaseService()

+ 27 - 0
src/app/services/milvus_service.py

@@ -23,6 +23,33 @@ class MilvusService:
     def __init__(self):
         self.client = get_milvus_manager().client
 
+    def create_collection(self, name: str, dimension: int = 768, description: str = "") -> None:
+        """创建 Milvus 集合"""
+        if self.client.has_collection(name):
+            logger.info(f"Collection {name} already exists.")
+            return
+        
+        # 使用简化的 create_collection API
+        self.client.create_collection(
+            collection_name=name,
+            dimension=dimension,
+            description=description,
+            auto_id=True,  # 自动生成 ID
+            id_type="int", # ID 类型
+            metric_type="COSINE" # 默认使用余弦相似度
+        )
+        logger.info(f"Created collection {name} with dimension {dimension}")
+
+    def drop_collection(self, name: str) -> None:
+        """删除 Milvus 集合"""
+        if self.client.has_collection(name):
+            self.client.drop_collection(name)
+            logger.info(f"Dropped collection {name}")
+
+    def has_collection(self, name: str) -> bool:
+        """检查集合是否存在"""
+        return self.client.has_collection(name)
+
     def get_collection_details(self) -> List[Dict[str, Any]]:
         """
         获取所有 Collections 详细信息

+ 126 - 0
src/views/knowledge_base_view.py

@@ -0,0 +1,126 @@
+"""
+知识库视图路由
+"""
+from fastapi import APIRouter, Depends, Query, Path
+from sqlalchemy.ext.asyncio import AsyncSession
+from typing import Optional
+
+from app.base.async_mysql_connection import get_db
+from app.services.knowledge_base_service import knowledge_base_service
+from app.sample.schemas.knowledge_base import (
+    KnowledgeBaseCreate, 
+    KnowledgeBaseUpdate, 
+    KnowledgeBaseResponse
+)
+from app.schemas.base import PaginatedResponseSchema, ResponseSchema
+from app.services.jwt_token import verify_token
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+
+router = APIRouter(prefix="/sample/knowledge-base", tags=["样本中心-知识库"])
+security = HTTPBearer()
+
+@router.get("", response_model=PaginatedResponseSchema)
+async def get_knowledge_bases(
+    page: int = Query(1, ge=1, description="页码"),
+    page_size: int = Query(10, ge=1, le=100, description="每页数量"),
+    keyword: str = Query(None, description="搜索关键词"),
+    status: str = Query(None, description="状态筛选"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取知识库列表"""
+    # 鉴权
+    payload = verify_token(credentials.credentials)
+    if not payload:
+        return PaginatedResponseSchema(code=401, message="无效的访问令牌", data=[], meta=None)
+
+    items, meta = await knowledge_base_service.get_list(
+        db, page=page, page_size=page_size, keyword=keyword, status=status
+    )
+
+    print("11111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222")
+
+    return PaginatedResponseSchema(
+        code=0,
+        message="获取知识库列表成功",
+        data=[KnowledgeBaseResponse.model_validate(item) for item in items],
+        meta=meta,
+    )
+
+@router.post("", response_model=ResponseSchema)
+async def create_knowledge_base(
+    payload: KnowledgeBaseCreate,
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """创建新知识库"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        new_kb = await knowledge_base_service.create(db, payload)
+        return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"创建失败: {str(e)}")
+
+@router.put("/{id}", response_model=ResponseSchema)
+async def update_knowledge_base(
+    payload: KnowledgeBaseUpdate,
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """更新知识库信息"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        kb = await knowledge_base_service.update(db, id, payload)
+        return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
+    except ValueError as e:
+        return ResponseSchema(code=404, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"更新失败: {str(e)}")
+
+@router.patch("/{id}/status", response_model=ResponseSchema)
+async def update_knowledge_base_status(
+    id: str = Path(..., description="知识库ID"),
+    status: str = Query(..., description="状态: normal/test/disabled"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """更新知识库状态"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        kb = await knowledge_base_service.update_status(db, id, status)
+        return ResponseSchema(code=0, message=f"状态已更新为 {status}")
+    except ValueError as e:
+        return ResponseSchema(code=404, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}")
+
+@router.delete("/{id}", response_model=ResponseSchema)
+async def delete_knowledge_base(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """删除知识库"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        await knowledge_base_service.delete(db, id)
+        return ResponseSchema(code=0, message="删除成功")
+    except ValueError as e:
+        return ResponseSchema(code=404, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"删除失败: {str(e)}")

+ 286 - 0
src/views/snippet_view.py

@@ -0,0 +1,286 @@
+"""
+知识片段视图路由
+"""
+from fastapi import APIRouter, Depends, Query, Path, Body
+from sqlalchemy.ext.asyncio import AsyncSession
+from typing import Optional, List, Dict, Any
+from datetime import datetime, timezone
+import json
+
+from app.base.async_mysql_connection import get_db
+from app.services.milvus_service import milvus_service
+from app.schemas.base import ResponseSchema, PaginatedResponseSchema, PaginationSchema
+from app.services.jwt_token import verify_token
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from pydantic import BaseModel
+
+router = APIRouter(prefix="/document/snippet", tags=["样本中心-知识片段"])
+security = HTTPBearer()
+
+# Schemas
+class SnippetCreate(BaseModel):
+    collection_name: str
+    doc_name: str = "手动添加"
+    content: str
+    meta_info: Optional[str] = None
+
+class SnippetUpdate(BaseModel):
+    collection_name: str
+    doc_name: Optional[str] = None
+    content: str
+
+@router.get("", response_model=PaginatedResponseSchema)
+async def get_snippets(
+    page: int = Query(1, ge=1),
+    page_size: int = Query(10, ge=1),
+    kb: Optional[str] = Query(None, description="知识库集合名称"),
+    keyword: Optional[str] = Query(None),
+    status: Optional[str] = Query(None),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取知识片段列表 (跨集合查询)"""
+    try:
+        # 1. 确定要查询的目标集合列表
+        target_collections = []
+        if kb:
+            target_collections = [kb]
+        else:
+            # 查询所有启用的知识库 (这里简单起见,直接查 Milvus 或者需要注入 DB 查询)
+            # 为了解耦,这里直接查 Milvus 的所有集合,或者如果需要 DB 过滤,则需要注入 db session
+            # 简单起见,先查 Milvus
+            target_collections = milvus_service.client.list_collections()
+        
+        if not target_collections:
+             return PaginatedResponseSchema(
+                code=0, message="没有可用的知识库", data=[], 
+                meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
+            )
+        
+        # 2. 计算分页逻辑 (跨集合分页算法)
+        global_total = 0
+        items = []
+        
+        # 需要跳过的全局偏移量
+        skip_count = (page - 1) * page_size
+        # 需要获取的目标数量
+        need_count = page_size
+        
+        # 遍历所有集合
+        for col_name in target_collections:
+            if not milvus_service.has_collection(col_name):
+                continue
+                
+            try:
+                # 获取该集合总数
+                stats = milvus_service.client.get_collection_stats(col_name)
+                col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
+                
+                if keyword:
+                    # 关键词模式:必须实际查询
+                    desc = milvus_service.client.describe_collection(col_name)
+                    existing_fields = [f['name'] for f in desc.get('fields', [])]
+                    
+                    # 尝试获取所有字段
+                    output_fields = ["*"]
+                    
+                    expr = f'text like "%{keyword}%"' if 'text' in existing_fields else "" 
+                    if not expr: continue 
+                    
+                    res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
+                    col_hits = len(res)
+                    global_total += col_hits
+                    
+                    if skip_count >= col_hits:
+                        skip_count -= col_hits
+                        continue
+                    
+                    take = min(need_count, col_hits - skip_count)
+                    chunk = res[skip_count : skip_count + take]
+                    
+                    for r in chunk:
+                        items.append(format_snippet(r, col_name))
+                    
+                    skip_count = 0 
+                    need_count -= take
+                    if need_count <= 0: break
+                    
+                else:
+                    # 无关键词模式
+                    global_total += col_count
+                    
+                    if skip_count >= col_count:
+                        skip_count -= col_count
+                        continue
+                    
+                    if need_count > 0:
+                        current_offset = skip_count
+                        current_limit = min(need_count, col_count - current_offset)
+                        
+                        output_fields = ["*"]
+                        
+                        res = milvus_service.client.query(
+                            collection_name=col_name,
+                            filter="",
+                            output_fields=output_fields,
+                            limit=current_limit,
+                            offset=current_offset
+                        )
+                        
+                        for r in res:
+                            items.append(format_snippet(r, col_name))
+                        
+                        skip_count = 0 
+                        need_count -= current_limit
+
+            except Exception as e:
+                print(f"Collection {col_name} query error: {e}")
+                continue
+
+        total_pages = (global_total + page_size - 1) // page_size
+
+        return PaginatedResponseSchema(
+            code=0, 
+            message="获取成功", 
+            data=items, 
+            meta=PaginationSchema(total=global_total, page=page, page_size=page_size, total_pages=total_pages)
+        )
+        
+    except Exception as e:
+        print(f"Query Snippets Error: {e}")
+        return PaginatedResponseSchema(
+            code=500, message=f"查询失败: {str(e)}", data=[], 
+            meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
+        )
+
+def format_snippet(r: Dict, col_name: str) -> Dict:
+    id_val = r.get("id") or r.get("pk")
+    content = r.get("text") or r.get("content") or r.get("page_content") or ""
+    
+    # 兜底:如果内容为空,显示 Keys 以便调试
+    if not content:
+        try:
+            debug_content = r.copy()
+            if "dense" in debug_content: del debug_content["dense"]
+            content = json.dumps(debug_content, default=str, ensure_ascii=False)
+        except:
+            content = "无法解析内容"
+
+    doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
+    meta_info = f"ParentID: {r.get('parent_id', '-')}"
+    
+    return {
+        "id": str(id_val),
+        "collection_name": col_name,
+        "doc_name": doc_name,
+        "code": f"SNIP-{id_val}",
+        "content": content,
+        "char_count": len(content) if content else 0,
+        "meta_info": meta_info,
+        "status": "normal",
+        "created_at": "-",
+        "updated_at": "-"
+    }
+
+@router.post("", response_model=ResponseSchema)
+async def create_snippet(
+    payload: SnippetCreate,
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """创建知识片段"""
+    try:
+        import random
+        fake_vector = [random.random() for _ in range(768)] 
+        
+        data = [{
+            "vector": fake_vector,
+            "text": payload.content,
+            "source": payload.doc_name,
+            "doc_id": "manual_add",
+            "file_name": payload.doc_name, # 确保这些字段都有值
+            "title": payload.doc_name
+        }]
+        
+        res = milvus_service.client.insert(
+            collection_name=payload.collection_name,
+            data=data
+        )
+        
+        milvus_service.client.flush(payload.collection_name)
+        
+        return ResponseSchema(code=0, message="创建成功", data={"count": res.get("insert_count", 1)})
+    except Exception as e:
+        print(f"Create Snippet Error: {e}")
+        return ResponseSchema(code=500, message=str(e))
+
+@router.put("/{id}", response_model=ResponseSchema)
+async def update_snippet(
+    id: str,
+    payload: SnippetUpdate,
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """更新知识片段"""
+    try:
+        kb = payload.collection_name
+        
+        # 1. 删除旧数据
+        desc = milvus_service.client.describe_collection(kb)
+        fields = [f['name'] for f in desc.get('fields', [])]
+        pk_field = "pk" if "pk" in fields else "id"
+        
+        if id.isdigit():
+            expr = f"{pk_field} in [{id}]"
+        else:
+            expr = f"{pk_field} in ['{id}']"
+        
+        milvus_service.client.delete(collection_name=kb, filter=expr)
+        
+        # 2. 插入新数据
+        import random
+        fake_vector = [random.random() for _ in range(768)] 
+        
+        data = [{
+            "vector": fake_vector,
+            "text": payload.content,
+            "source": payload.doc_name or "已更新",
+            "doc_id": "updated",
+            "file_name": payload.doc_name,
+            "title": payload.doc_name
+        }]
+        
+        milvus_service.client.insert(collection_name=kb, data=data)
+        milvus_service.client.flush(kb)
+        
+        return ResponseSchema(code=0, message="更新成功 (ID已变更)")
+    except Exception as e:
+        print(f"Update Snippet Error: {e}")
+        return ResponseSchema(code=500, message=str(e))
+
+@router.delete("/{id}", response_model=ResponseSchema)
+async def delete_snippet(
+    id: str, 
+    kb: str = Query(..., description="知识库名称"), 
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """删除知识片段"""
+    try:
+        if not milvus_service.has_collection(kb):
+             return ResponseSchema(code=404, message="知识库不存在")
+             
+        desc = milvus_service.client.describe_collection(kb)
+        fields = [f['name'] for f in desc.get('fields', [])]
+        pk_field = "pk" if "pk" in fields else "id"
+        
+        if id.isdigit():
+            expr = f"{pk_field} in [{id}]"
+        else:
+            expr = f"{pk_field} in ['{id}']"
+            
+        milvus_service.client.delete(
+            collection_name=kb,
+            filter=expr
+        )
+        milvus_service.client.flush(kb)
+        
+        return ResponseSchema(code=0, message="删除成功")
+    except Exception as e:
+        return ResponseSchema(code=500, message=str(e))

+ 2 - 2
test/run_server.py

@@ -9,9 +9,9 @@ import socket
 import logging
 
 # 添加项目根目录到Python路径
-project_root = os.path.dirname(os.path.abspath(__file__))
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.dirname(current_dir)
 sys.path.insert(0, project_root)
-sys.path.insert(0, os.path.join(project_root, 'src'))
 
 # 导入配置
 from src.app.core.config import config_handler