discounts.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from fastapi import APIRouter, HTTPException
  5. from fastapi.responses import Response
  6. from pydantic import BaseModel, Field
  7. from app.db import get_pool
  8. router = APIRouter(tags=["discounts"])
  9. async def _bump_domain_version(conn, domain: str) -> None:
  10. """Upsert domain_version: insert with version=1 for new domains, increment for existing."""
  11. await conn.execute(
  12. """
  13. INSERT INTO domain_version (domain, version, updated_at)
  14. VALUES ($1, 1, NOW())
  15. ON CONFLICT (domain) DO UPDATE
  16. SET version = domain_version.version + 1,
  17. updated_at = NOW()
  18. """,
  19. domain,
  20. )
  21. class DiscountIn(BaseModel):
  22. domain: str
  23. discount: float = Field(..., gt=0, le=1, description="折扣系数,如 0.8 表示八折")
  24. note: Optional[str] = None
  25. class DiscountOut(BaseModel):
  26. id: int
  27. domain: str
  28. discount: float
  29. note: Optional[str]
  30. created_at: datetime
  31. updated_at: datetime
  32. @router.get("/discounts", response_model=List[DiscountOut])
  33. async def list_discounts() -> List[DiscountOut]:
  34. pool = get_pool()
  35. rows = await pool.fetch("SELECT * FROM discounts ORDER BY updated_at DESC")
  36. return [DiscountOut(**dict(r)) for r in rows]
  37. @router.post("/discounts", response_model=DiscountOut, status_code=201)
  38. async def create_discount(body: DiscountIn) -> DiscountOut:
  39. pool = get_pool()
  40. async with pool.acquire() as conn:
  41. async with conn.transaction():
  42. row = await conn.fetchrow(
  43. """
  44. INSERT INTO discounts (domain, discount, note)
  45. VALUES ($1, $2, $3)
  46. ON CONFLICT (domain) DO UPDATE
  47. SET discount = EXCLUDED.discount,
  48. note = EXCLUDED.note,
  49. updated_at = NOW()
  50. RETURNING *
  51. """,
  52. body.domain, body.discount, body.note,
  53. )
  54. await _bump_domain_version(conn, body.domain)
  55. return DiscountOut(**dict(row))
  56. @router.put("/discounts/{discount_id}", response_model=DiscountOut)
  57. async def update_discount(discount_id: int, body: DiscountIn) -> DiscountOut:
  58. pool = get_pool()
  59. async with pool.acquire() as conn:
  60. async with conn.transaction():
  61. row = await conn.fetchrow(
  62. """
  63. UPDATE discounts SET domain=$1, discount=$2, note=$3, updated_at=NOW()
  64. WHERE id=$4 RETURNING *
  65. """,
  66. body.domain, body.discount, body.note, discount_id,
  67. )
  68. if not row:
  69. raise HTTPException(status_code=404, detail="不存在")
  70. await _bump_domain_version(conn, body.domain)
  71. return DiscountOut(**dict(row))
  72. @router.delete("/discounts/{discount_id}", status_code=204, response_model=None)
  73. async def delete_discount(discount_id: int) -> Response:
  74. pool = get_pool()
  75. async with pool.acquire() as conn:
  76. async with conn.transaction():
  77. # 先查出 domain,再删除,再 bump 版本
  78. existing = await conn.fetchrow("SELECT domain FROM discounts WHERE id=$1", discount_id)
  79. if not existing:
  80. raise HTTPException(status_code=404, detail="不存在")
  81. result = await conn.execute("DELETE FROM discounts WHERE id=$1", discount_id)
  82. if result == "DELETE 0":
  83. raise HTTPException(status_code=404, detail="不存在")
  84. await conn.execute("DELETE FROM domain_version WHERE domain=$1", existing["domain"])
  85. return Response(status_code=204)