Просмотр исходного кода

Merge branch 'dev' of http://192.168.0.3:3000/CRBC-MaaS-Platform-Project/LQAdminPlatform into dev

chenkun 1 месяц назад
Родитель
Сommit
8ae899c290

+ 25 - 0
src/app/sample/models/custom_schema.py

@@ -0,0 +1,25 @@
+"""
+知识库自定义Schema定义模型
+"""
+from sqlalchemy import Column, String, Integer, Boolean, Text, DateTime, func
+from sqlalchemy.dialects.mysql import CHAR, TINYINT
+from app.base.async_mysql_connection import Base
+import uuid
+
+class CustomSchema(Base):
+    """知识库自定义Schema表"""
+    __tablename__ = "t_samp_custom_schema"
+
+    id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="主键ID")
+    knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID")
+    field_name = Column(String(255), nullable=False, comment="字段名称(英文)")
+    field_type = Column(String(50), nullable=False, comment="字段类型")
+    max_length = Column(Integer, nullable=True, comment="最大长度")
+    is_primary = Column(Boolean, default=False, comment="是否主键")
+    description = Column(String(1000), nullable=True, comment="描述")
+    
+    created_time = Column(DateTime, default=func.now(), comment="创建时间")
+    updated_time = Column(DateTime, default=func.now(), onupdate=func.now(), comment="修改时间")
+
+    def __repr__(self):
+        return f"<CustomSchema kb_id={self.knowledge_base_id} field={self.field_name}>"

+ 28 - 0
src/app/sample/models/metadata.py

@@ -0,0 +1,28 @@
+"""
+知识库元数据定义模型
+"""
+from sqlalchemy import Column, String, Text, DateTime, func
+from sqlalchemy.dialects.mysql import CHAR, TINYINT
+from app.base.async_mysql_connection import Base
+import uuid
+
+class SampleMetadata(Base):
+    """知识库元数据定义表"""
+    __tablename__ = "t_samp_metadata"
+
+    id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="主键ID")
+    knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID")
+    field_zh_name = Column(String(255), nullable=False, comment="字段名称(中文)")
+    field_en_name = Column(String(500), nullable=False, comment="字段英文名称")
+    field_type = Column(String(10), nullable=False, comment="字段类型: text(文本), num(数字)")
+    remark = Column(String(1000), nullable=True, comment="备注")
+    
+    def to_dict(self) -> dict:
+        """转换为字典"""
+        return {
+            column.name: getattr(self, column.name)
+            for column in self.__table__.columns
+        }
+
+    def __repr__(self):
+        return f"<SampleMetadata kb_id={self.knowledge_base_id} field={self.field_en_name}>"

+ 18 - 1
src/app/sample/schemas/knowledge_base.py

@@ -1,16 +1,33 @@
-from typing import Optional
+from typing import Optional, List
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from app.schemas.base import BaseModelSchema
 from app.schemas.base import BaseModelSchema
 
 
+class MetadataField(BaseModel):
+    """元数据字段定义"""
+    field_zh_name: str = Field(..., description="中文名称")
+    field_en_name: str = Field(..., description="英文名称")
+    field_type: str = Field(..., description="字段类型: text/num")
+    remark: Optional[str] = Field(None, description="备注")
+
 class KnowledgeBaseBase(BaseModel):
 class KnowledgeBaseBase(BaseModel):
     name: str = Field(..., description="知识库名称")
     name: str = Field(..., description="知识库名称")
     collection_name: str = Field(..., description="Milvus集合名称")
     collection_name: str = Field(..., description="Milvus集合名称")
     description: Optional[str] = Field(None, description="描述")
     description: Optional[str] = Field(None, description="描述")
     status: Optional[str] = Field("normal", description="状态")
     status: Optional[str] = Field("normal", description="状态")
 
 
+class CustomSchemaField(BaseModel):
+    """自定义Schema字段定义"""
+    field_name: str = Field(..., description="字段名称(英文)")
+    field_type: str = Field(..., description="字段类型: BOOL/INT8/INT16/INT32/INT64/FLOAT/DOUBLE/VARCHAR/JSON")
+    max_length: Optional[int] = Field(None, description="最大长度(VARCHAR需要)")
+    is_primary: bool = Field(False, description="是否主键(通常不需要用户指定)")
+    description: Optional[str] = Field(None, description="描述")
+
 class KnowledgeBaseCreate(KnowledgeBaseBase):
 class KnowledgeBaseCreate(KnowledgeBaseBase):
     """创建知识库请求参数"""
     """创建知识库请求参数"""
     dimension: int = Field(768, description="向量维度,默认768")
     dimension: int = Field(768, description="向量维度,默认768")
+    metadata_fields: Optional[List[MetadataField]] = Field(None, description="元数据字段列表")
+    custom_schemas: Optional[List[CustomSchemaField]] = Field(None, description="自定义Schema字段列表")
 
 
 class KnowledgeBaseUpdate(BaseModel):
 class KnowledgeBaseUpdate(BaseModel):
     """更新知识库请求参数"""
     """更新知识库请求参数"""

+ 10 - 2
src/app/sample/schemas/search_engine.py

@@ -44,14 +44,22 @@ class SearchEngineResponse(BaseModelSchema):
 
 
 # --- 新增:知识库搜索相关模型 ---
 # --- 新增:知识库搜索相关模型 ---
 
 
+class FilterCondition(BaseModel):
+    field: str
+    value: str
+
 class KBSearchRequest(BaseModel):
 class KBSearchRequest(BaseModel):
     """知识库搜索请求"""
     """知识库搜索请求"""
     kb_id: str = Field(..., description="知识库ID或集合名称")
     kb_id: str = Field(..., description="知识库ID或集合名称")
     query: str = Field(..., description="检索关键字")
     query: str = Field(..., description="检索关键字")
-    metadata_field: Optional[str] = Field(None, description="元数据字典字段")
-    metadata_value: Optional[str] = Field(None, description="元数据字典值")
+    metadata_field: Optional[str] = Field(None, description="元数据字典字段(兼容旧版)")
+    metadata_value: Optional[str] = Field(None, description="元数据字典值(兼容旧版)")
+    filters: Optional[List[FilterCondition]] = Field(None, description="多重过滤条件")
     top_k: int = Field(10, description="返回结果数量")
     top_k: int = Field(10, description="返回结果数量")
     score_threshold: float = Field(0.0, description="相似度阈值")
     score_threshold: float = Field(0.0, description="相似度阈值")
+    metric_type: Optional[str] = Field(None, description="相似度计算方式")
+    page: int = Field(1, description="页码")
+    page_size: int = Field(10, description="每页数量")
 
 
 class KBSearchResultItem(BaseModel):
 class KBSearchResultItem(BaseModel):
     """单条搜索结果"""
     """单条搜索结果"""

+ 152 - 11
src/app/services/knowledge_base_service.py

@@ -4,11 +4,13 @@
 from math import ceil
 from math import ceil
 from typing import List, Optional, Tuple, Dict, Any
 from typing import List, Optional, Tuple, Dict, Any
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy import select, func, or_
+from sqlalchemy import select, func, or_, delete as sql_delete
 from datetime import datetime
 from datetime import datetime
 import uuid
 import uuid
 
 
 from app.sample.models.knowledge_base import KnowledgeBase
 from app.sample.models.knowledge_base import KnowledgeBase
+from app.sample.models.metadata import SampleMetadata
+from app.sample.models.custom_schema import CustomSchema
 from app.sample.schemas.knowledge_base import (
 from app.sample.schemas.knowledge_base import (
     KnowledgeBaseCreate, 
     KnowledgeBaseCreate, 
     KnowledgeBaseUpdate,
     KnowledgeBaseUpdate,
@@ -120,17 +122,13 @@ class KnowledgeBaseService:
         if exists.scalars().first():
         if exists.scalars().first():
             raise ValueError("知识库集合名称已存在")
             raise ValueError("知识库集合名称已存在")
 
 
-        # 2. 检查 Milvus 是否已存在
-        if milvus_service.has_collection(payload.collection_name):
-            raise ValueError("Milvus集合已存在,请使用其他名称")
+        # 2. 检查 Milvus 是否已存在 (如果之前残留)
+        # if milvus_service.has_collection(payload.collection_name):
+        #     raise ValueError("Milvus集合已存在,请使用其他名称")
 
 
         try:
         try:
-            # 3. 创建 Milvus 集合
-            milvus_service.create_collection(
-                name=payload.collection_name,
-                dimension=payload.dimension,
-                description=payload.description or ""
-            )
+            # 3. 创建 Milvus 集合 (延迟到点击同步按钮时创建)
+            # milvus_service.create_collection(...)
 
 
             # 4. 创建 DB 记录
             # 4. 创建 DB 记录
             now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
             now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -144,6 +142,36 @@ class KnowledgeBaseService:
                 updated_time=now
                 updated_time=now
             )
             )
             db.add(new_kb)
             db.add(new_kb)
+            
+            # 5. 保存元数据定义 (如果有)
+            if payload.metadata_fields:
+                for field in payload.metadata_fields:
+                    new_metadata = SampleMetadata(
+                        id=str(uuid.uuid4()),
+                        knowledge_base_id=new_kb.id,
+                        field_zh_name=field.field_zh_name,
+                        field_en_name=field.field_en_name,
+                        field_type=field.field_type,
+                        remark=field.remark
+                    )
+                    db.add(new_metadata)
+            
+            # 6. 保存自定义Schema定义 (如果有)
+            if payload.custom_schemas:
+                for schema_field in payload.custom_schemas:
+                    new_schema = CustomSchema(
+                        id=str(uuid.uuid4()),
+                        knowledge_base_id=new_kb.id,
+                        field_name=schema_field.field_name,
+                        field_type=schema_field.field_type,
+                        max_length=schema_field.max_length,
+                        is_primary=schema_field.is_primary,
+                        description=schema_field.description,
+                        created_time=now,
+                        updated_time=now
+                    )
+                    db.add(new_schema)
+
             await db.commit()
             await db.commit()
             await db.refresh(new_kb)
             await db.refresh(new_kb)
 
 
@@ -214,14 +242,127 @@ class KnowledgeBaseService:
 
 
         try:
         try:
             # 1. 删除 Milvus 集合 (强制删除)
             # 1. 删除 Milvus 集合 (强制删除)
-            milvus_service.drop_collection(kb.collection_name)
+            try:
+                if milvus_service.has_collection(kb.collection_name):
+                    milvus_service.drop_collection(kb.collection_name)
+            except Exception as milvus_err:
+                # 如果是命名不规范等导致的错误,忽略它,继续删除数据库记录
+                print(f"Ignore Milvus error during delete: {milvus_err}")
             
             
             # 2. 软删除 DB 记录
             # 2. 软删除 DB 记录
             kb.is_deleted = 1
             kb.is_deleted = 1
             kb.created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
             kb.created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+            
+            # 3. 删除关联的元数据 (硬删除)
+            await db.execute(sql_delete(SampleMetadata).where(SampleMetadata.knowledge_base_id == id))
+            
+            # 4. 删除关联的自定义Schema (硬删除)
+            await db.execute(sql_delete(CustomSchema).where(CustomSchema.knowledge_base_id == id))
+
             await db.commit()
             await db.commit()
         except Exception as e:
         except Exception as e:
             await db.rollback()
             await db.rollback()
             raise e
             raise e
 
 
+    async def sync_to_milvus(self, db: AsyncSession, id: str) -> KnowledgeBase:
+        """同步知识库到Milvus"""
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        
+        if not kb:
+            raise ValueError("知识库不存在")
+            
+        if milvus_service.has_collection(kb.collection_name):
+            raise ValueError("Milvus集合已存在")
+            
+        # 查询自定义Schema
+        schema_query = select(CustomSchema).where(CustomSchema.knowledge_base_id == id)
+        schema_result = await db.execute(schema_query)
+        custom_schemas = schema_result.scalars().all()
+        
+        fields = []
+        # 1. 添加用户自定义的Schema字段
+        if custom_schemas:
+            for s in custom_schemas:
+                fields.append({
+                    "name": s.field_name,
+                    "type": s.field_type,
+                    "max_length": s.max_length,
+                    "is_primary": s.is_primary,
+                    "description": s.description
+                })
+        
+        # 2. 自动添加 metadata 字段 (JSON类型)
+        # 即使没有定义元数据字段,通常也需要一个 JSON 类型的 metadata 字段来存储灵活的元数据
+        # 如果用户在 t_samp_metadata 中定义了元数据结构,这些结构实际上是存储在 metadata 字段中的 KV 对
+        # 但为了方便检索,我们也可以选择将 metadata 作为一个独立的 JSON 字段存在 Milvus 中
+        
+        # 检查是否已经有名为 'metadata' 的自定义字段,避免冲突
+        has_metadata_field = any(f['name'] == 'metadata' for f in fields)
+        if not has_metadata_field:
+            fields.append({
+                "name": "metadata",
+                "type": "JSON",
+                "description": "默认元数据字段"
+            })
+        
+        try:
+            # 暂时无法获取维度信息,默认768,或者应该在数据库中存储维度
+            # 假设默认 768,后续可以在 KnowledgeBase 模型中增加 dimension 字段
+            milvus_service.create_collection(
+                name=kb.collection_name,
+                dimension=768, 
+                description=kb.description or "",
+                fields=fields if fields else None
+            )
+            return kb
+        except Exception as e:
+            raise e
+
+    async def get_metadata_and_schema(self, db: AsyncSession, kb_id: str) -> Dict[str, List[dict]]:
+        """获取知识库的元数据字段列表和自定义Schema"""
+        # 检查知识库是否存在
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        # 查询元数据表
+        meta_query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb_id)
+        meta_result = await db.execute(meta_query)
+        metadata_fields = [f.to_dict() for f in meta_result.scalars().all()]
+        
+        # 查询自定义Schema表
+        schema_query = select(CustomSchema).where(CustomSchema.knowledge_base_id == kb_id)
+        schema_result = await db.execute(schema_query)
+        
+        custom_schemas = []
+        for s in schema_result.scalars().all():
+            custom_schemas.append({
+                "field_name": s.field_name,
+                "field_type": s.field_type,
+                "max_length": s.max_length,
+                "description": s.description
+            })
+            
+        return {
+            "metadata_fields": metadata_fields,
+            "custom_schemas": custom_schemas
+        }
+
+    async def get_metadata_fields(self, db: AsyncSession, kb_id: str) -> List[dict]:
+        """获取知识库的元数据字段列表"""
+        # 检查知识库是否存在
+        result = await db.execute(select(KnowledgeBase).where(KnowledgeBase.id == kb_id, KnowledgeBase.is_deleted == 0))
+        kb = result.scalars().first()
+        if not kb:
+            raise ValueError("知识库不存在")
+
+        # 查询元数据表
+        query = select(SampleMetadata).where(SampleMetadata.knowledge_base_id == kb_id)
+        result = await db.execute(query)
+        fields = result.scalars().all()
+        
+        return [f.to_dict() for f in fields]
+
 knowledge_base_service = KnowledgeBaseService()
 knowledge_base_service = KnowledgeBaseService()

+ 114 - 11
src/app/services/milvus_service.py

@@ -25,21 +25,99 @@ class MilvusService:
         # 获取embedding model
         # 获取embedding model
         self.emdmodel = get_embedding_model()
         self.emdmodel = get_embedding_model()
 
 
-    def create_collection(self, name: str, dimension: int = 768, description: str = "") -> None:
-        """创建 Milvus 集合"""
+    def create_collection(self, name: str, dimension: int = 768, description: str = "", fields: List[Dict] = None) -> None:
+        """
+        创建 Milvus 集合
+        :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...}
+        """
         if self.client.has_collection(name):
         if self.client.has_collection(name):
             logger.info(f"Collection {name} already exists.")
             logger.info(f"Collection {name} already exists.")
             return
             return
         
         
-        # 使用简化的 create_collection API
-        self.client.create_collection(
-            collection_name=name,
-            dimension=dimension,
-            description=description,
-            auto_id=True,  # 自动生成 ID
-            id_type="int", # ID 类型
-            metric_type="COSINE" # 默认使用余弦相似度
-        )
+        # 如果有自定义字段,使用 schema 创建
+        if fields:
+            from pymilvus import MilvusClient, DataType
+            
+            # 1. 创建 Schema
+            schema = MilvusClient.create_schema(
+                auto_id=True,
+                enable_dynamic_field=True,
+                description=description
+            )
+            
+            # 2. 添加必须的默认字段
+            schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
+            schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
+            # schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) # 如果需要混合检索,可能需要
+            
+            # 3. 添加用户自定义字段
+            # 映射字符串类型到 pymilvus DataType
+            type_map = {
+                "BOOL": DataType.BOOL,
+                "INT8": DataType.INT8,
+                "INT16": DataType.INT16,
+                "INT32": DataType.INT32,
+                "INT64": DataType.INT64,
+                "FLOAT": DataType.FLOAT,
+                "DOUBLE": DataType.DOUBLE,
+                "VARCHAR": DataType.VARCHAR,
+                "JSON": DataType.JSON,
+                "FLOAT_VECTOR": DataType.FLOAT_VECTOR
+            }
+            
+            for f in fields:
+                dtype = type_map.get(f.get("type", "").upper())
+                if not dtype:
+                    continue # 忽略未知类型
+                
+                kwargs = {
+                    "field_name": f.get("name"),
+                    "datatype": dtype,
+                    "description": f.get("description", "")
+                }
+                
+                if dtype == DataType.VARCHAR:
+                    kwargs["max_length"] = f.get("max_length", 65535)
+                
+                schema.add_field(**kwargs)
+            
+            # 4. 准备索引参数
+            index_params = self.client.prepare_index_params()
+            
+            # 5. 添加向量索引
+            index_params.add_index(
+                field_name="vector", 
+                index_type="AUTOINDEX",
+                metric_type="COSINE"
+            )
+            
+            # 6. 为自定义标量字段添加索引 (可选,这里为所有标量字段添加倒排索引以加速过滤)
+            for f in fields:
+                # VARCHAR/INT/BOOL 等支持索引
+                if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"]:
+                    index_params.add_index(
+                        field_name=f.get("name"),
+                        index_type="INVERTED" # 标量字段倒排索引
+                    )
+
+            # 7. 创建集合
+            self.client.create_collection(
+                collection_name=name,
+                schema=schema,
+                index_params=index_params
+            )
+            
+        else:
+            # 使用简化的 create_collection API
+            self.client.create_collection(
+                collection_name=name,
+                dimension=dimension,
+                description=description,
+                auto_id=True,  # 自动生成 ID
+                id_type="int", # ID 类型
+                metric_type="COSINE" # 默认使用余弦相似度
+            )
+        
         logger.info(f"Created collection {name} with dimension {dimension}")
         logger.info(f"Created collection {name} with dimension {dimension}")
 
 
     def drop_collection(self, name: str) -> None:
     def drop_collection(self, name: str) -> None:
@@ -187,6 +265,8 @@ class MilvusService:
 
 
         # 提取索引信息
         # 提取索引信息
         indices = []
         indices = []
+        
+        # 尝试从 describe_collection 结果中获取 (兼容旧逻辑)
         if "indexes" in desc:
         if "indexes" in desc:
             for idx in desc["indexes"]:
             for idx in desc["indexes"]:
                 index_info = {
                 index_info = {
@@ -197,6 +277,29 @@ class MilvusService:
                     "params": idx.get("params"),
                     "params": idx.get("params"),
                 }
                 }
                 indices.append(index_info)
                 indices.append(index_info)
+        
+        # 如果没有获取到索引信息,尝试主动查询 list_indexes
+        if not indices:
+            try:
+                # 获取索引列表 (通常返回索引名称列表)
+                index_names = self.client.list_indexes(collection_name=name)
+                if index_names:
+                    for idx_name in index_names:
+                        try:
+                            # 获取索引详情
+                            idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name)
+                            if idx_desc:
+                                indices.append({
+                                    "field_name": idx_desc.get("field_name"),
+                                    "index_name": idx_desc.get("index_name"),
+                                    "index_type": idx_desc.get("index_type"),
+                                    "metric_type": idx_desc.get("metric_type"),
+                                    "params": idx_desc.get("params"),
+                                })
+                        except Exception:
+                            continue
+            except Exception as e:
+                logger.warning(f"Failed to list/describe indexes for {name}: {e}")
 
 
         detail = {
         detail = {
             "name": name,
             "name": name,

+ 338 - 21
src/app/services/search_engine_service.py

@@ -23,6 +23,7 @@ from app.sample.schemas.search_engine import (
 from app.schemas.base import PaginationSchema
 from app.schemas.base import PaginationSchema
 from app.services.milvus_service import milvus_service
 from app.services.milvus_service import milvus_service
 from app.utils.vector_utils import text_to_vector_algo
 from app.utils.vector_utils import text_to_vector_algo
+import logging
 
 
 class SearchEngineService:
 class SearchEngineService:
     
     
@@ -36,33 +37,300 @@ class SearchEngineService:
             return KBSearchResponse(results=[], total=0)
             return KBSearchResponse(results=[], total=0)
             
             
         # 1. 使用算法生成向量 (替代 Embedding 模型)
         # 1. 使用算法生成向量 (替代 Embedding 模型)
+        # 尝试从 Milvus collection 获取向量维度,动态匹配维度
         # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
         # 这样相同的查询词会生成相同的向量,具备了基本的检索能力
-        query_vector = text_to_vector_algo(payload.query, dim=768)
+        try:
+            collection_detail = milvus_service.get_collection_detail(kb_id)
+        except Exception:
+            collection_detail = None
+
+        dim = None
+        if collection_detail and isinstance(collection_detail, dict):
+            fields = collection_detail.get("fields", []) or []
+            for f in fields:
+                # 根据字段类型查找向量字段(Milvus 向量字段类型通常为 FloatVector / float_vector)
+                if not isinstance(f, dict):
+                    continue
+                ftype = str(f.get("type") or "").lower()
+                print(ftype+'是什么东西')
+                if "100" in ftype or '101' in ftype:  # 假设 100 和 101 分别代表 FloatVector 和 BinaryVector
+                    # 找到向量字段,优先从 params.dim 获取维度
+                    params = f.get("params") or {}
+                    if params and params.get("dim"):
+                        try:
+                            dim = int(params.get("dim"))
+                            break
+                        except Exception:
+                            dim = None
+        # 回退默认维度
+        if not dim:
+            dim = 768
+
+        # 选择 Milvus 向量字段名(anns_field),字段名可能不是固定的 "vector",也可能叫 'dense'/'denser' 等
+        anns_field = "vector"
+        if collection_detail and isinstance(collection_detail, dict):
+            fields = collection_detail.get("fields", []) or []
+            # 优先寻找有 params.dim 的向量字段
+            for f in fields:
+                if not isinstance(f, dict):
+                    continue
+                params = f.get("params") or {}
+                if params and params.get("dim") and f.get("name"):
+                    anns_field = f.get("name")
+                    try:
+                        dim = int(params.get("dim"))
+                    except Exception:
+                        pass
+                    break
+
+            # 若未找到带 dim 的字段,尝试匹配常见的向量字段名或字段类型包含 "vector"
+            if anns_field == "vector":
+                for f in fields:
+                    if not isinstance(f, dict):
+                        continue
+                    fname = (f.get("name") or "")
+                    ftype = str(f.get("type") or "").lower()
+                    if fname and fname.lower() in ("vector", "denser", "dense", "embedding", "embeddings"):
+                        anns_field = fname
+                        break
+                    if fname and "vector" in ftype:
+                        anns_field = fname
+                        break
+
+        # 1. 向量搜索 (Dense Retrieval)
+        # 默认使用 Hybrid 混合检索逻辑,但为了简化,这里先保留向量检索的核心
+        # 如果 metric_type 指定为 hybrid,则可能需要结合关键词搜索等
+        # 目前后端实现主要是基于 Milvus 的 ANN 搜索
+        
+        # 强制使用 hybrid 混合检索模式作为基础(结合关键词匹配和向量相似度)
+        # 除非用户明确指定了其他度量方式(通常不会)
+        requested_metric = payload.metric_type
+        use_hybrid = False
+        
+        # 只有当 metric_type 为 None 或者特定值时才尝试混合检索
+        # 或者我们可以认为只要不指定,就优先尝试混合
+        if not requested_metric or requested_metric.lower() == 'hybrid':
+             use_hybrid = True
+        
+        search_params = {
+            "metric_type": "L2", # 默认内部计算用 L2
+            "params": {"nprobe": 10},
+        }
+        
+        # 如果前端指定了 metric_type (虽然前端现在默认 hybrid,但保留参数兼容性)
+        if payload.metric_type and payload.metric_type.upper() != 'HYBRID':
+             search_params["metric_type"] = payload.metric_type
         
         
         # 2. 构建过滤表达式
         # 2. 构建过滤表达式
-        expr = ""
+        expr_list = []
+        
+        # 兼容旧的单一字段过滤
         if payload.metadata_field and payload.metadata_value:
         if payload.metadata_field and payload.metadata_value:
-            # 示例:假设元数据直接作为字段存在,或者在 extra_info JSON 中
-            # 这里需要根据实际 Milvus Collection 的 Schema 调整
-            # 暂时忽略,以免报错
-            pass
+            safe_field = payload.metadata_field.replace("'", "").replace('"', "").strip()
+            safe_value = payload.metadata_value.replace("'", "").replace('"', "").strip()
             
             
+            if safe_field and safe_value:
+                if safe_value.isdigit():
+                    expr_list.append(f'{safe_field} == {safe_value}')
+                else:
+                    expr_list.append(f'{safe_field} == "{safe_value}"')
+        
+        # 处理新的多重过滤
+        if payload.filters:
+            for f in payload.filters:
+                safe_field = f.field.replace("'", "").replace('"', "").strip()
+                safe_value = f.value.replace("'", "").replace('"', "").strip()
+                
+                if safe_field and safe_value:
+                    if safe_value.isdigit():
+                        expr_list.append(f'{safe_field} == {safe_value}')
+                    else:
+                        expr_list.append(f'{safe_field} == "{safe_value}"')
+        
+        # 组合所有条件 (使用 AND)
+        expr = " and ".join(expr_list) if expr_list else ""
+        
+        # 选择 Milvus 向量字段名后生成向量 (移到这里,因为之前代码被替换掉了)
+        query_vector = text_to_vector_algo(payload.query, dim=dim)
+        
+        # 检测 collection 使用的 metric (恢复这部分逻辑,因为后续 search 需要)
+        metric_type = None
+        # 优先从 collection_detail 检测真实 metric
+        if collection_detail and isinstance(collection_detail, dict):
+            indices = collection_detail.get("indices") or []
+            if isinstance(indices, list) and len(indices) > 0:
+                for idx in indices:
+                    try:
+                        mt = idx.get("metric_type") or idx.get("metric")
+                        if mt:
+                            metric_type = str(mt).upper()
+                            break
+                    except Exception:
+                        continue
+        
+        # 尝试从 properties 中读取
+        if not metric_type and collection_detail and isinstance(collection_detail, dict):
+            props = collection_detail.get("properties") or {}
+            if isinstance(props, dict):
+                mt = props.get("metric_type") or props.get("metric")
+                if mt:
+                    metric_type = str(mt).upper()
+        
+        actual_search_metric = metric_type
+        if not actual_search_metric:
+             # 如果无法检测到 collection metric (如无索引),则可以使用用户请求的或默认 L2
+             actual_search_metric = requested_metric if requested_metric and requested_metric.upper() != 'HYBRID' else "L2"
+        
+        metric_type = actual_search_metric
+        
+        logger = logging.getLogger(__name__)
+        logger.info(f"Search KB={kb_id} using anns_field={anns_field}, dim={dim}, metric={metric_type} (requested={requested_metric})")
+
         # 3. 执行 Milvus 搜索
         # 3. 执行 Milvus 搜索
         try:
         try:
+            # 使用 collection 实际的 metric_type 作为检索度量,避免 mismatch 错误
+            # metric_type 已在上面检测并存放于变量 metric_type
             search_params = {
             search_params = {
-                "metric_type": "COSINE", 
+                "metric_type": metric_type,
                 "params": {"nprobe": 10}
                 "params": {"nprobe": 10}
             }
             }
+
+            # 如果 top_k <= 0 或未指定,解释为返回该 collection 中的所有文段
+            # 优先使用 page/page_size 计算 limit 和 offset
+            page = payload.page if payload.page and payload.page > 0 else 1
+            page_size = payload.page_size if payload.page_size and payload.page_size > 0 else 10
             
             
-            results = milvus_service.client.search(
-                collection_name=kb_id,
-                data=[query_vector],
-                anns_field="vector", 
-                search_params=search_params,
-                limit=payload.top_k,
-                filter=expr if expr else "",
-                output_fields=["*"] 
-            )
+            # 如果 payload 中有 top_k 且未传 page_size (或者 page_size 是默认值),可以使用 top_k 覆盖 page_size
+            # 但这里为了清晰,优先使用 page_size
+            
+            offset = (page - 1) * page_size
+            limit = page_size
+            
+            # Milvus 对 limit + offset 有限制 (通常 16384),这里做个简单的保护
+            if offset + limit > 16384:
+                # 如果超出深度分页限制,可能需要提示或截断
+                # 这里暂时不做处理,让 Milvus 报错或自行截断
+                pass
+
+            # 获取集合总数用于分页显示 (total)
+            collection_count = 0
+            if collection_detail and isinstance(collection_detail, dict):
+                collection_count = collection_detail.get("entity_count") or 0
+            
+            if not collection_count:
+                try:
+                    stats = milvus_service.client.get_collection_stats(collection_name=kb_id)
+                    collection_count = int(stats.get("row_count")) if isinstance(stats, dict) and stats.get("row_count") else 0
+                except Exception:
+                    collection_count = 0
+
+            # 如果是按照 top_k 逻辑 (不传 page/page_size),保留旧逻辑 (top_k 即 limit, offset=0)
+            # 但现在 Schema 默认 page=1, page_size=10,所以总是走分页逻辑
+            
+            try:
+                # 尝试使用混合检索 (Hybrid Search)
+                # 只有当用户没有显式指定 metric_type 或者指定为 hybrid 时,且集合支持(通常通过异常回退处理)时使用
+                # 但考虑到 metric_type 可能是 L2/COSINE,我们这里先尝试 hybrid,如果失败回退到普通
+                
+                # 为了不破坏现有逻辑,我们可以根据某种标志来决定是否使用 hybrid
+                # 或者默认尝试 hybrid,如果 collection 不支持 sparse 则会报错回退
+                
+                # 这里我们直接调用 milvus_service.hybrid_search
+                # 注意:hybrid_search 返回的格式与 client.search 不同,需要适配
+                
+                use_hybrid = False
+                # 只有当 metric_type 为 None 或者特定值时才尝试混合检索,避免与用户明确指定的 metric 冲突
+                # 或者我们可以认为只要不指定,就优先尝试混合
+                # 已经在上面判断过 use_hybrid = True 了
+                
+                if use_hybrid:
+                    logger.info(f"Attempting hybrid search for KB={kb_id}")
+                    try:
+                        # Hybrid search (LangChain Milvus) 暂时不支持直接传 offset
+                        # 所以我们需要获取 top_k = offset + limit,然后手动切片
+                        target_k = offset + limit
+                        
+                        hybrid_results = milvus_service.hybrid_search(
+                            collection_name=kb_id,
+                            query_text=payload.query,
+                            top_k=target_k
+                        )
+                        
+                        # 手动切片实现分页
+                        start = offset
+                        end = offset + limit
+                        # 确保不越界
+                        if start >= len(hybrid_results):
+                            sliced_results = []
+                        else:
+                            sliced_results = hybrid_results[start:end]
+                        
+                        formatted_results = []
+                        for item in sliced_results:
+                            formatted_results.append(KBSearchResultItem(
+                                id=str(item.get('id')),
+                                kb_name=kb_id,
+                                doc_name=item.get('metadata', {}).get('file_name') or item.get('metadata', {}).get('source') or "未知文档",
+                                content=item.get('text_content') or "",
+                                meta_info=str(item.get('metadata', {})),
+                                score=item.get('similarity', 0) * 100 # 假设是 0-1
+                            ))
+
+                        return KBSearchResponse(results=formatted_results, total=collection_count)
+
+                    except Exception as hybrid_err:
+                        logger.warning(f"Hybrid search failed, falling back to vector search: {hybrid_err}")
+                        # Fallback to standard vector search below
+                        pass
+
+                results = milvus_service.client.search(
+                    collection_name=kb_id,
+                    data=[query_vector],
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    limit=limit,
+                    offset=offset, # 添加 offset 支持分页
+                    filter=expr if expr else "",
+                    output_fields=["*"] 
+                )
+            except Exception as milvus_err:
+                # 捕获 Milvus 异常,常见原因包括 metric mismatch
+                logger.error(f"Milvus search failed for collection={kb_id}, metric_requested={metric_type}, anns_field={anns_field}: {milvus_err}")
+                
+                # Retry Logic: 如果是因为 metric 不匹配,解析错误信息中的 expected metric 并重试
+                error_msg = str(milvus_err)
+                if "metric type not match" in error_msg:
+                    import re
+                    # 匹配 expected=COSINE 或 expected='COSINE' 等格式
+                    # 支持 COSINE, L2, IP, BM25 等
+                    match = re.search(r"expected\s*=\s*['\"]?([A-Za-z0-9_]+)['\"]?", error_msg)
+                    if match:
+                        correct_metric = match.group(1).upper()
+                        logger.warning(f"Detected metric mismatch. Retrying with correct metric: {correct_metric}")
+                        
+                        # 更新 metric_type 并重试搜索
+                        search_params["metric_type"] = correct_metric
+                        # 同时也需要更新后续计算分数所用的 metric_type 变量,以便正确计算相似度
+                        metric_type = correct_metric
+                        
+                        # 特殊处理: BM25 可能需要 sparse vector 或其他参数,但 Milvus search 接口应该是一致的
+                        # 如果是 BM25,可能 anns_field 也要调整(通常 BM25 用 sparse vector)
+                        # 但这里假设 anns_field 是正确的,只是 metric 不对
+                        
+                        results = milvus_service.client.search(
+                            collection_name=kb_id,
+                            data=[query_vector],
+                            anns_field=anns_field,
+                            search_params=search_params,
+                            limit=limit,
+                            offset=offset, # 同样加上 offset
+                            filter=expr if expr else "",
+                            output_fields=["*"] 
+                        )
+                    else:
+                        raise
+                else:
+                    raise
             
             
             # 4. 格式化结果
             # 4. 格式化结果
             formatted_results = []
             formatted_results = []
@@ -73,27 +341,76 @@ class SearchEngineService:
                     #     continue
                     #     continue
                         
                         
                     entity = hit.entity
                     entity = hit.entity
-                    
+
                     content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
                     content = entity.get("text") or entity.get("content") or entity.get("page_content") or ""
                     if not content:
                     if not content:
-                        debug_data = {k:v for k,v in entity.items() if k != "vector"}
+                        debug_data = {k: v for k, v in entity.items() if k != anns_field}
                         content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
                         content = json.dumps(debug_data, ensure_ascii=False)[:200] + "..."
                         
                         
                     doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
                     doc_name = entity.get("file_name") or entity.get("title") or entity.get("source") or "未知文档"
                     
                     
                     meta_info = []
                     meta_info = []
                     for k, v in entity.items():
                     for k, v in entity.items():
-                        if k not in ["vector", "text", "content", "page_content", "id", "pk"]:
+                        if k not in [anns_field, "text", "content", "page_content", "id", "pk"]:
                             meta_info.append(f"{k}: {v}")
                             meta_info.append(f"{k}: {v}")
                     meta_str = "; ".join(meta_info[:3])
                     meta_str = "; ".join(meta_info[:3])
                     
                     
+                    # 根据 collection 的 metric 动态计算相似度分数
+                    # 如果用户请求了特定的 metric,尝试适配;否则使用实际 metric
+                    display_metric = requested_metric if requested_metric else metric_type
+                    
+                    similarity_pct = None
+                    try:
+                        raw_score = float(hit.score)
+                    except Exception:
+                        raw_score = None
+
+                    if raw_score is not None:
+                        # 核心计算逻辑:先根据 metric_type 理解 raw_score,再根据 display_metric 转换
+                        # 目前简化处理:直接根据 display_metric 解释 raw_score,忽略不兼容的情况
+                        # 更好的做法是:
+                        # 1. 识别 raw_score 的物理意义(距离还是相似度),基于 metric_type
+                        # 2. 转换为 display_metric 要求的格式
+                        
+                        # Case 1: 实际是 L2 (距离),用户想看 L2
+                        if "L2" in metric_type or "EUCLIDEAN" in metric_type:
+                            distance = raw_score
+                            if display_metric and ("COSINE" in display_metric):
+                                # L2 距离转 Cosine 相似度 (仅适用于归一化向量)
+                                # dist^2 = 2(1-cos) => cos = 1 - dist^2/2
+                                # 但这里简单起见,如果类型不匹配,还是按 L2 算百分比,避免数值错误
+                                similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
+                            else:
+                                similarity_pct = round((1.0 / (1.0 + distance)) * 100.0, 2)
+                                
+                        # Case 2: 实际是 Cosine (相似度 [-1, 1])
+                        elif "COSINE" in metric_type:
+                            cosine_score = raw_score
+                            # 无论用户想看什么,Cosine Score 本身就是相似度,直接归一化到 0-100 最直观
+                            similarity_pct = round(max(min((cosine_score + 1.0) / 2.0, 1.0), 0.0) * 100.0, 2)
+                            
+                        # Case 3: IP (内积)
+                        elif "IP" in metric_type or "INNER" in metric_type:
+                             similarity_pct = round(raw_score * 100.0, 2)
+                        
+                        # Fallback
+                        else:
+                            # 兼容 BM25 或其他未知 metric
+                            if "BM25" in metric_type:
+                                # BM25 分数通常是正数,没有固定上限,直接显示原值
+                                similarity_pct = round(raw_score, 2)
+                            else:
+                                similarity_pct = round(raw_score * 100.0, 2)
+                    else:
+                        similarity_pct = 0.0
+
                     formatted_results.append(KBSearchResultItem(
                     formatted_results.append(KBSearchResultItem(
                         id=str(hit.id),
                         id=str(hit.id),
-                        kb_name=kb_id, 
+                        kb_name=kb_id,
                         doc_name=doc_name,
                         doc_name=doc_name,
                         content=content,
                         content=content,
                         meta_info=meta_str,
                         meta_info=meta_str,
-                        score=round(hit.score * 100, 2)
+                        score=similarity_pct
                     ))
                     ))
             
             
             return KBSearchResponse(results=formatted_results, total=len(formatted_results))
             return KBSearchResponse(results=formatted_results, total=len(formatted_results))

+ 17 - 4
src/app/services/snippet_service.py

@@ -132,14 +132,21 @@ class SnippetService:
         # 使用统一算法生成向量
         # 使用统一算法生成向量
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         
         
-        data = [{
+        # 基础数据
+        item = {
             "vector": fake_vector,
             "vector": fake_vector,
             "text": payload.content,
             "text": payload.content,
             "source": payload.doc_name,
             "source": payload.doc_name,
             "doc_id": "manual_add",
             "doc_id": "manual_add",
             "file_name": payload.doc_name, 
             "file_name": payload.doc_name, 
             "title": payload.doc_name
             "title": payload.doc_name
-        }]
+        }
+        
+        # 合并自定义字段
+        if hasattr(payload, 'custom_fields') and payload.custom_fields:
+            item.update(payload.custom_fields)
+            
+        data = [item]
         
         
         res = milvus_service.client.insert(
         res = milvus_service.client.insert(
             collection_name=payload.collection_name,
             collection_name=payload.collection_name,
@@ -169,14 +176,20 @@ class SnippetService:
         # 使用统一算法生成向量
         # 使用统一算法生成向量
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         fake_vector = text_to_vector_algo(payload.content, dim=768)
         
         
-        data = [{
+        item = {
             "vector": fake_vector,
             "vector": fake_vector,
             "text": payload.content,
             "text": payload.content,
             "source": payload.doc_name or "已更新",
             "source": payload.doc_name or "已更新",
             "doc_id": "updated",
             "doc_id": "updated",
             "file_name": payload.doc_name,
             "file_name": payload.doc_name,
             "title": payload.doc_name
             "title": payload.doc_name
-        }]
+        }
+        
+        # 合并自定义字段
+        if hasattr(payload, 'custom_fields') and payload.custom_fields:
+            item.update(payload.custom_fields)
+            
+        data = [item]
         
         
         milvus_service.client.insert(collection_name=kb, data=data)
         milvus_service.client.insert(collection_name=kb, data=data)
         milvus_service.client.flush(kb)
         milvus_service.client.flush(kb)

+ 38 - 0
src/views/knowledge_base_view.py

@@ -102,3 +102,41 @@ async def delete_knowledge_base(
 
 
     await knowledge_base_service.delete(db, id)
     await knowledge_base_service.delete(db, id)
     return ResponseSchema(code=0, message="删除成功")
     return ResponseSchema(code=0, message="删除成功")
+
+@router.get("/{id}/metadata", response_model=ResponseSchema)
+async def get_knowledge_base_metadata(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取知识库的元数据字段定义和自定义Schema"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        data = await knowledge_base_service.get_metadata_and_schema(db, id)
+        return ResponseSchema(code=0, message="获取成功", data=data)
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"获取失败: {str(e)}")
+
+@router.post("/{id}/sync", response_model=ResponseSchema)
+async def sync_knowledge_base(
+    id: str = Path(..., description="知识库ID"),
+    db: AsyncSession = Depends(get_db),
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """同步创建Milvus集合"""
+    payload_token = verify_token(credentials.credentials)
+    if not payload_token:
+        return ResponseSchema(code=401, message="无效的访问令牌")
+
+    try:
+        await knowledge_base_service.sync_to_milvus(db, id)
+        return ResponseSchema(code=0, message="同步成功")
+    except ValueError as e:
+        return ResponseSchema(code=400, message=str(e))
+    except Exception as e:
+        return ResponseSchema(code=500, message=f"同步失败: {str(e)}")

+ 3 - 1
src/views/snippet_view.py

@@ -3,7 +3,7 @@
 """
 """
 from fastapi import APIRouter, Depends, Query, Path, Body
 from fastapi import APIRouter, Depends, Query, Path, Body
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
-from typing import Optional
+from typing import Optional, Dict, Any
 from datetime import datetime
 from datetime import datetime
 import urllib.parse
 import urllib.parse
 
 
@@ -22,11 +22,13 @@ class SnippetCreate(BaseModel):
     doc_name: str = "手动添加"
     doc_name: str = "手动添加"
     content: str
     content: str
     meta_info: Optional[str] = None
     meta_info: Optional[str] = None
+    custom_fields: Optional[Dict[str, Any]] = None
 
 
 class SnippetUpdate(BaseModel):
 class SnippetUpdate(BaseModel):
     collection_name: str
     collection_name: str
     doc_name: Optional[str] = None
     doc_name: Optional[str] = None
     content: str
     content: str
+    custom_fields: Optional[Dict[str, Any]] = None
 
 
 @router.get("", response_model=PaginatedResponseSchema)
 @router.get("", response_model=PaginatedResponseSchema)
 async def get_snippets(
 async def get_snippets(

+ 2 - 1
src/views/system_view.py

@@ -789,7 +789,8 @@ async def update_role_menus(
         
         
         # 调用 service 层
         # 调用 service 层
         system_service = SystemService()
         system_service = SystemService()
-        success, data, message = await system_service.update_role_menus(role_id, menu_ids)
+        updater_id = payload.get("sub")
+        success, data, message = await system_service.update_role_menus(role_id, menu_ids, updater_id)
         
         
         if success:
         if success:
             return ApiResponse(
             return ApiResponse(