license.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import math
  2. from datetime import datetime, timezone, timedelta
  3. from urllib.parse import urlparse
  4. CST = timezone(timedelta(hours=8)) # 东八区
  5. from sqlalchemy import select
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from app.models.license import SuperAdminLicense
  8. from app.models.monitoring import SuperAdmin
  9. from app.models.domain import MonitoredDomain
  10. from app.models.visitor import VisitorInfo
  11. from app.schemas.license import (
  12. LicenseCreate,
  13. LicenseResponse,
  14. LicenseStatusResponse,
  15. LicenseUpdate,
  16. SuperAdminOption,
  17. LicenseListResponse,
  18. )
  19. from app.services.sms import send_license_expired, send_license_restored, send_license_warning
  20. from app.redis import get_redis
  21. import logging
  22. logger = logging.getLogger(__name__)
  23. def _to_str(val) -> str:
  24. if val is None:
  25. return ""
  26. if isinstance(val, datetime):
  27. return val.astimezone(CST).isoformat()
  28. return str(val)
  29. async def _get_contact_for_license(db: AsyncSession, lic: SuperAdminLicense) -> tuple[str | None, str | None]:
  30. """根据 License 获取关联的联系人 phone 和 company 名称。返回 (phone, company)"""
  31. domain_result = await db.execute(
  32. select(MonitoredDomain).where(
  33. MonitoredDomain.super_admin_id == lic.super_admin_id,
  34. MonitoredDomain.is_active == True,
  35. ).limit(1)
  36. )
  37. domain = domain_result.scalar_one_or_none()
  38. if not domain:
  39. return None, None
  40. sa_result = await db.execute(
  41. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  42. )
  43. sa = sa_result.scalar_one_or_none()
  44. company = sa.remark or sa.username if sa else str(lic.super_admin_id)
  45. visitor_result = await db.execute(
  46. select(VisitorInfo).where(VisitorInfo.domain_id == domain.id)
  47. )
  48. visitor = visitor_result.scalar_one_or_none()
  49. if not visitor or not visitor.phone:
  50. return None, company
  51. return visitor.phone, company
  52. async def _notify_on_expired(db: AsyncSession, lic: SuperAdminLicense) -> str:
  53. """License valid 从 true 变为 false 时发送过期短信。返回: sent | skipped | failed"""
  54. try:
  55. r = await get_redis()
  56. if await r.get(f"sms:expired_sent:{lic.id}"):
  57. logger.info("License #%d 已发送过过期短信,跳过", lic.id)
  58. return "skipped"
  59. phone, company = await _get_contact_for_license(db, lic)
  60. if not phone:
  61. logger.warning("License #%d 无联系人手机号,跳过预警短信", lic.id)
  62. return "skipped"
  63. logger.info("发送 License 预警短信(过期): license_id=%d, phone=%s, company=%s", lic.id, phone, company)
  64. ok, reason = await send_license_expired(phone, company)
  65. if ok:
  66. logger.info("License 预警短信发送成功: license_id=%d", lic.id)
  67. await r.setex(f"sms:expired_sent:{lic.id}", 86400 * 30, "1") # 30 天过期
  68. return "sent"
  69. logger.error("License 预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
  70. return "failed"
  71. except Exception:
  72. logger.exception("发送预警短信异常,license_id=%d", lic.id)
  73. return "failed"
  74. async def _notify_on_restored(db: AsyncSession, lic: SuperAdminLicense) -> str:
  75. """License valid 从 false 变为 true 时发送恢复短信。返回: sent | skipped | failed"""
  76. try:
  77. phone, company = await _get_contact_for_license(db, lic)
  78. if not phone:
  79. logger.warning("License #%d 无联系人手机号,跳过恢复短信", lic.id)
  80. return "skipped"
  81. logger.info("发送 License 恢复短信: license_id=%d, phone=%s, company=%s", lic.id, phone, company)
  82. ok, reason = await send_license_restored(phone, company)
  83. if ok:
  84. logger.info("License 恢复短信发送成功: license_id=%d", lic.id)
  85. r = await get_redis()
  86. await r.delete(f"sms:expired_sent:{lic.id}")
  87. await r.delete(f"sms:warning_sent:{lic.id}")
  88. return "sent"
  89. logger.error("License 恢复短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
  90. return "failed"
  91. except Exception:
  92. logger.exception("发送恢复短信异常,license_id=%d", lic.id)
  93. return "failed"
  94. async def _notify_on_warning(db: AsyncSession, lic: SuperAdminLicense, days_left: int) -> str:
  95. """License 即将过期(剩余 7 天)预警短信。返回: sent | skipped | failed"""
  96. try:
  97. r = await get_redis()
  98. if await r.get(f"sms:warning_sent:{lic.id}"):
  99. logger.info("License #%d 已发送过预警短信,跳过", lic.id)
  100. return "skipped"
  101. phone, company = await _get_contact_for_license(db, lic)
  102. if not phone:
  103. logger.warning("License #%d 无联系人手机号,跳过即将过期预警短信", lic.id)
  104. return "skipped"
  105. logger.info("发送 License 即将过期预警: license_id=%d, phone=%s, company=%s, days_left=%d", lic.id, phone, company, days_left)
  106. ok, reason = await send_license_warning(phone, company, days_left)
  107. if ok:
  108. logger.info("License 即将过期预警短信发送成功: license_id=%d", lic.id)
  109. await r.setex(f"sms:warning_sent:{lic.id}", 86400 * 30, "1")
  110. return "sent"
  111. logger.error("License 即将过期预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
  112. return "failed"
  113. except Exception:
  114. logger.exception("发送即将过期预警短信异常,license_id=%d", lic.id)
  115. return "failed"
  116. def _calc_days_left(expires_at: datetime) -> int:
  117. now = datetime.now(timezone.utc) if expires_at.tzinfo else datetime.now()
  118. delta = expires_at - now
  119. return math.ceil(delta.total_seconds() / (60 * 60 * 24))
  120. async def get_super_admins(db: AsyncSession) -> list[SuperAdminOption]:
  121. """获取所有超级管理员,供下拉选择"""
  122. result = await db.execute(select(SuperAdmin).order_by(SuperAdmin.id))
  123. sas = result.scalars().all()
  124. return [SuperAdminOption(id=sa.id, username=sa.username, nickname=sa.nickname, remark=sa.remark) for sa in sas]
  125. async def create_license(
  126. db: AsyncSession, payload: LicenseCreate
  127. ) -> dict:
  128. """创建或更新 License。同一超管只保留一个 active 的 License,已存在则更新。"""
  129. existing = await db.execute(
  130. select(SuperAdminLicense).where(
  131. SuperAdminLicense.super_admin_id == payload.super_admin_id,
  132. SuperAdminLicense.status == "active",
  133. )
  134. )
  135. exist_lic = existing.scalar_one_or_none()
  136. expires_at = datetime.fromisoformat(payload.expires_at)
  137. if expires_at.tzinfo is None:
  138. expires_at = expires_at.replace(tzinfo=CST)
  139. if exist_lic:
  140. exist_lic.license_key = payload.license_key
  141. exist_lic.expires_at = expires_at
  142. exist_lic.max_tenants = payload.max_tenants
  143. exist_lic.max_users_per_tenant = payload.max_users_per_tenant
  144. exist_lic.remark = payload.remark
  145. await db.flush()
  146. await db.commit()
  147. return {"message": "License已更新", "license_id": exist_lic.id}
  148. new_lic = SuperAdminLicense(
  149. super_admin_id=payload.super_admin_id,
  150. license_key=payload.license_key,
  151. expires_at=expires_at,
  152. status="active",
  153. max_tenants=payload.max_tenants,
  154. max_users_per_tenant=payload.max_users_per_tenant,
  155. remark=payload.remark,
  156. )
  157. db.add(new_lic)
  158. await db.commit()
  159. await db.refresh(new_lic)
  160. return {"message": "License已创建", "license_id": new_lic.id}
  161. async def list_licenses(
  162. db: AsyncSession,
  163. super_admin_id: int | None = None,
  164. status: str | None = None,
  165. page: int = 1,
  166. size: int = 20,
  167. ) -> LicenseListResponse:
  168. """查询 License 列表"""
  169. stmt = select(SuperAdminLicense)
  170. if super_admin_id is not None:
  171. stmt = stmt.where(SuperAdminLicense.super_admin_id == super_admin_id)
  172. if status:
  173. stmt = stmt.where(SuperAdminLicense.status == status)
  174. stmt = stmt.order_by(SuperAdminLicense.created_at.desc())
  175. total_result = await db.execute(stmt)
  176. all_rows = total_result.scalars().all()
  177. total = len(all_rows)
  178. offset = (page - 1) * size
  179. paged = all_rows[offset: offset + size]
  180. items = []
  181. for r in paged:
  182. sa_result = await db.execute(
  183. select(SuperAdmin).where(SuperAdmin.id == r.super_admin_id)
  184. )
  185. sa = sa_result.scalar_one_or_none()
  186. sa_name = sa.remark or sa.username if sa else str(r.super_admin_id)
  187. # 查询关联的域名
  188. domain_result = await db.execute(
  189. select(MonitoredDomain).where(
  190. MonitoredDomain.super_admin_id == r.super_admin_id,
  191. MonitoredDomain.is_active == True,
  192. ).limit(1)
  193. )
  194. domain_row = domain_result.scalar_one_or_none()
  195. domain_name = domain_row.domain if domain_row else None
  196. # 查询关联的联系人信息
  197. contact = None
  198. if domain_row:
  199. visitor_result = await db.execute(
  200. select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id)
  201. )
  202. visitor = visitor_result.scalar_one_or_none()
  203. if visitor:
  204. contact = {"name": visitor.name, "phone": visitor.phone, "email": visitor.email}
  205. items.append(LicenseResponse(
  206. id=r.id,
  207. super_admin_id=r.super_admin_id,
  208. super_admin_name=sa_name,
  209. license_key=r.license_key,
  210. expires_at=_to_str(r.expires_at),
  211. status=r.status,
  212. max_tenants=r.max_tenants,
  213. max_users_per_tenant=r.max_users_per_tenant,
  214. remark=r.remark,
  215. created_at=_to_str(r.created_at),
  216. updated_at=_to_str(r.updated_at),
  217. domain=domain_name,
  218. contact=contact,
  219. ))
  220. return LicenseListResponse(total=total, items=items)
  221. async def get_license_status(
  222. db: AsyncSession, license_id: int
  223. ) -> LicenseStatusResponse:
  224. """获取单个 License 的详细状态"""
  225. result = await db.execute(
  226. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  227. )
  228. lic = result.scalar_one_or_none()
  229. if not lic:
  230. raise ValueError("License 不存在")
  231. days_left = _calc_days_left(lic.expires_at)
  232. sms_status = None
  233. if lic.status == "active" and days_left <= 0:
  234. lic.status = "expired"
  235. await db.commit()
  236. sms_status = await _notify_on_expired(db, lic)
  237. sa_result = await db.execute(
  238. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  239. )
  240. sa = sa_result.scalar_one_or_none()
  241. return LicenseStatusResponse(
  242. id=lic.id,
  243. super_admin_id=lic.super_admin_id,
  244. super_admin_name=sa.remark or sa.username if sa else str(lic.super_admin_id),
  245. license_key=lic.license_key,
  246. expires_at=_to_str(lic.expires_at),
  247. status=lic.status,
  248. days_left=days_left,
  249. max_tenants=lic.max_tenants,
  250. max_users_per_tenant=lic.max_users_per_tenant,
  251. remark=lic.remark,
  252. sms_status=sms_status,
  253. )
  254. async def revoke_license(
  255. db: AsyncSession, license_id: int
  256. ) -> dict:
  257. """吊销 License"""
  258. result = await db.execute(
  259. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  260. )
  261. lic = result.scalar_one_or_none()
  262. if not lic:
  263. raise ValueError("License 不存在")
  264. old_status = lic.status
  265. lic.status = "revoked"
  266. await db.commit()
  267. # 吊销即 valid 变为 false,不管原状态如何都发预警短信
  268. sms_status = await _notify_on_expired(db, lic)
  269. return {"message": "License已吊销", "sms_status": sms_status}
  270. async def update_license(
  271. db: AsyncSession, license_id: int, payload: LicenseUpdate
  272. ) -> dict:
  273. """更新 License 的 key 或过期时间"""
  274. result = await db.execute(
  275. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  276. )
  277. lic = result.scalar_one_or_none()
  278. if not lic:
  279. raise ValueError("License 不存在")
  280. changed = False
  281. old_status = lic.status
  282. sms_status = "none"
  283. if payload.license_key is not None:
  284. lic.license_key = payload.license_key
  285. changed = True
  286. if payload.expires_at is not None:
  287. expires_at = datetime.fromisoformat(payload.expires_at)
  288. if expires_at.tzinfo is None:
  289. expires_at = expires_at.replace(tzinfo=CST)
  290. lic.expires_at = expires_at
  291. if _calc_days_left(lic.expires_at) > 0:
  292. lic.status = "active"
  293. else:
  294. lic.status = "expired"
  295. changed = True
  296. if not changed:
  297. raise ValueError("没有需要更新的字段")
  298. # 修改过期时间后已超出 7 天窗口,重置预警标记
  299. if payload.expires_at is not None and _calc_days_left(lic.expires_at) > 7:
  300. r = await get_redis()
  301. await r.delete(f"sms:warning_sent:{lic.id}")
  302. await r.delete(f"sms:expired_sent:{lic.id}")
  303. elif old_status != "active" and lic.status == "active":
  304. r = await get_redis()
  305. await r.delete(f"sms:warning_sent:{lic.id}")
  306. await r.delete(f"sms:expired_sent:{lic.id}")
  307. await db.commit()
  308. # valid 变化时发送短信
  309. if old_status == "active" and lic.status != "active":
  310. sms_status = await _notify_on_expired(db, lic)
  311. elif old_status != "active" and lic.status == "active":
  312. sms_status = await _notify_on_restored(db, lic)
  313. return {"message": "License已更新", "license_id": lic.id, "sms_status": sms_status}
  314. async def restore_license(
  315. db: AsyncSession, license_id: int
  316. ) -> dict:
  317. """恢复已吊销的 License"""
  318. result = await db.execute(
  319. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  320. )
  321. lic = result.scalar_one_or_none()
  322. if not lic:
  323. raise ValueError("License 不存在")
  324. if lic.status != "revoked":
  325. raise ValueError("仅可恢复已吊销的 License")
  326. lic.status = "active"
  327. await db.commit()
  328. # 清除 Redis 去重标记
  329. r = await get_redis()
  330. await r.delete(f"sms:warning_sent:{lic.id}")
  331. await r.delete(f"sms:expired_sent:{lic.id}")
  332. sms_status = await _notify_on_restored(db, lic)
  333. return {"message": "License已恢复", "sms_status": sms_status}
  334. async def delete_license(
  335. db: AsyncSession, license_id: int
  336. ) -> dict:
  337. """删除 License 记录"""
  338. result = await db.execute(
  339. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  340. )
  341. lic = result.scalar_one_or_none()
  342. if not lic:
  343. raise ValueError("License 不存在")
  344. await db.delete(lic)
  345. await db.commit()
  346. return {"message": "License已删除"}
  347. async def check_license_by_referer(
  348. db: AsyncSession, referer: str | None
  349. ) -> dict:
  350. """通过 Referer 头匹配域名,检查对应超管的 License 状态。
  351. 流程:
  352. 1. 从 Referer 提取域名(host)
  353. 2. 在 MonitoredDomain 中匹配 domain 字段
  354. 3. 获取关联的 super_admin_id
  355. 4. 查询该超管当前 active 的 License 并返回状态
  356. """
  357. if not referer:
  358. return {
  359. "valid": False,
  360. "status": "unknown",
  361. "message": "缺少 Referer 头",
  362. }
  363. # 从 Referer URL 提取域名
  364. try:
  365. parsed = urlparse(referer)
  366. referer_host = parsed.hostname
  367. # 当 Referer 缺少 scheme(如 "127.0.0.1" 或 "example.com")时,
  368. # urlparse 会将其解析为 path 而非 netloc,导致 hostname 为 None。
  369. # 此时回退为使用原始字符串作为域名。
  370. if not referer_host:
  371. referer_host = referer.strip().rstrip("/")
  372. if not referer_host:
  373. return {
  374. "valid": False,
  375. "status": "unknown",
  376. "message": "Referer 格式无效",
  377. }
  378. # netloc 包含端口(如 127.0.0.1:8010),hostname 不包含
  379. referer_netloc = parsed.netloc if parsed.netloc else referer_host
  380. except Exception:
  381. return {
  382. "valid": False,
  383. "status": "unknown",
  384. "message": "Referer 解析失败",
  385. }
  386. # 匹配监控域名(支持多种存储格式:纯域名、带端口、带 http(s):// 前缀)
  387. result = await db.execute(
  388. select(MonitoredDomain).where(
  389. MonitoredDomain.is_active == True,
  390. MonitoredDomain.domain.in_([
  391. referer_netloc, # 例: 127.0.0.1:8010
  392. referer_host, # 例: 127.0.0.1
  393. f"http://{referer_netloc}", # 例: http://127.0.0.1:8010
  394. f"https://{referer_netloc}", # 例: https://127.0.0.1:8010
  395. ]),
  396. )
  397. )
  398. domain = result.scalars().first()
  399. if not domain or not domain.super_admin_id:
  400. return {
  401. "valid": False,
  402. "status": "unknown",
  403. "message": "域名未注册或无关联超管",
  404. }
  405. # 查询该超管最新的 license(不限 status,取最近创建的一条)
  406. license_result = await db.execute(
  407. select(SuperAdminLicense)
  408. .where(SuperAdminLicense.super_admin_id == domain.super_admin_id)
  409. .order_by(SuperAdminLicense.created_at.desc())
  410. )
  411. lic = license_result.scalars().first()
  412. if not lic:
  413. return {
  414. "valid": False,
  415. "status": "not_found",
  416. "message": "未找到有效 License",
  417. }
  418. days_left = _calc_days_left(lic.expires_at)
  419. if lic.status == "active" and days_left <= 0:
  420. lic.status = "expired"
  421. await db.commit()
  422. # 触发过期通知
  423. await _notify_on_expired(db, lic)
  424. sa_result = await db.execute(
  425. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  426. )
  427. sa = sa_result.scalar_one_or_none()
  428. return {
  429. "valid": lic.status == "active",
  430. "status": lic.status,
  431. "super_admin_name": sa.remark or sa.username if sa else str(lic.super_admin_id),
  432. "license_key": lic.license_key,
  433. "expires_at": _to_str(lic.expires_at),
  434. "days_left": days_left,
  435. "max_tenants": lic.max_tenants,
  436. "max_users_per_tenant": lic.max_users_per_tenant,
  437. "remark": lic.remark,
  438. }