151 lines
5.4 KiB
Python
151 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
|
|
logger = logging.getLogger("office365_self_service.graph")
|
|
|
|
|
|
class TokenManager:
|
|
def __init__(self, client_id: str, client_secret: str, token_endpoint: str, scope: str):
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.token_endpoint = token_endpoint
|
|
self.scope = scope
|
|
self._token: str | None = None
|
|
self._token_expires_at: float = 0
|
|
|
|
def get_token(self) -> str:
|
|
if self._token and time.time() < self._token_expires_at - 60:
|
|
return self._token
|
|
|
|
self.clear_token()
|
|
data = {
|
|
"grant_type": "client_credentials",
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"scope": self.scope,
|
|
}
|
|
try:
|
|
response = requests.post(self.token_endpoint, data=data, timeout=30)
|
|
response.raise_for_status()
|
|
except requests.RequestException as exc:
|
|
status_code = getattr(getattr(exc, "response", None), "status_code", 0) or 0
|
|
response_payload = None
|
|
response_text = ""
|
|
if getattr(exc, "response", None) is not None:
|
|
response_text = exc.response.text[:200]
|
|
try:
|
|
response_payload = exc.response.json()
|
|
except ValueError:
|
|
response_payload = None
|
|
message = "获取访问令牌失败"
|
|
if response_text:
|
|
message = f"{message}: {response_text}"
|
|
raise GraphAPIError(message, status_code=status_code, response=response_payload) from exc
|
|
|
|
try:
|
|
token_data = response.json()
|
|
except ValueError as exc:
|
|
raise GraphAPIError("解析访问令牌响应失败", response.status_code) from exc
|
|
|
|
self._token = token_data["access_token"]
|
|
expires_in = token_data.get("expires_in", 3600)
|
|
self._token_expires_at = time.time() + expires_in
|
|
logger.info(f"Token refreshed, expires in {expires_in} seconds")
|
|
return self._token
|
|
|
|
def clear_token(self) -> None:
|
|
self._token = None
|
|
self._token_expires_at = 0
|
|
|
|
|
|
class GraphAPIError(Exception):
|
|
def __init__(self, message: str, status_code: int = 0, response: dict | None = None):
|
|
super().__init__(message)
|
|
self.message = message
|
|
self.status_code = status_code
|
|
self.response = response
|
|
|
|
|
|
class GraphClient:
|
|
def __init__(self, token_manager: TokenManager, base_url: str):
|
|
self.token_manager = token_manager
|
|
self.base_url = base_url
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
return {
|
|
"Authorization": f"Bearer {self.token_manager.get_token()}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
def _request(self, method: str, path: str, **kwargs) -> dict[str, Any]:
|
|
url = f"{self.base_url}{path}"
|
|
headers = self._headers()
|
|
headers.update(kwargs.pop("headers", {}))
|
|
|
|
try:
|
|
response = requests.request(method, url, headers=headers, timeout=60, **kwargs)
|
|
except requests.RequestException as exc:
|
|
raise GraphAPIError(f"请求失败: {exc}")
|
|
|
|
try:
|
|
payload = response.json()
|
|
except ValueError:
|
|
if response.status_code == 204:
|
|
return {}
|
|
raise GraphAPIError(f"解析响应失败: {response.text[:200]}", response.status_code)
|
|
|
|
if response.status_code >= 400:
|
|
error_message = payload.get("error", {}).get("message") or str(payload)
|
|
raise GraphAPIError(error_message, response.status_code, payload)
|
|
|
|
return payload
|
|
|
|
def get(self, path: str, **kwargs) -> dict[str, Any]:
|
|
return self._request("GET", path, **kwargs)
|
|
|
|
def post(self, path: str, **kwargs) -> dict[str, Any]:
|
|
return self._request("POST", path, **kwargs)
|
|
|
|
def patch(self, path: str, **kwargs) -> dict[str, Any]:
|
|
return self._request("PATCH", path, **kwargs)
|
|
|
|
def delete(self, path: str, **kwargs) -> dict[str, Any]:
|
|
return self._request("DELETE", path, **kwargs)
|
|
|
|
def list_subscribed_skus(self) -> list[dict[str, Any]]:
|
|
result = self.get("/subscribedSkus")
|
|
return result.get("value", [])
|
|
|
|
def create_user(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
result = self.post("/users", json=payload)
|
|
return result
|
|
|
|
def get_user(self, user_id: str, select: list[str] | None = None) -> dict[str, Any]:
|
|
params = {}
|
|
if select:
|
|
params["$select"] = ",".join(select)
|
|
result = self.get(f"/users/{user_id}", params=params)
|
|
return result
|
|
|
|
def update_user(self, user_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
result = self.patch(f"/users/{user_id}", json=payload)
|
|
return result
|
|
|
|
def delete_user(self, user_id: str) -> None:
|
|
self.delete(f"/users/{user_id}")
|
|
|
|
def assign_license(self, user_id: str, add_licenses: list[dict] = None, remove_licenses: list[str] = None) -> dict[str, Any]:
|
|
payload: dict[str, list] = {}
|
|
if add_licenses:
|
|
payload["addLicenses"] = add_licenses
|
|
else:
|
|
payload["addLicenses"] = []
|
|
payload["removeLicenses"] = remove_licenses if remove_licenses else []
|
|
return self.post(f"/users/{user_id}/assignLicense", json=payload)
|