| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import json
- import logging
- import threading
- from datetime import datetime, timedelta
- from app.sample_center_client import SampleCenterClient, SampleCenterError
- logger = logging.getLogger(__name__)
- MAX_POLL_COUNT = 20
- POLL_INTERVAL_INIT = 2
- POLL_INTERVAL_MAX = 30
- POLL_MULTIPLIER = 1.5
- class KnowledgePoller:
- """后台轮询线程,定期检查 pending/processing 的入库任务状态。"""
- def __init__(self, app):
- self.app = app
- self._thread = None
- self._stop_event = threading.Event()
- def start(self):
- if self._thread and self._thread.is_alive():
- return
- self._stop_event.clear()
- self._thread = threading.Thread(target=self._run, name="knowledge-poller", daemon=True)
- self._thread.start()
- logger.info("Knowledge poller started")
- def stop(self):
- self._stop_event.set()
- if self._thread:
- self._thread.join(timeout=10)
- logger.info("Knowledge poller stopped")
- def _run(self):
- while not self._stop_event.is_set():
- try:
- self._poll_due_tasks()
- except Exception:
- logger.exception("Poller error")
- self._stop_event.wait(5)
- def _poll_due_tasks(self):
- from app.models import KnowledgeImportTask
- with self.app.app_context():
- now = datetime.utcnow()
- tasks = KnowledgeImportTask.query.filter(
- KnowledgeImportTask.status.in_(['pending', 'processing']),
- KnowledgeImportTask.next_poll_at <= now,
- KnowledgeImportTask.poll_count < MAX_POLL_COUNT,
- ).all()
- for task in tasks:
- self._poll_single_task(task)
- def _poll_single_task(self, task):
- from app import db
- cfg = self.app.config
- client = SampleCenterClient(
- base_url=cfg['SAMPLE_CENTER_BASE_URL'],
- app_id=cfg['SAMPLE_CENTER_APP_ID'],
- app_secret=cfg['SAMPLE_CENTER_APP_SECRET'],
- )
- task.poll_count += 1
- task.last_poll_at = datetime.utcnow()
- try:
- result = client.get_import_task(task.sample_task_id)
- sc_data = result.get('data', {})
- sc_status = sc_data.get('status', '')
- task.status_detail = json.dumps(sc_data, ensure_ascii=False)
- progress = sc_data.get('progress')
- if progress:
- task.progress = json.dumps(progress, ensure_ascii=False)
- if sc_status in ('completed',):
- progress_data = sc_data.get('progress')
- if progress_data:
- succeeded = progress_data.get('succeeded', 0)
- failed = progress_data.get('failed', 0)
- total = progress_data.get('total', 0)
- if total > 0 and succeeded == 0:
- task.status = 'failed'
- elif total > 0 and succeeded > 0 and failed > 0:
- task.status = 'partial_success'
- else:
- task.status = 'success'
- else:
- task.status = 'success'
- task.next_poll_at = None
- elif sc_status == 'failed':
- task.status = 'failed'
- task.error_message = sc_data.get('error', '')
- task.next_poll_at = None
- else:
- task.status = 'processing'
- interval = min(
- POLL_INTERVAL_INIT * (POLL_MULTIPLIER ** (task.poll_count - 1)),
- POLL_INTERVAL_MAX,
- )
- task.next_poll_at = datetime.utcnow() + timedelta(seconds=interval)
- db.session.commit()
- logger.info(f"Polled task {task.task_no}: status={task.status}")
- except SampleCenterError as e:
- task.error_message = str(e)
- interval = min(
- POLL_INTERVAL_INIT * (POLL_MULTIPLIER ** (task.poll_count - 1)),
- POLL_INTERVAL_MAX,
- )
- task.next_poll_at = datetime.utcnow() + timedelta(seconds=interval)
- db.session.commit()
- logger.warning(f"Poll error for {task.task_no}: {e}")
- except Exception:
- db.session.rollback()
- logger.exception(f"Unexpected poll error for {task.task_no}")
|