deployment.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. """部署 API —— 导出、在线服务管理、OpenAI 兼容代理。
  2. 路由前缀:
  3. /api/v1/deployment - 管理端点(JWT 认证)
  4. /api/v1/deployment/proxy - 代理端点(API Key 认证)
  5. 代理路由:
  6. POST /proxy/{task_id}/v1/chat/completions - OpenAI 兼容聊天补全
  7. POST /proxy/{task_id}/v1/completions - OpenAI 兼容文本补全
  8. GET /proxy/{task_id}/v1/models - 模型列表
  9. GET /proxy/{task_id}/health - 健康检查
  10. """
  11. import time
  12. import uuid
  13. from fastapi import APIRouter, Depends, HTTPException, Request
  14. from fastapi.responses import JSONResponse
  15. from app.core.auth import get_current_user
  16. from app.schemas.deployment import (
  17. DeployConfig, DeployResponse, DeployServeConfig, DeployedServiceInfo,
  18. )
  19. from app.services import api_key_service, deploy_service
  20. # 管理端点(需要 JWT 登录)
  21. router = APIRouter()
  22. # 代理端点(需要 API Key,不需要 JWT)
  23. proxy_router = APIRouter()
  24. # ---------------------------------------------------------------------------
  25. # API Key 验证(代理端点专用)
  26. # ---------------------------------------------------------------------------
  27. async def _extract_api_key(request: Request) -> str | None:
  28. """从 Authorization: Bearer sk-xxx 提取 API Key。"""
  29. auth_header = request.headers.get("Authorization", "")
  30. if auth_header.startswith("Bearer "):
  31. token = auth_header[7:].strip()
  32. if token.startswith("sk-"):
  33. return token
  34. return None
  35. async def _validate_proxy_auth(task_id: str, request: Request) -> None:
  36. """验证代理请求的 API Key,并检查用户是否拥有该部署任务。"""
  37. api_key = await _extract_api_key(request)
  38. if not api_key:
  39. raise HTTPException(
  40. status_code=401,
  41. detail={"error": {"message": "Missing API key. Use Authorization: Bearer sk-xxx", "type": "auth_error"}},
  42. )
  43. key_info = await api_key_service.validate_api_key(api_key)
  44. if not key_info:
  45. raise HTTPException(
  46. status_code=401,
  47. detail={"error": {"message": "Invalid or revoked API key", "type": "auth_error"}},
  48. )
  49. # 检查用户是否拥有该部署任务
  50. if not await api_key_service.check_deploy_ownership(task_id, key_info["user_id"]):
  51. raise HTTPException(
  52. status_code=403,
  53. detail={"error": {"message": "Access denied: you do not own this deployment", "type": "permission_error"}},
  54. )
  55. # ---------------------------------------------------------------------------
  56. # 管理端点(JWT 认证)
  57. # ---------------------------------------------------------------------------
  58. @router.post("/export", response_model=DeployResponse)
  59. async def export_adapter(
  60. config: DeployConfig,
  61. current_user: dict = Depends(get_current_user),
  62. ):
  63. """启动导出后台任务(导出模型文件),立即返回 task_id。"""
  64. user_id = current_user.get("sub")
  65. result = await deploy_service.export_adapter(
  66. config.job_id,
  67. {"merge_with_base": config.merge_with_base, "export_format": config.export_format},
  68. user_id=user_id,
  69. )
  70. return DeployResponse(**result)
  71. @router.post("/serve", response_model=DeployResponse)
  72. async def serve_model(
  73. config: DeployServeConfig,
  74. current_user: dict = Depends(get_current_user),
  75. ):
  76. """部署为在线推理服务(OpenAI 兼容 API)。
  77. 151 提供代理 API,253 运行纯推理 worker。
  78. 启动后通过 base_url 调用 /v1/chat/completions 等接口。
  79. """
  80. user_id = current_user.get("sub")
  81. try:
  82. result = await deploy_service.start_serving(
  83. config.job_id,
  84. {"merge_with_base": config.merge_with_base, "port": config.port, "host": config.host},
  85. user_id=user_id,
  86. )
  87. return DeployResponse(**result)
  88. except RuntimeError as e:
  89. raise HTTPException(status_code=400, detail=str(e))
  90. @router.get("/services", response_model=list[DeployedServiceInfo])
  91. async def list_services(current_user: dict = Depends(get_current_user)):
  92. """列出当前用户已部署的在线服务。"""
  93. user_id = current_user.get("sub")
  94. services = await deploy_service.list_deployed_services(user_id)
  95. return [DeployedServiceInfo(**s) for s in services]
  96. @router.post("/{task_id}/stop")
  97. async def stop_serving(
  98. task_id: str,
  99. current_user: dict = Depends(get_current_user),
  100. ):
  101. """停止已部署的在线服务。"""
  102. user_id = current_user.get("sub")
  103. result = await deploy_service.stop_serving(task_id, user_id)
  104. if "error" in result:
  105. raise HTTPException(status_code=400, detail=result["error"])
  106. return result
  107. @router.get("/{deploy_id}/status", response_model=DeployResponse)
  108. async def get_deployment_status(deploy_id: str):
  109. """获取导出/部署任务状态。"""
  110. result = await deploy_service.get_deploy_status(deploy_id)
  111. return DeployResponse(**result)
  112. # ---------------------------------------------------------------------------
  113. # OpenAI 兼容代理端点(API Key 认证)
  114. # ---------------------------------------------------------------------------
  115. @proxy_router.post("/proxy/{task_id}/v1/chat/completions")
  116. async def proxy_chat_completions(task_id: str, request: Request):
  117. """OpenAI 兼容的聊天补全代理。"""
  118. await _validate_proxy_auth(task_id, request)
  119. try:
  120. body = await request.json()
  121. except Exception:
  122. raise HTTPException(status_code=400, detail="Invalid JSON")
  123. messages = body.get("messages", [])
  124. if not messages:
  125. raise HTTPException(status_code=400, detail="messages is required")
  126. worker_req = {
  127. "messages": messages,
  128. "max_tokens": body.get("max_tokens", 512),
  129. "temperature": body.get("temperature", 0.7),
  130. "top_p": body.get("top_p", 0.9),
  131. "do_sample": body.get("temperature", 0.7) > 0,
  132. "repetition_penalty": body.get("repetition_penalty", 1.0),
  133. }
  134. worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)
  135. if "error" in worker_resp:
  136. return JSONResponse(
  137. status_code=502,
  138. content={"error": {"message": worker_resp["error"], "type": "upstream_error"}},
  139. )
  140. model = body.get("model", "local-model")
  141. return {
  142. "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
  143. "object": "chat.completion",
  144. "created": int(time.time()),
  145. "model": model,
  146. "choices": [{
  147. "index": 0,
  148. "message": {
  149. "role": "assistant",
  150. "content": worker_resp.get("generated_text", ""),
  151. },
  152. "finish_reason": "stop",
  153. }],
  154. "usage": {
  155. "prompt_tokens": worker_resp.get("prompt_tokens", 0),
  156. "completion_tokens": worker_resp.get("completion_tokens", 0),
  157. "total_tokens": worker_resp.get("total_tokens", 0),
  158. },
  159. }
  160. @proxy_router.post("/proxy/{task_id}/v1/completions")
  161. async def proxy_completions(task_id: str, request: Request):
  162. """OpenAI 兼容的文本补全代理。"""
  163. await _validate_proxy_auth(task_id, request)
  164. try:
  165. body = await request.json()
  166. except Exception:
  167. raise HTTPException(status_code=400, detail="Invalid JSON")
  168. prompt = body.get("prompt", "")
  169. if not prompt:
  170. raise HTTPException(status_code=400, detail="prompt is required")
  171. worker_req = {
  172. "prompt": prompt,
  173. "max_tokens": body.get("max_tokens", 512),
  174. "temperature": body.get("temperature", 0.7),
  175. "top_p": body.get("top_p", 0.9),
  176. "do_sample": body.get("temperature", 0.7) > 0,
  177. "repetition_penalty": body.get("repetition_penalty", 1.0),
  178. }
  179. worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)
  180. if "error" in worker_resp:
  181. return JSONResponse(
  182. status_code=502,
  183. content={"error": {"message": worker_resp["error"], "type": "upstream_error"}},
  184. )
  185. model = body.get("model", "local-model")
  186. return {
  187. "id": f"cmpl-{uuid.uuid4().hex[:12]}",
  188. "object": "text_completion",
  189. "created": int(time.time()),
  190. "model": model,
  191. "choices": [{
  192. "index": 0,
  193. "text": worker_resp.get("generated_text", ""),
  194. "finish_reason": "stop",
  195. }],
  196. "usage": {
  197. "prompt_tokens": worker_resp.get("prompt_tokens", 0),
  198. "completion_tokens": worker_resp.get("completion_tokens", 0),
  199. "total_tokens": worker_resp.get("total_tokens", 0),
  200. },
  201. }
  202. @proxy_router.get("/proxy/{task_id}/v1/models")
  203. async def proxy_models(task_id: str, request: Request):
  204. """返回模型列表(代理)。"""
  205. await _validate_proxy_auth(task_id, request)
  206. result = await deploy_service.get_deploy_status(task_id)
  207. model_name = f"finetuned-{result.get('job_id', '')[:8]}" if result.get("job_id") else "local-model"
  208. return {
  209. "object": "list",
  210. "data": [{
  211. "id": model_name,
  212. "object": "model",
  213. "created": int(time.time()),
  214. "owned_by": "local",
  215. }],
  216. }
  217. @proxy_router.get("/proxy/{task_id}/health")
  218. async def proxy_health(task_id: str, request: Request):
  219. """健康检查。"""
  220. await _validate_proxy_auth(task_id, request)
  221. result = await deploy_service.get_deploy_status(task_id)
  222. if result.get("status") != "running":
  223. return {"status": "error", "message": f"服务状态: {result.get('status', 'unknown')}"}
  224. return {"status": "ok", "task_id": task_id}