190 lines
6.2 KiB
Python
190 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import requests
|
|
|
|
|
|
REQUEST_TIMEOUT = 30
|
|
|
|
|
|
class TokenError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class GraphAPIError(RuntimeError):
|
|
def __init__(self, message: str, status_code: int | None = None, response: Any = None):
|
|
super().__init__(message)
|
|
self.message = message
|
|
self.status_code = status_code
|
|
self.response = response
|
|
|
|
|
|
class TokenManager:
|
|
def __init__(self, client_id: str, client_secret: str, token_endpoint: str, scope: str):
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.token_endpoint = token_endpoint
|
|
self.scope = scope
|
|
self._token = ""
|
|
self._expires_at = 0
|
|
|
|
def get_access_token(self) -> str:
|
|
now = int(time.time())
|
|
if self._token and now < self._expires_at:
|
|
return self._token
|
|
|
|
data = {
|
|
"grant_type": "client_credentials",
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"scope": self.scope,
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
self.token_endpoint,
|
|
data=data,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
timeout=REQUEST_TIMEOUT,
|
|
)
|
|
response.raise_for_status()
|
|
except requests.RequestException as exc:
|
|
message = f"获取访问令牌失败: {exc}"
|
|
raise TokenError(message) from exc
|
|
|
|
payload = response.json()
|
|
token = payload.get("access_token")
|
|
if not token:
|
|
raise TokenError("访问令牌响应缺少 access_token。")
|
|
|
|
expires_in = int(payload.get("expires_in", 3600))
|
|
self._token = token
|
|
self._expires_at = now + max(expires_in - 300, 60)
|
|
return token
|
|
|
|
|
|
class GraphClient:
|
|
def __init__(self, token_manager: TokenManager, base_url: str):
|
|
self.token_manager = token_manager
|
|
self.base_url = base_url.rstrip("/")
|
|
self.session = requests.Session()
|
|
|
|
def _headers(self, extra_headers: dict[str, str] | None = None) -> dict[str, str]:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.token_manager.get_access_token()}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
return headers
|
|
|
|
def _request(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
*,
|
|
params: dict[str, Any] | None = None,
|
|
json_data: dict[str, Any] | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
absolute_url: bool = False,
|
|
) -> dict[str, Any]:
|
|
url = endpoint if absolute_url else f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
try:
|
|
response = self.session.request(
|
|
method=method.upper(),
|
|
url=url,
|
|
params=params,
|
|
json=json_data,
|
|
headers=self._headers(extra_headers),
|
|
timeout=REQUEST_TIMEOUT,
|
|
)
|
|
except requests.RequestException as exc:
|
|
raise GraphAPIError(f"调用 Microsoft Graph 失败: {exc}") from exc
|
|
|
|
if response.status_code == 204:
|
|
return {}
|
|
|
|
if not response.ok:
|
|
details: Any
|
|
try:
|
|
details = response.json()
|
|
except ValueError:
|
|
details = {"error": response.text}
|
|
message = self._extract_error_message(details) or f"Graph API 返回 HTTP {response.status_code}"
|
|
raise GraphAPIError(message, status_code=response.status_code, response=details)
|
|
|
|
if not response.content:
|
|
return {}
|
|
return response.json()
|
|
|
|
@staticmethod
|
|
def _extract_error_message(details: Any) -> str:
|
|
if isinstance(details, dict):
|
|
graph_error = details.get("error")
|
|
if isinstance(graph_error, dict):
|
|
return str(graph_error.get("message") or graph_error.get("code") or "")
|
|
if graph_error:
|
|
return str(graph_error)
|
|
if "message" in details:
|
|
return str(details["message"])
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _quote_identifier(identifier: str) -> str:
|
|
return quote(str(identifier), safe="@._-$")
|
|
|
|
def create_user(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return self._request("POST", "/users", json_data=payload)
|
|
|
|
def get_user(self, identifier: str, select_fields: list[str] | None = None) -> dict[str, Any]:
|
|
params = {}
|
|
if select_fields:
|
|
params["$select"] = ",".join(select_fields)
|
|
return self._request("GET", f"/users/{self._quote_identifier(identifier)}", params=params or None)
|
|
|
|
def update_user(self, identifier: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return self._request("PATCH", f"/users/{self._quote_identifier(identifier)}", json_data=payload)
|
|
|
|
def delete_user(self, identifier: str) -> None:
|
|
self._request("DELETE", f"/users/{self._quote_identifier(identifier)}")
|
|
|
|
def list_users(self, select_fields: list[str] | None = None) -> list[dict[str, Any]]:
|
|
params = {}
|
|
if select_fields:
|
|
params["$select"] = ",".join(select_fields)
|
|
|
|
response = self._request("GET", "/users", params=params or None)
|
|
users = list(response.get("value", []))
|
|
next_link = response.get("@odata.nextLink")
|
|
|
|
while next_link:
|
|
response = self._request("GET", next_link, absolute_url=True)
|
|
users.extend(response.get("value", []))
|
|
next_link = response.get("@odata.nextLink")
|
|
|
|
return users
|
|
|
|
def list_subscribed_skus(self) -> list[dict[str, Any]]:
|
|
return self._request("GET", "/subscribedSkus").get("value", [])
|
|
|
|
def assign_license(
|
|
self,
|
|
user_id: str,
|
|
*,
|
|
add_licenses: list[dict[str, Any]],
|
|
remove_licenses: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
payload = {
|
|
"addLicenses": add_licenses,
|
|
"removeLicenses": remove_licenses or [],
|
|
}
|
|
return self._request(
|
|
"POST",
|
|
f"/users/{self._quote_identifier(user_id)}/assignLicense",
|
|
json_data=payload,
|
|
)
|
|
|