inference_worker.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """轻量推理 worker —— 在算力节点(253)上运行。
  2. 只依赖 Python 标准库 + torch + transformers(不需要 fastapi/uvicorn)。
  3. 通过 TCP 接收 JSON 请求,返回 JSON 响应。
  4. 协议:4 字节大端长度前缀 + JSON body
  5. 启动:
  6. python inference_worker.py --model-path /path/to/merged/model --port 8100
  7. 请求格式:
  8. {
  9. "prompt": "<|user|>\\n你好\\n<|assistant|>\\n",
  10. "max_new_tokens": 512,
  11. "temperature": 0.7,
  12. "top_p": 0.9,
  13. "do_sample": true,
  14. "repetition_penalty": 1.0
  15. }
  16. 响应格式:
  17. {
  18. "generated_text": "你好!有什么可以帮你的吗?",
  19. "prompt_tokens": 12,
  20. "completion_tokens": 15,
  21. "total_tokens": 27
  22. }
  23. """
  24. import argparse
  25. import json
  26. import socket
  27. import struct
  28. import threading
  29. import sys
  30. def _build_prompt_from_messages(tokenizer, messages: list[dict]) -> str:
  31. """将 OpenAI 消息格式转为模型输入文本。
  32. 优先使用 tokenizer 自带的 apply_chat_template(Qwen3.5 等模型内建了正确的模板),
  33. 只有当 tokenizer 没有 chat_template 时才回退到手动拼接。
  34. """
  35. if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
  36. return tokenizer.apply_chat_template(
  37. messages, tokenize=False, add_generation_prompt=True
  38. )
  39. # 回退:手动拼接(兼容没有 chat_template 的模型)
  40. parts = []
  41. for msg in messages:
  42. role = msg.get("role", "")
  43. content = msg.get("content", "")
  44. if role == "system":
  45. parts.append(f"<|system|>\n{content}")
  46. elif role == "user":
  47. parts.append(f"<|user|>\n{content}")
  48. elif role == "assistant":
  49. parts.append(f"<|assistant|>\n{content}")
  50. parts.append("<|assistant|>\n")
  51. return "\n".join(parts)
  52. def _build_stop_criteria(tokenizer, model_device):
  53. """构建 StoppingCriteria,遇到角色切换标记或 eos 时停止生成,防止复读。"""
  54. from transformers import StoppingCriteria, StoppingCriteriaList
  55. # 收集所有 stop 短语
  56. stop_phrases = ["<|im_end|>", "<|endoftext|>", "<|eob|>", "<|eol|>", "<|user|>", "<|system|>", "<|assistant|>"]
  57. stop_token_ids = []
  58. for phrase in stop_phrases:
  59. ids = tokenizer.encode(phrase, add_special_tokens=False)
  60. if ids:
  61. stop_token_ids.append(ids)
  62. # 也加入 eos_token_id(如果有)
  63. if tokenizer.eos_token_id is not None:
  64. stop_token_ids.append([tokenizer.eos_token_id])
  65. class StopOnRoleToken(StoppingCriteria):
  66. def __init__(self, stop_sequences, device):
  67. self.stop_sequences = stop_sequences
  68. self.device = device
  69. def __call__(self, input_ids, scores, **kwargs):
  70. # 检查最近生成的 token 是否匹配任意 stop 序列
  71. gen_seq = input_ids[0].tolist()
  72. for stop_ids in self.stop_sequences:
  73. if len(gen_seq) >= len(stop_ids):
  74. if gen_seq[-len(stop_ids):] == stop_ids:
  75. return True
  76. return False
  77. return StoppingCriteriaList([StopOnRoleToken(stop_token_ids, model_device)])
  78. class InferenceWorker:
  79. def __init__(self, model_path: str):
  80. import torch
  81. from transformers import AutoModelForCausalLM, AutoTokenizer
  82. print(f"[worker] Loading tokenizer from: {model_path}", flush=True)
  83. self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  84. if self.tokenizer.pad_token is None:
  85. self.tokenizer.pad_token = self.tokenizer.eos_token
  86. print(f"[worker] Loading model from: {model_path}", flush=True)
  87. # 单卡加载,与训练 DDP 模式一致:每个进程只用一张 GPU
  88. # 避免 device_map="auto" 拆分模型导致 rotary_emb 等共享模块跨卡报错
  89. # CUDA_VISIBLE_DEVICES 由启动脚本设置,cuda:0 就是第一张可见 GPU
  90. if torch.cuda.is_available():
  91. device_map = {"": 0}
  92. print(f"[worker] Single GPU device_map: cuda:0", flush=True)
  93. else:
  94. device_map = "cpu"
  95. print("[worker] CPU device_map", flush=True)
  96. self.model = AutoModelForCausalLM.from_pretrained(
  97. model_path, torch_dtype=torch.float16, device_map=device_map,
  98. )
  99. self.model.eval()
  100. self.torch = torch
  101. print("[worker] Model loaded successfully.", flush=True)
  102. def generate(self, request: dict) -> dict:
  103. """处理一次推理请求。"""
  104. # 支持两种输入:messages(OpenAI 格式)或 prompt(原始文本)
  105. messages = request.get("messages")
  106. if messages:
  107. prompt = _build_prompt_from_messages(self.tokenizer, messages)
  108. else:
  109. prompt = request.get("prompt", "")
  110. max_new_tokens = request.get("max_tokens", request.get("max_new_tokens", 512))
  111. temperature = max(request.get("temperature", 0.7), 0.01)
  112. top_p = request.get("top_p", 0.9)
  113. do_sample = request.get("do_sample", temperature > 0)
  114. repetition_penalty = request.get("repetition_penalty", 1.1)
  115. inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
  116. prompt_tokens = inputs["input_ids"].shape[1]
  117. # 构建 stop criteria:遇到角色标记就停止,防止复读
  118. stopping_criteria = _build_stop_criteria(self.tokenizer, self.model.device)
  119. with self.torch.no_grad():
  120. outputs = self.model.generate(
  121. **inputs,
  122. max_new_tokens=max_new_tokens,
  123. temperature=temperature,
  124. top_p=top_p,
  125. do_sample=do_sample,
  126. repetition_penalty=repetition_penalty,
  127. pad_token_id=self.tokenizer.eos_token_id,
  128. eos_token_id=self.tokenizer.eos_token_id,
  129. stopping_criteria=stopping_criteria,
  130. )
  131. generated = self.tokenizer.decode(
  132. outputs[0][prompt_tokens:], skip_special_tokens=True
  133. )
  134. # 文本级兜底截断:在生成文本中找到最早的 stop 标记并截断
  135. # 防止 StoppingCriteria 因 tokenizer 编码差异未能触发
  136. _stop_markers = ["<|eob|>", "<|im_end|>", "<|endoftext|>",
  137. "<|user|>", "<|system|>", "<|assistant|>"]
  138. earliest = len(generated)
  139. for marker in _stop_markers:
  140. idx = generated.find(marker)
  141. if idx != -1 and idx < earliest:
  142. earliest = idx
  143. generated = generated[:earliest].strip()
  144. completion_tokens = outputs.shape[1] - prompt_tokens
  145. return {
  146. "generated_text": generated,
  147. "prompt_tokens": int(prompt_tokens),
  148. "completion_tokens": int(completion_tokens),
  149. "total_tokens": int(prompt_tokens + completion_tokens),
  150. }
  151. def _recv_exact(sock: socket.socket, n: int) -> bytes:
  152. """确保接收恰好 n 字节。"""
  153. buf = bytearray()
  154. while len(buf) < n:
  155. chunk = sock.recv(n - len(buf))
  156. if not chunk:
  157. raise ConnectionError("Connection closed while reading")
  158. buf.extend(chunk)
  159. return bytes(buf)
  160. def handle_client(worker: InferenceWorker, conn: socket.socket, addr):
  161. """处理单个 TCP 客户端连接。"""
  162. try:
  163. # 读取 4 字节长度前缀
  164. len_data = _recv_exact(conn, 4)
  165. length = struct.unpack(">I", len_data)[0]
  166. # 读取 JSON body
  167. body_data = _recv_exact(conn, length)
  168. request = json.loads(body_data.decode("utf-8"))
  169. print(f"[worker] Request from {addr}: {list(request.keys())}", flush=True)
  170. # 执行推理
  171. response = worker.generate(request)
  172. print(
  173. f"[worker] Response: {response['completion_tokens']} tokens generated",
  174. flush=True,
  175. )
  176. # 发送响应
  177. resp_bytes = json.dumps(response, ensure_ascii=False).encode("utf-8")
  178. conn.sendall(struct.pack(">I", len(resp_bytes)))
  179. conn.sendall(resp_bytes)
  180. except Exception as e:
  181. print(f"[worker] Error handling {addr}: {e}", flush=True)
  182. try:
  183. error_resp = json.dumps({"error": str(e)}).encode("utf-8")
  184. conn.sendall(struct.pack(">I", len(error_resp)))
  185. conn.sendall(error_resp)
  186. except Exception:
  187. pass
  188. finally:
  189. conn.close()
  190. def main():
  191. parser = argparse.ArgumentParser(description="Lightweight Inference Worker")
  192. parser.add_argument("--model-path", type=str, required=True, help="模型目录路径")
  193. parser.add_argument("--port", type=int, required=True, help="监听端口")
  194. parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
  195. args = parser.parse_args()
  196. print(f"[worker] Initializing...", flush=True)
  197. worker = InferenceWorker(args.model_path)
  198. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  199. server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  200. server.bind((args.host, args.port))
  201. server.listen(2)
  202. print(
  203. f"[worker] Listening on {args.host}:{args.port} (TCP, length-prefixed JSON)",
  204. flush=True,
  205. )
  206. # 通知启动脚本:服务已就绪
  207. print("[worker] READY", flush=True)
  208. def accept_loop():
  209. while True:
  210. try:
  211. conn, addr = server.accept()
  212. t = threading.Thread(target=handle_client, args=(worker, conn, addr))
  213. t.daemon = True
  214. t.start()
  215. except OSError:
  216. break # server closed
  217. except Exception as e:
  218. print(f"[worker] Accept error: {e}", flush=True)
  219. accept_thread = threading.Thread(target=accept_loop, daemon=True)
  220. accept_thread.start()
  221. try:
  222. accept_thread.join()
  223. except KeyboardInterrupt:
  224. print("[worker] Shutting down...", flush=True)
  225. server.close()
  226. if __name__ == "__main__":
  227. main()