deploy_server_template.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """OpenAI 兼容的模型推理服务器。
  2. 使用方法:
  3. python server.py --port 8000 --host 0.0.0.0
  4. API 端点:
  5. POST /v1/chat/completions - OpenAI 兼容的聊天补全接口
  6. POST /v1/completions - 文本补全接口
  7. GET /v1/models - 模型列表
  8. GET /health - 健康检查
  9. 调用示例:
  10. curl http://localhost:8000/v1/chat/completions \
  11. -H "Content-Type: application/json" \
  12. -d '{
  13. "model": "local-model",
  14. "messages": [{"role": "user", "content": "你好"}],
  15. "max_tokens": 512,
  16. "temperature": 0.7
  17. }'
  18. """
  19. import argparse
  20. import json
  21. import time
  22. import uuid
  23. from pathlib import Path
  24. import torch
  25. from fastapi import FastAPI
  26. from fastapi.middleware.cors import CORSMiddleware
  27. from pydantic import BaseModel
  28. from transformers import AutoModelForCausalLM, AutoTokenizer
  29. app = FastAPI(title="Model Serving API", version="1.0.0")
  30. app.add_middleware(
  31. CORSMiddleware,
  32. allow_origins=["*"],
  33. allow_methods=["*"],
  34. allow_headers=["*"],
  35. )
  36. model = None
  37. tokenizer = None
  38. model_name = "local-model"
  39. # --- Request / Response schemas ---
  40. class Message(BaseModel):
  41. role: str
  42. content: str
  43. class ChatRequest(BaseModel):
  44. model: str = "local-model"
  45. messages: list[Message]
  46. max_tokens: int = 512
  47. temperature: float = 0.7
  48. top_p: float = 0.9
  49. stream: bool = False
  50. class CompletionRequest(BaseModel):
  51. model: str = "local-model"
  52. prompt: str
  53. max_tokens: int = 512
  54. temperature: float = 0.7
  55. top_p: float = 0.9
  56. stream: bool = False
  57. class ChoiceMessage(BaseModel):
  58. role: str = "assistant"
  59. content: str
  60. class Choice(BaseModel):
  61. index: int = 0
  62. message: ChoiceMessage
  63. finish_reason: str = "stop"
  64. class Usage(BaseModel):
  65. prompt_tokens: int
  66. completion_tokens: int
  67. total_tokens: int
  68. class ChatResponse(BaseModel):
  69. id: str
  70. object: str = "chat.completion"
  71. created: int
  72. model: str
  73. choices: list[Choice]
  74. usage: Usage
  75. class CompletionChoice(BaseModel):
  76. index: int = 0
  77. text: str
  78. finish_reason: str = "stop"
  79. class CompletionResponse(BaseModel):
  80. id: str
  81. object: str = "text_completion"
  82. created: int
  83. model: str
  84. choices: list[CompletionChoice]
  85. usage: Usage
  86. class ModelInfo(BaseModel):
  87. id: str
  88. object: str = "model"
  89. created: int = 0
  90. owned_by: str = "local"
  91. class ModelList(BaseModel):
  92. object: str = "list"
  93. data: list[ModelInfo]
  94. # --- Endpoints ---
  95. @app.post("/v1/chat/completions", response_model=ChatResponse)
  96. async def chat_completions(req: ChatRequest):
  97. prompt = _build_prompt(req.messages)
  98. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  99. prompt_tokens = inputs["input_ids"].shape[1]
  100. with torch.no_grad():
  101. outputs = model.generate(
  102. **inputs,
  103. max_new_tokens=req.max_tokens,
  104. temperature=max(req.temperature, 0.01),
  105. top_p=req.top_p,
  106. do_sample=req.temperature > 0,
  107. pad_token_id=tokenizer.eos_token_id,
  108. )
  109. generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
  110. completion_tokens = outputs.shape[1] - prompt_tokens
  111. return ChatResponse(
  112. id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
  113. created=int(time.time()),
  114. model=req.model,
  115. choices=[Choice(message=ChoiceMessage(content=generated))],
  116. usage=Usage(
  117. prompt_tokens=prompt_tokens,
  118. completion_tokens=completion_tokens,
  119. total_tokens=prompt_tokens + completion_tokens,
  120. ),
  121. )
  122. @app.post("/v1/completions", response_model=CompletionResponse)
  123. async def completions(req: CompletionRequest):
  124. inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
  125. prompt_tokens = inputs["input_ids"].shape[1]
  126. with torch.no_grad():
  127. outputs = model.generate(
  128. **inputs,
  129. max_new_tokens=req.max_tokens,
  130. temperature=max(req.temperature, 0.01),
  131. top_p=req.top_p,
  132. do_sample=req.temperature > 0,
  133. pad_token_id=tokenizer.eos_token_id,
  134. )
  135. generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
  136. completion_tokens = outputs.shape[1] - prompt_tokens
  137. return CompletionResponse(
  138. id=f"cmpl-{uuid.uuid4().hex[:12]}",
  139. created=int(time.time()),
  140. model=req.model,
  141. choices=[CompletionChoice(text=generated)],
  142. usage=Usage(
  143. prompt_tokens=prompt_tokens,
  144. completion_tokens=completion_tokens,
  145. total_tokens=prompt_tokens + completion_tokens,
  146. ),
  147. )
  148. @app.get("/v1/models", response_model=ModelList)
  149. async def list_models():
  150. return ModelList(data=[ModelInfo(id=model_name)])
  151. @app.get("/health")
  152. async def health():
  153. return {"status": "ok", "model": model_name}
  154. def _build_prompt(messages: list[Message]) -> str:
  155. """将 OpenAI 消息格式转为模型输入文本。"""
  156. parts = []
  157. for msg in messages:
  158. if msg.role == "system":
  159. parts.append(f"<|system|>\n{msg.content}")
  160. elif msg.role == "user":
  161. parts.append(f"<|user|>\n{msg.content}")
  162. elif msg.role == "assistant":
  163. parts.append(f"<|assistant|>\n{msg.content}")
  164. parts.append("<|assistant|>\n")
  165. return "\n".join(parts)
  166. if __name__ == "__main__":
  167. parser = argparse.ArgumentParser(description="Model Serving API Server")
  168. parser.add_argument("--model-path", type=str, default="./model", help="模型目录路径")
  169. parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
  170. parser.add_argument("--port", type=int, default=8000, help="监听端口")
  171. parser.add_argument("--device", type=str, default="auto", help="设备 (auto/cuda/cpu)")
  172. args = parser.parse_args()
  173. model_path = Path(args.model_path)
  174. model_name = model_path.name
  175. print(f"Loading model from: {model_path}")
  176. if args.device == "auto":
  177. device_map = {"": 0} if torch.cuda.is_available() else "auto"
  178. elif args.device == "cuda":
  179. device_map = {"": 0}
  180. else:
  181. device_map = "cpu"
  182. tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
  183. if tokenizer.pad_token is None:
  184. tokenizer.pad_token = tokenizer.eos_token
  185. model = AutoModelForCausalLM.from_pretrained(
  186. str(model_path), torch_dtype=torch.float16, device_map=device_map,
  187. )
  188. model.eval()
  189. print(f"Model loaded. Starting server on {args.host}:{args.port}")
  190. import uvicorn
  191. uvicorn.run(app, host=args.host, port=args.port)