"""OpenAI 兼容的模型推理服务器。 使用方法: python server.py --port 8000 --host 0.0.0.0 API 端点: POST /v1/chat/completions - OpenAI 兼容的聊天补全接口 POST /v1/completions - 文本补全接口 GET /v1/models - 模型列表 GET /health - 健康检查 调用示例: curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "local-model", "messages": [{"role": "user", "content": "你好"}], "max_tokens": 512, "temperature": 0.7 }' """ import argparse import json import time import uuid from pathlib import Path import torch from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer app = FastAPI(title="Model Serving API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) model = None tokenizer = None model_name = "local-model" # --- Request / Response schemas --- class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str = "local-model" messages: list[Message] max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.9 stream: bool = False class CompletionRequest(BaseModel): model: str = "local-model" prompt: str max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.9 stream: bool = False class ChoiceMessage(BaseModel): role: str = "assistant" content: str class Choice(BaseModel): index: int = 0 message: ChoiceMessage finish_reason: str = "stop" class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: list[Choice] usage: Usage class CompletionChoice(BaseModel): index: int = 0 text: str finish_reason: str = "stop" class CompletionResponse(BaseModel): id: str object: str = "text_completion" created: int model: str choices: list[CompletionChoice] usage: Usage class ModelInfo(BaseModel): id: str object: str = "model" created: int = 0 owned_by: str = "local" class ModelList(BaseModel): object: str = "list" data: list[ModelInfo] # --- Endpoints --- @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(req: ChatRequest): prompt = _build_prompt(req.messages) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) prompt_tokens = inputs["input_ids"].shape[1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=req.max_tokens, temperature=max(req.temperature, 0.01), top_p=req.top_p, do_sample=req.temperature > 0, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True) completion_tokens = outputs.shape[1] - prompt_tokens return ChatResponse( id=f"chatcmpl-{uuid.uuid4().hex[:12]}", created=int(time.time()), model=req.model, choices=[Choice(message=ChoiceMessage(content=generated))], usage=Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), ) @app.post("/v1/completions", response_model=CompletionResponse) async def completions(req: CompletionRequest): inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device) prompt_tokens = inputs["input_ids"].shape[1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=req.max_tokens, temperature=max(req.temperature, 0.01), top_p=req.top_p, do_sample=req.temperature > 0, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True) completion_tokens = outputs.shape[1] - prompt_tokens return CompletionResponse( id=f"cmpl-{uuid.uuid4().hex[:12]}", created=int(time.time()), model=req.model, choices=[CompletionChoice(text=generated)], usage=Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), ) @app.get("/v1/models", response_model=ModelList) async def list_models(): return ModelList(data=[ModelInfo(id=model_name)]) @app.get("/health") async def health(): return {"status": "ok", "model": model_name} def _build_prompt(messages: list[Message]) -> str: """将 OpenAI 消息格式转为模型输入文本。""" parts = [] for msg in messages: if msg.role == "system": parts.append(f"<|system|>\n{msg.content}") elif msg.role == "user": parts.append(f"<|user|>\n{msg.content}") elif msg.role == "assistant": parts.append(f"<|assistant|>\n{msg.content}") parts.append("<|assistant|>\n") return "\n".join(parts) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Model Serving API Server") parser.add_argument("--model-path", type=str, default="./model", help="模型目录路径") parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址") parser.add_argument("--port", type=int, default=8000, help="监听端口") parser.add_argument("--device", type=str, default="auto", help="设备 (auto/cuda/cpu)") args = parser.parse_args() model_path = Path(args.model_path) model_name = model_path.name print(f"Loading model from: {model_path}") if args.device == "auto": device_map = {"": 0} if torch.cuda.is_available() else "auto" elif args.device == "cuda": device_map = {"": 0} else: device_map = "cpu" tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( str(model_path), torch_dtype=torch.float16, device_map=device_map, ) model.eval() print(f"Model loaded. Starting server on {args.host}:{args.port}") import uvicorn uvicorn.run(app, host=args.host, port=args.port)