Files
office365manage/office365_admin/services.py

598 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from dataclasses import dataclass
import logging
from typing import Any
from .graph import GraphAPIError, GraphClient, TokenManager
from .settings import Settings
logger = logging.getLogger("office365_admin.service")
USER_SELECT_FIELDS = [
"id",
"displayName",
"userPrincipalName",
"mail",
"givenName",
"surname",
"department",
"jobTitle",
"officeLocation",
"mobilePhone",
"usageLocation",
"accountEnabled",
"assignedLicenses",
"createdDateTime",
]
IDENTIFIER_ALIASES = [
"userPrincipalName",
"user_principal_name",
"user_id",
"userId",
"username",
"email",
"mail",
"upn",
"id",
]
OPTIONAL_FIELD_ALIASES = {
"displayName": ["displayName", "display_name"],
"mailNickname": ["mailNickname", "mail_nickname", "nickname"],
"givenName": ["givenName", "given_name", "firstName", "firstname"],
"surname": ["surname", "lastName", "lastname", "last_name"],
"department": ["department"],
"jobTitle": ["jobTitle", "job_title"],
"officeLocation": ["officeLocation", "office_location"],
"mobilePhone": ["mobilePhone", "mobile", "phone"],
"usageLocation": ["usageLocation", "usage_location"],
"userPrincipalName": ["userPrincipalName", "user_principal_name", "upn"],
}
NULLABLE_FIELDS = {
"displayName",
"givenName",
"surname",
"department",
"jobTitle",
"officeLocation",
"mobilePhone",
}
@dataclass
class ServiceOperationError(RuntimeError):
message: str
status_code: int = 400
details: Any = None
def __str__(self) -> str:
return self.message
class ServiceConfigurationError(RuntimeError):
pass
class Office365Service:
def __init__(self, settings: Settings):
self.settings = settings
self._graph_client: GraphClient | None = None
def status(self) -> dict[str, Any]:
return {
"ready": self.settings.graph_ready,
"validationErrors": list(self.settings.validation_errors),
"warnings": list(self.settings.warnings),
"graphFlavor": "Microsoft Graph Global",
}
def _ensure_client(self) -> GraphClient:
if not self.settings.graph_ready:
joined = "".join(self.settings.validation_errors)
raise ServiceConfigurationError(f"Graph 配置不完整: {joined}")
if self._graph_client is None:
token_manager = TokenManager(
client_id=self.settings.client_id,
client_secret=self.settings.client_secret,
token_endpoint=self.settings.token_endpoint,
scope=self.settings.scope,
)
self._graph_client = GraphClient(token_manager, self.settings.graph_base_url)
return self._graph_client
def list_licenses(self) -> list[dict[str, Any]]:
client = self._ensure_client()
try:
skus = client.list_subscribed_skus()
except GraphAPIError as exc:
raise self._translate_graph_error(exc, "读取许可证列表失败")
items = []
for sku in skus:
total = int(sku.get("prepaidUnits", {}).get("enabled", 0) or 0)
consumed = int(sku.get("consumedUnits", 0) or 0)
items.append(
{
"skuId": sku.get("skuId"),
"skuPartNumber": sku.get("skuPartNumber"),
"availableUnits": max(total - consumed, 0),
"totalUnits": total,
"consumedUnits": consumed,
}
)
return sorted(items, key=lambda item: item["skuPartNumber"] or "")
def list_users(self, search: str = "", page: int = 1, page_size: int | None = None) -> dict[str, Any]:
requested_page_size = page_size or self.settings.default_page_size
requested_page_size = min(max(requested_page_size, 1), self.settings.max_page_size)
page = max(page, 1)
users, total_before_search = self._list_filtered_users(search)
total = len(users)
start = (page - 1) * requested_page_size
end = start + requested_page_size
paged_users = users[start:end]
return {
"items": paged_users,
"page": page,
"pageSize": requested_page_size,
"total": total,
"totalBeforeSearch": total_before_search,
"summary": {
"active": sum(1 for user in users if user["accountEnabled"]),
"disabled": sum(1 for user in users if not user["accountEnabled"]),
},
}
def list_user_identifiers(self, search: str = "") -> dict[str, Any]:
users, _ = self._list_filtered_users(search)
identifiers = [
user["userPrincipalName"]
for user in users
if user.get("userPrincipalName")
]
return {
"identifiers": identifiers,
"total": len(identifiers),
}
def get_user(self, identifier: str) -> dict[str, Any]:
client = self._ensure_client()
identifier = self._normalize_identifier(identifier)
sku_lookup = self._get_sku_lookup()
try:
user = client.get_user(identifier, USER_SELECT_FIELDS)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"读取用户 {identifier} 失败")
serialized = self._serialize_user(user, sku_lookup=sku_lookup)
serialized["licenses"] = [
{
"skuId": sku_id,
"skuPartNumber": sku_lookup.get(sku_id, sku_id),
}
for sku_id in serialized["assignedLicenses"]
]
return serialized
def create_user(self, payload: dict[str, Any]) -> dict[str, Any]:
client = self._ensure_client()
identifier = self._resolve_identifier(payload, required=True)
upn = self._normalize_identifier(identifier)
username = upn.split("@", 1)[0]
password = self._string_value(payload, ["password"]) or self.settings.default_password
force_change_password = self._bool_value(
payload,
["forceChangePasswordNextSignIn", "force_change_password"],
self.settings.force_change_password,
)
account_enabled = self._bool_value(payload, ["accountEnabled", "enabled"], True)
create_payload = {
"accountEnabled": account_enabled,
"displayName": self._string_value(payload, OPTIONAL_FIELD_ALIASES["displayName"]) or username,
"mailNickname": self._string_value(payload, OPTIONAL_FIELD_ALIASES["mailNickname"]) or username,
"userPrincipalName": upn,
"passwordProfile": {
"password": password,
"forceChangePasswordNextSignIn": force_change_password,
},
}
for graph_field, aliases in OPTIONAL_FIELD_ALIASES.items():
if graph_field in {"displayName", "mailNickname", "userPrincipalName"}:
continue
value = self._string_value(payload, aliases)
if value:
create_payload[graph_field] = value
if "usageLocation" not in create_payload:
create_payload["usageLocation"] = self.settings.default_usage_location
try:
user = client.create_user(create_payload)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"创建用户 {upn} 失败")
license_result = None
sku_part_number = self._string_value(payload, ["skuPartNumber", "sku", "license"]) or self.settings.default_license_sku
if sku_part_number:
license_result = self._assign_license(user["id"], sku_part_number)
return {
"user": self.get_user(user["id"]),
"temporaryPassword": password,
"licenseAssigned": bool(license_result),
"licenseResult": license_result,
}
def update_user(
self,
identifier: str,
payload: dict[str, Any],
*,
blank_strategy: str = "clear",
) -> dict[str, Any]:
client = self._ensure_client()
identifier = self._normalize_identifier(identifier)
patch_payload: dict[str, Any] = {}
for graph_field, aliases in OPTIONAL_FIELD_ALIASES.items():
value = self._raw_value(payload, aliases)
if value is None:
continue
if isinstance(value, str):
value = value.strip()
if value == "" and graph_field in NULLABLE_FIELDS:
if blank_strategy == "clear":
patch_payload[graph_field] = None
continue
if value != "":
patch_payload[graph_field] = value
if self._raw_value(payload, ["accountEnabled", "enabled"]) is not None:
patch_payload["accountEnabled"] = self._bool_value(payload, ["accountEnabled", "enabled"], True)
password = self._string_value(payload, ["password"])
if password:
patch_payload["passwordProfile"] = {
"password": password,
"forceChangePasswordNextSignIn": self._bool_value(
payload,
["forceChangePasswordNextSignIn", "force_change_password"],
self.settings.force_change_password,
),
}
if patch_payload:
try:
client.update_user(identifier, patch_payload)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"更新用户 {identifier} 失败")
license_result = None
sku_part_number = self._string_value(payload, ["skuPartNumber", "sku", "license"])
if sku_part_number:
existing_user = self.get_user(identifier)
license_result = self._assign_license(existing_user["id"], sku_part_number)
updated_identifier = patch_payload.get("userPrincipalName", identifier)
return {
"user": self.get_user(updated_identifier),
"licenseAssigned": bool(license_result),
"licenseResult": license_result,
}
def delete_user(self, identifier: str) -> dict[str, Any]:
client = self._ensure_client()
identifier = self._normalize_identifier(identifier)
try:
user = self.get_user(identifier)
client.delete_user(identifier)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"删除用户 {identifier} 失败")
return {"user": user}
def reset_password(self, identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
client = self._ensure_client()
identifier = self._normalize_identifier(identifier)
payload = payload or {}
password = self._string_value(payload, ["password"]) or self.settings.default_password
force_change_password = self._bool_value(
payload,
["forceChangePasswordNextSignIn", "force_change_password"],
self.settings.force_change_password,
)
reset_payload = {
"passwordProfile": {
"password": password,
"forceChangePasswordNextSignIn": force_change_password,
}
}
try:
client.update_user(identifier, reset_payload)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"重置用户 {identifier} 密码失败")
return {
"user": self.get_user(identifier),
"temporaryPassword": password,
}
def batch_create(self, rows: list[dict[str, Any]], progress_callback=None) -> dict[str, Any]:
return self._run_batch(
operation="create",
items=rows,
callback=lambda row: self.create_user(row),
identifier_getter=lambda row: self._resolve_identifier(row, required=False),
progress_callback=progress_callback,
)
def batch_update(self, rows: list[dict[str, Any]], progress_callback=None) -> dict[str, Any]:
return self._run_batch(
operation="update",
items=rows,
callback=lambda row: self.update_user(
self._resolve_identifier(row, required=True),
row,
blank_strategy="ignore",
),
identifier_getter=lambda row: self._resolve_identifier(row, required=False),
progress_callback=progress_callback,
)
def batch_delete(self, identifiers: list[str], progress_callback=None) -> dict[str, Any]:
return self._run_batch(
operation="delete",
items=identifiers,
callback=lambda identifier: self.delete_user(identifier),
identifier_getter=lambda identifier: identifier,
progress_callback=progress_callback,
)
def batch_reset_password(self, rows: list[dict[str, Any]] | list[str], progress_callback=None) -> dict[str, Any]:
return self._run_batch(
operation="reset-password",
items=rows,
callback=self._batch_reset_callback,
identifier_getter=self._batch_reset_identifier,
progress_callback=progress_callback,
)
def _batch_reset_callback(self, item: dict[str, Any] | str) -> dict[str, Any]:
if isinstance(item, str):
return self.reset_password(item)
identifier = self._resolve_identifier(item, required=True)
return self.reset_password(identifier, item)
def _batch_reset_identifier(self, item: dict[str, Any] | str) -> str:
if isinstance(item, str):
return item
return self._resolve_identifier(item, required=False)
def _run_batch(self, operation: str, items: list[Any], callback, identifier_getter, progress_callback=None) -> dict[str, Any]:
results = []
success_count = 0
logger.info("Batch %s started: total=%s", operation, len(items))
for index, item in enumerate(items, start=1):
identifier = identifier_getter(item) or f"item-{index}"
try:
result = callback(item)
success_count += 1
record = {
"index": index,
"identifier": identifier,
"success": True,
"message": "执行成功",
"data": result,
}
logger.info("Batch %s item success: %s", operation, identifier)
except (ServiceConfigurationError, ServiceOperationError, ValueError) as exc:
record = {
"index": index,
"identifier": identifier,
"success": False,
"message": str(exc),
}
logger.warning("Batch %s item failed: %s - %s", operation, identifier, exc)
except Exception as exc:
record = {
"index": index,
"identifier": identifier,
"success": False,
"message": str(exc),
}
logger.exception("Batch %s item crashed: %s", operation, identifier)
results.append(record)
if progress_callback:
progress_callback(
{
"completed": index,
"total": len(items),
"successCount": success_count,
"failureCount": index - success_count,
"identifier": identifier,
"success": record["success"],
"message": record["message"],
}
)
summary = {
"operation": operation,
"total": len(items),
"successCount": success_count,
"failureCount": len(items) - success_count,
"results": results,
}
logger.info(
"Batch %s finished: total=%s success=%s failure=%s",
operation,
summary["total"],
summary["successCount"],
summary["failureCount"],
)
return summary
def _assign_license(self, user_id: str, sku_part_number: str) -> dict[str, Any]:
client = self._ensure_client()
skus = self.list_licenses()
matched = next(
(sku for sku in skus if (sku["skuPartNumber"] or "").upper() == sku_part_number.upper()),
None,
)
if not matched:
raise ServiceOperationError(f"未找到许可证 SKU: {sku_part_number}", status_code=404)
if matched["availableUnits"] <= 0:
raise ServiceOperationError(f"许可证 {sku_part_number} 已无可用席位。", status_code=409)
try:
return client.assign_license(
user_id,
add_licenses=[{"skuId": matched["skuId"], "disabledPlans": []}],
)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, f"为用户分配许可证 {sku_part_number} 失败")
def _get_sku_lookup(self) -> dict[str, str]:
return {
item["skuId"]: item["skuPartNumber"]
for item in self.list_licenses()
if item.get("skuId")
}
def _list_filtered_users(self, search: str = "") -> tuple[list[dict[str, Any]], int]:
client = self._ensure_client()
try:
raw_users = client.list_users(USER_SELECT_FIELDS)
except GraphAPIError as exc:
raise self._translate_graph_error(exc, "读取用户列表失败")
users = [self._serialize_user(user) for user in raw_users]
total_before_search = len(users)
if search.strip():
query = search.strip().lower()
users = [
user
for user in users
if any(
query in str(user.get(field, "") or "").lower()
for field in (
"displayName",
"userPrincipalName",
"mail",
"department",
"jobTitle",
"givenName",
"surname",
)
)
]
users.sort(key=lambda item: (item["userPrincipalName"] or "").lower())
return users, total_before_search
def _serialize_user(self, user: dict[str, Any], sku_lookup: dict[str, str] | None = None) -> dict[str, Any]:
assigned_license_ids = [
item.get("skuId")
for item in (user.get("assignedLicenses") or [])
if item.get("skuId")
]
license_labels = [sku_lookup.get(item, item) for item in assigned_license_ids] if sku_lookup else []
return {
"id": user.get("id"),
"displayName": user.get("displayName") or "",
"userPrincipalName": user.get("userPrincipalName") or "",
"mail": user.get("mail") or "",
"givenName": user.get("givenName") or "",
"surname": user.get("surname") or "",
"department": user.get("department") or "",
"jobTitle": user.get("jobTitle") or "",
"officeLocation": user.get("officeLocation") or "",
"mobilePhone": user.get("mobilePhone") or "",
"usageLocation": user.get("usageLocation") or "",
"accountEnabled": bool(user.get("accountEnabled", True)),
"assignedLicenses": assigned_license_ids,
"assignedLicensesCount": len(assigned_license_ids),
"licenseLabels": license_labels,
"createdDateTime": user.get("createdDateTime") or "",
}
def _normalize_identifier(self, identifier: str) -> str:
normalized = str(identifier).strip()
if not normalized:
raise ValueError("账号标识不能为空。")
if "@" in normalized:
return normalized
if self.settings.default_domain:
return f"{normalized}@{self.settings.default_domain}"
return normalized
def _resolve_identifier(self, payload: dict[str, Any], required: bool = False) -> str:
value = self._string_value(payload, IDENTIFIER_ALIASES)
if value:
return value
if required:
raise ValueError("缺少账号标识字段,至少需要 userPrincipalName / user_id / username / email 之一。")
return ""
def _raw_value(self, payload: dict[str, Any], aliases: list[str]) -> Any:
normalized_payload = {self._normalize_key(key): value for key, value in payload.items()}
for alias in aliases:
normalized_alias = self._normalize_key(alias)
if normalized_alias in normalized_payload:
return normalized_payload[normalized_alias]
return None
def _string_value(self, payload: dict[str, Any], aliases: list[str]) -> str:
value = self._raw_value(payload, aliases)
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def _bool_value(self, payload: dict[str, Any], aliases: list[str], default: bool) -> bool:
value = self._raw_value(payload, aliases)
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
normalized = str(value).strip().lower()
if normalized in {"1", "true", "yes", "y", "enabled", "on"}:
return True
if normalized in {"0", "false", "no", "n", "disabled", "off"}:
return False
return default
@staticmethod
def _normalize_key(key: str) -> str:
return "".join(ch for ch in str(key).lower() if ch.isalnum())
@staticmethod
def _translate_graph_error(exc: GraphAPIError, fallback_message: str) -> ServiceOperationError:
message = fallback_message
if exc.message:
message = f"{fallback_message}: {exc.message}"
status_code = exc.status_code or 502
lowered = message.lower()
if "already exists" in lowered or "another object with the same value" in lowered:
status_code = 409
return ServiceOperationError(message=message, status_code=status_code, details=exc.response)