test_hybrid_search_debug.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 混合检索问题诊断脚本
  5. 用于排查 hybrid_search 返回0结果的问题
  6. """
  7. import sys
  8. import os
  9. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  10. from pymilvus import connections, Collection, utility
  11. from foundation.ai.models.model_handler import model_handler
  12. from foundation.observability.logger.loggering import server_logger as logger
  13. def check_milvus_connection():
  14. """检查 Milvus 连接"""
  15. print("\n" + "="*60)
  16. print("1. 检查 Milvus 连接")
  17. print("="*60)
  18. try:
  19. from foundation.infrastructure.config.config import config_handler
  20. host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  21. port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  22. connections.connect(
  23. alias="debug",
  24. host=host,
  25. port=port,
  26. db_name="lq_db"
  27. )
  28. print(f"✅ Milvus 连接成功: {host}:{port}")
  29. return True
  30. except Exception as e:
  31. print(f"❌ Milvus 连接失败: {e}")
  32. return False
  33. def check_collection_exists(collection_name: str):
  34. """检查 Collection 是否存在"""
  35. print(f"\n2. 检查 Collection 是否存在: {collection_name}")
  36. print("-"*60)
  37. exists = utility.has_collection(collection_name, using="debug")
  38. if exists:
  39. print(f"✅ Collection '{collection_name}' 存在")
  40. else:
  41. print(f"❌ Collection '{collection_name}' 不存在!")
  42. return exists
  43. def check_collection_schema(collection_name: str):
  44. """检查 Collection Schema 结构"""
  45. print(f"\n3. 检查 Collection Schema 结构")
  46. print("-"*60)
  47. try:
  48. col = Collection(collection_name, using="debug")
  49. schema = col.schema
  50. print(f"Collection: {collection_name}")
  51. print(f"Description: {schema.description}")
  52. print(f"\n字段列表:")
  53. has_dense = False
  54. has_sparse = False
  55. field_names = []
  56. for field in schema.fields:
  57. field_names.append(field.name)
  58. print(f" - {field.name}: {field.dtype.name}", end="")
  59. if hasattr(field, 'dim') and field.dim:
  60. print(f" (dim={field.dim})", end="")
  61. if field.is_primary:
  62. print(" [PRIMARY]", end="")
  63. print()
  64. # 检查关键字段
  65. if field.name == "dense":
  66. has_dense = True
  67. if field.name == "sparse":
  68. has_sparse = True
  69. print(f"\n混合搜索所需字段检查:")
  70. print(f" - dense 字段: {'✅ 存在' if has_dense else '❌ 不存在'}")
  71. print(f" - sparse 字段: {'✅ 存在' if has_sparse else '❌ 不存在'}")
  72. if not has_dense or not has_sparse:
  73. print(f"\n⚠️ 警告: Collection 缺少混合搜索所需的字段!")
  74. print(f" 混合搜索需要 'dense' 和 'sparse' 两个字段")
  75. print(f" 当前字段: {field_names}")
  76. return has_dense and has_sparse
  77. except Exception as e:
  78. print(f"❌ 获取 Schema 失败: {e}")
  79. return False
  80. def check_collection_data(collection_name: str):
  81. """检查 Collection 数据量"""
  82. print(f"\n4. 检查 Collection 数据量")
  83. print("-"*60)
  84. try:
  85. col = Collection(collection_name, using="debug")
  86. col.load()
  87. num_entities = col.num_entities
  88. print(f"数据量: {num_entities} 条")
  89. if num_entities == 0:
  90. print("❌ Collection 为空,没有数据!")
  91. return False
  92. else:
  93. print("✅ Collection 有数据")
  94. return True
  95. except Exception as e:
  96. print(f"❌ 获取数据量失败: {e}")
  97. return False
  98. def check_collection_index(collection_name: str):
  99. """检查 Collection 索引"""
  100. print(f"\n5. 检查 Collection 索引")
  101. print("-"*60)
  102. try:
  103. col = Collection(collection_name, using="debug")
  104. indexes = col.indexes
  105. if not indexes:
  106. print("❌ 没有索引!")
  107. return False
  108. for idx in indexes:
  109. print(f" - 字段: {idx.field_name}")
  110. print(f" 索引参数: {idx.params}")
  111. print("✅ 索引存在")
  112. return True
  113. except Exception as e:
  114. print(f"❌ 获取索引失败: {e}")
  115. return False
  116. def test_traditional_search(collection_name: str, query_text: str):
  117. """测试传统向量搜索(不使用混合搜索)"""
  118. print(f"\n6. 测试传统向量搜索")
  119. print("-"*60)
  120. try:
  121. col = Collection(collection_name, using="debug")
  122. col.load()
  123. # 获取 embedding
  124. emdmodel = model_handler.get_embedding_model()
  125. query_vector = emdmodel.embed_query(query_text)
  126. print(f"查询文本: {query_text}")
  127. print(f"向量维度: {len(query_vector)}")
  128. # 确定向量字段名
  129. vector_field = None
  130. for field in col.schema.fields:
  131. if "FLOAT_VECTOR" in str(field.dtype):
  132. vector_field = field.name
  133. break
  134. if not vector_field:
  135. print("❌ 未找到向量字段")
  136. return False
  137. print(f"向量字段: {vector_field}")
  138. # 执行搜索
  139. search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
  140. results = col.search(
  141. data=[query_vector],
  142. anns_field=vector_field,
  143. param=search_params,
  144. limit=5,
  145. output_fields=["text"]
  146. )
  147. print(f"\n搜索结果: {len(results[0])} 条")
  148. for i, hit in enumerate(results[0]):
  149. print(f" {i+1}. ID={hit.id}, 距离={hit.distance:.4f}")
  150. if len(results[0]) > 0:
  151. print("✅ 传统向量搜索正常")
  152. return True
  153. else:
  154. print("❌ 传统向量搜索也返回0结果")
  155. return False
  156. except Exception as e:
  157. print(f"❌ 传统搜索失败: {e}")
  158. import traceback
  159. traceback.print_exc()
  160. return False
  161. def test_langchain_hybrid_search(collection_name: str, query_text: str):
  162. """测试 LangChain Milvus 混合搜索"""
  163. print(f"\n7. 测试 LangChain Milvus 混合搜索")
  164. print("-"*60)
  165. try:
  166. from langchain_milvus import Milvus, BM25BuiltInFunction
  167. from foundation.infrastructure.config.config import config_handler
  168. host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  169. port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  170. connection_args = {
  171. "uri": f"http://{host}:{port}",
  172. "db_name": "lq_db"
  173. }
  174. emdmodel = model_handler.get_embedding_model()
  175. print(f"尝试连接 Collection: {collection_name}")
  176. print(f"连接参数: {connection_args}")
  177. # 尝试创建 vectorstore
  178. vectorstore = Milvus(
  179. embedding_function=emdmodel,
  180. collection_name=collection_name,
  181. connection_args=connection_args,
  182. consistency_level="Strong",
  183. builtin_function=BM25BuiltInFunction(),
  184. vector_field=["dense", "sparse"]
  185. )
  186. print("✅ Vectorstore 创建成功")
  187. # 执行混合搜索
  188. print(f"\n执行混合搜索,查询: {query_text}")
  189. results = vectorstore.similarity_search_with_score(
  190. query=query_text,
  191. k=5,
  192. ranker_type="weighted",
  193. ranker_params={"weights": [0.7, 0.3]}
  194. )
  195. print(f"搜索结果: {len(results)} 条")
  196. for i, (doc, score) in enumerate(results):
  197. content = doc.page_content[:50] if doc.page_content else "N/A"
  198. print(f" {i+1}. score={score:.4f}, content={content}...")
  199. if len(results) > 0:
  200. print("✅ 混合搜索正常")
  201. return True
  202. else:
  203. print("❌ 混合搜索返回0结果")
  204. return False
  205. except Exception as e:
  206. print(f"❌ 混合搜索失败: {e}")
  207. import traceback
  208. traceback.print_exc()
  209. return False
  210. def test_retrieval_manager(collection_name: str, query_text: str):
  211. """测试 RetrievalManager 的混合搜索"""
  212. print(f"\n8. 测试 RetrievalManager 混合搜索")
  213. print("-"*60)
  214. try:
  215. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  216. results = retrieval_manager.hybrid_search_recall(
  217. collection_name=collection_name,
  218. query_text=query_text,
  219. top_k=5,
  220. ranker_type="weighted",
  221. dense_weight=0.7,
  222. sparse_weight=0.3
  223. )
  224. print(f"搜索结果: {len(results)} 条")
  225. for i, result in enumerate(results):
  226. content = result.get('text_content', '')[:50]
  227. print(f" {i+1}. {content}...")
  228. if len(results) > 0:
  229. print("✅ RetrievalManager 混合搜索正常")
  230. return True
  231. else:
  232. print("❌ RetrievalManager 混合搜索返回0结果")
  233. return False
  234. except Exception as e:
  235. print(f"❌ RetrievalManager 测试失败: {e}")
  236. import traceback
  237. traceback.print_exc()
  238. return False
  239. def main():
  240. """主诊断函数"""
  241. print("\n" + "="*60)
  242. print("混合检索问题诊断")
  243. print("="*60)
  244. # 配置
  245. collection_name = "first_bfp_collection_entity"
  246. query_text = "高空作业"
  247. print(f"\n诊断目标:")
  248. print(f" - Collection: {collection_name}")
  249. print(f" - 查询文本: {query_text}")
  250. # 执行诊断
  251. results = {}
  252. # 1. 检查连接
  253. results['connection'] = check_milvus_connection()
  254. if not results['connection']:
  255. print("\n❌ Milvus 连接失败,无法继续诊断")
  256. return
  257. # 2. 检查 Collection 存在
  258. results['exists'] = check_collection_exists(collection_name)
  259. if not results['exists']:
  260. print(f"\n❌ Collection '{collection_name}' 不存在,无法继续诊断")
  261. return
  262. # 3. 检查 Schema
  263. results['schema'] = check_collection_schema(collection_name)
  264. # 4. 检查数据量
  265. results['data'] = check_collection_data(collection_name)
  266. # 5. 检查索引
  267. results['index'] = check_collection_index(collection_name)
  268. # 6. 测试传统搜索
  269. results['traditional'] = test_traditional_search(collection_name, query_text)
  270. # 7. 测试 LangChain 混合搜索
  271. results['langchain'] = test_langchain_hybrid_search(collection_name, query_text)
  272. # 8. 测试 RetrievalManager
  273. results['retrieval'] = test_retrieval_manager(collection_name, query_text)
  274. # 总结
  275. print("\n" + "="*60)
  276. print("诊断总结")
  277. print("="*60)
  278. for key, value in results.items():
  279. status = "✅" if value else "❌"
  280. print(f" {status} {key}")
  281. # 给出建议
  282. print("\n" + "="*60)
  283. print("问题分析与建议")
  284. print("="*60)
  285. if not results.get('schema'):
  286. print("""
  287. ⚠️ 主要问题: Collection Schema 不支持混合搜索
  288. 原因: Collection 缺少 'dense' 和 'sparse' 字段
  289. 混合搜索需要在创建 Collection 时使用 BM25BuiltInFunction
  290. 解决方案:
  291. 1. 使用 create_hybrid_collection 方法重新创建 Collection
  292. 2. 或者修改代码,对不支持混合搜索的 Collection 使用传统向量搜索
  293. """)
  294. if results.get('traditional') and not results.get('langchain'):
  295. print("""
  296. ⚠️ 问题: 传统搜索正常,但混合搜索失败
  297. 可能原因:
  298. 1. Collection 创建时未启用 BM25 功能
  299. 2. LangChain Milvus 版本兼容性问题
  300. 3. vector_field 配置与实际字段名不匹配
  301. 建议:
  302. 1. 检查 Collection 创建方式
  303. 2. 确认 langchain-milvus 版本
  304. """)
  305. if not results.get('data'):
  306. print("""
  307. ⚠️ 问题: Collection 为空
  308. 解决方案: 先向 Collection 中导入数据
  309. """)
  310. if __name__ == "__main__":
  311. main()