models.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from fastapi import APIRouter, HTTPException
  5. from fastapi.responses import Response
  6. from pydantic import BaseModel
  7. from app.db import get_pool
  8. router = APIRouter(tags=["models"])
  9. async def _bump_all_domain_versions(conn) -> None:
  10. """api_key 有变动时,所有域名的版本号 +1,让客户端感知到数据变化。"""
  11. await conn.execute(
  12. "UPDATE domain_version SET version = version + 1, updated_at = NOW()"
  13. )
  14. await conn.execute(
  15. "UPDATE price_snapshot_version SET version = GREATEST(version + 1, 1), updated_at = NOW() WHERE id = 1"
  16. )
  17. class ModelIn(BaseModel):
  18. name: str
  19. url: str
  20. api_key_id: Optional[int] = None
  21. group_id: Optional[int] = None
  22. class ModelUpdate(BaseModel):
  23. name: Optional[str] = None
  24. url: Optional[str] = None
  25. api_key_id: Optional[int] = None
  26. group_id: Optional[int] = None
  27. class ModelOut(BaseModel):
  28. id: int
  29. name: str
  30. url: str
  31. api_key_id: Optional[int]
  32. api_key_name: Optional[str] = None
  33. group_id: Optional[int] = None
  34. group_name: Optional[str] = None
  35. created_at: datetime
  36. @router.get("/models", response_model=List[ModelOut])
  37. async def list_models() -> List[ModelOut]:
  38. pool = get_pool()
  39. async with pool.acquire() as conn:
  40. rows = await conn.fetch(
  41. """
  42. SELECT m.id, m.name, m.url, m.api_key_id, m.group_id, m.created_at,
  43. k.name AS api_key_name, g.name AS group_name
  44. FROM models m
  45. LEFT JOIN api_keys k ON k.id = m.api_key_id
  46. LEFT JOIN model_groups g ON g.id = m.group_id
  47. ORDER BY m.created_at DESC
  48. """
  49. )
  50. return [ModelOut(**dict(r)) for r in rows]
  51. @router.post("/models", response_model=ModelOut, status_code=201)
  52. async def create_model(body: ModelIn) -> ModelOut:
  53. pool = get_pool()
  54. async with pool.acquire() as conn:
  55. try:
  56. row = await conn.fetchrow(
  57. """
  58. INSERT INTO models (name, url, api_key_id, group_id) VALUES ($1, $2, $3, $4)
  59. RETURNING id, name, url, api_key_id, group_id, created_at
  60. """,
  61. body.name, body.url, body.api_key_id, body.group_id,
  62. )
  63. api_key_name = None
  64. if row["api_key_id"]:
  65. k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
  66. api_key_name = k["name"] if k else None
  67. group_name = None
  68. if row["group_id"]:
  69. g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
  70. group_name = g["name"] if g else None
  71. except Exception:
  72. raise HTTPException(status_code=409, detail="该 URL 已存在")
  73. return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)
  74. @router.put("/models/{model_id}", response_model=ModelOut)
  75. async def update_model(model_id: int, body: ModelUpdate) -> ModelOut:
  76. pool = get_pool()
  77. async with pool.acquire() as conn:
  78. existing = await conn.fetchrow(
  79. "SELECT id, name, url, api_key_id, group_id, created_at FROM models WHERE id = $1",
  80. model_id,
  81. )
  82. if existing is None:
  83. raise HTTPException(status_code=404, detail="模型不存在")
  84. new_name = body.name if body.name is not None else existing["name"]
  85. new_url = body.url if body.url is not None else existing["url"]
  86. new_api_key_id = body.api_key_id if body.api_key_id is not None else existing["api_key_id"]
  87. new_group_id = body.group_id if body.group_id is not None else existing["group_id"]
  88. api_key_changed = new_api_key_id != existing["api_key_id"]
  89. try:
  90. async with conn.transaction():
  91. row = await conn.fetchrow(
  92. """
  93. UPDATE models SET name = $1, url = $2, api_key_id = $3, group_id = $4
  94. WHERE id = $5
  95. RETURNING id, name, url, api_key_id, group_id, created_at
  96. """,
  97. new_name, new_url, new_api_key_id, new_group_id, model_id,
  98. )
  99. if api_key_changed:
  100. await _bump_all_domain_versions(conn)
  101. api_key_name = None
  102. if row["api_key_id"]:
  103. k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
  104. api_key_name = k["name"] if k else None
  105. group_name = None
  106. if row["group_id"]:
  107. g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
  108. group_name = g["name"] if g else None
  109. except HTTPException:
  110. raise
  111. except Exception:
  112. raise HTTPException(status_code=409, detail="该 URL 已存在")
  113. return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)
  114. @router.delete("/models/{model_id}", status_code=204, response_model=None)
  115. async def delete_model(model_id: int) -> Response:
  116. pool = get_pool()
  117. async with pool.acquire() as conn:
  118. result = await conn.execute("DELETE FROM models WHERE id = $1", model_id)
  119. if result == "DELETE 0":
  120. raise HTTPException(status_code=404, detail="模型不存在")
  121. return Response(status_code=204)
  122. class BatchDeleteIn(BaseModel):
  123. ids: List[int]
  124. @router.post("/models/batch-delete", status_code=200)
  125. async def batch_delete_models(body: BatchDeleteIn) -> dict:
  126. if not body.ids:
  127. raise HTTPException(status_code=400, detail="ids 不能为空")
  128. pool = get_pool()
  129. async with pool.acquire() as conn:
  130. result = await conn.execute(
  131. "DELETE FROM models WHERE id = ANY($1::int[])",
  132. body.ids,
  133. )
  134. deleted = int(result.split()[-1])
  135. return {"deleted": deleted}
  136. class UpsertModelIn(BaseModel):
  137. name: str
  138. url: str
  139. api_key_id: Optional[int] = None
  140. group_id: Optional[int] = None
  141. @router.post("/models/upsert", response_model=ModelOut, status_code=200)
  142. async def upsert_model(body: UpsertModelIn) -> ModelOut:
  143. """按 URL 做 upsert:URL 已存在则更新 name,不存在则插入。"""
  144. pool = get_pool()
  145. async with pool.acquire() as conn:
  146. row = await conn.fetchrow(
  147. """
  148. INSERT INTO models (name, url, api_key_id, group_id)
  149. VALUES ($1, $2, $3, $4)
  150. ON CONFLICT (url) DO UPDATE
  151. SET name = EXCLUDED.name,
  152. api_key_id = COALESCE(EXCLUDED.api_key_id, models.api_key_id),
  153. group_id = COALESCE(EXCLUDED.group_id, models.group_id)
  154. RETURNING id, name, url, api_key_id, group_id, created_at
  155. """,
  156. body.name, body.url, body.api_key_id, body.group_id,
  157. )
  158. api_key_name = None
  159. if row["api_key_id"]:
  160. k = await conn.fetchrow("SELECT name FROM api_keys WHERE id = $1", row["api_key_id"])
  161. api_key_name = k["name"] if k else None
  162. group_name = None
  163. if row["group_id"]:
  164. g = await conn.fetchrow("SELECT name FROM model_groups WHERE id = $1", row["group_id"])
  165. group_name = g["name"] if g else None
  166. return ModelOut(**dict(row), api_key_name=api_key_name, group_name=group_name)