domain_model_prices.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from fastapi import APIRouter, HTTPException
  5. from fastapi.responses import Response
  6. from pydantic import BaseModel, Field
  7. from app.db import get_pool
  8. router = APIRouter(tags=["domain_model_prices"])
  9. # ── Pydantic models ──────────────────────────────────────────────────────────
  10. class DomainModelPriceIn(BaseModel):
  11. model_name: str
  12. discount: float = Field(..., gt=0, le=1, description="折扣系数,如 0.8 表示八折")
  13. note: Optional[str] = None
  14. class DomainModelPriceOut(BaseModel):
  15. id: int
  16. domain: str
  17. model_name: str
  18. discount: float
  19. note: Optional[str]
  20. created_at: datetime
  21. updated_at: datetime
  22. # ── Helpers ──────────────────────────────────────────────────────────────────
  23. async def _bump_domain_version(conn, domain: str) -> None:
  24. await conn.execute(
  25. """
  26. INSERT INTO domain_version (domain, version, updated_at)
  27. VALUES ($1, 1, NOW())
  28. ON CONFLICT (domain) DO UPDATE
  29. SET version = domain_version.version + 1,
  30. updated_at = NOW()
  31. """,
  32. domain,
  33. )
  34. # ── Routes ───────────────────────────────────────────────────────────────────
  35. @router.get("/discounts/{domain}/model-prices", response_model=List[DomainModelPriceOut])
  36. async def list_domain_model_prices(domain: str) -> List[DomainModelPriceOut]:
  37. """列出某域名下所有已配置的模型折扣。"""
  38. pool = get_pool()
  39. rows = await pool.fetch(
  40. "SELECT * FROM domain_model_prices WHERE domain = $1 ORDER BY model_name",
  41. domain,
  42. )
  43. return [DomainModelPriceOut(**dict(r)) for r in rows]
  44. @router.put("/discounts/{domain}/model-prices/{model_name}", response_model=DomainModelPriceOut)
  45. async def upsert_domain_model_price(
  46. domain: str,
  47. model_name: str,
  48. body: DomainModelPriceIn,
  49. ) -> DomainModelPriceOut:
  50. """新增或更新某域名下某模型的折扣系数。"""
  51. pool = get_pool()
  52. async with pool.acquire() as conn:
  53. async with conn.transaction():
  54. row = await conn.fetchrow(
  55. """
  56. INSERT INTO domain_model_prices (domain, model_name, discount, note)
  57. VALUES ($1, $2, $3, $4)
  58. ON CONFLICT (domain, model_name) DO UPDATE
  59. SET discount = EXCLUDED.discount,
  60. note = EXCLUDED.note,
  61. updated_at = NOW()
  62. RETURNING *
  63. """,
  64. domain,
  65. model_name,
  66. body.discount,
  67. body.note,
  68. )
  69. await _bump_domain_version(conn, domain)
  70. return DomainModelPriceOut(**dict(row))
  71. @router.delete(
  72. "/discounts/{domain}/model-prices/{model_name}",
  73. status_code=204,
  74. response_model=None,
  75. )
  76. async def delete_domain_model_price(domain: str, model_name: str) -> Response:
  77. """删除某域名下某模型的折扣配置。"""
  78. pool = get_pool()
  79. async with pool.acquire() as conn:
  80. async with conn.transaction():
  81. result = await conn.execute(
  82. "DELETE FROM domain_model_prices WHERE domain = $1 AND model_name = $2",
  83. domain,
  84. model_name,
  85. )
  86. if result == "DELETE 0":
  87. raise HTTPException(status_code=404, detail="配置不存在")
  88. await _bump_domain_version(conn, domain)
  89. return Response(status_code=204)
  90. # ── Batch upsert ─────────────────────────────────────────────────────────────
  91. class BatchItem(BaseModel):
  92. model_name: str
  93. discount: float = Field(..., gt=0, le=1)
  94. note: Optional[str] = None
  95. class BatchIn(BaseModel):
  96. items: List[BatchItem]
  97. @router.put("/discounts/{domain}/model-prices", response_model=List[DomainModelPriceOut])
  98. async def batch_upsert_domain_model_prices(
  99. domain: str,
  100. body: BatchIn,
  101. ) -> List[DomainModelPriceOut]:
  102. """批量新增或更新某域名下多个模型的折扣系数。"""
  103. if not body.items:
  104. return []
  105. pool = get_pool()
  106. async with pool.acquire() as conn:
  107. async with conn.transaction():
  108. rows = []
  109. for item in body.items:
  110. row = await conn.fetchrow(
  111. """
  112. INSERT INTO domain_model_prices (domain, model_name, discount, note)
  113. VALUES ($1, $2, $3, $4)
  114. ON CONFLICT (domain, model_name) DO UPDATE
  115. SET discount = EXCLUDED.discount,
  116. note = EXCLUDED.note,
  117. updated_at = NOW()
  118. RETURNING *
  119. """,
  120. domain,
  121. item.model_name,
  122. item.discount,
  123. item.note,
  124. )
  125. rows.append(DomainModelPriceOut(**dict(row)))
  126. await _bump_domain_version(conn, domain)
  127. return rows