debug_query_extract.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 调试QueryRewriteManager.query_extract方法
  5. """
  6. import sys
  7. import os
  8. import time
  9. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from foundation.ai.rag.retrieval.query_rewrite import QueryRewriteManager
  11. from foundation.observability.logger.loggering import server_logger as logger
  12. def debug_query_extract():
  13. """
  14. 调试query_extract方法
  15. """
  16. print("="*60)
  17. print("调试QueryRewriteManager.query_extract方法")
  18. print("="*60)
  19. # 测试数据
  20. review_content = "深度大于3m的基坑开挖、有地下水侵扰的基坑清底封底,每个工作班至少巡查两遍。"
  21. print(f"原始输入内容: {review_content}")
  22. print(f"内容长度: {len(review_content)}")
  23. try:
  24. # 手动构建提示词模板进行调试
  25. from foundation.ai.rag.retrieval.query_rewrite import prompt_loader, generate_model_client
  26. import uuid
  27. import asyncio
  28. # 获取提示词模板
  29. task_prompt = prompt_loader.get_prompt_template(
  30. reviewer_type="query_extract",
  31. review_content=review_content
  32. )
  33. print(f"\n[DEBUG] 提示词模板类型: {type(task_prompt)}")
  34. # 尝试格式化消息
  35. try:
  36. messages = task_prompt.format_messages()
  37. print(f"[DEBUG] 消息数量: {len(messages)}")
  38. print(f"[DEBUG] 系统消息: {messages[0].content[:200]}...")
  39. print(f"[DEBUG] 用户消息: {messages[1].content[:300]}...")
  40. # 检查用户消息是否包含正确的review_content
  41. if review_content in messages[1].content:
  42. print("[OK] review_content 正确传递到提示词")
  43. else:
  44. print("[ERROR] review_content 未正确传递到提示词")
  45. print(f"[DEBUG] 期望内容: {review_content}")
  46. print(f"[DEBUG] 用户消息实际内容: {messages[1].content}")
  47. except Exception as e:
  48. print(f"[ERROR] 格式化消息失败: {e}")
  49. return
  50. # 构建任务提示信息
  51. task_prompt_info = {
  52. "task_prompt": task_prompt,
  53. "task_name": "query_extract"
  54. }
  55. # 生成trace_id
  56. trace_id = str(uuid.uuid4())
  57. print(f"[DEBUG] Trace ID: {trace_id}")
  58. # 调用模型生成接口
  59. print("[DEBUG] 开始调用模型...")
  60. model_response = asyncio.run(generate_model_client.get_model_generate_invoke(
  61. trace_id=trace_id,
  62. task_prompt_info=task_prompt_info
  63. ))
  64. print(f"[DEBUG] 模型响应: {model_response}")
  65. # 使用原始方法进行对比测试
  66. print("\n" + "="*40)
  67. print("使用原始QueryRewriteManager方法测试")
  68. print("="*40)
  69. query_rewrite_manager = QueryRewriteManager()
  70. start_time = time.time()
  71. result = query_rewrite_manager.query_extract(review_content)
  72. end_time = time.time()
  73. elapsed_time = end_time - start_time
  74. print(f"[OK] 原始方法提取完成,耗时: {elapsed_time:.2f}秒")
  75. print(f"[OK] 原始方法返回结果: {result}")
  76. except Exception as e:
  77. print(f"[ERROR] 调试失败: {str(e)}")
  78. import traceback
  79. traceback.print_exc()
  80. def main():
  81. """
  82. 主测试函数
  83. """
  84. print("开始调试 QueryRewriteManager.query_extract 方法")
  85. debug_query_extract()
  86. print("\n调试完成")
  87. if __name__ == "__main__":
  88. main()