194 lines
7.4 KiB
Python
194 lines
7.4 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from .graph import GraphAPIError, GraphClient, TokenManager
|
||
from .settings import Settings
|
||
|
||
|
||
logger = logging.getLogger("office365_self_service.service")
|
||
|
||
|
||
class ServiceOperationError(RuntimeError):
|
||
def __init__(self, message: str, status_code: int = 400, details=None):
|
||
super().__init__(message)
|
||
self.message = message
|
||
self.status_code = status_code
|
||
self.details = details
|
||
|
||
|
||
class ServiceConfigurationError(RuntimeError):
|
||
pass
|
||
|
||
|
||
class Office365Service:
|
||
def __init__(self, settings: Settings):
|
||
self.settings = settings
|
||
self._graph_client: GraphClient | None = None
|
||
|
||
def _ensure_client(self) -> GraphClient:
|
||
if not self.settings.graph_ready:
|
||
joined = ";".join(self.settings.validation_errors)
|
||
raise ServiceConfigurationError(f"Graph 配置不完整: {joined}")
|
||
|
||
if self._graph_client is None:
|
||
token_manager = TokenManager(
|
||
client_id=self.settings.client_id,
|
||
client_secret=self.settings.client_secret,
|
||
token_endpoint=self.settings.token_endpoint,
|
||
scope=self.settings.scope,
|
||
)
|
||
self._graph_client = GraphClient(token_manager, self.settings.graph_base_url)
|
||
return self._graph_client
|
||
|
||
def create_user(self, username: str, password: str | None = None, display_name: str | None = None, retry: bool = True) -> dict[str, Any]:
|
||
upn, mail_nickname = self._build_user_identifiers(username)
|
||
client = self._ensure_client()
|
||
|
||
password = password or self.settings.default_password
|
||
display_name = display_name or mail_nickname
|
||
|
||
create_payload = {
|
||
"accountEnabled": True,
|
||
"displayName": display_name,
|
||
"mailNickname": mail_nickname,
|
||
"userPrincipalName": upn,
|
||
"passwordProfile": {
|
||
"password": password,
|
||
"forceChangePasswordNextSignIn": self.settings.force_change_password,
|
||
},
|
||
"usageLocation": self.settings.default_usage_location,
|
||
}
|
||
|
||
try:
|
||
user = client.create_user(create_payload)
|
||
except GraphAPIError as exc:
|
||
if retry and "token is expired" in str(exc).lower():
|
||
logger.info("Token expired, refreshing and retrying...")
|
||
self._graph_client.token_manager.clear_token()
|
||
return self.create_user(username, password, display_name, retry=False)
|
||
raise self._translate_graph_error(exc, f"创建用户 {upn} 失败")
|
||
|
||
license_result = None
|
||
license_message = None
|
||
if self.settings.default_license_sku:
|
||
license_result, license_message, license_status = self._assign_license(user["id"])
|
||
if license_message and self.settings.license_assignment_required:
|
||
self._rollback_user_for_license_failure(
|
||
client=client,
|
||
user_id=user["id"],
|
||
user_principal_name=upn,
|
||
license_message=license_message,
|
||
license_status=license_status,
|
||
)
|
||
|
||
return {
|
||
"user": user,
|
||
"userPrincipalName": upn,
|
||
"temporaryPassword": password,
|
||
"licenseAssigned": bool(license_result),
|
||
"licenseResult": license_result,
|
||
"licenseMessage": license_message,
|
||
}
|
||
|
||
def _build_user_identifiers(self, username: str) -> tuple[str, str]:
|
||
normalized = (username or "").strip().lower()
|
||
if not normalized:
|
||
raise ValueError("请输入用户名。")
|
||
|
||
if "@" in normalized:
|
||
local_part, _, domain = normalized.partition("@")
|
||
if not local_part or not domain:
|
||
raise ValueError("请输入有效的完整邮箱地址。")
|
||
return normalized, local_part
|
||
|
||
if not self.settings.default_domain:
|
||
raise ServiceConfigurationError("DEFAULT_DOMAIN 未配置,请输入完整邮箱地址后重试。")
|
||
|
||
return f"{normalized}@{self.settings.default_domain}", normalized
|
||
|
||
def _assign_license(self, user_id: str) -> tuple[dict[str, Any] | None, str | None, int]:
|
||
client = self._ensure_client()
|
||
sku_part_number = self.settings.default_license_sku
|
||
|
||
try:
|
||
skus = client.list_subscribed_skus()
|
||
except GraphAPIError as exc:
|
||
message = f"获取许可证列表失败: {exc.message or exc}"
|
||
logger.warning(message)
|
||
return None, message, exc.status_code or 502
|
||
|
||
matched = next(
|
||
(sku for sku in skus if (sku.get("skuPartNumber") or "").upper() == sku_part_number.upper()),
|
||
None,
|
||
)
|
||
if not matched:
|
||
message = f"未找到许可证 SKU: {sku_part_number}"
|
||
logger.warning(message)
|
||
return None, message, 409
|
||
if int(matched.get("consumedUnits", 0) or 0) >= int(matched.get("prepaidUnits", {}).get("enabled", 0) or 0):
|
||
message = f"许可证 {sku_part_number} 已无可用席位"
|
||
logger.warning(message)
|
||
return None, message, 409
|
||
|
||
try:
|
||
return (
|
||
client.assign_license(
|
||
user_id,
|
||
add_licenses=[{"skuId": matched["skuId"], "disabledPlans": []}],
|
||
),
|
||
None,
|
||
200,
|
||
)
|
||
except GraphAPIError as exc:
|
||
message = f"分配许可证失败: {exc.message or exc}"
|
||
logger.warning(message)
|
||
return None, message, exc.status_code or 502
|
||
|
||
def _rollback_user_for_license_failure(
|
||
self,
|
||
client: GraphClient,
|
||
user_id: str,
|
||
user_principal_name: str,
|
||
license_message: str,
|
||
license_status: int,
|
||
) -> None:
|
||
try:
|
||
client.delete_user(user_id)
|
||
except GraphAPIError as exc:
|
||
delete_message = exc.message or str(exc)
|
||
raise ServiceOperationError(
|
||
message=(
|
||
f"账号 {user_principal_name} 已创建,但许可证分配失败且回滚删除失败。"
|
||
f"{license_message};删除失败: {delete_message}"
|
||
),
|
||
status_code=502,
|
||
details={
|
||
"userPrincipalName": user_principal_name,
|
||
"licenseError": license_message,
|
||
"rollbackDeleteError": delete_message,
|
||
"rolledBack": False,
|
||
},
|
||
) from exc
|
||
|
||
raise ServiceOperationError(
|
||
message=f"许可证分配失败,已回滚删除账号 {user_principal_name}。{license_message}",
|
||
status_code=license_status or 409,
|
||
details={
|
||
"userPrincipalName": user_principal_name,
|
||
"licenseError": license_message,
|
||
"rolledBack": True,
|
||
},
|
||
)
|
||
|
||
def _translate_graph_error(self, exc: GraphAPIError, fallback_message: str) -> ServiceOperationError:
|
||
message = fallback_message
|
||
if exc.message:
|
||
message = f"{fallback_message}: {exc.message}"
|
||
status_code = exc.status_code or 502
|
||
lowered = message.lower()
|
||
if "already exists" in lowered or "another object with the same value" in lowered:
|
||
status_code = 409
|
||
return ServiceOperationError(message=message, status_code=status_code, details=exc.response)
|