api_keys.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from fastapi import APIRouter, HTTPException
  5. from fastapi.responses import Response
  6. from pydantic import BaseModel
  7. from app.db import get_pool
  8. from app.routers.models import _bump_all_domain_versions
  9. router = APIRouter(tags=["api_keys"])
  10. class ApiKeyIn(BaseModel):
  11. name: str
  12. key_value: str
  13. note: Optional[str] = None
  14. class ApiKeyUpdate(BaseModel):
  15. name: Optional[str] = None
  16. key_value: Optional[str] = None
  17. note: Optional[str] = None
  18. class ApiKeyOut(BaseModel):
  19. id: int
  20. name: str
  21. key_value: str # 返回脱敏值,不暴露明文
  22. note: Optional[str]
  23. model_count: int = 0
  24. created_at: datetime
  25. updated_at: datetime
  26. def _mask_key(v: str) -> str:
  27. """脱敏:保留前6位和后4位,中间替换为 ••••••••"""
  28. if len(v) <= 10:
  29. return '••••••••'
  30. return v[:6] + '••••••••' + v[-4:]
  31. class BatchAssignIn(BaseModel):
  32. """批量将某个 api_key 绑定到一批模型"""
  33. api_key_id: Optional[int] = None # None 表示清除绑定
  34. model_ids: List[int]
  35. @router.get("/api-keys", response_model=List[ApiKeyOut])
  36. async def list_api_keys() -> List[ApiKeyOut]:
  37. pool = get_pool()
  38. rows = await pool.fetch(
  39. """
  40. SELECT k.id, k.name, k.key_value, k.note, k.created_at, k.updated_at,
  41. COUNT(m.id) AS model_count
  42. FROM api_keys k
  43. LEFT JOIN models m ON m.api_key_id = k.id
  44. GROUP BY k.id
  45. ORDER BY k.created_at DESC
  46. """
  47. )
  48. return [ApiKeyOut(**{**dict(r), "key_value": _mask_key(r["key_value"])}) for r in rows]
  49. @router.post("/api-keys", response_model=ApiKeyOut, status_code=201)
  50. async def create_api_key(body: ApiKeyIn) -> ApiKeyOut:
  51. pool = get_pool()
  52. row = await pool.fetchrow(
  53. """
  54. INSERT INTO api_keys (name, key_value, note)
  55. VALUES ($1, $2, $3)
  56. RETURNING id, name, key_value, note, created_at, updated_at
  57. """,
  58. body.name, body.key_value, body.note,
  59. )
  60. return ApiKeyOut(**dict(row))
  61. @router.put("/api-keys/{key_id}", response_model=ApiKeyOut)
  62. async def update_api_key(key_id: int, body: ApiKeyUpdate) -> ApiKeyOut:
  63. pool = get_pool()
  64. existing = await pool.fetchrow(
  65. "SELECT id, name, key_value, note, created_at, updated_at FROM api_keys WHERE id = $1",
  66. key_id,
  67. )
  68. if existing is None:
  69. raise HTTPException(status_code=404, detail="API Key 不存在")
  70. new_name = body.name if body.name is not None else existing["name"]
  71. new_value = body.key_value if body.key_value is not None else existing["key_value"]
  72. new_note = body.note if body.note is not None else existing["note"]
  73. key_value_changed = new_value != existing["key_value"]
  74. async with pool.acquire() as conn:
  75. async with conn.transaction():
  76. row = await conn.fetchrow(
  77. """
  78. UPDATE api_keys SET name = $1, key_value = $2, note = $3, updated_at = NOW()
  79. WHERE id = $4
  80. RETURNING id, name, key_value, note, created_at, updated_at
  81. """,
  82. new_name, new_value, new_note, key_id,
  83. )
  84. # key 值本身变了 → 通知所有绑定了此 key 的客户端
  85. if key_value_changed:
  86. await _bump_all_domain_versions(conn)
  87. return ApiKeyOut(**dict(row))
  88. @router.delete("/api-keys/{key_id}", status_code=204, response_model=None)
  89. async def delete_api_key(key_id: int) -> Response:
  90. pool = get_pool()
  91. result = await pool.execute("DELETE FROM api_keys WHERE id = $1", key_id)
  92. if result == "DELETE 0":
  93. raise HTTPException(status_code=404, detail="API Key 不存在")
  94. return Response(status_code=204)
  95. @router.post("/api-keys/batch-assign", status_code=200)
  96. async def batch_assign_api_key(body: BatchAssignIn) -> dict:
  97. """批量将指定 api_key 绑定到一批模型(api_key_id=None 则清除绑定)"""
  98. if not body.model_ids:
  99. return {"updated": 0}
  100. pool = get_pool()
  101. if body.api_key_id is not None:
  102. key_row = await pool.fetchrow("SELECT id FROM api_keys WHERE id = $1", body.api_key_id)
  103. if key_row is None:
  104. raise HTTPException(status_code=404, detail="API Key 不存在")
  105. async with pool.acquire() as conn:
  106. async with conn.transaction():
  107. result = await conn.execute(
  108. "UPDATE models SET api_key_id = $1 WHERE id = ANY($2::bigint[])",
  109. body.api_key_id, body.model_ids,
  110. )
  111. updated = int(result.split()[-1]) if result else 0
  112. if updated > 0:
  113. await _bump_all_domain_versions(conn)
  114. return {"updated": updated}