| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- """
- Diagnose GLM-4.7 tool calling behavior.
- Calls the API directly and logs raw SSE chunks.
- """
- import os
- import sys
- import json
- import requests
- # Load .env
- from dotenv import load_dotenv
- env_path = os.path.join(os.path.dirname(__file__), '.env')
- load_dotenv(env_path)
- # Get model config from command line or use defaults
- API_BASE = sys.argv[1] if len(sys.argv) > 1 else "https://open.bigmodel.cn/api/paas/v4/"
- API_KEY = sys.argv[2] if len(sys.argv) > 2 else "your-api-key-here"
- MODEL_NAME = sys.argv[3] if len(sys.argv) > 3 else "glm-4-flash"
- TOOLS = [
- {
- "type": "function",
- "function": {
- "name": "query_database",
- "description": "Execute a SQL SELECT query to retrieve data from a PostgreSQL database. Available tables:\n1. deep_collection (id, url, content, summary, status, error_msg, created_at, updated_at)\n2. spider_result (id, task_id, title, abstract, source, cover, link, created_at)\n3. collection_task (id, keyword, source, status, created_at, finished_at)\nUse PostgreSQL syntax. Use COUNT/GROUP BY for statistics. Always call this tool when the user asks about data.",
- "parameters": {
- "type": "object",
- "properties": {
- "sql": {
- "type": "string",
- "description": "The SQL SELECT query to execute."
- }
- },
- "required": ["sql"]
- }
- }
- },
- {
- "type": "function",
- "function": {
- "name": "render_chart",
- "description": "Generate a chart configuration for the user interface. You MUST use this tool to show charts.",
- "parameters": {
- "type": "object",
- "properties": {
- "title": {"type": "string"},
- "chart_type": {"type": "string", "enum": ["bar", "line", "pie"]},
- "data": {"type": "object"}
- },
- "required": ["title", "chart_type", "data"]
- }
- }
- }
- ]
- def main():
- url = API_BASE.rstrip('/') + '/chat/completions'
- headers = {
- "Authorization": f"Bearer {API_KEY}",
- "Content-Type": "application/json"
- }
- payload = {
- "model": MODEL_NAME,
- "messages": [
- {"role": "system", "content": "你是一个专业的数据分析助手。当用户询问任何关于数据的问题时,你必须首先调用 query_database 工具查询数据。SQL 使用 PostgreSQL 语法。统计类问题使用 COUNT + GROUP BY。"},
- {"role": "user", "content": "统计一下最近采集的新闻来源分布"}
- ],
- "tools": TOOLS,
- "tool_choice": "auto",
- "stream": True,
- "max_tokens": 4096,
- "temperature": 0.1
- }
- print(f"=== GLM-4.7 Tool Call Diagnostic ===")
- print(f"URL: {url}")
- print(f"Model: {MODEL_NAME}")
- print(f"Sending request...\n")
- chunk_count = 0
- tool_call_chunks = []
- content_chunks = []
- with requests.post(url, json=payload, headers=headers, stream=True, timeout=60) as response:
- print(f"HTTP Status: {response.status_code}")
- if response.status_code != 200:
- print(f"Error: {response.text}")
- return
- for line in response.iter_lines():
- if not line:
- continue
- decoded_line = line.decode('utf-8')
- if decoded_line.startswith('data: '):
- data_str = decoded_line[6:].strip()
- if data_str == '[DONE]':
- print(f"\n[DONE] received after {chunk_count} chunks")
- break
- chunk_count += 1
- try:
- data_json = json.loads(data_str)
- if 'choices' in data_json and len(data_json['choices']) > 0:
- choice = data_json['choices'][0]
- delta = choice.get('delta', {})
- finish_reason = choice.get('finish_reason')
- # Log content
- c = delta.get('content', '')
- if c:
- content_chunks.append(c)
- if len(content_chunks) <= 3:
- print(f"Chunk #{chunk_count} [content]: {repr(c[:80])}")
- # Log tool_calls
- if 'tool_calls' in delta:
- tc_chunk = delta['tool_calls']
- tool_call_chunks.append(tc_chunk)
- print(f"Chunk #{chunk_count} [tool_calls]: {json.dumps(tc_chunk, ensure_ascii=False)}")
- # Log finish_reason
- if finish_reason:
- print(f"Chunk #{chunk_count} [finish_reason]: {finish_reason}")
- # Log usage if present
- if 'usage' in data_json and data_json['usage']:
- print(f"Chunk #{chunk_count} [usage]: {data_json['usage']}")
- except json.JSONDecodeError as e:
- print(f"Chunk #{chunk_count} [JSON ERROR]: {e}")
- print(f" Raw: {data_str[:200]}")
- print(f"\n=== Summary ===")
- print(f"Total chunks: {chunk_count}")
- print(f"Content chunks: {len(content_chunks)}")
- print(f"Tool call chunks: {len(tool_call_chunks)}")
- if content_chunks:
- full_content = ''.join(content_chunks)
- print(f"\nFull content ({len(full_content)} chars):")
- print(full_content[:500])
- if tool_call_chunks:
- print(f"\nAll tool call chunks:")
- for i, tc in enumerate(tool_call_chunks):
- print(f" {i}: {json.dumps(tc, ensure_ascii=False)}")
- else:
- print("\n*** NO TOOL CALLS RECEIVED ***")
- print("The model did NOT invoke any tools.")
- if __name__ == '__main__':
- main()
|