model_groups.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from fastapi import APIRouter, HTTPException
  5. from fastapi.responses import Response
  6. from pydantic import BaseModel
  7. from app.db import get_pool
  8. from app.routers.models import _bump_all_domain_versions
  9. router = APIRouter(tags=["model_groups"])
  10. class GroupIn(BaseModel):
  11. name: str
  12. note: Optional[str] = None
  13. class GroupUpdate(BaseModel):
  14. name: Optional[str] = None
  15. note: Optional[str] = None
  16. class GroupOut(BaseModel):
  17. id: int
  18. name: str
  19. note: Optional[str]
  20. model_count: int = 0
  21. created_at: datetime
  22. updated_at: datetime
  23. class BatchAssignGroupIn(BaseModel):
  24. group_id: Optional[int] = None # None 表示清除分组
  25. model_ids: List[int]
  26. @router.get("/model-groups", response_model=List[GroupOut])
  27. async def list_groups() -> List[GroupOut]:
  28. pool = get_pool()
  29. rows = await pool.fetch(
  30. """
  31. SELECT g.id, g.name, g.note, g.created_at, g.updated_at,
  32. COUNT(m.id) AS model_count
  33. FROM model_groups g
  34. LEFT JOIN models m ON m.group_id = g.id
  35. GROUP BY g.id
  36. ORDER BY g.created_at ASC
  37. """
  38. )
  39. return [GroupOut(**dict(r)) for r in rows]
  40. @router.post("/model-groups", response_model=GroupOut, status_code=201)
  41. async def create_group(body: GroupIn) -> GroupOut:
  42. pool = get_pool()
  43. row = await pool.fetchrow(
  44. """
  45. INSERT INTO model_groups (name, note)
  46. VALUES ($1, $2)
  47. RETURNING id, name, note, created_at, updated_at
  48. """,
  49. body.name, body.note,
  50. )
  51. return GroupOut(**dict(row), model_count=0)
  52. @router.put("/model-groups/{group_id}", response_model=GroupOut)
  53. async def update_group(group_id: int, body: GroupUpdate) -> GroupOut:
  54. pool = get_pool()
  55. existing = await pool.fetchrow(
  56. "SELECT id, name, note, created_at, updated_at FROM model_groups WHERE id = $1",
  57. group_id,
  58. )
  59. if existing is None:
  60. raise HTTPException(status_code=404, detail="分组不存在")
  61. new_name = body.name if body.name is not None else existing["name"]
  62. new_note = body.note if body.note is not None else existing["note"]
  63. row = await pool.fetchrow(
  64. """
  65. UPDATE model_groups SET name = $1, note = $2, updated_at = NOW()
  66. WHERE id = $3
  67. RETURNING id, name, note, created_at, updated_at
  68. """,
  69. new_name, new_note, group_id,
  70. )
  71. count = await pool.fetchval("SELECT COUNT(*) FROM models WHERE group_id = $1", group_id)
  72. return GroupOut(**dict(row), model_count=count or 0)
  73. @router.delete("/model-groups/{group_id}", status_code=204, response_model=None)
  74. async def delete_group(group_id: int) -> Response:
  75. pool = get_pool()
  76. # 删除分组时,将该分组下的模型 group_id 置为 NULL
  77. async with pool.acquire() as conn:
  78. async with conn.transaction():
  79. await conn.execute("UPDATE models SET group_id = NULL WHERE group_id = $1", group_id)
  80. result = await conn.execute("DELETE FROM model_groups WHERE id = $1", group_id)
  81. if result == "DELETE 0":
  82. raise HTTPException(status_code=404, detail="分组不存在")
  83. return Response(status_code=204)
  84. @router.post("/model-groups/batch-assign", status_code=200)
  85. async def batch_assign_group(body: BatchAssignGroupIn) -> dict:
  86. """批量将模型归入某个分组(group_id=None 则清除分组)"""
  87. if not body.model_ids:
  88. return {"updated": 0}
  89. pool = get_pool()
  90. if body.group_id is not None:
  91. exists = await pool.fetchval("SELECT id FROM model_groups WHERE id = $1", body.group_id)
  92. if not exists:
  93. raise HTTPException(status_code=404, detail="分组不存在")
  94. async with pool.acquire() as conn:
  95. async with conn.transaction():
  96. result = await conn.execute(
  97. "UPDATE models SET group_id = $1 WHERE id = ANY($2::bigint[])",
  98. body.group_id, body.model_ids,
  99. )
  100. updated = int(result.split()[-1]) if result else 0
  101. if updated > 0:
  102. await _bump_all_domain_versions(conn)
  103. return {"updated": updated}