license.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import math
  2. from datetime import datetime, timezone, timedelta
  3. from urllib.parse import urlparse
  4. CST = timezone(timedelta(hours=8)) # 东八区
  5. from sqlalchemy import select
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from app.models.license import SuperAdminLicense
  8. from app.models.monitoring import SuperAdmin
  9. from app.models.domain import MonitoredDomain
  10. from app.models.visitor import VisitorInfo
  11. from app.schemas.license import (
  12. LicenseCreate,
  13. LicenseResponse,
  14. LicenseStatusResponse,
  15. LicenseUpdate,
  16. SuperAdminOption,
  17. LicenseListResponse,
  18. )
  19. from app.services.sms import send_license_expired, send_license_restored, send_license_warning
  20. import logging
  21. logger = logging.getLogger(__name__)
  22. def _to_str(val) -> str:
  23. if val is None:
  24. return ""
  25. if isinstance(val, datetime):
  26. return val.astimezone(CST).isoformat()
  27. return str(val)
  28. async def _get_contact_for_license(db: AsyncSession, lic: SuperAdminLicense) -> tuple[str | None, str | None]:
  29. """根据 License 获取关联的联系人 phone 和 company 名称。返回 (phone, company)"""
  30. domain_result = await db.execute(
  31. select(MonitoredDomain).where(
  32. MonitoredDomain.super_admin_id == lic.super_admin_id,
  33. MonitoredDomain.is_active == True,
  34. ).limit(1)
  35. )
  36. domain = domain_result.scalar_one_or_none()
  37. if not domain:
  38. return None, None
  39. sa_result = await db.execute(
  40. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  41. )
  42. sa = sa_result.scalar_one_or_none()
  43. company = sa.remark or sa.username if sa else str(lic.super_admin_id)
  44. visitor_result = await db.execute(
  45. select(VisitorInfo).where(VisitorInfo.domain_id == domain.id)
  46. )
  47. visitor = visitor_result.scalar_one_or_none()
  48. if not visitor or not visitor.phone:
  49. return None, company
  50. return visitor.phone, company
  51. async def _notify_on_expired(db: AsyncSession, lic: SuperAdminLicense, retry: bool = False) -> tuple[str, str]:
  52. """License 过期通知短信。返回 (status, reason)"""
  53. try:
  54. if lic.expired_sent:
  55. return "skipped", "已发送过过期短信"
  56. phone, company = await _get_contact_for_license(db, lic)
  57. if not phone:
  58. return "skipped", "无联系人手机号"
  59. ok, reason = await send_license_expired(phone, company, retry=retry)
  60. if ok:
  61. lic.expired_sent = True
  62. await db.commit()
  63. return "sent", ""
  64. if "冷却" in reason and not retry:
  65. return "pending", "短信冷却中,稍后自动重试发送"
  66. return "failed", reason
  67. except Exception as e:
  68. logger.exception("发送过期短信异常,license_id=%d", lic.id)
  69. return "failed", str(e)
  70. async def _notify_on_restored(db: AsyncSession, lic: SuperAdminLicense, retry: bool = False) -> tuple[str, str]:
  71. """License 恢复通知短信。返回 (status, reason)"""
  72. try:
  73. phone, company = await _get_contact_for_license(db, lic)
  74. if not phone:
  75. return "skipped", "无联系人手机号"
  76. ok, reason = await send_license_restored(phone, company, retry=retry)
  77. if ok:
  78. lic.expired_sent = False
  79. lic.warning_sent = False
  80. await db.commit()
  81. return "sent", ""
  82. if "冷却" in reason and not retry:
  83. return "pending", "短信冷却中,稍后自动重试发送"
  84. return "failed", reason
  85. except Exception as e:
  86. logger.exception("发送恢复短信异常,license_id=%d", lic.id)
  87. return "failed", str(e)
  88. async def _notify_on_warning(db: AsyncSession, lic: SuperAdminLicense, days_left: int, retry: bool = False) -> tuple[str, str]:
  89. """License 即将过期预警短信。返回 (status, reason)"""
  90. try:
  91. if lic.warning_sent:
  92. return "skipped", "已发送过预警短信"
  93. phone, company = await _get_contact_for_license(db, lic)
  94. if not phone:
  95. return "skipped", "无联系人手机号"
  96. ok, reason = await send_license_warning(phone, company, days_left, retry=retry)
  97. if ok:
  98. lic.warning_sent = True
  99. await db.commit()
  100. return "sent", ""
  101. if "冷却" in reason and not retry:
  102. return "pending", "短信冷却中,稍后自动重试发送"
  103. return "failed", reason
  104. except Exception as e:
  105. logger.exception("发送预警短信异常,license_id=%d", lic.id)
  106. return "failed", str(e)
  107. async def _retry_send_sms(db: AsyncSession, lic: SuperAdminLicense, sms_type: str, days_left: int = 0) -> None:
  108. """后台任务:等待冷却后重试发送 License 相关短信"""
  109. import asyncio
  110. from app.database import async_session
  111. async with async_session() as session:
  112. result = await session.execute(
  113. select(SuperAdminLicense).where(SuperAdminLicense.id == lic.id)
  114. )
  115. lic = result.scalar_one_or_none()
  116. if not lic:
  117. return
  118. # 冷却 30 秒后重试
  119. await asyncio.sleep(30)
  120. if sms_type == "expired":
  121. await _notify_on_expired(session, lic, retry=True)
  122. elif sms_type == "restored":
  123. await _notify_on_restored(session, lic, retry=True)
  124. elif sms_type == "warning":
  125. await _notify_on_warning(session, lic, days_left, retry=True)
  126. def _calc_days_left(expires_at: datetime) -> int:
  127. now = datetime.now(timezone.utc) if expires_at.tzinfo else datetime.now()
  128. delta = expires_at - now
  129. return math.ceil(delta.total_seconds() / (60 * 60 * 24))
  130. async def get_super_admins(db: AsyncSession) -> list[SuperAdminOption]:
  131. """获取所有超级管理员,供下拉选择"""
  132. result = await db.execute(select(SuperAdmin).order_by(SuperAdmin.id))
  133. sas = result.scalars().all()
  134. return [SuperAdminOption(id=sa.id, username=sa.username, nickname=sa.nickname, remark=sa.remark) for sa in sas]
  135. async def create_license(
  136. db: AsyncSession, payload: LicenseCreate
  137. ) -> dict:
  138. """创建或更新 License。同一超管只保留一个 active 的 License,已存在则更新。"""
  139. existing = await db.execute(
  140. select(SuperAdminLicense).where(
  141. SuperAdminLicense.super_admin_id == payload.super_admin_id,
  142. SuperAdminLicense.status == "active",
  143. )
  144. )
  145. exist_lic = existing.scalar_one_or_none()
  146. expires_at = datetime.fromisoformat(payload.expires_at)
  147. if expires_at.tzinfo is None:
  148. expires_at = expires_at.replace(tzinfo=CST)
  149. if exist_lic:
  150. exist_lic.license_key = payload.license_key
  151. exist_lic.expires_at = expires_at
  152. exist_lic.max_tenants = payload.max_tenants
  153. exist_lic.max_users_per_tenant = payload.max_users_per_tenant
  154. exist_lic.remark = payload.remark
  155. await db.flush()
  156. await db.commit()
  157. return {"message": "License已更新", "license_id": exist_lic.id}
  158. new_lic = SuperAdminLicense(
  159. super_admin_id=payload.super_admin_id,
  160. license_key=payload.license_key,
  161. expires_at=expires_at,
  162. status="active",
  163. max_tenants=payload.max_tenants,
  164. max_users_per_tenant=payload.max_users_per_tenant,
  165. remark=payload.remark,
  166. )
  167. db.add(new_lic)
  168. await db.commit()
  169. await db.refresh(new_lic)
  170. return {"message": "License已创建", "license_id": new_lic.id}
  171. async def list_licenses(
  172. db: AsyncSession,
  173. super_admin_id: int | None = None,
  174. status: str | None = None,
  175. page: int = 1,
  176. size: int = 20,
  177. ) -> LicenseListResponse:
  178. """查询 License 列表"""
  179. stmt = select(SuperAdminLicense)
  180. if super_admin_id is not None:
  181. stmt = stmt.where(SuperAdminLicense.super_admin_id == super_admin_id)
  182. if status:
  183. stmt = stmt.where(SuperAdminLicense.status == status)
  184. stmt = stmt.order_by(SuperAdminLicense.created_at.desc())
  185. total_result = await db.execute(stmt)
  186. all_rows = total_result.scalars().all()
  187. total = len(all_rows)
  188. offset = (page - 1) * size
  189. paged = all_rows[offset: offset + size]
  190. items = []
  191. for r in paged:
  192. sa_result = await db.execute(
  193. select(SuperAdmin).where(SuperAdmin.id == r.super_admin_id)
  194. )
  195. sa = sa_result.scalar_one_or_none()
  196. sa_name = sa.remark or sa.username if sa else str(r.super_admin_id)
  197. # 查询关联的域名
  198. domain_result = await db.execute(
  199. select(MonitoredDomain).where(
  200. MonitoredDomain.super_admin_id == r.super_admin_id,
  201. MonitoredDomain.is_active == True,
  202. ).limit(1)
  203. )
  204. domain_row = domain_result.scalar_one_or_none()
  205. domain_name = domain_row.domain if domain_row else None
  206. # 查询关联的联系人信息
  207. contact = None
  208. if domain_row:
  209. visitor_result = await db.execute(
  210. select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id)
  211. )
  212. visitor = visitor_result.scalar_one_or_none()
  213. if visitor:
  214. contact = {"name": visitor.name, "phone": visitor.phone, "email": visitor.email}
  215. items.append(LicenseResponse(
  216. id=r.id,
  217. super_admin_id=r.super_admin_id,
  218. super_admin_name=sa_name,
  219. license_key=r.license_key,
  220. expires_at=_to_str(r.expires_at),
  221. status=r.status,
  222. max_tenants=r.max_tenants,
  223. max_users_per_tenant=r.max_users_per_tenant,
  224. remark=r.remark,
  225. created_at=_to_str(r.created_at),
  226. updated_at=_to_str(r.updated_at),
  227. domain=domain_name,
  228. contact=contact,
  229. ))
  230. return LicenseListResponse(total=total, items=items)
  231. async def get_license_status(
  232. db: AsyncSession, license_id: int
  233. ) -> LicenseStatusResponse:
  234. """获取单个 License 的详细状态"""
  235. result = await db.execute(
  236. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  237. )
  238. lic = result.scalar_one_or_none()
  239. if not lic:
  240. raise ValueError("License 不存在")
  241. days_left = _calc_days_left(lic.expires_at)
  242. sms_status = None
  243. sms_reason = ""
  244. if lic.status == "active" and days_left <= 0:
  245. lic.status = "expired"
  246. await db.commit()
  247. sms_status, sms_reason = await _notify_on_expired(db, lic)
  248. sa_result = await db.execute(
  249. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  250. )
  251. sa = sa_result.scalar_one_or_none()
  252. return LicenseStatusResponse(
  253. id=lic.id,
  254. super_admin_id=lic.super_admin_id,
  255. super_admin_name=sa.remark or sa.username if sa else str(lic.super_admin_id),
  256. license_key=lic.license_key,
  257. expires_at=_to_str(lic.expires_at),
  258. status=lic.status,
  259. days_left=days_left,
  260. max_tenants=lic.max_tenants,
  261. max_users_per_tenant=lic.max_users_per_tenant,
  262. remark=lic.remark,
  263. sms_status=sms_status,
  264. )
  265. async def revoke_license(
  266. db: AsyncSession, license_id: int, background_tasks=None
  267. ) -> dict:
  268. """吊销 License。状态立即生效,短信后台发送。"""
  269. result = await db.execute(
  270. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  271. )
  272. lic = result.scalar_one_or_none()
  273. if not lic:
  274. raise ValueError("License 不存在")
  275. old_status = lic.status
  276. lic.status = "revoked"
  277. await db.commit()
  278. # 吊销即 valid 变为 false,不管原状态如何都发预警短信
  279. sms_status, sms_reason = await _notify_on_expired(db, lic)
  280. if sms_status == "pending" and background_tasks:
  281. background_tasks.add_task(_retry_send_sms, db, lic, "expired")
  282. return {"message": "License已吊销", "sms_status": sms_status, "sms_reason": sms_reason}
  283. async def update_license(
  284. db: AsyncSession, license_id: int, payload: LicenseUpdate, background_tasks=None
  285. ) -> dict:
  286. """更新 License 的 key 或过期时间。状态立即生效,短信后台发送。"""
  287. result = await db.execute(
  288. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  289. )
  290. lic = result.scalar_one_or_none()
  291. if not lic:
  292. raise ValueError("License 不存在")
  293. changed = False
  294. old_status = lic.status
  295. sms_status = "none"
  296. if payload.license_key is not None:
  297. lic.license_key = payload.license_key
  298. changed = True
  299. if payload.expires_at is not None:
  300. expires_at = datetime.fromisoformat(payload.expires_at)
  301. if expires_at.tzinfo is None:
  302. expires_at = expires_at.replace(tzinfo=CST)
  303. lic.expires_at = expires_at
  304. if _calc_days_left(lic.expires_at) > 0:
  305. lic.status = "active"
  306. else:
  307. lic.status = "expired"
  308. changed = True
  309. if not changed:
  310. raise ValueError("没有需要更新的字段")
  311. # 修改过期时间后已超出 7 天窗口,重置预警标记
  312. new_days_left = _calc_days_left(lic.expires_at) if payload.expires_at else None
  313. if new_days_left is not None and new_days_left > 7:
  314. lic.warning_sent = False
  315. lic.expired_sent = False
  316. elif old_status != "active" and lic.status == "active":
  317. lic.warning_sent = False
  318. lic.expired_sent = False
  319. elif new_days_left is not None and 0 < new_days_left <= 7 and not lic.warning_sent:
  320. # 修改后进入7天内,立即发送预警短信
  321. lic.warning_sent = True
  322. sms_status, sms_reason = await _notify_on_warning(db, lic, new_days_left)
  323. if sms_status == "pending" and background_tasks:
  324. background_tasks.add_task(_retry_send_sms, db, lic, "warning", new_days_left)
  325. await db.commit()
  326. return {"message": "License 已更新", "license_id": lic.id, "sms_status": sms_status, "sms_type": "warning", "sms_reason": sms_reason}
  327. await db.commit()
  328. # valid 变化时发送短信
  329. sms_reason = ""
  330. sms_status = "none"
  331. sms_type = None
  332. if old_status == "active" and lic.status != "active":
  333. sms_type = "expired"
  334. elif old_status != "active" and lic.status == "active":
  335. sms_type = "restored"
  336. if sms_type:
  337. if sms_type == "expired":
  338. sms_status, sms_reason = await _notify_on_expired(db, lic)
  339. else:
  340. sms_status, sms_reason = await _notify_on_restored(db, lic)
  341. if sms_status == "pending" and background_tasks:
  342. background_tasks.add_task(_retry_send_sms, db, lic, sms_type)
  343. return {"message": "License已更新", "license_id": lic.id, "sms_status": sms_status, "sms_type": sms_type, "sms_reason": sms_reason}
  344. async def restore_license(
  345. db: AsyncSession, license_id: int, background_tasks=None
  346. ) -> dict:
  347. """恢复已吊销的 License。状态立即生效,短信后台发送。"""
  348. result = await db.execute(
  349. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  350. )
  351. lic = result.scalar_one_or_none()
  352. if not lic:
  353. raise ValueError("License 不存在")
  354. if lic.status != "revoked":
  355. raise ValueError("仅可恢复已吊销的 License")
  356. lic.status = "active"
  357. lic.warning_sent = False
  358. lic.expired_sent = False
  359. await db.commit()
  360. sms_status, sms_reason = await _notify_on_restored(db, lic)
  361. if sms_status == "pending" and background_tasks:
  362. background_tasks.add_task(_retry_send_sms, db, lic, "restored")
  363. return {"message": "License已恢复", "sms_status": sms_status, "sms_reason": sms_reason}
  364. async def delete_license(
  365. db: AsyncSession, license_id: int
  366. ) -> dict:
  367. """删除 License 记录"""
  368. result = await db.execute(
  369. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  370. )
  371. lic = result.scalar_one_or_none()
  372. if not lic:
  373. raise ValueError("License 不存在")
  374. await db.delete(lic)
  375. await db.commit()
  376. return {"message": "License已删除"}
  377. async def check_license_by_referer(
  378. db: AsyncSession, referer: str | None
  379. ) -> dict:
  380. """通过 Referer 头匹配域名,检查对应超管的 License 状态。
  381. 流程:
  382. 1. 从 Referer 提取域名(host)
  383. 2. 在 MonitoredDomain 中匹配 domain 字段
  384. 3. 获取关联的 super_admin_id
  385. 4. 查询该超管当前 active 的 License 并返回状态
  386. """
  387. if not referer:
  388. return {
  389. "valid": False,
  390. "status": "unknown",
  391. "message": "缺少 Referer 头",
  392. }
  393. # 从 Referer URL 提取域名
  394. try:
  395. parsed = urlparse(referer)
  396. referer_host = parsed.hostname
  397. # 当 Referer 缺少 scheme(如 "127.0.0.1" 或 "example.com")时,
  398. # urlparse 会将其解析为 path 而非 netloc,导致 hostname 为 None。
  399. # 此时回退为使用原始字符串作为域名。
  400. if not referer_host:
  401. referer_host = referer.strip().rstrip("/")
  402. if not referer_host:
  403. return {
  404. "valid": False,
  405. "status": "unknown",
  406. "message": "Referer 格式无效",
  407. }
  408. # netloc 包含端口(如 127.0.0.1:8010),hostname 不包含
  409. referer_netloc = parsed.netloc if parsed.netloc else referer_host
  410. except Exception:
  411. return {
  412. "valid": False,
  413. "status": "unknown",
  414. "message": "Referer 解析失败",
  415. }
  416. # 匹配监控域名(支持多种存储格式:纯域名、带端口、带 http(s):// 前缀)
  417. result = await db.execute(
  418. select(MonitoredDomain).where(
  419. MonitoredDomain.is_active == True,
  420. MonitoredDomain.domain.in_([
  421. referer_netloc, # 例: 127.0.0.1:8010
  422. referer_host, # 例: 127.0.0.1
  423. f"http://{referer_netloc}", # 例: http://127.0.0.1:8010
  424. f"https://{referer_netloc}", # 例: https://127.0.0.1:8010
  425. ]),
  426. )
  427. )
  428. domain = result.scalars().first()
  429. if not domain or not domain.super_admin_id:
  430. return {
  431. "valid": False,
  432. "status": "unknown",
  433. "message": "域名未注册或无关联超管",
  434. }
  435. # 查询该超管最新的 license(不限 status,取最近创建的一条)
  436. license_result = await db.execute(
  437. select(SuperAdminLicense)
  438. .where(SuperAdminLicense.super_admin_id == domain.super_admin_id)
  439. .order_by(SuperAdminLicense.created_at.desc())
  440. )
  441. lic = license_result.scalars().first()
  442. if not lic:
  443. return {
  444. "valid": False,
  445. "status": "not_found",
  446. "message": "未找到有效 License",
  447. }
  448. days_left = _calc_days_left(lic.expires_at)
  449. if lic.status == "active" and days_left <= 0:
  450. lic.status = "expired"
  451. await db.commit()
  452. # 触发过期通知
  453. await _notify_on_expired(db, lic)
  454. sa_result = await db.execute(
  455. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  456. )
  457. sa = sa_result.scalar_one_or_none()
  458. # 余额检查:License 有效时额外检查超管余额
  459. license_valid = lic.status == "active"
  460. sa_balance = float(sa.balance or 0) if sa else 0
  461. sa_balance_ok = sa_balance > 0
  462. # 如果 License 有效但超管余额不足,整体 valid 为 false
  463. valid = license_valid and sa_balance_ok
  464. status = lic.status
  465. if license_valid and not sa_balance_ok:
  466. status = "insufficient_balance"
  467. return {
  468. "valid": valid,
  469. "status": status,
  470. "super_admin_name": sa.remark or sa.username if sa else str(lic.super_admin_id),
  471. "super_admin_id": lic.super_admin_id,
  472. "sa_balance": sa_balance,
  473. "sa_balance_ok": sa_balance_ok,
  474. "license_key": lic.license_key,
  475. "expires_at": _to_str(lic.expires_at),
  476. "days_left": days_left,
  477. "max_tenants": lic.max_tenants,
  478. "max_users_per_tenant": lic.max_users_per_tenant,
  479. "remark": lic.remark,
  480. }