linyang 3 týždňov pred
rodič
commit
cd954e2b00

+ 1 - 1
src/app/sample/models/knowledge_base.py

@@ -12,7 +12,7 @@ class KnowledgeBase(BaseModel):
     __tablename__ = "t_samp_knowledge_base"
 
     name = Column(String(100), nullable=False, comment="知识库名称")
-    collection_name_parent = Column(String(100), nullable=False, unique=True, comment="Milvus集合名称(Table Name)(父)")
+    collection_name_parent = Column(String(100), nullable=True, unique=True, comment="Milvus集合名称(Table Name)(父)")
     collection_name_children = Column(String(100), nullable=True, comment="Milvus集合名称(Table Name)(子)")
     description = Column(String(500), nullable=True, comment="描述")
     # 默认禁用,只有同步成功后才置为 normal

+ 3 - 3
src/app/sample/schemas/knowledge_base.py

@@ -11,7 +11,7 @@ class MetadataField(BaseModel):
 
 class KnowledgeBaseBase(BaseModel):
     name: str = Field(..., description="知识库名称")
-    collection_name_parent: str = Field(..., description="Milvus集合名称(父)")
+    collection_name_parent: Optional[str] = Field(None, description="Milvus集合名称(父)")
     collection_name_children: str = Field(..., description="Milvus集合名称(子)")
     description: Optional[str] = Field(None, description="描述")
     # 默认状态改为禁用,只有同步成功后才置为 normal
@@ -27,7 +27,7 @@ class CustomSchemaField(BaseModel):
 
 class KnowledgeBaseCreate(KnowledgeBaseBase):
     """创建知识库请求参数"""
-    dimension: int = Field(768, description="向量维度,默认768")
+    dimension: int = Field(4096, description="向量维度,默认4096")
     metadata_fields: Optional[List[MetadataField]] = Field(None, description="元数据字段列表")
     custom_schemas: Optional[List[CustomSchemaField]] = Field(None, description="自定义Schema字段列表")
 
@@ -48,7 +48,7 @@ class KnowledgeBaseResponse(BaseModelSchema):
     """知识库响应模型"""
     id: str
     name: str
-    collection_name_parent: str
+    collection_name_parent: Optional[str] = None
     collection_name_children: Optional[str]
     description: Optional[str]
     status: str

+ 1 - 0
src/app/sample/schemas/search_engine.py

@@ -69,6 +69,7 @@ class KBSearchResultItem(BaseModel):
     content: str
     meta_info: str
     document_id: Optional[str] = None
+    parent_id: Optional[str] = None
     metadata: Optional[Dict[str, Any]] = None
     score: float
     

+ 65 - 31
src/app/services/knowledge_base_service.py

@@ -4,7 +4,7 @@
 from math import ceil
 from typing import List, Optional, Tuple, Dict, Any
 from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy import select, func, or_, delete as sql_delete
+from sqlalchemy import select, func, or_, delete as sql_delete, update as sql_update
 from datetime import datetime
 import uuid
 import asyncio
@@ -12,6 +12,7 @@ import asyncio
 from app.sample.models.knowledge_base import KnowledgeBase
 from app.sample.models.metadata import SampleMetadata
 from app.sample.models.custom_schema import CustomSchema
+from app.sample.models.base_info import DocumentMain
 from app.sample.schemas.knowledge_base import (
     KnowledgeBaseCreate, 
     KnowledgeBaseUpdate,
@@ -219,6 +220,39 @@ class KnowledgeBaseService:
     ) -> Tuple[List[KnowledgeBase], PaginationSchema]:
         """获取知识库列表"""
         
+        # --- 同步 Milvus 数据 (简化版:仅更新现有KB的计数和状态) ---
+        try:
+            # 1. 获取 Milvus 所有集合
+            milvus_names = milvus_service.client.list_collections()
+            
+            # 2. 获取 DB 中已有的集合
+            result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0))
+            existing_kbs = result.scalars().all()
+            
+            # 3. 更新现有KB的统计
+            has_changes = False
+            for kb in existing_kbs:
+                total_count = 0
+                
+                # 统计 collection_name_parent
+                if kb.collection_name_parent and kb.collection_name_parent in milvus_names:
+                    total_count += await self._get_collection_row_count(kb.collection_name_parent)
+                    
+                # 统计 collection_name_children
+                if kb.collection_name_children and kb.collection_name_children in milvus_names:
+                    total_count += await self._get_collection_row_count(kb.collection_name_children)
+                    
+                if kb.document_count != total_count:
+                    kb.document_count = total_count
+                    has_changes = True
+
+            if has_changes:
+                await db.commit()
+
+        except Exception as e:
+            print(f"Sync Milvus collections failed: {e}")
+        # ----------------------
+
         # 查询未删除的 KB
         query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == 0)
         
@@ -245,19 +279,15 @@ class KnowledgeBaseService:
         items = result.scalars().all()
 
         # 设置 is_synced (辅助字段,不存库)
-        try:
-            milvus_names_set = set(milvus_service.client.list_collections())
-            for item in items:
-                c1_ok = item.collection_name_parent in milvus_names_set
-                c2_ok = True
-                if item.collection_name_children:
-                    c2_ok = item.collection_name_children in milvus_names_set
-                
-                item.is_synced = c1_ok and c2_ok
-        except Exception as e:
-            print(f"Check Milvus sync status failed: {e}")
-            for item in items:
-                item.is_synced = False
+        milvus_names_set = set(milvus_service.client.list_collections())
+        for item in items:
+            # 父集合可能为空(未勾选创建父集合),此时只要子集合存在也算已同步
+            c1_ok = True if not item.collection_name_parent else item.collection_name_parent in milvus_names_set
+            c2_ok = True
+            if item.collection_name_children:
+                c2_ok = item.collection_name_children in milvus_names_set
+            
+            item.is_synced = c1_ok and c2_ok
         
         meta = PaginationSchema(
             page=page,
@@ -270,31 +300,36 @@ class KnowledgeBaseService:
 
     async def create(self, db: AsyncSession, payload: KnowledgeBaseCreate) -> KnowledgeBase:
         """创建新知识库"""
+        parent_name = (payload.collection_name_parent or "").strip() or None
+        child_name = (payload.collection_name_children or "").strip()
+
+        if not child_name:
+            raise ValueError("请输入子集合名称")
+
         # 1. 检查 DB 是否已存在
         # 检查父子集合名称不能相同
-        if payload.collection_name_children and payload.collection_name_parent == payload.collection_name_children:
+        if parent_name and parent_name == child_name:
             raise ValueError("父集合名称和子集合名称不能相同")
 
         # 检查 collection_name_parent (可选)
-        if payload.collection_name_parent:
+        if parent_name:
             exists1 = await db.execute(select(KnowledgeBase).where(
-                KnowledgeBase.collection_name_parent == payload.collection_name_parent,
+                KnowledgeBase.collection_name_parent == parent_name,
                 KnowledgeBase.is_deleted == 0
             ))
             if exists1.scalars().first():
-                raise ValueError(f"集合名称 {payload.collection_name_parent} 已存在")
+                raise ValueError(f"集合名称 {parent_name} 已存在")
             
         # 检查 collection_name_children
-        if payload.collection_name_children:
-            exists2 = await db.execute(select(KnowledgeBase).where(
-                or_(
-                    KnowledgeBase.collection_name_parent == payload.collection_name_children,
-                    KnowledgeBase.collection_name_children == payload.collection_name_children
-                ),
-                KnowledgeBase.is_deleted == 0
-            ))
-            if exists2.scalars().first():
-                raise ValueError(f"集合名称 {payload.collection_name_children} 已存在")
+        exists2 = await db.execute(select(KnowledgeBase).where(
+            or_(
+                KnowledgeBase.collection_name_parent == child_name,
+                KnowledgeBase.collection_name_children == child_name
+            ),
+            KnowledgeBase.is_deleted == 0
+        ))
+        if exists2.scalars().first():
+            raise ValueError(f"集合名称 {child_name} 已存在")
 
         try:
             # 3. 创建 DB 记录
@@ -302,8 +337,8 @@ class KnowledgeBaseService:
             new_kb = KnowledgeBase(
                 id=str(uuid.uuid4()),
                 name=payload.name,
-                collection_name_parent=payload.collection_name_parent,
-                collection_name_children=payload.collection_name_children,
+                collection_name_parent=parent_name,
+                collection_name_children=child_name,
                 description=payload.description,
                 # 默认创建为禁用状态,待同步成功后再启用
                 status="disabled",
@@ -493,7 +528,6 @@ class KnowledgeBaseService:
             {"name": "dense", "type": "FLOAT_VECTOR", "description": "向量列"},
             {"name": "sparse", "type": "BM25", "description": "内容的BM25关键字检索"},
             {"name": "document_id", "type": "VARCHAR", "max_length": 128, "description": "样本中心上传文档ID"},
-            {"name": "kb_id", "type": "VARCHAR", "max_length": 128, "description": "知识库ID"},
             {"name": "parent_id", "type": "VARCHAR", "max_length": 128, "description": "父段ID"},
             {"name": "index", "type": "INT64", "description": "索引序号"},
             {"name": "tag_list", "type": "VARCHAR", "max_length": 2048, "description": "标签"},

+ 42 - 31
src/app/services/milvus_service.py

@@ -39,23 +39,23 @@ class MilvusService:
             self.ensure_collection_exists(name)
 
     async def insert_knowledge(self, content: str, doc_info: Dict[str, Any]):
-        """将 Markdown 内容切分并入库 (支持通过路由到明确的父子集合)"""
+        """将 Markdown 内容切分并入库 (支持父子段分表)"""
         try:
             doc_id = doc_info.get("doc_id")
             doc_name = doc_info.get("doc_name")
-            kb_method = doc_info.get("kb_method")
+            doc_version = doc_info.get("doc_version", int(time.time()))
+            tags = doc_info.get("tags", "")
+            user_id = doc_info.get("user_id", "system")
             
-            # 获取明确的集合名称 (由业务层从数据库查出)
-            parent_col = doc_info.get("collection_name_parent") or PARENT_COLLECTION_NAME
-            child_col = doc_info.get("collection_name_children") or CHILD_COLLECTION_NAME
+            kb_method = doc_info.get("kb_method")
+            target_collection = doc_info.get("collection_name") or PARENT_COLLECTION_NAME
             
             from langchain_text_splitters import RecursiveCharacterTextSplitter
 
             if kb_method == "parent_child":
-                # --- 方案 A: 父子段分表入库 (双写模式) ---
-                # 确保两个集合都存在
-                self.ensure_collection_exists(parent_col)
-                self.ensure_collection_exists(child_col)
+                # --- 方案 A: 父子段分表入库 ---
+                parent_col = f"{target_collection}_parent"
+                child_col = f"{target_collection}_child"
                 
                 # 1. 切分父段 (较大块)
                 parent_splitter = RecursiveCharacterTextSplitter(
@@ -86,6 +86,10 @@ class MilvusService:
                         # 子段的 parent_id 指向父段的 p_id
                         c_metadata = self._prepare_metadata(doc_info, p_id, c_idx, p_id)
                         child_docs.append(Document(page_content=c_content, metadata=c_metadata))
+
+                # 确保两个集合都存在
+                self.ensure_collection_exists(parent_col)
+                self.ensure_collection_exists(child_col)
                 
                 # 分别入库
                 if parent_docs:
@@ -96,7 +100,7 @@ class MilvusService:
                 logger.info(f"Successfully inserted parent-child chunks for {doc_name}: {len(parent_docs)} parents -> {len(child_docs)} children")
             
             else:
-                # --- 方案 B: 常规单表入库 (只进子表,parent_id 设为空) ---
+                # --- 常规单表入库逻辑 ---
                 chunks = []
                 if kb_method == "length":
                     splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
@@ -109,7 +113,6 @@ class MilvusService:
                     )
                     chunks = splitter.split_text(content)
                 else:
-                    # 默认按双换行切分
                     chunks = [p.strip() for p in re.split(r"\n\s*\n+", content) if p.strip()]
                 
                 if not chunks:
@@ -119,15 +122,13 @@ class MilvusService:
                 documents = []
                 for idx, chunk in enumerate(chunks):
                     p_id = hashlib.sha1(f"{doc_id}_{idx}".encode()).hexdigest()
-                    # 对于单表模式,parent_id 设为空字符串
-                    metadata = self._prepare_metadata(doc_info, p_id, idx, "")
+                    metadata = self._prepare_metadata(doc_info, p_id, idx, p_id)
                     documents.append(Document(page_content=chunk, metadata=metadata))
 
-                # 确保子表集合存在
-                self.ensure_collection_exists(child_col)
-                get_milvus_vectorstore(child_col).add_documents(documents)
+                self.ensure_collection_exists(target_collection)
+                get_milvus_vectorstore(target_collection).add_documents(documents)
                 
-                logger.info(f"Successfully inserted {len(documents)} chunks for {doc_name} into {child_col} (kb_method: {kb_method})")
+                logger.info(f"Successfully inserted {len(documents)} chunks for {doc_name} into {target_collection}")
 
         except Exception as e:
             logger.error(f"Error inserting knowledge into Milvus: {e}")
@@ -140,11 +141,9 @@ class MilvusService:
         doc_version = doc_info.get("doc_version", int(time.time()))
         tags = doc_info.get("tags", "")
         user_id = doc_info.get("user_id", "system")
-        kb_id = doc_info.get("kb_id", "")
         
         return {
             "document_id": doc_id,
-            "kb_id": kb_id,
             "parent_id": parent_ref_id,
             "index": index,
             "tag_list": tags,
@@ -174,7 +173,6 @@ class MilvusService:
             schema.add_field("dense", DataType.FLOAT_VECTOR, dim=self.DENSE_DIM)
             schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
             schema.add_field("document_id", DataType.VARCHAR, max_length=256)
-            schema.add_field("kb_id", DataType.VARCHAR, max_length=256)
             schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
             schema.add_field("index", DataType.INT64)
             schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
@@ -227,9 +225,22 @@ class MilvusService:
             )
             needs_index = True
 
-        # [Optimized] 移除对 JSON 字段的冗余索引创建逻辑,避免在 Milvus 2.4+ 环境下因缺少参数报错
-        # 同时确保 core index (dense/sparse) 命名与 create_collection 保持一致
+        if "permission" in fields_in_collection and "permission" not in existing_indexes:
+            index_params.add_index(
+                field_name="permission",
+                index_type="INVERTED",
+                params={"json_cast_type": "VARCHAR"}
+            )
+            needs_index = True
         
+        if "metadata" in fields_in_collection and "metadata" not in existing_indexes:
+            index_params.add_index(
+                field_name="metadata",
+                index_type="INVERTED",
+                params={"json_cast_type": "VARCHAR"}
+            )
+            needs_index = True
+
         if needs_index:
             logger.info(f"Creating missing indexes for collection: {name}")
             try:
@@ -270,7 +281,6 @@ class MilvusService:
                 # 如果没有定义主键,添加默认主键
                 schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
             
-            # 检查是否有默认向量列,如果没有则添加 (兼容旧逻辑)
             # 检查是否有默认向量列,如果没有则添加 (兼容旧逻辑,但如果fields里有dense则不添加)
             has_vector = any(f.get("type") == "FLOAT_VECTOR" for f in fields)
             if not has_vector:
@@ -349,18 +359,15 @@ class MilvusService:
             # 5. 为所有向量字段添加索引
             for f in fields:
                 ftype = f.get("type", "").upper()
-                fname = f.get("name")
                 if ftype == "FLOAT_VECTOR":
                     index_params.add_index(
-                        field_name=fname, 
-                        index_name=f"{fname}_idx", # 显式命名,如 dense_idx,确保与 ensure_collection_exists 一致
+                        field_name=f.get("name"), 
                         index_type="AUTOINDEX",
                         metric_type="IP" # [Modified] 更改为 IP (内积),通常对规范化向量效果更好,与 COSINE 类似但更简单
                     )
                 elif ftype == "BM25" or ftype == "SPARSE_FLOAT_VECTOR":
                     index_params.add_index(
-                        field_name=fname,
-                        index_name="bm25_idx", # 显式命名,确保与 ensure_collection_exists 一致
+                        field_name=f.get("name"),
                         index_type="SPARSE_INVERTED_INDEX", # 稀疏向量索引
                         metric_type="BM25"
                     )
@@ -375,9 +382,13 @@ class MilvusService:
                         index_type="INVERTED"
                     )
                 elif ftype == "JSON":
-                    # JSON 字段索引在某些环境下存在兼容性问题(如缺少 json_cast_type 报错),
-                    # 考虑到目前主要通过表达式过滤 JSON,且非核心性能瓶颈,暂时不自动创建 JSON 索引。
-                    pass
+                    # Milvus 2.4+ JSON 索引必须指定 json_cast_type
+                    # 这里为 JSON 字段添加默认索引,以便支持查询
+                    index_params.add_index(
+                        field_name=f.get("name"),
+                        index_type="INVERTED",
+                        params={"json_cast_type": "VARCHAR"}
+                    )
 
             # 7. 创建集合
             self.client.create_collection(

+ 47 - 101
src/app/services/sample_service.py

@@ -111,7 +111,7 @@ class SampleService:
         Args:
             doc_ids: 文档ID列表
             username: 操作人
-            kb_id: 知识库ID (可选,若不传则根据 source_type 自动匹配)
+            kb_id: 知识库ID
             kb_method: 切分方法
         """
         conn = get_db_connection()
@@ -129,7 +129,7 @@ class SampleService:
             # 1. 获取所有选中选中的文档详情
             placeholders = ','.join(['%s']*len(doc_ids))
             fetch_sql = f"""
-                SELECT id, title, source_type, md_url, conversion_status, whether_to_enter, created_time, kb_id 
+                SELECT id, title, source_type, md_url, conversion_status, whether_to_enter, created_time 
                 FROM t_samp_document_main 
                 WHERE id IN ({placeholders})
             """
@@ -146,7 +146,6 @@ class SampleService:
                 status = doc.get('conversion_status')
                 whether_to_enter = doc.get('whether_to_enter', 0)
                 md_url = doc.get('md_url')
-                source_type = doc.get('source_type')
                 
                 # A. 检查是否已入库
                 if whether_to_enter == 1:
@@ -169,38 +168,7 @@ class SampleService:
                     error_details.append(f"· {title}: 转换结果地址丢失")
                     continue
                 
-                # C. 确定入库策略 (严格使用弹窗传入的参数)
-                # 不从数据库读取旧的 kb_method,保证入库逻辑由本次操作决定
-                current_kb_id = kb_id or doc.get('kb_id')
-                current_kb_method = kb_method  # 直接使用前端传来的切分方式
-
-                if not current_kb_id:
-                    logger.warning(f"文档 {title}({doc_id}) 未指定知识库,跳过入库")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 未指定目标知识库")
-                    continue
-
-                if not current_kb_method:
-                    logger.warning(f"文档 {title}({doc_id}) 未指定切分方式,跳过入库")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 未指定切分策略")
-                    continue
-
-                # 获取知识库信息 (collection_name_parent, collection_name_children)
-                kb_sql = "SELECT collection_name_parent, collection_name_children FROM t_samp_knowledge_base WHERE id = %s AND is_deleted = 0"
-                cursor.execute(kb_sql, (current_kb_id,))
-                kb_res = cursor.fetchone()
-                
-                if not kb_res:
-                    logger.warning(f"找不到指定的知识库: id={current_kb_id}")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 指定的知识库不存在或已被删除")
-                    continue
-                
-                collection_name_parent = kb_res['collection_name_parent']
-                collection_name_children = kb_res['collection_name_children']
-                
-                # D. 从 MinIO 获取 Markdown 内容
+                # B. 从 MinIO 获取 Markdown 内容
                 try:
                     md_content = self.minio_manager.get_object_content(md_url)
                     if not md_content:
@@ -211,32 +179,39 @@ class SampleService:
                     error_details.append(f"· {title}: 读取云端文件失败")
                     continue
                 
-                # E. 调用 MilvusService 进行切分和入库
+                # C. 调用 MilvusService 进行切分和入库
                 try:
+                    # 如果有 kb_id,需要根据它获取 collection_name
+                    collection_name = None
+                    if kb_id:
+                        kb_sql = "SELECT collection_name FROM t_samp_knowledge_base WHERE id = %s"
+                        cursor.execute(kb_sql, (kb_id,))
+                        kb_res = cursor.fetchone()
+                        if kb_res:
+                            collection_name = kb_res['collection_name']
+                    
                     # 准备元数据
-                    current_date = int(datetime.now().strftime('%Y%m%d'))
                     doc_info = {
                         "doc_id": doc_id,
                         "doc_name": title,
-                        "doc_version": int(doc['created_time'].strftime('%Y%m%d')) if doc.get('created_time') else current_date,
-                        "tags": source_type or 'unknown',
+                        "doc_version": int(doc['created_time'].strftime('%Y%m%d')) if doc.get('created_time') else 20260127,
+                        "tags": doc.get('source_type') or 'unknown',
                         "user_id": username,  # 传递操作人作为 created_by
-                        "kb_id": current_kb_id,
-                        "kb_method": current_kb_method,
-                        "collection_name_parent": collection_name_parent,
-                        "collection_name_children": collection_name_children
+                        "kb_id": kb_id,
+                        "kb_method": kb_method,
+                        "collection_name": collection_name
                     }
                     await self.milvus_service.insert_knowledge(md_content, doc_info)
                     
-                    # F. 添加到任务管理中心 (类型为 data)
+                    # D. 添加到任务管理中心 (类型为 data)
                     try:
                         await task_service.add_task(doc_id, 'data')
                     except Exception as task_err:
                         logger.error(f"添加文档 {title} 到任务中心失败: {task_err}")
 
-                    # G. 更新数据库状态
+                    # E. 更新数据库状态
                     update_sql = "UPDATE t_samp_document_main SET whether_to_enter = 1, kb_id = %s, kb_method = %s, updated_by = %s, updated_time = NOW() WHERE id = %s"
-                    cursor.execute(update_sql, (current_kb_id, current_kb_method, username, doc_id))
+                    cursor.execute(update_sql, (kb_id, kb_method, username, doc_id))
                     success_count += 1
                     
                 except Exception as milvus_err:
@@ -410,9 +385,8 @@ class SampleService:
                     LEFT JOIN {sub_table} s ON m.id = s.id
                     LEFT JOIN t_sys_user u1 ON m.created_by = u1.id
                     LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id
-                    LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id
                 """
-                fields_sql = "m.*, s.*, u1.username as creator_name, u2.username as updater_name, kb.name as kb_name, m.id as id"
+                fields_sql = "m.*, s.*, u1.username as creator_name, u2.username as updater_name, m.id as id"
                 where_clauses.append("m.source_type = %s")
                 params.append(table_type)
                 order_sql = "m.created_time DESC"
@@ -433,8 +407,8 @@ class SampleService:
                         where_clauses.append("s.level_4_classification = %s")
                         params.append(level_4_classification)
             else:
-                from_sql = "t_samp_document_main m LEFT JOIN t_sys_user u1 ON m.created_by = u1.id LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id"
-                fields_sql = "m.*, u1.username as creator_name, u2.username as updater_name, kb.name as kb_name"
+                from_sql = "t_samp_document_main m LEFT JOIN t_sys_user u1 ON m.created_by = u1.id LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id"
+                fields_sql = "m.*, u1.username as creator_name, u2.username as updater_name"
                 order_sql = "m.created_time DESC"
                 title_field = "m.title"
             
@@ -457,6 +431,7 @@ class SampleService:
             sql = f"SELECT {fields_sql} FROM {from_sql} {where_sql} ORDER BY {order_sql} LIMIT %s OFFSET %s"
             params.extend([size, offset])
             
+            logger.info(f"Executing SQL: {sql} with params: {params}")
             cursor.execute(sql, tuple(params))
             items = [self._format_document_row(row) for row in cursor.fetchall()]
             
@@ -571,13 +546,12 @@ class SampleService:
                 INSERT INTO t_samp_document_main (
                     id, title, source_type, file_url, 
                     file_extension, created_by, updated_by, created_time, updated_time,
-                    conversion_status, whether_to_task, kb_id
-                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0, %s)
+                    conversion_status, whether_to_task
+                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0)
                 """,
                 (
                     doc_id, doc_data.get('title'), table_type, file_url,
-                    doc_data.get('file_extension'), user_id, user_id,
-                    doc_data.get('kb_id')
+                    doc_data.get('file_extension'), user_id, user_id
                 )
             )
 
@@ -674,14 +648,14 @@ class SampleService:
             # 1. 更新主表
             cursor.execute(
                 """
-                UPDATE t_samp_document_main SET 
-                    title = %s, file_url = %s, file_extension = %s, 
-                    updated_by = %s, updated_time = NOW(), kb_id = %s
+                UPDATE t_samp_document_main 
+                SET title = %s, file_url = %s, file_extension = %s,
+                    updated_by = %s, updated_time = NOW()
                 WHERE id = %s
                 """,
                 (
                     doc_data.get('title'), file_url, doc_data.get('file_extension'),
-                    updater_id, doc_data.get('kb_id'), doc_id
+                    updater_id, doc_id
                 )
             )
 
@@ -780,7 +754,7 @@ class SampleService:
                     s.participating_units, s.reference_basis,
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.chinese_name',
@@ -804,7 +778,7 @@ class SampleService:
                     s.note, 
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.plan_name',
@@ -825,7 +799,7 @@ class SampleService:
                     s.note, 
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.file_name',
@@ -886,12 +860,11 @@ class SampleService:
             
             # 使用 LEFT JOIN 关联主表和用户表获取姓名
             sql = f"""
-                SELECT {fields}, kb.name as kb_name
+                SELECT {fields} 
                 FROM {table_name} s
                 LEFT JOIN t_samp_document_main m ON s.id = m.id
                 LEFT JOIN t_sys_user u1 ON s.created_by = u1.id
                 LEFT JOIN t_sys_user u2 ON s.updated_by = u2.id
-                LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id
                 {where_sql} 
                 ORDER BY s.created_time DESC 
                 LIMIT %s OFFSET %s
@@ -1035,12 +1008,12 @@ class SampleService:
                 INSERT INTO t_samp_document_main (
                     id, title, source_type, file_url, 
                     file_extension, created_by, updated_by, created_time, updated_time,
-                    conversion_status, whether_to_task, kb_id
-                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0, %s)
+                    conversion_status, whether_to_task
+                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0)
                 """,
                 (
                     doc_id, data.get('title'), type, file_url,
-                    file_extension, user_id, user_id, data.get('kb_id')
+                    file_extension, user_id, user_id
                 )
             )
             
@@ -1149,10 +1122,10 @@ class SampleService:
             cursor.execute(
                 """
                 UPDATE t_samp_document_main 
-                SET title = %s, file_url = %s, file_extension = %s, updated_by = %s, updated_time = NOW(), kb_id = %s
+                SET title = %s, file_url = %s, file_extension = %s, updated_by = %s, updated_time = NOW()
                 WHERE id = %s
                 """,
-                (data.get('title'), file_url, file_extension, updater_id, data.get('kb_id'), doc_id)
+                (data.get('title'), file_url, file_extension, updater_id, doc_id)
             )
 
             # 2. 更新子表 (移除 file_url)
@@ -1227,10 +1200,6 @@ class SampleService:
 
     async def delete_basic_info(self, type: str, doc_id: str) -> Tuple[bool, str]:
         """删除基本信息"""
-        if not doc_id:
-            return False, "缺少 ID 参数"
-            
-        logger.info(f"Deleting basic info: type={type}, id={doc_id}")
         conn = get_db_connection()
         if not conn:
             return False, "数据库连接失败"
@@ -1241,44 +1210,21 @@ class SampleService:
             if not table_name:
                 return False, "无效的类型"
             
-            # 1. 显式删除子表记录 (防止 CASCADE 未生效)
-            try:
-                cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (doc_id,))
-                logger.info(f"Deleted from sub-table {table_name}, affected: {cursor.rowcount}")
-            except Exception as sub_e:
-                logger.warning(f"删除子表 {table_name} 记录失败 (可能不存在): {sub_e}")
-
-            # 2. 同步删除任务管理中心的数据 (优先删除关联数据)
-            try:
-                # 使用当前事务删除任务记录(如果 task_service 支持的话,目前它自建连接)
-                # 这里我们直接在当前 cursor 中也执行一次,确保事务一致性
-                cursor.execute("DELETE FROM t_task_management WHERE business_id = %s", (doc_id,))
-                logger.info(f"Deleted from t_task_management, affected: {cursor.rowcount}")
-            except Exception as task_e:
-                logger.warning(f"在主事务中删除任务记录失败: {task_e}")
-
-            # 3. 删除主表记录
+            # 1. 删除主表记录 (由于设置了 ON DELETE CASCADE,子表记录会自动删除)
             cursor.execute("DELETE FROM t_samp_document_main WHERE id = %s", (doc_id,))
-            affected_main = cursor.rowcount
-            logger.info(f"Deleted from t_samp_document_main, affected: {affected_main}")
             
-            if affected_main == 0:
-                logger.warning(f"未找到主表记录: {doc_id}")
-                # 即使主表没找到,我们也 commit 之前的操作并返回成功(幂等性)
-            
-            conn.commit()
-            
-            # 4. 再次确保任务中心数据已删除 (调用原有服务)
+            # 同步删除任务管理中心的数据
             try:
                 await task_service.delete_task(doc_id)
             except Exception as task_err:
-                logger.error(f"调用 task_service 删除任务失败: {task_err}")
+                logger.error(f"同步删除任务中心数据失败 (ID: {doc_id}): {task_err}")
 
+            conn.commit()
             return True, "删除成功"
         except Exception as e:
-            logger.exception(f"删除基本信息异常 (ID: {doc_id})")
+            logger.exception("删除基本信息失败")
             conn.rollback()
-            return False, f"删除失败: {str(e)}"
+            return False, str(e)
         finally:
             cursor.close()
             conn.close()

+ 10 - 129
src/app/services/search_engine_service.py

@@ -38,15 +38,11 @@ class SearchEngineService:
         from sqlalchemy import text
         try:
             # 简单判断是否是 UUID 格式或数字 ID,尝试查询数据库
-            # 修改 collection_name 为 collection_name_parent,并增加对 children 的兼容
-            kb_query = text("SELECT collection_name_parent FROM t_samp_knowledge_base WHERE id = :kb_id OR collection_name_parent = :kb_id OR collection_name_children = :kb_id")
+            kb_query = text("SELECT collection_name FROM t_samp_knowledge_base WHERE id = :kb_id OR collection_name = :kb_id")
             kb_res = await db.execute(kb_query, {"kb_id": original_kb_id})
             kb_row = kb_res.fetchone()
             if kb_row:
                 collection_name = kb_row[0]
-                # 如果是 parent_child 模式,剥离 _parent 后缀供后面拼接
-                if collection_name and collection_name.endswith('_parent'):
-                    collection_name = collection_name[:-7]
                 logging.info(f"Resolved kb_id {original_kb_id} to collection_name: {collection_name}")
         except Exception as db_err:
             logging.warning(f"Failed to resolve kb_id {original_kb_id} from database: {db_err}")
@@ -326,22 +322,7 @@ class SearchEngineService:
 
                     # PDR 模式内容获取 (可选)
                     item_content = item.get('text') or item.get('content') or item.get('page_content') or ""
-                    if is_pdr:
-                        parent_id = item_metadata.get("parent_id") or item.get("parent_id")
-                        if parent_id:
-                            try:
-                                parent_results = milvus_service.client.query(
-                                    collection_name=parent_col,
-                                    filter=f'parent_id == "{parent_id}"',
-                                    output_fields=["text", "content", "page_content"]
-                                )
-                                if parent_results:
-                                    p_entity = parent_results[0]
-                                    parent_full = p_entity.get("text") or p_entity.get("content") or p_entity.get("page_content")
-                                    if parent_full:
-                                        item_content = f"【父段内容】\n{parent_full}\n\n【片段内容】\n{item_content}"
-                            except:
-                                pass
+                    parent_id = item.get("parent_id") or item_metadata.get("parent_id") or ""
 
                     doc_name = (
                         item_metadata.get('doc_name')
@@ -361,6 +342,7 @@ class SearchEngineService:
                         content=item_content,
                         meta_info=str(item_metadata),
                         document_id=str(item.get("document_id") or ""),
+                        parent_id=str(parent_id) if parent_id is not None else None,
                         metadata=item_metadata,
                         score=0
                     ))
@@ -518,115 +500,14 @@ class SearchEngineService:
                     # 1. 如果有 PDR 标记,则 parent_col 已知
                     # 2. 如果没有,需要查询 KnowledgeBase 表获取 collection_name_parent (父表)
                     
-                    final_parent_content = None
-                    parent_id = entity.get("metadata", {}).get("parent_id") if entity.get("metadata") else entity.get("parent_id")
+                    parent_id = entity.get("parent_id")
+                    if not parent_id and isinstance(entity.get("metadata"), dict):
+                        parent_id = entity.get("metadata", {}).get("parent_id")
                     
                     # 尝试从 metadata 中获取 parent_id,如果 entity 直接有也行
                     if not parent_id and "metadata" in entity and isinstance(entity["metadata"], dict):
                         parent_id = entity["metadata"].get("parent_id")
-
-                    if parent_id:
-                        # 确定父表名
-                        target_parent_col = parent_col if is_pdr else None
-                        
-                        if not target_parent_col:
-                            # 尝试从 DB 查询
-                            # 注意:这里需要 db session,但当前 search_kb 参数中有 db
-                            # 我们可以查询 KnowledgeBase 表
-                            try:
-                                from app.sample.models.knowledge_base import KnowledgeBase
-                                # 假设 kb_id 是子表名 (collection_name)
-                                kb_stmt = select(KnowledgeBase.collection_name_parent).where(
-                                    or_(
-                                        KnowledgeBase.collection_name == kb_id,
-                                        KnowledgeBase.collection_name_parent == kb_id # 兼容
-                                    )
-                                )
-                                # 这里是在循环里,查询 DB 可能会慢,但为了功能正确性先加上
-                                # 优化:应该在循环外批量查,或者缓存 KB 信息
-                                # 鉴于 search_kb 入口已经解析过 collection_name,我们可以复用
-                                # 但 search_kb 解析的是 original_kb_id -> collection_name
-                                # 如果 kb_id == child_col (即 is_pdr=True),则 parent_col 已经设置好了
-                                
-                                # 如果 !is_pdr,那么 kb_id 可能就是子表名 (用户直接选了子表)
-                                # 此时我们需要反查父表
-                                kb_res = await db.execute(kb_stmt)
-                                p_row = kb_res.fetchone()
-                                if p_row and p_row[0]:
-                                    target_parent_col = p_row[0]
-                            except Exception:
-                                pass
-                        
-                        if target_parent_col and milvus_service.has_collection(target_parent_col):
-                            try:
-                                # 构造查询表达式 (参考 SnippetService)
-                                # 检查父表 document_id 类型
-                                p_desc = milvus_service.client.describe_collection(target_parent_col)
-                                p_fields = p_desc.get('fields', [])
-                                p_is_int = False
-                                for f in p_fields:
-                                    if f['name'] == 'document_id':
-                                        if f.get('type') == 5: p_is_int = True
-                                        break
-                                
-                                clean_pid = str(parent_id)
-                                if clean_pid.startswith("SNIP-"):
-                                    clean_pid = clean_pid.replace("SNIP-", "")
-                                    
-                                p_expr = ""
-                                if "document_id" in [f['name'] for f in p_fields]:
-                                    if p_is_int and clean_pid.isdigit():
-                                        p_expr = f'document_id == {clean_pid}'
-                                    else:
-                                        p_expr = f'document_id == "{clean_pid}"'
-                                
-                                p_content = None
-                                if p_expr:
-                                    p_res = milvus_service.client.query(target_parent_col, filter=p_expr, output_fields=["text", "content", "page_content"], limit=1)
-                                    if p_res:
-                                        p_content = p_res[0].get("text") or p_res[0].get("content") or p_res[0].get("page_content")
-                                
-                                # 尝试 PK
-                                if not p_content:
-                                    pk_field = "pk"
-                                    pk_is_int = True
-                                    for f in p_fields:
-                                        if f.get('primary_key'):
-                                            pk_field = f.get('name')
-                                            pk_is_int = (f.get('type') == 5)
-                                            break
-                                    
-                                    pk_expr = ""
-                                    if pk_is_int and clean_pid.isdigit():
-                                        pk_expr = f'{pk_field} == {clean_pid}'
-                                    else:
-                                        pk_expr = f'{pk_field} == "{clean_pid}"'
-                                        
-                                    p_res = milvus_service.client.query(target_parent_col, filter=pk_expr, output_fields=["text", "content", "page_content"], limit=1)
-                                    if p_res:
-                                        p_content = p_res[0].get("text") or p_res[0].get("content") or p_res[0].get("page_content")
-                                
-                                if p_content:
-                                    final_parent_content = p_content
-                                    # [Optional] 是否替换 content? 
-                                    # 用户需求:"如果有父段的话也要去显示父段的内容"
-                                    # 通常是作为附加信息,或者替换片段内容(如果片段内容太短)
-                                    # 这里我们选择追加或者替换,视情况而定。
-                                    # 为了明确展示,我们可以在 content 前面加标识,或者直接替换
-                                    # 参考 SnippetService,它是返回 parent_content 字段
-                                    # 但 KBSearchResultItem 只有 content 字段
-                                    # 我们可以把父段内容拼接到 content 中
-                                    # content = f"[父段内容]:\n{final_parent_content}\n\n[片段内容]:\n{content}"
-                                    # 或者只显示父段内容(如果用户认为父段内容更完整)
-                                    # 既然用户说“显示父段的内容”,可能意图是“上下文”
-                                    # 我们这里采用拼接方式,清晰展示
-                                    pass # 逻辑在下面处理
-
-                            except Exception as e:
-                                logging.error(f"Failed to fetch parent chunk {parent_id} from {target_parent_col}: {e}")
-
-                    if final_parent_content:
-                         content = f"【父段内容】\n{final_parent_content}\n\n【片段内容】\n{content}"
+                    # 详情页需要按 parent_id 动态查询父段并支持多条展示,这里不在检索结果中拼接父段内容
 
                     # 获取文档名称
                     doc_name = "未知文档"
@@ -732,13 +613,13 @@ class SearchEngineService:
                         content=content,
                         meta_info=meta_str,
                         document_id=str(document_id) if document_id is not None else None,
+                        parent_id=str(parent_id) if parent_id is not None else None,
                         metadata=meta_dict if isinstance(meta_dict, dict) else None,
                         score=similarity_pct
                     ))
-            
-            # 按相似度由大到小排序
-            formatted_results.sort(key=lambda x: x.score, reverse=True)
 
+            formatted_results.sort(key=lambda r: (r.score or 0), reverse=True)
+            
             # [Fix] 动态计算 total 用于分页
             # 如果当前页结果不满 limit,说明是最后一页
             current_count = len(formatted_results)

+ 69 - 65
src/app/services/snippet_service.py

@@ -793,11 +793,26 @@ class SnippetService:
                         snippet_data = self._format_snippet(res[0], kb)
 
             if not snippet_data:
-                return None
+                snippet_data = {
+                    "id": "",
+                    "collection_name": kb,
+                    "doc_name": "",
+                    "code": "",
+                    "content": "",
+                    "char_count": 0,
+                    "meta_info": "",
+                    "metadata": {},
+                    "document_id": "",
+                    "parent_id": "",
+                    "tag_list": "",
+                    "status": "normal",
+                    "created_at": "-",
+                    "updated_at": "-"
+                }
 
             # [New Feature] 获取父段内容
             # 逻辑:根据当前子表(kb) -> 查 KnowledgeBase 表找到对应的父表 -> 用 parent_id 查父表内容
-            parent_id = snippet_data.get("parent_id")
+            parent_id = snippet_data.get("parent_id") or clean_id
             # print(f"DEBUG: snippet_id={id}, kb={kb}, found parent_id={parent_id}")
             
             if parent_id:
@@ -820,82 +835,71 @@ class SnippetService:
                     if kb_record and kb_record.collection_name_parent:
                         parent_kb = kb_record.collection_name_parent
                         
-                        # 2. 在父表中查询 parent_id
-                        # 父表中的 ID 应该是 document_id 或者 pk
+                        # 2. 在父表中查询 parent_id 相同的父段(可能有多个切片)
                         if milvus_service.has_collection(parent_kb):
-                            # 构造查询表达式
-                            # 假设 parent_id 对应父表的 document_id
-                            # 先检查父表字段
                             print("successqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq")
                             p_desc = milvus_service.client.describe_collection(parent_kb)
                             p_fields = p_desc.get('fields', [])
                             p_field_names = [f['name'] for f in p_fields]
                             
+                            clean_pid = str(parent_id)
+                            if clean_pid.startswith("SNIP-"):
+                                clean_pid = clean_pid.replace("SNIP-", "")
+
+                            parent_segments = []
                             p_expr = ""
-                            if "document_id" in p_field_names:
-                                # 检查类型
+                            if "parent_id" in p_field_names:
                                 p_is_int = False
                                 for f in p_fields:
-                                    if f['name'] == 'document_id':
-                                        if f.get('type') == 5: p_is_int = True
+                                    if f['name'] == 'parent_id':
+                                        if f.get('type') == 5:
+                                            p_is_int = True
                                         break
-                                
-                                # [Fix] 增强匹配逻辑:
-                                # 如果 parent_id 是纯数字字符串,且字段是 INT,则转 int
-                                # 如果 parent_id 是纯数字字符串,且字段是 VARCHAR,则保留引号
-                                # 如果 parent_id 是 SNIP- 开头,去掉前缀后再试
-                                
-                                clean_pid = str(parent_id)
-                                if clean_pid.startswith("SNIP-"):
-                                    clean_pid = clean_pid.replace("SNIP-", "")
-                                
-                                if p_is_int:
-                                    if clean_pid.isdigit():
-                                        p_expr = f'document_id == {clean_pid}'
+                                if p_is_int and clean_pid.isdigit():
+                                    p_expr = f'parent_id == {clean_pid}'
                                 else:
-                                    # VARCHAR 类型,直接加引号
-                                    p_expr = f'document_id == "{clean_pid}"'
-                            
-                            # print(f"DEBUG: Parent query expr (document_id): {p_expr}")
-                            
-                            p_content = None
+                                    p_expr = f'parent_id == "{clean_pid}"'
+                            elif "metadata" in p_field_names:
+                                p_expr = f'metadata["parent_id"] == "{clean_pid}"'
+
                             if p_expr:
-                                p_res = milvus_service.client.query(parent_kb, filter=p_expr, output_fields=["text", "content"], limit=1)
+                                p_res = milvus_service.client.query(
+                                    collection_name=parent_kb,
+                                    filter=p_expr,
+                                    output_fields=["pk", "id", "text", "content", "page_content", "index", "document_id", "parent_id", "metadata", "created_time"],
+                                    limit=200
+                                )
                                 if p_res:
-                                    p_content = p_res[0].get("text") or p_res[0].get("content")
-                                    print('cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc')
-                            
-                            # 如果还没查到,尝试 PK
-                            if not p_content:
-                                p_pk_field = "pk"
-                                p_pk_int = True
-                                for f in p_fields:
-                                    if f.get('primary_key'):
-                                        p_pk_field = f.get('name')
-                                        p_pk_int = (f.get('type') == 5)
-                                        break
-                                
-                                clean_pid = str(parent_id)
-                                if clean_pid.startswith("SNIP-"):
-                                    clean_pid = clean_pid.replace("SNIP-", "")
-
-                                p_expr_pk = ""
-                                if p_pk_int:
-                                    if clean_pid.isdigit():
-                                        p_expr_pk = f'{p_pk_field} == {clean_pid}'
-                                else:
-                                    p_expr_pk = f'{p_pk_field} == "{clean_pid}"'
-                                
-                                # print(f"DEBUG: Parent query expr (PK): {p_expr_pk}")
-                                    
-                                if p_expr_pk:
-                                    p_res = milvus_service.client.query(parent_kb, filter=p_expr_pk, output_fields=["text", "content"], limit=1)
-                                    if p_res:
-                                        p_content = p_res[0].get("text") or p_res[0].get("content")
-
-                            if p_content:
-                                snippet_data["parent_content"] = p_content
-                                # print(f"DEBUG: Found parent content, len={len(p_content)}")
+                                    for r in p_res:
+                                        p_content = r.get("text") or r.get("content") or r.get("page_content") or ""
+                                        parent_segments.append({
+                                            "id": str(r.get("pk") or r.get("id") or ""),
+                                            "index": r.get("index"),
+                                            "document_id": str(r.get("document_id") or ""),
+                                            "parent_id": str(r.get("parent_id") or clean_pid),
+                                            "metadata": r.get("metadata") or {},
+                                            "created_time": r.get("created_time"),
+                                            "content": p_content
+                                        })
+
+                            if parent_segments:
+                                def _sort_key(x):
+                                    idx = x.get("index")
+                                    try:
+                                        idx = int(idx)
+                                    except Exception:
+                                        idx = 0
+                                    ct = x.get("created_time")
+                                    try:
+                                        ct = int(ct)
+                                    except Exception:
+                                        ct = 0
+                                    pid = x.get("id") or ""
+                                    return (idx, ct, pid)
+
+                                parent_segments.sort(key=_sort_key)
+                                snippet_data["parent_segments"] = parent_segments
+                                snippet_data["parent_content"] = parent_segments[0].get("content") or ""
                             # else:
                                 # print("DEBUG: Parent content NOT found")
                     # else:

+ 1 - 1
src/views/knowledge_base_view.py

@@ -60,7 +60,7 @@ async def get_knowledge_base_simple_list(
     return ResponseSchema(
         code=0,
         message="获取成功",
-        data=[{"id": item.id, "name": item.name, "collection_name": item.collection_name_parent} for item in items]
+        data=[{"id": item.id, "name": item.name, "collection_name": item.collection_name_children or item.collection_name_parent} for item in items]
     )
 
 @router.post("", response_model=ResponseSchema)