import math from datetime import datetime, timezone, timedelta from urllib.parse import urlparse CST = timezone(timedelta(hours=8)) # 东八区 from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.license import SuperAdminLicense from app.models.monitoring import SuperAdmin from app.models.domain import MonitoredDomain from app.models.visitor import VisitorInfo from app.schemas.license import ( LicenseCreate, LicenseResponse, LicenseStatusResponse, LicenseUpdate, SuperAdminOption, LicenseListResponse, ) from app.services.sms import send_license_expired, send_license_restored, send_license_warning from app.redis import get_redis import logging logger = logging.getLogger(__name__) def _to_str(val) -> str: if val is None: return "" if isinstance(val, datetime): return val.astimezone(CST).isoformat() return str(val) async def _get_contact_for_license(db: AsyncSession, lic: SuperAdminLicense) -> tuple[str | None, str | None]: """根据 License 获取关联的联系人 phone 和 company 名称。返回 (phone, company)""" domain_result = await db.execute( select(MonitoredDomain).where( MonitoredDomain.super_admin_id == lic.super_admin_id, MonitoredDomain.is_active == True, ).limit(1) ) domain = domain_result.scalar_one_or_none() if not domain: return None, None sa_result = await db.execute( select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id) ) sa = sa_result.scalar_one_or_none() company = sa.remark or sa.username if sa else str(lic.super_admin_id) visitor_result = await db.execute( select(VisitorInfo).where(VisitorInfo.domain_id == domain.id) ) visitor = visitor_result.scalar_one_or_none() if not visitor or not visitor.phone: return None, company return visitor.phone, company async def _notify_on_expired(db: AsyncSession, lic: SuperAdminLicense) -> str: """License valid 从 true 变为 false 时发送过期短信。返回: sent | skipped | failed""" try: r = await get_redis() if await r.get(f"sms:expired_sent:{lic.id}"): logger.info("License #%d 已发送过过期短信,跳过", lic.id) return "skipped" phone, company = await _get_contact_for_license(db, lic) if not phone: logger.warning("License #%d 无联系人手机号,跳过预警短信", lic.id) return "skipped" logger.info("发送 License 预警短信(过期): license_id=%d, phone=%s, company=%s", lic.id, phone, company) ok, reason = await send_license_expired(phone, company) if ok: logger.info("License 预警短信发送成功: license_id=%d", lic.id) await r.setex(f"sms:expired_sent:{lic.id}", 86400 * 30, "1") # 30 天过期 return "sent" logger.error("License 预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason) return "failed" except Exception: logger.exception("发送预警短信异常,license_id=%d", lic.id) return "failed" async def _notify_on_restored(db: AsyncSession, lic: SuperAdminLicense) -> str: """License valid 从 false 变为 true 时发送恢复短信。返回: sent | skipped | failed""" try: phone, company = await _get_contact_for_license(db, lic) if not phone: logger.warning("License #%d 无联系人手机号,跳过恢复短信", lic.id) return "skipped" logger.info("发送 License 恢复短信: license_id=%d, phone=%s, company=%s", lic.id, phone, company) ok, reason = await send_license_restored(phone, company) if ok: logger.info("License 恢复短信发送成功: license_id=%d", lic.id) r = await get_redis() await r.delete(f"sms:expired_sent:{lic.id}") await r.delete(f"sms:warning_sent:{lic.id}") return "sent" logger.error("License 恢复短信发送失败: license_id=%d, 原因=%s", lic.id, reason) return "failed" except Exception: logger.exception("发送恢复短信异常,license_id=%d", lic.id) return "failed" async def _notify_on_warning(db: AsyncSession, lic: SuperAdminLicense, days_left: int) -> str: """License 即将过期(剩余 7 天)预警短信。返回: sent | skipped | failed""" try: r = await get_redis() if await r.get(f"sms:warning_sent:{lic.id}"): logger.info("License #%d 已发送过预警短信,跳过", lic.id) return "skipped" phone, company = await _get_contact_for_license(db, lic) if not phone: logger.warning("License #%d 无联系人手机号,跳过即将过期预警短信", lic.id) return "skipped" logger.info("发送 License 即将过期预警: license_id=%d, phone=%s, company=%s, days_left=%d", lic.id, phone, company, days_left) ok, reason = await send_license_warning(phone, company, days_left) if ok: logger.info("License 即将过期预警短信发送成功: license_id=%d", lic.id) await r.setex(f"sms:warning_sent:{lic.id}", 86400 * 30, "1") return "sent" logger.error("License 即将过期预警短信发送失败: license_id=%d, 原因=%s", lic.id, reason) return "failed" except Exception: logger.exception("发送即将过期预警短信异常,license_id=%d", lic.id) return "failed" def _calc_days_left(expires_at: datetime) -> int: now = datetime.now(timezone.utc) if expires_at.tzinfo else datetime.now() delta = expires_at - now return math.ceil(delta.total_seconds() / (60 * 60 * 24)) async def get_super_admins(db: AsyncSession) -> list[SuperAdminOption]: """获取所有超级管理员,供下拉选择""" result = await db.execute(select(SuperAdmin).order_by(SuperAdmin.id)) sas = result.scalars().all() return [SuperAdminOption(id=sa.id, username=sa.username, nickname=sa.nickname, remark=sa.remark) for sa in sas] async def create_license( db: AsyncSession, payload: LicenseCreate ) -> dict: """创建或更新 License。同一超管只保留一个 active 的 License,已存在则更新。""" existing = await db.execute( select(SuperAdminLicense).where( SuperAdminLicense.super_admin_id == payload.super_admin_id, SuperAdminLicense.status == "active", ) ) exist_lic = existing.scalar_one_or_none() expires_at = datetime.fromisoformat(payload.expires_at) if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=CST) if exist_lic: exist_lic.license_key = payload.license_key exist_lic.expires_at = expires_at exist_lic.max_tenants = payload.max_tenants exist_lic.max_users_per_tenant = payload.max_users_per_tenant exist_lic.remark = payload.remark await db.flush() await db.commit() return {"message": "License已更新", "license_id": exist_lic.id} new_lic = SuperAdminLicense( super_admin_id=payload.super_admin_id, license_key=payload.license_key, expires_at=expires_at, status="active", max_tenants=payload.max_tenants, max_users_per_tenant=payload.max_users_per_tenant, remark=payload.remark, ) db.add(new_lic) await db.commit() await db.refresh(new_lic) return {"message": "License已创建", "license_id": new_lic.id} async def list_licenses( db: AsyncSession, super_admin_id: int | None = None, status: str | None = None, page: int = 1, size: int = 20, ) -> LicenseListResponse: """查询 License 列表""" stmt = select(SuperAdminLicense) if super_admin_id is not None: stmt = stmt.where(SuperAdminLicense.super_admin_id == super_admin_id) if status: stmt = stmt.where(SuperAdminLicense.status == status) stmt = stmt.order_by(SuperAdminLicense.created_at.desc()) total_result = await db.execute(stmt) all_rows = total_result.scalars().all() total = len(all_rows) offset = (page - 1) * size paged = all_rows[offset: offset + size] items = [] for r in paged: sa_result = await db.execute( select(SuperAdmin).where(SuperAdmin.id == r.super_admin_id) ) sa = sa_result.scalar_one_or_none() sa_name = sa.remark or sa.username if sa else str(r.super_admin_id) # 查询关联的域名 domain_result = await db.execute( select(MonitoredDomain).where( MonitoredDomain.super_admin_id == r.super_admin_id, MonitoredDomain.is_active == True, ).limit(1) ) domain_row = domain_result.scalar_one_or_none() domain_name = domain_row.domain if domain_row else None # 查询关联的联系人信息 contact = None if domain_row: visitor_result = await db.execute( select(VisitorInfo).where(VisitorInfo.domain_id == domain_row.id) ) visitor = visitor_result.scalar_one_or_none() if visitor: contact = {"name": visitor.name, "phone": visitor.phone, "email": visitor.email} items.append(LicenseResponse( id=r.id, super_admin_id=r.super_admin_id, super_admin_name=sa_name, license_key=r.license_key, expires_at=_to_str(r.expires_at), status=r.status, max_tenants=r.max_tenants, max_users_per_tenant=r.max_users_per_tenant, remark=r.remark, created_at=_to_str(r.created_at), updated_at=_to_str(r.updated_at), domain=domain_name, contact=contact, )) return LicenseListResponse(total=total, items=items) async def get_license_status( db: AsyncSession, license_id: int ) -> LicenseStatusResponse: """获取单个 License 的详细状态""" result = await db.execute( select(SuperAdminLicense).where(SuperAdminLicense.id == license_id) ) lic = result.scalar_one_or_none() if not lic: raise ValueError("License 不存在") days_left = _calc_days_left(lic.expires_at) sms_status = None if lic.status == "active" and days_left <= 0: lic.status = "expired" await db.commit() sms_status = await _notify_on_expired(db, lic) sa_result = await db.execute( select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id) ) sa = sa_result.scalar_one_or_none() return LicenseStatusResponse( id=lic.id, super_admin_id=lic.super_admin_id, super_admin_name=sa.remark or sa.username if sa else str(lic.super_admin_id), license_key=lic.license_key, expires_at=_to_str(lic.expires_at), status=lic.status, days_left=days_left, max_tenants=lic.max_tenants, max_users_per_tenant=lic.max_users_per_tenant, remark=lic.remark, sms_status=sms_status, ) async def revoke_license( db: AsyncSession, license_id: int ) -> dict: """吊销 License""" result = await db.execute( select(SuperAdminLicense).where(SuperAdminLicense.id == license_id) ) lic = result.scalar_one_or_none() if not lic: raise ValueError("License 不存在") old_status = lic.status lic.status = "revoked" await db.commit() # 吊销即 valid 变为 false,不管原状态如何都发预警短信 sms_status = await _notify_on_expired(db, lic) return {"message": "License已吊销", "sms_status": sms_status} async def update_license( db: AsyncSession, license_id: int, payload: LicenseUpdate ) -> dict: """更新 License 的 key 或过期时间""" result = await db.execute( select(SuperAdminLicense).where(SuperAdminLicense.id == license_id) ) lic = result.scalar_one_or_none() if not lic: raise ValueError("License 不存在") changed = False old_status = lic.status sms_status = "none" if payload.license_key is not None: lic.license_key = payload.license_key changed = True if payload.expires_at is not None: expires_at = datetime.fromisoformat(payload.expires_at) if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=CST) lic.expires_at = expires_at if _calc_days_left(lic.expires_at) > 0: lic.status = "active" else: lic.status = "expired" changed = True if not changed: raise ValueError("没有需要更新的字段") # 修改过期时间后已超出 7 天窗口,重置预警标记 if payload.expires_at is not None and _calc_days_left(lic.expires_at) > 7: r = await get_redis() await r.delete(f"sms:warning_sent:{lic.id}") await r.delete(f"sms:expired_sent:{lic.id}") elif old_status != "active" and lic.status == "active": r = await get_redis() await r.delete(f"sms:warning_sent:{lic.id}") await r.delete(f"sms:expired_sent:{lic.id}") await db.commit() # valid 变化时发送短信 if old_status == "active" and lic.status != "active": sms_status = await _notify_on_expired(db, lic) elif old_status != "active" and lic.status == "active": sms_status = await _notify_on_restored(db, lic) return {"message": "License已更新", "license_id": lic.id, "sms_status": sms_status} async def restore_license( db: AsyncSession, license_id: int ) -> dict: """恢复已吊销的 License""" result = await db.execute( select(SuperAdminLicense).where(SuperAdminLicense.id == license_id) ) lic = result.scalar_one_or_none() if not lic: raise ValueError("License 不存在") if lic.status != "revoked": raise ValueError("仅可恢复已吊销的 License") lic.status = "active" await db.commit() # 清除 Redis 去重标记 r = await get_redis() await r.delete(f"sms:warning_sent:{lic.id}") await r.delete(f"sms:expired_sent:{lic.id}") sms_status = await _notify_on_restored(db, lic) return {"message": "License已恢复", "sms_status": sms_status} async def delete_license( db: AsyncSession, license_id: int ) -> dict: """删除 License 记录""" result = await db.execute( select(SuperAdminLicense).where(SuperAdminLicense.id == license_id) ) lic = result.scalar_one_or_none() if not lic: raise ValueError("License 不存在") await db.delete(lic) await db.commit() return {"message": "License已删除"} async def check_license_by_referer( db: AsyncSession, referer: str | None ) -> dict: """通过 Referer 头匹配域名,检查对应超管的 License 状态。 流程: 1. 从 Referer 提取域名(host) 2. 在 MonitoredDomain 中匹配 domain 字段 3. 获取关联的 super_admin_id 4. 查询该超管当前 active 的 License 并返回状态 """ if not referer: return { "valid": False, "status": "unknown", "message": "缺少 Referer 头", } # 从 Referer URL 提取域名 try: parsed = urlparse(referer) referer_host = parsed.hostname # 当 Referer 缺少 scheme(如 "127.0.0.1" 或 "example.com")时, # urlparse 会将其解析为 path 而非 netloc,导致 hostname 为 None。 # 此时回退为使用原始字符串作为域名。 if not referer_host: referer_host = referer.strip().rstrip("/") if not referer_host: return { "valid": False, "status": "unknown", "message": "Referer 格式无效", } # netloc 包含端口(如 127.0.0.1:8010),hostname 不包含 referer_netloc = parsed.netloc if parsed.netloc else referer_host except Exception: return { "valid": False, "status": "unknown", "message": "Referer 解析失败", } # 匹配监控域名(支持多种存储格式:纯域名、带端口、带 http(s):// 前缀) result = await db.execute( select(MonitoredDomain).where( MonitoredDomain.is_active == True, MonitoredDomain.domain.in_([ referer_netloc, # 例: 127.0.0.1:8010 referer_host, # 例: 127.0.0.1 f"http://{referer_netloc}", # 例: http://127.0.0.1:8010 f"https://{referer_netloc}", # 例: https://127.0.0.1:8010 ]), ) ) domain = result.scalars().first() if not domain or not domain.super_admin_id: return { "valid": False, "status": "unknown", "message": "域名未注册或无关联超管", } # 查询该超管最新的 license(不限 status,取最近创建的一条) license_result = await db.execute( select(SuperAdminLicense) .where(SuperAdminLicense.super_admin_id == domain.super_admin_id) .order_by(SuperAdminLicense.created_at.desc()) ) lic = license_result.scalars().first() if not lic: return { "valid": False, "status": "not_found", "message": "未找到有效 License", } days_left = _calc_days_left(lic.expires_at) if lic.status == "active" and days_left <= 0: lic.status = "expired" await db.commit() # 触发过期通知 await _notify_on_expired(db, lic) sa_result = await db.execute( select(SuperAdmin).where(SuperAdmin.id == lic.super_admin_id) ) sa = sa_result.scalar_one_or_none() return { "valid": lic.status == "active", "status": lic.status, "super_admin_name": sa.remark or sa.username if sa else str(lic.super_admin_id), "license_key": lic.license_key, "expires_at": _to_str(lic.expires_at), "days_left": days_left, "max_tenants": lic.max_tenants, "max_users_per_tenant": lic.max_users_per_tenant, "remark": lic.remark, }