background_tasks.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """轻量后台任务管理器,按类型控制并发,用 asyncio.create_task 执行。"""
  2. import asyncio
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Any, Callable, Coroutine, Optional
  5. from app.core.logging import logger
  6. class BackgroundTaskManager:
  7. def __init__(self):
  8. self._tasks: dict[str, dict[str, Any]] = {}
  9. self._type_semaphores: dict[str, asyncio.Semaphore] = {}
  10. self._callbacks: list[Callable[[str, dict], None]] = []
  11. def set_concurrency(self, task_type: str, limit: int) -> None:
  12. self._type_semaphores[task_type] = asyncio.Semaphore(limit)
  13. def register_callback(self, callback: Callable[[str, dict], None]) -> None:
  14. self._callbacks.append(callback)
  15. def register_task(self, task_id: str, task_type: str, metadata: dict | None = None) -> None:
  16. self._tasks[task_id] = {
  17. "task_type": task_type,
  18. "status": "pending",
  19. "progress": 0.0,
  20. "error": None,
  21. "created_at": datetime.now(timezone.utc),
  22. **(metadata or {}),
  23. }
  24. def update_task(self, task_id: str, **kwargs) -> None:
  25. if task_id in self._tasks:
  26. self._tasks[task_id].update(kwargs)
  27. for cb in self._callbacks:
  28. try:
  29. cb(task_id, dict(self._tasks[task_id]))
  30. except Exception:
  31. pass
  32. def get_task(self, task_id: str) -> Optional[dict[str, Any]]:
  33. return self._tasks.get(task_id)
  34. def list_tasks_by_type(self, task_type: str) -> dict[str, dict[str, Any]]:
  35. return {tid: t for tid, t in self._tasks.items() if t.get("task_type") == task_type}
  36. async def run(self, task_id: str, task_type: str, coro: Coroutine) -> None:
  37. sem = self._type_semaphores.get(task_type)
  38. async def _wrapped() -> None:
  39. if sem:
  40. async with sem:
  41. self.update_task(task_id, status="running", progress=0.0)
  42. try:
  43. result = await coro
  44. self.update_task(
  45. task_id, status="completed", progress=100.0, **(result or {})
  46. )
  47. except Exception as e:
  48. self.update_task(task_id, status="failed", error=str(e))
  49. logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
  50. else:
  51. self.update_task(task_id, status="running", progress=0.0)
  52. try:
  53. result = await coro
  54. self.update_task(
  55. task_id, status="completed", progress=100.0, **(result or {})
  56. )
  57. except Exception as e:
  58. self.update_task(task_id, status="failed", error=str(e))
  59. logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
  60. asyncio.create_task(_wrapped())
  61. def cancel_task(self, task_id: str) -> bool:
  62. if task_id in self._tasks and self._tasks[task_id]["status"] in ("pending", "running"):
  63. self.update_task(task_id, status="cancelled", error="Cancelled by user")
  64. return True
  65. return False
  66. def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
  67. cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
  68. to_remove = [
  69. tid
  70. for tid, t in self._tasks.items()
  71. if t["status"] in ("completed", "failed", "cancelled")
  72. and t.get("created_at")
  73. and t["created_at"] < cutoff
  74. ]
  75. for tid in to_remove:
  76. del self._tasks[tid]
  77. @property
  78. def tasks(self) -> dict[str, dict[str, Any]]:
  79. return dict(self._tasks)
  80. # 全局单例,在 main.py lifespan 中初始化
  81. background_task_manager = BackgroundTaskManager()