from __future__ import annotations from datetime import datetime from typing import List, Optional from fastapi import APIRouter, HTTPException from fastapi.responses import Response from pydantic import BaseModel from app.db import get_pool from app.routers.models import _bump_all_domain_versions router = APIRouter(tags=["model_groups"]) class GroupIn(BaseModel): name: str note: Optional[str] = None class GroupUpdate(BaseModel): name: Optional[str] = None note: Optional[str] = None class GroupOut(BaseModel): id: int name: str note: Optional[str] model_count: int = 0 created_at: datetime updated_at: datetime class BatchAssignGroupIn(BaseModel): group_id: Optional[int] = None # None 表示清除分组 model_ids: List[int] @router.get("/model-groups", response_model=List[GroupOut]) async def list_groups() -> List[GroupOut]: pool = get_pool() rows = await pool.fetch( """ SELECT g.id, g.name, g.note, g.created_at, g.updated_at, COUNT(m.id) AS model_count FROM model_groups g LEFT JOIN models m ON m.group_id = g.id GROUP BY g.id ORDER BY g.created_at ASC """ ) return [GroupOut(**dict(r)) for r in rows] @router.post("/model-groups", response_model=GroupOut, status_code=201) async def create_group(body: GroupIn) -> GroupOut: pool = get_pool() row = await pool.fetchrow( """ INSERT INTO model_groups (name, note) VALUES ($1, $2) RETURNING id, name, note, created_at, updated_at """, body.name, body.note, ) return GroupOut(**dict(row), model_count=0) @router.put("/model-groups/{group_id}", response_model=GroupOut) async def update_group(group_id: int, body: GroupUpdate) -> GroupOut: pool = get_pool() existing = await pool.fetchrow( "SELECT id, name, note, created_at, updated_at FROM model_groups WHERE id = $1", group_id, ) if existing is None: raise HTTPException(status_code=404, detail="分组不存在") new_name = body.name if body.name is not None else existing["name"] new_note = body.note if body.note is not None else existing["note"] row = await pool.fetchrow( """ UPDATE model_groups SET name = $1, note = $2, updated_at = NOW() WHERE id = $3 RETURNING id, name, note, created_at, updated_at """, new_name, new_note, group_id, ) count = await pool.fetchval("SELECT COUNT(*) FROM models WHERE group_id = $1", group_id) return GroupOut(**dict(row), model_count=count or 0) @router.delete("/model-groups/{group_id}", status_code=204, response_model=None) async def delete_group(group_id: int) -> Response: pool = get_pool() # 删除分组时,将该分组下的模型 group_id 置为 NULL async with pool.acquire() as conn: async with conn.transaction(): await conn.execute("UPDATE models SET group_id = NULL WHERE group_id = $1", group_id) result = await conn.execute("DELETE FROM model_groups WHERE id = $1", group_id) if result == "DELETE 0": raise HTTPException(status_code=404, detail="分组不存在") return Response(status_code=204) @router.post("/model-groups/batch-assign", status_code=200) async def batch_assign_group(body: BatchAssignGroupIn) -> dict: """批量将模型归入某个分组(group_id=None 则清除分组)""" if not body.model_ids: return {"updated": 0} pool = get_pool() if body.group_id is not None: exists = await pool.fetchval("SELECT id FROM model_groups WHERE id = $1", body.group_id) if not exists: raise HTTPException(status_code=404, detail="分组不存在") async with pool.acquire() as conn: async with conn.transaction(): result = await conn.execute( "UPDATE models SET group_id = $1 WHERE id = ANY($2::bigint[])", body.group_id, body.model_ids, ) updated = int(result.split()[-1]) if result else 0 if updated > 0: await _bump_all_domain_versions(conn) return {"updated": updated}