main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import asyncio
  2. import logging
  3. import math
  4. from contextlib import asynccontextmanager
  5. from datetime import datetime, timezone, timedelta
  6. from fastapi import FastAPI
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from app.config import settings
  9. from app.routers import domains, monitoring, license as license_router, auth as auth_router
  10. from app.routers.license import public_router as public_license_router
  11. from app.routers import sa_balance as sa_balance_router
  12. from app.database import async_session
  13. from app.services.license import _notify_on_expired, _notify_on_warning
  14. from app.services.domain_fetch import fetch_domain_transactions
  15. from app.services.sms import send_sa_balance_warning, send_sa_balance_depleted
  16. from app.services.compensation_service import process_pending
  17. from app.models.license import SuperAdminLicense
  18. from app.models.monitoring import FetchScheduleConfig, FetchLog, SuperAdmin
  19. from app.models.domain import MonitoredDomain
  20. from app.models.visitor import VisitorInfo
  21. from app.redis import close_redis
  22. from sqlalchemy import select, text
  23. from decimal import Decimal
  24. logger = logging.getLogger(__name__)
  25. CST = timezone(timedelta(hours=8))
  26. async def _check_licenses():
  27. """定期检查 License:过期状态更新 + 7 天预警"""
  28. while True:
  29. try:
  30. async with async_session() as session:
  31. # 1. 检查已过期
  32. expired_result = await session.execute(
  33. select(SuperAdminLicense).where(
  34. SuperAdminLicense.status == "active",
  35. SuperAdminLicense.expires_at <= text("NOW()"),
  36. )
  37. )
  38. for lic in expired_result.scalars().all():
  39. logger.info("检测到 License #%d 已过期,更新状态并发送预警短信", lic.id)
  40. lic.status = "expired"
  41. await session.commit()
  42. await _notify_on_expired(session, lic)
  43. # 2. 检查剩余 7 天预警(未发送过的)
  44. warning_result = await session.execute(
  45. select(SuperAdminLicense).where(
  46. SuperAdminLicense.status == "active",
  47. SuperAdminLicense.warning_sent == False,
  48. SuperAdminLicense.expires_at > text("NOW()"),
  49. SuperAdminLicense.expires_at <= text("NOW() + INTERVAL '7 days'"),
  50. )
  51. )
  52. for lic in warning_result.scalars().all():
  53. days_left = math.ceil((lic.expires_at - datetime.now(timezone.utc)).total_seconds() / 86400)
  54. logger.info("检测到 License #%d 剩余 %d 天,发送预警短信", lic.id, days_left)
  55. await _notify_on_warning(session, lic, days_left)
  56. lic.warning_sent = True
  57. await session.commit()
  58. except Exception:
  59. logger.exception("定时检查 License 异常")
  60. await asyncio.sleep(24 * 3600)
  61. async def _daily_fetch():
  62. """定时爬取:每分钟检查 DB 配置,到了目标时间就爬取当天流水"""
  63. while True:
  64. try:
  65. async with async_session() as session:
  66. result = await session.execute(select(FetchScheduleConfig).limit(1))
  67. config = result.scalar_one_or_none()
  68. if config and config.enabled:
  69. h, m = map(int, config.schedule_time.split(":"))
  70. now = datetime.now(CST)
  71. today = now.strftime("%Y-%m-%d")
  72. # 已过目标时间且今天还没爬过
  73. if now.hour * 60 + now.minute >= h * 60 + m and config.last_fetch_date != today:
  74. logger.info("开始定时爬取当天流水: %s", today)
  75. # 查询今日已失败的域名(跳过)
  76. failed_result = await session.execute(
  77. select(FetchLog.domain).where(
  78. FetchLog.status == "failed",
  79. FetchLog.created_at >= text(f"'{today} 00:00:00+08'"),
  80. ).distinct()
  81. )
  82. failed_domains = {r[0] for r in failed_result.all()}
  83. domain_result = await session.execute(
  84. select(MonitoredDomain).where(MonitoredDomain.is_active == True)
  85. )
  86. for d in domain_result.scalars().all():
  87. if d.domain in failed_domains:
  88. logger.info("域名 %s 今日已失败,跳过", d.domain)
  89. continue
  90. try:
  91. await fetch_domain_transactions(d.domain, session, fetch_date=today)
  92. logger.info("域名 %s 当天流水爬取完成", d.domain)
  93. session.add(FetchLog(
  94. domain=d.domain, status="success", message="当天流水爬取完成"
  95. ))
  96. except Exception as e:
  97. error_msg = str(e)[:500]
  98. logger.exception("域名 %s 当天流水爬取失败", d.domain)
  99. session.add(FetchLog(
  100. domain=d.domain, status="failed", message=error_msg
  101. ))
  102. config.last_fetch_date = today
  103. await session.commit()
  104. last_fetch_date = today
  105. logger.info("当天定时爬取全部完成")
  106. except Exception:
  107. logger.exception("定时爬取异常")
  108. await asyncio.sleep(60)
  109. async def _check_sa_balances():
  110. """定期检查超管余额:余额预警 + 余额耗尽通知(每 10 分钟)"""
  111. while True:
  112. try:
  113. async with async_session() as session:
  114. result = await session.execute(select(SuperAdmin).order_by(SuperAdmin.id))
  115. threshold = Decimal(str(settings.sa_balance_warning_threshold))
  116. for sa in result.scalars().all():
  117. balance = Decimal(str(sa.balance or 0))
  118. company = sa.remark or sa.username or str(sa.id)
  119. # 手机号:优先用超管自身 phone,否则从域名关联的 VisitorInfo 获取
  120. phone = sa.phone
  121. if not phone:
  122. domain_row = (await session.execute(
  123. select(MonitoredDomain).where(
  124. MonitoredDomain.super_admin_id == sa.id,
  125. MonitoredDomain.is_active == True,
  126. ).limit(1)
  127. )).scalar_one_or_none()
  128. if domain_row:
  129. visitor = (await session.execute(
  130. select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id)
  131. )).scalar_one_or_none()
  132. if visitor:
  133. phone = visitor.phone
  134. if balance <= 0:
  135. # 余额耗尽
  136. if not sa.balance_depleted_sent and phone:
  137. ok, reason = await send_sa_balance_depleted(phone, company)
  138. if ok:
  139. sa.balance_depleted_sent = True
  140. await session.commit()
  141. logger.info("超管 %d(%s) 余额耗尽,短信已发送", sa.id, company)
  142. else:
  143. logger.warning("超管 %d 耗尽短信发送失败: %s", sa.id, reason)
  144. elif balance <= threshold:
  145. # 余额不足但未耗尽
  146. if not sa.balance_warning_sent and phone:
  147. ok, reason = await send_sa_balance_warning(phone, company, f"{balance:.2f}")
  148. if ok:
  149. sa.balance_warning_sent = True
  150. await session.commit()
  151. logger.info("超管 %d(%s) 余额预警(%.2f),短信已发送", sa.id, company, balance)
  152. else:
  153. logger.warning("超管 %d 预警短信发送失败: %s", sa.id, reason)
  154. else:
  155. # 余额恢复,重置标记
  156. if sa.balance_warning_sent or sa.balance_depleted_sent:
  157. sa.balance_warning_sent = False
  158. sa.balance_depleted_sent = False
  159. await session.commit()
  160. logger.info("超管 %d(%s) 余额恢复至 %.2f,重置预警标记", sa.id, company, balance)
  161. except Exception:
  162. logger.exception("定时检查超管余额异常")
  163. await asyncio.sleep(600) # 10 分钟
  164. async def _process_compensation():
  165. """定期处理待补偿扣减记录(每 30 秒)"""
  166. while True:
  167. try:
  168. async with async_session() as session:
  169. count = await process_pending(session)
  170. if count > 0:
  171. logger.info("补偿任务处理了 %d 条记录", count)
  172. except Exception:
  173. logger.exception("补偿任务异常")
  174. await asyncio.sleep(30)
  175. @asynccontextmanager
  176. async def lifespan(app: FastAPI):
  177. # 启动时立即检查一次
  178. try:
  179. async with async_session() as session:
  180. # 1. 检查已过期
  181. expired_result = await session.execute(
  182. select(SuperAdminLicense).where(
  183. SuperAdminLicense.status == "active",
  184. SuperAdminLicense.expires_at <= text("NOW()"),
  185. )
  186. )
  187. for lic in expired_result.scalars().all():
  188. logger.info("启动时检测到 License #%d 已过期,更新状态并发送预警短信", lic.id)
  189. lic.status = "expired"
  190. await session.commit()
  191. await _notify_on_expired(session, lic)
  192. # 2. 检查剩余 7 天预警(未发送过的)
  193. from datetime import timedelta
  194. warning_result = await session.execute(
  195. select(SuperAdminLicense).where(
  196. SuperAdminLicense.status == "active",
  197. SuperAdminLicense.warning_sent == False,
  198. SuperAdminLicense.expires_at > text("NOW()"),
  199. SuperAdminLicense.expires_at <= text("NOW() + INTERVAL '7 days'"),
  200. )
  201. )
  202. for lic in warning_result.scalars().all():
  203. days_left = math.ceil((lic.expires_at - datetime.now(timezone.utc)).total_seconds() / 86400)
  204. logger.info("启动时检测到 License #%d 剩余 %d 天,发送预警短信", lic.id, days_left)
  205. await _notify_on_warning(session, lic, days_left)
  206. lic.warning_sent = True
  207. await session.commit()
  208. logger.info("启动检查完成")
  209. except Exception:
  210. logger.exception("启动时检查 License 异常")
  211. # 启动后台定时任务(错开启动,避免同时抢连接)
  212. license_task = asyncio.create_task(_check_licenses())
  213. await asyncio.sleep(2)
  214. fetch_task = asyncio.create_task(_daily_fetch())
  215. await asyncio.sleep(2)
  216. balance_task = asyncio.create_task(_check_sa_balances())
  217. await asyncio.sleep(2)
  218. compensation_task = asyncio.create_task(_process_compensation())
  219. logger.info("后台任务已启动:License 检查 + 定时爬取 + 超管余额预警 + 扣减补偿")
  220. yield
  221. license_task.cancel()
  222. fetch_task.cancel()
  223. balance_task.cancel()
  224. compensation_task.cancel()
  225. for task in (license_task, fetch_task, balance_task, compensation_task):
  226. try:
  227. await task
  228. except asyncio.CancelledError:
  229. pass
  230. await close_redis()
  231. app = FastAPI(
  232. title="域名流水监控",
  233. version="0.1.0",
  234. debug=settings.debug,
  235. lifespan=lifespan,
  236. )
  237. # CORS 配置,允许前端 Vite 开发服务器访问
  238. app.add_middleware(
  239. CORSMiddleware,
  240. allow_origins=["http://localhost:5173"],
  241. allow_credentials=True,
  242. allow_methods=["*"],
  243. allow_headers=["*"],
  244. )
  245. # 注册路由
  246. app.include_router(domains.router)
  247. app.include_router(monitoring.router)
  248. app.include_router(license_router.router)
  249. app.include_router(public_license_router)
  250. app.include_router(auth_router.router)
  251. app.include_router(sa_balance_router.router)
  252. app.include_router(sa_balance_router.public_router)
  253. @app.get("/health")
  254. async def health():
  255. """健康检查接口"""
  256. return {"status": "ok"}