linyang il y a 3 semaines
Parent
commit
f553a1c9f2

+ 4 - 1
src/app/api/v1/document/knowledge_base.py

@@ -3,6 +3,7 @@
 """
 from math import ceil
 from typing import List
+import logging
 from fastapi import APIRouter, Query, Path, Depends, HTTPException
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy import select, func, or_
@@ -18,6 +19,8 @@ from app.sample.schemas.knowledge_base import (
 )
 from app.services.milvus_service import milvus_service
 
+logger = logging.getLogger(__name__)
+
 router = APIRouter()
 
 @router.get("", response_model=PaginatedResponseSchema)
@@ -76,7 +79,7 @@ async def get_knowledge_bases(
             await db.commit()
             
     except Exception as e:
-        print(f"Sync Milvus collections failed: {e}")
+        logger.exception("Sync Milvus collections failed")
     # ----------------------
 
     query = select(KnowledgeBase).where(KnowledgeBase.is_deleted == False)

+ 43 - 19
src/app/services/knowledge_base_service.py

@@ -3,6 +3,7 @@
 """
 from math import ceil
 from typing import List, Optional, Tuple, Dict, Any
+import logging
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy import select, func, or_, delete as sql_delete, update as sql_update
 from datetime import datetime
@@ -21,6 +22,8 @@ from app.sample.schemas.knowledge_base import (
 from app.services.milvus_service import milvus_service
 from app.schemas.base import PaginationSchema
 
+logger = logging.getLogger(__name__)
+
 class KnowledgeBaseService:
     
     async def _get_collection_row_count(self, collection_name: str) -> int:
@@ -34,13 +37,11 @@ class KnowledgeBaseService:
                     indexes = milvus_service.client.list_indexes(collection_name)
                     if not indexes:
                          # 无索引无法加载,直接跳过,进入 Fallback 使用 stats
-                         # print(f"Collection {collection_name} has no index, skipping load.")
                          raise Exception("Collection has no index, cannot load")
                 except Exception:
                     # list_indexes 失败也视为无法加载
                     raise Exception("Failed to check indexes or no index")
 
-                # print(f"Auto loading collection {collection_name} for counting...")
                 milvus_service.set_collection_state(collection_name, "load")
             
             # 尝试使用 count(*) 获取准确的实时数量
@@ -77,7 +78,7 @@ class KnowledgeBaseService:
                 if res and isinstance(res, list) and "count(*)" in res[0]:
                     return int(res[0]["count(*)"])
             except Exception as e:
-                print(f"Query count with filter error for {collection_name}: {e}")
+                logger.warning("Query count with filter error for %s: %s", collection_name, e)
                 # 再次尝试不过滤 (使用恒真表达式)
                 if milvus_service.get_collection_state(collection_name) == "Loaded":
                      # 获取 PK 字段名
@@ -98,9 +99,8 @@ class KnowledgeBaseService:
                      res = milvus_service.client.query(collection_name, filter=filter_expr, output_fields=["count(*)"])
                      if res and isinstance(res, list) and "count(*)" in res[0]:
                         return int(res[0]["count(*)"])
-        except Exception as e:
-            # print(f"Get collection row count error for {collection_name}: {e}")
-            pass
+        except Exception:
+            logger.exception("Get collection row count error for %s", collection_name)
             
         # Fallback: 使用 get_collection_stats (可能包含已删除未 Compaction 的数据)
         try:
@@ -206,9 +206,9 @@ class KnowledgeBaseService:
                         db.add(new_metadata)
                     
                     # 注意:调用方负责 commit,这里不 commit 以支持批量事务
-                    print(f"Auto inferred metadata for {target_col}: {inferred_keys}")
+                    logger.info("Auto inferred metadata for %s: %s", target_col, inferred_keys)
         except Exception as e:
-            print(f"Failed to infer metadata for {target_col}: {e}")
+            logger.exception("Failed to infer metadata for %s", target_col)
 
     async def get_list(
         self, 
@@ -250,7 +250,7 @@ class KnowledgeBaseService:
                 await db.commit()
 
         except Exception as e:
-            print(f"Sync Milvus collections failed: {e}")
+            logger.exception("Sync Milvus collections failed")
         # ----------------------
 
         # 查询未删除的 KB
@@ -300,12 +300,24 @@ class KnowledgeBaseService:
 
     async def create(self, db: AsyncSession, payload: KnowledgeBaseCreate) -> KnowledgeBase:
         """创建新知识库"""
+        name = (payload.name or "").strip()
+        if not name:
+            raise ValueError("请输入知识库名称")
+
         parent_name = (payload.collection_name_parent or "").strip() or None
         child_name = (payload.collection_name_children or "").strip()
 
         if not child_name:
             raise ValueError("请输入子集合名称")
 
+        # 检查知识库名称是否重名
+        exists_name = await db.execute(select(KnowledgeBase).where(
+            KnowledgeBase.name == name,
+            KnowledgeBase.is_deleted == 0
+        ))
+        if exists_name.scalars().first():
+            raise ValueError(f"知识库名称 {name} 已存在")
+
         # 1. 检查 DB 是否已存在
         # 检查父子集合名称不能相同
         if parent_name and parent_name == child_name:
@@ -314,11 +326,14 @@ class KnowledgeBaseService:
         # 检查 collection_name_parent (可选)
         if parent_name:
             exists1 = await db.execute(select(KnowledgeBase).where(
-                KnowledgeBase.collection_name_parent == parent_name,
+                or_(
+                    KnowledgeBase.collection_name_parent == parent_name,
+                    KnowledgeBase.collection_name_children == parent_name
+                ),
                 KnowledgeBase.is_deleted == 0
             ))
             if exists1.scalars().first():
-                raise ValueError(f"集合名称 {parent_name} 已存在")
+                raise ValueError(f"集合名称 {parent_name} 已存在")
             
         # 检查 collection_name_children
         exists2 = await db.execute(select(KnowledgeBase).where(
@@ -329,14 +344,14 @@ class KnowledgeBaseService:
             KnowledgeBase.is_deleted == 0
         ))
         if exists2.scalars().first():
-            raise ValueError(f"集合名称 {child_name} 已存在")
+            raise ValueError(f"集合名称 {child_name} 已存在")
 
         try:
             # 3. 创建 DB 记录
             now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
             new_kb = KnowledgeBase(
                 id=str(uuid.uuid4()),
-                name=payload.name,
+                name=name,
                 collection_name_parent=parent_name,
                 collection_name_children=child_name,
                 has_parent_collection=1 if parent_name else 0,
@@ -382,7 +397,18 @@ class KnowledgeBaseService:
 
         try:
             if payload.name is not None:
-                kb.name = payload.name
+                new_name = (payload.name or "").strip()
+                if not new_name:
+                    raise ValueError("请输入知识库名称")
+                if new_name != kb.name:
+                    exists_name = await db.execute(select(KnowledgeBase).where(
+                        KnowledgeBase.name == new_name,
+                        KnowledgeBase.is_deleted == 0,
+                        KnowledgeBase.id != id
+                    ))
+                    if exists_name.scalars().first():
+                        raise ValueError(f"知识库名称 {new_name} 已存在")
+                    kb.name = new_name
 
             if payload.description is not None:
                 kb.description = payload.description
@@ -491,7 +517,7 @@ class KnowledgeBaseService:
                     if milvus_service.has_collection(col):
                         milvus_service.drop_collection(col)
                 except Exception as milvus_err:
-                    print(f"Ignore Milvus error during delete {col}: {milvus_err}")
+                    logger.warning("Ignore Milvus error during delete %s: %s", col, milvus_err)
             
             # 2. 解除文档关联 (将 kb_id 置空,状态改为未入库)
             await db.execute(
@@ -647,9 +673,9 @@ class KnowledgeBaseService:
                         
                         await db.commit()
                         metadata_fields = new_fields
-                        print(f"Auto inferred metadata for {target_col}: {inferred_keys}")
+                        logger.info("Auto inferred metadata for %s: %s", target_col, inferred_keys)
             except Exception as e:
-                print(f"Failed to infer metadata for {target_col}: {e}")
+                logger.exception("Failed to infer metadata for %s", target_col)
                 # 推断失败不影响正常返回
         
         # 返回空的 custom_schemas,因为现在是固定 Schema
@@ -692,7 +718,6 @@ class KnowledgeBaseService:
                 # 确保集合已加载
                 state = milvus_service.get_collection_state(collection_name)
                 if state != "Loaded":
-                    # print(f"Collection {collection_name} is {state}, loading...")
                     milvus_service.set_collection_state(collection_name, "load")
                 
                 # 获取该集合的计数
@@ -709,7 +734,6 @@ class KnowledgeBaseService:
                 
                 # 更新数据库
                 if kb.document_count != total_count:
-                    # print(f"Updating doc count for KB {kb.name}: {kb.document_count} -> {total_count}")
                     kb.document_count = total_count
                     kb.updated_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                     await db.commit()

+ 23 - 26
src/app/services/milvus_service.py

@@ -640,7 +640,7 @@ class MilvusService:
 
         # 2. 重新获取集合信息
         desc = self.client.describe_collection(collection_name=name)
-        print(desc)
+        logger.debug("Collection %s describe: %s", name, desc)
         stats = self.client.get_collection_stats(collection_name=name)
         load_state = self.client.get_load_state(collection_name=name)
 
@@ -749,21 +749,19 @@ if __name__ == "__main__":
 
     service = MilvusService()
     
-    # 测试混合搜索 hybrid_search
-    print("=" * 50)
-    print("测试混合检索 (Hybrid Search)")
-    print("=" * 50)
+    logger.info("=" * 50)
+    logger.info("测试混合检索 (Hybrid Search)")
+    logger.info("=" * 50)
     
     try:
         # 示例参数,需要根据实际情况修改
         collection_name = "first_bfp_collection_status" 
         query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行"  # 修改为实际查询内容
         
-        # 测试 weighted 模式
-        print("\n1. 测试 Weighted 重排序模式:")
-        print(f"   集合: {collection_name}")
-        print(f"   查询: {query_text}")
-        print(f"   密集权重: 0.7, 稀疏权重: 0.3")
+        logger.info("1. 测试 Weighted 重排序模式:")
+        logger.info("集合: %s", collection_name)
+        logger.info("查询: %s", query_text)
+        logger.info("密集权重: 0.7, 稀疏权重: 0.3")
         
         results_weighted = service.hybrid_search(
             collection_name=collection_name,
@@ -774,14 +772,14 @@ if __name__ == "__main__":
             sparse_weight=0.3
         )
         
-        print(f"\n   结果数量: {len(results_weighted)}")
+        logger.info("结果数量: %s", len(results_weighted))
         for i, result in enumerate(results_weighted, 1):
-            print(f"   [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
+            text_preview = (result.get("text_content") or "")[:50]
+            logger.info("[%s] ID: %s, Text: %s...", i, result.get("id"), text_preview)
         
-        # 测试 RRF 模式
-        print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
-        print(f"   集合: {collection_name}")
-        print(f"   查询: {query_text}")
+        logger.info("2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
+        logger.info("集合: %s", collection_name)
+        logger.info("查询: %s", query_text)
         
         results_rrf = service.hybrid_search(
             collection_name=collection_name,
@@ -790,21 +788,20 @@ if __name__ == "__main__":
             ranker_type="rrf"
         )
         
-        print(f"\n   结果数量: {len(results_rrf)}")
+        logger.info("结果数量: %s", len(results_rrf))
         for i, result in enumerate(results_rrf, 1):
-            print(f"   [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
+            text_preview = (result.get("text_content") or "")[:50]
+            logger.info("[%s] ID: %s, Text: %s...", i, result.get("id"), text_preview)
         
-        print("\n✓ 混合检索测试完成")
+        logger.info("混合检索测试完成")
         
     except Exception as e:
-        print(f"\n✗ 混合检索测试失败: {e}")
-        import traceback
-        traceback.print_exc()
+        logger.exception("混合检索测试失败")
     
     # 也可以查看集合详情
-    print("\n" + "=" * 50)
-    print("获取所有集合信息:")
-    print("=" * 50)
+    logger.info("=" * 50)
+    logger.info("获取所有集合信息:")
+    logger.info("=" * 50)
     data = service.get_collection_details()
     for item in data:
-        print(json.dumps(item, ensure_ascii=False, indent=2))
+        logger.info("%s", json.dumps(item, ensure_ascii=False, indent=2))

+ 57 - 110
src/app/services/sample_service.py

@@ -105,15 +105,14 @@ class SampleService:
 
     # ==================== 文档管理 ====================
     
-    async def batch_enter_knowledge_base(self, doc_ids: List[str], username: str, kb_method: str = "general", chunk_size: int = 500, separator: str = "。") -> Tuple[int, str]:
+    async def batch_enter_knowledge_base(self, doc_ids: List[str], username: str, kb_id: str = None, kb_method: str = None) -> Tuple[int, str]:
         """批量将文档入库到知识库
         
         Args:
             doc_ids: 文档ID列表
             username: 操作人
+            kb_id: 知识库ID
             kb_method: 切分方法
-            chunk_size: 切分长度
-            separator: 切分符号
         """
         conn = get_db_connection()
         if not conn:
@@ -121,6 +120,7 @@ class SampleService:
         
         cursor = conn.cursor()
         success_count = 0
+        skipped_count = 0
         already_entered_count = 0
         failed_count = 0
         error_details = []
@@ -129,7 +129,7 @@ class SampleService:
             # 1. 获取所有选中选中的文档详情
             placeholders = ','.join(['%s']*len(doc_ids))
             fetch_sql = f"""
-                SELECT id, title, source_type, md_url, conversion_status, whether_to_enter, created_time, kb_id 
+                SELECT id, title, source_type, md_url, conversion_status, whether_to_enter, created_time 
                 FROM t_samp_document_main 
                 WHERE id IN ({placeholders})
             """
@@ -146,7 +146,6 @@ class SampleService:
                 status = doc.get('conversion_status')
                 whether_to_enter = doc.get('whether_to_enter', 0)
                 md_url = doc.get('md_url')
-                source_type = doc.get('source_type')
                 
                 # A. 检查是否已入库
                 if whether_to_enter == 1:
@@ -158,48 +157,18 @@ class SampleService:
                 # B. 检查转换状态
                 if status != 2:
                     reason = "尚未转换成功" if status == 0 else "正在转换中" if status == 1 else "转换失败"
-                    logger.warning(f"文档 {title}({doc_id}) 状态为 {status},入库失败: {reason}")
-                    failed_count += 1
+                    logger.warning(f"文档 {title}({doc_id}) 状态为 {status},跳过入库: {reason}")
+                    skipped_count += 1
                     error_details.append(f"· {title}: {reason}")
                     continue
                 
                 if not md_url:
-                    logger.warning(f"文档 {title}({doc_id}) 缺少 md_url,入库失败")
-                    failed_count += 1
+                    logger.warning(f"文档 {title}({doc_id}) 缺少 md_url,跳过入库")
+                    skipped_count += 1
                     error_details.append(f"· {title}: 转换结果地址丢失")
                     continue
                 
-                # C. 确定入库策略 (从数据库读取已绑定的知识库)
-                current_kb_id = doc.get('kb_id')
-                current_kb_method = kb_method  # 直接使用前端传来的切分方式
-
-                if not current_kb_id:
-                    logger.warning(f"文档 {title}({doc_id}) 未指定知识库,跳过入库")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 未指定目标知识库")
-                    continue
-
-                if not current_kb_method:
-                    logger.warning(f"文档 {title}({doc_id}) 未指定切分方式,跳过入库")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 未指定切分策略")
-                    continue
-
-                # 获取知识库信息 (collection_name_parent, collection_name_children)
-                kb_sql = "SELECT collection_name_parent, collection_name_children FROM t_samp_knowledge_base WHERE id = %s AND is_deleted = 0"
-                cursor.execute(kb_sql, (current_kb_id,))
-                kb_res = cursor.fetchone()
-                
-                if not kb_res:
-                    logger.warning(f"找不到指定的知识库: id={current_kb_id}")
-                    failed_count += 1
-                    error_details.append(f"· {title}: 指定的知识库不存在或已被删除")
-                    continue
-                
-                collection_name_parent = kb_res['collection_name_parent']
-                collection_name_children = kb_res['collection_name_children']
-                
-                # D. 从 MinIO 获取 Markdown 内容
+                # B. 从 MinIO 获取 Markdown 内容
                 try:
                     md_content = self.minio_manager.get_object_content(md_url)
                     if not md_content:
@@ -210,34 +179,39 @@ class SampleService:
                     error_details.append(f"· {title}: 读取云端文件失败")
                     continue
                 
-                # E. 调用 MilvusService 进行切分和入库
+                # C. 调用 MilvusService 进行切分和入库
                 try:
+                    # 如果有 kb_id,需要根据它获取 collection_name
+                    collection_name = None
+                    if kb_id:
+                        kb_sql = "SELECT collection_name FROM t_samp_knowledge_base WHERE id = %s"
+                        cursor.execute(kb_sql, (kb_id,))
+                        kb_res = cursor.fetchone()
+                        if kb_res:
+                            collection_name = kb_res['collection_name']
+                    
                     # 准备元数据
-                    current_date = int(datetime.now().strftime('%Y%m%d'))
                     doc_info = {
                         "doc_id": doc_id,
-                        "file_name": title,
-                        "doc_version": int(doc['created_time'].strftime('%Y%m%d')) if doc.get('created_time') else current_date,
-                        "tags": "",
+                        "doc_name": title,
+                        "doc_version": int(doc['created_time'].strftime('%Y%m%d')) if doc.get('created_time') else 20260127,
+                        "tags": doc.get('source_type') or 'unknown',
                         "user_id": username,  # 传递操作人作为 created_by
-                        "kb_id": current_kb_id,
-                        "kb_method": current_kb_method,
-                        "collection_name_parent": collection_name_parent,
-                        "collection_name_children": collection_name_children,
-                        "chunk_size": chunk_size,
-                        "separator": separator
+                        "kb_id": kb_id,
+                        "kb_method": kb_method,
+                        "collection_name": collection_name
                     }
                     await self.milvus_service.insert_knowledge(md_content, doc_info)
                     
-                    # F. 添加到任务管理中心 (类型为 data)
+                    # D. 添加到任务管理中心 (类型为 data)
                     try:
                         await task_service.add_task(doc_id, 'data')
                     except Exception as task_err:
                         logger.error(f"添加文档 {title} 到任务中心失败: {task_err}")
 
-                    # G. 更新数据库状态
+                    # E. 更新数据库状态
                     update_sql = "UPDATE t_samp_document_main SET whether_to_enter = 1, kb_id = %s, kb_method = %s, updated_by = %s, updated_time = NOW() WHERE id = %s"
-                    cursor.execute(update_sql, (current_kb_id, current_kb_method, username, doc_id))
+                    cursor.execute(update_sql, (kb_id, kb_method, username, doc_id))
                     success_count += 1
                     
                 except Exception as milvus_err:
@@ -249,12 +223,14 @@ class SampleService:
             conn.commit()
             
             # 构造详细的消息
-            if success_count == len(doc_ids) and failed_count == 0 and already_entered_count == 0:
+            if success_count == len(doc_ids) and failed_count == 0 and skipped_count == 0 and already_entered_count == 0:
                 msg = f"✅ 入库成功!共处理 {success_count} 份文档。"
             else:
                 msg = f"📊 入库处理完成:\n· 成功:{success_count} 份\n"
                 if already_entered_count > 0:
                     msg += f"· 跳过:{already_entered_count} 份 (已入库)\n"
+                if skipped_count > 0:
+                    msg += f"· 跳过:{skipped_count} 份 (转换中或失败)\n"
                 if failed_count > 0:
                     msg += f"· 失败:{failed_count} 份\n"
             
@@ -409,9 +385,8 @@ class SampleService:
                     LEFT JOIN {sub_table} s ON m.id = s.id
                     LEFT JOIN t_sys_user u1 ON m.created_by = u1.id
                     LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id
-                    LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id
                 """
-                fields_sql = "m.*, s.*, u1.username as creator_name, u2.username as updater_name, kb.name as kb_name, m.id as id"
+                fields_sql = "m.*, s.*, u1.username as creator_name, u2.username as updater_name, m.id as id"
                 where_clauses.append("m.source_type = %s")
                 params.append(table_type)
                 order_sql = "m.created_time DESC"
@@ -432,8 +407,8 @@ class SampleService:
                         where_clauses.append("s.level_4_classification = %s")
                         params.append(level_4_classification)
             else:
-                from_sql = "t_samp_document_main m LEFT JOIN t_sys_user u1 ON m.created_by = u1.id LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id"
-                fields_sql = "m.*, u1.username as creator_name, u2.username as updater_name, kb.name as kb_name"
+                from_sql = "t_samp_document_main m LEFT JOIN t_sys_user u1 ON m.created_by = u1.id LEFT JOIN t_sys_user u2 ON m.updated_by = u2.id"
+                fields_sql = "m.*, u1.username as creator_name, u2.username as updater_name"
                 order_sql = "m.created_time DESC"
                 title_field = "m.title"
             
@@ -456,6 +431,7 @@ class SampleService:
             sql = f"SELECT {fields_sql} FROM {from_sql} {where_sql} ORDER BY {order_sql} LIMIT %s OFFSET %s"
             params.extend([size, offset])
             
+            logger.info(f"Executing SQL: {sql} with params: {params}")
             cursor.execute(sql, tuple(params))
             items = [self._format_document_row(row) for row in cursor.fetchall()]
             
@@ -570,13 +546,12 @@ class SampleService:
                 INSERT INTO t_samp_document_main (
                     id, title, source_type, file_url, 
                     file_extension, created_by, updated_by, created_time, updated_time,
-                    conversion_status, whether_to_task, kb_id
-                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0, %s)
+                    conversion_status, whether_to_task
+                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0)
                 """,
                 (
                     doc_id, doc_data.get('title'), table_type, file_url,
-                    doc_data.get('file_extension'), user_id, user_id,
-                    doc_data.get('kb_id')
+                    doc_data.get('file_extension'), user_id, user_id
                 )
             )
 
@@ -673,14 +648,14 @@ class SampleService:
             # 1. 更新主表
             cursor.execute(
                 """
-                UPDATE t_samp_document_main SET 
-                    title = %s, file_url = %s, file_extension = %s, 
-                    updated_by = %s, updated_time = NOW(), kb_id = %s
+                UPDATE t_samp_document_main 
+                SET title = %s, file_url = %s, file_extension = %s,
+                    updated_by = %s, updated_time = NOW()
                 WHERE id = %s
                 """,
                 (
                     doc_data.get('title'), file_url, doc_data.get('file_extension'),
-                    updater_id, doc_data.get('kb_id'), doc_id
+                    updater_id, doc_id
                 )
             )
 
@@ -779,7 +754,7 @@ class SampleService:
                     s.participating_units, s.reference_basis,
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.chinese_name',
@@ -803,7 +778,7 @@ class SampleService:
                     s.note, 
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.plan_name',
@@ -824,7 +799,7 @@ class SampleService:
                     s.note, 
                     s.created_by, u1.username as creator_name, s.created_time,
                     s.updated_by, u2.username as updater_name, s.updated_time,
-                    m.file_url, m.conversion_status, m.md_url, m.json_url, m.kb_id, m.whether_to_enter
+                    m.file_url, m.conversion_status, m.md_url, m.json_url
                 """
                 field_map = {
                     'title': 's.file_name',
@@ -885,12 +860,11 @@ class SampleService:
             
             # 使用 LEFT JOIN 关联主表和用户表获取姓名
             sql = f"""
-                SELECT {fields}, kb.name as kb_name
+                SELECT {fields} 
                 FROM {table_name} s
                 LEFT JOIN t_samp_document_main m ON s.id = m.id
                 LEFT JOIN t_sys_user u1 ON s.created_by = u1.id
                 LEFT JOIN t_sys_user u2 ON s.updated_by = u2.id
-                LEFT JOIN t_samp_knowledge_base kb ON m.kb_id = kb.id
                 {where_sql} 
                 ORDER BY s.created_time DESC 
                 LIMIT %s OFFSET %s
@@ -1034,12 +1008,12 @@ class SampleService:
                 INSERT INTO t_samp_document_main (
                     id, title, source_type, file_url, 
                     file_extension, created_by, updated_by, created_time, updated_time,
-                    conversion_status, whether_to_task, kb_id
-                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0, %s)
+                    conversion_status, whether_to_task
+                ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), 0, 0)
                 """,
                 (
                     doc_id, data.get('title'), type, file_url,
-                    file_extension, user_id, user_id, data.get('kb_id')
+                    file_extension, user_id, user_id
                 )
             )
             
@@ -1148,10 +1122,10 @@ class SampleService:
             cursor.execute(
                 """
                 UPDATE t_samp_document_main 
-                SET title = %s, file_url = %s, file_extension = %s, updated_by = %s, updated_time = NOW(), kb_id = %s
+                SET title = %s, file_url = %s, file_extension = %s, updated_by = %s, updated_time = NOW()
                 WHERE id = %s
                 """,
-                (data.get('title'), file_url, file_extension, updater_id, data.get('kb_id'), doc_id)
+                (data.get('title'), file_url, file_extension, updater_id, doc_id)
             )
 
             # 2. 更新子表 (移除 file_url)
@@ -1226,10 +1200,6 @@ class SampleService:
 
     async def delete_basic_info(self, type: str, doc_id: str) -> Tuple[bool, str]:
         """删除基本信息"""
-        if not doc_id:
-            return False, "缺少 ID 参数"
-            
-        logger.info(f"Deleting basic info: type={type}, id={doc_id}")
         conn = get_db_connection()
         if not conn:
             return False, "数据库连接失败"
@@ -1240,44 +1210,21 @@ class SampleService:
             if not table_name:
                 return False, "无效的类型"
             
-            # 1. 显式删除子表记录 (防止 CASCADE 未生效)
-            try:
-                cursor.execute(f"DELETE FROM {table_name} WHERE id = %s", (doc_id,))
-                logger.info(f"Deleted from sub-table {table_name}, affected: {cursor.rowcount}")
-            except Exception as sub_e:
-                logger.warning(f"删除子表 {table_name} 记录失败 (可能不存在): {sub_e}")
-
-            # 2. 同步删除任务管理中心的数据 (优先删除关联数据)
-            try:
-                # 使用当前事务删除任务记录(如果 task_service 支持的话,目前它自建连接)
-                # 这里我们直接在当前 cursor 中也执行一次,确保事务一致性
-                cursor.execute("DELETE FROM t_task_management WHERE business_id = %s", (doc_id,))
-                logger.info(f"Deleted from t_task_management, affected: {cursor.rowcount}")
-            except Exception as task_e:
-                logger.warning(f"在主事务中删除任务记录失败: {task_e}")
-
-            # 3. 删除主表记录
+            # 1. 删除主表记录 (由于设置了 ON DELETE CASCADE,子表记录会自动删除)
             cursor.execute("DELETE FROM t_samp_document_main WHERE id = %s", (doc_id,))
-            affected_main = cursor.rowcount
-            logger.info(f"Deleted from t_samp_document_main, affected: {affected_main}")
-            
-            if affected_main == 0:
-                logger.warning(f"未找到主表记录: {doc_id}")
-                # 即使主表没找到,我们也 commit 之前的操作并返回成功(幂等性)
-            
-            conn.commit()
             
-            # 4. 再次确保任务中心数据已删除 (调用原有服务)
+            # 同步删除任务管理中心的数据
             try:
                 await task_service.delete_task(doc_id)
             except Exception as task_err:
-                logger.error(f"调用 task_service 删除任务失败: {task_err}")
+                logger.error(f"同步删除任务中心数据失败 (ID: {doc_id}): {task_err}")
 
+            conn.commit()
             return True, "删除成功"
         except Exception as e:
-            logger.exception(f"删除基本信息异常 (ID: {doc_id})")
+            logger.exception("删除基本信息失败")
             conn.rollback()
-            return False, f"删除失败: {str(e)}"
+            return False, str(e)
         finally:
             cursor.close()
             conn.close()

+ 186 - 26
src/app/services/search_engine_service.py

@@ -25,6 +25,8 @@ from app.services.milvus_service import milvus_service
 from app.utils.vector_utils import text_to_vector_algo
 import logging
 
+logger = logging.getLogger(__name__)
+
 class SearchEngineService:
     
     async def search_kb(self, db: AsyncSession, payload: KBSearchRequest) -> KBSearchResponse:
@@ -34,15 +36,21 @@ class SearchEngineService:
         original_kb_id = payload.kb_id 
         collection_name = original_kb_id
         
-        # 0. 尝试从数据库解析 kb_id 为 collection_name (如果是 UUID)
+        # 0. 尝试从数据库解析 kb_id 为 collection_name (如果传的是知识库ID)
         from sqlalchemy import text
         try:
             # 简单判断是否是 UUID 格式或数字 ID,尝试查询数据库
-            kb_query = text("SELECT collection_name FROM t_samp_knowledge_base WHERE id = :kb_id OR collection_name = :kb_id")
+            kb_query = text(
+                "SELECT collection_name_children, collection_name_parent "
+                "FROM t_samp_knowledge_base "
+                "WHERE id = :kb_id "
+                "   OR collection_name_children = :kb_id "
+                "   OR collection_name_parent = :kb_id"
+            )
             kb_res = await db.execute(kb_query, {"kb_id": original_kb_id})
             kb_row = kb_res.fetchone()
             if kb_row:
-                collection_name = kb_row[0]
+                collection_name = kb_row[0] or kb_row[1] or collection_name
                 logging.info(f"Resolved kb_id {original_kb_id} to collection_name: {collection_name}")
         except Exception as db_err:
             logging.warning(f"Failed to resolve kb_id {original_kb_id} from database: {db_err}")
@@ -61,6 +69,13 @@ class SearchEngineService:
             logging.info(f"Detected PDR collection for {collection_name}, searching in {child_col}")
         elif not milvus_service.has_collection(collection_name):
             return KBSearchResponse(results=[], total=0)
+
+        try:
+            state = milvus_service.get_collection_state(kb_id)
+            if state != "Loaded":
+                milvus_service.set_collection_state(kb_id, "load")
+        except Exception:
+            pass
             
         # 1. 使用算法生成向量 (替代 Embedding 模型)
         # 尝试从 Milvus collection 获取向量维度,动态匹配维度
@@ -74,23 +89,19 @@ class SearchEngineService:
         if collection_detail and isinstance(collection_detail, dict):
             fields = collection_detail.get("fields", []) or []
             for f in fields:
-                # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
                 if not isinstance(f, dict):
                     continue
-                ftype = str(f.get("type") or "").lower()
-                print(ftype+'是什么东西')
-                if "100" in ftype or '101' in ftype:  # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
-                    # 找到向量字段,优先从 params.dim 获取维度
-                    params = f.get("params") or {}
-                    if params and params.get("dim"):
-                        try:
-                            dim = int(params.get("dim"))
-                            break
-                        except Exception:
-                            dim = None
-        # 回退默认维度
+                params = f.get("params") or {}
+                if params and params.get("dim"):
+                    try:
+                        dim = int(params.get("dim"))
+                        break
+                    except Exception:
+                        dim = None
+
+        # 回退默认维度:与系统 embedding 维度保持一致(避免向量维度不匹配导致检索报错)
         if not dim:
-            dim = 768
+            dim = milvus_service.DENSE_DIM
 
         # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
         anns_field = "dense"
@@ -149,6 +160,52 @@ class SearchEngineService:
         
         # 2. 构建过滤表达式
         expr_list = []
+
+        metadata_type_map: Dict[str, str] = {}
+        try:
+            from app.sample.models.knowledge_base import KnowledgeBase
+            from app.sample.models.metadata import SampleMetadata
+
+            kb_stmt = select(KnowledgeBase).where(
+                or_(
+                    KnowledgeBase.collection_name_children == collection_name,
+                    KnowledgeBase.collection_name_parent == collection_name
+                ),
+                KnowledgeBase.is_deleted == 0
+            )
+            kb_res = await db.execute(kb_stmt)
+            kb_obj = kb_res.scalars().first()
+            if kb_obj:
+                meta_stmt = select(SampleMetadata.field_en_name, SampleMetadata.field_type).where(
+                    SampleMetadata.knowledge_base_id == kb_obj.id
+                )
+                meta_res = await db.execute(meta_stmt)
+                rows = meta_res.all()
+                metadata_type_map = {str(r[0]): str(r[1]) for r in rows if r and r[0]}
+        except Exception:
+            metadata_type_map = {}
+
+        def build_eq_expr(target_field: str, value: str, field_type: Optional[str], is_top_level: bool) -> str:
+            field_type_norm = (field_type or "").strip().lower()
+            looks_numeric = False
+            try:
+                float(value)
+                looks_numeric = True
+            except Exception:
+                looks_numeric = False
+
+            if field_type_norm == "num" or (looks_numeric and not is_top_level):
+                try:
+                    num_val = float(value)
+                    num_expr = f"{target_field} == {int(num_val) if num_val.is_integer() else num_val}"
+                    if is_top_level and field_type_norm == "num":
+                        return num_expr
+                    str_expr = f'{target_field} == "{value}"'
+                    return f"({num_expr} or {str_expr})"
+                except Exception:
+                    return f'{target_field} == "{value}"'
+
+            return f'{target_field} == "{value}"'
         
         # 兼容旧的单一字段过滤
         if payload.metadata_field and payload.metadata_value:
@@ -176,7 +233,13 @@ class SearchEngineService:
                 if not is_top_level:
                      target_field = f'metadata["{safe_field}"]'
 
-                expr_list.append(f'{target_field} == "{safe_value}"')
+                if is_top_level:
+                    expr_list.append(build_eq_expr(target_field, safe_value, metadata_type_map.get(safe_field), True))
+                else:
+                    alt_field = f'metadata["metadata"]["{safe_field}"]'
+                    expr_main = build_eq_expr(target_field, safe_value, metadata_type_map.get(safe_field), False)
+                    expr_alt = build_eq_expr(alt_field, safe_value, metadata_type_map.get(safe_field), False)
+                    expr_list.append(f"({expr_main} or {expr_alt})")
         
         # 处理新的多重过滤
         if payload.filters:
@@ -221,6 +284,9 @@ class SearchEngineService:
                                 sub_exprs.append(f'metadata["file_name"] in {val_list_str}')
                                 sub_exprs.append(f'metadata["doc_name"] in {val_list_str}')
                                 sub_exprs.append(f'metadata["title"] in {val_list_str}')
+                                sub_exprs.append(f'metadata["metadata"]["file_name"] in {val_list_str}')
+                                sub_exprs.append(f'metadata["metadata"]["doc_name"] in {val_list_str}')
+                                sub_exprs.append(f'metadata["metadata"]["title"] in {val_list_str}')
                                 
                                 # 组合成 (A or B or C)
                                 # 注意:如果某些字段不存在,Milvus 可能会报错吗?
@@ -248,7 +314,13 @@ class SearchEngineService:
                          target_field = f'metadata["{safe_field}"]'
 
                     # [Fix] 统一将 metadata 值视为字符串查询
-                    expr_list.append(f'{target_field} == "{safe_value}"')
+                    if is_top_level:
+                        expr_list.append(build_eq_expr(target_field, safe_value, metadata_type_map.get(safe_field), True))
+                    else:
+                        alt_field = f'metadata["metadata"]["{safe_field}"]'
+                        expr_main = build_eq_expr(target_field, safe_value, metadata_type_map.get(safe_field), False)
+                        expr_alt = build_eq_expr(alt_field, safe_value, metadata_type_map.get(safe_field), False)
+                        expr_list.append(f"({expr_main} or {expr_alt})")
         
         # 组合所有条件 (使用 AND)
         expr = " and ".join(expr_list) if expr_list else ""
@@ -276,9 +348,13 @@ class SearchEngineService:
                     total = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
                 else:
                     # 带条件 count
-                    res_cnt = milvus_service.client.query(kb_id, filter=count_expr, output_fields=["count(*)"])
-                    if res_cnt:
-                        total = res_cnt[0].get("count(*)") or 0
+                    try:
+                        res_cnt = milvus_service.client.query(kb_id, filter=count_expr, output_fields=["count(*)"])
+                        if res_cnt:
+                            total = int(res_cnt[0].get("count(*)") or 0)
+                    except Exception as e:
+                        logger.warning(f"Scalar count(*) failed for KB={kb_id}, expr={count_expr}: {e}")
+                        total = 0
                 
                 # 2. 分页查询
                 # 如果没有 expr,Milvus query 需要一个 valid expression
@@ -351,10 +427,94 @@ class SearchEngineService:
 
             except Exception as e:
                 logging.error(f"Scalar query failed: {e}")
-                return KBSearchResponse(results=[], total=0)
+                raise ValueError(f"元数据过滤条件查询失败:{e}")
+
+        # --- 分支 B: 关键词/混合检索 (有关键词) ---
+        # 优先尝试用 text like 做关键词召回(比算法向量更贴合“相关性”);无结果再回退向量检索
+        query_text = (payload.query or "").strip()
+        if use_hybrid and query_text:
+            logger = logging.getLogger(__name__)
+            safe_q = query_text.replace('"', "").replace("'", "").strip()
+            like_expr = f'text like "%{safe_q}%"'
+            combined_expr = f"({expr}) and ({like_expr})" if expr else like_expr
+
+            page = payload.page if payload.page and payload.page > 0 else 1
+            page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
+            offset = (page - 1) * page_size
+            limit = page_size
+
+            try:
+                total = 0
+                try:
+                    res_cnt = milvus_service.client.query(kb_id, filter=combined_expr, output_fields=["count(*)"])
+                    if res_cnt:
+                        total = int(res_cnt[0].get("count(*)") or 0)
+                except Exception as e:
+                    logger.warning(f"Keyword count(*) failed for KB={kb_id}, expr={combined_expr}: {e}")
+                    total = 0
+
+                fetch_limit = min(500, max(limit * 3, offset + limit))
+                res_page = milvus_service.client.query(
+                    collection_name=kb_id,
+                    filter=combined_expr,
+                    output_fields=["*"],
+                    limit=fetch_limit,
+                    offset=0
+                )
+
+                formatted_results = []
+                for item in res_page or []:
+                    item_metadata = item.get('metadata') or {}
+                    if isinstance(item_metadata, str):
+                        try:
+                            item_metadata = json.loads(item_metadata)
+                        except Exception:
+                            item_metadata = {}
+
+                    content = item.get('text') or item.get('content') or item.get('page_content') or ""
+                    doc_name = (
+                        item_metadata.get('doc_name')
+                        or item_metadata.get('file_name')
+                        or item_metadata.get('title')
+                        or item_metadata.get('source')
+                        or item.get('file_name')
+                        or item.get('title')
+                        or item.get('source')
+                        or "未知文档"
+                    )
+                    parent_id = item.get("parent_id") or item_metadata.get("parent_id") or ""
+                    document_id = item.get("document_id") or item_metadata.get("document_id") or ""
+
+                    occ = 0
+                    try:
+                        occ = (content or "").count(query_text)
+                    except Exception:
+                        occ = 0
+                    score = 60.0 + min(40.0, float(min(4, occ)) * 10.0)
+                    if query_text and doc_name and query_text in str(doc_name):
+                        score = min(100.0, score + 10.0)
+
+                    formatted_results.append(KBSearchResultItem(
+                        id=str(item.get('pk') or item.get('id')),
+                        kb_name=original_kb_id,
+                        doc_name=doc_name,
+                        content=content,
+                        meta_info=str(item_metadata),
+                        document_id=str(document_id) if document_id is not None else None,
+                        parent_id=str(parent_id) if parent_id is not None else None,
+                        metadata=item_metadata,
+                        score=round(score, 2)
+                    ))
+
+                formatted_results.sort(key=lambda r: (r.score or 0), reverse=True)
+                paged = formatted_results[offset: offset + limit]
+                if total == 0:
+                    total = len(formatted_results)
+                return KBSearchResponse(results=paged, total=total)
+            except Exception as e:
+                logger.warning(f"Keyword search fallback to vector due to error: {e}")
 
-        # --- 分支 B: 向量/混合检索 (有关键词) ---
-        # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
+        # 回退到向量检索(算法向量)
         query_vector = text_to_vector_algo(payload.query, dim=dim)
         
         # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
@@ -639,7 +799,7 @@ class SearchEngineService:
             return KBSearchResponse(results=formatted_results, total=final_total)
             
         except Exception as e:
-            print(f"Search error: {e}")
+            logger.exception("Search error")
             return KBSearchResponse(results=[], total=0)
 
     # ... (Keep existing CRUD methods below) ...

+ 20 - 36
src/app/services/snippet_service.py

@@ -10,6 +10,7 @@ import io
 import time
 import uuid
 from datetime import datetime
+import logging
 
 from app.services.milvus_service import milvus_service
 from app.schemas.base import PaginationSchema, PaginatedResponseSchema
@@ -24,6 +25,8 @@ from app.base.async_mysql_connection import get_db_connection
 
 from app.sample.models.base_info import DocumentMain
 
+logger = logging.getLogger(__name__)
+
 class SnippetService:
     
     async def get_list(
@@ -56,7 +59,7 @@ class SnippetService:
                         kb_name_map[row[1]] = row[2]
                         
         except Exception as e:
-            print(f"Failed to load KB map: {e}")
+            logger.exception("Failed to load KB map")
 
         # 1. 确定要查询的目标集合列表
         target_collections = []
@@ -101,10 +104,10 @@ class SnippetService:
                 # 目前 Schema 只有 is_deleted。
                 # 兼容处理:
                 try:
-                    # 尝试查询一条数据,看是否支持 is_deleted 字段
-                    # 这是一个简单的探针查询,如果报错则说明字段不存在
-                    milvus_service.client.query(col_name, filter="is_deleted == false", output_fields=["count(*)"], limit=1)
-                    has_is_deleted = True
+                    desc = milvus_service.client.describe_collection(col_name)
+                    fields = desc.get("fields", []) if isinstance(desc, dict) else []
+                    field_names = [f.get("name") for f in fields if isinstance(f, dict)]
+                    has_is_deleted = "is_deleted" in field_names
                 except Exception:
                     has_is_deleted = False
 
@@ -173,10 +176,6 @@ class SnippetService:
                     # 如果有状态过滤,也必须 query,不能直接用 stats
                     if status:
                          # 必须 query 计数
-                         # 优化:先 count
-                         res_cnt = milvus_service.client.query(col_name, filter=expr, output_fields=["count(*)"])
-                         # res_cnt 格式可能不同,视 Milvus 版本。通常 query 不支持聚合。
-                         # 只能先 query id
                          res = milvus_service.client.query(col_name, filter=expr, output_fields=["pk"])
                          col_hits = len(res)
                          global_total += col_hits
@@ -234,7 +233,7 @@ class SnippetService:
                         need_count -= len(res_page)
 
             except Exception as e:
-                print(f"Collection {col_name} query error: {e}")
+                logger.exception("Collection %s query error", col_name)
                 continue
 
         total_pages = (global_total + page_size - 1) // page_size if page_size else 0
@@ -273,13 +272,10 @@ class SnippetService:
                 from sqlalchemy import select
 
                 async with get_db_connection() as db:
-                    # 打印调试信息
-                    # print(f"Querying DocumentMain for {len(doc_ids)} ids: {list(doc_ids)[:5]}...")
                     stmt = select(DocumentMain.id, DocumentMain.title).where(DocumentMain.id.in_(list(doc_ids)))
                     result = await db.execute(stmt)
                     rows = result.all()
                     doc_name_map = {str(row[0]): row[1] for row in rows}
-                    # print(f"Found {len(doc_name_map)} documents.")
                 
                 if doc_name_map:
                     for item in items:
@@ -287,9 +283,7 @@ class SnippetService:
                         if did and str(did) in doc_name_map:
                             item["doc_name"] = doc_name_map[str(did)]
             except Exception as e:
-                import traceback
-                traceback.print_exc()
-                print(f"Failed to fetch document names from DB: {e}")
+                logger.exception("Failed to fetch document names from DB")
 
         return items, meta
 
@@ -399,7 +393,7 @@ class SnippetService:
                 if res[0].get("created_by"):
                     old_created_by = res[0].get("created_by")
         except Exception as e:
-            print(f"Failed to fetch old item info: {e}")
+            logger.exception("Failed to fetch old item info")
 
         # 1. 删除旧数据
 
@@ -541,9 +535,9 @@ class SnippetService:
                 try:
                     async with AsyncSessionLocal() as db:
                         count = await knowledge_base_service.update_doc_count(db, kb)
-                        print(f"Synced doc count for {kb} after delete: {count}")
+                        logger.info("Synced doc count for %s after delete: %s", kb, count)
                 except Exception as ex:
-                    print(f"Error in sync_count task: {ex}")
+                    logger.exception("Error in sync_count task")
             
             # 检查是否有正在运行的 loop
             try:
@@ -555,7 +549,7 @@ class SnippetService:
                 # 没有 loop,可以直接 run
                 asyncio.run(sync_count())
         except Exception as e:
-            print(f"Failed to sync doc count after delete: {e}")
+            logger.exception("Failed to sync doc count after delete")
 
     def _format_snippet(self, r: Dict, col_name: str, kb_map: Dict[str, str] = None) -> Dict:
         id_val = r.get("pk") or r.get("id")
@@ -598,7 +592,7 @@ class SnippetService:
         meta_dict = meta if isinstance(meta, dict) else {}
         
         parent_id = r.get("parent_id")
-        print(f"parent_id from DB: {parent_id}11111111111111111111111111111111111111111111111111111111")
+        
         if not parent_id and "parent_id" in meta_dict:
             parent_id = meta_dict["parent_id"]
             
@@ -722,7 +716,7 @@ class SnippetService:
             
         try:
             # 获取集合字段信息
-            print("111111111111111111111111111111111111111111111111111111")
+            
             desc = milvus_service.client.describe_collection(kb)
             fields = desc.get('fields', [])
             field_names = [f['name'] for f in fields]
@@ -813,10 +807,8 @@ class SnippetService:
             # [New Feature] 获取父段内容
             # 逻辑:根据当前子表(kb) -> 查 KnowledgeBase 表找到对应的父表 -> 用 parent_id 查父表内容
             parent_id = snippet_data.get("parent_id") or clean_id
-            # print(f"DEBUG: snippet_id={id}, kb={kb}, found parent_id={parent_id}")
             
             if parent_id:
-                print(parent_id+'2222222222222222222222222222222222222222222222222222222')
                 try:
                     # 1. 查找 KnowledgeBase 记录
                     from sqlalchemy import select, or_
@@ -837,7 +829,7 @@ class SnippetService:
                         
                         # 2. 在父表中查询 parent_id 相同的父段(可能有多个切片)
                         if milvus_service.has_collection(parent_kb):
-                            print("successqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq")
+                            logger.debug("Parent collection describe succeeded")
                             p_desc = milvus_service.client.describe_collection(parent_kb)
                             p_fields = p_desc.get('fields', [])
                             p_field_names = [f['name'] for f in p_fields]
@@ -900,22 +892,14 @@ class SnippetService:
                                 parent_segments.sort(key=_sort_key)
                                 snippet_data["parent_segments"] = parent_segments
                                 snippet_data["parent_content"] = parent_segments[0].get("content") or ""
-                            # else:
-                                # print("DEBUG: Parent content NOT found")
-                    # else:
-                        # print("DEBUG: Parent KB record NOT found or no collection_name_parent")
                                 
                 except Exception as e:
-                    print(f"Failed to fetch parent content: {e}")
-                    import traceback
-                    traceback.print_exc()
+                    logger.exception("Failed to fetch parent content")
 
             return snippet_data
             
         except Exception as e:
-            import traceback
-            traceback.print_exc()
-            print(f"Get snippet detail error: {e}, id={id}, kb={kb}")
+            logger.exception("Get snippet detail error, id=%s, kb=%s", id, kb)
             return None
 
     def export_snippets(self, kb: Optional[str] = None, keyword: Optional[str] = None) -> Any:
@@ -975,7 +959,7 @@ class SnippetService:
                         break
                         
             except Exception as e:
-                print(f"Collection {col_name} export error: {e}")
+                logger.exception("Collection %s export error", col_name)
                 continue
 
     def generate_csv_stream(self, kb: Optional[str] = None, keyword: Optional[str] = None):

+ 131 - 92
src/deploy/admin_front_deploy.py

@@ -14,11 +14,14 @@ import argparse
 import tempfile
 import hashlib
 import subprocess
+import logging
 from pathlib import Path
 from typing import Optional, Tuple, List
 import getpass
 import shutil
 
+logger = logging.getLogger(__name__)
+
 class VueAutoDeployer:
     def __init__(self, hostname: str, username: str, 
                  local_source_dir: str, remote_deploy_dir: str,
@@ -244,7 +247,7 @@ class VueAutoDeployer:
 
     def _validate_all_local_directories(self):
         """验证本地目录是否存在且包含必要文件"""
-        print(f"检查本地目录: {self.local_source_dir}")
+        logger.info("检查本地目录: %s", self.local_source_dir)
         
         if not os.path.exists(self.local_source_dir):
             raise FileNotFoundError(f"本地目录不存在: {self.local_source_dir}")
@@ -272,17 +275,21 @@ class VueAutoDeployer:
             )
         
         # 显示目录内容
-        print("本地目录内容:")
+        logger.info("本地目录内容:")
         for item in os.listdir(self.local_source_dir):
             item_path = os.path.join(self.local_source_dir, item)
             if os.path.isdir(item_path):
-                print(f"  📁 {item}/")
+                logger.info("  📁 %s/", item)
             else:
-                print(f"  📄 {item}")
+                logger.info("  📄 %s", item)
         
+<<<<<<< Updated upstream
         print("✓ 本地目录验证通过")
 
 
+=======
+        logger.info("本地目录验证通过")
+>>>>>>> Stashed changes
     
     def _create_zip_from_source(self) -> str:
         """
@@ -291,7 +298,7 @@ class VueAutoDeployer:
         Returns:
             zip文件的临时路径
         """
-        print(f"\n正在创建压缩包: {self.zip_filename}")
+        logger.info("正在创建压缩包: %s", self.zip_filename)
         
         # 创建临时文件
         temp_dir = tempfile.mkdtemp(prefix="vue_deploy_")
@@ -302,7 +309,7 @@ class VueAutoDeployer:
                 # 添加index.html
                 index_path = os.path.join(self.local_source_dir, 'index.html')
                 zipf.write(index_path, 'index.html')
-                print(f"  ✓ 添加: index.html")
+                logger.info("  ✓ 添加: index.html")
                 
                 # 添加assets目录
                 assets_dir = os.path.join(self.local_source_dir, 'assets')
@@ -316,7 +323,7 @@ class VueAutoDeployer:
                             file_path = os.path.join(root, file)
                             arcname = os.path.join(rel_path, file)
                             zipf.write(file_path, arcname)
-                            print(f"  ✓ 添加: {arcname}")
+                            logger.info("  ✓ 添加: %s", arcname)
                 
                 # 添加其他可能的文件(css, js文件)
                 for item in os.listdir(self.local_source_dir):
@@ -324,16 +331,16 @@ class VueAutoDeployer:
                         item_path = os.path.join(self.local_source_dir, item)
                         if os.path.isfile(item_path) and item.endswith(('.css', '.js')):
                             zipf.write(item_path, item)
-                            print(f"  ✓ 添加: {item}")
+                            logger.info("  ✓ 添加: %s", item)
             
             # 获取压缩包信息
             zip_size = os.path.getsize(zip_path)
             file_count = len(zipfile.ZipFile(zip_path, 'r').namelist())
             
-            print(f"✓ 压缩包创建完成:")
-            print(f"  文件路径: {zip_path}")
-            print(f"  文件大小: {zip_size / 1024 / 1024:.2f} MB")
-            print(f"  包含文件: {file_count} 个")
+            logger.info("压缩包创建完成")
+            logger.info("文件路径: %s", zip_path)
+            logger.info("文件大小: %.2f MB", zip_size / 1024 / 1024)
+            logger.info("包含文件: %s 个", file_count)
             
             return zip_path
             
@@ -344,7 +351,7 @@ class VueAutoDeployer:
     
     def connect(self) -> bool:
         """连接到SSH服务器"""
-        print(f"\n正在连接到服务器 {self.hostname}:{self.port}...")
+        logger.info("正在连接到服务器 %s:%s...", self.hostname, self.port)
         
         try:
             self.ssh_client = paramiko.SSHClient()
@@ -370,7 +377,7 @@ class VueAutoDeployer:
                 connect_params['password'] = self.password
                 auth_method = "密码认证"
             
-            print(f"使用认证方式: {auth_method}")
+            logger.info("使用认证方式: %s", auth_method)
             self.ssh_client.connect(**connect_params, timeout=30)
             
             # 测试连接
@@ -379,22 +386,22 @@ class VueAutoDeployer:
             user = output.split('\n')[1] if len(output.split('\n')) > 1 else '未知'
             hostname = output.split('\n')[2] if len(output.split('\n')) > 2 else '未知'
             
-            print(f"✓ SSH连接成功!")
-            print(f"  服务器: {hostname}")
-            print(f"  用户: {user}")
+            logger.info("SSH连接成功")
+            logger.info("服务器: %s", hostname)
+            logger.info("用户: %s", user)
             
             # 创建SFTP客户端
             self.sftp_client = self.ssh_client.open_sftp()
             return True
             
         except paramiko.AuthenticationException:
-            print("✗ SSH认证失败!请检查用户名/密码/密钥")
+            logger.error("SSH认证失败,请检查用户名/密码/密钥")
             return False
         except paramiko.SSHException as e:
-            print(f"✗ SSH连接异常: {e}")
+            logger.exception("SSH连接异常: %s", e)
             return False
         except Exception as e:
-            print(f"✗ 连接失败: {e}")
+            logger.exception("连接失败: %s", e)
             return False
     
     def disconnect(self):
@@ -403,7 +410,7 @@ class VueAutoDeployer:
             self.sftp_client.close()
         if self.ssh_client:
             self.ssh_client.close()
-        print("✓ 已断开SSH连接")
+        logger.info("已断开SSH连接")
     
     def execute_command(self, command: str, verbose: bool = True) -> Tuple[int, str, str]:
         """
@@ -418,7 +425,7 @@ class VueAutoDeployer:
         """
         try:
             if verbose:
-                print(f"执行命令: {command}")
+                logger.info("执行命令: %s", command)
             
             stdin, stdout, stderr = self.ssh_client.exec_command(command, timeout=60)
             
@@ -431,15 +438,15 @@ class VueAutoDeployer:
             
             if verbose:
                 if stdout_str:
-                    print(f"输出:\n{stdout_str}")
+                    logger.info("输出:\n%s", stdout_str)
                 if stderr_str and exit_status != 0:
-                    print(f"错误:\n{stderr_str}")
-                print(f"返回码: {exit_status}")
+                    logger.warning("错误:\n%s", stderr_str)
+                logger.info("返回码: %s", exit_status)
             
             return exit_status, stdout_str, stderr_str
             
         except Exception as e:
-            print(f"✗ 执行命令失败: {e}")
+            logger.exception("执行命令失败: %s", e)
             return -1, "", str(e)
     
     def upload_file(self, local_path: str, remote_path: str) -> bool:
@@ -455,36 +462,36 @@ class VueAutoDeployer:
         """
         try:
             if not os.path.exists(local_path):
-                print(f"✗ 本地文件不存在: {local_path}")
+                logger.error("本地文件不存在: %s", local_path)
                 return False
             
-            print(f"本地文件: {local_path}")
+            logger.info("本地文件: %s", local_path)
             file_size = os.path.getsize(local_path)
-            print(f"正在上传文件: {os.path.basename(local_path)} ({file_size/1024/1024:.2f} MB)")
+            logger.info("正在上传文件: %s (%.2f MB)", os.path.basename(local_path), file_size / 1024 / 1024)
             
             # 确保远程目录存在并检查权限
             remote_dir = os.path.dirname(remote_path)
-            print(f"远程文件目录: {remote_dir}")
+            logger.info("远程文件目录: %s", remote_dir)
             
             # 创建目录
             exit_code, stdout, stderr = self.execute_command(f"mkdir -p {remote_dir}", verbose=False)
             if exit_code != 0:
-                print(f"✗ 创建远程目录失败: {stderr}")
+                logger.error("创建远程目录失败: %s", stderr)
                 return False
             
             # 检查目录权限
             exit_code, stdout, stderr = self.execute_command(f"ls -ld {remote_dir}", verbose=False)
             if exit_code == 0:
-                print(f"目录权限: {stdout}")
+                logger.info("目录权限: %s", stdout)
             
             # 检查写入权限
             test_file = os.path.join(remote_dir, "test_write_permission.tmp")
             exit_code, stdout, stderr = self.execute_command(f"touch {test_file} && rm -f {test_file}", verbose=False)
             if exit_code != 0:
-                print(f"✗ 远程目录没有写入权限: {remote_dir}")
-                print(f"错误: {stderr}")
+                logger.error("远程目录没有写入权限: %s", remote_dir)
+                logger.error("错误: %s", stderr)
                 return False
-            print("✓ 远程目录写入权限检查通过")
+            logger.info("远程目录写入权限检查通过")
             
             # 使用SFTP上传文件(显示进度)
             start_time = time.time()
@@ -502,30 +509,30 @@ class VueAutoDeployer:
             # 验证上传
             exit_code, stdout, stderr = self.execute_command(f"ls -lh {remote_path}", verbose=False)
             if exit_code == 0 and stdout:
-                print(f"✓ 文件验证成功: {stdout}")
+                logger.info("文件验证成功: %s", stdout)
             else:
-                print(f"⚠ 文件验证失败: 文件可能未正确上传")
+                logger.warning("文件验证失败: 文件可能未正确上传")
                 if stderr:
-                    print(f"错误: {stderr}")
+                    logger.warning("错误: %s", stderr)
                 # 检查目录内容
-                print(f"检查目录内容: {remote_dir}")
+                logger.info("检查目录内容: %s", remote_dir)
                 self.execute_command(f"ls -la {remote_dir}", verbose=True)
                 return False
             
             elapsed = time.time() - start_time
-            print(f"\n✓ 文件上传成功!耗时: {elapsed:.1f}秒")
+            logger.info("文件上传成功,耗时: %.1f秒", elapsed)
             
             return True
             
         except paramiko.SFTPError as e:
-            print(f"\n✗ SFTP上传失败: {e}")
-            print("可能的原因:")
-            print("  1. 远程目录权限不足")
-            print("  2. 磁盘空间不足")
-            print("  3. 网络连接中断")
+            logger.exception("SFTP上传失败: %s", e)
+            logger.info("可能的原因:")
+            logger.info("  1. 远程目录权限不足")
+            logger.info("  2. 磁盘空间不足")
+            logger.info("  3. 网络连接中断")
             return False
         except Exception as e:
-            print(f"\n✗ 文件上传失败: {e}")
+            logger.exception("文件上传失败: %s", e)
             return False
     
     def check_remote_prerequisites(self) -> bool:
@@ -535,7 +542,7 @@ class VueAutoDeployer:
         Returns:
             是否满足条件
         """
-        print("\n检查远程服务器部署条件...")
+        logger.info("检查远程服务器部署条件...")
         
         checks = []
         
@@ -598,24 +605,24 @@ class VueAutoDeployer:
             checks.append(("nginx目录", "✓", stdout.strip().split()[-1]))
         
         # 显示检查结果
-        print("\n" + "="*60)
-        print("服务器环境检查结果:")
-        print("="*60)
+        logger.info("=" * 60)
+        logger.info("服务器环境检查结果:")
+        logger.info("=" * 60)
         
         all_passed = True
         for check_name, status, message in checks:
             if status == "✓":
-                print(f"  {status} {check_name}: {message}")
+                logger.info("  %s %s: %s", status, check_name, message)
             elif status == "⚠":
-                print(f"  {status} {check_name}: {message}")
+                logger.info("  %s %s: %s", status, check_name, message)
             else:
-                print(f"  {status} {check_name}: {message}")
+                logger.info("  %s %s: %s", status, check_name, message)
                 all_passed = False
         
-        print("="*60)
+        logger.info("=" * 60)
         
         if not all_passed:
-            print("\n⚠ 警告: 部分检查未通过,部署可能会失败")
+            logger.warning("警告: 部分检查未通过,部署可能会失败")
             response = input("是否继续部署?(y/N): ").strip().lower()
             return response == 'y'
         
@@ -631,9 +638,9 @@ class VueAutoDeployer:
         Returns:
             是否部署成功
         """
-        print("="*70)
-        print("Vue前端应用自动化部署流程")
-        print("="*70)
+        logger.info("=" * 70)
+        logger.info("Vue前端应用自动化部署流程")
+        logger.info("=" * 70)
         
         temp_zip_path = None
         temp_dir = None
@@ -647,9 +654,14 @@ class VueAutoDeployer:
                     return False
             
             # 步骤1: 本地压缩文件
+<<<<<<< Updated upstream
             step_num = "1/5" if self.frontend_project_dir else "1/4"
             print(f"\n[步骤 {step_num}] 本地压缩Vue构建文件")
             print("-"*40)
+=======
+            logger.info("[步骤 1/4] 本地压缩Vue构建文件")
+            logger.info("-" * 40)
+>>>>>>> Stashed changes
             temp_zip_path = self._create_zip_from_source()
             temp_dir = os.path.dirname(temp_zip_path)
 
@@ -658,9 +670,14 @@ class VueAutoDeployer:
             self._validate_all_local_directories()
             
             # 步骤2: 连接到服务器
+<<<<<<< Updated upstream
             step_num = "2/5" if self.frontend_project_dir else "2/4"
             print(f"\n[步骤 {step_num}] 连接到远程服务器")
             print("-"*40)
+=======
+            logger.info("[步骤 2/4] 连接到远程服务器")
+            logger.info("-" * 40)
+>>>>>>> Stashed changes
             if not self.connect():
                 return False
             
@@ -669,50 +686,63 @@ class VueAutoDeployer:
                 return False
             
             # 步骤3: 上传文件
+<<<<<<< Updated upstream
             step_num = "3/5" if self.frontend_project_dir else "3/4"
             print(f"\n[步骤 {step_num}] 上传文件到服务器")
             print("-"*40)
+=======
+            logger.info("[步骤 3/4] 上传文件到服务器")
+            logger.info("-" * 40)
+>>>>>>> Stashed changes
             remote_zip_path = os.path.join(self.remote_deploy_dir, self.zip_filename)
             
             if not self.upload_file(temp_zip_path, remote_zip_path):
                 return False
             
             # 步骤4: 执行部署脚本
+<<<<<<< Updated upstream
             step_num = "4/5" if self.frontend_project_dir else "4/4"
             print(f"\n[步骤 {step_num}] 执行远程部署脚本")
             print("-"*40)
+=======
+            logger.info("[步骤 4/4] 执行远程部署脚本")
+            logger.info("-" * 40)
+>>>>>>> Stashed changes
             
             # 构建部署命令,传递上传的zip文件路径作为参数  {remote_zip_path}
             deploy_command = f"{self.remote_script_path}"
             
-            print(f"执行部署命令: {deploy_command}")
-            print("-"*40)
+            logger.info("执行部署命令: %s", deploy_command)
+            logger.info("-" * 40)
             
             start_time = time.time()
             exit_code, stdout, stderr = self.execute_command(deploy_command, verbose=True)
             elapsed_time = time.time() - start_time
             
-            print("-"*40)
-            print(f"部署执行完成,耗时: {elapsed_time:.1f}秒")
+            logger.info("-" * 40)
+            logger.info("部署执行完成,耗时: %.1f秒", elapsed_time)
             
             if exit_code != 0:
-                print(f"✗ 部署失败!返回码: {exit_code}")
+                logger.error("部署失败,返回码: %s", exit_code)
                 if stderr:
-                    print(f"错误信息:\n{stderr}")
+                    logger.error("错误信息:\n%s", stderr)
                 return False
             
-            print("✅ 部署成功完成!")
+            logger.info("部署成功完成")
             
             # 可选: 验证部署结果
+<<<<<<< Updated upstream
             print("\n验证部署结果...")
             self.execute_command("ls -la /home/lq/nginx/html/ 2>/dev/null | head -10", verbose=True)
+=======
+            logger.info("验证部署结果...")
+            self.execute_command("ls -la /usr/share/nginx/html/ 2>/dev/null | head -10", verbose=True)
+>>>>>>> Stashed changes
             
             return True
             
         except Exception as e:
-            print(f"\n✗ 部署过程中发生错误: {e}")
-            import traceback
-            traceback.print_exc()
+            logger.exception("部署过程中发生错误: %s", e)
             return False
             
         finally:
@@ -720,9 +750,9 @@ class VueAutoDeployer:
             if cleanup_temp and temp_dir and os.path.exists(temp_dir):
                 try:
                     shutil.rmtree(temp_dir)
-                    print(f"✓ 已清理临时文件: {temp_dir}")
+                    logger.info("已清理临时文件: %s", temp_dir)
                 except:
-                    print(f"⚠ 清理临时文件失败: {temp_dir}")
+                    logger.warning("清理临时文件失败: %s", temp_dir)
             
             # 断开连接
             self.disconnect()
@@ -773,13 +803,15 @@ def parse_arguments():
 
 def main():
     """主函数"""
+    logging.basicConfig(level=logging.INFO, format="%(message)s")
     args = parse_arguments()
     
     # 显示欢迎信息
-    print("="*70)
-    print("🚀 Vue前端应用自动化部署工具")
-    print("="*70)
+    logger.info("=" * 70)
+    logger.info("🚀 Vue前端应用自动化部署工具")
+    logger.info("=" * 70)
     
+<<<<<<< Updated upstream
     print(f"服务器: {args.host}:{args.port}")
     print(f"用户: {args.user}")
     print(f"本地源目录: {args.source}")
@@ -789,6 +821,14 @@ def main():
     print(f"远程目录: {args.remote_dir}")
     print(f"部署脚本: {args.script}")
     print("="*70)
+=======
+    logger.info("服务器: %s:%s", args.host, args.port)
+    logger.info("用户: %s", args.user)
+    logger.info("本地源目录: %s", args.source)
+    logger.info("远程目录: %s", args.remote_dir)
+    logger.info("部署脚本: %s", args.script)
+    logger.info("=" * 70)
+>>>>>>> Stashed changes
     
     try:
         # 创建部署器实例
@@ -809,31 +849,29 @@ def main():
         success = deployer.deploy(cleanup_temp=not args.no_cleanup)
         
         if success:
-            print("\n" + "="*70)
-            print("🎉 部署流程全部完成!")
-            print("="*70)
-            print("\n建议操作:")
-            print("  1. 访问网站检查是否正常显示")
-            print("  2. 检查nginx错误日志: sudo tail -f /var/log/nginx/error.log")
-            print("  3. 如果需要重启nginx: sudo systemctl restart nginx")
-            print("="*70)
+            logger.info("=" * 70)
+            logger.info("🎉 部署流程全部完成!")
+            logger.info("=" * 70)
+            logger.info("建议操作:")
+            logger.info("  1. 访问网站检查是否正常显示")
+            logger.info("  2. 检查nginx错误日志: sudo tail -f /var/log/nginx/error.log")
+            logger.info("  3. 如果需要重启nginx: sudo systemctl restart nginx")
+            logger.info("=" * 70)
             return 0
         else:
-            print("\n" + "="*70)
-            print("❌ 部署失败!")
-            print("="*70)
+            logger.error("=" * 70)
+            logger.error("❌ 部署失败!")
+            logger.error("=" * 70)
             return 1
             
     except FileNotFoundError as e:
-        print(f"\n✗ 文件错误: {e}")
+        logger.error("文件错误: %s", e)
         return 1
     except KeyboardInterrupt:
-        print("\n\n⚠ 用户中断操作")
+        logger.warning("用户中断操作")
         return 130
     except Exception as e:
-        print(f"\n✗ 发生未预期的错误: {e}")
-        import traceback
-        traceback.print_exc()
+        logger.exception("发生未预期的错误: %s", e)
         return 1
 
 if __name__ == "__main__":
@@ -841,8 +879,9 @@ if __name__ == "__main__":
     try:
         import paramiko
     except ImportError:
-        print("错误: 未安装paramiko库")
-        print("请安装依赖: pip install paramiko")
+        logging.basicConfig(level=logging.INFO, format="%(message)s")
+        logger.error("错误: 未安装paramiko库")
+        logger.error("请安装依赖: pip install paramiko")
         sys.exit(1)
     
-    sys.exit(main())
+    sys.exit(main())

+ 14 - 4
src/views/knowledge_base_view.py

@@ -74,8 +74,13 @@ async def create_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    new_kb = await knowledge_base_service.create(db, payload)
-    return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
+    try:
+        new_kb = await knowledge_base_service.create(db, payload)
+        return ResponseSchema(code=0, message="创建成功", data=KnowledgeBaseResponse.model_validate(new_kb))
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception:
+        return ResponseSchema(code=500, message="创建失败")
 
 @router.post("/{id}", response_model=ResponseSchema)
 async def update_knowledge_base(
@@ -89,8 +94,13 @@ async def update_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    kb = await knowledge_base_service.update(db, id, payload)
-    return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
+    try:
+        kb = await knowledge_base_service.update(db, id, payload)
+        return ResponseSchema(code=0, message="更新成功", data=KnowledgeBaseResponse.model_validate(kb))
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception:
+        return ResponseSchema(code=500, message="更新失败")
 
 @router.post("/{id}/status", response_model=ResponseSchema)
 async def update_knowledge_base_status(

+ 7 - 2
src/views/search_engine_view.py

@@ -32,8 +32,13 @@ async def search_knowledge_base(
     if not payload_token:
         return ResponseSchema(code=401, message="无效的访问令牌")
 
-    result = await search_engine_service.search_kb(db, payload)
-    return ResponseSchema(code=0, message="搜索成功", data=result)
+    try:
+        result = await search_engine_service.search_kb(db, payload)
+        return ResponseSchema(code=0, message="搜索成功", data=result)
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e), data=KBSearchResponse(results=[], total=0))
+    except Exception:
+        return ResponseSchema(code=500, message="搜索失败", data=KBSearchResponse(results=[], total=0))
 
 
 # --- 原有接口 ---