Selaa lähdekoodia

feat: t_rag_kng_standard父子表创建(chinese)

ai02 4 viikkoa sitten
vanhempi
sitoutus
298abd6e09
1 muutettua tiedostoa jossa 250 lisäystä ja 0 poistoa
  1. 250 0
      src/app/scripts/base_collections_create.py

+ 250 - 0
src/app/scripts/base_collections_create.py

@@ -0,0 +1,250 @@
+"""
+创建中文施工标准规范 Milvus Collection 脚本
+- t_rag_kng_chinese_parent
+- t_rag_kng_chinese
+"""
+from __future__ import annotations
+
+from pymilvus import DataType, Function, FunctionType
+
+from app.config.milvus_client import get_milvusclient
+
+PARENT_COLLECTION_NAME = "t_rag_kng_standard_parent1"
+CHILD_COLLECTION_NAME = "t_rag_kng_standard1"
+DENSE_DIM = 4096
+
+
+def create_schema():
+    """按目标结构创建 schema。"""
+    client = get_milvusclient()
+    schema = client.create_schema(auto_id=True, enable_dynamic_fields=True)
+
+    schema.add_field(
+        "pk",
+        DataType.INT64,
+        is_primary=True,
+        auto_id=True,
+        nullable=False,
+        description="主键",
+    )
+    schema.add_field(
+        "text",
+        DataType.VARCHAR,
+        max_length=65535,
+        enable_analyzer=True,
+        analyzer_params={"type": "chinese"},
+        nullable=False,
+        description="内容",
+    )
+    schema.add_field(
+        "dense",
+        DataType.FLOAT_VECTOR,
+        dim=DENSE_DIM,
+        nullable=False,
+        description="向量列",
+    )
+    schema.add_field(
+        "sparse",
+        DataType.SPARSE_FLOAT_VECTOR,
+        nullable=False,
+        description="内容的BM25关键字检索",
+    )
+    schema.add_field(
+        "document_id",
+        DataType.VARCHAR,
+        max_length=128,
+        nullable=False,
+        description="样本中心上传文档ID",
+    )
+    schema.add_field(
+        "parent_id",
+        DataType.VARCHAR,
+        max_length=128,
+        nullable=False,
+        description="父段ID",
+    )
+    schema.add_field(
+        "index",
+        DataType.INT64,
+        nullable=False,
+        description="索引序号",
+    )
+    schema.add_field(
+        "tag_list",
+        DataType.VARCHAR,
+        max_length=2048,
+        nullable=False,
+        description="标签",
+    )
+    schema.add_field(
+        "permission",
+        DataType.JSON,
+        nullable=False,
+        description="权限",
+    )
+    schema.add_field(
+        "metadata",
+        DataType.JSON,
+        nullable=False,
+        description="元数据",
+    )
+    schema.add_field(
+        "is_deleted",
+        DataType.BOOL,
+        nullable=False,
+        description="删除标志",
+    )
+    schema.add_field(
+        "created_by",
+        DataType.VARCHAR,
+        max_length=128,
+        nullable=False,
+        description="创建人",
+    )
+    schema.add_field(
+        "created_time",
+        DataType.INT64,
+        nullable=False,
+        description="创建时间",
+    )
+    schema.add_field(
+        "updated_by",
+        DataType.VARCHAR,
+        max_length=128,
+        nullable=False,
+        description="修改人",
+    )
+    schema.add_field(
+        "updated_time",
+        DataType.INT64,
+        nullable=False,
+        description="修改时间",
+    )
+
+    schema.add_function(
+        Function(
+            name="bm25_fn",
+            input_field_names=["text"],
+            output_field_names=["sparse"],
+            function_type=FunctionType.BM25,
+        )
+    )
+    return schema
+
+
+def create_index(client, collection_name: str):
+    """按目标结构创建所有索引。"""
+    index_params = client.prepare_index_params()
+
+    index_params.add_index(
+        field_name="text",
+        index_name="text",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="dense",
+        index_name="dense",
+        index_type="AUTOINDEX",
+        metric_type="IP",
+    )
+    index_params.add_index(
+        field_name="sparse",
+        index_name="sparse",
+        index_type="SPARSE_INVERTED_INDEX",
+        metric_type="BM25",
+    )
+    index_params.add_index(
+        field_name="document_id",
+        index_name="document_id",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="parent_id",
+        index_name="parent_id",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="index",
+        index_name="index",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="tag_list",
+        index_name="tag_list",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="permission",
+        index_name="permission",
+        index_type="INVERTED",
+        params={"json_cast_type": "VARCHAR"},
+    )
+    index_params.add_index(
+        field_name="metadata",
+        index_name="metadata",
+        index_type="INVERTED",
+        params={"json_cast_type": "VARCHAR"},
+    )
+    index_params.add_index(
+        field_name="is_deleted",
+        index_name="is_deleted",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="created_by",
+        index_name="created_by",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="created_time",
+        index_name="created_time",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="updated_by",
+        index_name="updated_by",
+        index_type="INVERTED",
+    )
+    index_params.add_index(
+        field_name="updated_time",
+        index_name="updated_time",
+        index_type="INVERTED",
+    )
+
+    client.create_index(collection_name=collection_name, index_params=index_params)
+
+
+def ensure_collection(collection_name: str, auto_load: bool = True):
+    """确保 collection 存在,不存在则创建。"""
+    client = get_milvusclient()
+    if client.has_collection(collection_name=collection_name):
+        print(f"Collection 已存在: {collection_name}")
+        if auto_load:
+            client.load_collection(collection_name=collection_name)
+            print(f"Collection 已加载: {collection_name}")
+        return False
+
+    schema = create_schema()
+    client.create_collection(
+        collection_name=collection_name,
+        schema=schema,
+        consistency_level="Bounded",
+        num_shards=1,
+        properties={"timezone": "Asia/Shanghai"},
+    )
+    create_index(client, collection_name)
+
+    if auto_load:
+        client.load_collection(collection_name=collection_name)
+
+    print(f"Collection 创建完成: {collection_name}")
+    return True
+
+
+def main():
+    ensure_collection(PARENT_COLLECTION_NAME, auto_load=True)
+    ensure_collection(CHILD_COLLECTION_NAME, auto_load=True)
+
+
+if __name__ == "__main__":
+    main()