|
|
@@ -1,171 +0,0 @@
|
|
|
-#!/usr/bin/env python
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
-"""
|
|
|
-批量查询多阶段召回示例脚本
|
|
|
-
|
|
|
-用途:
|
|
|
-- 复用 Milvus 混合检索 + rerank 的多阶段召回能力
|
|
|
-- 一次传入多条查询,逐条输出召回摘要,便于对比效果
|
|
|
-
|
|
|
-运行:
|
|
|
- python test/batch_multi_stage_recall.py
|
|
|
-"""
|
|
|
-
|
|
|
-import os
|
|
|
-import sys
|
|
|
-import time
|
|
|
-from typing import List, Dict, Any
|
|
|
-
|
|
|
-# 确保可以从项目根目录导入
|
|
|
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
-
|
|
|
-from foundation.ai.rag.retrieval.retrieval import retrieval_manager
|
|
|
-from foundation.database.base.vector.milvus_vector import MilvusVectorManager
|
|
|
-
|
|
|
-
|
|
|
-def build_sample_documents() -> List[Dict[str, Any]]:
|
|
|
- """构造示例文档,便于快速跑通脚本。"""
|
|
|
- return [
|
|
|
- {
|
|
|
- "content": "大模型是一类具有大量参数的人工智能模型,通常拥有数十亿到数千亿个参数。",
|
|
|
- "metadata": {"source": "ai_basics", "category": "technology"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "深度学习是机器学习的一个分支,使用多层神经网络来学习数据的复杂模式。",
|
|
|
- "metadata": {"source": "ml_basics", "category": "technology"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "自然语言处理是人工智能的一个重要领域,专注于计算机与人类语言的交互。",
|
|
|
- "metadata": {"source": "nlp_basics", "category": "ai_field"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "Transformer架构是现代大语言模型的基础,由Google在2017年提出。",
|
|
|
- "metadata": {"source": "transformer_info", "category": "architecture"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "GPT、BERT、T5等都是基于Transformer架构的著名预训练语言模型。",
|
|
|
- "metadata": {"source": "model_examples", "category": "models"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "苹果公司是一家美国跨国技术公司,总部位于加利福尼亚州库比蒂诺。",
|
|
|
- "metadata": {"source": "company_info", "category": "business"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "机器学习算法需要大量的训练数据来学习有效的特征表示。",
|
|
|
- "metadata": {"source": "ml_training", "category": "technology"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "计算机视觉是人工智能的另一个重要分支,专注于图像和视频的理解。",
|
|
|
- "metadata": {"source": "cv_basics", "category": "ai_field"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "强化学习通过与环境交互来学习最优策略,广泛应用于游戏和机器人控制。",
|
|
|
- "metadata": {"source": "rl_basics", "category": "technology"},
|
|
|
- },
|
|
|
- {
|
|
|
- "content": "预训练-微调范式已经成为大模型训练的标准方法,可以显著提高模型性能。",
|
|
|
- "metadata": {"source": "training_paradigm", "category": "technology"},
|
|
|
- },
|
|
|
- ]
|
|
|
-
|
|
|
-
|
|
|
-def create_collection_with_samples(collection_name: str) -> str:
|
|
|
- """创建(或覆盖)测试集合并写入示例文档。"""
|
|
|
- docs = build_sample_documents()
|
|
|
- vector_manager = MilvusVectorManager()
|
|
|
- vectorstore = vector_manager.create_hybrid_collection(
|
|
|
- collection_name=collection_name,
|
|
|
- documents=docs,
|
|
|
- )
|
|
|
- if vectorstore:
|
|
|
- print(f"✓ 集合 '{collection_name}' 准备就绪,文档数: {len(docs)}")
|
|
|
- return collection_name
|
|
|
- raise RuntimeError("创建或写入集合失败")
|
|
|
-
|
|
|
-
|
|
|
-def run_batch_queries(
|
|
|
- collection_name: str,
|
|
|
- queries: List[str],
|
|
|
- hybrid_top_k: int = 10,
|
|
|
- final_top_k: int = 5,
|
|
|
-) -> List[Dict[str, Any]]:
|
|
|
- """
|
|
|
- 对多条查询依次执行多阶段召回。
|
|
|
- 返回所有结果(按查询顺序拼接)。
|
|
|
- """
|
|
|
- all_results: List[Dict[str, Any]] = []
|
|
|
- for idx, query in enumerate(queries, 1):
|
|
|
- print("\n" + "-" * 60)
|
|
|
- print(f"[{idx}/{len(queries)}] 查询: {query}")
|
|
|
- start = time.time()
|
|
|
- try:
|
|
|
- results = retrieval_manager.multi_stage_recall(
|
|
|
- collection_name=collection_name,
|
|
|
- query_text=query,
|
|
|
- hybrid_top_k=hybrid_top_k,
|
|
|
- final_top_k=final_top_k,
|
|
|
- )
|
|
|
- except Exception as exc:
|
|
|
- print(f"✗ 召回失败: {exc}")
|
|
|
- continue
|
|
|
-
|
|
|
- cost = time.time() - start
|
|
|
- print(f"✓ 召回完成,耗时 {cost:.2f}s,返回 {len(results)} 条")
|
|
|
-
|
|
|
- # 展示前 3 条
|
|
|
- for i, r in enumerate(results[:3], 1):
|
|
|
- print(
|
|
|
- f" {i}. sim={r.get('similarity', 0):.4f} "
|
|
|
- f"rerank={r.get('rerank_score', 0):.4f} "
|
|
|
- f"hybrid={r.get('hybrid_score', 0):.4f}"
|
|
|
- )
|
|
|
- print(f" 内容: {r.get('text_content', '')[:80]}...")
|
|
|
- print(f" 元数据: {r.get('metadata', {})}")
|
|
|
-
|
|
|
- all_results.extend(results)
|
|
|
- return all_results
|
|
|
-
|
|
|
-
|
|
|
-def main():
|
|
|
- collection_name = "batch_bfp_collection"
|
|
|
-
|
|
|
- # 1) 准备集合与示例数据
|
|
|
- try:
|
|
|
- create_collection_with_samples(collection_name)
|
|
|
- except Exception as exc:
|
|
|
- print(f"✗ 集合准备失败: {exc}")
|
|
|
- sys.exit(1)
|
|
|
-
|
|
|
- # 2) 定义批量查询
|
|
|
- batch_queries = [
|
|
|
- "什么是大模型?",
|
|
|
- "Transformer架构有什么特点",
|
|
|
- "苹果公司做什么的",
|
|
|
- "深度学习和机器学习的关系",
|
|
|
- "人工智能伦理",
|
|
|
- ]
|
|
|
-
|
|
|
- # 3) 批量召回
|
|
|
- all_results = run_batch_queries(
|
|
|
- collection_name=collection_name,
|
|
|
- queries=batch_queries,
|
|
|
- hybrid_top_k=12,
|
|
|
- final_top_k=6,
|
|
|
- )
|
|
|
-
|
|
|
- # 4) 简单汇总
|
|
|
- print("\n" + "=" * 60)
|
|
|
- print("批量召回完成")
|
|
|
- print(f"总查询数: {len(batch_queries)}")
|
|
|
- print(f"累计结果数: {len(all_results)}")
|
|
|
-
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- main()
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|