diagnose_tool_call.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """
  2. Diagnose GLM-4.7 tool calling behavior.
  3. Calls the API directly and logs raw SSE chunks.
  4. """
  5. import os
  6. import sys
  7. import json
  8. import requests
  9. # Load .env
  10. from dotenv import load_dotenv
  11. env_path = os.path.join(os.path.dirname(__file__), '.env')
  12. load_dotenv(env_path)
  13. # Get model config from command line or use defaults
  14. API_BASE = sys.argv[1] if len(sys.argv) > 1 else "https://open.bigmodel.cn/api/paas/v4/"
  15. API_KEY = sys.argv[2] if len(sys.argv) > 2 else "your-api-key-here"
  16. MODEL_NAME = sys.argv[3] if len(sys.argv) > 3 else "glm-4-flash"
  17. TOOLS = [
  18. {
  19. "type": "function",
  20. "function": {
  21. "name": "query_database",
  22. "description": "Execute a SQL SELECT query to retrieve data from a PostgreSQL database. Available tables:\n1. deep_collection (id, url, content, summary, status, error_msg, created_at, updated_at)\n2. spider_result (id, task_id, title, abstract, source, cover, link, created_at)\n3. collection_task (id, keyword, source, status, created_at, finished_at)\nUse PostgreSQL syntax. Use COUNT/GROUP BY for statistics. Always call this tool when the user asks about data.",
  23. "parameters": {
  24. "type": "object",
  25. "properties": {
  26. "sql": {
  27. "type": "string",
  28. "description": "The SQL SELECT query to execute."
  29. }
  30. },
  31. "required": ["sql"]
  32. }
  33. }
  34. },
  35. {
  36. "type": "function",
  37. "function": {
  38. "name": "render_chart",
  39. "description": "Generate a chart configuration for the user interface. You MUST use this tool to show charts.",
  40. "parameters": {
  41. "type": "object",
  42. "properties": {
  43. "title": {"type": "string"},
  44. "chart_type": {"type": "string", "enum": ["bar", "line", "pie"]},
  45. "data": {"type": "object"}
  46. },
  47. "required": ["title", "chart_type", "data"]
  48. }
  49. }
  50. }
  51. ]
  52. def main():
  53. url = API_BASE.rstrip('/') + '/chat/completions'
  54. headers = {
  55. "Authorization": f"Bearer {API_KEY}",
  56. "Content-Type": "application/json"
  57. }
  58. payload = {
  59. "model": MODEL_NAME,
  60. "messages": [
  61. {"role": "system", "content": "你是一个专业的数据分析助手。当用户询问任何关于数据的问题时,你必须首先调用 query_database 工具查询数据。SQL 使用 PostgreSQL 语法。统计类问题使用 COUNT + GROUP BY。"},
  62. {"role": "user", "content": "统计一下最近采集的新闻来源分布"}
  63. ],
  64. "tools": TOOLS,
  65. "tool_choice": "auto",
  66. "stream": True,
  67. "max_tokens": 4096,
  68. "temperature": 0.1
  69. }
  70. print(f"=== GLM-4.7 Tool Call Diagnostic ===")
  71. print(f"URL: {url}")
  72. print(f"Model: {MODEL_NAME}")
  73. print(f"Sending request...\n")
  74. chunk_count = 0
  75. tool_call_chunks = []
  76. content_chunks = []
  77. with requests.post(url, json=payload, headers=headers, stream=True, timeout=60) as response:
  78. print(f"HTTP Status: {response.status_code}")
  79. if response.status_code != 200:
  80. print(f"Error: {response.text}")
  81. return
  82. for line in response.iter_lines():
  83. if not line:
  84. continue
  85. decoded_line = line.decode('utf-8')
  86. if decoded_line.startswith('data: '):
  87. data_str = decoded_line[6:].strip()
  88. if data_str == '[DONE]':
  89. print(f"\n[DONE] received after {chunk_count} chunks")
  90. break
  91. chunk_count += 1
  92. try:
  93. data_json = json.loads(data_str)
  94. if 'choices' in data_json and len(data_json['choices']) > 0:
  95. choice = data_json['choices'][0]
  96. delta = choice.get('delta', {})
  97. finish_reason = choice.get('finish_reason')
  98. # Log content
  99. c = delta.get('content', '')
  100. if c:
  101. content_chunks.append(c)
  102. if len(content_chunks) <= 3:
  103. print(f"Chunk #{chunk_count} [content]: {repr(c[:80])}")
  104. # Log tool_calls
  105. if 'tool_calls' in delta:
  106. tc_chunk = delta['tool_calls']
  107. tool_call_chunks.append(tc_chunk)
  108. print(f"Chunk #{chunk_count} [tool_calls]: {json.dumps(tc_chunk, ensure_ascii=False)}")
  109. # Log finish_reason
  110. if finish_reason:
  111. print(f"Chunk #{chunk_count} [finish_reason]: {finish_reason}")
  112. # Log usage if present
  113. if 'usage' in data_json and data_json['usage']:
  114. print(f"Chunk #{chunk_count} [usage]: {data_json['usage']}")
  115. except json.JSONDecodeError as e:
  116. print(f"Chunk #{chunk_count} [JSON ERROR]: {e}")
  117. print(f" Raw: {data_str[:200]}")
  118. print(f"\n=== Summary ===")
  119. print(f"Total chunks: {chunk_count}")
  120. print(f"Content chunks: {len(content_chunks)}")
  121. print(f"Tool call chunks: {len(tool_call_chunks)}")
  122. if content_chunks:
  123. full_content = ''.join(content_chunks)
  124. print(f"\nFull content ({len(full_content)} chars):")
  125. print(full_content[:500])
  126. if tool_call_chunks:
  127. print(f"\nAll tool call chunks:")
  128. for i, tc in enumerate(tool_call_chunks):
  129. print(f" {i}: {json.dumps(tc, ensure_ascii=False)}")
  130. else:
  131. print("\n*** NO TOOL CALLS RECEIVED ***")
  132. print("The model did NOT invoke any tools.")
  133. if __name__ == '__main__':
  134. main()