from __future__ import annotations import logging import threading import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime from typing import Any, Callable logger = logging.getLogger("office365_admin.tasks") def _now_iso() -> str: return datetime.now().isoformat(timespec="seconds") @dataclass class TaskRecord: id: str operation: str total: int status: str = "queued" message: str = "任务已提交" created_at: str = field(default_factory=_now_iso) started_at: str = "" finished_at: str = "" completed: int = 0 success_count: int = 0 failure_count: int = 0 current_item: str = "" current_message: str = "" recent_failures: list[dict[str, str]] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: progress_percent = 0 if self.total > 0: progress_percent = int((self.completed / self.total) * 100) return { "id": self.id, "operation": self.operation, "status": self.status, "message": self.message, "createdAt": self.created_at, "startedAt": self.started_at, "finishedAt": self.finished_at, "total": self.total, "completed": self.completed, "successCount": self.success_count, "failureCount": self.failure_count, "progressPercent": progress_percent, "currentItem": self.current_item, "currentMessage": self.current_message, "recentFailures": self.recent_failures, } class TaskNotFoundError(KeyError): pass class BackgroundTaskManager: def __init__(self, max_workers: int = 4): self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="office365-task") self._tasks: dict[str, TaskRecord] = {} self._lock = threading.Lock() def submit( self, operation: str, total: int, runner: Callable[[Callable[[dict[str, Any]], None]], dict[str, Any]], ) -> dict[str, Any]: task_id = uuid.uuid4().hex record = TaskRecord( id=task_id, operation=operation, total=total, message=f"{self._label(operation)}任务已提交", ) with self._lock: self._tasks[task_id] = record logger.info("Task %s queued: operation=%s total=%s", task_id, operation, total) self._executor.submit(self._run_task, task_id, runner) return record.to_dict() def get_task(self, task_id: str) -> dict[str, Any]: with self._lock: record = self._tasks.get(task_id) if record is None: raise TaskNotFoundError(task_id) return record.to_dict() def _run_task( self, task_id: str, runner: Callable[[Callable[[dict[str, Any]], None]], dict[str, Any]], ) -> None: self._update( task_id, status="running", started_at=_now_iso(), message="任务执行中", ) logger.info("Task %s started", task_id) try: result = runner(lambda update: self._handle_progress(task_id, update)) summary_message = ( f"任务完成,成功 {result.get('successCount', 0)},失败 {result.get('failureCount', 0)}" ) self._update( task_id, status="succeeded", finished_at=_now_iso(), completed=result.get("total", 0), success_count=result.get("successCount", 0), failure_count=result.get("failureCount", 0), current_message=summary_message, message=summary_message, ) logger.info("Task %s finished: %s", task_id, summary_message) except Exception as exc: logger.exception("Task %s failed", task_id) self._update( task_id, status="failed", finished_at=_now_iso(), message=f"任务执行失败: {exc}", current_message=str(exc), ) def _handle_progress(self, task_id: str, update: dict[str, Any]) -> None: update_payload = { "completed": update.get("completed", 0), "success_count": update.get("successCount", 0), "failure_count": update.get("failureCount", 0), "current_item": update.get("identifier", ""), "current_message": update.get("message", ""), "message": f"正在执行 {update.get('completed', 0)} / {update.get('total', 0)}", } if not update.get("success") and update.get("identifier"): with self._lock: record = self._tasks.get(task_id) if record is not None and len(record.recent_failures) < 5: record.recent_failures.append( { "identifier": update.get("identifier", ""), "message": update.get("message", ""), } ) self._update(task_id, **update_payload) def _update(self, task_id: str, **changes: Any) -> None: with self._lock: record = self._tasks.get(task_id) if record is None: return for key, value in changes.items(): setattr(record, key, value) @staticmethod def _label(operation: str) -> str: labels = { "create": "批量创建", "update": "批量更新", "delete": "批量删除", "reset-password": "批量改密", } return labels.get(operation, operation)