import json import logging import threading from datetime import datetime, timedelta from app.sample_center_client import SampleCenterClient, SampleCenterError logger = logging.getLogger(__name__) MAX_POLL_COUNT = 20 POLL_INTERVAL_INIT = 2 POLL_INTERVAL_MAX = 30 POLL_MULTIPLIER = 1.5 class KnowledgePoller: """后台轮询线程,定期检查 pending/processing 的入库任务状态。""" def __init__(self, app): self.app = app self._thread = None self._stop_event = threading.Event() def start(self): if self._thread and self._thread.is_alive(): return self._stop_event.clear() self._thread = threading.Thread(target=self._run, name="knowledge-poller", daemon=True) self._thread.start() logger.info("Knowledge poller started") def stop(self): self._stop_event.set() if self._thread: self._thread.join(timeout=10) logger.info("Knowledge poller stopped") def _run(self): while not self._stop_event.is_set(): try: self._poll_due_tasks() except Exception: logger.exception("Poller error") self._stop_event.wait(5) def _poll_due_tasks(self): from app.models import KnowledgeImportTask with self.app.app_context(): now = datetime.utcnow() tasks = KnowledgeImportTask.query.filter( KnowledgeImportTask.status.in_(['pending', 'processing']), KnowledgeImportTask.next_poll_at <= now, KnowledgeImportTask.poll_count < MAX_POLL_COUNT, ).all() for task in tasks: self._poll_single_task(task) def _poll_single_task(self, task): from app import db cfg = self.app.config client = SampleCenterClient( base_url=cfg['SAMPLE_CENTER_BASE_URL'], app_id=cfg['SAMPLE_CENTER_APP_ID'], app_secret=cfg['SAMPLE_CENTER_APP_SECRET'], ) task.poll_count += 1 task.last_poll_at = datetime.utcnow() try: result = client.get_import_task(task.sample_task_id) sc_data = result.get('data', {}) sc_status = sc_data.get('status', '') task.status_detail = json.dumps(sc_data, ensure_ascii=False) progress = sc_data.get('progress') if progress: task.progress = json.dumps(progress, ensure_ascii=False) if sc_status in ('completed',): progress_data = sc_data.get('progress') if progress_data: succeeded = progress_data.get('succeeded', 0) failed = progress_data.get('failed', 0) total = progress_data.get('total', 0) if total > 0 and succeeded == 0: task.status = 'failed' elif total > 0 and succeeded > 0 and failed > 0: task.status = 'partial_success' else: task.status = 'success' else: task.status = 'success' task.next_poll_at = None elif sc_status == 'failed': task.status = 'failed' task.error_message = sc_data.get('error', '') task.next_poll_at = None else: task.status = 'processing' interval = min( POLL_INTERVAL_INIT * (POLL_MULTIPLIER ** (task.poll_count - 1)), POLL_INTERVAL_MAX, ) task.next_poll_at = datetime.utcnow() + timedelta(seconds=interval) db.session.commit() logger.info(f"Polled task {task.task_no}: status={task.status}") except SampleCenterError as e: task.error_message = str(e) interval = min( POLL_INTERVAL_INIT * (POLL_MULTIPLIER ** (task.poll_count - 1)), POLL_INTERVAL_MAX, ) task.next_poll_at = datetime.utcnow() + timedelta(seconds=interval) db.session.commit() logger.warning(f"Poll error for {task.task_no}: {e}") except Exception: db.session.rollback() logger.exception(f"Unexpected poll error for {task.task_no}")