| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- from __future__ import annotations
- from datetime import datetime
- from typing import List, Optional
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import Response
- from pydantic import BaseModel
- from app.db import get_pool
- from app.routers.models import _bump_all_domain_versions
- router = APIRouter(tags=["model_groups"])
- class GroupIn(BaseModel):
- name: str
- note: Optional[str] = None
- class GroupUpdate(BaseModel):
- name: Optional[str] = None
- note: Optional[str] = None
- class GroupOut(BaseModel):
- id: int
- name: str
- note: Optional[str]
- model_count: int = 0
- created_at: datetime
- updated_at: datetime
- class BatchAssignGroupIn(BaseModel):
- group_id: Optional[int] = None # None 表示清除分组
- model_ids: List[int]
- @router.get("/model-groups", response_model=List[GroupOut])
- async def list_groups() -> List[GroupOut]:
- pool = get_pool()
- rows = await pool.fetch(
- """
- SELECT g.id, g.name, g.note, g.created_at, g.updated_at,
- COUNT(m.id) AS model_count
- FROM model_groups g
- LEFT JOIN models m ON m.group_id = g.id
- GROUP BY g.id
- ORDER BY g.created_at ASC
- """
- )
- return [GroupOut(**dict(r)) for r in rows]
- @router.post("/model-groups", response_model=GroupOut, status_code=201)
- async def create_group(body: GroupIn) -> GroupOut:
- pool = get_pool()
- row = await pool.fetchrow(
- """
- INSERT INTO model_groups (name, note)
- VALUES ($1, $2)
- RETURNING id, name, note, created_at, updated_at
- """,
- body.name, body.note,
- )
- return GroupOut(**dict(row), model_count=0)
- @router.put("/model-groups/{group_id}", response_model=GroupOut)
- async def update_group(group_id: int, body: GroupUpdate) -> GroupOut:
- pool = get_pool()
- existing = await pool.fetchrow(
- "SELECT id, name, note, created_at, updated_at FROM model_groups WHERE id = $1",
- group_id,
- )
- if existing is None:
- raise HTTPException(status_code=404, detail="分组不存在")
- new_name = body.name if body.name is not None else existing["name"]
- new_note = body.note if body.note is not None else existing["note"]
- row = await pool.fetchrow(
- """
- UPDATE model_groups SET name = $1, note = $2, updated_at = NOW()
- WHERE id = $3
- RETURNING id, name, note, created_at, updated_at
- """,
- new_name, new_note, group_id,
- )
- count = await pool.fetchval("SELECT COUNT(*) FROM models WHERE group_id = $1", group_id)
- return GroupOut(**dict(row), model_count=count or 0)
- @router.delete("/model-groups/{group_id}", status_code=204, response_model=None)
- async def delete_group(group_id: int) -> Response:
- pool = get_pool()
- # 删除分组时,将该分组下的模型 group_id 置为 NULL
- async with pool.acquire() as conn:
- async with conn.transaction():
- await conn.execute("UPDATE models SET group_id = NULL WHERE group_id = $1", group_id)
- result = await conn.execute("DELETE FROM model_groups WHERE id = $1", group_id)
- if result == "DELETE 0":
- raise HTTPException(status_code=404, detail="分组不存在")
- return Response(status_code=204)
- @router.post("/model-groups/batch-assign", status_code=200)
- async def batch_assign_group(body: BatchAssignGroupIn) -> dict:
- """批量将模型归入某个分组(group_id=None 则清除分组)"""
- if not body.model_ids:
- return {"updated": 0}
- pool = get_pool()
- if body.group_id is not None:
- exists = await pool.fetchval("SELECT id FROM model_groups WHERE id = $1", body.group_id)
- if not exists:
- raise HTTPException(status_code=404, detail="分组不存在")
- async with pool.acquire() as conn:
- async with conn.transaction():
- result = await conn.execute(
- "UPDATE models SET group_id = $1 WHERE id = ANY($2::bigint[])",
- body.group_id, body.model_ids,
- )
- updated = int(result.split()[-1]) if result else 0
- if updated > 0:
- await _bump_all_domain_versions(conn)
- return {"updated": updated}
|