| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- """轻量后台任务管理器,按类型控制并发,用 asyncio.create_task 执行。"""
- import asyncio
- from datetime import datetime, timedelta, timezone
- from typing import Any, Callable, Coroutine, Optional
- from app.core.logging import logger
- class BackgroundTaskManager:
- def __init__(self):
- self._tasks: dict[str, dict[str, Any]] = {}
- self._type_semaphores: dict[str, asyncio.Semaphore] = {}
- self._callbacks: list[Callable[[str, dict], None]] = []
- def set_concurrency(self, task_type: str, limit: int) -> None:
- self._type_semaphores[task_type] = asyncio.Semaphore(limit)
- def register_callback(self, callback: Callable[[str, dict], None]) -> None:
- self._callbacks.append(callback)
- def register_task(self, task_id: str, task_type: str, metadata: dict | None = None) -> None:
- self._tasks[task_id] = {
- "task_type": task_type,
- "status": "pending",
- "progress": 0.0,
- "error": None,
- "created_at": datetime.now(timezone.utc),
- **(metadata or {}),
- }
- def update_task(self, task_id: str, **kwargs) -> None:
- if task_id in self._tasks:
- self._tasks[task_id].update(kwargs)
- for cb in self._callbacks:
- try:
- cb(task_id, dict(self._tasks[task_id]))
- except Exception:
- pass
- def get_task(self, task_id: str) -> Optional[dict[str, Any]]:
- return self._tasks.get(task_id)
- def list_tasks_by_type(self, task_type: str) -> dict[str, dict[str, Any]]:
- return {tid: t for tid, t in self._tasks.items() if t.get("task_type") == task_type}
- async def run(self, task_id: str, task_type: str, coro: Coroutine) -> None:
- sem = self._type_semaphores.get(task_type)
- async def _wrapped() -> None:
- if sem:
- async with sem:
- self.update_task(task_id, status="running", progress=0.0)
- try:
- result = await coro
- self.update_task(
- task_id, status="completed", progress=100.0, **(result or {})
- )
- except Exception as e:
- self.update_task(task_id, status="failed", error=str(e))
- logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
- else:
- self.update_task(task_id, status="running", progress=0.0)
- try:
- result = await coro
- self.update_task(
- task_id, status="completed", progress=100.0, **(result or {})
- )
- except Exception as e:
- self.update_task(task_id, status="failed", error=str(e))
- logger.error(f"Background task {task_id} ({task_type}) failed: {e}")
- asyncio.create_task(_wrapped())
- def cancel_task(self, task_id: str) -> bool:
- if task_id in self._tasks and self._tasks[task_id]["status"] in ("pending", "running"):
- self.update_task(task_id, status="cancelled", error="Cancelled by user")
- return True
- return False
- def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
- cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
- to_remove = [
- tid
- for tid, t in self._tasks.items()
- if t["status"] in ("completed", "failed", "cancelled")
- and t.get("created_at")
- and t["created_at"] < cutoff
- ]
- for tid in to_remove:
- del self._tasks[tid]
- @property
- def tasks(self) -> dict[str, dict[str, Any]]:
- return dict(self._tasks)
- # 全局单例,在 main.py lifespan 中初始化
- background_task_manager = BackgroundTaskManager()
|