deploy_server_template.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. """OpenAI 兼容的模型推理服务器。
  2. 使用方法:
  3. python server.py --port 8000 --host 0.0.0.0
  4. API 端点:
  5. POST /v1/chat/completions - OpenAI 兼容的聊天补全接口
  6. POST /v1/completions - 文本补全接口
  7. GET /v1/models - 模型列表
  8. GET /health - 健康检查
  9. 调用示例:
  10. curl http://localhost:8000/v1/chat/completions \
  11. -H "Content-Type: application/json" \
  12. -d '{
  13. "model": "local-model",
  14. "messages": [{"role": "user", "content": "你好"}],
  15. "max_tokens": 512,
  16. "temperature": 0.7
  17. }'
  18. """
  19. import argparse
  20. import json
  21. import time
  22. import uuid
  23. from pathlib import Path
  24. import torch
  25. from fastapi import FastAPI
  26. from fastapi.middleware.cors import CORSMiddleware
  27. from pydantic import BaseModel
  28. from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
  29. app = FastAPI(title="Model Serving API", version="1.0.0")
  30. app.add_middleware(
  31. CORSMiddleware,
  32. allow_origins=["*"],
  33. allow_methods=["*"],
  34. allow_headers=["*"],
  35. )
  36. model = None
  37. tokenizer = None
  38. model_name = "local-model"
  39. # --- Request / Response schemas ---
  40. class Message(BaseModel):
  41. role: str
  42. content: str
  43. class ChatRequest(BaseModel):
  44. model: str = "local-model"
  45. messages: list[Message]
  46. max_tokens: int = 512
  47. temperature: float = 0.7
  48. top_p: float = 0.9
  49. stream: bool = False
  50. class CompletionRequest(BaseModel):
  51. model: str = "local-model"
  52. prompt: str
  53. max_tokens: int = 512
  54. temperature: float = 0.7
  55. top_p: float = 0.9
  56. stream: bool = False
  57. class ChoiceMessage(BaseModel):
  58. role: str = "assistant"
  59. content: str
  60. class Choice(BaseModel):
  61. index: int = 0
  62. message: ChoiceMessage
  63. finish_reason: str = "stop"
  64. class Usage(BaseModel):
  65. prompt_tokens: int
  66. completion_tokens: int
  67. total_tokens: int
  68. class ChatResponse(BaseModel):
  69. id: str
  70. object: str = "chat.completion"
  71. created: int
  72. model: str
  73. choices: list[Choice]
  74. usage: Usage
  75. class CompletionChoice(BaseModel):
  76. index: int = 0
  77. text: str
  78. finish_reason: str = "stop"
  79. class CompletionResponse(BaseModel):
  80. id: str
  81. object: str = "text_completion"
  82. created: int
  83. model: str
  84. choices: list[CompletionChoice]
  85. usage: Usage
  86. class ModelInfo(BaseModel):
  87. id: str
  88. object: str = "model"
  89. created: int = 0
  90. owned_by: str = "local"
  91. class ModelList(BaseModel):
  92. object: str = "list"
  93. data: list[ModelInfo]
  94. # --- Endpoints ---
  95. @app.post("/v1/chat/completions", response_model=ChatResponse)
  96. async def chat_completions(req: ChatRequest):
  97. prompt = _build_prompt(req.messages)
  98. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  99. prompt_tokens = inputs["input_ids"].shape[1]
  100. stopping_criteria = _build_stop_criteria()
  101. with torch.no_grad():
  102. outputs = model.generate(
  103. **inputs,
  104. max_new_tokens=req.max_tokens,
  105. temperature=max(req.temperature, 0.01),
  106. top_p=req.top_p,
  107. do_sample=req.temperature > 0,
  108. repetition_penalty=1.1,
  109. pad_token_id=tokenizer.eos_token_id,
  110. eos_token_id=tokenizer.eos_token_id,
  111. stopping_criteria=stopping_criteria,
  112. )
  113. generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
  114. generated = _clean_generated(generated)
  115. completion_tokens = outputs.shape[1] - prompt_tokens
  116. return ChatResponse(
  117. id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
  118. created=int(time.time()),
  119. model=req.model,
  120. choices=[Choice(message=ChoiceMessage(content=generated))],
  121. usage=Usage(
  122. prompt_tokens=prompt_tokens,
  123. completion_tokens=completion_tokens,
  124. total_tokens=prompt_tokens + completion_tokens,
  125. ),
  126. )
  127. @app.post("/v1/completions", response_model=CompletionResponse)
  128. async def completions(req: CompletionRequest):
  129. inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
  130. prompt_tokens = inputs["input_ids"].shape[1]
  131. stopping_criteria = _build_stop_criteria()
  132. with torch.no_grad():
  133. outputs = model.generate(
  134. **inputs,
  135. max_new_tokens=req.max_tokens,
  136. temperature=max(req.temperature, 0.01),
  137. top_p=req.top_p,
  138. do_sample=req.temperature > 0,
  139. repetition_penalty=1.1,
  140. pad_token_id=tokenizer.eos_token_id,
  141. eos_token_id=tokenizer.eos_token_id,
  142. stopping_criteria=stopping_criteria,
  143. )
  144. generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
  145. generated = _clean_generated(generated)
  146. completion_tokens = outputs.shape[1] - prompt_tokens
  147. return CompletionResponse(
  148. id=f"cmpl-{uuid.uuid4().hex[:12]}",
  149. created=int(time.time()),
  150. model=req.model,
  151. choices=[CompletionChoice(text=generated)],
  152. usage=Usage(
  153. prompt_tokens=prompt_tokens,
  154. completion_tokens=completion_tokens,
  155. total_tokens=prompt_tokens + completion_tokens,
  156. ),
  157. )
  158. @app.get("/v1/models", response_model=ModelList)
  159. async def list_models():
  160. return ModelList(data=[ModelInfo(id=model_name)])
  161. @app.get("/health")
  162. async def health():
  163. return {"status": "ok", "model": model_name}
  164. def _build_prompt(messages: list[Message]) -> str:
  165. """将 OpenAI 消息格式转为模型输入文本。"""
  166. parts = []
  167. for msg in messages:
  168. if msg.role == "system":
  169. parts.append(f"<|system|>\n{msg.content}")
  170. elif msg.role == "user":
  171. parts.append(f"<|user|>\n{msg.content}")
  172. elif msg.role == "assistant":
  173. parts.append(f"<|assistant|>\n{msg.content}")
  174. parts.append("<|assistant|>\n")
  175. return "\n".join(parts)
  176. def _build_stop_criteria():
  177. """构建 StoppingCriteria,遇到角色切换标记时停止生成,防止复读。"""
  178. stop_phrases = ["<|user|>", "<|system|>", "<|assistant|>"]
  179. stop_token_ids = []
  180. for phrase in stop_phrases:
  181. ids = tokenizer.encode(phrase, add_special_tokens=False)
  182. stop_token_ids.append(ids)
  183. class StopOnRoleToken(StoppingCriteria):
  184. def __init__(self, stop_sequences):
  185. self.stop_sequences = stop_sequences
  186. def __call__(self, input_ids, scores, **kwargs):
  187. gen_seq = input_ids[0].tolist()
  188. for stop_ids in self.stop_sequences:
  189. if len(gen_seq) >= len(stop_ids):
  190. if gen_seq[-len(stop_ids):] == stop_ids:
  191. return True
  192. return False
  193. return StoppingCriteriaList([StopOnRoleToken(stop_token_ids)])
  194. def _clean_generated(generated: str) -> str:
  195. """清理可能残留的角色标记。"""
  196. for marker in ["<|user|>", "<|system|>", "<|assistant|>"]:
  197. if marker in generated:
  198. generated = generated[:generated.index(marker)]
  199. return generated.strip()
  200. if __name__ == "__main__":
  201. parser = argparse.ArgumentParser(description="Model Serving API Server")
  202. parser.add_argument("--model-path", type=str, default="./model", help="模型目录路径")
  203. parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
  204. parser.add_argument("--port", type=int, default=8000, help="监听端口")
  205. parser.add_argument("--device", type=str, default="auto", help="设备 (auto/cuda/cpu)")
  206. args = parser.parse_args()
  207. model_path = Path(args.model_path)
  208. model_name = model_path.name
  209. print(f"Loading model from: {model_path}")
  210. if args.device == "auto":
  211. # device_map="auto" 自动将模型层分散到所有可见 GPU
  212. device_map = "auto" if torch.cuda.is_available() else "cpu"
  213. elif args.device == "cuda":
  214. device_map = "auto" # 自动多卡分散
  215. else:
  216. device_map = "cpu"
  217. tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
  218. if tokenizer.pad_token is None:
  219. tokenizer.pad_token = tokenizer.eos_token
  220. model = AutoModelForCausalLM.from_pretrained(
  221. str(model_path), torch_dtype=torch.float16, device_map=device_map,
  222. )
  223. model.eval()
  224. print(f"Model loaded. Starting server on {args.host}:{args.port}")
  225. import uvicorn
  226. uvicorn.run(app, host=args.host, port=args.port)