from __future__ import annotations from datetime import datetime from typing import List, Optional from fastapi import APIRouter, HTTPException from fastapi.responses import Response from pydantic import BaseModel from app.db import get_pool from app.routers.models import _bump_all_domain_versions router = APIRouter(tags=["api_keys"]) class ApiKeyIn(BaseModel): name: str key_value: str note: Optional[str] = None class ApiKeyUpdate(BaseModel): name: Optional[str] = None key_value: Optional[str] = None note: Optional[str] = None class ApiKeyOut(BaseModel): id: int name: str key_value: str # 返回脱敏值,不暴露明文 note: Optional[str] model_count: int = 0 created_at: datetime updated_at: datetime def _mask_key(v: str) -> str: """脱敏:保留前6位和后4位,中间替换为 ••••••••""" if len(v) <= 10: return '••••••••' return v[:6] + '••••••••' + v[-4:] class BatchAssignIn(BaseModel): """批量将某个 api_key 绑定到一批模型""" api_key_id: Optional[int] = None # None 表示清除绑定 model_ids: List[int] @router.get("/api-keys", response_model=List[ApiKeyOut]) async def list_api_keys() -> List[ApiKeyOut]: pool = get_pool() rows = await pool.fetch( """ SELECT k.id, k.name, k.key_value, k.note, k.created_at, k.updated_at, COUNT(m.id) AS model_count FROM api_keys k LEFT JOIN models m ON m.api_key_id = k.id GROUP BY k.id ORDER BY k.created_at DESC """ ) return [ApiKeyOut(**{**dict(r), "key_value": _mask_key(r["key_value"])}) for r in rows] @router.post("/api-keys", response_model=ApiKeyOut, status_code=201) async def create_api_key(body: ApiKeyIn) -> ApiKeyOut: pool = get_pool() row = await pool.fetchrow( """ INSERT INTO api_keys (name, key_value, note) VALUES ($1, $2, $3) RETURNING id, name, key_value, note, created_at, updated_at """, body.name, body.key_value, body.note, ) return ApiKeyOut(**dict(row)) @router.put("/api-keys/{key_id}", response_model=ApiKeyOut) async def update_api_key(key_id: int, body: ApiKeyUpdate) -> ApiKeyOut: pool = get_pool() existing = await pool.fetchrow( "SELECT id, name, key_value, note, created_at, updated_at FROM api_keys WHERE id = $1", key_id, ) if existing is None: raise HTTPException(status_code=404, detail="API Key 不存在") new_name = body.name if body.name is not None else existing["name"] new_value = body.key_value if body.key_value is not None else existing["key_value"] new_note = body.note if body.note is not None else existing["note"] key_value_changed = new_value != existing["key_value"] async with pool.acquire() as conn: async with conn.transaction(): row = await conn.fetchrow( """ UPDATE api_keys SET name = $1, key_value = $2, note = $3, updated_at = NOW() WHERE id = $4 RETURNING id, name, key_value, note, created_at, updated_at """, new_name, new_value, new_note, key_id, ) # key 值本身变了 → 通知所有绑定了此 key 的客户端 if key_value_changed: await _bump_all_domain_versions(conn) return ApiKeyOut(**dict(row)) @router.delete("/api-keys/{key_id}", status_code=204, response_model=None) async def delete_api_key(key_id: int) -> Response: pool = get_pool() result = await pool.execute("DELETE FROM api_keys WHERE id = $1", key_id) if result == "DELETE 0": raise HTTPException(status_code=404, detail="API Key 不存在") return Response(status_code=204) @router.post("/api-keys/batch-assign", status_code=200) async def batch_assign_api_key(body: BatchAssignIn) -> dict: """批量将指定 api_key 绑定到一批模型(api_key_id=None 则清除绑定)""" if not body.model_ids: return {"updated": 0} pool = get_pool() if body.api_key_id is not None: key_row = await pool.fetchrow("SELECT id FROM api_keys WHERE id = $1", body.api_key_id) if key_row is None: raise HTTPException(status_code=404, detail="API Key 不存在") async with pool.acquire() as conn: async with conn.transaction(): result = await conn.execute( "UPDATE models SET api_key_id = $1 WHERE id = ANY($2::bigint[])", body.api_key_id, body.model_ids, ) updated = int(result.split()[-1]) if result else 0 if updated > 0: await _bump_all_domain_versions(conn) return {"updated": updated}