"""轻量推理 worker —— 在算力节点(253)上运行。 只依赖 Python 标准库 + torch + transformers(不需要 fastapi/uvicorn)。 通过 TCP 接收 JSON 请求,返回 JSON 响应。 协议:4 字节大端长度前缀 + JSON body 启动: python inference_worker.py --model-path /path/to/merged/model --port 8100 请求格式: { "prompt": "<|user|>\\n你好\\n<|assistant|>\\n", "max_new_tokens": 512, "temperature": 0.7, "top_p": 0.9, "do_sample": true, "repetition_penalty": 1.0 } 响应格式: { "generated_text": "你好!有什么可以帮你的吗?", "prompt_tokens": 12, "completion_tokens": 15, "total_tokens": 27 } """ import argparse import json import socket import struct import threading import sys def _build_prompt_from_messages(tokenizer, messages: list[dict]) -> str: """将 OpenAI 消息格式转为模型输入文本。 优先使用 tokenizer 自带的 apply_chat_template(Qwen3.5 等模型内建了正确的模板), 只有当 tokenizer 没有 chat_template 时才回退到手动拼接。 """ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # 回退:手动拼接(兼容没有 chat_template 的模型) parts = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": parts.append(f"<|system|>\n{content}") elif role == "user": parts.append(f"<|user|>\n{content}") elif role == "assistant": parts.append(f"<|assistant|>\n{content}") parts.append("<|assistant|>\n") return "\n".join(parts) def _build_stop_criteria(tokenizer, model_device): """构建 StoppingCriteria,遇到角色切换标记或 eos 时停止生成,防止复读。""" from transformers import StoppingCriteria, StoppingCriteriaList # 收集所有 stop 短语 stop_phrases = ["<|im_end|>", "<|endoftext|>", "<|eob|>", "<|eol|>", "<|user|>", "<|system|>", "<|assistant|>"] stop_token_ids = [] for phrase in stop_phrases: ids = tokenizer.encode(phrase, add_special_tokens=False) if ids: stop_token_ids.append(ids) # 也加入 eos_token_id(如果有) if tokenizer.eos_token_id is not None: stop_token_ids.append([tokenizer.eos_token_id]) class StopOnRoleToken(StoppingCriteria): def __init__(self, stop_sequences, device): self.stop_sequences = stop_sequences self.device = device def __call__(self, input_ids, scores, **kwargs): # 检查最近生成的 token 是否匹配任意 stop 序列 gen_seq = input_ids[0].tolist() for stop_ids in self.stop_sequences: if len(gen_seq) >= len(stop_ids): if gen_seq[-len(stop_ids):] == stop_ids: return True return False return StoppingCriteriaList([StopOnRoleToken(stop_token_ids, model_device)]) class InferenceWorker: def __init__(self, model_path: str): import torch from transformers import AutoModelForCausalLM, AutoTokenizer print(f"[worker] Loading tokenizer from: {model_path}", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"[worker] Loading model from: {model_path}", flush=True) # 单卡加载,与训练 DDP 模式一致:每个进程只用一张 GPU # 避免 device_map="auto" 拆分模型导致 rotary_emb 等共享模块跨卡报错 # CUDA_VISIBLE_DEVICES 由启动脚本设置,cuda:0 就是第一张可见 GPU if torch.cuda.is_available(): device_map = {"": 0} print(f"[worker] Single GPU device_map: cuda:0", flush=True) else: device_map = "cpu" print("[worker] CPU device_map", flush=True) self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map=device_map, ) self.model.eval() self.torch = torch print("[worker] Model loaded successfully.", flush=True) def generate(self, request: dict) -> dict: """处理一次推理请求。""" # 支持两种输入:messages(OpenAI 格式)或 prompt(原始文本) messages = request.get("messages") if messages: prompt = _build_prompt_from_messages(self.tokenizer, messages) else: prompt = request.get("prompt", "") max_new_tokens = request.get("max_tokens", request.get("max_new_tokens", 512)) temperature = max(request.get("temperature", 0.7), 0.01) top_p = request.get("top_p", 0.9) do_sample = request.get("do_sample", temperature > 0) repetition_penalty = request.get("repetition_penalty", 1.1) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) prompt_tokens = inputs["input_ids"].shape[1] # 构建 stop criteria:遇到角色标记就停止,防止复读 stopping_criteria = _build_stop_criteria(self.tokenizer, self.model.device) with self.torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, stopping_criteria=stopping_criteria, ) generated = self.tokenizer.decode( outputs[0][prompt_tokens:], skip_special_tokens=True ) # 文本级兜底截断:在生成文本中找到最早的 stop 标记并截断 # 防止 StoppingCriteria 因 tokenizer 编码差异未能触发 _stop_markers = ["<|eob|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|system|>", "<|assistant|>"] earliest = len(generated) for marker in _stop_markers: idx = generated.find(marker) if idx != -1 and idx < earliest: earliest = idx generated = generated[:earliest].strip() completion_tokens = outputs.shape[1] - prompt_tokens return { "generated_text": generated, "prompt_tokens": int(prompt_tokens), "completion_tokens": int(completion_tokens), "total_tokens": int(prompt_tokens + completion_tokens), } def _recv_exact(sock: socket.socket, n: int) -> bytes: """确保接收恰好 n 字节。""" buf = bytearray() while len(buf) < n: chunk = sock.recv(n - len(buf)) if not chunk: raise ConnectionError("Connection closed while reading") buf.extend(chunk) return bytes(buf) def handle_client(worker: InferenceWorker, conn: socket.socket, addr): """处理单个 TCP 客户端连接。""" try: # 读取 4 字节长度前缀 len_data = _recv_exact(conn, 4) length = struct.unpack(">I", len_data)[0] # 读取 JSON body body_data = _recv_exact(conn, length) request = json.loads(body_data.decode("utf-8")) print(f"[worker] Request from {addr}: {list(request.keys())}", flush=True) # 执行推理 response = worker.generate(request) print( f"[worker] Response: {response['completion_tokens']} tokens generated", flush=True, ) # 发送响应 resp_bytes = json.dumps(response, ensure_ascii=False).encode("utf-8") conn.sendall(struct.pack(">I", len(resp_bytes))) conn.sendall(resp_bytes) except Exception as e: print(f"[worker] Error handling {addr}: {e}", flush=True) try: error_resp = json.dumps({"error": str(e)}).encode("utf-8") conn.sendall(struct.pack(">I", len(error_resp))) conn.sendall(error_resp) except Exception: pass finally: conn.close() def main(): parser = argparse.ArgumentParser(description="Lightweight Inference Worker") parser.add_argument("--model-path", type=str, required=True, help="模型目录路径") parser.add_argument("--port", type=int, required=True, help="监听端口") parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址") args = parser.parse_args() print(f"[worker] Initializing...", flush=True) worker = InferenceWorker(args.model_path) server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.bind((args.host, args.port)) server.listen(2) print( f"[worker] Listening on {args.host}:{args.port} (TCP, length-prefixed JSON)", flush=True, ) # 通知启动脚本:服务已就绪 print("[worker] READY", flush=True) def accept_loop(): while True: try: conn, addr = server.accept() t = threading.Thread(target=handle_client, args=(worker, conn, addr)) t.daemon = True t.start() except OSError: break # server closed except Exception as e: print(f"[worker] Accept error: {e}", flush=True) accept_thread = threading.Thread(target=accept_loop, daemon=True) accept_thread.start() try: accept_thread.join() except KeyboardInterrupt: print("[worker] Shutting down...", flush=True) server.close() if __name__ == "__main__": main()