Browse Source

代码结构优化

linyang 1 month ago
parent
commit
4874ba220b
3 changed files with 269 additions and 262 deletions
  1. 229 0
      src/app/services/snippet_service.py
  2. 11 33
      src/views/knowledge_base_view.py
  3. 29 229
      src/views/snippet_view.py

+ 229 - 0
src/app/services/snippet_service.py

@@ -0,0 +1,229 @@
+
+"""
+知识片段业务逻辑服务
+"""
+from typing import List, Optional, Tuple, Dict, Any
+import json
+import random
+from datetime import datetime
+
+from app.services.milvus_service import milvus_service
+from app.schemas.base import PaginationSchema, PaginatedResponseSchema
+
+class SnippetService:
+    
+    def get_list(
+        self,
+        page: int = 1,
+        page_size: int = 10,
+        kb: Optional[str] = None,
+        keyword: Optional[str] = None,
+        status: Optional[str] = None
+    ) -> Tuple[List[Dict], PaginationSchema]:
+        """获取知识片段列表 (跨集合查询)"""
+        
+        # 1. 确定要查询的目标集合列表
+        target_collections = []
+        if kb:
+            target_collections = [kb]
+        else:
+            # 简单起见,先查 Milvus 的所有集合
+            target_collections = milvus_service.client.list_collections()
+        
+        if not target_collections:
+            return [], PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
+        
+        # 2. 计算分页逻辑 (跨集合分页算法)
+        global_total = 0
+        items = []
+        
+        # 需要跳过的全局偏移量
+        skip_count = (page - 1) * page_size
+        # 需要获取的目标数量
+        need_count = page_size
+        
+        # 遍历所有集合
+        for col_name in target_collections:
+            if not milvus_service.has_collection(col_name):
+                continue
+                
+            try:
+                # 获取该集合总数
+                stats = milvus_service.client.get_collection_stats(col_name)
+                col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
+                
+                if keyword:
+                    # 关键词模式:必须实际查询
+                    desc = milvus_service.client.describe_collection(col_name)
+                    existing_fields = [f['name'] for f in desc.get('fields', [])]
+                    
+                    # 尝试获取所有字段
+                    output_fields = ["*"]
+                    
+                    expr = f'text like "%{keyword}%"' if 'text' in existing_fields else "" 
+                    if not expr: continue 
+                    
+                    res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
+                    col_hits = len(res)
+                    global_total += col_hits
+                    
+                    if skip_count >= col_hits:
+                        skip_count -= col_hits
+                        continue
+                    
+                    take = min(need_count, col_hits - skip_count)
+                    chunk = res[skip_count : skip_count + take]
+                    
+                    for r in chunk:
+                        items.append(self._format_snippet(r, col_name))
+                    
+                    skip_count = 0 
+                    need_count -= take
+                    if need_count <= 0: break
+                    
+                else:
+                    # 无关键词模式
+                    global_total += col_count
+                    
+                    if skip_count >= col_count:
+                        skip_count -= col_count
+                        continue
+                    
+                    if need_count > 0:
+                        current_offset = skip_count
+                        current_limit = min(need_count, col_count - current_offset)
+                        
+                        output_fields = ["*"]
+                        
+                        res = milvus_service.client.query(
+                            collection_name=col_name,
+                            filter="",
+                            output_fields=output_fields,
+                            limit=current_limit,
+                            offset=current_offset
+                        )
+                        
+                        for r in res:
+                            items.append(self._format_snippet(r, col_name))
+                        
+                        skip_count = 0 
+                        need_count -= current_limit
+
+            except Exception as e:
+                print(f"Collection {col_name} query error: {e}")
+                continue
+
+        total_pages = (global_total + page_size - 1) // page_size if page_size else 0
+
+        meta = PaginationSchema(
+            page=page,
+            page_size=page_size,
+            total=global_total,
+            total_pages=total_pages
+        )
+        
+        return items, meta
+
+    def create(self, payload: Any) -> Dict:
+        """创建知识片段"""
+        fake_vector = [random.random() for _ in range(768)] 
+        
+        data = [{
+            "vector": fake_vector,
+            "text": payload.content,
+            "source": payload.doc_name,
+            "doc_id": "manual_add",
+            "file_name": payload.doc_name, 
+            "title": payload.doc_name
+        }]
+        
+        res = milvus_service.client.insert(
+            collection_name=payload.collection_name,
+            data=data
+        )
+        
+        milvus_service.client.flush(payload.collection_name)
+        return {"count": res.get("insert_count", 1)}
+
+    def update(self, id: str, payload: Any) -> str:
+        """更新知识片段"""
+        kb = payload.collection_name
+        
+        # 1. 删除旧数据
+        desc = milvus_service.client.describe_collection(kb)
+        fields = [f['name'] for f in desc.get('fields', [])]
+        pk_field = "pk" if "pk" in fields else "id"
+        
+        if id.isdigit():
+            expr = f"{pk_field} in [{id}]"
+        else:
+            expr = f"{pk_field} in ['{id}']"
+        
+        milvus_service.client.delete(collection_name=kb, filter=expr)
+        
+        # 2. 插入新数据
+        fake_vector = [random.random() for _ in range(768)] 
+        
+        data = [{
+            "vector": fake_vector,
+            "text": payload.content,
+            "source": payload.doc_name or "已更新",
+            "doc_id": "updated",
+            "file_name": payload.doc_name,
+            "title": payload.doc_name
+        }]
+        
+        milvus_service.client.insert(collection_name=kb, data=data)
+        milvus_service.client.flush(kb)
+        
+        return "更新成功 (ID已变更)"
+
+    def delete(self, id: str, kb: str) -> None:
+        """删除知识片段"""
+        if not milvus_service.has_collection(kb):
+             raise ValueError("知识库不存在")
+             
+        desc = milvus_service.client.describe_collection(kb)
+        fields = [f['name'] for f in desc.get('fields', [])]
+        pk_field = "pk" if "pk" in fields else "id"
+        
+        if id.isdigit():
+            expr = f"{pk_field} in [{id}]"
+        else:
+            expr = f"{pk_field} in ['{id}']"
+            
+        milvus_service.client.delete(
+            collection_name=kb,
+            filter=expr
+        )
+        milvus_service.client.flush(kb)
+
+    def _format_snippet(self, r: Dict, col_name: str) -> Dict:
+        id_val = r.get("id") or r.get("pk")
+        content = r.get("text") or r.get("content") or r.get("page_content") or ""
+        
+        if not content:
+            try:
+                debug_content = r.copy()
+                if "dense" in debug_content: del debug_content["dense"]
+                content = json.dumps(debug_content, default=str, ensure_ascii=False)
+            except:
+                content = "无法解析内容"
+
+        doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
+        meta_info = f"ParentID: {r.get('parent_id', '-')}"
+        
+        return {
+            "id": str(id_val),
+            "collection_name": col_name,
+            "doc_name": doc_name,
+            "code": f"SNIP-{id_val}",
+            "content": content,
+            "char_count": len(content) if content else 0,
+            "meta_info": meta_info,
+            "status": "normal",
+            "created_at": "-",
+            "updated_at": "-"
+        }
+
+snippet_service = SnippetService()

+ 11 - 33
src/views/knowledge_base_view.py

@@ -38,8 +38,6 @@ async def get_knowledge_bases(
         db, page=page, page_size=page_size, keyword=keyword, status=status
     )
 
-    print("11111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222")
-
     return PaginatedResponseSchema(
         code=0,
         message="获取知识库列表成功",
@@ -58,15 +56,10 @@ async def create_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    try:
-        new_kb = await knowledge_base_service.create(db, payload)
-        return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
-    except ValueError as e:
-        return ResponseSchema(code=400, message=str(e))
-    except Exception as e:
-        return ResponseSchema(code=500, message=f"创建失败: {str(e)}")
+    new_kb = await knowledge_base_service.create(db, payload)
+    return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
 
-@router.put("/{id}", response_model=ResponseSchema)
+@router.post("/{id}", response_model=ResponseSchema)
 async def update_knowledge_base(
     payload: KnowledgeBaseUpdate,
     id: str = Path(..., description="知识库ID"),
@@ -78,15 +71,10 @@ async def update_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    try:
-        kb = await knowledge_base_service.update(db, id, payload)
-        return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
-    except ValueError as e:
-        return ResponseSchema(code=404, message=str(e))
-    except Exception as e:
-        return ResponseSchema(code=500, message=f"更新失败: {str(e)}")
+    kb = await knowledge_base_service.update(db, id, payload)
+    return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
 
-@router.patch("/{id}/status", response_model=ResponseSchema)
+@router.post("/{id}/status", response_model=ResponseSchema)
 async def update_knowledge_base_status(
     id: str = Path(..., description="知识库ID"),
     status: str = Query(..., description="状态: normal/test/disabled"),
@@ -98,15 +86,10 @@ async def update_knowledge_base_status(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    try:
-        kb = await knowledge_base_service.update_status(db, id, status)
-        return ResponseSchema(code=0, message=f"状态已更新为 {status}")
-    except ValueError as e:
-        return ResponseSchema(code=404, message=str(e))
-    except Exception as e:
-        return ResponseSchema(code=500, message=f"状态更新失败: {str(e)}")
+    await knowledge_base_service.update_status(db, id, status)
+    return ResponseSchema(code=0, message=f"状态已更新为 {status}")
 
-@router.delete("/{id}", response_model=ResponseSchema)
+@router.post("/{id}/delete", response_model=ResponseSchema)
 async def delete_knowledge_base(
     id: str = Path(..., description="知识库ID"),
     db: AsyncSession = Depends(get_db),
@@ -117,10 +100,5 @@ async def delete_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    try:
-        await knowledge_base_service.delete(db, id)
-        return ResponseSchema(code=0, message="删除成功")
-    except ValueError as e:
-        return ResponseSchema(code=404, message=str(e))
-    except Exception as e:
-        return ResponseSchema(code=500, message=f"删除失败: {str(e)}")
+    await knowledge_base_service.delete(db, id)
+    return ResponseSchema(code=0, message="删除成功")

+ 29 - 229
src/views/snippet_view.py

@@ -2,14 +2,10 @@
 知识片段视图路由
 """
 from fastapi import APIRouter, Depends, Query, Path, Body
-from sqlalchemy.ext.asyncio import AsyncSession
-from typing import Optional, List, Dict, Any
-from datetime import datetime, timezone
-import json
+from typing import Optional
 
-from app.base.async_mysql_connection import get_db
-from app.services.milvus_service import milvus_service
-from app.schemas.base import ResponseSchema, PaginatedResponseSchema, PaginationSchema
+from app.services.snippet_service import snippet_service
+from app.schemas.base import ResponseSchema, PaginatedResponseSchema
 from app.services.jwt_token import verify_token
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from pydantic import BaseModel
@@ -39,147 +35,14 @@ async def get_snippets(
     credentials: HTTPAuthorizationCredentials = Depends(security)
 ):
     """获取知识片段列表 (跨集合查询)"""
-    try:
-        # 1. 确定要查询的目标集合列表
-        target_collections = []
-        if kb:
-            target_collections = [kb]
-        else:
-            # 查询所有启用的知识库 (这里简单起见,直接查 Milvus 或者需要注入 DB 查询)
-            # 为了解耦,这里直接查 Milvus 的所有集合,或者如果需要 DB 过滤,则需要注入 db session
-            # 简单起见,先查 Milvus
-            target_collections = milvus_service.client.list_collections()
-        
-        if not target_collections:
-             return PaginatedResponseSchema(
-                code=0, message="没有可用的知识库", data=[], 
-                meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
-            )
-        
-        # 2. 计算分页逻辑 (跨集合分页算法)
-        global_total = 0
-        items = []
-        
-        # 需要跳过的全局偏移量
-        skip_count = (page - 1) * page_size
-        # 需要获取的目标数量
-        need_count = page_size
-        
-        # 遍历所有集合
-        for col_name in target_collections:
-            if not milvus_service.has_collection(col_name):
-                continue
-                
-            try:
-                # 获取该集合总数
-                stats = milvus_service.client.get_collection_stats(col_name)
-                col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
-                
-                if keyword:
-                    # 关键词模式:必须实际查询
-                    desc = milvus_service.client.describe_collection(col_name)
-                    existing_fields = [f['name'] for f in desc.get('fields', [])]
-                    
-                    # 尝试获取所有字段
-                    output_fields = ["*"]
-                    
-                    expr = f'text like "%{keyword}%"' if 'text' in existing_fields else "" 
-                    if not expr: continue 
-                    
-                    res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
-                    col_hits = len(res)
-                    global_total += col_hits
-                    
-                    if skip_count >= col_hits:
-                        skip_count -= col_hits
-                        continue
-                    
-                    take = min(need_count, col_hits - skip_count)
-                    chunk = res[skip_count : skip_count + take]
-                    
-                    for r in chunk:
-                        items.append(format_snippet(r, col_name))
-                    
-                    skip_count = 0 
-                    need_count -= take
-                    if need_count <= 0: break
-                    
-                else:
-                    # 无关键词模式
-                    global_total += col_count
-                    
-                    if skip_count >= col_count:
-                        skip_count -= col_count
-                        continue
-                    
-                    if need_count > 0:
-                        current_offset = skip_count
-                        current_limit = min(need_count, col_count - current_offset)
-                        
-                        output_fields = ["*"]
-                        
-                        res = milvus_service.client.query(
-                            collection_name=col_name,
-                            filter="",
-                            output_fields=output_fields,
-                            limit=current_limit,
-                            offset=current_offset
-                        )
-                        
-                        for r in res:
-                            items.append(format_snippet(r, col_name))
-                        
-                        skip_count = 0 
-                        need_count -= current_limit
-
-            except Exception as e:
-                print(f"Collection {col_name} query error: {e}")
-                continue
-
-        total_pages = (global_total + page_size - 1) // page_size
-
-        return PaginatedResponseSchema(
-            code=0, 
-            message="获取成功", 
-            data=items, 
-            meta=PaginationSchema(total=global_total, page=page, page_size=page_size, total_pages=total_pages)
-        )
-        
-    except Exception as e:
-        print(f"Query Snippets Error: {e}")
-        return PaginatedResponseSchema(
-            code=500, message=f"查询失败: {str(e)}", data=[], 
-            meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
-        )
-
-def format_snippet(r: Dict, col_name: str) -> Dict:
-    id_val = r.get("id") or r.get("pk")
-    content = r.get("text") or r.get("content") or r.get("page_content") or ""
-    
-    # 兜底:如果内容为空,显示 Keys 以便调试
-    if not content:
-        try:
-            debug_content = r.copy()
-            if "dense" in debug_content: del debug_content["dense"]
-            content = json.dumps(debug_content, default=str, ensure_ascii=False)
-        except:
-            content = "无法解析内容"
-
-    doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
-    meta_info = f"ParentID: {r.get('parent_id', '-')}"
+    items, meta = snippet_service.get_list(page, page_size, kb, keyword, status)
     
-    return {
-        "id": str(id_val),
-        "collection_name": col_name,
-        "doc_name": doc_name,
-        "code": f"SNIP-{id_val}",
-        "content": content,
-        "char_count": len(content) if content else 0,
-        "meta_info": meta_info,
-        "status": "normal",
-        "created_at": "-",
-        "updated_at": "-"
-    }
+    return PaginatedResponseSchema(
+        code=0, 
+        message="获取成功", 
+        data=items, 
+        meta=meta
+    )
 
 @router.post("", response_model=ResponseSchema)
 async def create_snippet(
@@ -187,100 +50,37 @@ async def create_snippet(
     credentials: HTTPAuthorizationCredentials = Depends(security)
 ):
     """创建知识片段"""
-    try:
-        import random
-        fake_vector = [random.random() for _ in range(768)] 
-        
-        data = [{
-            "vector": fake_vector,
-            "text": payload.content,
-            "source": payload.doc_name,
-            "doc_id": "manual_add",
-            "file_name": payload.doc_name, # 确保这些字段都有值
-            "title": payload.doc_name
-        }]
-        
-        res = milvus_service.client.insert(
-            collection_name=payload.collection_name,
-            data=data
-        )
-        
-        milvus_service.client.flush(payload.collection_name)
-        
-        return ResponseSchema(code=0, message="创建成功", data={"count": res.get("insert_count", 1)})
-    except Exception as e:
-        print(f"Create Snippet Error: {e}")
-        return ResponseSchema(code=500, message=str(e))
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    data = snippet_service.create(payload)
+    return ResponseSchema(code=0, message="创建成功", data=data)
 
-@router.put("/{id}", response_model=ResponseSchema)
+@router.post("/{id}", response_model=ResponseSchema)
 async def update_snippet(
     id: str,
     payload: SnippetUpdate,
     credentials: HTTPAuthorizationCredentials = Depends(security)
 ):
     """更新知识片段"""
-    try:
-        kb = payload.collection_name
-        
-        # 1. 删除旧数据
-        desc = milvus_service.client.describe_collection(kb)
-        fields = [f['name'] for f in desc.get('fields', [])]
-        pk_field = "pk" if "pk" in fields else "id"
-        
-        if id.isdigit():
-            expr = f"{pk_field} in [{id}]"
-        else:
-            expr = f"{pk_field} in ['{id}']"
-        
-        milvus_service.client.delete(collection_name=kb, filter=expr)
-        
-        # 2. 插入新数据
-        import random
-        fake_vector = [random.random() for _ in range(768)] 
-        
-        data = [{
-            "vector": fake_vector,
-            "text": payload.content,
-            "source": payload.doc_name or "已更新",
-            "doc_id": "updated",
-            "file_name": payload.doc_name,
-            "title": payload.doc_name
-        }]
-        
-        milvus_service.client.insert(collection_name=kb, data=data)
-        milvus_service.client.flush(kb)
-        
-        return ResponseSchema(code=0, message="更新成功 (ID已变更)")
-    except Exception as e:
-        print(f"Update Snippet Error: {e}")
-        return ResponseSchema(code=500, message=str(e))
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    msg = snippet_service.update(id, payload)
+    return ResponseSchema(code=0, message=msg)
 
-@router.delete("/{id}", response_model=ResponseSchema)
+@router.post("/{id}/delete", response_model=ResponseSchema)
 async def delete_snippet(
     id: str, 
     kb: str = Query(..., description="知识库名称"), 
     credentials: HTTPAuthorizationCredentials = Depends(security)
 ):
     """删除知识片段"""
-    try:
-        if not milvus_service.has_collection(kb):
-             return ResponseSchema(code=404, message="知识库不存在")
-             
-        desc = milvus_service.client.describe_collection(kb)
-        fields = [f['name'] for f in desc.get('fields', [])]
-        pk_field = "pk" if "pk" in fields else "id"
-        
-        if id.isdigit():
-            expr = f"{pk_field} in [{id}]"
-        else:
-            expr = f"{pk_field} in ['{id}']"
-            
-        milvus_service.client.delete(
-            collection_name=kb,
-            filter=expr
-        )
-        milvus_service.client.flush(kb)
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
         
-        return ResponseSchema(code=0, message="删除成功")
-    except Exception as e:
-        return ResponseSchema(code=500, message=str(e))
+    snippet_service.delete(id, kb)
+    return ResponseSchema(code=0, message="删除成功")