test_model_stress.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型压力测试脚本
  5. 测试当前系统的模型调用压力水平,支持:
  6. - 可配置并发数(--concurrency)
  7. - 可配置测试总次数(--count)
  8. - 可配置上下文长度(--context-size):1k/2k/4k/8k/16k tokens
  9. - 选择不同 LLM / Embedding 模型
  10. - 输出延迟统计(avg/p50/p95/p99/min/max)、吞吐量、错误率
  11. 运行方式:
  12. # 使用默认配置(10并发,50次请求,蜀天35B)
  13. python utils_test/Model_Test/test_model_stress.py
  14. # python utils_test/Model_Test/test_model_stress.py --concurrency 150 --count 150 --model shutian_qwen3_6_27b --context-size 8k
  15. # 避免服务端 KV 缓存命中(注入随机值)
  16. python utils_test/Model_Test/test_model_stress.py --concurrency 10 --count 50 --bust-cache
  17. # 自定义参数
  18. python utils_test/Model_Test/test_model_stress.py --concurrency 20 --count 100 --model shutian_qwen3_5_122b
  19. # 测试不同上下文长度
  20. python utils_test/Model_Test/test_model_stress.py --context-size 4k
  21. python utils_test/Model_Test/test_model_stress.py --context-size 8k -c 5 -n 20
  22. # 自动遍历所有上下文长度(1k/2k/4k/8k/16k)生成对比报告
  23. python utils_test/Model_Test/test_model_stress.py --context-size all
  24. # 测试 Embedding 模型
  25. python utils_test/Model_Test/test_model_stress.py --type embedding --count 100 --concurrency 10
  26. # 测试所有 LLM 模型(逐个)
  27. python utils_test/Model_Test/test_model_stress.py --all-models
  28. # 使用 function_name(从 model_setting.yaml 读取模型)
  29. python utils_test/Model_Test/test_model_stress.py --function completeness_review_generate
  30. """
  31. import sys
  32. import asyncio
  33. import argparse
  34. import time
  35. import statistics
  36. import uuid
  37. from pathlib import Path
  38. from dataclasses import dataclass, field
  39. from typing import List, Optional, Tuple
  40. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
  41. sys.path.insert(0, str(PROJECT_ROOT))
  42. # ============================================================
  43. # 可用模型列表(与 model_setting.yaml available_models 一致)
  44. # ============================================================
  45. LLM_MODELS = [
  46. "qwen3_5_35b_a3b",
  47. "qwen3_5_27b",
  48. "qwen3_5_122b_a10b",
  49. "doubao",
  50. "doubao-1.5-pro-256k",
  51. "doubao-1.5-lite-32k",
  52. "deepseek",
  53. "deepseek-v3",
  54. "lq_qwen3_8b",
  55. "lq_qwen3_8b_lq_lora",
  56. "lq_qwen3_4b",
  57. "qwen_local_14b",
  58. "shutian_qwen3_5_122b",
  59. "shutian_qwen3_8b",
  60. "shutian_qwen3_5_35b",
  61. "shutian_qwen3_6_27b",
  62. ]
  63. EMBEDDING_MODELS = [
  64. "siliconflow_embed",
  65. "shutian_qwen3_embed",
  66. ]
  67. # 测试用 prompt
  68. TEST_SYSTEM_PROMPT = "你是一个测试助手,请简洁回答问题。"
  69. TEST_USER_PROMPT = "请用一句话回答:1+1等于几?"
  70. TEST_EMBED_TEXT = "这是一个模型Embedding压力测试文本,用于验证向量化服务的并发能力和响应延迟。"
  71. # 上下文填充文本(中文段落,约 1.5 字/token,用于模拟指定 token 数的上下文)
  72. _PADDING_SENTENCES = [
  73. "施工方案编制应结合工程实际,充分考虑施工环境、地质条件、气候因素等影响。",
  74. "桥梁工程的施工质量直接关系到结构安全和使用寿命,必须严格按照设计图纸和规范要求执行。",
  75. "混凝土浇筑前应检查模板支撑体系的稳定性,确保钢筋绑扎间距和保护层厚度满足设计要求。",
  76. "预应力张拉施工应按照设计张拉顺序进行,控制张拉力和伸长量在允许偏差范围内。",
  77. "基坑开挖过程中应加强监测,及时掌握围护结构变形和周边建筑物沉降情况。",
  78. "路基填筑应分层压实,每层压实厚度不宜超过30cm,压实度应满足设计和规范要求。",
  79. "隧道施工应遵循短开挖、强支护、早封闭、勤量测的原则,确保施工安全。",
  80. "钢结构焊接应由持证焊工操作,焊缝质量应符合设计要求和相关标准规定。",
  81. "施工测量放线应采用全站仪或GPS定位,确保平面位置和高程精度满足规范要求。",
  82. "安全生产管理应建立健全责任制,定期开展安全教育培训和隐患排查治理工作。",
  83. ]
  84. # 预生成不同长度的填充文本(避免每次请求重复生成)
  85. _CONTEXT_CACHE: dict = {}
  86. def _generate_context_text(target_tokens: int) -> str:
  87. """生成约 target_tokens 个 token 的中文上下文填充文本。
  88. Qwen 系列 tokenizer 中文约 1.8~2.2 字/token,取 2.2 倍冗余确保实际 token 数达标。
  89. 为保证填充质量,循环使用有意义的工程语句而非随机字符。
  90. """
  91. if target_tokens in _CONTEXT_CACHE:
  92. return _CONTEXT_CACHE[target_tokens]
  93. # 2.2 字/token 冗余系数,宁多勿少
  94. target_chars = int(target_tokens * 2.2)
  95. parts = []
  96. idx = 0
  97. while len("".join(parts)) < target_chars:
  98. parts.append(_PADDING_SENTENCES[idx % len(_PADDING_SENTENCES)])
  99. idx += 1
  100. text = "\n".join(parts)
  101. # 精确截断
  102. text = text[:target_chars]
  103. _CONTEXT_CACHE[target_tokens] = text
  104. return text
  105. CONTEXT_SIZE_PRESETS = {
  106. "1k": 1024,
  107. "2k": 2048,
  108. "4k": 4096,
  109. "8k": 8192,
  110. "16k": 16384,
  111. }
  112. # ============================================================
  113. # 统计数据结构
  114. # ============================================================
  115. @dataclass
  116. class RequestResult:
  117. success: bool
  118. latency_ms: float
  119. error: Optional[str] = None
  120. completion_tokens: int = 0 # 模型输出 token 数
  121. prompt_tokens: int = 0 # 模型输入 token 数
  122. @dataclass
  123. class StressTestResult:
  124. model_name: str
  125. model_type: str # "llm" or "embedding"
  126. concurrency: int
  127. total_requests: int
  128. context_size_tokens: int = 0 # 0 表示未指定
  129. success_count: int = 0
  130. fail_count: int = 0
  131. latencies_ms: List[float] = field(default_factory=list)
  132. completion_tokens_list: List[int] = field(default_factory=list)
  133. prompt_tokens_list: List[int] = field(default_factory=list)
  134. errors: List[str] = field(default_factory=list)
  135. total_time_s: float = 0.0
  136. # ============================================================
  137. # 测试执行器
  138. # ============================================================
  139. def _extract_token_usage(response) -> Tuple[int, int]:
  140. """从 LangChain AIMessage 中提取 token 使用量。
  141. 返回 (prompt_tokens, completion_tokens),提取失败返回 (0, 0)。
  142. """
  143. prompt_tokens = 0
  144. completion_tokens = 0
  145. # 方式1: usage_metadata(LangChain 标准)
  146. if hasattr(response, "usage_metadata") and response.usage_metadata:
  147. meta = response.usage_metadata
  148. prompt_tokens = getattr(meta, "input_tokens", 0) or 0
  149. completion_tokens = getattr(meta, "output_tokens", 0) or 0
  150. # 方式2: response_metadata.token_usage(OpenAI 兼容)
  151. if prompt_tokens == 0 and completion_tokens == 0:
  152. rm = getattr(response, "response_metadata", {}) or {}
  153. token_usage = rm.get("token_usage", {})
  154. prompt_tokens = token_usage.get("prompt_tokens", 0)
  155. completion_tokens = token_usage.get("completion_tokens", 0)
  156. return prompt_tokens, completion_tokens
  157. async def _run_llm_request(trace_id: str, model_name: Optional[str] = None,
  158. function_name: Optional[str] = None,
  159. context_size: int = 0,
  160. bust_cache: bool = False) -> RequestResult:
  161. """执行单次 LLM 调用并记录延迟和 token 用量
  162. Args:
  163. context_size: 上下文 token 数,>0 时在 user_prompt 前拼接填充文本
  164. bust_cache: 在 prompt 末尾追加随机值避免 KV 缓存命中
  165. """
  166. from foundation.ai.models.model_handler import model_handler
  167. from foundation.ai.models.model_config_loader import get_model_for_function, get_thinking_mode_for_function
  168. from langchain_core.messages import SystemMessage, HumanMessage
  169. # 解析最终使用的模型名称和 thinking 配置
  170. resolved_model = model_name
  171. enable_thinking = False
  172. if function_name:
  173. try:
  174. cfg_model = get_model_for_function(function_name)
  175. cfg_thinking = get_thinking_mode_for_function(function_name)
  176. if cfg_model:
  177. resolved_model = cfg_model
  178. if cfg_thinking is not None:
  179. enable_thinking = cfg_thinking
  180. except Exception:
  181. pass
  182. if not resolved_model:
  183. resolved_model = "shutian_qwen3_5_35b"
  184. user_prompt = TEST_USER_PROMPT
  185. if context_size > 0:
  186. padding = _generate_context_text(context_size)
  187. user_prompt = f"{padding}\n\n---\n\n{TEST_USER_PROMPT}"
  188. if bust_cache:
  189. rand = uuid.uuid4().hex[:12]
  190. user_prompt = f"[noise:{rand}]\n{user_prompt}"
  191. messages = [SystemMessage(content=TEST_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
  192. start = time.perf_counter()
  193. try:
  194. model = model_handler.get_model_by_name(resolved_model)
  195. # 处理 Qwen3.5 thinking 模式绑定
  196. is_qwen35 = "qwen3.5" in resolved_model.lower() or "qwen3_5" in resolved_model.lower()
  197. if is_qwen35:
  198. model = model.bind(extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}})
  199. response = await model.ainvoke(messages)
  200. latency = (time.perf_counter() - start) * 1000
  201. if not response or not response.content:
  202. return RequestResult(success=False, latency_ms=latency, error="空响应")
  203. prompt_tokens, completion_tokens = _extract_token_usage(response)
  204. return RequestResult(
  205. success=True, latency_ms=latency,
  206. completion_tokens=completion_tokens,
  207. prompt_tokens=prompt_tokens,
  208. )
  209. except Exception as e:
  210. latency = (time.perf_counter() - start) * 1000
  211. return RequestResult(success=False, latency_ms=latency, error=str(e)[:200])
  212. async def _run_embedding_request(trace_id: str, model_name: str) -> RequestResult:
  213. """执行单次 Embedding 调用并记录延迟"""
  214. from foundation.ai.models.model_handler import model_handler
  215. start = time.perf_counter()
  216. try:
  217. embed_model = model_handler.get_embedding_model()
  218. # OpenAIEmbeddings 使用 aembed_query
  219. result = await embed_model.aembed_query(TEST_EMBED_TEXT)
  220. latency = (time.perf_counter() - start) * 1000
  221. if not result:
  222. return RequestResult(success=False, latency_ms=latency, error="空响应")
  223. return RequestResult(success=True, latency_ms=latency)
  224. except Exception as e:
  225. latency = (time.perf_counter() - start) * 1000
  226. return RequestResult(success=False, latency_ms=latency, error=str(e)[:200])
  227. async def run_stress_test(
  228. model_name: str,
  229. model_type: str,
  230. concurrency: int,
  231. total_count: int,
  232. function_name: Optional[str] = None,
  233. context_size: int = 0,
  234. bust_cache: bool = False,
  235. ) -> StressTestResult:
  236. """执行压力测试
  237. Args:
  238. model_name: 模型名称
  239. model_type: "llm" 或 "embedding"
  240. concurrency: 并发数
  241. total_count: 总请求次数
  242. function_name: 功能名称(可选,仅 LLM 有效)
  243. context_size: 上下文 token 数(0=不填充)
  244. """
  245. display_name = function_name or model_name
  246. result = StressTestResult(
  247. model_name=display_name,
  248. model_type=model_type,
  249. concurrency=concurrency,
  250. total_requests=total_count,
  251. context_size_tokens=context_size,
  252. )
  253. semaphore = asyncio.Semaphore(concurrency)
  254. progress_done = 0
  255. async def _task(idx: int):
  256. nonlocal progress_done
  257. async with semaphore:
  258. trace_id = f"stress_{model_name}_{idx}"
  259. if model_type == "embedding":
  260. return await _run_embedding_request(trace_id, model_name)
  261. else:
  262. return await _run_llm_request(trace_id, model_name, function_name, context_size, bust_cache)
  263. ctx_label = f" | 上下文: {context_size//1024}k tokens" if context_size > 0 else ""
  264. print(f"\n{'='*60}")
  265. print(f" 模型压力测试: {display_name} ({model_type.upper()})")
  266. print(f" 并发数: {concurrency} | 总请求: {total_count}{ctx_label}")
  267. print(f"{'='*60}")
  268. wall_start = time.perf_counter()
  269. # 创建所有任务
  270. tasks = [_task(i) for i in range(total_count)]
  271. # 逐批执行并打印进度
  272. batch_size = min(concurrency * 2, total_count)
  273. for batch_start in range(0, total_count, batch_size):
  274. batch_end = min(batch_start + batch_size, total_count)
  275. batch = tasks[batch_start:batch_end]
  276. batch_results = await asyncio.gather(*batch, return_exceptions=True)
  277. for r in batch_results:
  278. if isinstance(r, Exception):
  279. result.fail_count += 1
  280. result.errors.append(str(r)[:200])
  281. result.latencies_ms.append(0)
  282. else:
  283. if r.success:
  284. result.success_count += 1
  285. result.completion_tokens_list.append(r.completion_tokens)
  286. result.prompt_tokens_list.append(r.prompt_tokens)
  287. else:
  288. result.fail_count += 1
  289. if r.error:
  290. result.errors.append(r.error)
  291. result.latencies_ms.append(r.latency_ms)
  292. progress_done = batch_end
  293. pct = progress_done / total_count * 100
  294. print(f" 进度: {progress_done}/{total_count} ({pct:.0f}%)", end="\r")
  295. result.total_time_s = time.perf_counter() - wall_start
  296. print()
  297. return result
  298. # ============================================================
  299. # 结果报告
  300. # ============================================================
  301. def print_report(result: StressTestResult):
  302. """打印测试报告"""
  303. successful = [l for l, r in zip(result.latencies_ms,
  304. [True]*result.success_count + [False]*result.fail_count)
  305. if r and l > 0]
  306. all_latencies = [l for l in result.latencies_ms if l > 0]
  307. ctx_label = f"{result.context_size_tokens//1024}k tokens" if result.context_size_tokens > 0 else "默认"
  308. print(f"\n{'─'*60}")
  309. print(f" 测试报告: {result.model_name} ({result.model_type.upper()})")
  310. print(f"{'─'*60}")
  311. print(f" 并发数: {result.concurrency}")
  312. print(f" 上下文长度: {ctx_label}")
  313. print(f" 总请求: {result.total_requests}")
  314. print(f" 成功: {result.success_count}")
  315. print(f" 失败: {result.fail_count}")
  316. print(f" 错误率: {result.fail_count/result.total_requests*100:.1f}%")
  317. print(f" 总耗时: {result.total_time_s:.2f}s")
  318. if result.total_time_s > 0:
  319. throughput = result.success_count / result.total_time_s
  320. print(f" 吞吐量: {throughput:.2f} req/s")
  321. # LLM token 统计
  322. total_completion_tokens = sum(result.completion_tokens_list)
  323. total_prompt_tokens = sum(result.prompt_tokens_list)
  324. has_token_data = total_completion_tokens > 0 or total_prompt_tokens > 0
  325. if has_token_data and result.total_time_s > 0:
  326. tokens_per_sec = total_completion_tokens / result.total_time_s
  327. avg_completion = total_completion_tokens / len(result.completion_tokens_list) if result.completion_tokens_list else 0
  328. avg_prompt = total_prompt_tokens / len(result.prompt_tokens_list) if result.prompt_tokens_list else 0
  329. print(f"\n Token 统计:")
  330. print(f" 总输入 token: {total_prompt_tokens}")
  331. print(f" 总输出 token: {total_completion_tokens}")
  332. print(f" 平均输入/请求: {avg_prompt:.0f}")
  333. print(f" 平均输出/请求: {avg_completion:.0f}")
  334. print(f" 输出 tokens/s: {tokens_per_sec:.1f}")
  335. elif has_token_data:
  336. print(f"\n Token 统计: (总耗时为0,无法计算吞吐)")
  337. if all_latencies:
  338. print(f"\n 延迟统计 (ms):")
  339. print(f" 最小值: {min(all_latencies):.0f}")
  340. print(f" 最大值: {max(all_latencies):.0f}")
  341. print(f" 平均值: {statistics.mean(all_latencies):.0f}")
  342. sorted_lat = sorted(all_latencies)
  343. p50 = sorted_lat[int(len(sorted_lat) * 0.5)]
  344. p95 = sorted_lat[min(int(len(sorted_lat) * 0.95), len(sorted_lat)-1)]
  345. p99 = sorted_lat[min(int(len(sorted_lat) * 0.99), len(sorted_lat)-1)]
  346. print(f" P50: {p50:.0f}")
  347. print(f" P95: {p95:.0f}")
  348. print(f" P99: {p99:.0f}")
  349. if len(all_latencies) > 1:
  350. print(f" 标准差: {statistics.stdev(all_latencies):.0f}")
  351. if result.errors:
  352. unique_errors = list(set(result.errors))[:5]
  353. print(f"\n 错误样本 (最多5条):")
  354. for err in unique_errors:
  355. print(f" - {err}")
  356. print(f"{'─'*60}")
  357. total_completion_tokens = sum(result.completion_tokens_list)
  358. total_prompt_tokens = sum(result.prompt_tokens_list)
  359. tokens_per_sec = round(total_completion_tokens / result.total_time_s, 1) if result.total_time_s > 0 and total_completion_tokens > 0 else 0
  360. return {
  361. "model": result.model_name,
  362. "type": result.model_type,
  363. "concurrency": result.concurrency,
  364. "context_tokens": result.context_size_tokens,
  365. "total": result.total_requests,
  366. "success": result.success_count,
  367. "fail": result.fail_count,
  368. "error_rate": f"{result.fail_count/result.total_requests*100:.1f}%",
  369. "total_time_s": round(result.total_time_s, 2),
  370. "throughput_rps": round(result.success_count / result.total_time_s, 2) if result.total_time_s > 0 else 0,
  371. "latency_avg_ms": round(statistics.mean(all_latencies), 0) if all_latencies else 0,
  372. "latency_p95_ms": round(sorted(all_latencies)[min(int(len(all_latencies)*0.95), len(all_latencies)-1)], 0) if all_latencies else 0,
  373. "completion_tokens": total_completion_tokens,
  374. "prompt_tokens": total_prompt_tokens,
  375. "tokens_per_sec": tokens_per_sec,
  376. }
  377. # ============================================================
  378. # 主入口
  379. # ============================================================
  380. def parse_args():
  381. parser = argparse.ArgumentParser(
  382. description="模型压力测试 - 测试 LLM / Embedding 模型的并发能力和延迟",
  383. formatter_class=argparse.RawDescriptionHelpFormatter,
  384. epilog="""
  385. 示例:
  386. # 默认测试(蜀天35B, 10并发, 50次)
  387. python utils_test/Model_Test/test_model_stress.py
  388. # 测试蜀天122B, 20并发, 100次
  389. python utils_test/Model_Test/test_model_stress.py --model shutian_qwen3_5_122b -c 20 -n 100
  390. # 测试 4k 上下文长度
  391. python utils_test/Model_Test/test_model_stress.py --context-size 4k
  392. # 自动遍历 1k/2k/4k/8k 上下文长度,输出对比
  393. python utils_test/Model_Test/test_model_stress.py --context-size all
  394. # 测试 Embedding
  395. python utils_test/Model_Test/test_model_stress.py --type embedding -c 20 -n 200
  396. # 使用 function_name
  397. python utils_test/Model_Test/test_model_stress.py --function completeness_review_generate
  398. # 测试所有 LLM 模型
  399. python utils_test/Model_Test/test_model_stress.py --all-models -c 5 -n 10
  400. """,
  401. )
  402. parser.add_argument(
  403. "--type", choices=["llm", "embedding"], default="llm",
  404. help="模型类型: llm 或 embedding (默认: llm)",
  405. )
  406. parser.add_argument(
  407. "--model", "-m", type=str, default=None,
  408. help="模型名称,如 shutian_qwen3_5_35b (默认: shutian_qwen3_5_35b)",
  409. )
  410. parser.add_argument(
  411. "--function", "-f", type=str, default=None,
  412. help="功能名称(从 model_setting.yaml 加载模型配置),如 completeness_review_generate",
  413. )
  414. parser.add_argument(
  415. "--concurrency", "-c", type=int, default=10,
  416. help="并发数 (默认: 10)",
  417. )
  418. parser.add_argument(
  419. "--count", "-n", type=int, default=50,
  420. help="总请求次数 (默认: 50)",
  421. )
  422. parser.add_argument(
  423. "--context-size", "-ctx", type=str, default=None,
  424. help="上下文长度: 1k / 2k / 4k / 8k / 16k / all(逐个测试)/ 数字如 2048 (默认: 不填充)",
  425. )
  426. parser.add_argument(
  427. "--all-models", action="store_true",
  428. help="逐个测试所有可用模型(使用 -c 和 -n 指定并发和次数)",
  429. )
  430. parser.add_argument(
  431. "--all-embeddings", action="store_true",
  432. help="逐个测试所有 Embedding 模型",
  433. )
  434. parser.add_argument(
  435. "--bust-cache", action="store_true",
  436. help="在每次请求的 prompt 末尾注入随机值,避免服务端 KV 缓存命中",
  437. )
  438. return parser.parse_args()
  439. def _parse_context_size(raw: Optional[str]) -> List[int]:
  440. """解析 --context-size 参数,返回 token 数列表。
  441. 支持: 1k, 2k, 4k, 8k, all, 或纯数字如 2048
  442. """
  443. if raw is None:
  444. return [0] # 0 = 不填充
  445. raw = raw.strip().lower()
  446. if raw == "all":
  447. return list(CONTEXT_SIZE_PRESETS.values())
  448. if raw in CONTEXT_SIZE_PRESETS:
  449. return [CONTEXT_SIZE_PRESETS[raw]]
  450. # 纯数字
  451. try:
  452. return [int(raw)]
  453. except ValueError:
  454. print(f" [错误] 不支持的 --context-size 值: {raw},可选: 1k/2k/4k/8k/all/数字")
  455. sys.exit(1)
  456. async def _run_single_model_test(args, model_name: str, function_name: Optional[str],
  457. context_sizes: List[int]) -> List[dict]:
  458. """对单个模型执行一组上下文长度的压力测试"""
  459. results_summary = []
  460. for ctx_size in context_sizes:
  461. ctx_display = f"{ctx_size//1024}k" if ctx_size > 0 else "默认"
  462. try:
  463. result = await run_stress_test(
  464. model_name=model_name or "via_function",
  465. model_type=args.type,
  466. concurrency=args.concurrency,
  467. total_count=args.count,
  468. function_name=function_name,
  469. context_size=ctx_size,
  470. bust_cache=args.bust_cache,
  471. )
  472. summary = print_report(result)
  473. summary["context_display"] = ctx_display
  474. results_summary.append(summary)
  475. except Exception as e:
  476. print(f"\n [跳过] {ctx_display}: {e}")
  477. results_summary.append({
  478. "model": function_name or model_name,
  479. "type": args.type,
  480. "context_display": ctx_display,
  481. "error": str(e)[:200],
  482. })
  483. return results_summary
  484. async def main():
  485. args = parse_args()
  486. context_sizes = _parse_context_size(args.context_size)
  487. if args.all_models or args.all_embeddings:
  488. # 逐个测试所有模型
  489. models = LLM_MODELS if args.all_models else EMBEDDING_MODELS
  490. model_type = "llm" if args.all_models else "embedding"
  491. results_summary = []
  492. for model_name in models:
  493. try:
  494. result = await run_stress_test(
  495. model_name=model_name,
  496. model_type=model_type,
  497. concurrency=args.concurrency,
  498. total_count=args.count,
  499. bust_cache=args.bust_cache,
  500. )
  501. summary = print_report(result)
  502. results_summary.append(summary)
  503. except Exception as e:
  504. print(f"\n [跳过] {model_name}: {e}")
  505. results_summary.append({
  506. "model": model_name,
  507. "type": model_type,
  508. "error": str(e)[:200],
  509. })
  510. # 汇总表
  511. print(f"\n\n{'='*90}")
  512. print(f" 汇总对比")
  513. print(f"{'='*90}")
  514. print(f" {'模型':<30} {'成功':>6} {'失败':>6} {'错误率':>8} {'吞吐量':>10} {'延迟avg':>10} {'P95':>10} {'tok/s':>8}")
  515. print(f" {'─'*30} {'─'*6} {'─'*6} {'─'*8} {'─'*10} {'─'*10} {'─'*10} {'─'*8}")
  516. for s in results_summary:
  517. if "error" in s:
  518. print(f" {s['model']:<30} {'SKIP':>6} - {s['error'][:40]}")
  519. else:
  520. tps = s.get('tokens_per_sec', 0)
  521. tps_str = f"{tps:.1f}" if tps > 0 else "n/a"
  522. print(f" {s['model']:<30} {s['success']:>6} {s['fail']:>6} {s['error_rate']:>8}"
  523. f" {s['throughput_rps']:>8.1f}/s {s['latency_avg_ms']:>8.0f}ms {s['latency_p95_ms']:>8.0f}ms {tps_str:>8}")
  524. print(f"{'='*90}")
  525. else:
  526. # 单模型测试(支持多上下文长度)
  527. model_name = args.model
  528. if not model_name and not args.function:
  529. if args.type == "embedding":
  530. model_name = "shutian_qwen3_embed"
  531. else:
  532. model_name = "shutian_qwen3_5_35b"
  533. results_summary = await _run_single_model_test(
  534. args, model_name, args.function, context_sizes,
  535. )
  536. # 多上下文长度时输出对比汇总
  537. if len(context_sizes) > 1:
  538. print(f"\n\n{'='*90}")
  539. print(f" 上下文长度对比: {args.function or model_name}")
  540. print(f"{'='*90}")
  541. print(f" {'上下文':<10} {'成功':>6} {'失败':>6} {'错误率':>8} {'吞吐量':>10} {'延迟avg':>10} {'P95':>10} {'tok/s':>8}")
  542. print(f" {'─'*10} {'─'*6} {'─'*6} {'─'*8} {'─'*10} {'─'*10} {'─'*10} {'─'*8}")
  543. for s in results_summary:
  544. if "error" in s:
  545. print(f" {s['context_display']:<10} {'SKIP':>6} - {s['error'][:40]}")
  546. else:
  547. tps = s.get('tokens_per_sec', 0)
  548. tps_str = f"{tps:.1f}" if tps > 0 else "n/a"
  549. print(f" {s['context_display']:<10} {s['success']:>6} {s['fail']:>6} {s['error_rate']:>8}"
  550. f" {s['throughput_rps']:>8.1f}/s {s['latency_avg_ms']:>8.0f}ms {s['latency_p95_ms']:>8.0f}ms {tps_str:>8}")
  551. print(f"{'='*90}")
  552. if __name__ == "__main__":
  553. asyncio.run(main())