Kaynağa Gözat

检索和知识库相关的更新和完善

linyang 1 ay önce
ebeveyn
işleme
0505c3a948

+ 25 - 0
src/app/sample/models/custom_schema.py

@@ -0,0 +1,25 @@
+"""
+知识库自定义Schema定义模型
+"""
+from sqlalchemy import Column, String, Integer, Boolean, Text, DateTime, func
+from sqlalchemy.dialects.mysql import CHAR, TINYINT
+from app.base.async_mysql_connection import Base
+import uuid
+
+class CustomSchema(Base):
+    """知识库自定义Schema表"""
+    __tablename__ = "t_samp_custom_schema"
+
+    id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="主键ID")
+    knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID")
+    field_name = Column(String(255), nullable=False, comment="字段名称(英文)")
+    field_type = Column(String(50), nullable=False, comment="字段类型")
+    max_length = Column(Integer, nullable=True, comment="最大长度")
+    is_primary = Column(Boolean, default=False, comment="是否主键")
+    description = Column(String(1000), nullable=True, comment="描述")
+    
+    created_time = Column(DateTime, default=func.now(), comment="创建时间")
+    updated_time = Column(DateTime, default=func.now(), onupdate=func.now(), comment="修改时间")
+
+    def __repr__(self):
+        return f"<CustomSchema kb_id={self.knowledge_base_id} field={self.field_name}>"

+ 28 - 0
src/app/sample/models/metadata.py

@@ -0,0 +1,28 @@
+"""
+知识库元数据定义模型
+"""
+from sqlalchemy import Column, String, Text, DateTime, func
+from sqlalchemy.dialects.mysql import CHAR, TINYINT
+from app.base.async_mysql_connection import Base
+import uuid
+
+class SampleMetadata(Base):
+    """知识库元数据定义表"""
+    __tablename__ = "t_samp_metadata"
+
+    id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="主键ID")
+    knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID")
+    field_zh_name = Column(String(255), nullable=False, comment="字段名称(中文)")
+    field_en_name = Column(String(500), nullable=False, comment="字段英文名称")
+    field_type = Column(String(10), nullable=False, comment="字段类型: text(文本), num(数字)")
+    remark = Column(String(1000), nullable=True, comment="备注")
+    
+    def to_dict(self) -> dict:
+        """转换为字典"""
+        return {
+            column.name: getattr(self, column.name)
+            for column in self.__table__.columns
+        }
+
+    def __repr__(self):
+        return f"<SampleMetadata kb_id={self.knowledge_base_id} field={self.field_en_name}>"

+ 18 - 1
src/app/sample/schemas/knowledge_base.py

@@ -1,16 +1,33 @@
-from typing import Optional
+from typing import Optional, List
 from pydantic import BaseModel, Field
 from app.schemas.base import BaseModelSchema
 
+class MetadataField(BaseModel):
+    """元数据字段定义"""
+    field_zh_name: str = Field(..., description="中文名称")
+    field_en_name: str = Field(..., description="英文名称")
+    field_type: str = Field(..., description="字段类型: text/num")
+    remark: Optional[str] = Field(None, description="备注")
+
 class KnowledgeBaseBase(BaseModel):
     name: str = Field(..., description="知识库名称")
     collection_name: str = Field(..., description="Milvus集合名称")
     description: Optional[str] = Field(None, description="描述")
     status: Optional[str] = Field("normal", description="状态")
 
+class CustomSchemaField(BaseModel):
+    """自定义Schema字段定义"""
+    field_name: str = Field(..., description="字段名称(英文)")
+    field_type: str = Field(..., description="字段类型: BOOL/INT8/INT16/INT32/INT64/FLOAT/DOUBLE/VARCHAR/JSON")
+    max_length: Optional[int] = Field(None, description="最大长度(VARCHAR需要)")
+    is_primary: bool = Field(False, description="是否主键(通常不需要用户指定)")
+    description: Optional[str] = Field(None, description="描述")
+
 class KnowledgeBaseCreate(KnowledgeBaseBase):
     """创建知识库请求参数"""
     dimension: int = Field(768, description="向量维度,默认768")
+    metadata_fields: Optional[List[MetadataField]] = Field(None, description="元数据字段列表")
+    custom_schemas: Optional[List[CustomSchemaField]] = Field(None, description="自定义Schema字段列表")
 
 class KnowledgeBaseUpdate(BaseModel):
     """更新知识库请求参数"""

+ 10 - 2
src/app/sample/schemas/search_engine.py

@@ -44,14 +44,22 @@ class SearchEngineResponse(BaseModelSchema):
 
 # --- 新增:知识库搜索相关模型 ---
 
+class FilterCondition(BaseModel):
+    field: str
+    value: str
+
 class KBSearchRequest(BaseModel):
     """知识库搜索请求"""
     kb_id: str = Field(..., description="知识库ID或集合名称")
     query: str = Field(..., description="检索关键字")
-    metadata_field: Optional[str] = Field(None, description="元数据字典字段")
-    metadata_value: Optional[str] = Field(None, description="元数据字典值")
+    metadata_field: Optional[str] = Field(None, description="元数据字典字段(兼容旧版)")
+    metadata_value: Optional[str] = Field(None, description="元数据字典值(兼容旧版)")
+    filters: Optional[List[FilterCondition]] = Field(None, description="多重过滤条件")
     top_k: int = Field(10, description="返回结果数量")
     score_threshold: float = Field(0.0, description="相似度阈值")
+    metric_type: Optional[str] = Field(None, description="相似度计算方式")
+    page: int = Field(1, description="页码")
+    page_size: int = Field(10, description="每页数量")
 
 class KBSearchResultItem(BaseModel):
     """单条搜索结果"""

+ 152 - 11
src/app/services/knowledge_base_service.py

@@ -4,11 +4,13 @@
 from math import ceil
 from typing import List, Optional, Tuple, Dict, Any
 from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy import select, func, or_
+from sqlalchemy import select, func, or_, delete as sql_delete
 from datetime import datetime
 import uuid
 
 from app.sample.models.knowledge_base import KnowledgeBase
+from app.sample.models.metadata import SampleMetadata
+from app.sample.models.custom_schema import CustomSchema
 from app.sample.schemas.knowledge_base import (
     KnowledgeBaseCreate, 
     KnowledgeBaseUpdate,
@@ -120,17 +122,13 @@ class KnowledgeBaseService:
         if exists.scalars().first():
             raise ValueError("知识库集合名称已存在")
 
-        # 2. 检查 Milvus 是否已存在
-        if milvus_service.has_collection(payload.collection_name):
-            raise ValueError("Milvus集合已存在,请使用其他名称")
+        # 2. 检查 Milvus 是否已存在 (如果之前残留)
+        # if milvus_service.has_collection(payload.collection_name):
+        #     raise ValueError("Milvus集合已存在,请使用其他名称")
 
         try:
-            # 3. 创建 Milvus 集合
-            milvus_service.create_collection(
-                name=payload.collection_name,
-                dimension=payload.dimension,
-                description=payload.description or ""
-            )
+            # 3. 创建 Milvus 集合 (延迟到点击同步按钮时创建)
+            # milvus_service.create_collection(...)
 
             # 4. 创建 DB 记录
             now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -144,6 +142,36 @@ class KnowledgeBaseService:
                 updated_time=now
             )
             db.add(new_kb)
+            
+            # 5. 保存元数据定义 (如果有)
+            if payload.metadata_fields:
+                for field in payload.metadata_fields:
+                    new_metadata = SampleMetadata(
+                        id=str(uuid.uuid4()),
+                        knowledge_base_id=new_kb.id,
+                        field_zh_name=field.field_zh_name,
+                        field_en_name=field.field_en_name,
+                        field_type=field.field_type,
+                        remark=field.remark
+                    )
+                    db.add(new_metadata)
+            
+            # 6. 保存自定义Schema定义 (如果有)
+            if payload.custom_schemas:
+                for schema_field in payload.custom_schemas:
+                    new_schema = CustomSchema(
+                        id=str(uuid.uuid4()),
+                        knowledge_base_id=new_kb.id,
+                        field_name=schema_field.field_name,
+                        field_type=schema_field.field_type,
+                        max_length=schema_field.max_length,
+                        is_primary=schema_field.is_primary,
+                        description=schema_field.description,
+                        created_time=now,
+                        updated_time=now
+                    )
+                    db.add(new_schema)
+
             await db.commit()
             await db.refresh(new_kb)
 
@@ -214,14 +242,127 @@ class KnowledgeBaseService:
 
         try:
             # 1. 删除 Milvus 集合 (强制删除)
-            milvus_service.drop_collection(kb.collection_name)
+            try:
+                if milvus_service.has_collection(kb.collection_name):
+                    milvus_service.drop_collection(kb.collection_name)
+            except Exception as milvus_err:
+                # 如果是命名不规范等导致的错误,忽略它,继续删除数据库记录
+                print(f"Ignore Milvus error during delete: {milvus_err}")
             
             # 2. 软删除 DB 记录
             kb.is_deleted = 1
             kb.created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            
+            # 3. 删除关联的元数据 (硬删除)
+            await db.execute(sql_delete(SampleMetadata).where(SampleMetadata.knowledge_base_id == id))
+            
+            # 4. 删除关联的自定义Schema (硬删除)
+            await db.execute(sql_delete(CustomSchema).where(CustomSchema.knowledge_base_id == id))
+
             await db.commit()
         except Exception as e:
             await db.rollback()
             raise e
 
+    async def sync_to_milvus(self, db: AsyncSession, id: str) -> KnowledgeBase:
+        """同步知识库到Milvus"""
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        
+        if not kb:
+            raise ValueError("知识库不存在")
+            
+        if milvus_service.has_collection(kb.collection_name):
+            raise ValueError("Milvus集合已存在")
+            
+        # 查询自定义Schema
+        schema_query = select(CustomSchema).where(CustomSchema.knowledge_base_id == id)
+        schema_result = await db.execute(schema_query)
+        custom_schemas = schema_result.scalars().all()
+        
+        fields = []
+        # 1. 添加用户自定义的Schema字段
+        if custom_schemas:
+            for s in custom_schemas:
+                fields.append({
+                    "name": s.field_name,
+                    "type": s.field_type,
+                    "max_length": s.max_length,
+                    "is_primary": s.is_primary,
+                    "description": s.description
+                })
+        
+        # 2. 自动添加 metadata 字段 (JSON类型)
+        # 即使没有定义元数据字段,通常也需要一个 JSON 类型的 metadata 字段来存储灵活的元数据
+        # 如果用户在 t_samp_metadata 中定义了元数据结构,这些结构实际上是存储在 metadata 字段中的 KV 对
+        # 但为了方便检索,我们也可以选择将 metadata 作为一个独立的 JSON 字段存在 Milvus 中
+        
+        # 检查是否已经有名为 'metadata' 的自定义字段,避免冲突
+        has_metadata_field = any(f['name'] == 'metadata' for f in fields)
+        if not has_metadata_field:
+            fields.append({
+                "name": "metadata",
+                "type": "JSON",
+                "description": "默认元数据字段"
+            })
+        
+        try:
+            # 暂时无法获取维度信息,默认768,或者应该在数据库中存储维度
+            # 假设默认 768,后续可以在 KnowledgeBase 模型中增加 dimension 字段
+            milvus_service.create_collection(
+                name=kb.collection_name,
+                dimension=768, 
+                description=kb.description or "",
+                fields=fields if fields else None
+            )
+            return kb
+        except Exception as e:
+            raise e
+
+    async def get_metadata_and_schema(self, db: AsyncSession, kb_id: str) -> Dict[str, List[dict]]:
+        """获取知识库的元数据字段列表和自定义Schema"""
+        # 检查知识库是否存在
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        # 查询元数据表
+        meta_query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb_id)
+        meta_result = await db.execute(meta_query)
+        metadata_fields = [f.to_dict() for f in meta_result.scalars().all()]
+        
+        # 查询自定义Schema表
+        schema_query = select(CustomSchema).where(CustomSchema.knowledge_base_id == kb_id)
+        schema_result = await db.execute(schema_query)
+        
+        custom_schemas = []
+        for s in schema_result.scalars().all():
+            custom_schemas.append({
+                "field_name": s.field_name,
+                "field_type": s.field_type,
+                "max_length": s.max_length,
+                "description": s.description
+            })
+            
+        return {
+            "metadata_fields": metadata_fields,
+            "custom_schemas": custom_schemas
+        }
+
+    async def get_metadata_fields(self, db: AsyncSession, kb_id: str) -> List[dict]:
+        """获取知识库的元数据字段列表"""
+        # 检查知识库是否存在
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        # 查询元数据表
+        query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb_id)
+        result = await db.execute(query)
+        fields = result.scalars().all()
+        
+        return [f.to_dict() for f in fields]
+
 knowledge_base_service = KnowledgeBaseService()

+ 114 - 11
src/app/services/milvus_service.py

@@ -25,21 +25,99 @@ class MilvusService:
         # 获取embedding model
         self.emdmodel = get_embedding_model()
 
-    def create_collection(self, name: str, dimension: int = 768, description: str = "") -> None:
-        """创建 Milvus 集合"""
+    def create_collection(self, name: str, dimension: int = 768, description: str = "", fields: List[Dict] = None) -> None:
+        """
+        创建 Milvus 集合
+        :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...}
+        """
         if self.client.has_collection(name):
             logger.info(f"Collection {name} already exists.")
             return
         
-        # 使用简化的 create_collection API
-        self.client.create_collection(
-            collection_name=name,
-            dimension=dimension,
-            description=description,
-            auto_id=True,  # 自动生成 ID
-            id_type="int", # ID 类型
-            metric_type="COSINE" # 默认使用余弦相似度
-        )
+        # 如果有自定义字段,使用 schema 创建
+        if fields:
+            from pymilvus import MilvusClient, DataType
+            
+            # 1. 创建 Schema
+            schema = MilvusClient.create_schema(
+                auto_id=True,
+                enable_dynamic_field=True,
+                description=description
+            )
+            
+            # 2. 添加必须的默认字段
+            schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
+            schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
+            # schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) # 如果需要混合检索,可能需要
+            
+            # 3. 添加用户自定义字段
+            # 映射字符串类型到 pymilvus DataType
+            type_map = {
+                "BOOL": DataType.BOOL,
+                "INT8": DataType.INT8,
+                "INT16": DataType.INT16,
+                "INT32": DataType.INT32,
+                "INT64": DataType.INT64,
+                "FLOAT": DataType.FLOAT,
+                "DOUBLE": DataType.DOUBLE,
+                "VARCHAR": DataType.VARCHAR,
+                "JSON": DataType.JSON,
+                "FLOAT_VECTOR": DataType.FLOAT_VECTOR
+            }
+            
+            for f in fields:
+                dtype = type_map.get(f.get("type", "").upper())
+                if not dtype:
+                    continue # 忽略未知类型
+                
+                kwargs = {
+                    "field_name": f.get("name"),
+                    "datatype": dtype,
+                    "description": f.get("description", "")
+                }
+                
+                if dtype == DataType.VARCHAR:
+                    kwargs["max_length"] = f.get("max_length", 65535)
+                
+                schema.add_field(**kwargs)
+            
+            # 4. 准备索引参数
+            index_params = self.client.prepare_index_params()
+            
+            # 5. 添加向量索引
+            index_params.add_index(
+                field_name="vector", 
+                index_type="AUTOINDEX",
+                metric_type="COSINE"
+            )
+            
+            # 6. 为自定义标量字段添加索引 (可选,这里为所有标量字段添加倒排索引以加速过滤)
+            for f in fields:
+                # VARCHAR/INT/BOOL 等支持索引
+                if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"]:
+                    index_params.add_index(
+                        field_name=f.get("name"),
+                        index_type="INVERTED" # 标量字段倒排索引
+                    )
+
+            # 7. 创建集合
+            self.client.create_collection(
+                collection_name=name,
+                schema=schema,
+                index_params=index_params
+            )
+            
+        else:
+            # 使用简化的 create_collection API
+            self.client.create_collection(
+                collection_name=name,
+                dimension=dimension,
+                description=description,
+                auto_id=True,  # 自动生成 ID
+                id_type="int", # ID 类型
+                metric_type="COSINE" # 默认使用余弦相似度
+            )
+        
         logger.info(f"Created collection {name} with dimension {dimension}")
 
     def drop_collection(self, name: str) -> None:
@@ -187,6 +265,8 @@ class MilvusService:
 
         # 提取索引信息
         indices = []
+        
+        # 尝试从 describe_collection 结果中获取 (兼容旧逻辑)
         if "indexes" in desc:
             for idx in desc["indexes"]:
                 index_info = {
@@ -197,6 +277,29 @@ class MilvusService:
                     "params": idx.get("params"),
                 }
                 indices.append(index_info)
+        
+        # 如果没有获取到索引信息,尝试主动查询 list_indexes
+        if not indices:
+            try:
+                # 获取索引列表 (通常返回索引名称列表)
+                index_names = self.client.list_indexes(collection_name=name)
+                if index_names:
+                    for idx_name in index_names:
+                        try:
+                            # 获取索引详情
+                            idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name)
+                            if idx_desc:
+                                indices.append({
+                                    "field_name": idx_desc.get("field_name"),
+                                    "index_name": idx_desc.get("index_name"),
+                                    "index_type": idx_desc.get("index_type"),
+                                    "metric_type": idx_desc.get("metric_type"),
+                                    "params": idx_desc.get("params"),
+                                })
+                        except Exception:
+                            continue
+            except Exception as e:
+                logger.warning(f"Failed to list/describe indexes for {name}: {e}")
 
         detail = {
             "name": name,

+ 338 - 21
src/app/services/search_engine_service.py

@@ -23,6 +23,7 @@ from app.sample.schemas.search_engine import (
 from app.schemas.base import PaginationSchema
 from app.services.milvus_service import milvus_service
 from app.utils.vector_utils import text_to_vector_algo
+import logging
 
 class SearchEngineService:
     
@@ -36,33 +37,300 @@ class SearchEngineService:
             return KBSearchResponse(results=[], total=0)
             
         # 1. 使用算法生成向量 (替代 Embedding 模型)
+        # 尝试从 Milvus collection 获取向量维度,动态匹配维度
         # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
-        query_vector = text_to_vector_algo(payload.query, dim=768)
+        try:
+            collection_detail = milvus_service.get_collection_detail(kb_id)
+        except Exception:
+            collection_detail = None
+
+        dim = None
+        if collection_detail and isinstance(collection_detail, dict):
+            fields = collection_detail.get("fields", []) or []
+            for f in fields:
+                # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
+                if not isinstance(f, dict):
+                    continue
+                ftype = str(f.get("type") or "").lower()
+                print(ftype+'是什么东西')
+                if "100" in ftype or '101' in ftype:  # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
+                    # 找到向量字段,优先从 params.dim 获取维度
+                    params = f.get("params") or {}
+                    if params and params.get("dim"):
+                        try:
+                            dim = int(params.get("dim"))
+                            break
+                        except Exception:
+                            dim = None
+        # 回退默认维度
+        if not dim:
+            dim = 768
+
+        # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
+        anns_field = "vector"
+        if collection_detail and isinstance(collection_detail, dict):
+            fields = collection_detail.get("fields", []) or []
+            # 优先寻找有 params.dim 的向量字段
+            for f in fields:
+                if not isinstance(f, dict):
+                    continue
+                params = f.get("params") or {}
+                if params and params.get("dim") and f.get("name"):
+                    anns_field = f.get("name")
+                    try:
+                        dim = int(params.get("dim"))
+                    except Exception:
+                        pass
+                    break
+
+            # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
+            if anns_field == "vector":
+                for f in fields:
+                    if not isinstance(f, dict):
+                        continue
+                    fname = (f.get("name") or "")
+                    ftype = str(f.get("type") or "").lower()
+                    if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"):
+                        anns_field = fname
+                        break
+                    if fname and "vector" in ftype:
+                        anns_field = fname
+                        break
+
+        # 1. 向量搜索 (Dense Retrieval)
+        # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心
+        # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等
+        # 目前后端实现主要是基于 Milvus 的 ANN 搜索
+        
+        # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度)
+        # 除非用户明确指定了其他度量方式(通常不会)
+        requested_metric = payload.metric_type
+        use_hybrid = False
+        
+        # 只有当 metric_type 为 None 或者特定值时才尝试混合检索
+        # 或者我们可以认为只要不指定,就优先尝试混合
+        if not requested_metric or requested_metric.lower() == 'hybrid':
+             use_hybrid = True
+        
+        search_params = {
+            "metric_type": "L2", # 默认内部计算用 L2
+            "params": {"nprobe": 10},
+        }
+        
+        # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性)
+        if payload.metric_type and payload.metric_type.upper() != 'HYBRID':
+             search_params["metric_type"] = payload.metric_type
         
         # 2. 构建过滤表达式
-        expr = ""
+        expr_list = []
+        
+        # 兼容旧的单一字段过滤
         if payload.metadata_field and payload.metadata_value:
-            # 示例:假设元数据直接作为字段存在,或者在 extra_info JSON 中
-            # 这里需要根据实际 Milvus Collection 的 Schema 调整
-            # 暂时忽略,以免报错
-            pass
+            safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip()
+            safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip()
             
+            if safe_field and safe_value:
+                if safe_value.isdigit():
+                    expr_list.append(f'{safe_field} == {safe_value}')
+                else:
+                    expr_list.append(f'{safe_field} == "{safe_value}"')
+        
+        # 处理新的多重过滤
+        if payload.filters:
+            for f in payload.filters:
+                safe_field = f.field.replace("'", "").replace('"', "").strip()
+                safe_value = f.value.replace("'", "").replace('"', "").strip()
+                
+                if safe_field and safe_value:
+                    if safe_value.isdigit():
+                        expr_list.append(f'{safe_field} == {safe_value}')
+                    else:
+                        expr_list.append(f'{safe_field} == "{safe_value}"')
+        
+        # 组合所有条件 (使用 AND)
+        expr = " and ".join(expr_list) if expr_list else ""
+        
+        # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
+        query_vector = text_to_vector_algo(payload.query, dim=dim)
+        
+        # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
+        metric_type = None
+        # 优先从 collection_detail 检测真实 metric
+        if collection_detail and isinstance(collection_detail, dict):
+            indices = collection_detail.get("indices") or []
+            if isinstance(indices, list) and len(indices) > 0:
+                for idx in indices:
+                    try:
+                        mt = idx.get("metric_type") or idx.get("metric")
+                        if mt:
+                            metric_type = str(mt).upper()
+                            break
+                    except Exception:
+                        continue
+        
+        # 尝试从 properties 中读取
+        if not metric_type and collection_detail and isinstance(collection_detail, dict):
+            props = collection_detail.get("properties") or {}
+            if isinstance(props, dict):
+                mt = props.get("metric_type") or props.get("metric")
+                if mt:
+                    metric_type = str(mt).upper()
+        
+        actual_search_metric = metric_type
+        if not actual_search_metric:
+             # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2
+             actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2"
+        
+        metric_type = actual_search_metric
+        
+        logger = logging.getLogger(__name__)
+        logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})")
+
         # 3. 执行 Milvus 搜索
         try:
+            # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误
+            # metric_type 已在上面检测并存放于变量 metric_type
             search_params = {
-                "metric_type": "COSINE", 
+                "metric_type": metric_type,
                 "params": {"nprobe": 10}
             }
+
+            # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段
+            # 优先使用 page/page_size 计算 limit 和 offset
+            page = payload.page if payload.page and payload.page > 0 else 1
+            page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
             
-            results = milvus_service.client.search(
-                collection_name=kb_id,
-                data=[query_vector],
-                anns_field="vector", 
-                search_params=search_params,
-                limit=payload.top_k,
-                filter=expr if expr else "",
-                output_fields=["*"] 
-            )
+            # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size
+            # 但这里为了清晰,优先使用 page_size
+            
+            offset = (page - 1) * page_size
+            limit = page_size
+            
+            # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护
+            if offset + limit > 16384:
+                # 如果超出深度分页限制,可能需要提示或截断
+                # 这里暂时不做处理,让 Milvus 报错或自行截断
+                pass
+
+            # 获取集合总数用于分页显示 (total)
+            collection_count = 0
+            if collection_detail and isinstance(collection_detail, dict):
+                collection_count = collection_detail.get("entity_count") or 0
+            
+            if not collection_count:
+                try:
+                    stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
+                    collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0
+                except Exception:
+                    collection_count = 0
+
+            # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0)
+            # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
+            
+            try:
+                # 尝试使用混合检索 (Hybrid Search)
+                # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
+                # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
+                
+                # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
+                # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
+                
+                # 这里我们直接调用 milvus_service.hybrid_search
+                # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
+                
+                use_hybrid = False
+                # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
+                # 或者我们可以认为只要不指定,就优先尝试混合
+                # 已经在上面判断过 use_hybrid = True 了
+                
+                if use_hybrid:
+                    logger.info(f"Attempting hybrid search for KB={kb_id}")
+                    try:
+                        # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
+                        # 所以我们需要获取 top_k = offset + limit,然后手动切片
+                        target_k = offset + limit
+                        
+                        hybrid_results = milvus_service.hybrid_search(
+                            collection_name=kb_id,
+                            query_text=payload.query,
+                            top_k=target_k
+                        )
+                        
+                        # 手动切片实现分页
+                        start = offset
+                        end = offset + limit
+                        # 确保不越界
+                        if start >= len(hybrid_results):
+                            sliced_results = []
+                        else:
+                            sliced_results = hybrid_results[start:end]
+                        
+                        formatted_results = []
+                        for item in sliced_results:
+                            formatted_results.append(KBSearchResultItem(
+                                id=str(item.get('id')),
+                                kb_name=kb_id,
+                                doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档",
+                                content=item.get('text_content') or "",
+                                meta_info=str(item.get('metadata', {})),
+                                score=item.get('similarity', 0) * 100 # 假设是 0-1
+                            ))
+
+                        return KBSearchResponse(results=formatted_results, total=collection_count)
+
+                    except Exception as hybrid_err:
+                        logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
+                        # Fallback to standard vector search below
+                        pass
+
+                results = milvus_service.client.search(
+                    collection_name=kb_id,
+                    data=[query_vector],
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    limit=limit,
+                    offset=offset, # 添加 offset 支持分页
+                    filter=expr if expr else "",
+                    output_fields=["*"] 
+                )
+            except Exception as milvus_err:
+                # 捕获 Milvus 异常,常见原因包括 metric mismatch
+                logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}")
+                
+                # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试
+                error_msg = str(milvus_err)
+                if "metric type not match" in error_msg:
+                    import re
+                    # 匹配 expected=COSINE 或 expected='COSINE' 等格式
+                    # 支持 COSINE, L2, IP, BM25 等
+                    match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg)
+                    if match:
+                        correct_metric = match.group(1).upper()
+                        logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}")
+                        
+                        # 更新 metric_type 并重试搜索
+                        search_params["metric_type"] = correct_metric
+                        # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度
+                        metric_type = correct_metric
+                        
+                        # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的
+                        # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector)
+                        # 但这里假设 anns_field 是正确的,只是 metric 不对
+                        
+                        results = milvus_service.client.search(
+                            collection_name=kb_id,
+                            data=[query_vector],
+                            anns_field=anns_field,
+                            search_params=search_params,
+                            limit=limit,
+                            offset=offset, # 同样加上 offset
+                            filter=expr if expr else "",
+                            output_fields=["*"] 
+                        )
+                    else:
+                        raise
+                else:
+                    raise
             
             # 4. 格式化结果
             formatted_results = []
@@ -73,27 +341,76 @@ class SearchEngineService:
                     #     continue
                         
                     entity = hit.entity
-                    
+
                     content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
                     if not content:
-                        debug_data = {k:v for k,v in entity.items() if k != "vector"}
+                        debug_data = {k: v for k, v in entity.items() if k != anns_field}
                         content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
                         
                     doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
                     
                     meta_info = []
                     for k, v in entity.items():
-                        if k not in ["vector", "text", "content", "page_content", "id", "pk"]:
+                        if k not in [anns_field, "text", "content", "page_content", "id", "pk"]:
                             meta_info.append(f"{k}: {v}")
                     meta_str = "; ".join(meta_info[:3])
                     
+                    # 根据 collection 的 metric 动态计算相似度分数
+                    # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
+                    display_metric = requested_metric if requested_metric else metric_type
+                    
+                    similarity_pct = None
+                    try:
+                        raw_score = float(hit.score)
+                    except Exception:
+                        raw_score = None
+
+                    if raw_score is not None:
+                        # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
+                        # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
+                        # 更好的做法是:
+                        # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
+                        # 2. 转换为 display_metric 要求的格式
+                        
+                        # Case 1: 实际是 L2 (距离),用户想看 L2
+                        if "L2" in metric_type or "EUCLIDEAN" in metric_type:
+                            distance = raw_score
+                            if display_metric and ("COSINE" in display_metric):
+                                # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
+                                # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
+                                # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
+                                similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
+                            else:
+                                similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
+                                
+                        # Case 2: 实际是 Cosine (相似度 [-1, 1])
+                        elif "COSINE" in metric_type:
+                            cosine_score = raw_score
+                            # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
+                            similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
+                            
+                        # Case 3: IP (内积)
+                        elif "IP" in metric_type or "INNER" in metric_type:
+                             similarity_pct = round(raw_score * 100.0, 2)
+                        
+                        # Fallback
+                        else:
+                            # 兼容 BM25 或其他未知 metric
+                            if "BM25" in metric_type:
+                                # BM25 分数通常是正数,没有固定上限,直接显示原值
+                                similarity_pct = round(raw_score, 2)
+                            else:
+                                similarity_pct = round(raw_score * 100.0, 2)
+                    else:
+                        similarity_pct = 0.0
+
                     formatted_results.append(KBSearchResultItem(
                         id=str(hit.id),
-                        kb_name=kb_id, 
+                        kb_name=kb_id,
                         doc_name=doc_name,
                         content=content,
                         meta_info=meta_str,
-                        score=round(hit.score * 100, 2)
+                        score=similarity_pct
                     ))
             
             return KBSearchResponse(results=formatted_results, total=len(formatted_results))

+ 17 - 4
src/app/services/snippet_service.py

@@ -132,14 +132,21 @@ class SnippetService:
         # 使用统一算法生成向量
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         
-        data = [{
+        # 基础数据
+        item = {
             "vector": fake_vector,
             "text": payload.content,
             "source": payload.doc_name,
             "doc_id": "manual_add",
             "file_name": payload.doc_name, 
             "title": payload.doc_name
-        }]
+        }
+        
+        # 合并自定义字段
+        if hasattr(payload, 'custom_fields') and payload.custom_fields:
+            item.update(payload.custom_fields)
+            
+        data = [item]
         
         res = milvus_service.client.insert(
             collection_name=payload.collection_name,
@@ -169,14 +176,20 @@ class SnippetService:
         # 使用统一算法生成向量
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         
-        data = [{
+        item = {
             "vector": fake_vector,
             "text": payload.content,
             "source": payload.doc_name or "已更新",
             "doc_id": "updated",
             "file_name": payload.doc_name,
             "title": payload.doc_name
-        }]
+        }
+        
+        # 合并自定义字段
+        if hasattr(payload, 'custom_fields') and payload.custom_fields:
+            item.update(payload.custom_fields)
+            
+        data = [item]
         
         milvus_service.client.insert(collection_name=kb, data=data)
         milvus_service.client.flush(kb)

+ 38 - 0
src/views/knowledge_base_view.py

@@ -102,3 +102,41 @@ async def delete_knowledge_base(
 
     await knowledge_base_service.delete(db, id)
     return ResponseSchema(code=0, message="删除成功")
+
+@router.get("/{id}/metadata", response_model=ResponseSchema)
+async def get_knowledge_base_metadata(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取知识库的元数据字段定义和自定义Schema"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        data = await knowledge_base_service.get_metadata_and_schema(db, id)
+        return ResponseSchema(code=0, message="获取成功", data=data)
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"获取失败: {str(e)}")
+
+@router.post("/{id}/sync", response_model=ResponseSchema)
+async def sync_knowledge_base(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """同步创建Milvus集合"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        await knowledge_base_service.sync_to_milvus(db, id)
+        return ResponseSchema(code=0, message="同步成功")
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"同步失败: {str(e)}")

+ 3 - 1
src/views/snippet_view.py

@@ -3,7 +3,7 @@
 """
 from fastapi import APIRouter, Depends, Query, Path, Body
 from fastapi.responses import StreamingResponse
-from typing import Optional
+from typing import Optional, Dict, Any
 from datetime import datetime
 import urllib.parse
 
@@ -22,11 +22,13 @@ class SnippetCreate(BaseModel):
     doc_name: str = "手动添加"
     content: str
     meta_info: Optional[str] = None
+    custom_fields: Optional[Dict[str, Any]] = None
 
 class SnippetUpdate(BaseModel):
     collection_name: str
     doc_name: Optional[str] = None
     content: str
+    custom_fields: Optional[Dict[str, Any]] = None
 
 @router.get("", response_model=PaginatedResponseSchema)
 async def get_snippets(

+ 2 - 1
src/views/system_view.py

@@ -789,7 +789,8 @@ async def update_role_menus(
         
         # 调用 service 层
         system_service = SystemService()
-        success, data, message = await system_service.update_role_menus(role_id, menu_ids)
+        updater_id = payload.get("sub")
+        success, data, message = await system_service.update_role_menus(role_id, menu_ids, updater_id)
         
         if success:
             return ApiResponse(