|
|
@@ -0,0 +1,171 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+批量查询多阶段召回示例脚本
|
|
|
+
|
|
|
+用途:
|
|
|
+- 复用 Milvus 混合检索 + rerank 的多阶段召回能力
|
|
|
+- 一次传入多条查询,逐条输出召回摘要,便于对比效果
|
|
|
+
|
|
|
+运行:
|
|
|
+ python test/batch_multi_stage_recall.py
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from typing import List, Dict, Any
|
|
|
+
|
|
|
+# 确保可以从项目根目录导入
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
+
|
|
|
+from foundation.ai.rag.retrieval.retrieval import retrieval_manager
|
|
|
+from foundation.database.base.vector.milvus_vector import MilvusVectorManager
|
|
|
+
|
|
|
+
|
|
|
+def build_sample_documents() -> List[Dict[str, Any]]:
|
|
|
+ """构造示例文档,便于快速跑通脚本。"""
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ "content": "大模型是一类具有大量参数的人工智能模型,通常拥有数十亿到数千亿个参数。",
|
|
|
+ "metadata": {"source": "ai_basics", "category": "technology"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "深度学习是机器学习的一个分支,使用多层神经网络来学习数据的复杂模式。",
|
|
|
+ "metadata": {"source": "ml_basics", "category": "technology"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "自然语言处理是人工智能的一个重要领域,专注于计算机与人类语言的交互。",
|
|
|
+ "metadata": {"source": "nlp_basics", "category": "ai_field"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "Transformer架构是现代大语言模型的基础,由Google在2017年提出。",
|
|
|
+ "metadata": {"source": "transformer_info", "category": "architecture"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "GPT、BERT、T5等都是基于Transformer架构的著名预训练语言模型。",
|
|
|
+ "metadata": {"source": "model_examples", "category": "models"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "苹果公司是一家美国跨国技术公司,总部位于加利福尼亚州库比蒂诺。",
|
|
|
+ "metadata": {"source": "company_info", "category": "business"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "机器学习算法需要大量的训练数据来学习有效的特征表示。",
|
|
|
+ "metadata": {"source": "ml_training", "category": "technology"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "计算机视觉是人工智能的另一个重要分支,专注于图像和视频的理解。",
|
|
|
+ "metadata": {"source": "cv_basics", "category": "ai_field"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "强化学习通过与环境交互来学习最优策略,广泛应用于游戏和机器人控制。",
|
|
|
+ "metadata": {"source": "rl_basics", "category": "technology"},
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "content": "预训练-微调范式已经成为大模型训练的标准方法,可以显著提高模型性能。",
|
|
|
+ "metadata": {"source": "training_paradigm", "category": "technology"},
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+def create_collection_with_samples(collection_name: str) -> str:
|
|
|
+ """创建(或覆盖)测试集合并写入示例文档。"""
|
|
|
+ docs = build_sample_documents()
|
|
|
+ vector_manager = MilvusVectorManager()
|
|
|
+ vectorstore = vector_manager.create_hybrid_collection(
|
|
|
+ collection_name=collection_name,
|
|
|
+ documents=docs,
|
|
|
+ )
|
|
|
+ if vectorstore:
|
|
|
+ print(f"✓ 集合 '{collection_name}' 准备就绪,文档数: {len(docs)}")
|
|
|
+ return collection_name
|
|
|
+ raise RuntimeError("创建或写入集合失败")
|
|
|
+
|
|
|
+
|
|
|
+def run_batch_queries(
|
|
|
+ collection_name: str,
|
|
|
+ queries: List[str],
|
|
|
+ hybrid_top_k: int = 10,
|
|
|
+ final_top_k: int = 5,
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 对多条查询依次执行多阶段召回。
|
|
|
+ 返回所有结果(按查询顺序拼接)。
|
|
|
+ """
|
|
|
+ all_results: List[Dict[str, Any]] = []
|
|
|
+ for idx, query in enumerate(queries, 1):
|
|
|
+ print("\n" + "-" * 60)
|
|
|
+ print(f"[{idx}/{len(queries)}] 查询: {query}")
|
|
|
+ start = time.time()
|
|
|
+ try:
|
|
|
+ results = retrieval_manager.multi_stage_recall(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query,
|
|
|
+ hybrid_top_k=hybrid_top_k,
|
|
|
+ final_top_k=final_top_k,
|
|
|
+ )
|
|
|
+ except Exception as exc:
|
|
|
+ print(f"✗ 召回失败: {exc}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ cost = time.time() - start
|
|
|
+ print(f"✓ 召回完成,耗时 {cost:.2f}s,返回 {len(results)} 条")
|
|
|
+
|
|
|
+ # 展示前 3 条
|
|
|
+ for i, r in enumerate(results[:3], 1):
|
|
|
+ print(
|
|
|
+ f" {i}. sim={r.get('similarity', 0):.4f} "
|
|
|
+ f"rerank={r.get('rerank_score', 0):.4f} "
|
|
|
+ f"hybrid={r.get('hybrid_score', 0):.4f}"
|
|
|
+ )
|
|
|
+ print(f" 内容: {r.get('text_content', '')[:80]}...")
|
|
|
+ print(f" 元数据: {r.get('metadata', {})}")
|
|
|
+
|
|
|
+ all_results.extend(results)
|
|
|
+ return all_results
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ collection_name = "batch_bfp_collection"
|
|
|
+
|
|
|
+ # 1) 准备集合与示例数据
|
|
|
+ try:
|
|
|
+ create_collection_with_samples(collection_name)
|
|
|
+ except Exception as exc:
|
|
|
+ print(f"✗ 集合准备失败: {exc}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ # 2) 定义批量查询
|
|
|
+ batch_queries = [
|
|
|
+ "什么是大模型?",
|
|
|
+ "Transformer架构有什么特点",
|
|
|
+ "苹果公司做什么的",
|
|
|
+ "深度学习和机器学习的关系",
|
|
|
+ "人工智能伦理",
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 3) 批量召回
|
|
|
+ all_results = run_batch_queries(
|
|
|
+ collection_name=collection_name,
|
|
|
+ queries=batch_queries,
|
|
|
+ hybrid_top_k=12,
|
|
|
+ final_top_k=6,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 4) 简单汇总
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("批量召回完成")
|
|
|
+ print(f"总查询数: {len(batch_queries)}")
|
|
|
+ print(f"累计结果数: {len(all_results)}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|