Files

190 lines
6.2 KiB
Python

from __future__ import annotations
import time
from typing import Any
from urllib.parse import quote
import requests
REQUEST_TIMEOUT = 30
class TokenError(RuntimeError):
pass
class GraphAPIError(RuntimeError):
def __init__(self, message: str, status_code: int | None = None, response: Any = None):
super().__init__(message)
self.message = message
self.status_code = status_code
self.response = response
class TokenManager:
def __init__(self, client_id: str, client_secret: str, token_endpoint: str, scope: str):
self.client_id = client_id
self.client_secret = client_secret
self.token_endpoint = token_endpoint
self.scope = scope
self._token = ""
self._expires_at = 0
def get_access_token(self) -> str:
now = int(time.time())
if self._token and now < self._expires_at:
return self._token
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": self.scope,
}
try:
response = requests.post(
self.token_endpoint,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=REQUEST_TIMEOUT,
)
response.raise_for_status()
except requests.RequestException as exc:
message = f"获取访问令牌失败: {exc}"
raise TokenError(message) from exc
payload = response.json()
token = payload.get("access_token")
if not token:
raise TokenError("访问令牌响应缺少 access_token。")
expires_in = int(payload.get("expires_in", 3600))
self._token = token
self._expires_at = now + max(expires_in - 300, 60)
return token
class GraphClient:
def __init__(self, token_manager: TokenManager, base_url: str):
self.token_manager = token_manager
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def _headers(self, extra_headers: dict[str, str] | None = None) -> dict[str, str]:
headers = {
"Authorization": f"Bearer {self.token_manager.get_access_token()}",
"Content-Type": "application/json",
}
if extra_headers:
headers.update(extra_headers)
return headers
def _request(
self,
method: str,
endpoint: str,
*,
params: dict[str, Any] | None = None,
json_data: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
absolute_url: bool = False,
) -> dict[str, Any]:
url = endpoint if absolute_url else f"{self.base_url}/{endpoint.lstrip('/')}"
try:
response = self.session.request(
method=method.upper(),
url=url,
params=params,
json=json_data,
headers=self._headers(extra_headers),
timeout=REQUEST_TIMEOUT,
)
except requests.RequestException as exc:
raise GraphAPIError(f"调用 Microsoft Graph 失败: {exc}") from exc
if response.status_code == 204:
return {}
if not response.ok:
details: Any
try:
details = response.json()
except ValueError:
details = {"error": response.text}
message = self._extract_error_message(details) or f"Graph API 返回 HTTP {response.status_code}"
raise GraphAPIError(message, status_code=response.status_code, response=details)
if not response.content:
return {}
return response.json()
@staticmethod
def _extract_error_message(details: Any) -> str:
if isinstance(details, dict):
graph_error = details.get("error")
if isinstance(graph_error, dict):
return str(graph_error.get("message") or graph_error.get("code") or "")
if graph_error:
return str(graph_error)
if "message" in details:
return str(details["message"])
return ""
@staticmethod
def _quote_identifier(identifier: str) -> str:
return quote(str(identifier), safe="@._-$")
def create_user(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._request("POST", "/users", json_data=payload)
def get_user(self, identifier: str, select_fields: list[str] | None = None) -> dict[str, Any]:
params = {}
if select_fields:
params["$select"] = ",".join(select_fields)
return self._request("GET", f"/users/{self._quote_identifier(identifier)}", params=params or None)
def update_user(self, identifier: str, payload: dict[str, Any]) -> dict[str, Any]:
return self._request("PATCH", f"/users/{self._quote_identifier(identifier)}", json_data=payload)
def delete_user(self, identifier: str) -> None:
self._request("DELETE", f"/users/{self._quote_identifier(identifier)}")
def list_users(self, select_fields: list[str] | None = None) -> list[dict[str, Any]]:
params = {}
if select_fields:
params["$select"] = ",".join(select_fields)
response = self._request("GET", "/users", params=params or None)
users = list(response.get("value", []))
next_link = response.get("@odata.nextLink")
while next_link:
response = self._request("GET", next_link, absolute_url=True)
users.extend(response.get("value", []))
next_link = response.get("@odata.nextLink")
return users
def list_subscribed_skus(self) -> list[dict[str, Any]]:
return self._request("GET", "/subscribedSkus").get("value", [])
def assign_license(
self,
user_id: str,
*,
add_licenses: list[dict[str, Any]],
remove_licenses: list[str] | None = None,
) -> dict[str, Any]:
payload = {
"addLicenses": add_licenses,
"removeLicenses": remove_licenses or [],
}
return self._request(
"POST",
f"/users/{self._quote_identifier(user_id)}/assignLicense",
json_data=payload,
)