| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- """OpenAI 兼容的模型推理服务器。
- 使用方法:
- python server.py --port 8000 --host 0.0.0.0
- API 端点:
- POST /v1/chat/completions - OpenAI 兼容的聊天补全接口
- POST /v1/completions - 文本补全接口
- GET /v1/models - 模型列表
- GET /health - 健康检查
- 调用示例:
- curl http://localhost:8000/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "local-model",
- "messages": [{"role": "user", "content": "你好"}],
- "max_tokens": 512,
- "temperature": 0.7
- }'
- """
- import argparse
- import json
- import time
- import uuid
- from pathlib import Path
- import torch
- from fastapi import FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- from transformers import AutoModelForCausalLM, AutoTokenizer
- app = FastAPI(title="Model Serving API", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
- )
- model = None
- tokenizer = None
- model_name = "local-model"
- # --- Request / Response schemas ---
- class Message(BaseModel):
- role: str
- content: str
- class ChatRequest(BaseModel):
- model: str = "local-model"
- messages: list[Message]
- max_tokens: int = 512
- temperature: float = 0.7
- top_p: float = 0.9
- stream: bool = False
- class CompletionRequest(BaseModel):
- model: str = "local-model"
- prompt: str
- max_tokens: int = 512
- temperature: float = 0.7
- top_p: float = 0.9
- stream: bool = False
- class ChoiceMessage(BaseModel):
- role: str = "assistant"
- content: str
- class Choice(BaseModel):
- index: int = 0
- message: ChoiceMessage
- finish_reason: str = "stop"
- class Usage(BaseModel):
- prompt_tokens: int
- completion_tokens: int
- total_tokens: int
- class ChatResponse(BaseModel):
- id: str
- object: str = "chat.completion"
- created: int
- model: str
- choices: list[Choice]
- usage: Usage
- class CompletionChoice(BaseModel):
- index: int = 0
- text: str
- finish_reason: str = "stop"
- class CompletionResponse(BaseModel):
- id: str
- object: str = "text_completion"
- created: int
- model: str
- choices: list[CompletionChoice]
- usage: Usage
- class ModelInfo(BaseModel):
- id: str
- object: str = "model"
- created: int = 0
- owned_by: str = "local"
- class ModelList(BaseModel):
- object: str = "list"
- data: list[ModelInfo]
- # --- Endpoints ---
- @app.post("/v1/chat/completions", response_model=ChatResponse)
- async def chat_completions(req: ChatRequest):
- prompt = _build_prompt(req.messages)
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
- prompt_tokens = inputs["input_ids"].shape[1]
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=req.max_tokens,
- temperature=max(req.temperature, 0.01),
- top_p=req.top_p,
- do_sample=req.temperature > 0,
- pad_token_id=tokenizer.eos_token_id,
- )
- generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
- completion_tokens = outputs.shape[1] - prompt_tokens
- return ChatResponse(
- id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
- created=int(time.time()),
- model=req.model,
- choices=[Choice(message=ChoiceMessage(content=generated))],
- usage=Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- ),
- )
- @app.post("/v1/completions", response_model=CompletionResponse)
- async def completions(req: CompletionRequest):
- inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
- prompt_tokens = inputs["input_ids"].shape[1]
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=req.max_tokens,
- temperature=max(req.temperature, 0.01),
- top_p=req.top_p,
- do_sample=req.temperature > 0,
- pad_token_id=tokenizer.eos_token_id,
- )
- generated = tokenizer.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
- completion_tokens = outputs.shape[1] - prompt_tokens
- return CompletionResponse(
- id=f"cmpl-{uuid.uuid4().hex[:12]}",
- created=int(time.time()),
- model=req.model,
- choices=[CompletionChoice(text=generated)],
- usage=Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- ),
- )
- @app.get("/v1/models", response_model=ModelList)
- async def list_models():
- return ModelList(data=[ModelInfo(id=model_name)])
- @app.get("/health")
- async def health():
- return {"status": "ok", "model": model_name}
- def _build_prompt(messages: list[Message]) -> str:
- """将 OpenAI 消息格式转为模型输入文本。"""
- parts = []
- for msg in messages:
- if msg.role == "system":
- parts.append(f"<|system|>\n{msg.content}")
- elif msg.role == "user":
- parts.append(f"<|user|>\n{msg.content}")
- elif msg.role == "assistant":
- parts.append(f"<|assistant|>\n{msg.content}")
- parts.append("<|assistant|>\n")
- return "\n".join(parts)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Model Serving API Server")
- parser.add_argument("--model-path", type=str, default="./model", help="模型目录路径")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址")
- parser.add_argument("--port", type=int, default=8000, help="监听端口")
- parser.add_argument("--device", type=str, default="auto", help="设备 (auto/cuda/cpu)")
- args = parser.parse_args()
- model_path = Path(args.model_path)
- model_name = model_path.name
- print(f"Loading model from: {model_path}")
- if args.device == "auto":
- device_map = {"": 0} if torch.cuda.is_available() else "auto"
- elif args.device == "cuda":
- device_map = {"": 0}
- else:
- device_map = "cpu"
- tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- model = AutoModelForCausalLM.from_pretrained(
- str(model_path), torch_dtype=torch.float16, device_map=device_map,
- )
- model.eval()
- print(f"Model loaded. Starting server on {args.host}:{args.port}")
- import uvicorn
- uvicorn.run(app, host=args.host, port=args.port)
|