| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- import math
- from datetime import datetime, timezone, timedelta
- from urllib.parse import urlparse
- CST = timezone(timedelta(hours=8)) # 东八区
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.models.license import SuperAdminLicense
- from app.models.monitoring import SuperAdmin
- from app.models.domain import MonitoredDomain
- from app.models.visitor import VisitorInfo
- from app.schemas.license import (
- LicenseCreate,
- LicenseResponse,
- LicenseStatusResponse,
- LicenseUpdate,
- SuperAdminOption,
- LicenseListResponse,
- )
- from app.services.sms import send_license_expired, send_license_restored, send_license_warning
- from app.redis import get_redis
- import logging
- logger = logging.getLogger(__name__)
- def _to_str(val) -> str:
- if val is None:
- return ""
- if isinstance(val, datetime):
- return val.astimezone(CST).isoformat()
- return str(val)
- async def _get_contact_for_license(db: AsyncSession, lic: SuperAdminLicense) -> tuple[str | None, str | None]:
- """根据 License 获取关联的联系人 phone 和 company 名称。返回 (phone, company)"""
- domain_result = await db.execute(
- select(MonitoredDomain).where(
- MonitoredDomain.super_admin_id == lic.super_admin_id,
- MonitoredDomain.is_active == True,
- ).limit(1)
- )
- domain = domain_result.scalar_one_or_none()
- if not domain:
- return None, None
- sa_result = await db.execute(
- select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
- )
- sa = sa_result.scalar_one_or_none()
- company = sa.remark or sa.username if sa else str(lic.super_admin_id)
- visitor_result = await db.execute(
- select(VisitorInfo).where(VisitorInfo.domain_id == domain.id)
- )
- visitor = visitor_result.scalar_one_or_none()
- if not visitor or not visitor.phone:
- return None, company
- return visitor.phone, company
- async def _notify_on_expired(db: AsyncSession, lic: SuperAdminLicense) -> str:
- """License valid 从 true 变为 false 时发送过期短信。返回: sent | skipped | failed"""
- try:
- r = await get_redis()
- if await r.get(f"sms:expired_sent:{lic.id}"):
- logger.info("License #%d 已发送过过期短信,跳过", lic.id)
- return "skipped"
- phone, company = await _get_contact_for_license(db, lic)
- if not phone:
- logger.warning("License #%d 无联系人手机号,跳过预警短信", lic.id)
- return "skipped"
- logger.info("发送 License 预警短信(过期): license_id=%d, phone=%s, company=%s", lic.id, phone, company)
- ok, reason = await send_license_expired(phone, company)
- if ok:
- logger.info("License 预警短信发送成功: license_id=%d", lic.id)
- await r.setex(f"sms:expired_sent:{lic.id}", 86400 * 30, "1") # 30 天过期
- return "sent"
- logger.error("License 预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
- return "failed"
- except Exception:
- logger.exception("发送预警短信异常,license_id=%d", lic.id)
- return "failed"
- async def _notify_on_restored(db: AsyncSession, lic: SuperAdminLicense) -> str:
- """License valid 从 false 变为 true 时发送恢复短信。返回: sent | skipped | failed"""
- try:
- phone, company = await _get_contact_for_license(db, lic)
- if not phone:
- logger.warning("License #%d 无联系人手机号,跳过恢复短信", lic.id)
- return "skipped"
- logger.info("发送 License 恢复短信: license_id=%d, phone=%s, company=%s", lic.id, phone, company)
- ok, reason = await send_license_restored(phone, company)
- if ok:
- logger.info("License 恢复短信发送成功: license_id=%d", lic.id)
- r = await get_redis()
- await r.delete(f"sms:expired_sent:{lic.id}")
- await r.delete(f"sms:warning_sent:{lic.id}")
- return "sent"
- logger.error("License 恢复短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
- return "failed"
- except Exception:
- logger.exception("发送恢复短信异常,license_id=%d", lic.id)
- return "failed"
- async def _notify_on_warning(db: AsyncSession, lic: SuperAdminLicense, days_left: int) -> str:
- """License 即将过期(剩余 7 天)预警短信。返回: sent | skipped | failed"""
- try:
- r = await get_redis()
- if await r.get(f"sms:warning_sent:{lic.id}"):
- logger.info("License #%d 已发送过预警短信,跳过", lic.id)
- return "skipped"
- phone, company = await _get_contact_for_license(db, lic)
- if not phone:
- logger.warning("License #%d 无联系人手机号,跳过即将过期预警短信", lic.id)
- return "skipped"
- logger.info("发送 License 即将过期预警: license_id=%d, phone=%s, company=%s, days_left=%d", lic.id, phone, company, days_left)
- ok, reason = await send_license_warning(phone, company, days_left)
- if ok:
- logger.info("License 即将过期预警短信发送成功: license_id=%d", lic.id)
- await r.setex(f"sms:warning_sent:{lic.id}", 86400 * 30, "1")
- return "sent"
- logger.error("License 即将过期预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason)
- return "failed"
- except Exception:
- logger.exception("发送即将过期预警短信异常,license_id=%d", lic.id)
- return "failed"
- def _calc_days_left(expires_at: datetime) -> int:
- now = datetime.now(timezone.utc) if expires_at.tzinfo else datetime.now()
- delta = expires_at - now
- return math.ceil(delta.total_seconds() / (60 * 60 * 24))
- async def get_super_admins(db: AsyncSession) -> list[SuperAdminOption]:
- """获取所有超级管理员,供下拉选择"""
- result = await db.execute(select(SuperAdmin).order_by(SuperAdmin.id))
- sas = result.scalars().all()
- return [SuperAdminOption(id=sa.id, username=sa.username, nickname=sa.nickname, remark=sa.remark) for sa in sas]
- async def create_license(
- db: AsyncSession, payload: LicenseCreate
- ) -> dict:
- """创建或更新 License。同一超管只保留一个 active 的 License,已存在则更新。"""
- existing = await db.execute(
- select(SuperAdminLicense).where(
- SuperAdminLicense.super_admin_id == payload.super_admin_id,
- SuperAdminLicense.status == "active",
- )
- )
- exist_lic = existing.scalar_one_or_none()
- expires_at = datetime.fromisoformat(payload.expires_at)
- if expires_at.tzinfo is None:
- expires_at = expires_at.replace(tzinfo=CST)
- if exist_lic:
- exist_lic.license_key = payload.license_key
- exist_lic.expires_at = expires_at
- exist_lic.max_tenants = payload.max_tenants
- exist_lic.max_users_per_tenant = payload.max_users_per_tenant
- exist_lic.remark = payload.remark
- await db.flush()
- await db.commit()
- return {"message": "License已更新", "license_id": exist_lic.id}
- new_lic = SuperAdminLicense(
- super_admin_id=payload.super_admin_id,
- license_key=payload.license_key,
- expires_at=expires_at,
- status="active",
- max_tenants=payload.max_tenants,
- max_users_per_tenant=payload.max_users_per_tenant,
- remark=payload.remark,
- )
- db.add(new_lic)
- await db.commit()
- await db.refresh(new_lic)
- return {"message": "License已创建", "license_id": new_lic.id}
- async def list_licenses(
- db: AsyncSession,
- super_admin_id: int | None = None,
- status: str | None = None,
- page: int = 1,
- size: int = 20,
- ) -> LicenseListResponse:
- """查询 License 列表"""
- stmt = select(SuperAdminLicense)
- if super_admin_id is not None:
- stmt = stmt.where(SuperAdminLicense.super_admin_id == super_admin_id)
- if status:
- stmt = stmt.where(SuperAdminLicense.status == status)
- stmt = stmt.order_by(SuperAdminLicense.created_at.desc())
- total_result = await db.execute(stmt)
- all_rows = total_result.scalars().all()
- total = len(all_rows)
- offset = (page - 1) * size
- paged = all_rows[offset: offset + size]
- items = []
- for r in paged:
- sa_result = await db.execute(
- select(SuperAdmin).where(SuperAdmin.id == r.super_admin_id)
- )
- sa = sa_result.scalar_one_or_none()
- sa_name = sa.remark or sa.username if sa else str(r.super_admin_id)
- # 查询关联的域名
- domain_result = await db.execute(
- select(MonitoredDomain).where(
- MonitoredDomain.super_admin_id == r.super_admin_id,
- MonitoredDomain.is_active == True,
- ).limit(1)
- )
- domain_row = domain_result.scalar_one_or_none()
- domain_name = domain_row.domain if domain_row else None
- # 查询关联的联系人信息
- contact = None
- if domain_row:
- visitor_result = await db.execute(
- select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id)
- )
- visitor = visitor_result.scalar_one_or_none()
- if visitor:
- contact = {"name": visitor.name, "phone": visitor.phone, "email": visitor.email}
- items.append(LicenseResponse(
- id=r.id,
- super_admin_id=r.super_admin_id,
- super_admin_name=sa_name,
- license_key=r.license_key,
- expires_at=_to_str(r.expires_at),
- status=r.status,
- max_tenants=r.max_tenants,
- max_users_per_tenant=r.max_users_per_tenant,
- remark=r.remark,
- created_at=_to_str(r.created_at),
- updated_at=_to_str(r.updated_at),
- domain=domain_name,
- contact=contact,
- ))
- return LicenseListResponse(total=total, items=items)
- async def get_license_status(
- db: AsyncSession, license_id: int
- ) -> LicenseStatusResponse:
- """获取单个 License 的详细状态"""
- result = await db.execute(
- select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
- )
- lic = result.scalar_one_or_none()
- if not lic:
- raise ValueError("License 不存在")
- days_left = _calc_days_left(lic.expires_at)
- sms_status = None
- if lic.status == "active" and days_left <= 0:
- lic.status = "expired"
- await db.commit()
- sms_status = await _notify_on_expired(db, lic)
- sa_result = await db.execute(
- select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
- )
- sa = sa_result.scalar_one_or_none()
- return LicenseStatusResponse(
- id=lic.id,
- super_admin_id=lic.super_admin_id,
- super_admin_name=sa.remark or sa.username if sa else str(lic.super_admin_id),
- license_key=lic.license_key,
- expires_at=_to_str(lic.expires_at),
- status=lic.status,
- days_left=days_left,
- max_tenants=lic.max_tenants,
- max_users_per_tenant=lic.max_users_per_tenant,
- remark=lic.remark,
- sms_status=sms_status,
- )
- async def revoke_license(
- db: AsyncSession, license_id: int
- ) -> dict:
- """吊销 License"""
- result = await db.execute(
- select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
- )
- lic = result.scalar_one_or_none()
- if not lic:
- raise ValueError("License 不存在")
- old_status = lic.status
- lic.status = "revoked"
- await db.commit()
- # 吊销即 valid 变为 false,不管原状态如何都发预警短信
- sms_status = await _notify_on_expired(db, lic)
- return {"message": "License已吊销", "sms_status": sms_status}
- async def update_license(
- db: AsyncSession, license_id: int, payload: LicenseUpdate
- ) -> dict:
- """更新 License 的 key 或过期时间"""
- result = await db.execute(
- select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
- )
- lic = result.scalar_one_or_none()
- if not lic:
- raise ValueError("License 不存在")
- changed = False
- old_status = lic.status
- sms_status = "none"
- if payload.license_key is not None:
- lic.license_key = payload.license_key
- changed = True
- if payload.expires_at is not None:
- expires_at = datetime.fromisoformat(payload.expires_at)
- if expires_at.tzinfo is None:
- expires_at = expires_at.replace(tzinfo=CST)
- lic.expires_at = expires_at
- if _calc_days_left(lic.expires_at) > 0:
- lic.status = "active"
- else:
- lic.status = "expired"
- changed = True
- if not changed:
- raise ValueError("没有需要更新的字段")
- # 修改过期时间后已超出 7 天窗口,重置预警标记
- if payload.expires_at is not None and _calc_days_left(lic.expires_at) > 7:
- r = await get_redis()
- await r.delete(f"sms:warning_sent:{lic.id}")
- await r.delete(f"sms:expired_sent:{lic.id}")
- elif old_status != "active" and lic.status == "active":
- r = await get_redis()
- await r.delete(f"sms:warning_sent:{lic.id}")
- await r.delete(f"sms:expired_sent:{lic.id}")
- await db.commit()
- # valid 变化时发送短信
- if old_status == "active" and lic.status != "active":
- sms_status = await _notify_on_expired(db, lic)
- elif old_status != "active" and lic.status == "active":
- sms_status = await _notify_on_restored(db, lic)
- return {"message": "License已更新", "license_id": lic.id, "sms_status": sms_status}
- async def restore_license(
- db: AsyncSession, license_id: int
- ) -> dict:
- """恢复已吊销的 License"""
- result = await db.execute(
- select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
- )
- lic = result.scalar_one_or_none()
- if not lic:
- raise ValueError("License 不存在")
- if lic.status != "revoked":
- raise ValueError("仅可恢复已吊销的 License")
- lic.status = "active"
- await db.commit()
- # 清除 Redis 去重标记
- r = await get_redis()
- await r.delete(f"sms:warning_sent:{lic.id}")
- await r.delete(f"sms:expired_sent:{lic.id}")
- sms_status = await _notify_on_restored(db, lic)
- return {"message": "License已恢复", "sms_status": sms_status}
- async def delete_license(
- db: AsyncSession, license_id: int
- ) -> dict:
- """删除 License 记录"""
- result = await db.execute(
- select(SuperAdminLicense).where(SuperAdminLicense.id == license_id)
- )
- lic = result.scalar_one_or_none()
- if not lic:
- raise ValueError("License 不存在")
- await db.delete(lic)
- await db.commit()
- return {"message": "License已删除"}
- async def check_license_by_referer(
- db: AsyncSession, referer: str | None
- ) -> dict:
- """通过 Referer 头匹配域名,检查对应超管的 License 状态。
- 流程:
- 1. 从 Referer 提取域名(host)
- 2. 在 MonitoredDomain 中匹配 domain 字段
- 3. 获取关联的 super_admin_id
- 4. 查询该超管当前 active 的 License 并返回状态
- """
- if not referer:
- return {
- "valid": False,
- "status": "unknown",
- "message": "缺少 Referer 头",
- }
- # 从 Referer URL 提取域名
- try:
- parsed = urlparse(referer)
- referer_host = parsed.hostname
- # 当 Referer 缺少 scheme(如 "127.0.0.1" 或 "example.com")时,
- # urlparse 会将其解析为 path 而非 netloc,导致 hostname 为 None。
- # 此时回退为使用原始字符串作为域名。
- if not referer_host:
- referer_host = referer.strip().rstrip("/")
- if not referer_host:
- return {
- "valid": False,
- "status": "unknown",
- "message": "Referer 格式无效",
- }
- # netloc 包含端口(如 127.0.0.1:8010),hostname 不包含
- referer_netloc = parsed.netloc if parsed.netloc else referer_host
- except Exception:
- return {
- "valid": False,
- "status": "unknown",
- "message": "Referer 解析失败",
- }
- # 匹配监控域名(支持多种存储格式:纯域名、带端口、带 http(s):// 前缀)
- result = await db.execute(
- select(MonitoredDomain).where(
- MonitoredDomain.is_active == True,
- MonitoredDomain.domain.in_([
- referer_netloc, # 例: 127.0.0.1:8010
- referer_host, # 例: 127.0.0.1
- f"http://{referer_netloc}", # 例: http://127.0.0.1:8010
- f"https://{referer_netloc}", # 例: https://127.0.0.1:8010
- ]),
- )
- )
- domain = result.scalars().first()
- if not domain or not domain.super_admin_id:
- return {
- "valid": False,
- "status": "unknown",
- "message": "域名未注册或无关联超管",
- }
- # 查询该超管最新的 license(不限 status,取最近创建的一条)
- license_result = await db.execute(
- select(SuperAdminLicense)
- .where(SuperAdminLicense.super_admin_id == domain.super_admin_id)
- .order_by(SuperAdminLicense.created_at.desc())
- )
- lic = license_result.scalars().first()
- if not lic:
- return {
- "valid": False,
- "status": "not_found",
- "message": "未找到有效 License",
- }
- days_left = _calc_days_left(lic.expires_at)
- if lic.status == "active" and days_left <= 0:
- lic.status = "expired"
- await db.commit()
- # 触发过期通知
- await _notify_on_expired(db, lic)
- sa_result = await db.execute(
- select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id)
- )
- sa = sa_result.scalar_one_or_none()
- return {
- "valid": lic.status == "active",
- "status": lic.status,
- "super_admin_name": sa.remark or sa.username if sa else str(lic.super_admin_id),
- "license_key": lic.license_key,
- "expires_at": _to_str(lic.expires_at),
- "days_left": days_left,
- "max_tenants": lic.max_tenants,
- "max_users_per_tenant": lic.max_users_per_tenant,
- "remark": lic.remark,
- }
|