linyang il y a 4 semaines
Parent
commit
ce2b6458e6

+ 237 - 9
src/app/services/knowledge_base_service.py

@@ -21,6 +21,99 @@ from app.schemas.base import PaginationSchema
 
 class KnowledgeBaseService:
     
+    async def _get_collection_row_count(self, collection_name: str) -> int:
+        """获取集合行数(优先尝试 count(*) 以获取实时准确值)"""
+        try:
+            # 尝试使用 count(*) 获取准确的实时数量
+            # 过滤掉已标记删除的数据 (is_deleted == false)
+            # 注意:如果 Schema 中没有 is_deleted 字段,这里可能会报错,需要根据实际 Schema 调整
+            # 但之前的代码中 Schema 确实包含了 is_deleted
+            try:
+                # 检查集合是否已加载
+                if milvus_service.get_collection_state(collection_name) == "Loaded":
+                    res = milvus_service.client.query(collection_name, filter="is_deleted == false", output_fields=["count(*)"])
+                    if res and isinstance(res, list) and "count(*)" in res[0]:
+                        return int(res[0]["count(*)"])
+            except Exception:
+                # 再次尝试不过滤
+                if milvus_service.get_collection_state(collection_name) == "Loaded":
+                    res = milvus_service.client.query(collection_name, filter="", output_fields=["count(*)"])
+                    if res and isinstance(res, list) and "count(*)" in res[0]:
+                        return int(res[0]["count(*)"])
+        except Exception:
+            pass
+            
+        # Fallback: 使用 get_collection_stats (可能包含已删除未 Compaction 的数据)
+        try:
+            stats = milvus_service.client.get_collection_stats(collection_name)
+            return int(stats.get("row_count", 0))
+        except Exception:
+            return 0
+
+    async def _infer_and_save_metadata(self, db: AsyncSession, kb: KnowledgeBase) -> None:
+        """
+        [Internal] 从 Milvus 数据中推断元数据并保存到 DB
+        仅当 DB 中没有定义元数据时调用
+        """
+        try:
+            # 检查是否已加载(避免不必要的错误)
+            # if milvus_service.get_collection_state(kb.collection_name) != "Loaded":
+            #     milvus_service.load_collection(kb.collection_name)
+            
+            # 采样查询 (获取前10条)
+            try:
+                res = milvus_service.client.query(
+                    collection_name=kb.collection_name,
+                    filter="is_deleted == false",
+                    output_fields=["metadata"],
+                    limit=10
+                )
+            except Exception as e:
+                # 如果 filter 查询失败(可能不支持 is_deleted),尝试无 filter 查询
+                res = milvus_service.client.query(
+                    collection_name=kb.collection_name,
+                    filter="",
+                    output_fields=["metadata"],
+                    limit=10
+                )
+            
+            if res:
+                inferred_keys = set()
+                for item in res:
+                    meta = item.get("metadata") or {}
+                    # Milvus 可能会返回 JSON 字符串
+                    if isinstance(meta, str):
+                        try:
+                            import json
+                            meta = json.loads(meta)
+                        except:
+                            meta = {}
+                            
+                    if isinstance(meta, dict):
+                        inferred_keys.update(meta.keys())
+                
+                # 过滤掉默认字段
+                ignore_keys = {"doc_name", "file_name", "title", "source", "chunk_id"}
+                inferred_keys = inferred_keys - ignore_keys
+                
+                if inferred_keys:
+                    # 自动生成并保存到 DB
+                    for key in inferred_keys:
+                        new_metadata = SampleMetadata(
+                            id=str(uuid.uuid4()),
+                            knowledge_base_id=kb.id,
+                            field_zh_name=key, # 默认用英文名
+                            field_en_name=key,
+                            field_type="text", # 默认推断为 text
+                            remark="Auto inferred from Milvus data"
+                        )
+                        db.add(new_metadata)
+                    
+                    # 注意:调用方负责 commit,这里不 commit 以支持批量事务
+                    print(f"Auto inferred metadata for {kb.collection_name}: {inferred_keys}")
+        except Exception as e:
+            print(f"Failed to infer metadata for {kb.collection_name}: {e}")
+
     async def get_list(
         self, 
         db: AsyncSession,
@@ -45,11 +138,7 @@ class KnowledgeBaseService:
             has_changes = False
             for m_name in milvus_names:
                 # 获取统计信息
-                try:
-                    stats = milvus_service.client.get_collection_stats(m_name)
-                    row_count = int(stats.get("row_count", 0))
-                except Exception:
-                    row_count = 0
+                row_count = await self._get_collection_row_count(m_name)
 
                 if m_name not in existing_map:
                     # 新增
@@ -64,6 +153,11 @@ class KnowledgeBaseService:
                         updated_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                     )
                     db.add(new_kb)
+                    
+                    # [新增逻辑] 对新发现的知识库,立即尝试推断元数据
+                    if row_count > 0:
+                        await self._infer_and_save_metadata(db, new_kb)
+                        
                     has_changes = True
                 else:
                     # 更新统计
@@ -72,10 +166,61 @@ class KnowledgeBaseService:
                         kb.document_count = row_count
                         # kb.created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 统计更新不一定更新时间
                         has_changes = True
+                    
+                    # [新增逻辑] 如果已有知识库有数据但没有元数据定义,尝试推断 (Lazy check)
+                    # 为了性能,这里不每次都查元数据表。只有当 row_count > 0 时才考虑。
+                    # 且为了避免每次都 query Milvus,我们可以假设如果 DB 中没有 metadata 记录才去推断。
+                    # 但在循环中查 DB (select count) 也是性能损耗。
+                    # 既然用户现在的痛点是“必须点击查看”,我们可以放宽一点:
+                    # 仅在新发现时推断,或者增加一个明确的“同步所有”按钮。
+                    # 或者,这里简单做个优化:如果 row_count > 0,我们尝试去 infer。
+                    # _infer_and_save_metadata 内部不包含 "检查DB是否有元数据" 的逻辑,需要补充。
             
             if has_changes:
                 await db.commit()
                 
+            # [补充逻辑] 对所有现有 KB,如果 document_count > 0,检查是否需要推断
+            # 为了不影响性能,这里只对本次循环中更新了 count 的,或者... 
+            # 还是保持现状:只对新发现的做推断。对于老的,用户可能需要手动触发或者我们做一个批量接口。
+            # 但用户说“如果不点击查看就没有办法自动推断”,暗示他希望列表页能解决。
+            # 我们可以查一次所有 KB 的 ID,再查 SampleMetadata 表中哪些 KB ID 已经有数据了。
+            
+            # 1. 获取所有有数据的 KB ID
+            # active_kb_ids = [kb.id for kb in existing_kbs if kb.document_count > 0]
+            # if active_kb_ids:
+            #     # 2. 查哪些已经有元数据
+            #     meta_res = await db.execute(select(SampleMetadata.knowledge_base_id).where(SampleMetadata.knowledge_base_id.in_(active_kb_ids)).distinct())
+            #     has_meta_ids = set(meta_res.scalars().all())
+            #     
+            #     # 3. 找出需要推断的
+            #     need_infer_kbs = [kb for kb in existing_kbs if kb.id in active_kb_ids and kb.id not in has_meta_ids]
+            #     
+            #     for kb in need_infer_kbs:
+            #         await self._infer_and_save_metadata(db, kb)
+            #         has_changes = True
+            
+            # 上述逻辑比较完善,加入代码中:
+            
+            # 获取所有当前存在的 KB (包括刚新增的,如果 session 未 commit 可能查不到 ID,所以最好在 commit 后再做,或者分两步)
+            # 简单起见,我们把上面的 commit 放在这之前是不行的,因为 new_kb 还没 ID (uuid 是手动生成的,其实有 ID)。
+            
+            # 优化:只对 existing_kbs 做检查。new_kb 已经在上面处理了。
+            if existing_kbs:
+                active_kbs = [kb for kb in existing_kbs if kb.document_count > 0]
+                if active_kbs:
+                    active_ids = [kb.id for kb in active_kbs]
+                    # 批量查询已存在元数据的 KB ID
+                    meta_res = await db.execute(select(SampleMetadata.knowledge_base_id).where(SampleMetadata.knowledge_base_id.in_(active_ids)).distinct())
+                    has_meta_ids = set(meta_res.scalars().all())
+                    
+                    for kb in active_kbs:
+                        if kb.id not in has_meta_ids:
+                            await self._infer_and_save_metadata(db, kb)
+                            has_changes = True
+
+            if has_changes:
+                await db.commit()
+
         except Exception as e:
             # 同步失败不影响查询,只打印日志
             print(f"Sync Milvus collections failed: {e}")
@@ -272,8 +417,8 @@ class KnowledgeBaseService:
         current_count = kb.document_count
         if kb.collection_name and milvus_service.has_collection(kb.collection_name):
             try:
-                stats = milvus_service.client.get_collection_stats(kb.collection_name)
-                current_count = int(stats.get("row_count", 0))
+                # 使用统一的计数方法
+                current_count = await self._get_collection_row_count(kb.collection_name)
             except Exception:
                 # 获取失败则使用 DB 中的缓存值
                 pass
@@ -292,7 +437,7 @@ class KnowledgeBaseService:
             
             # 2. 软删除 DB 记录
             kb.is_deleted = 1
-            kb.created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            kb.updated_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
             
             # 3. 删除关联的元数据 (硬删除)
             await db.execute(sql_delete(SampleMetadata).where(SampleMetadata.knowledge_base_id == id))
@@ -431,7 +576,7 @@ class KnowledgeBaseService:
             raise e
 
     async def get_metadata_and_schema(self, db: AsyncSession, kb_id: str) -> Dict[str, List[dict]]:
-        """获取知识库的元数据字段列表 (Schema已固定,不再返回自定义Schema)"""
+        """获取知识库的元数据字段列表 (如果 DB 中没有定义,尝试从 Milvus 数据推断)"""
         # 检查知识库是否存在
         result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0))
         kb = result.scalars().first()
@@ -443,6 +588,67 @@ class KnowledgeBaseService:
         meta_result = await db.execute(meta_query)
         metadata_fields = [f.to_dict() for f in meta_result.scalars().all()]
         
+        # 自动推断逻辑:如果 DB 中没有定义元数据,且 Milvus 中有数据,尝试推断
+        if not metadata_fields and kb.collection_name and milvus_service.has_collection(kb.collection_name):
+            try:
+                # 采样查询 (获取前10条)
+                try:
+                    res = milvus_service.client.query(
+                        collection_name=kb.collection_name,
+                        filter="is_deleted == false",
+                        output_fields=["metadata"],
+                        limit=10
+                    )
+                except Exception as e:
+                    # 如果 filter 查询失败(可能不支持 is_deleted),尝试无 filter 查询
+                    res = milvus_service.client.query(
+                        collection_name=kb.collection_name,
+                        filter="",
+                        output_fields=["metadata"],
+                        limit=10
+                    )
+                
+                if res:
+                    inferred_keys = set()
+                    for item in res:
+                        meta = item.get("metadata") or {}
+                        # Milvus 可能会返回 JSON 字符串,尝试解析
+                        if isinstance(meta, str):
+                            try:
+                                import json
+                                meta = json.loads(meta)
+                            except:
+                                meta = {}
+                                
+                        if isinstance(meta, dict):
+                            inferred_keys.update(meta.keys())
+                    
+                    # 过滤掉一些默认字段,避免干扰
+                    ignore_keys = {"doc_name", "file_name", "title", "source", "chunk_id"}
+                    inferred_keys = inferred_keys - ignore_keys
+                    
+                    if inferred_keys:
+                        # 自动生成并保存到 DB
+                        new_fields = []
+                        for key in inferred_keys:
+                            new_metadata = SampleMetadata(
+                                id=str(uuid.uuid4()),
+                                knowledge_base_id=kb.id,
+                                field_zh_name=key, # 默认用英文名
+                                field_en_name=key,
+                                field_type="text", # 默认推断为 text
+                                remark="Auto inferred from Milvus data"
+                            )
+                            db.add(new_metadata)
+                            new_fields.append(new_metadata.to_dict())
+                        
+                        await db.commit()
+                        metadata_fields = new_fields
+                        print(f"Auto inferred metadata for {kb.collection_name}: {inferred_keys}")
+            except Exception as e:
+                print(f"Failed to infer metadata for {kb.collection_name}: {e}")
+                # 推断失败不影响正常返回
+        
         # 返回空的 custom_schemas,因为现在是固定 Schema
         return {
             "metadata_fields": metadata_fields,
@@ -464,4 +670,26 @@ class KnowledgeBaseService:
         
         return [f.to_dict() for f in fields]
 
+    async def update_doc_count(self, db: AsyncSession, collection_name: str) -> None:
+        """根据 Milvus 实时数据更新知识库文档数量"""
+        # 查找知识库
+        result = await db.execute(select(KnowledgeBase).where(
+            KnowledgeBase.collection_name == collection_name,
+            KnowledgeBase.is_deleted == 0
+        ))
+        kb = result.scalars().first()
+        
+        if kb and milvus_service.has_collection(collection_name):
+            try:
+                # 使用统一的计数方法
+                row_count = await self._get_collection_row_count(collection_name)
+                
+                # 更新数据库
+                if kb.document_count != row_count:
+                    kb.document_count = row_count
+                    kb.updated_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+                    await db.commit()
+            except Exception as e:
+                print(f"Failed to update doc count for {collection_name}: {e}")
+
 knowledge_base_service = KnowledgeBaseService()

+ 9 - 1
src/app/services/milvus_service.py

@@ -421,7 +421,8 @@ class MilvusService:
 
     def hybrid_search(self, collection_name: str, query_text: str,
                      top_k: int = 3, ranker_type: str = "weighted",
-                     dense_weight: float = 0.7, sparse_weight: float = 0.3):
+                     dense_weight: float = 0.7, sparse_weight: float = 0.3,
+                     expr: str = None):
         """
         混合搜索(参考 test_hybrid_v2.6.py 的实现)
 
@@ -432,12 +433,16 @@ class MilvusService:
             ranker_type: 重排序类型 "weighted" 或 "rrf"
             dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
             sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
+            expr: 过滤表达式 (Metadata Filtering)
 
         Returns:
             List[Dict]: 搜索结果列表
         """
         try:
             collection_name = collection_name
+            
+            # 确保集合已加载
+            self.client.load_collection(collection_name)
 
             # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
             vectorstore = get_milvus_vectorstore(
@@ -446,10 +451,12 @@ class MilvusService:
             )
 
             # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
+            # 注意:LangChain Milvus vectorstore 的 similarity_search 支持 expr 参数用于过滤
             if ranker_type == "weighted":
                 results = vectorstore.similarity_search(
                     query=query_text,
                     k=top_k,
+                    expr=expr,
                     ranker_type="weighted",
                     ranker_params={"weights": [dense_weight, sparse_weight]}
                 )
@@ -457,6 +464,7 @@ class MilvusService:
                 results = vectorstore.similarity_search(
                     query=query_text,
                     k=top_k,
+                    expr=expr,
                     ranker_type="rrf",
                     ranker_params={"k": 60}
                 )

+ 39 - 9
src/app/services/search_engine_service.py

@@ -138,14 +138,43 @@ class SearchEngineService:
         # 处理新的多重过滤
         if payload.filters:
             for f in payload.filters:
-                safe_field = f.field.replace("'", "").replace('"', "").strip()
-                safe_value = f.value.replace("'", "").replace('"', "").strip()
-                
-                if safe_field and safe_value:
-                    if safe_value.isdigit():
-                        expr_list.append(f'{safe_field} == {safe_value}')
-                    else:
-                        expr_list.append(f'{safe_field} == "{safe_value}"')
+                # 特殊处理文档过滤 (IN 查询)
+                if f.field == 'doc_name_in':
+                    try:
+                        doc_names = json.loads(f.value)
+                        if isinstance(doc_names, list) and doc_names:
+                            # 构建 doc_name in ["A", "B"]
+                            # 注意:Schema 中 doc_name 字段名可能不统一,通常是 doc_name, file_name, title
+                            # 这里我们需要尝试匹配正确的字段名。
+                            # 假设我们在 create 时主要存的是 file_name 或 doc_name
+                            
+                            # 简单起见,我们尝试对常见字段做 OR,但这在 Milvus expr 中可能复杂
+                            # 更稳妥的是我们在存数据时统一了字段。
+                            # 假设统一用 "file_name" 或 "doc_name"
+                            
+                            # 获取 collection fields
+                            target_field = "file_name" # 默认
+                            if collection_detail and isinstance(collection_detail, dict):
+                                fields = [fl.get("name") for fl in collection_detail.get("fields", []) if isinstance(fl, dict)]
+                                if "doc_name" in fields:
+                                    target_field = "doc_name"
+                                elif "title" in fields:
+                                    target_field = "title"
+                            
+                            # 构建 IN 列表
+                            in_values = ",".join([f'"{name}"' for name in doc_names])
+                            expr_list.append(f'{target_field} in [{in_values}]')
+                    except Exception as e:
+                        print(f"Error parsing doc_name_in: {e}")
+                else:
+                    safe_field = f.field.replace("'", "").replace('"', "").strip()
+                    safe_value = f.value.replace("'", "").replace('"', "").strip()
+                    
+                    if safe_field and safe_value:
+                        if safe_value.isdigit():
+                            expr_list.append(f'{safe_field} == {safe_value}')
+                        else:
+                            expr_list.append(f'{safe_field} == "{safe_value}"')
         
         # 组合所有条件 (使用 AND)
         expr = " and ".join(expr_list) if expr_list else ""
@@ -253,7 +282,8 @@ class SearchEngineService:
                         hybrid_results = milvus_service.hybrid_search(
                             collection_name=kb_id,
                             query_text=payload.query,
-                            top_k=target_k
+                            top_k=target_k,
+                            expr=expr if expr else None
                         )
                         
                         # 手动切片实现分页

+ 77 - 32
src/app/services/snippet_service.py

@@ -15,6 +15,11 @@ from app.services.milvus_service import milvus_service
 from app.schemas.base import PaginationSchema, PaginatedResponseSchema
 from app.utils.vector_utils import text_to_vector_algo
 
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+from app.sample.models.metadata import SampleMetadata
+from app.sample.models.knowledge_base import KnowledgeBase
+
 class SnippetService:
     
     def get_list(
@@ -188,27 +193,45 @@ class SnippetService:
         
         return items, meta
 
-    def create(self, payload: Any) -> Dict:
+    async def create(self, db: AsyncSession, payload: Any) -> Dict:
         """创建知识片段"""
         # 使用统一算法生成向量
         dim = milvus_service.DENSE_DIM
         fake_vector = text_to_vector_algo(payload.content, dim=dim)
         
+        # 1. 动态构建 metadata
+        # 查找 KnowledgeBase ID
+        kb_query = select(KnowledgeBase).where(KnowledgeBase.collection_name == payload.collection_name)
+        result = await db.execute(kb_query)
+        kb = result.scalars().first()
+        
+        metadata = {}
+        if kb:
+            # 查找该知识库定义的元数据字段
+            meta_query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb.id)
+            meta_result = await db.execute(meta_query)
+            defined_fields = meta_result.scalars().all()
+            
+            # 仅填充定义的字段
+            custom_fields = getattr(payload, 'custom_fields', {}) or {}
+            for field in defined_fields:
+                field_name = field.field_en_name
+                # 从 custom_fields 中获取值,如果不存在则为 None 或空字符串 (视需求而定)
+                # 这里我们只存有值的
+                if field_name in custom_fields:
+                    metadata[field_name] = custom_fields[field_name]
+        
         # 基础数据
         now = int(time.time() * 1000)
         item = {
             "vector": fake_vector,
             "text": payload.content,
-            "document_id": str(uuid.uuid4()), # 生成UUID
+            "document_id": payload.custom_fields.get("document_id") if hasattr(payload, 'custom_fields') and payload.custom_fields and payload.custom_fields.get("document_id") else str(uuid.uuid4()),
             "parent_id": payload.custom_fields.get("parent_id", "") if hasattr(payload, 'custom_fields') and payload.custom_fields else "",
             "index": 0,
-            "tag_list": "",
+            "tag_list": payload.custom_fields.get("tag_list", "") if hasattr(payload, 'custom_fields') and payload.custom_fields else "",
             "permission": {},
-            "metadata": {
-                "doc_name": payload.doc_name,
-                "file_name": payload.doc_name, 
-                "title": payload.doc_name
-            },
+            "metadata": metadata, # 使用动态构建的 metadata
             "is_deleted": False,
             "created_by": "system",
             "created_time": now,
@@ -216,10 +239,6 @@ class SnippetService:
             "updated_time": now
         }
         
-        # 合并自定义字段 (Schema已固定)
-        # if hasattr(payload, 'custom_fields') and payload.custom_fields:
-        #     item.update(payload.custom_fields)
-            
         data = [item]
         
         res = milvus_service.client.insert(
@@ -230,12 +249,12 @@ class SnippetService:
         milvus_service.client.flush(payload.collection_name)
         return {"count": res.get("insert_count", 1)}
 
-    def update(self, id: str, payload: Any) -> str:
+    async def update(self, db: AsyncSession, id: str, payload: Any) -> str:
         """更新知识片段"""
-        kb = payload.collection_name
+        kb_name = payload.collection_name
         
         # 1. 删除旧数据
-        desc = milvus_service.client.describe_collection(kb)
+        desc = milvus_service.client.describe_collection(kb_name)
         fields = [f['name'] for f in desc.get('fields', [])]
         pk_field = "pk" if "pk" in fields else "id"
         
@@ -244,42 +263,54 @@ class SnippetService:
         else:
             expr = f"{pk_field} in ['{id}']"
         
-        milvus_service.client.delete(collection_name=kb, filter=expr)
+        milvus_service.client.delete(collection_name=kb_name, filter=expr)
         
         # 2. 插入新数据
         # 使用统一算法生成向量
         dim = milvus_service.DENSE_DIM
         fake_vector = text_to_vector_algo(payload.content, dim=dim)
         
+        # 动态构建 metadata
+        # 查找 KnowledgeBase ID
+        kb_query = select(KnowledgeBase).where(KnowledgeBase.collection_name == kb_name)
+        result = await db.execute(kb_query)
+        kb_obj = result.scalars().first()
+        
+        metadata = {}
+        if kb_obj:
+            # 查找该知识库定义的元数据字段
+            meta_query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb_obj.id)
+            meta_result = await db.execute(meta_query)
+            defined_fields = meta_result.scalars().all()
+            
+            # 仅填充定义的字段
+            custom_fields = getattr(payload, 'custom_fields', {}) or {}
+            for field in defined_fields:
+                field_name = field.field_en_name
+                if field_name in custom_fields:
+                    metadata[field_name] = custom_fields[field_name]
+        
         now = int(time.time() * 1000)
         item = {
             "vector": fake_vector,
             "text": payload.content,
-            "document_id": str(uuid.uuid4()), # 更新也会生成新文档ID
+            "document_id": payload.custom_fields.get("document_id") if hasattr(payload, 'custom_fields') and payload.custom_fields and payload.custom_fields.get("document_id") else str(uuid.uuid4()),
             "parent_id": payload.custom_fields.get("parent_id", "") if hasattr(payload, 'custom_fields') and payload.custom_fields else "",
             "index": 0,
-            "tag_list": "",
+            "tag_list": payload.custom_fields.get("tag_list", "") if hasattr(payload, 'custom_fields') and payload.custom_fields else "",
             "permission": {},
-            "metadata": {
-                "doc_name": payload.doc_name or "已更新",
-                "file_name": payload.doc_name,
-                "title": payload.doc_name
-            },
+            "metadata": metadata, # 使用动态构建的 metadata
             "is_deleted": False,
             "created_by": "system",
             "created_time": now,
             "updated_by": "system",
             "updated_time": now
         }
-        
-        # 合并自定义字段 (Schema已固定)
-        # if hasattr(payload, 'custom_fields') and payload.custom_fields:
-        #     item.update(payload.custom_fields)
             
         data = [item]
         
-        milvus_service.client.insert(collection_name=kb, data=data)
-        milvus_service.client.flush(kb)
+        milvus_service.client.insert(collection_name=kb_name, data=data)
+        milvus_service.client.flush(kb_name)
         
         return "更新成功 (ID已变更)"
 
@@ -307,13 +338,23 @@ class SnippetService:
         id_val = r.get("pk") or r.get("id")
         content = r.get("text") or r.get("content") or ""
         
+        # 处理 metadata (Milvus 可能返回 JSON 字符串)
+        meta = r.get("metadata") or {}
+        if isinstance(meta, str):
+            try:
+                import json
+                meta = json.loads(meta)
+            except:
+                pass
+        
         # 尝试从 metadata 中获取 doc_name
         doc_name = "未知文档"
-        meta = r.get("metadata") or {}
+        # 优先从 metadata 字典取
         if isinstance(meta, dict):
              doc_name = meta.get("doc_name") or meta.get("file_name") or meta.get("title") or doc_name
-        else:
-             # 兼容旧数据
+        
+        # 如果 metadata 里没有,或者不是 dict,尝试从一级字段取 (兼容旧数据)
+        if doc_name == "未知文档":
              doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or doc_name
 
         meta_info = f"ParentID: {r.get('parent_id', '-')}"
@@ -336,6 +377,10 @@ class SnippetService:
             "content": content,
             "char_count": len(content) if content else 0,
             "meta_info": meta_info,
+            "metadata": meta, # 透传完整元数据
+            "document_id": r.get("document_id", ""),
+            "parent_id": r.get("parent_id", ""),
+            "tag_list": r.get("tag_list", ""), # 返回 tag_list
             "status": "normal",
             "created_at": created_at,
             "updated_at": "-"

+ 15 - 5
src/views/snippet_view.py

@@ -6,8 +6,11 @@ from fastapi.responses import StreamingResponse
 from typing import Optional, Dict, Any
 from datetime import datetime
 import urllib.parse
+from sqlalchemy.ext.asyncio import AsyncSession
 
+from app.base.async_mysql_connection import get_db
 from app.services.snippet_service import snippet_service
+from app.services.knowledge_base_service import knowledge_base_service
 from app.schemas.base import ResponseSchema, PaginatedResponseSchema
 from app.services.jwt_token import verify_token
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -75,35 +78,38 @@ async def export_snippets(
 @router.post("", response_model=ResponseSchema)
 async def create_snippet(
     payload: SnippetCreate,
-    credentials: HTTPAuthorizationCredentials = Depends(security)
+    credentials: HTTPAuthorizationCredentials = Depends(security),
+    db: AsyncSession = Depends(get_db)
 ):
     """创建知识片段"""
     payload_token = verify_token(credentials.credentials)
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    data = snippet_service.create(payload)
+    data = await snippet_service.create(db, payload)
     return ResponseSchema(code=0, message="创建成功", data=data)
 
 @router.post("/{id}", response_model=ResponseSchema)
 async def update_snippet(
     id: str,
     payload: SnippetUpdate,
-    credentials: HTTPAuthorizationCredentials = Depends(security)
+    credentials: HTTPAuthorizationCredentials = Depends(security),
+    db: AsyncSession = Depends(get_db)
 ):
     """更新知识片段"""
     payload_token = verify_token(credentials.credentials)
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    msg = snippet_service.update(id, payload)
+    msg = await snippet_service.update(db, id, payload)
     return ResponseSchema(code=0, message=msg)
 
 @router.post("/{id}/delete", response_model=ResponseSchema)
 async def delete_snippet(
     id: str, 
     kb: str = Query(..., description="知识库名称"), 
-    credentials: HTTPAuthorizationCredentials = Depends(security)
+    credentials: HTTPAuthorizationCredentials = Depends(security),
+    db: AsyncSession = Depends(get_db)
 ):
     """删除知识片段"""
     payload_token = verify_token(credentials.credentials)
@@ -111,4 +117,8 @@ async def delete_snippet(
         return ResponseSchema(code=401, message="无效的访问令牌")
         
     snippet_service.delete(id, kb)
+    
+    # 更新知识库文档数量
+    await knowledge_base_service.update_doc_count(db, kb)
+    
     return ResponseSchema(code=0, message="删除成功")

+ 13 - 3
src/views/tag_view.py

@@ -151,6 +151,16 @@ async def get_category_tree(
 
 def _build_tree_response(category: TagCategory) -> dict:
     """将 TagCategory 转换为完整的树响应结构"""
+    # 确保 category 对象属性存在,避免延迟加载或属性缺失导致异常
+    created_at = getattr(category, 'created_at', None)
+    # 如果是 datetime 对象,可能需要转换,这里假设 ResponseSchema 或 json 序列化会自动处理
+    # 如果是 SQLAlchemy 模型,有时属性访问会触发加载,在 async 下需注意
+    
+    # 递归构建 children
+    children_data = None
+    if hasattr(category, 'children') and category.children:
+        children_data = [_build_tree_response(child) for child in category.children]
+        
     return {
         'id': category.id,
         'parent_id': category.parent_id,
@@ -164,11 +174,11 @@ def _build_tree_response(category: TagCategory) -> dict:
         'is_deleted': category.is_deleted,
         'created_by': category.created_by,
         'created_by_name': getattr(category, 'created_by_name', None),
-        'created_at': category.created_at,
+        'created_at': created_at,
         'updated_by': category.updated_by,
         'updated_by_name': getattr(category, 'updated_by_name', None),
-        'updated_at': category.updated_at,
-        'children': [_build_tree_response(child) for child in (getattr(category, 'children', None) or [])] if hasattr(category, 'children') and category.children else None
+        'updated_at': getattr(category, 'updated_at', None),
+        'children': children_data
     }