from __future__ import annotations import time from typing import Any from urllib.parse import quote import requests REQUEST_TIMEOUT = 30 class TokenError(RuntimeError): pass class GraphAPIError(RuntimeError): def __init__(self, message: str, status_code: int | None = None, response: Any = None): super().__init__(message) self.message = message self.status_code = status_code self.response = response class TokenManager: def __init__(self, client_id: str, client_secret: str, token_endpoint: str, scope: str): self.client_id = client_id self.client_secret = client_secret self.token_endpoint = token_endpoint self.scope = scope self._token = "" self._expires_at = 0 def get_access_token(self) -> str: now = int(time.time()) if self._token and now < self._expires_at: return self._token data = { "grant_type": "client_credentials", "client_id": self.client_id, "client_secret": self.client_secret, "scope": self.scope, } try: response = requests.post( self.token_endpoint, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=REQUEST_TIMEOUT, ) response.raise_for_status() except requests.RequestException as exc: message = f"获取访问令牌失败: {exc}" raise TokenError(message) from exc payload = response.json() token = payload.get("access_token") if not token: raise TokenError("访问令牌响应缺少 access_token。") expires_in = int(payload.get("expires_in", 3600)) self._token = token self._expires_at = now + max(expires_in - 300, 60) return token class GraphClient: def __init__(self, token_manager: TokenManager, base_url: str): self.token_manager = token_manager self.base_url = base_url.rstrip("/") self.session = requests.Session() def _headers(self, extra_headers: dict[str, str] | None = None) -> dict[str, str]: headers = { "Authorization": f"Bearer {self.token_manager.get_access_token()}", "Content-Type": "application/json", } if extra_headers: headers.update(extra_headers) return headers def _request( self, method: str, endpoint: str, *, params: dict[str, Any] | None = None, json_data: dict[str, Any] | None = None, extra_headers: dict[str, str] | None = None, absolute_url: bool = False, ) -> dict[str, Any]: url = endpoint if absolute_url else f"{self.base_url}/{endpoint.lstrip('/')}" try: response = self.session.request( method=method.upper(), url=url, params=params, json=json_data, headers=self._headers(extra_headers), timeout=REQUEST_TIMEOUT, ) except requests.RequestException as exc: raise GraphAPIError(f"调用 Microsoft Graph 失败: {exc}") from exc if response.status_code == 204: return {} if not response.ok: details: Any try: details = response.json() except ValueError: details = {"error": response.text} message = self._extract_error_message(details) or f"Graph API 返回 HTTP {response.status_code}" raise GraphAPIError(message, status_code=response.status_code, response=details) if not response.content: return {} return response.json() @staticmethod def _extract_error_message(details: Any) -> str: if isinstance(details, dict): graph_error = details.get("error") if isinstance(graph_error, dict): return str(graph_error.get("message") or graph_error.get("code") or "") if graph_error: return str(graph_error) if "message" in details: return str(details["message"]) return "" @staticmethod def _quote_identifier(identifier: str) -> str: return quote(str(identifier), safe="@._-$") def create_user(self, payload: dict[str, Any]) -> dict[str, Any]: return self._request("POST", "/users", json_data=payload) def get_user(self, identifier: str, select_fields: list[str] | None = None) -> dict[str, Any]: params = {} if select_fields: params["$select"] = ",".join(select_fields) return self._request("GET", f"/users/{self._quote_identifier(identifier)}", params=params or None) def update_user(self, identifier: str, payload: dict[str, Any]) -> dict[str, Any]: return self._request("PATCH", f"/users/{self._quote_identifier(identifier)}", json_data=payload) def delete_user(self, identifier: str) -> None: self._request("DELETE", f"/users/{self._quote_identifier(identifier)}") def list_users(self, select_fields: list[str] | None = None) -> list[dict[str, Any]]: params = {} if select_fields: params["$select"] = ",".join(select_fields) response = self._request("GET", "/users", params=params or None) users = list(response.get("value", [])) next_link = response.get("@odata.nextLink") while next_link: response = self._request("GET", next_link, absolute_url=True) users.extend(response.get("value", [])) next_link = response.get("@odata.nextLink") return users def list_subscribed_skus(self) -> list[dict[str, Any]]: return self._request("GET", "/subscribedSkus").get("value", []) def assign_license( self, user_id: str, *, add_licenses: list[dict[str, Any]], remove_licenses: list[str] | None = None, ) -> dict[str, Any]: payload = { "addLicenses": add_licenses, "removeLicenses": remove_licenses or [], } return self._request( "POST", f"/users/{self._quote_identifier(user_id)}/assignLicense", json_data=payload, )