From de130f105282817674d250fad853073354d37a4e Mon Sep 17 00:00:00 2001 From: youbin Date: Tue, 31 Mar 2026 08:13:38 +0800 Subject: [PATCH] Harden redemption flow and improve operational safety --- .env.example | 4 +- README.md | 9 +- docker-compose.yml | 3 +- office365_self_service/__init__.py | 13 +- office365_self_service/graph.py | 28 +- office365_self_service/models.py | 57 ++- office365_self_service/routes.py | 339 +++++++++++++++-- office365_self_service/services.py | 104 +++++- office365_self_service/settings.py | 27 +- .../templates/admin_dashboard.html | 275 ++++++++++++-- .../templates/admin_login.html | 13 +- .../templates/user_redemption.html | 27 +- tests/test_app.py | 345 ++++++++++++++++++ 13 files changed, 1138 insertions(+), 106 deletions(-) create mode 100644 tests/test_app.py diff --git a/.env.example b/.env.example index 01c52f8..b3d7056 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,7 @@ PORT=8000 DEBUG=false # 数据库 (SQLite) +# 默认 sqlite:///redemption.db 会落到 Flask 的 instance/redemption.db DATABASE_URL=sqlite:///redemption.db # Flask会话密钥 (建议使用随机长字符串) @@ -30,4 +31,5 @@ DEFAULT_DOMAIN=yourtenant.onmicrosoft.com DEFAULT_PASSWORD=P@ssw0rd123! DEFAULT_USAGE_LOCATION=US DEFAULT_LICENSE_SKU=ENTERPRISEPACK -FORCE_CHANGE_PASSWORD=true \ No newline at end of file +LICENSE_ASSIGNMENT_REQUIRED=false +FORCE_CHANGE_PASSWORD=true diff --git a/README.md b/README.md index 509bf85..7a33ec0 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ - **自助开通**:用户输入兑换码和用户名自助开通 Office 365 账号 - **自动授权**:账号开通时自动分配许可证 - **兑换记录**:后台记录兑换码与已开通账号的对应关系 +- **审计日志**:后台分页查看生成、删除、兑换成功/失败等关键事件 ## 项目结构 @@ -80,6 +81,7 @@ docker compose down | `DEFAULT_PASSWORD` | 是 | 新建账号默认密码 | 自定义高强度密码 | | `DEFAULT_DOMAIN` | 建议 | 默认域名 | 例如 `yourtenant.onmicrosoft.com` | | `DEFAULT_LICENSE_SKU` | 可选 | 默认许可证 SKU | 例如 `ENTERPRISEPACK`、`M365_BUSINESS_PREMIUM` | +| `LICENSE_ASSIGNMENT_REQUIRED` | 可选 | 许可证分配失败时是否回滚删除新账号 | 默认 `false` | | `DEFAULT_USAGE_LOCATION` | 建议 | 默认使用地区 | 国际版常用:`US`、`SG`、`JP` | | `WEB_AUTH_ENABLED` | 可选 | 后台登录保护 | `true` 或 `false` | | `ADMIN_USERNAME` | 建议 | 后台登录用户名 | 自定义 | @@ -89,6 +91,8 @@ docker compose down | `PORT` | 可选 | 服务监听端口 | 默认 `8000` | | `DEBUG` | 可选 | 调试模式 | 默认 `false` | +提示:如果本地误用了容器内的 SQLite 路径(例如 `sqlite:////app/data/redemption.db`),项目现在会自动映射到当前仓库下的对应本地路径。 + ### Entra ID (Azure AD) 应用配置 1. **创建应用注册** @@ -126,7 +130,7 @@ docker compose down 1. 使用设置的 admin 账号登录 2. 点击「生成兑换码」批量生成兑换码 -3. 可以查看所有兑换码及兑换记录 +3. 可以查看兑换码、兑换记录和审计日志 ### 用户自助开通 @@ -148,5 +152,6 @@ docker compose down ## 注意事项 - `DEFAULT_LICENSE_SKU` 必须是租户中实际存在的 SKU 名称 +- 如果希望“建号和授权”保持强一致,可设置 `LICENSE_ASSIGNMENT_REQUIRED=true` - 兑换码使用后立即失效,无法重复使用 -- 生产环境建议使用 `DEBUG=false` 并配置反向代理 \ No newline at end of file +- 生产环境建议使用 `DEBUG=false` 并配置反向代理 diff --git a/docker-compose.yml b/docker-compose.yml index b318e88..0388c48 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,12 +5,13 @@ services: - "8000:8000" volumes: - ./data:/app/data + - ./instance:/app/instance - ./logs:/app/logs env_file: - .env restart: unless-stopped healthcheck: - test: ["CMD", "python", "-c", "from urllib.request import urlopen; urlopen('http://localhost:8000/admin/api/health', timeout=5).read()"] + test: ["CMD", "python", "-c", "from urllib.request import urlopen; urlopen('http://localhost:8000/api/health', timeout=5).read()"] interval: 30s timeout: 10s retries: 3 diff --git a/office365_self_service/__init__.py b/office365_self_service/__init__.py index a3993b6..33af12d 100644 --- a/office365_self_service/__init__.py +++ b/office365_self_service/__init__.py @@ -9,6 +9,7 @@ from flask import Flask from flask_sqlalchemy import SQLAlchemy from sqlalchemy import event from sqlalchemy.engine import Engine +from sqlalchemy.engine.url import make_url from .services import Office365Service from .settings import Settings, load_settings @@ -17,6 +18,13 @@ from .settings import Settings, load_settings db = SQLAlchemy() +def _ensure_sqlite_directory(database_url: str) -> None: + url = make_url(database_url) + if url.drivername != "sqlite" or not url.database or url.database == ":memory:": + return + Path(url.database).parent.mkdir(parents=True, exist_ok=True) + + def _configure_logging(app: Flask) -> None: log_dir = Path(app.root_path).parent / "logs" log_dir.mkdir(parents=True, exist_ok=True) @@ -62,11 +70,12 @@ def create_app( app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False _configure_logging(app) + _ensure_sqlite_directory(settings.database_url) db.init_app(app) with app.app_context(): - from .models import RedemptionCode + from .models import AuditEvent, RedemptionCode db.create_all() if service_factory is None: @@ -83,4 +92,4 @@ def create_app( app.register_blueprint(bp_admin) app.register_blueprint(bp_user) - return app \ No newline at end of file + return app diff --git a/office365_self_service/graph.py b/office365_self_service/graph.py index 5233fb7..2b8a16b 100644 --- a/office365_self_service/graph.py +++ b/office365_self_service/graph.py @@ -30,9 +30,29 @@ class TokenManager: "client_secret": self.client_secret, "scope": self.scope, } - response = requests.post(self.token_endpoint, data=data, timeout=30) - response.raise_for_status() - token_data = response.json() + try: + response = requests.post(self.token_endpoint, data=data, timeout=30) + response.raise_for_status() + except requests.RequestException as exc: + status_code = getattr(getattr(exc, "response", None), "status_code", 0) or 0 + response_payload = None + response_text = "" + if getattr(exc, "response", None) is not None: + response_text = exc.response.text[:200] + try: + response_payload = exc.response.json() + except ValueError: + response_payload = None + message = "获取访问令牌失败" + if response_text: + message = f"{message}: {response_text}" + raise GraphAPIError(message, status_code=status_code, response=response_payload) from exc + + try: + token_data = response.json() + except ValueError as exc: + raise GraphAPIError("解析访问令牌响应失败", response.status_code) from exc + self._token = token_data["access_token"] expires_in = token_data.get("expires_in", 3600) self._token_expires_at = time.time() + expires_in @@ -127,4 +147,4 @@ class GraphClient: else: payload["addLicenses"] = [] payload["removeLicenses"] = remove_licenses if remove_licenses else [] - return self.post(f"/users/{user_id}/assignLicense", json=payload) \ No newline at end of file + return self.post(f"/users/{user_id}/assignLicense", json=payload) diff --git a/office365_self_service/models.py b/office365_self_service/models.py index 8134e9f..1b54163 100644 --- a/office365_self_service/models.py +++ b/office365_self_service/models.py @@ -1,17 +1,30 @@ from __future__ import annotations -from datetime import datetime +import json +from datetime import datetime, timezone from . import db +def utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def serialize_datetime(value: datetime | None) -> str | None: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + + class RedemptionCode(db.Model): __tablename__ = "redemption_codes" id = db.Column(db.Integer, primary_key=True) code = db.Column(db.String(64), unique=True, nullable=False, index=True) status = db.Column(db.String(16), nullable=False, default="available") - created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now()) + created_at = db.Column(db.DateTime, nullable=False, default=utc_now) used_at = db.Column(db.DateTime, nullable=True) used_by_username = db.Column(db.String(256), nullable=True) used_by_principal_name = db.Column(db.String(256), nullable=True) @@ -21,8 +34,42 @@ class RedemptionCode(db.Model): "id": self.id, "code": self.code, "status": self.status, - "createdAt": self.created_at.isoformat() if self.created_at else None, - "usedAt": self.used_at.isoformat() if self.used_at else None, + "createdAt": serialize_datetime(self.created_at), + "usedAt": serialize_datetime(self.used_at), "usedByUsername": self.used_by_username, "usedByPrincipalName": self.used_by_principal_name, - } \ No newline at end of file + } + + +class AuditEvent(db.Model): + __tablename__ = "audit_events" + + id = db.Column(db.Integer, primary_key=True) + event_type = db.Column(db.String(64), nullable=False, index=True) + status = db.Column(db.String(16), nullable=False, default="success", index=True) + actor = db.Column(db.String(128), nullable=False, default="system") + code = db.Column(db.String(64), nullable=True, index=True) + username = db.Column(db.String(256), nullable=True) + principal_name = db.Column(db.String(256), nullable=True) + details = db.Column(db.Text, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, default=utc_now, index=True) + + def to_dict(self): + parsed_details = None + if self.details: + try: + parsed_details = json.loads(self.details) + except ValueError: + parsed_details = {"raw": self.details} + + return { + "id": self.id, + "eventType": self.event_type, + "status": self.status, + "actor": self.actor, + "code": self.code, + "username": self.username, + "principalName": self.principal_name, + "details": parsed_details, + "createdAt": serialize_datetime(self.created_at), + } diff --git a/office365_self_service/routes.py b/office365_self_service/routes.py index ac85f37..442d6b3 100644 --- a/office365_self_service/routes.py +++ b/office365_self_service/routes.py @@ -1,19 +1,25 @@ from __future__ import annotations +import json +import logging import secrets -from datetime import datetime, timezone, timedelta from functools import wraps from flask import Blueprint, current_app, jsonify, render_template, request, session -from sqlalchemy import func +from sqlalchemy import func, update from . import db -from .models import RedemptionCode +from .models import AuditEvent, RedemptionCode, utc_now from .services import Office365Service, ServiceConfigurationError, ServiceOperationError bp_admin = Blueprint("admin", __name__, url_prefix="/admin") bp_user = Blueprint("user", __name__) +logger = logging.getLogger("office365_self_service.routes") + +STATUS_AVAILABLE = "available" +STATUS_PROCESSING = "processing" +STATUS_USED = "used" def _settings(): @@ -67,6 +73,133 @@ def _json_payload() -> dict: return request.get_json(silent=True) or {} +def _code_match(code: str): + return func.lower(RedemptionCode.code) == code.lower() + + +def _health_payload() -> dict: + settings = _settings() + return { + "platform": settings.to_public_dict(), + "authenticated": _authenticated(), + } + + +def _current_actor(default: str = "system") -> str: + if _authenticated(): + return session.get("admin_username") or _settings().admin_username or default + return default + + +def _build_audit_event( + event_type: str, + *, + status: str = "success", + actor: str | None = None, + code: str | None = None, + username: str | None = None, + principal_name: str | None = None, + details: dict | None = None, +) -> AuditEvent: + return AuditEvent( + event_type=event_type, + status=status, + actor=actor or "system", + code=code, + username=username, + principal_name=principal_name, + details=json.dumps(details, ensure_ascii=False, sort_keys=True) if details is not None else None, + ) + + +def _record_audit_events(*events: AuditEvent) -> None: + pending = [event for event in events if event is not None] + if not pending: + return + + try: + db.session.add_all(pending) + db.session.commit() + except Exception: + db.session.rollback() + logger.exception("写入审计日志失败,共 %s 条事件。", len(pending)) + + +def _pagination_params() -> tuple[int, int, int]: + settings = _settings() + + try: + page = int(request.args.get("page", "1")) + except ValueError: + page = 1 + + try: + page_size = int(request.args.get("pageSize", str(settings.default_page_size))) + except ValueError: + page_size = settings.default_page_size + + page = max(page, 1) + page_size = min(max(page_size, 1), settings.max_page_size) + return page, page_size, (page - 1) * page_size + + +def _pagination_payload(page: int, page_size: int, total: int) -> dict[str, int]: + pages = (total + page_size - 1) // page_size if total else 0 + return { + "page": page, + "pageSize": page_size, + "total": total, + "pages": pages, + } + + +def _normalize_pagination(page: int, page_size: int, total: int) -> tuple[int, int, int]: + pages = (total + page_size - 1) // page_size if total else 0 + if pages and page > pages: + page = pages + return page, (page - 1) * page_size, pages + + +def _reserve_code(code: str) -> bool: + result = db.session.execute( + update(RedemptionCode) + .where(_code_match(code), RedemptionCode.status == STATUS_AVAILABLE) + .values(status=STATUS_PROCESSING) + ) + db.session.commit() + return result.rowcount == 1 + + +def _release_code(code: str) -> None: + db.session.rollback() + db.session.execute( + update(RedemptionCode) + .where(_code_match(code), RedemptionCode.status == STATUS_PROCESSING) + .values( + status=STATUS_AVAILABLE, + used_at=None, + used_by_username=None, + used_by_principal_name=None, + ) + ) + db.session.commit() + + +def _complete_redemption(code: str, username: str, principal_name: str | None) -> bool: + result = db.session.execute( + update(RedemptionCode) + .where(_code_match(code), RedemptionCode.status == STATUS_PROCESSING) + .values( + status=STATUS_USED, + used_at=utc_now(), + used_by_username=username, + used_by_principal_name=principal_name, + ) + ) + db.session.commit() + return result.rowcount == 1 + + @bp_admin.get("/") def admin_index(): if not _authenticated(): @@ -76,13 +209,7 @@ def admin_index(): @bp_admin.get("/api/health") def health(): - settings = _settings() - return _success( - { - "platform": settings.to_public_dict(), - "authenticated": _authenticated(), - } - ) + return _success(_health_payload()) @bp_admin.get("/api/session") @@ -110,6 +237,7 @@ def login(): if username == settings.admin_username and password == settings.admin_password: session["authenticated"] = True session.permanent = True + session["admin_username"] = username return _success({"authenticated": True}, message="登录成功。") return _error("用户名或密码错误。", status=401) @@ -126,20 +254,54 @@ def config_info(): return _success(_settings().to_public_dict()) +@bp_admin.get("/api/audit-events") +@require_auth +def list_audit_events(): + event_type = request.args.get("eventType", "").strip() + status = request.args.get("status", "").strip() + page, page_size, offset = _pagination_params() + + query = db.select(AuditEvent) + count_query = db.select(func.count()).select_from(AuditEvent) + + if event_type: + query = query.where(AuditEvent.event_type == event_type) + count_query = count_query.where(AuditEvent.event_type == event_type) + if status: + query = query.where(AuditEvent.status == status) + count_query = count_query.where(AuditEvent.status == status) + + total = db.session.execute(count_query).scalar_one() + page, offset, _ = _normalize_pagination(page, page_size, total) + events = db.session.execute( + query.order_by(AuditEvent.created_at.desc(), AuditEvent.id.desc()).offset(offset).limit(page_size) + ).scalars().all() + + return _success({ + "events": [event.to_dict() for event in events], + **_pagination_payload(page, page_size, total), + }) + + @bp_admin.get("/api/codes") @require_auth def list_codes(): status = request.args.get("status") + page, page_size, offset = _pagination_params() query = db.select(RedemptionCode) + count_query = db.select(func.count()).select_from(RedemptionCode) - if status == "available": - query = query.where(RedemptionCode.status == "available") - elif status == "used": - query = query.where(RedemptionCode.status == "used") + if status in {STATUS_AVAILABLE, STATUS_PROCESSING, STATUS_USED}: + query = query.where(RedemptionCode.status == status) + count_query = count_query.where(RedemptionCode.status == status) - result = db.session.execute(query.order_by(RedemptionCode.created_at.desc())).scalars().all() + total = db.session.execute(count_query).scalar_one() + page, offset, _ = _normalize_pagination(page, page_size, total) + result = db.session.execute( + query.order_by(RedemptionCode.created_at.desc()).offset(offset).limit(page_size) + ).scalars().all() codes = [code.to_dict() for code in result] - return _success({"codes": codes, "total": len(codes)}) + return _success({"codes": codes, **_pagination_payload(page, page_size, total)}) @bp_admin.post("/api/codes/generate") @@ -147,6 +309,7 @@ def list_codes(): def generate_codes(): payload = _json_payload() count = payload.get("count", 1) + actor = _current_actor("auth-disabled-admin") if count < 1: count = 1 if count > 100: @@ -163,40 +326,58 @@ def generate_codes(): codes.append(code) db.session.commit() + _record_audit_events( + *[ + _build_audit_event( + "code_generated", + actor=actor, + code=generated_code, + details={"batchCount": len(codes)}, + ) + for generated_code in codes + ] + ) return _success({"codes": codes, "count": len(codes)}, f"成功生成 {count} 个兑换码。") @bp_admin.delete("/api/codes/") @require_auth def delete_code(code: str): + actor = _current_actor("auth-disabled-admin") redemption_code = RedemptionCode.query.filter_by(code=code).first() if not redemption_code: return _error("兑换码不存在。", status=404) + if redemption_code.status != STATUS_AVAILABLE: + return _error("仅可删除未使用的兑换码。", status=409) db.session.delete(redemption_code) db.session.commit() + _record_audit_events( + _build_audit_event( + "code_deleted", + actor=actor, + code=code, + ) + ) return _success(message="兑换码已删除。") @bp_admin.get("/api/records") @require_auth def list_records(): - page = int(request.args.get("page", "1")) - page_size = int(request.args.get("pageSize", "25")) + page, page_size, offset = _pagination_params() - query = db.select(RedemptionCode).where(RedemptionCode.status == "used") - result = db.session.execute(query.order_by(RedemptionCode.used_at.desc())).scalars().all() - - total = len(result) - start = (page - 1) * page_size - end = start + page_size - records = result[start:end] + query = db.select(RedemptionCode).where(RedemptionCode.status == STATUS_USED) + count_query = db.select(func.count()).select_from(RedemptionCode).where(RedemptionCode.status == STATUS_USED) + total = db.session.execute(count_query).scalar_one() + page, offset, _ = _normalize_pagination(page, page_size, total) + records = db.session.execute( + query.order_by(RedemptionCode.used_at.desc()).offset(offset).limit(page_size) + ).scalars().all() return _success({ "records": [code.to_dict() for code in records], - "page": page, - "pageSize": page_size, - "total": total, + **_pagination_payload(page, page_size, total), }) @@ -205,43 +386,121 @@ def index(): return render_template("user_redemption.html", settings=_settings()) +@bp_user.get("/api/health") +def user_health(): + return _success(_health_payload()) + + @bp_user.post("/api/redeem") def redeem(): payload = _json_payload() code = str(payload.get("code", "")).strip().upper() username = str(payload.get("username", "")).strip().lower() + actor = _current_actor("public") if not code: return _error("请输入兑换码。", status=400) if not username: return _error("请输入用户名。", status=400) - redemption_code = RedemptionCode.query.filter( - func.lower(RedemptionCode.code) == code.lower(), - RedemptionCode.status == "available" - ).first() - if not redemption_code: + if not _reserve_code(code): + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="failed", + actor=actor, + code=code, + username=username, + details={"message": "兑换码无效或已被使用。"}, + ) + ) return _error("兑换码无效或已被使用。", status=404) try: user_result = _service().create_user(username=username) + except ServiceConfigurationError as exc: + _release_code(code) + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="failed", + actor=actor, + code=code, + username=username, + details={"message": str(exc)}, + ) + ) + return _error(str(exc), status=503) except ServiceOperationError as exc: - return _error(str(exc), status=500) + _release_code(code) + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="failed", + actor=actor, + code=code, + username=username, + principal_name=(exc.details or {}).get("userPrincipalName") if isinstance(exc.details, dict) else None, + details={"message": exc.message, "serviceDetails": exc.details}, + ) + ) + return _error(exc.message, status=exc.status_code, details=exc.details) except Exception as exc: + logger.exception("兑换码 %s 开通账号时发生未预期错误", code) + _release_code(code) + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="failed", + actor=actor, + code=code, + username=username, + details={"message": f"创建账号失败: {exc}"}, + ) + ) return _error(f"创建账号失败: {exc}", status=500) - redemption_code.status = "used" - redemption_code.used_at = datetime.now(timezone.utc) - redemption_code.used_by_username = username - redemption_code.used_by_principal_name = user_result.get("userPrincipalName") - db.session.commit() + if not _complete_redemption(code, username, user_result.get("userPrincipalName")): + logger.error("账号 %s 已创建,但兑换码 %s 未能完成最终状态更新。", user_result.get("userPrincipalName"), code) + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="warning", + actor=actor, + code=code, + username=username, + principal_name=user_result.get("userPrincipalName"), + details={"message": "账号已创建,但兑换码状态更新失败。"}, + ) + ) + return _error( + "账号已创建,但兑换码状态更新失败,请联系管理员处理。", + status=500, + details={"userPrincipalName": user_result.get("userPrincipalName")}, + ) + _record_audit_events( + _build_audit_event( + "redeem_completed", + status="success", + actor=actor, + code=code, + username=username, + principal_name=user_result.get("userPrincipalName"), + details={ + "licenseAssigned": user_result.get("licenseAssigned"), + "licenseMessage": user_result.get("licenseMessage"), + }, + ) + ) return _success({ "userPrincipalName": user_result.get("userPrincipalName"), "temporaryPassword": user_result.get("temporaryPassword"), + "licenseAssigned": user_result.get("licenseAssigned"), + "licenseMessage": user_result.get("licenseMessage"), }, "账号开通成功!", status=201) @bp_user.get("/api/config") def config(): - return _success(_settings().to_public_dict()) \ No newline at end of file + return _success(_settings().to_public_dict()) diff --git a/office365_self_service/services.py b/office365_self_service/services.py index a4b4544..89a7be5 100644 --- a/office365_self_service/services.py +++ b/office365_self_service/services.py @@ -43,16 +43,16 @@ class Office365Service: return self._graph_client def create_user(self, username: str, password: str | None = None, display_name: str | None = None, retry: bool = True) -> dict[str, Any]: + upn, mail_nickname = self._build_user_identifiers(username) client = self._ensure_client() - upn = f"{username}@{self.settings.default_domain}" password = password or self.settings.default_password - display_name = display_name or username + display_name = display_name or mail_nickname create_payload = { "accountEnabled": True, "displayName": display_name, - "mailNickname": username, + "mailNickname": mail_nickname, "userPrincipalName": upn, "passwordProfile": { "password": password, @@ -71,8 +71,17 @@ class Office365Service: raise self._translate_graph_error(exc, f"创建用户 {upn} 失败") license_result = None + license_message = None if self.settings.default_license_sku: - license_result = self._assign_license(user["id"]) + license_result, license_message, license_status = self._assign_license(user["id"]) + if license_message and self.settings.license_assignment_required: + self._rollback_user_for_license_failure( + client=client, + user_id=user["id"], + user_principal_name=upn, + license_message=license_message, + license_status=license_status, + ) return { "user": user, @@ -80,37 +89,98 @@ class Office365Service: "temporaryPassword": password, "licenseAssigned": bool(license_result), "licenseResult": license_result, + "licenseMessage": license_message, } - def _assign_license(self, user_id: str) -> dict[str, Any]: + def _build_user_identifiers(self, username: str) -> tuple[str, str]: + normalized = (username or "").strip().lower() + if not normalized: + raise ValueError("请输入用户名。") + + if "@" in normalized: + local_part, _, domain = normalized.partition("@") + if not local_part or not domain: + raise ValueError("请输入有效的完整邮箱地址。") + return normalized, local_part + + if not self.settings.default_domain: + raise ServiceConfigurationError("DEFAULT_DOMAIN 未配置,请输入完整邮箱地址后重试。") + + return f"{normalized}@{self.settings.default_domain}", normalized + + def _assign_license(self, user_id: str) -> tuple[dict[str, Any] | None, str | None, int]: client = self._ensure_client() sku_part_number = self.settings.default_license_sku try: skus = client.list_subscribed_skus() except GraphAPIError as exc: - logger.warning("获取许可证列表失败: %s", exc) - return None + message = f"获取许可证列表失败: {exc.message or exc}" + logger.warning(message) + return None, message, exc.status_code or 502 matched = next( (sku for sku in skus if (sku.get("skuPartNumber") or "").upper() == sku_part_number.upper()), None, ) if not matched: - logger.warning("未找到许可证 SKU: %s", sku_part_number) - return None + message = f"未找到许可证 SKU: {sku_part_number}" + logger.warning(message) + return None, message, 409 if int(matched.get("consumedUnits", 0) or 0) >= int(matched.get("prepaidUnits", {}).get("enabled", 0) or 0): - logger.warning("许可证 %s 已无可用席位", sku_part_number) - return None + message = f"许可证 {sku_part_number} 已无可用席位" + logger.warning(message) + return None, message, 409 try: - return client.assign_license( - user_id, - add_licenses=[{"skuId": matched["skuId"], "disabledPlans": []}], + return ( + client.assign_license( + user_id, + add_licenses=[{"skuId": matched["skuId"], "disabledPlans": []}], + ), + None, + 200, ) except GraphAPIError as exc: - logger.warning("分配许可证失败: %s", exc) - return None + message = f"分配许可证失败: {exc.message or exc}" + logger.warning(message) + return None, message, exc.status_code or 502 + + def _rollback_user_for_license_failure( + self, + client: GraphClient, + user_id: str, + user_principal_name: str, + license_message: str, + license_status: int, + ) -> None: + try: + client.delete_user(user_id) + except GraphAPIError as exc: + delete_message = exc.message or str(exc) + raise ServiceOperationError( + message=( + f"账号 {user_principal_name} 已创建,但许可证分配失败且回滚删除失败。" + f"{license_message};删除失败: {delete_message}" + ), + status_code=502, + details={ + "userPrincipalName": user_principal_name, + "licenseError": license_message, + "rollbackDeleteError": delete_message, + "rolledBack": False, + }, + ) from exc + + raise ServiceOperationError( + message=f"许可证分配失败,已回滚删除账号 {user_principal_name}。{license_message}", + status_code=license_status or 409, + details={ + "userPrincipalName": user_principal_name, + "licenseError": license_message, + "rolledBack": True, + }, + ) def _translate_graph_error(self, exc: GraphAPIError, fallback_message: str) -> ServiceOperationError: message = fallback_message @@ -120,4 +190,4 @@ class Office365Service: lowered = message.lower() if "already exists" in lowered or "another object with the same value" in lowered: status_code = 409 - return ServiceOperationError(message=message, status_code=status_code, details=exc.response) \ No newline at end of file + return ServiceOperationError(message=message, status_code=status_code, details=exc.response) diff --git a/office365_self_service/settings.py b/office365_self_service/settings.py index 41ee4d9..42550c4 100644 --- a/office365_self_service/settings.py +++ b/office365_self_service/settings.py @@ -2,6 +2,7 @@ from __future__ import annotations import os from dataclasses import dataclass, field +from pathlib import Path from dotenv import load_dotenv @@ -22,6 +23,24 @@ def _env_int(name: str, default: int) -> int: return default +def _normalize_database_url(database_url: str, warnings: list[str]) -> str: + normalized = database_url.strip() + if not normalized: + return "sqlite:///redemption.db" + + container_prefix = "sqlite:////app/" + if normalized.startswith(container_prefix) and not Path("/.dockerenv").exists(): + local_relative = normalized.removeprefix(container_prefix) + project_root = Path(__file__).resolve().parent.parent + local_path = (project_root / local_relative).resolve() + warnings.append( + f"DATABASE_URL 使用容器路径时,已自动映射到本地路径 {local_path}。" + ) + return f"sqlite:///{local_path}" + + return normalized + + @dataclass class Settings: app_name: str @@ -39,6 +58,7 @@ class Settings: default_domain: str default_usage_location: str default_license_sku: str + license_assignment_required: bool force_change_password: bool graph_base_url: str token_endpoint: str @@ -68,6 +88,7 @@ class Settings: "defaultDomain": self.default_domain, "defaultUsageLocation": self.default_usage_location, "defaultLicenseSku": self.default_license_sku, + "licenseAssignmentRequired": self.license_assignment_required, "forceChangePassword": self.force_change_password, "pageSize": self.default_page_size, "maxPageSize": self.max_page_size, @@ -84,6 +105,7 @@ def load_settings() -> Settings: validation_errors: list[str] = [] warnings: list[str] = [] + database_url = _normalize_database_url(os.getenv("DATABASE_URL", "sqlite:///redemption.db"), warnings) required_fields = { "CLIENT_ID": os.getenv("CLIENT_ID", "").strip(), @@ -121,13 +143,14 @@ def load_settings() -> Settings: default_domain=os.getenv("DEFAULT_DOMAIN", "").strip(), default_usage_location=os.getenv("DEFAULT_USAGE_LOCATION", "US").strip() or "US", default_license_sku=os.getenv("DEFAULT_LICENSE_SKU", "").strip(), + license_assignment_required=_env_bool("LICENSE_ASSIGNMENT_REQUIRED", False), force_change_password=_env_bool("FORCE_CHANGE_PASSWORD", True), graph_base_url=graph_base_url, token_endpoint=token_endpoint, scope=scope, - database_url=os.getenv("DATABASE_URL", "sqlite:///redemption.db").strip(), + database_url=database_url, default_page_size=min(max(_env_int("DEFAULT_PAGE_SIZE", 25), 1), 100), max_page_size=min(max(_env_int("MAX_PAGE_SIZE", 100), 10), 500), validation_errors=tuple(validation_errors), warnings=tuple(warnings), - ) \ No newline at end of file + ) diff --git a/office365_self_service/templates/admin_dashboard.html b/office365_self_service/templates/admin_dashboard.html index a6bfe10..33910cd 100644 --- a/office365_self_service/templates/admin_dashboard.html +++ b/office365_self_service/templates/admin_dashboard.html @@ -21,6 +21,7 @@
@@ -42,6 +43,7 @@
+
@@ -59,6 +61,18 @@
+
+
+ + + 共 0 条 +
+
+ + + +
+
@@ -80,6 +94,54 @@ +
+
+ + + 共 0 条 +
+
+ + + +
+
+ + + + +
+

审计日志

+
+
+
+ + + + + + + + + + + + + +
时间事件状态操作人兑换码账号详情
+
+
+
+ + + 共 0 条 +
+
+ + + +
+
@@ -114,42 +176,174 @@ - \ No newline at end of file + diff --git a/office365_self_service/templates/admin_login.html b/office365_self_service/templates/admin_login.html index ee4800e..f1493bb 100644 --- a/office365_self_service/templates/admin_login.html +++ b/office365_self_service/templates/admin_login.html @@ -27,6 +27,15 @@ - \ No newline at end of file + diff --git a/office365_self_service/templates/user_redemption.html b/office365_self_service/templates/user_redemption.html index 0228256..d019b54 100644 --- a/office365_self_service/templates/user_redemption.html +++ b/office365_self_service/templates/user_redemption.html @@ -25,11 +25,16 @@
+ {% if settings.default_domain %}
@{{ settings.default_domain }}
请输入您想要的用户名,将自动拼接域名为完整邮箱地址
+ {% else %} + +
当前未配置默认域名,请直接输入完整邮箱地址。
+ {% endif %}
@@ -50,11 +55,21 @@
提示:首次登录后系统会要求您更改密码,请使用临时密码登录。
+
- \ No newline at end of file + diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 0000000..1a661a9 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,345 @@ +import os +from datetime import datetime, timezone +import tempfile +import unittest +from unittest.mock import patch +from pathlib import Path + +from office365_self_service import create_app, db +from office365_self_service.models import AuditEvent, RedemptionCode +from office365_self_service.services import Office365Service, ServiceConfigurationError, ServiceOperationError +from office365_self_service.settings import GRAPH_BASE_URL, GRAPH_SCOPE, Settings, TOKEN_ENDPOINT_TEMPLATE, load_settings + + +def build_settings(database_url: str, **overrides) -> Settings: + tenant_id = overrides.pop("tenant_id", "tenant-id") + base = { + "app_name": "Office 365 Self Service Test", + "host": "127.0.0.1", + "port": 5000, + "debug": False, + "session_secret": "test-secret", + "auth_enabled": False, + "admin_username": "", + "admin_password": "", + "client_id": "client-id", + "tenant_id": tenant_id, + "client_secret": "client-secret", + "default_password": "TempPassw0rd!", + "default_domain": "example.com", + "default_usage_location": "US", + "default_license_sku": "", + "license_assignment_required": False, + "force_change_password": True, + "graph_base_url": GRAPH_BASE_URL, + "token_endpoint": TOKEN_ENDPOINT_TEMPLATE.format(tenant_id=tenant_id), + "scope": GRAPH_SCOPE, + "database_url": database_url, + "validation_errors": (), + "warnings": (), + } + base.update(overrides) + return Settings(**base) + + +class FakeService: + def __init__(self, result=None, error=None): + self.result = result or { + "userPrincipalName": "alice@example.com", + "temporaryPassword": "TempPassw0rd!", + "licenseAssigned": True, + "licenseMessage": None, + } + self.error = error + self.calls = [] + + def create_user(self, username: str, **kwargs): + self.calls.append(username) + if self.error: + raise self.error + return self.result + + +class FakeGraphClient: + def __init__(self, skus=None, assign_result=None, assign_error=None, delete_error=None): + self.payloads = [] + self.deleted_users = [] + self.skus = skus or [] + self.assign_result = assign_result or {"status": "ok"} + self.assign_error = assign_error + self.delete_error = delete_error + + def create_user(self, payload): + self.payloads.append(payload) + return {"id": "user-1"} + + def list_subscribed_skus(self): + return self.skus + + def assign_license(self, user_id, add_licenses=None, remove_licenses=None): + if self.assign_error: + raise self.assign_error + return self.assign_result + + def delete_user(self, user_id): + if self.delete_error: + raise self.delete_error + self.deleted_users.append(user_id) + + +class AppRouteTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + db_path = Path(self.temp_dir.name) / "test.db" + self.settings = build_settings(f"sqlite:///{db_path}") + self.service = FakeService() + self.app = create_app( + settings_override=self.settings, + service_factory=lambda _settings: self.service, + ) + self.app.testing = True + self.client = self.app.test_client() + + with self.app.app_context(): + db.drop_all() + db.create_all() + + def tearDown(self): + self.temp_dir.cleanup() + + def add_code(self, code: str, status: str = "available"): + with self.app.app_context(): + db.session.add(RedemptionCode(code=code, status=status)) + db.session.commit() + + def fetch_code(self, code: str) -> RedemptionCode: + with self.app.app_context(): + return RedemptionCode.query.filter_by(code=code).first() + + def fetch_audit_events(self) -> list[AuditEvent]: + with self.app.app_context(): + return AuditEvent.query.order_by(AuditEvent.created_at.asc(), AuditEvent.id.asc()).all() + + def test_redeem_marks_code_used_and_prevents_second_use(self): + self.add_code("CODE-001") + + response = self.client.post("/api/redeem", json={"code": "code-001", "username": "alice"}) + payload = response.get_json() + + self.assertEqual(response.status_code, 201) + self.assertTrue(payload["success"]) + self.assertEqual(payload["data"]["userPrincipalName"], "alice@example.com") + self.assertEqual(self.service.calls, ["alice"]) + + code = self.fetch_code("CODE-001") + self.assertEqual(code.status, "used") + self.assertEqual(code.used_by_username, "alice") + self.assertEqual(code.used_by_principal_name, "alice@example.com") + + second = self.client.post("/api/redeem", json={"code": "CODE-001", "username": "bob"}) + second_payload = second.get_json() + + self.assertEqual(second.status_code, 404) + self.assertFalse(second_payload["success"]) + self.assertEqual(self.service.calls, ["alice"]) + + def test_redeem_releases_code_when_service_fails(self): + self.service = FakeService(error=ServiceOperationError("用户名已存在。", status_code=409)) + self.app = create_app( + settings_override=self.settings, + service_factory=lambda _settings: self.service, + ) + self.app.testing = True + self.client = self.app.test_client() + + with self.app.app_context(): + db.drop_all() + db.create_all() + db.session.add(RedemptionCode(code="CODE-002")) + db.session.commit() + + response = self.client.post("/api/redeem", json={"code": "CODE-002", "username": "alice"}) + payload = response.get_json() + + self.assertEqual(response.status_code, 409) + self.assertFalse(payload["success"]) + code = self.fetch_code("CODE-002") + self.assertEqual(code.status, "available") + self.assertIsNone(code.used_by_username) + self.assertEqual(self.service.calls, ["alice"]) + + def test_public_health_endpoint_is_available(self): + response = self.client.get("/api/health") + payload = response.get_json() + + self.assertEqual(response.status_code, 200) + self.assertTrue(payload["success"]) + self.assertIn("platform", payload["data"]) + + def test_generate_delete_and_failed_redeem_are_audited(self): + generate = self.client.post("/admin/api/codes/generate", json={"count": 1}) + generated_code = generate.get_json()["data"]["codes"][0] + + delete = self.client.delete(f"/admin/api/codes/{generated_code}") + self.assertEqual(delete.status_code, 200) + + failed = self.client.post("/api/redeem", json={"code": "MISSING-CODE", "username": "alice"}) + self.assertEqual(failed.status_code, 404) + + audit_response = self.client.get("/admin/api/audit-events?page=1&pageSize=10") + audit_payload = audit_response.get_json() + + self.assertEqual(audit_response.status_code, 200) + self.assertEqual(audit_payload["data"]["total"], 3) + self.assertEqual( + [event["eventType"] for event in audit_payload["data"]["events"]], + ["redeem_completed", "code_deleted", "code_generated"], + ) + + def test_successful_redeem_creates_success_audit_event(self): + self.add_code("CODE-004") + + response = self.client.post("/api/redeem", json={"code": "CODE-004", "username": "alice"}) + self.assertEqual(response.status_code, 201) + + events = self.fetch_audit_events() + self.assertEqual(len(events), 1) + event_payload = events[0].to_dict() + self.assertEqual(event_payload["eventType"], "redeem_completed") + self.assertEqual(event_payload["status"], "success") + self.assertEqual(event_payload["code"], "CODE-004") + self.assertEqual(event_payload["principalName"], "alice@example.com") + self.assertEqual(event_payload["details"]["licenseAssigned"], True) + + def test_codes_api_uses_database_pagination(self): + for index in range(5): + self.add_code(f"CODE-{index:03d}") + + response = self.client.get("/admin/api/codes?page=2&pageSize=2") + payload = response.get_json() + + self.assertEqual(response.status_code, 200) + self.assertEqual(payload["data"]["page"], 2) + self.assertEqual(payload["data"]["pageSize"], 2) + self.assertEqual(payload["data"]["total"], 5) + self.assertEqual(payload["data"]["pages"], 3) + self.assertEqual(len(payload["data"]["codes"]), 2) + + def test_records_api_uses_database_pagination(self): + with self.app.app_context(): + for index in range(5): + db.session.add( + RedemptionCode( + code=f"USED-{index:03d}", + status="used", + used_by_username=f"user{index}", + used_by_principal_name=f"user{index}@example.com", + ) + ) + db.session.commit() + + response = self.client.get("/admin/api/records?page=2&pageSize=2") + payload = response.get_json() + + self.assertEqual(response.status_code, 200) + self.assertEqual(payload["data"]["page"], 2) + self.assertEqual(payload["data"]["pageSize"], 2) + self.assertEqual(payload["data"]["total"], 5) + self.assertEqual(payload["data"]["pages"], 3) + self.assertEqual(len(payload["data"]["records"]), 2) + + def test_delete_rejects_non_available_codes(self): + self.add_code("CODE-003", status="used") + + response = self.client.delete("/admin/api/codes/CODE-003") + payload = response.get_json() + + self.assertEqual(response.status_code, 409) + self.assertFalse(payload["success"]) + + +class ServiceBehaviorTests(unittest.TestCase): + def test_create_user_accepts_full_upn_without_default_domain(self): + settings = build_settings("sqlite:////tmp/unused.db", default_domain="") + service = Office365Service(settings) + fake_client = FakeGraphClient() + service._graph_client = fake_client + + result = service.create_user("alice@example.com") + + self.assertEqual(result["userPrincipalName"], "alice@example.com") + self.assertEqual(fake_client.payloads[0]["mailNickname"], "alice") + self.assertEqual(fake_client.payloads[0]["userPrincipalName"], "alice@example.com") + + def test_create_user_requires_full_upn_when_default_domain_missing(self): + settings = build_settings("sqlite:////tmp/unused.db", default_domain="") + service = Office365Service(settings) + service._graph_client = FakeGraphClient() + + with self.assertRaises(ServiceConfigurationError): + service.create_user("alice") + + def test_create_user_returns_license_warning_when_not_strict(self): + settings = build_settings( + "sqlite:////tmp/unused.db", + default_license_sku="ENTERPRISEPACK", + license_assignment_required=False, + ) + service = Office365Service(settings) + service._graph_client = FakeGraphClient(skus=[]) + + result = service.create_user("alice") + + self.assertFalse(result["licenseAssigned"]) + self.assertEqual(result["licenseMessage"], "未找到许可证 SKU: ENTERPRISEPACK") + + def test_create_user_rolls_back_when_license_required(self): + settings = build_settings( + "sqlite:////tmp/unused.db", + default_license_sku="ENTERPRISEPACK", + license_assignment_required=True, + ) + service = Office365Service(settings) + fake_client = FakeGraphClient(skus=[]) + service._graph_client = fake_client + + with self.assertRaises(ServiceOperationError) as context: + service.create_user("alice") + + self.assertIn("已回滚删除账号 alice@example.com", str(context.exception)) + self.assertEqual(fake_client.deleted_users, ["user-1"]) + + +class ModelSerializationTests(unittest.TestCase): + def test_redemption_code_serializes_datetimes_as_utc_z(self): + code = RedemptionCode( + code="CODE-UTC", + created_at=datetime(2026, 3, 31, 12, 0, 0), + used_at=datetime(2026, 3, 31, 13, 0, 0, tzinfo=timezone.utc), + ) + + payload = code.to_dict() + + self.assertEqual(payload["createdAt"], "2026-03-31T12:00:00Z") + self.assertEqual(payload["usedAt"], "2026-03-31T13:00:00Z") + + +class SettingsTests(unittest.TestCase): + def test_container_database_url_is_remapped_locally(self): + env = { + "CLIENT_ID": "client-id", + "TENANT_ID": "tenant-id", + "CLIENT_SECRET": "client-secret", + "DEFAULT_PASSWORD": "TempPassw0rd!", + "DATABASE_URL": "sqlite:////app/data/redemption.db", + } + + with patch.dict(os.environ, env, clear=False): + settings = load_settings() + + self.assertTrue(settings.database_url.endswith("/office365-self-service/data/redemption.db")) + self.assertIn("已自动映射到本地路径", " ".join(settings.warnings)) + + +if __name__ == "__main__": + unittest.main()