from __future__ import annotations from datetime import datetime from typing import List, Optional from fastapi import APIRouter, HTTPException from fastapi.responses import Response from pydantic import BaseModel, Field from app.db import get_pool router = APIRouter(tags=["domain_model_prices"]) # ── Pydantic models ────────────────────────────────────────────────────────── class DomainModelPriceIn(BaseModel): model_name: str discount: float = Field(..., gt=0, le=1, description="折扣系数,如 0.8 表示八折") note: Optional[str] = None class DomainModelPriceOut(BaseModel): id: int domain: str model_name: str discount: float note: Optional[str] created_at: datetime updated_at: datetime # ── Helpers ────────────────────────────────────────────────────────────────── async def _bump_domain_version(conn, domain: str) -> None: await conn.execute( """ INSERT INTO domain_version (domain, version, updated_at) VALUES ($1, 1, NOW()) ON CONFLICT (domain) DO UPDATE SET version = domain_version.version + 1, updated_at = NOW() """, domain, ) # ── Routes ─────────────────────────────────────────────────────────────────── @router.get("/discounts/{domain}/model-prices", response_model=List[DomainModelPriceOut]) async def list_domain_model_prices(domain: str) -> List[DomainModelPriceOut]: """列出某域名下所有已配置的模型折扣。""" pool = get_pool() rows = await pool.fetch( "SELECT * FROM domain_model_prices WHERE domain = $1 ORDER BY model_name", domain, ) return [DomainModelPriceOut(**dict(r)) for r in rows] @router.put("/discounts/{domain}/model-prices/{model_name}", response_model=DomainModelPriceOut) async def upsert_domain_model_price( domain: str, model_name: str, body: DomainModelPriceIn, ) -> DomainModelPriceOut: """新增或更新某域名下某模型的折扣系数。""" pool = get_pool() async with pool.acquire() as conn: async with conn.transaction(): row = await conn.fetchrow( """ INSERT INTO domain_model_prices (domain, model_name, discount, note) VALUES ($1, $2, $3, $4) ON CONFLICT (domain, model_name) DO UPDATE SET discount = EXCLUDED.discount, note = EXCLUDED.note, updated_at = NOW() RETURNING * """, domain, model_name, body.discount, body.note, ) await _bump_domain_version(conn, domain) return DomainModelPriceOut(**dict(row)) @router.delete( "/discounts/{domain}/model-prices/{model_name}", status_code=204, response_model=None, ) async def delete_domain_model_price(domain: str, model_name: str) -> Response: """删除某域名下某模型的折扣配置。""" pool = get_pool() async with pool.acquire() as conn: async with conn.transaction(): result = await conn.execute( "DELETE FROM domain_model_prices WHERE domain = $1 AND model_name = $2", domain, model_name, ) if result == "DELETE 0": raise HTTPException(status_code=404, detail="配置不存在") await _bump_domain_version(conn, domain) return Response(status_code=204) # ── Batch upsert ───────────────────────────────────────────────────────────── class BatchItem(BaseModel): model_name: str discount: float = Field(..., gt=0, le=1) note: Optional[str] = None class BatchIn(BaseModel): items: List[BatchItem] @router.put("/discounts/{domain}/model-prices", response_model=List[DomainModelPriceOut]) async def batch_upsert_domain_model_prices( domain: str, body: BatchIn, ) -> List[DomainModelPriceOut]: """批量新增或更新某域名下多个模型的折扣系数。""" if not body.items: return [] pool = get_pool() async with pool.acquire() as conn: async with conn.transaction(): rows = [] for item in body.items: row = await conn.fetchrow( """ INSERT INTO domain_model_prices (domain, model_name, discount, note) VALUES ($1, $2, $3, $4) ON CONFLICT (domain, model_name) DO UPDATE SET discount = EXCLUDED.discount, note = EXCLUDED.note, updated_at = NOW() RETURNING * """, domain, item.model_name, item.discount, item.note, ) rows.append(DomainModelPriceOut(**dict(row))) await _bump_domain_version(conn, domain) return rows