| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- from __future__ import annotations
- from datetime import datetime
- from typing import List, Optional
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import Response
- from pydantic import BaseModel
- from app.db import get_pool
- from app.routers.models import _bump_all_domain_versions
- router = APIRouter(tags=["api_keys"])
- class ApiKeyIn(BaseModel):
- name: str
- key_value: str
- note: Optional[str] = None
- class ApiKeyUpdate(BaseModel):
- name: Optional[str] = None
- key_value: Optional[str] = None
- note: Optional[str] = None
- class ApiKeyOut(BaseModel):
- id: int
- name: str
- key_value: str # 返回脱敏值,不暴露明文
- note: Optional[str]
- model_count: int = 0
- created_at: datetime
- updated_at: datetime
- def _mask_key(v: str) -> str:
- """脱敏:保留前6位和后4位,中间替换为 ••••••••"""
- if len(v) <= 10:
- return '••••••••'
- return v[:6] + '••••••••' + v[-4:]
- class BatchAssignIn(BaseModel):
- """批量将某个 api_key 绑定到一批模型"""
- api_key_id: Optional[int] = None # None 表示清除绑定
- model_ids: List[int]
- @router.get("/api-keys", response_model=List[ApiKeyOut])
- async def list_api_keys() -> List[ApiKeyOut]:
- pool = get_pool()
- rows = await pool.fetch(
- """
- SELECT k.id, k.name, k.key_value, k.note, k.created_at, k.updated_at,
- COUNT(m.id) AS model_count
- FROM api_keys k
- LEFT JOIN models m ON m.api_key_id = k.id
- GROUP BY k.id
- ORDER BY k.created_at DESC
- """
- )
- return [ApiKeyOut(**{**dict(r), "key_value": _mask_key(r["key_value"])}) for r in rows]
- @router.post("/api-keys", response_model=ApiKeyOut, status_code=201)
- async def create_api_key(body: ApiKeyIn) -> ApiKeyOut:
- pool = get_pool()
- row = await pool.fetchrow(
- """
- INSERT INTO api_keys (name, key_value, note)
- VALUES ($1, $2, $3)
- RETURNING id, name, key_value, note, created_at, updated_at
- """,
- body.name, body.key_value, body.note,
- )
- return ApiKeyOut(**dict(row))
- @router.put("/api-keys/{key_id}", response_model=ApiKeyOut)
- async def update_api_key(key_id: int, body: ApiKeyUpdate) -> ApiKeyOut:
- pool = get_pool()
- existing = await pool.fetchrow(
- "SELECT id, name, key_value, note, created_at, updated_at FROM api_keys WHERE id = $1",
- key_id,
- )
- if existing is None:
- raise HTTPException(status_code=404, detail="API Key 不存在")
- new_name = body.name if body.name is not None else existing["name"]
- new_value = body.key_value if body.key_value is not None else existing["key_value"]
- new_note = body.note if body.note is not None else existing["note"]
- key_value_changed = new_value != existing["key_value"]
- async with pool.acquire() as conn:
- async with conn.transaction():
- row = await conn.fetchrow(
- """
- UPDATE api_keys SET name = $1, key_value = $2, note = $3, updated_at = NOW()
- WHERE id = $4
- RETURNING id, name, key_value, note, created_at, updated_at
- """,
- new_name, new_value, new_note, key_id,
- )
- # key 值本身变了 → 通知所有绑定了此 key 的客户端
- if key_value_changed:
- await _bump_all_domain_versions(conn)
- return ApiKeyOut(**dict(row))
- @router.delete("/api-keys/{key_id}", status_code=204, response_model=None)
- async def delete_api_key(key_id: int) -> Response:
- pool = get_pool()
- result = await pool.execute("DELETE FROM api_keys WHERE id = $1", key_id)
- if result == "DELETE 0":
- raise HTTPException(status_code=404, detail="API Key 不存在")
- return Response(status_code=204)
- @router.post("/api-keys/batch-assign", status_code=200)
- async def batch_assign_api_key(body: BatchAssignIn) -> dict:
- """批量将指定 api_key 绑定到一批模型(api_key_id=None 则清除绑定)"""
- if not body.model_ids:
- return {"updated": 0}
- pool = get_pool()
- if body.api_key_id is not None:
- key_row = await pool.fetchrow("SELECT id FROM api_keys WHERE id = $1", body.api_key_id)
- if key_row is None:
- raise HTTPException(status_code=404, detail="API Key 不存在")
- async with pool.acquire() as conn:
- async with conn.transaction():
- result = await conn.execute(
- "UPDATE models SET api_key_id = $1 WHERE id = ANY($2::bigint[])",
- body.api_key_id, body.model_ids,
- )
- updated = int(result.split()[-1]) if result else 0
- if updated > 0:
- await _bump_all_domain_versions(conn)
- return {"updated": updated}
|