173 lines
5.7 KiB
Python
173 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any, Callable
|
|
|
|
|
|
logger = logging.getLogger("office365_admin.tasks")
|
|
|
|
|
|
def _now_iso() -> str:
|
|
return datetime.now().isoformat(timespec="seconds")
|
|
|
|
|
|
@dataclass
|
|
class TaskRecord:
|
|
id: str
|
|
operation: str
|
|
total: int
|
|
status: str = "queued"
|
|
message: str = "任务已提交"
|
|
created_at: str = field(default_factory=_now_iso)
|
|
started_at: str = ""
|
|
finished_at: str = ""
|
|
completed: int = 0
|
|
success_count: int = 0
|
|
failure_count: int = 0
|
|
current_item: str = ""
|
|
current_message: str = ""
|
|
recent_failures: list[dict[str, str]] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
progress_percent = 0
|
|
if self.total > 0:
|
|
progress_percent = int((self.completed / self.total) * 100)
|
|
return {
|
|
"id": self.id,
|
|
"operation": self.operation,
|
|
"status": self.status,
|
|
"message": self.message,
|
|
"createdAt": self.created_at,
|
|
"startedAt": self.started_at,
|
|
"finishedAt": self.finished_at,
|
|
"total": self.total,
|
|
"completed": self.completed,
|
|
"successCount": self.success_count,
|
|
"failureCount": self.failure_count,
|
|
"progressPercent": progress_percent,
|
|
"currentItem": self.current_item,
|
|
"currentMessage": self.current_message,
|
|
"recentFailures": self.recent_failures,
|
|
}
|
|
|
|
|
|
class TaskNotFoundError(KeyError):
|
|
pass
|
|
|
|
|
|
class BackgroundTaskManager:
|
|
def __init__(self, max_workers: int = 4):
|
|
self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="office365-task")
|
|
self._tasks: dict[str, TaskRecord] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def submit(
|
|
self,
|
|
operation: str,
|
|
total: int,
|
|
runner: Callable[[Callable[[dict[str, Any]], None]], dict[str, Any]],
|
|
) -> dict[str, Any]:
|
|
task_id = uuid.uuid4().hex
|
|
record = TaskRecord(
|
|
id=task_id,
|
|
operation=operation,
|
|
total=total,
|
|
message=f"{self._label(operation)}任务已提交",
|
|
)
|
|
with self._lock:
|
|
self._tasks[task_id] = record
|
|
|
|
logger.info("Task %s queued: operation=%s total=%s", task_id, operation, total)
|
|
self._executor.submit(self._run_task, task_id, runner)
|
|
return record.to_dict()
|
|
|
|
def get_task(self, task_id: str) -> dict[str, Any]:
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is None:
|
|
raise TaskNotFoundError(task_id)
|
|
return record.to_dict()
|
|
|
|
def _run_task(
|
|
self,
|
|
task_id: str,
|
|
runner: Callable[[Callable[[dict[str, Any]], None]], dict[str, Any]],
|
|
) -> None:
|
|
self._update(
|
|
task_id,
|
|
status="running",
|
|
started_at=_now_iso(),
|
|
message="任务执行中",
|
|
)
|
|
logger.info("Task %s started", task_id)
|
|
|
|
try:
|
|
result = runner(lambda update: self._handle_progress(task_id, update))
|
|
summary_message = (
|
|
f"任务完成,成功 {result.get('successCount', 0)},失败 {result.get('failureCount', 0)}"
|
|
)
|
|
self._update(
|
|
task_id,
|
|
status="succeeded",
|
|
finished_at=_now_iso(),
|
|
completed=result.get("total", 0),
|
|
success_count=result.get("successCount", 0),
|
|
failure_count=result.get("failureCount", 0),
|
|
current_message=summary_message,
|
|
message=summary_message,
|
|
)
|
|
logger.info("Task %s finished: %s", task_id, summary_message)
|
|
except Exception as exc:
|
|
logger.exception("Task %s failed", task_id)
|
|
self._update(
|
|
task_id,
|
|
status="failed",
|
|
finished_at=_now_iso(),
|
|
message=f"任务执行失败: {exc}",
|
|
current_message=str(exc),
|
|
)
|
|
|
|
def _handle_progress(self, task_id: str, update: dict[str, Any]) -> None:
|
|
update_payload = {
|
|
"completed": update.get("completed", 0),
|
|
"success_count": update.get("successCount", 0),
|
|
"failure_count": update.get("failureCount", 0),
|
|
"current_item": update.get("identifier", ""),
|
|
"current_message": update.get("message", ""),
|
|
"message": f"正在执行 {update.get('completed', 0)} / {update.get('total', 0)}",
|
|
}
|
|
if not update.get("success") and update.get("identifier"):
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is not None and len(record.recent_failures) < 5:
|
|
record.recent_failures.append(
|
|
{
|
|
"identifier": update.get("identifier", ""),
|
|
"message": update.get("message", ""),
|
|
}
|
|
)
|
|
self._update(task_id, **update_payload)
|
|
|
|
def _update(self, task_id: str, **changes: Any) -> None:
|
|
with self._lock:
|
|
record = self._tasks.get(task_id)
|
|
if record is None:
|
|
return
|
|
for key, value in changes.items():
|
|
setattr(record, key, value)
|
|
|
|
@staticmethod
|
|
def _label(operation: str) -> str:
|
|
labels = {
|
|
"create": "批量创建",
|
|
"update": "批量更新",
|
|
"delete": "批量删除",
|
|
"reset-password": "批量改密",
|
|
}
|
|
return labels.get(operation, operation)
|