| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- """轻量推理 worker —— 在算力节点(253)上运行。
- 只依赖 Python 标准库 + torch + transformers(不需要 fastapi/uvicorn)。
- 通过 TCP 接收 JSON 请求,返回 JSON 响应。
- 协议:4 字节大端长度前缀 + JSON body
- 启动:
- python inference_worker.py --model-path /path/to/merged/model --port 8100
- 请求格式:
- {
- "prompt": "<|user|>\\n你好\\n<|assistant|>\\n",
- "max_new_tokens": 512,
- "temperature": 0.7,
- "top_p": 0.9,
- "do_sample": true,
- "repetition_penalty": 1.0
- }
- 响应格式:
- {
- "generated_text": "你好!有什么可以帮你的吗?",
- "prompt_tokens": 12,
- "completion_tokens": 15,
- "total_tokens": 27
- }
- """
- import argparse
- import json
- import socket
- import struct
- import threading
- import sys
- def _build_prompt_from_messages(tokenizer, messages: list[dict]) -> str:
- """将 OpenAI 消息格式转为模型输入文本。
- 优先使用 tokenizer 自带的 apply_chat_template(Qwen3.5 等模型内建了正确的模板),
- 只有当 tokenizer 没有 chat_template 时才回退到手动拼接。
- """
- if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
- return tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- # 回退:手动拼接(兼容没有 chat_template 的模型)
- parts = []
- for msg in messages:
- role = msg.get("role", "")
- content = msg.get("content", "")
- if role == "system":
- parts.append(f"<|system|>\n{content}")
- elif role == "user":
- parts.append(f"<|user|>\n{content}")
- elif role == "assistant":
- parts.append(f"<|assistant|>\n{content}")
- parts.append("<|assistant|>\n")
- return "\n".join(parts)
- def _build_stop_criteria(tokenizer, model_device):
- """构建 StoppingCriteria,遇到角色切换标记或 eos 时停止生成,防止复读。"""
- from transformers import StoppingCriteria, StoppingCriteriaList
- # 收集所有 stop 短语
- stop_phrases = ["<|im_end|>", "<|endoftext|>", "<|eob|>", "<|eol|>", "<|user|>", "<|system|>", "<|assistant|>"]
- stop_token_ids = []
- for phrase in stop_phrases:
- ids = tokenizer.encode(phrase, add_special_tokens=False)
- if ids:
- stop_token_ids.append(ids)
- # 也加入 eos_token_id(如果有)
- if tokenizer.eos_token_id is not None:
- stop_token_ids.append([tokenizer.eos_token_id])
- class StopOnRoleToken(StoppingCriteria):
- def __init__(self, stop_sequences, device):
- self.stop_sequences = stop_sequences
- self.device = device
- def __call__(self, input_ids, scores, **kwargs):
- # 检查最近生成的 token 是否匹配任意 stop 序列
- gen_seq = input_ids[0].tolist()
- for stop_ids in self.stop_sequences:
- if len(gen_seq) >= len(stop_ids):
- if gen_seq[-len(stop_ids):] == stop_ids:
- return True
- return False
- return StoppingCriteriaList([StopOnRoleToken(stop_token_ids, model_device)])
- class InferenceWorker:
- def __init__(self, model_path: str):
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- print(f"[worker] Loading tokenizer from: {model_path}", flush=True)
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- if self.tokenizer.pad_token is None:
- self.tokenizer.pad_token = self.tokenizer.eos_token
- print(f"[worker] Loading model from: {model_path}", flush=True)
- if torch.cuda.is_available():
- num_gpus = torch.cuda.device_count()
- if num_gpus > 1:
- device_map = self._build_device_map(model_path, num_gpus)
- print(f"[worker] Multi-GPU device_map ({num_gpus} GPUs): {device_map}", flush=True)
- else:
- device_map = {"": 0}
- print(f"[worker] Single GPU device_map: cuda:0", flush=True)
- else:
- device_map = "cpu"
- print("[worker] CPU device_map", flush=True)
- self.model = AutoModelForCausalLM.from_pretrained(
- model_path, torch_dtype=torch.float16, device_map=device_map,
- )
- self.model.eval()
- self.torch = torch
- print("[worker] Model loaded successfully.", flush=True)
- @staticmethod
- def _build_device_map(model_path: str, num_gpus: int) -> dict:
- """构建多卡 device_map,确保 tied weights 在同一张卡上。
- HuggingFace 的 device_map="auto" 有时无法正确处理 tied weights
- (embed_tokens 和 lm_head 共享权重),导致它们被分到不同 GPU。
- 这里手动构建映射,将 tied weights 强制放在同一张卡。
- """
- from transformers import AutoConfig
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- num_layers = getattr(config, "num_hidden_layers", None)
- if num_layers is None:
- return "auto"
- layers_per_gpu = num_layers // num_gpus
- remainder = num_layers % num_gpus
- device_map = {}
- layer_idx = 0
- for gpu in range(num_gpus):
- count = layers_per_gpu + (1 if gpu < remainder else 0)
- for _ in range(count):
- device_map[f"model.layers.{layer_idx}"] = gpu
- layer_idx += 1
- # 核心:tied weights 强制放在同一张卡(第 0 张)
- # embed_tokens 和 lm_head 共享 Embedding 权重
- device_map["model.embed_tokens"] = 0
- device_map["model.norm"] = 0
- device_map["lm_head"] = 0
- # Qwen 等模型可能有 rotary_emb
- if hasattr(config, "rope_theta") or hasattr(config, "rotary_emb"):
- device_map["model.rotary_emb"] = 0
- return device_map
- def generate(self, request: dict) -> dict:
- """处理一次推理请求。"""
- # 支持两种输入:messages(OpenAI 格式)或 prompt(原始文本)
- messages = request.get("messages")
- if messages:
- prompt = _build_prompt_from_messages(self.tokenizer, messages)
- else:
- prompt = request.get("prompt", "")
- max_new_tokens = request.get("max_tokens", request.get("max_new_tokens", 512))
- temperature = max(request.get("temperature", 0.7), 0.01)
- top_p = request.get("top_p", 0.9)
- do_sample = request.get("do_sample", temperature > 0)
- repetition_penalty = request.get("repetition_penalty", 1.1)
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
- prompt_tokens = inputs["input_ids"].shape[1]
- # 构建 stop criteria:遇到角色标记就停止,防止复读
- stopping_criteria = _build_stop_criteria(self.tokenizer, self.model.device)
- with self.torch.no_grad():
- outputs = self.model.generate(
- **inputs,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- do_sample=do_sample,
- repetition_penalty=repetition_penalty,
- pad_token_id=self.tokenizer.eos_token_id,
- eos_token_id=self.tokenizer.eos_token_id,
- stopping_criteria=stopping_criteria,
- )
- generated = self.tokenizer.decode(
- outputs[0][prompt_tokens:], skip_special_tokens=True
- )
- # 文本级兜底截断:在生成文本中找到最早的 stop 标记并截断
- # 防止 StoppingCriteria 因 tokenizer 编码差异未能触发
- _stop_markers = ["<|eob|>", "<|im_end|>", "<|endoftext|>",
- "<|user|>", "<|system|>", "<|assistant|>"]
- earliest = len(generated)
- for marker in _stop_markers:
- idx = generated.find(marker)
- if idx != -1 and idx < earliest:
- earliest = idx
- generated = generated[:earliest].strip()
- completion_tokens = outputs.shape[1] - prompt_tokens
- return {
- "generated_text": generated,
- "prompt_tokens": int(prompt_tokens),
- "completion_tokens": int(completion_tokens),
- "total_tokens": int(prompt_tokens + completion_tokens),
- }
- def _recv_exact(sock: socket.socket, n: int) -> bytes:
- """确保接收恰好 n 字节。"""
- buf = bytearray()
- while len(buf) < n:
- chunk = sock.recv(n - len(buf))
- if not chunk:
- raise ConnectionError("Connection closed while reading")
- buf.extend(chunk)
- return bytes(buf)
- def handle_client(worker: InferenceWorker, conn: socket.socket, addr):
- """处理单个 TCP 客户端连接。"""
- try:
- # 读取 4 字节长度前缀
- len_data = _recv_exact(conn, 4)
- length = struct.unpack(">I", len_data)[0]
- # 读取 JSON body
- body_data = _recv_exact(conn, length)
- request = json.loads(body_data.decode("utf-8"))
- print(f"[worker] Request from {addr}: {list(request.keys())}", flush=True)
- # 执行推理
- response = worker.generate(request)
- print(
- f"[worker] Response: {response['completion_tokens']} tokens generated",
- flush=True,
- )
- # 发送响应
- resp_bytes = json.dumps(response, ensure_ascii=False).encode("utf-8")
- conn.sendall(struct.pack(">I", len(resp_bytes)))
- conn.sendall(resp_bytes)
- except Exception as e:
- print(f"[worker] Error handling {addr}: {e}", flush=True)
- try:
- error_resp = json.dumps({"error": str(e)}).encode("utf-8")
- conn.sendall(struct.pack(">I", len(error_resp)))
- conn.sendall(error_resp)
- except Exception:
- pass
- finally:
- conn.close()
- def main():
- parser = argparse.ArgumentParser(description="Lightweight Inference Worker")
- parser.add_argument("--model-path", type=str, required=True, help="模型目录路径")
- parser.add_argument("--port", type=int, required=True, help="监听端口")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
- args = parser.parse_args()
- print(f"[worker] Initializing...", flush=True)
- worker = InferenceWorker(args.model_path)
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server.bind((args.host, args.port))
- server.listen(2)
- print(
- f"[worker] Listening on {args.host}:{args.port} (TCP, length-prefixed JSON)",
- flush=True,
- )
- # 通知启动脚本:服务已就绪
- print("[worker] READY", flush=True)
- def accept_loop():
- while True:
- try:
- conn, addr = server.accept()
- t = threading.Thread(target=handle_client, args=(worker, conn, addr))
- t.daemon = True
- t.start()
- except OSError:
- break # server closed
- except Exception as e:
- print(f"[worker] Accept error: {e}", flush=True)
- accept_thread = threading.Thread(target=accept_loop, daemon=True)
- accept_thread.start()
- try:
- accept_thread.join()
- except KeyboardInterrupt:
- print("[worker] Shutting down...", flush=True)
- server.close()
- if __name__ == "__main__":
- main()
|