| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- from __future__ import annotations
- from datetime import datetime
- from typing import List, Optional
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import Response
- from pydantic import BaseModel
- from app.db import get_pool
- router = APIRouter(tags=["models"])
- async def _bump_all_domain_versions(conn) -> None:
- """api_key 有变动时,所有域名的版本号 +1,让客户端感知到数据变化。"""
- await conn.execute(
- "UPDATE domain_version SET version = version + 1, updated_at = NOW()"
- )
- await conn.execute(
- "UPDATE price_snapshot_version SET version = GREATEST(version + 1, 1), updated_at = NOW() WHERE id = 1"
- )
- class ModelIn(BaseModel):
- name: str
- url: str
- api_key_id: Optional[int] = None
- group_id: Optional[int] = None
- class ModelUpdate(BaseModel):
- name: Optional[str] = None
- url: Optional[str] = None
- api_key_id: Optional[int] = None
- group_id: Optional[int] = None
- class ModelOut(BaseModel):
- id: int
- name: str
- url: str
- api_key_id: Optional[int]
- api_key_name: Optional[str] = None
- group_id: Optional[int] = None
- group_name: Optional[str] = None
- created_at: datetime
- @router.get("/models", response_model=List[ModelOut])
- async def list_models() -> List[ModelOut]:
- pool = get_pool()
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- """
- SELECT m.id, m.name, m.url, m.api_key_id, m.group_id, m.created_at,
- k.name AS api_key_name, g.name AS group_name
- FROM models m
- LEFT JOIN api_keys k ON k.id = m.api_key_id
- LEFT JOIN model_groups g ON g.id = m.group_id
- ORDER BY m.created_at DESC
- """
- )
- return [ModelOut(**dict(r)) for r in rows]
- @router.post("/models", response_model=ModelOut, status_code=201)
- async def create_model(body: ModelIn) -> ModelOut:
- pool = get_pool()
- async with pool.acquire() as conn:
- try:
- row = await conn.fetchrow(
- """
- INSERT INTO models (name, url, api_key_id, group_id) VALUES ($1, $2, $3, $4)
- RETURNING id, name, url, api_key_id, group_id, created_at
- """,
- body.name, body.url, body.api_key_id, body.group_id,
- )
- api_key_name = None
- if row["api_key_id"]:
- k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
- api_key_name = k["name"] if k else None
- group_name = None
- if row["group_id"]:
- g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
- group_name = g["name"] if g else None
- except Exception:
- raise HTTPException(status_code=409, detail="该 URL 已存在")
- return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)
- @router.put("/models/{model_id}", response_model=ModelOut)
- async def update_model(model_id: int, body: ModelUpdate) -> ModelOut:
- pool = get_pool()
- async with pool.acquire() as conn:
- existing = await conn.fetchrow(
- "SELECT id, name, url, api_key_id, group_id, created_at FROM models WHERE id = $1",
- model_id,
- )
- if existing is None:
- raise HTTPException(status_code=404, detail="模型不存在")
- new_name = body.name if body.name is not None else existing["name"]
- new_url = body.url if body.url is not None else existing["url"]
- new_api_key_id = body.api_key_id if body.api_key_id is not None else existing["api_key_id"]
- new_group_id = body.group_id if body.group_id is not None else existing["group_id"]
- api_key_changed = new_api_key_id != existing["api_key_id"]
- try:
- async with conn.transaction():
- row = await conn.fetchrow(
- """
- UPDATE models SET name = $1, url = $2, api_key_id = $3, group_id = $4
- WHERE id = $5
- RETURNING id, name, url, api_key_id, group_id, created_at
- """,
- new_name, new_url, new_api_key_id, new_group_id, model_id,
- )
- if api_key_changed:
- await _bump_all_domain_versions(conn)
- api_key_name = None
- if row["api_key_id"]:
- k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
- api_key_name = k["name"] if k else None
- group_name = None
- if row["group_id"]:
- g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
- group_name = g["name"] if g else None
- except HTTPException:
- raise
- except Exception:
- raise HTTPException(status_code=409, detail="该 URL 已存在")
- return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)
- @router.delete("/models/{model_id}", status_code=204, response_model=None)
- async def delete_model(model_id: int) -> Response:
- pool = get_pool()
- async with pool.acquire() as conn:
- result = await conn.execute("DELETE FROM models WHERE id = $1", model_id)
- if result == "DELETE 0":
- raise HTTPException(status_code=404, detail="模型不存在")
- return Response(status_code=204)
- class BatchDeleteIn(BaseModel):
- ids: List[int]
- @router.post("/models/batch-delete", status_code=200)
- async def batch_delete_models(body: BatchDeleteIn) -> dict:
- if not body.ids:
- raise HTTPException(status_code=400, detail="ids 不能为空")
- pool = get_pool()
- async with pool.acquire() as conn:
- result = await conn.execute(
- "DELETE FROM models WHERE id = ANY($1::int[])",
- body.ids,
- )
- deleted = int(result.split()[-1])
- return {"deleted": deleted}
- class UpsertModelIn(BaseModel):
- name: str
- url: str
- api_key_id: Optional[int] = None
- group_id: Optional[int] = None
- @router.post("/models/upsert", response_model=ModelOut, status_code=200)
- async def upsert_model(body: UpsertModelIn) -> ModelOut:
- """按 URL 做 upsert:URL 已存在则更新 name,不存在则插入。"""
- pool = get_pool()
- async with pool.acquire() as conn:
- row = await conn.fetchrow(
- """
- INSERT INTO models (name, url, api_key_id, group_id)
- VALUES ($1, $2, $3, $4)
- ON CONFLICT (url) DO UPDATE
- SET name = EXCLUDED.name,
- api_key_id = COALESCE(EXCLUDED.api_key_id, models.api_key_id),
- group_id = COALESCE(EXCLUDED.group_id, models.group_id)
- RETURNING id, name, url, api_key_id, group_id, created_at
- """,
- body.name, body.url, body.api_key_id, body.group_id,
- )
- api_key_name = None
- if row["api_key_id"]:
- k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
- api_key_name = k["name"] if k else None
- group_name = None
- if row["group_id"]:
- g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
- group_name = g["name"] if g else None
- return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)
|