domains.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from fastapi import APIRouter, Depends, HTTPException, Query
  2. from pydantic import BaseModel
  3. from sqlalchemy import select, func, desc
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.database import get_db
  6. from app.models.domain import MonitoredDomain
  7. from app.models.monitoring import SuperAdmin, FetchScheduleConfig, FetchLog, SuperAdminTenant
  8. from app.schemas.domain import (
  9. MonitoredDomainCreate,
  10. MonitoredDomainResponse,
  11. )
  12. from app.services.domain_fetch import fetch_domain_transactions
  13. import logging
  14. logger = logging.getLogger(__name__)
  15. router = APIRouter(prefix="/api/domains", tags=["domains"])
  16. @router.post("/", response_model=MonitoredDomainResponse, status_code=201)
  17. async def add_domain(
  18. payload: MonitoredDomainCreate,
  19. db: AsyncSession = Depends(get_db),
  20. ):
  21. """添加需要监控的域名,同时创建对应的超管记录"""
  22. existing = await db.execute(
  23. select(MonitoredDomain).where(MonitoredDomain.domain == payload.domain)
  24. )
  25. if existing.scalar_one_or_none():
  26. raise HTTPException(status_code=409, detail="域名已在监控中")
  27. # 如果未指定超管,自动创建一条超管记录
  28. sa_id = payload.super_admin_id
  29. if sa_id is None:
  30. max_id_result = await db.execute(select(func.max(SuperAdmin.id)))
  31. max_id = max_id_result.scalar() or 0
  32. new_sa = SuperAdmin(
  33. id=max_id + 1,
  34. username=payload.domain,
  35. nickname=payload.domain,
  36. remark=payload.remark or None,
  37. )
  38. db.add(new_sa)
  39. await db.flush()
  40. sa_id = new_sa.id
  41. record = MonitoredDomain(domain=payload.domain, remark=payload.remark or None, super_admin_id=sa_id)
  42. db.add(record)
  43. await db.commit()
  44. await db.refresh(record)
  45. return record
  46. @router.get("/", response_model=list[MonitoredDomainResponse])
  47. async def list_domains(db: AsyncSession = Depends(get_db)):
  48. """获取所有已监控的域名列表"""
  49. result = await db.execute(select(MonitoredDomain))
  50. return result.scalars().all()
  51. class MonitoredDomainUpdate(BaseModel):
  52. """更新域名备注"""
  53. remark: str = ""
  54. @router.patch("/{domain_id}", response_model=MonitoredDomainResponse)
  55. async def update_domain_remark(
  56. domain_id: int,
  57. payload: MonitoredDomainUpdate,
  58. db: AsyncSession = Depends(get_db),
  59. ):
  60. """更新域名备注,并同步到关联的超管"""
  61. result = await db.execute(
  62. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  63. )
  64. record = result.scalar_one_or_none()
  65. if not record:
  66. raise HTTPException(status_code=404, detail="域名不存在")
  67. record.remark = payload.remark or None
  68. # 同步到关联的超管
  69. if record.super_admin_id:
  70. sa_result = await db.execute(
  71. select(SuperAdmin).where(SuperAdmin.id == record.super_admin_id)
  72. )
  73. sa = sa_result.scalar_one_or_none()
  74. if sa:
  75. sa.remark = payload.remark or None
  76. await db.commit()
  77. await db.refresh(record)
  78. return record
  79. @router.delete("/{domain_id}", status_code=204)
  80. async def remove_domain(domain_id: int, db: AsyncSession = Depends(get_db)):
  81. """移除指定 ID 的监控域名,并级联清理所有关联数据"""
  82. result = await db.execute(
  83. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  84. )
  85. record = result.scalar_one_or_none()
  86. if not record:
  87. raise HTTPException(status_code=404, detail="域名不存在")
  88. domain_name = record.domain
  89. sa_id = record.super_admin_id
  90. # 1. 删除该域名相关的爬取日志
  91. await db.execute(
  92. FetchLog.__table__.delete().where(FetchLog.domain == domain_name)
  93. )
  94. # 2. 删除域名记录本身
  95. await db.delete(record)
  96. # 3. 如果没有其他域名关联这个超管,级联清理超管及其所有关联数据
  97. if sa_id is not None:
  98. remaining = await db.execute(
  99. select(MonitoredDomain).where(MonitoredDomain.super_admin_id == sa_id)
  100. )
  101. if not remaining.scalar_one_or_none():
  102. # 3a. 删除该超管的所有 License
  103. from app.models.license import SuperAdminLicense
  104. await db.execute(
  105. SuperAdminLicense.__table__.delete().where(SuperAdminLicense.super_admin_id == sa_id)
  106. )
  107. # 3b. 查找该超管关联的所有租户
  108. tenant_result = await db.execute(
  109. select(SuperAdminTenant.tenant_id).where(SuperAdminTenant.super_admin_id == sa_id)
  110. )
  111. tenant_ids = [row[0] for row in tenant_result.all()]
  112. # 3c. 先删除超管-租户关联
  113. await db.execute(
  114. SuperAdminTenant.__table__.delete().where(SuperAdminTenant.super_admin_id == sa_id)
  115. )
  116. # 3d. 删除不再被任何超管关联的租户及其消费明细
  117. from app.models.monitoring import Tenant, UserConsumptionDetail
  118. if tenant_ids:
  119. for tid in tenant_ids:
  120. # 检查是否还有其他超管关联此租户
  121. other = await db.execute(
  122. select(SuperAdminTenant).where(SuperAdminTenant.tenant_id == tid)
  123. )
  124. if not other.scalar_one_or_none():
  125. # 删除消费明细
  126. await db.execute(
  127. UserConsumptionDetail.__table__.delete().where(
  128. UserConsumptionDetail.tenant_id == tid
  129. )
  130. )
  131. # 删除租户
  132. tenant_row = await db.execute(
  133. select(Tenant).where(Tenant.id == tid)
  134. )
  135. t = tenant_row.scalar_one_or_none()
  136. if t:
  137. await db.delete(t)
  138. # 3e. 删除超管本身
  139. sa_result = await db.execute(
  140. select(SuperAdmin).where(SuperAdmin.id == sa_id)
  141. )
  142. sa = sa_result.scalar_one_or_none()
  143. if sa:
  144. await db.delete(sa)
  145. await db.commit()
  146. @router.get("/{domain_id}/transactions")
  147. async def get_domain_transactions(
  148. domain_id: int,
  149. fetch_date: str | None = Query(None, description="爬取指定日期,格式 YYYY-MM-DD,不传则查全部"),
  150. db: AsyncSession = Depends(get_db),
  151. ):
  152. """爬取指定域名的监控数据并入库"""
  153. result = await db.execute(
  154. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  155. )
  156. record = result.scalar_one_or_none()
  157. if not record:
  158. raise HTTPException(status_code=404, detail="域名不存在")
  159. try:
  160. data = await fetch_domain_transactions(record.domain, db, fetch_date=fetch_date)
  161. db.add(FetchLog(domain=record.domain, status="success", message="手动爬取成功"))
  162. await db.commit()
  163. return {"status": "ok", "domain": record.domain, "data": data}
  164. except Exception as e:
  165. error_msg = str(e)[:500]
  166. db.add(FetchLog(domain=record.domain, status="failed", message=error_msg))
  167. await db.commit()
  168. raise HTTPException(status_code=500, detail=error_msg)
  169. @router.post("/fetch-all", status_code=202)
  170. async def fetch_all_transactions(
  171. fetch_date: str | None = Query(None, description="爬取指定日期,格式 YYYY-MM-DD,不传则爬取当天"),
  172. db: AsyncSession = Depends(get_db),
  173. ):
  174. """批量爬取所有已启用域名的监控数据,默认只爬取当天"""
  175. if not fetch_date:
  176. from datetime import datetime, timezone, timedelta
  177. CST = timezone(timedelta(hours=8))
  178. fetch_date = datetime.now(CST).strftime("%Y-%m-%d")
  179. result = await db.execute(
  180. select(MonitoredDomain).where(MonitoredDomain.is_active == True)
  181. )
  182. domains = result.scalars().all()
  183. errors = []
  184. for d in domains:
  185. try:
  186. await fetch_domain_transactions(d.domain, db, fetch_date=fetch_date)
  187. db.add(FetchLog(domain=d.domain, status="success", message="手动批量爬取成功"))
  188. except Exception as e:
  189. error_msg = str(e)[:500]
  190. db.add(FetchLog(domain=d.domain, status="failed", message=error_msg))
  191. errors.append({"domain": d.domain, "error": error_msg})
  192. await db.commit()
  193. return {"status": "ok", "total": len(domains), "errors": errors}
  194. class ScheduleConfigUpdate(BaseModel):
  195. """更新定时爬取配置"""
  196. enabled: bool
  197. schedule_time: str # HH:MM
  198. @router.get("/schedule")
  199. async def get_schedule_config(db: AsyncSession = Depends(get_db)):
  200. """获取定时爬取配置"""
  201. result = await db.execute(select(FetchScheduleConfig).limit(1))
  202. config = result.scalar_one_or_none()
  203. if not config:
  204. config = FetchScheduleConfig(enabled=False, schedule_time="02:00")
  205. db.add(config)
  206. await db.commit()
  207. await db.refresh(config)
  208. return {"enabled": config.enabled, "schedule_time": config.schedule_time}
  209. @router.put("/schedule")
  210. async def update_schedule_config(
  211. payload: ScheduleConfigUpdate,
  212. db: AsyncSession = Depends(get_db),
  213. ):
  214. """更新定时爬取配置"""
  215. result = await db.execute(select(FetchScheduleConfig).limit(1))
  216. config = result.scalar_one_or_none()
  217. if not config:
  218. config = FetchScheduleConfig()
  219. db.add(config)
  220. config.enabled = payload.enabled
  221. config.schedule_time = payload.schedule_time
  222. await db.commit()
  223. return {"message": "配置已保存", "enabled": config.enabled, "schedule_time": config.schedule_time}
  224. @router.get("/fetch-logs")
  225. async def list_fetch_logs(
  226. domain: str | None = Query(None, description="按域名筛选"),
  227. status: str | None = Query(None, description="按状态筛选: success / failed / skipped"),
  228. page: int = Query(1, ge=1),
  229. size: int = Query(20, ge=1, le=100),
  230. db: AsyncSession = Depends(get_db),
  231. ):
  232. """获取爬取日志列表"""
  233. conditions = []
  234. if domain:
  235. conditions.append(FetchLog.domain == domain)
  236. if status:
  237. conditions.append(FetchLog.status == status)
  238. count_query = select(func.count()).select_from(FetchLog)
  239. if conditions:
  240. count_query = count_query.where(*conditions)
  241. total = await db.scalar(count_query) or 0
  242. logs_query = select(FetchLog)
  243. if conditions:
  244. logs_query = logs_query.where(*conditions)
  245. logs_query = logs_query.order_by(desc(FetchLog.created_at)).offset((page - 1) * size).limit(size)
  246. logs_result = await db.execute(logs_query)
  247. logs = logs_result.scalars().all()
  248. return {
  249. "items": [
  250. {
  251. "id": log.id,
  252. "domain": log.domain,
  253. "status": log.status,
  254. "message": log.message,
  255. "created_at": log.created_at.isoformat() if log.created_at else None,
  256. }
  257. for log in logs
  258. ],
  259. "total": total,
  260. "page": page,
  261. "size": size,
  262. }