license.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import math
  2. from datetime import datetime, timezone
  3. from urllib.parse import urlparse
  4. from sqlalchemy import select
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from app.models.license import SuperAdminLicense
  7. from app.models.monitoring import SuperAdmin
  8. from app.models.domain import MonitoredDomain
  9. from app.models.visitor import VisitorInfo
  10. from app.schemas.license import (
  11. LicenseCreate,
  12. LicenseResponse,
  13. LicenseStatusResponse,
  14. LicenseUpdate,
  15. SuperAdminOption,
  16. LicenseListResponse,
  17. )
  18. from app.services.sms import send_license_expired, send_license_restored
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. def _to_str(val) -> str:
  22. if val is None:
  23. return ""
  24. if isinstance(val, datetime):
  25. return val.isoformat()
  26. return str(val)
  27. async def _get_contact_for_license(db: AsyncSession, lic: SuperAdminLicense) -> tuple[str | None, str | None]:
  28. """根据 License 获取关联的联系人 phone 和 company 名称。返回 (phone, company)"""
  29. domain_result = await db.execute(
  30. select(MonitoredDomain).where(
  31. MonitoredDomain.super_admin_id == lic.super_admin_id,
  32. MonitoredDomain.is_active == True,
  33. ).limit(1)
  34. )
  35. domain = domain_result.scalar_one_or_none()
  36. if not domain:
  37. return None, None
  38. sa_result = await db.execute(
  39. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  40. )
  41. sa = sa_result.scalar_one_or_none()
  42. company = sa.remark or sa.username if sa else str(lic.super_admin_id)
  43. visitor_result = await db.execute(
  44. select(VisitorInfo).where(VisitorInfo.domain_id == domain.id)
  45. )
  46. visitor = visitor_result.scalar_one_or_none()
  47. if not visitor or not visitor.phone:
  48. return None, company
  49. return visitor.phone, company
  50. async def _notify_on_expired(db: AsyncSession, lic: SuperAdminLicense):
  51. """License 过期时发送短信通知(静默失败,不阻断主流程)"""
  52. try:
  53. phone, company = await _get_contact_for_license(db, lic)
  54. if phone:
  55. await send_license_expired(phone, company)
  56. except Exception:
  57. logger.exception("发送过期短信失败,license_id=%d", lic.id)
  58. async def _notify_on_restored(db: AsyncSession, lic: SuperAdminLicense):
  59. """License 恢复时发送短信通知(静默失败,不阻断主流程)"""
  60. try:
  61. phone, company = await _get_contact_for_license(db, lic)
  62. if phone:
  63. await send_license_restored(phone, company)
  64. except Exception:
  65. logger.exception("发送恢复短信失败,license_id=%d", lic.id)
  66. def _calc_days_left(expires_at: datetime) -> int:
  67. now = datetime.now(timezone.utc) if expires_at.tzinfo else datetime.now()
  68. delta = expires_at - now
  69. return math.ceil(delta.total_seconds() / (60 * 60 * 24))
  70. async def get_super_admins(db: AsyncSession) -> list[SuperAdminOption]:
  71. """获取所有超级管理员,供下拉选择"""
  72. result = await db.execute(select(SuperAdmin).order_by(SuperAdmin.id))
  73. sas = result.scalars().all()
  74. return [SuperAdminOption(id=sa.id, username=sa.username, nickname=sa.nickname, remark=sa.remark) for sa in sas]
  75. async def create_license(
  76. db: AsyncSession, payload: LicenseCreate
  77. ) -> dict:
  78. """创建或更新 License。同一超管只保留一个 active 的 License,已存在则更新。"""
  79. existing = await db.execute(
  80. select(SuperAdminLicense).where(
  81. SuperAdminLicense.super_admin_id == payload.super_admin_id,
  82. SuperAdminLicense.status == "active",
  83. )
  84. )
  85. exist_lic = existing.scalar_one_or_none()
  86. expires_at = datetime.fromisoformat(payload.expires_at)
  87. if exist_lic:
  88. exist_lic.license_key = payload.license_key
  89. exist_lic.expires_at = expires_at
  90. exist_lic.max_tenants = payload.max_tenants
  91. exist_lic.max_users_per_tenant = payload.max_users_per_tenant
  92. exist_lic.remark = payload.remark
  93. await db.flush()
  94. await db.commit()
  95. return {"message": "License已更新", "license_id": exist_lic.id}
  96. new_lic = SuperAdminLicense(
  97. super_admin_id=payload.super_admin_id,
  98. license_key=payload.license_key,
  99. expires_at=expires_at,
  100. status="active",
  101. max_tenants=payload.max_tenants,
  102. max_users_per_tenant=payload.max_users_per_tenant,
  103. remark=payload.remark,
  104. )
  105. db.add(new_lic)
  106. await db.commit()
  107. await db.refresh(new_lic)
  108. return {"message": "License已创建", "license_id": new_lic.id}
  109. async def list_licenses(
  110. db: AsyncSession,
  111. super_admin_id: int | None = None,
  112. status: str | None = None,
  113. page: int = 1,
  114. size: int = 20,
  115. ) -> LicenseListResponse:
  116. """查询 License 列表"""
  117. stmt = select(SuperAdminLicense)
  118. if super_admin_id is not None:
  119. stmt = stmt.where(SuperAdminLicense.super_admin_id == super_admin_id)
  120. if status:
  121. stmt = stmt.where(SuperAdminLicense.status == status)
  122. stmt = stmt.order_by(SuperAdminLicense.created_at.desc())
  123. total_result = await db.execute(stmt)
  124. all_rows = total_result.scalars().all()
  125. total = len(all_rows)
  126. offset = (page - 1) * size
  127. paged = all_rows[offset: offset + size]
  128. items = []
  129. for r in paged:
  130. sa_result = await db.execute(
  131. select(SuperAdmin).where(SuperAdmin.id == r.super_admin_id)
  132. )
  133. sa = sa_result.scalar_one_or_none()
  134. sa_name = sa.remark or sa.username if sa else str(r.super_admin_id)
  135. # 查询关联的域名
  136. domain_result = await db.execute(
  137. select(MonitoredDomain).where(
  138. MonitoredDomain.super_admin_id == r.super_admin_id,
  139. MonitoredDomain.is_active == True,
  140. ).limit(1)
  141. )
  142. domain_row = domain_result.scalar_one_or_none()
  143. domain_name = domain_row.domain if domain_row else None
  144. # 查询关联的联系人信息
  145. contact = None
  146. if domain_row:
  147. visitor_result = await db.execute(
  148. select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id)
  149. )
  150. visitor = visitor_result.scalar_one_or_none()
  151. if visitor:
  152. contact = {"name": visitor.name, "phone": visitor.phone, "email": visitor.email}
  153. items.append(LicenseResponse(
  154. id=r.id,
  155. super_admin_id=r.super_admin_id,
  156. super_admin_name=sa_name,
  157. license_key=r.license_key,
  158. expires_at=_to_str(r.expires_at),
  159. status=r.status,
  160. max_tenants=r.max_tenants,
  161. max_users_per_tenant=r.max_users_per_tenant,
  162. remark=r.remark,
  163. created_at=_to_str(r.created_at),
  164. updated_at=_to_str(r.updated_at),
  165. domain=domain_name,
  166. contact=contact,
  167. ))
  168. return LicenseListResponse(total=total, items=items)
  169. async def get_license_status(
  170. db: AsyncSession, license_id: int
  171. ) -> LicenseStatusResponse:
  172. """获取单个 License 的详细状态"""
  173. result = await db.execute(
  174. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  175. )
  176. lic = result.scalar_one_or_none()
  177. if not lic:
  178. raise ValueError("License 不存在")
  179. days_left = _calc_days_left(lic.expires_at)
  180. if lic.status == "active" and days_left <= 0:
  181. lic.status = "expired"
  182. await db.commit()
  183. # 触发过期通知
  184. await _notify_on_expired(db, lic)
  185. sa_result = await db.execute(
  186. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  187. )
  188. sa = sa_result.scalar_one_or_none()
  189. return LicenseStatusResponse(
  190. id=lic.id,
  191. super_admin_id=lic.super_admin_id,
  192. super_admin_name=sa.remark or sa.username if sa else str(lic.super_admin_id),
  193. license_key=lic.license_key,
  194. expires_at=_to_str(lic.expires_at),
  195. status=lic.status,
  196. days_left=days_left,
  197. max_tenants=lic.max_tenants,
  198. max_users_per_tenant=lic.max_users_per_tenant,
  199. remark=lic.remark,
  200. )
  201. async def revoke_license(
  202. db: AsyncSession, license_id: int
  203. ) -> dict:
  204. """吊销 License"""
  205. result = await db.execute(
  206. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  207. )
  208. lic = result.scalar_one_or_none()
  209. if not lic:
  210. raise ValueError("License 不存在")
  211. lic.status = "revoked"
  212. await db.commit()
  213. return {"message": "License已吊销"}
  214. async def update_license(
  215. db: AsyncSession, license_id: int, payload: LicenseUpdate
  216. ) -> dict:
  217. """更新 License 的 key 或过期时间"""
  218. result = await db.execute(
  219. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  220. )
  221. lic = result.scalar_one_or_none()
  222. if not lic:
  223. raise ValueError("License 不存在")
  224. changed = False
  225. if payload.license_key is not None:
  226. lic.license_key = payload.license_key
  227. changed = True
  228. if payload.expires_at is not None:
  229. old_status = lic.status
  230. lic.expires_at = datetime.fromisoformat(payload.expires_at)
  231. if _calc_days_left(lic.expires_at) > 0:
  232. lic.status = "active"
  233. else:
  234. lic.status = "expired"
  235. changed = True
  236. # 状态变化时发送短信通知
  237. if old_status == "active" and lic.status == "expired":
  238. await _notify_on_expired(db, lic)
  239. elif old_status == "expired" and lic.status == "active":
  240. await _notify_on_restored(db, lic)
  241. if not changed:
  242. raise ValueError("没有需要更新的字段")
  243. await db.commit()
  244. return {"message": "License已更新", "license_id": lic.id}
  245. async def restore_license(
  246. db: AsyncSession, license_id: int
  247. ) -> dict:
  248. """恢复已吊销的 License"""
  249. result = await db.execute(
  250. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  251. )
  252. lic = result.scalar_one_or_none()
  253. if not lic:
  254. raise ValueError("License 不存在")
  255. if lic.status != "revoked":
  256. raise ValueError("仅可恢复已吊销的 License")
  257. lic.status = "active"
  258. await db.commit()
  259. # 触发恢复通知
  260. await _notify_on_restored(db, lic)
  261. return {"message": "License已恢复"}
  262. async def delete_license(
  263. db: AsyncSession, license_id: int
  264. ) -> dict:
  265. """删除 License 记录"""
  266. result = await db.execute(
  267. select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
  268. )
  269. lic = result.scalar_one_or_none()
  270. if not lic:
  271. raise ValueError("License 不存在")
  272. await db.delete(lic)
  273. await db.commit()
  274. return {"message": "License已删除"}
  275. async def check_license_by_referer(
  276. db: AsyncSession, referer: str | None
  277. ) -> dict:
  278. """通过 Referer 头匹配域名,检查对应超管的 License 状态。
  279. 流程:
  280. 1. 从 Referer 提取域名(host)
  281. 2. 在 MonitoredDomain 中匹配 domain 字段
  282. 3. 获取关联的 super_admin_id
  283. 4. 查询该超管当前 active 的 License 并返回状态
  284. """
  285. if not referer:
  286. return {
  287. "valid": False,
  288. "status": "unknown",
  289. "message": "缺少 Referer 头",
  290. }
  291. # 从 Referer URL 提取域名
  292. try:
  293. parsed = urlparse(referer)
  294. referer_host = parsed.hostname
  295. # 当 Referer 缺少 scheme(如 "127.0.0.1" 或 "example.com")时,
  296. # urlparse 会将其解析为 path 而非 netloc,导致 hostname 为 None。
  297. # 此时回退为使用原始字符串作为域名。
  298. if not referer_host:
  299. referer_host = referer.strip().rstrip("/")
  300. if not referer_host:
  301. return {
  302. "valid": False,
  303. "status": "unknown",
  304. "message": "Referer 格式无效",
  305. }
  306. # netloc 包含端口(如 127.0.0.1:8010),hostname 不包含
  307. referer_netloc = parsed.netloc if parsed.netloc else referer_host
  308. except Exception:
  309. return {
  310. "valid": False,
  311. "status": "unknown",
  312. "message": "Referer 解析失败",
  313. }
  314. # 匹配监控域名(支持多种存储格式:纯域名、带端口、带 http(s):// 前缀)
  315. result = await db.execute(
  316. select(MonitoredDomain).where(
  317. MonitoredDomain.is_active == True,
  318. MonitoredDomain.domain.in_([
  319. referer_netloc, # 例: 127.0.0.1:8010
  320. referer_host, # 例: 127.0.0.1
  321. f"http://{referer_netloc}", # 例: http://127.0.0.1:8010
  322. f"https://{referer_netloc}", # 例: https://127.0.0.1:8010
  323. ]),
  324. )
  325. )
  326. domain = result.scalars().first()
  327. if not domain or not domain.super_admin_id:
  328. return {
  329. "valid": False,
  330. "status": "unknown",
  331. "message": "域名未注册或无关联超管",
  332. }
  333. # 查询该超管最新的 license(不限 status,取最近创建的一条)
  334. license_result = await db.execute(
  335. select(SuperAdminLicense)
  336. .where(SuperAdminLicense.super_admin_id == domain.super_admin_id)
  337. .order_by(SuperAdminLicense.created_at.desc())
  338. )
  339. lic = license_result.scalars().first()
  340. if not lic:
  341. return {
  342. "valid": False,
  343. "status": "not_found",
  344. "message": "未找到有效 License",
  345. }
  346. days_left = _calc_days_left(lic.expires_at)
  347. if lic.status == "active" and days_left <= 0:
  348. lic.status = "expired"
  349. await db.commit()
  350. # 触发过期通知
  351. await _notify_on_expired(db, lic)
  352. sa_result = await db.execute(
  353. select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
  354. )
  355. sa = sa_result.scalar_one_or_none()
  356. return {
  357. "valid": lic.status == "active",
  358. "status": lic.status,
  359. "super_admin_name": sa.remark or sa.username if sa else str(lic.super_admin_id),
  360. "license_key": lic.license_key,
  361. "expires_at": _to_str(lic.expires_at),
  362. "days_left": days_left,
  363. "max_tenants": lic.max_tenants,
  364. "max_users_per_tenant": lic.max_users_per_tenant,
  365. "remark": lic.remark,
  366. }