domains.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from fastapi import APIRouter, Depends, HTTPException, Query
  2. from pydantic import BaseModel
  3. from sqlalchemy import select, func, desc
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.database import get_db
  6. from app.models.domain import MonitoredDomain
  7. from app.models.monitoring import SuperAdmin, FetchScheduleConfig, FetchLog
  8. from app.schemas.domain import (
  9. MonitoredDomainCreate,
  10. MonitoredDomainResponse,
  11. )
  12. from app.services.domain_fetch import fetch_domain_transactions
  13. import logging
  14. logger = logging.getLogger(__name__)
  15. router = APIRouter(prefix="/api/domains", tags=["domains"])
  16. @router.post("/", response_model=MonitoredDomainResponse, status_code=201)
  17. async def add_domain(
  18. payload: MonitoredDomainCreate,
  19. db: AsyncSession = Depends(get_db),
  20. ):
  21. """添加需要监控的域名,同时创建对应的超管记录"""
  22. existing = await db.execute(
  23. select(MonitoredDomain).where(MonitoredDomain.domain == payload.domain)
  24. )
  25. if existing.scalar_one_or_none():
  26. raise HTTPException(status_code=409, detail="域名已在监控中")
  27. # 如果未指定超管,自动创建一条超管记录
  28. sa_id = payload.super_admin_id
  29. if sa_id is None:
  30. max_id_result = await db.execute(select(func.max(SuperAdmin.id)))
  31. max_id = max_id_result.scalar() or 0
  32. new_sa = SuperAdmin(
  33. id=max_id + 1,
  34. username=payload.domain,
  35. nickname=payload.domain,
  36. remark=payload.remark or None,
  37. )
  38. db.add(new_sa)
  39. await db.flush()
  40. sa_id = new_sa.id
  41. record = MonitoredDomain(domain=payload.domain, remark=payload.remark or None, super_admin_id=sa_id)
  42. db.add(record)
  43. await db.commit()
  44. await db.refresh(record)
  45. return record
  46. @router.get("/", response_model=list[MonitoredDomainResponse])
  47. async def list_domains(db: AsyncSession = Depends(get_db)):
  48. """获取所有已监控的域名列表"""
  49. result = await db.execute(select(MonitoredDomain))
  50. return result.scalars().all()
  51. class MonitoredDomainUpdate(BaseModel):
  52. """更新域名备注"""
  53. remark: str = ""
  54. @router.patch("/{domain_id}", response_model=MonitoredDomainResponse)
  55. async def update_domain_remark(
  56. domain_id: int,
  57. payload: MonitoredDomainUpdate,
  58. db: AsyncSession = Depends(get_db),
  59. ):
  60. """更新域名备注,并同步到关联的超管"""
  61. result = await db.execute(
  62. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  63. )
  64. record = result.scalar_one_or_none()
  65. if not record:
  66. raise HTTPException(status_code=404, detail="域名不存在")
  67. record.remark = payload.remark or None
  68. # 同步到关联的超管
  69. if record.super_admin_id:
  70. sa_result = await db.execute(
  71. select(SuperAdmin).where(SuperAdmin.id == record.super_admin_id)
  72. )
  73. sa = sa_result.scalar_one_or_none()
  74. if sa:
  75. sa.remark = payload.remark or None
  76. await db.commit()
  77. await db.refresh(record)
  78. return record
  79. @router.delete("/{domain_id}", status_code=204)
  80. async def remove_domain(domain_id: int, db: AsyncSession = Depends(get_db)):
  81. """移除指定 ID 的监控域名"""
  82. result = await db.execute(
  83. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  84. )
  85. record = result.scalar_one_or_none()
  86. if not record:
  87. raise HTTPException(status_code=404, detail="域名不存在")
  88. await db.delete(record)
  89. await db.commit()
  90. @router.get("/{domain_id}/transactions")
  91. async def get_domain_transactions(
  92. domain_id: int,
  93. fetch_date: str | None = Query(None, description="爬取指定日期,格式 YYYY-MM-DD,不传则查全部"),
  94. db: AsyncSession = Depends(get_db),
  95. ):
  96. """爬取指定域名的监控数据并入库"""
  97. result = await db.execute(
  98. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  99. )
  100. record = result.scalar_one_or_none()
  101. if not record:
  102. raise HTTPException(status_code=404, detail="域名不存在")
  103. async def get_domain_transactions(
  104. domain_id: int,
  105. fetch_date: str | None = Query(None, description="爬取指定日期,格式 YYYY-MM-DD,不传则查全部"),
  106. db: AsyncSession = Depends(get_db),
  107. ):
  108. """爬取指定域名的监控数据并入库"""
  109. result = await db.execute(
  110. select(MonitoredDomain).where(MonitoredDomain.id == domain_id)
  111. )
  112. record = result.scalar_one_or_none()
  113. if not record:
  114. raise HTTPException(status_code=404, detail="域名不存在")
  115. try:
  116. data = await fetch_domain_transactions(record.domain, db, fetch_date=fetch_date)
  117. db.add(FetchLog(domain=record.domain, status="success", message="手动爬取成功"))
  118. await db.commit()
  119. return {"status": "ok", "domain": record.domain, "data": data}
  120. except Exception as e:
  121. error_msg = str(e)[:500]
  122. db.add(FetchLog(domain=record.domain, status="failed", message=error_msg))
  123. await db.commit()
  124. raise HTTPException(status_code=500, detail=error_msg)
  125. @router.post("/fetch-all", status_code=202)
  126. async def fetch_all_transactions(
  127. fetch_date: str | None = Query(None, description="爬取指定日期,格式 YYYY-MM-DD,不传则爬取当天"),
  128. db: AsyncSession = Depends(get_db),
  129. ):
  130. """批量爬取所有已启用域名的监控数据,默认只爬取当天"""
  131. if not fetch_date:
  132. from datetime import datetime, timezone, timedelta
  133. CST = timezone(timedelta(hours=8))
  134. fetch_date = datetime.now(CST).strftime("%Y-%m-%d")
  135. result = await db.execute(
  136. select(MonitoredDomain).where(MonitoredDomain.is_active == True)
  137. )
  138. domains = result.scalars().all()
  139. errors = []
  140. for d in domains:
  141. try:
  142. await fetch_domain_transactions(d.domain, db, fetch_date=fetch_date)
  143. db.add(FetchLog(domain=d.domain, status="success", message="手动批量爬取成功"))
  144. except Exception as e:
  145. error_msg = str(e)[:500]
  146. db.add(FetchLog(domain=d.domain, status="failed", message=error_msg))
  147. errors.append({"domain": d.domain, "error": error_msg})
  148. await db.commit()
  149. return {"status": "ok", "total": len(domains), "errors": errors}
  150. class ScheduleConfigUpdate(BaseModel):
  151. """更新定时爬取配置"""
  152. enabled: bool
  153. schedule_time: str # HH:MM
  154. @router.get("/schedule")
  155. async def get_schedule_config(db: AsyncSession = Depends(get_db)):
  156. """获取定时爬取配置"""
  157. result = await db.execute(select(FetchScheduleConfig).limit(1))
  158. config = result.scalar_one_or_none()
  159. if not config:
  160. config = FetchScheduleConfig(enabled=False, schedule_time="02:00")
  161. db.add(config)
  162. await db.commit()
  163. await db.refresh(config)
  164. return {"enabled": config.enabled, "schedule_time": config.schedule_time}
  165. @router.put("/schedule")
  166. async def update_schedule_config(
  167. payload: ScheduleConfigUpdate,
  168. db: AsyncSession = Depends(get_db),
  169. ):
  170. """更新定时爬取配置"""
  171. result = await db.execute(select(FetchScheduleConfig).limit(1))
  172. config = result.scalar_one_or_none()
  173. if not config:
  174. config = FetchScheduleConfig()
  175. db.add(config)
  176. config.enabled = payload.enabled
  177. config.schedule_time = payload.schedule_time
  178. await db.commit()
  179. return {"message": "配置已保存", "enabled": config.enabled, "schedule_time": config.schedule_time}
  180. @router.get("/fetch-logs")
  181. async def list_fetch_logs(
  182. domain: str | None = Query(None, description="按域名筛选"),
  183. status: str | None = Query(None, description="按状态筛选: success / failed / skipped"),
  184. page: int = Query(1, ge=1),
  185. size: int = Query(20, ge=1, le=100),
  186. db: AsyncSession = Depends(get_db),
  187. ):
  188. """获取爬取日志列表"""
  189. conditions = []
  190. if domain:
  191. conditions.append(FetchLog.domain == domain)
  192. if status:
  193. conditions.append(FetchLog.status == status)
  194. count_query = select(func.count()).select_from(FetchLog)
  195. if conditions:
  196. count_query = count_query.where(*conditions)
  197. total = await db.scalar(count_query) or 0
  198. logs_query = select(FetchLog)
  199. if conditions:
  200. logs_query = logs_query.where(*conditions)
  201. logs_query = logs_query.order_by(desc(FetchLog.created_at)).offset((page - 1) * size).limit(size)
  202. logs_result = await db.execute(logs_query)
  203. logs = logs_result.scalars().all()
  204. return {
  205. "items": [
  206. {
  207. "id": log.id,
  208. "domain": log.domain,
  209. "status": log.status,
  210. "message": log.message,
  211. "created_at": log.created_at.isoformat() if log.created_at else None,
  212. }
  213. for log in logs
  214. ],
  215. "total": total,
  216. "page": page,
  217. "size": size,
  218. }