from __future__ import annotations from dataclasses import dataclass import logging from typing import Any from .graph import GraphAPIError, GraphClient, TokenManager from .settings import Settings logger = logging.getLogger("office365_admin.service") USER_SELECT_FIELDS = [ "id", "displayName", "userPrincipalName", "mail", "givenName", "surname", "department", "jobTitle", "officeLocation", "mobilePhone", "usageLocation", "accountEnabled", "assignedLicenses", "createdDateTime", ] IDENTIFIER_ALIASES = [ "userPrincipalName", "user_principal_name", "user_id", "userId", "username", "email", "mail", "upn", "id", ] OPTIONAL_FIELD_ALIASES = { "displayName": ["displayName", "display_name"], "mailNickname": ["mailNickname", "mail_nickname", "nickname"], "givenName": ["givenName", "given_name", "firstName", "firstname"], "surname": ["surname", "lastName", "lastname", "last_name"], "department": ["department"], "jobTitle": ["jobTitle", "job_title"], "officeLocation": ["officeLocation", "office_location"], "mobilePhone": ["mobilePhone", "mobile", "phone"], "usageLocation": ["usageLocation", "usage_location"], "userPrincipalName": ["userPrincipalName", "user_principal_name", "upn"], } NULLABLE_FIELDS = { "displayName", "givenName", "surname", "department", "jobTitle", "officeLocation", "mobilePhone", } @dataclass class ServiceOperationError(RuntimeError): message: str status_code: int = 400 details: Any = None def __str__(self) -> str: return self.message class ServiceConfigurationError(RuntimeError): pass class Office365Service: def __init__(self, settings: Settings): self.settings = settings self._graph_client: GraphClient | None = None def status(self) -> dict[str, Any]: return { "ready": self.settings.graph_ready, "validationErrors": list(self.settings.validation_errors), "warnings": list(self.settings.warnings), "graphFlavor": "Microsoft Graph Global", } def _ensure_client(self) -> GraphClient: if not self.settings.graph_ready: joined = ";".join(self.settings.validation_errors) raise ServiceConfigurationError(f"Graph 配置不完整: {joined}") if self._graph_client is None: token_manager = TokenManager( client_id=self.settings.client_id, client_secret=self.settings.client_secret, token_endpoint=self.settings.token_endpoint, scope=self.settings.scope, ) self._graph_client = GraphClient(token_manager, self.settings.graph_base_url) return self._graph_client def list_licenses(self) -> list[dict[str, Any]]: client = self._ensure_client() try: skus = client.list_subscribed_skus() except GraphAPIError as exc: raise self._translate_graph_error(exc, "读取许可证列表失败") items = [] for sku in skus: total = int(sku.get("prepaidUnits", {}).get("enabled", 0) or 0) consumed = int(sku.get("consumedUnits", 0) or 0) items.append( { "skuId": sku.get("skuId"), "skuPartNumber": sku.get("skuPartNumber"), "availableUnits": max(total - consumed, 0), "totalUnits": total, "consumedUnits": consumed, } ) return sorted(items, key=lambda item: item["skuPartNumber"] or "") def list_users(self, search: str = "", page: int = 1, page_size: int | None = None) -> dict[str, Any]: requested_page_size = page_size or self.settings.default_page_size requested_page_size = min(max(requested_page_size, 1), self.settings.max_page_size) page = max(page, 1) users, total_before_search = self._list_filtered_users(search) total = len(users) start = (page - 1) * requested_page_size end = start + requested_page_size paged_users = users[start:end] return { "items": paged_users, "page": page, "pageSize": requested_page_size, "total": total, "totalBeforeSearch": total_before_search, "summary": { "active": sum(1 for user in users if user["accountEnabled"]), "disabled": sum(1 for user in users if not user["accountEnabled"]), }, } def list_user_identifiers(self, search: str = "") -> dict[str, Any]: users, _ = self._list_filtered_users(search) identifiers = [ user["userPrincipalName"] for user in users if user.get("userPrincipalName") ] return { "identifiers": identifiers, "total": len(identifiers), } def get_user(self, identifier: str) -> dict[str, Any]: client = self._ensure_client() identifier = self._normalize_identifier(identifier) sku_lookup = self._get_sku_lookup() try: user = client.get_user(identifier, USER_SELECT_FIELDS) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"读取用户 {identifier} 失败") serialized = self._serialize_user(user, sku_lookup=sku_lookup) serialized["licenses"] = [ { "skuId": sku_id, "skuPartNumber": sku_lookup.get(sku_id, sku_id), } for sku_id in serialized["assignedLicenses"] ] return serialized def create_user(self, payload: dict[str, Any]) -> dict[str, Any]: client = self._ensure_client() identifier = self._resolve_identifier(payload, required=True) upn = self._normalize_identifier(identifier) username = upn.split("@", 1)[0] password = self._string_value(payload, ["password"]) or self.settings.default_password force_change_password = self._bool_value( payload, ["forceChangePasswordNextSignIn", "force_change_password"], self.settings.force_change_password, ) account_enabled = self._bool_value(payload, ["accountEnabled", "enabled"], True) create_payload = { "accountEnabled": account_enabled, "displayName": self._string_value(payload, OPTIONAL_FIELD_ALIASES["displayName"]) or username, "mailNickname": self._string_value(payload, OPTIONAL_FIELD_ALIASES["mailNickname"]) or username, "userPrincipalName": upn, "passwordProfile": { "password": password, "forceChangePasswordNextSignIn": force_change_password, }, } for graph_field, aliases in OPTIONAL_FIELD_ALIASES.items(): if graph_field in {"displayName", "mailNickname", "userPrincipalName"}: continue value = self._string_value(payload, aliases) if value: create_payload[graph_field] = value if "usageLocation" not in create_payload: create_payload["usageLocation"] = self.settings.default_usage_location try: user = client.create_user(create_payload) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"创建用户 {upn} 失败") license_result = None sku_part_number = self._string_value(payload, ["skuPartNumber", "sku", "license"]) or self.settings.default_license_sku if sku_part_number: license_result = self._assign_license(user["id"], sku_part_number) return { "user": self.get_user(user["id"]), "temporaryPassword": password, "licenseAssigned": bool(license_result), "licenseResult": license_result, } def update_user( self, identifier: str, payload: dict[str, Any], *, blank_strategy: str = "clear", ) -> dict[str, Any]: client = self._ensure_client() identifier = self._normalize_identifier(identifier) patch_payload: dict[str, Any] = {} for graph_field, aliases in OPTIONAL_FIELD_ALIASES.items(): value = self._raw_value(payload, aliases) if value is None: continue if isinstance(value, str): value = value.strip() if value == "" and graph_field in NULLABLE_FIELDS: if blank_strategy == "clear": patch_payload[graph_field] = None continue if value != "": patch_payload[graph_field] = value if self._raw_value(payload, ["accountEnabled", "enabled"]) is not None: patch_payload["accountEnabled"] = self._bool_value(payload, ["accountEnabled", "enabled"], True) password = self._string_value(payload, ["password"]) if password: patch_payload["passwordProfile"] = { "password": password, "forceChangePasswordNextSignIn": self._bool_value( payload, ["forceChangePasswordNextSignIn", "force_change_password"], self.settings.force_change_password, ), } if patch_payload: try: client.update_user(identifier, patch_payload) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"更新用户 {identifier} 失败") license_result = None sku_part_number = self._string_value(payload, ["skuPartNumber", "sku", "license"]) if sku_part_number: existing_user = self.get_user(identifier) license_result = self._assign_license(existing_user["id"], sku_part_number) updated_identifier = patch_payload.get("userPrincipalName", identifier) return { "user": self.get_user(updated_identifier), "licenseAssigned": bool(license_result), "licenseResult": license_result, } def delete_user(self, identifier: str) -> dict[str, Any]: client = self._ensure_client() identifier = self._normalize_identifier(identifier) try: user = self.get_user(identifier) client.delete_user(identifier) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"删除用户 {identifier} 失败") return {"user": user} def reset_password(self, identifier: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: client = self._ensure_client() identifier = self._normalize_identifier(identifier) payload = payload or {} password = self._string_value(payload, ["password"]) or self.settings.default_password force_change_password = self._bool_value( payload, ["forceChangePasswordNextSignIn", "force_change_password"], self.settings.force_change_password, ) reset_payload = { "passwordProfile": { "password": password, "forceChangePasswordNextSignIn": force_change_password, } } try: client.update_user(identifier, reset_payload) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"重置用户 {identifier} 密码失败") return { "user": self.get_user(identifier), "temporaryPassword": password, } def batch_create(self, rows: list[dict[str, Any]], progress_callback=None) -> dict[str, Any]: return self._run_batch( operation="create", items=rows, callback=lambda row: self.create_user(row), identifier_getter=lambda row: self._resolve_identifier(row, required=False), progress_callback=progress_callback, ) def batch_update(self, rows: list[dict[str, Any]], progress_callback=None) -> dict[str, Any]: return self._run_batch( operation="update", items=rows, callback=lambda row: self.update_user( self._resolve_identifier(row, required=True), row, blank_strategy="ignore", ), identifier_getter=lambda row: self._resolve_identifier(row, required=False), progress_callback=progress_callback, ) def batch_delete(self, identifiers: list[str], progress_callback=None) -> dict[str, Any]: return self._run_batch( operation="delete", items=identifiers, callback=lambda identifier: self.delete_user(identifier), identifier_getter=lambda identifier: identifier, progress_callback=progress_callback, ) def batch_reset_password(self, rows: list[dict[str, Any]] | list[str], progress_callback=None) -> dict[str, Any]: return self._run_batch( operation="reset-password", items=rows, callback=self._batch_reset_callback, identifier_getter=self._batch_reset_identifier, progress_callback=progress_callback, ) def _batch_reset_callback(self, item: dict[str, Any] | str) -> dict[str, Any]: if isinstance(item, str): return self.reset_password(item) identifier = self._resolve_identifier(item, required=True) return self.reset_password(identifier, item) def _batch_reset_identifier(self, item: dict[str, Any] | str) -> str: if isinstance(item, str): return item return self._resolve_identifier(item, required=False) def _run_batch(self, operation: str, items: list[Any], callback, identifier_getter, progress_callback=None) -> dict[str, Any]: results = [] success_count = 0 logger.info("Batch %s started: total=%s", operation, len(items)) for index, item in enumerate(items, start=1): identifier = identifier_getter(item) or f"item-{index}" try: result = callback(item) success_count += 1 record = { "index": index, "identifier": identifier, "success": True, "message": "执行成功", "data": result, } logger.info("Batch %s item success: %s", operation, identifier) except (ServiceConfigurationError, ServiceOperationError, ValueError) as exc: record = { "index": index, "identifier": identifier, "success": False, "message": str(exc), } logger.warning("Batch %s item failed: %s - %s", operation, identifier, exc) except Exception as exc: record = { "index": index, "identifier": identifier, "success": False, "message": str(exc), } logger.exception("Batch %s item crashed: %s", operation, identifier) results.append(record) if progress_callback: progress_callback( { "completed": index, "total": len(items), "successCount": success_count, "failureCount": index - success_count, "identifier": identifier, "success": record["success"], "message": record["message"], } ) summary = { "operation": operation, "total": len(items), "successCount": success_count, "failureCount": len(items) - success_count, "results": results, } logger.info( "Batch %s finished: total=%s success=%s failure=%s", operation, summary["total"], summary["successCount"], summary["failureCount"], ) return summary def _assign_license(self, user_id: str, sku_part_number: str) -> dict[str, Any]: client = self._ensure_client() skus = self.list_licenses() matched = next( (sku for sku in skus if (sku["skuPartNumber"] or "").upper() == sku_part_number.upper()), None, ) if not matched: raise ServiceOperationError(f"未找到许可证 SKU: {sku_part_number}", status_code=404) if matched["availableUnits"] <= 0: raise ServiceOperationError(f"许可证 {sku_part_number} 已无可用席位。", status_code=409) try: return client.assign_license( user_id, add_licenses=[{"skuId": matched["skuId"], "disabledPlans": []}], ) except GraphAPIError as exc: raise self._translate_graph_error(exc, f"为用户分配许可证 {sku_part_number} 失败") def _get_sku_lookup(self) -> dict[str, str]: return { item["skuId"]: item["skuPartNumber"] for item in self.list_licenses() if item.get("skuId") } def _list_filtered_users(self, search: str = "") -> tuple[list[dict[str, Any]], int]: client = self._ensure_client() try: raw_users = client.list_users(USER_SELECT_FIELDS) except GraphAPIError as exc: raise self._translate_graph_error(exc, "读取用户列表失败") users = [self._serialize_user(user) for user in raw_users] total_before_search = len(users) if search.strip(): query = search.strip().lower() users = [ user for user in users if any( query in str(user.get(field, "") or "").lower() for field in ( "displayName", "userPrincipalName", "mail", "department", "jobTitle", "givenName", "surname", ) ) ] users.sort(key=lambda item: (item["userPrincipalName"] or "").lower()) return users, total_before_search def _serialize_user(self, user: dict[str, Any], sku_lookup: dict[str, str] | None = None) -> dict[str, Any]: assigned_license_ids = [ item.get("skuId") for item in (user.get("assignedLicenses") or []) if item.get("skuId") ] license_labels = [sku_lookup.get(item, item) for item in assigned_license_ids] if sku_lookup else [] return { "id": user.get("id"), "displayName": user.get("displayName") or "", "userPrincipalName": user.get("userPrincipalName") or "", "mail": user.get("mail") or "", "givenName": user.get("givenName") or "", "surname": user.get("surname") or "", "department": user.get("department") or "", "jobTitle": user.get("jobTitle") or "", "officeLocation": user.get("officeLocation") or "", "mobilePhone": user.get("mobilePhone") or "", "usageLocation": user.get("usageLocation") or "", "accountEnabled": bool(user.get("accountEnabled", True)), "assignedLicenses": assigned_license_ids, "assignedLicensesCount": len(assigned_license_ids), "licenseLabels": license_labels, "createdDateTime": user.get("createdDateTime") or "", } def _normalize_identifier(self, identifier: str) -> str: normalized = str(identifier).strip() if not normalized: raise ValueError("账号标识不能为空。") if "@" in normalized: return normalized if self.settings.default_domain: return f"{normalized}@{self.settings.default_domain}" return normalized def _resolve_identifier(self, payload: dict[str, Any], required: bool = False) -> str: value = self._string_value(payload, IDENTIFIER_ALIASES) if value: return value if required: raise ValueError("缺少账号标识字段,至少需要 userPrincipalName / user_id / username / email 之一。") return "" def _raw_value(self, payload: dict[str, Any], aliases: list[str]) -> Any: normalized_payload = {self._normalize_key(key): value for key, value in payload.items()} for alias in aliases: normalized_alias = self._normalize_key(alias) if normalized_alias in normalized_payload: return normalized_payload[normalized_alias] return None def _string_value(self, payload: dict[str, Any], aliases: list[str]) -> str: value = self._raw_value(payload, aliases) if value is None: return "" if isinstance(value, str): return value.strip() return str(value).strip() def _bool_value(self, payload: dict[str, Any], aliases: list[str], default: bool) -> bool: value = self._raw_value(payload, aliases) if value is None: return default if isinstance(value, bool): return value if isinstance(value, (int, float)): return bool(value) normalized = str(value).strip().lower() if normalized in {"1", "true", "yes", "y", "enabled", "on"}: return True if normalized in {"0", "false", "no", "n", "disabled", "off"}: return False return default @staticmethod def _normalize_key(key: str) -> str: return "".join(ch for ch in str(key).lower() if ch.isalnum()) @staticmethod def _translate_graph_error(exc: GraphAPIError, fallback_message: str) -> ServiceOperationError: message = fallback_message if exc.message: message = f"{fallback_message}: {exc.message}" status_code = exc.status_code or 502 lowered = message.lower() if "already exists" in lowered or "another object with the same value" in lowered: status_code = 409 return ServiceOperationError(message=message, status_code=status_code, details=exc.response)