import os from datetime import datetime, timezone import tempfile import unittest from unittest.mock import patch from pathlib import Path from office365_self_service import create_app, db from office365_self_service.models import AuditEvent, RedemptionCode from office365_self_service.services import Office365Service, ServiceConfigurationError, ServiceOperationError from office365_self_service.settings import GRAPH_BASE_URL, GRAPH_SCOPE, Settings, TOKEN_ENDPOINT_TEMPLATE, load_settings def build_settings(database_url: str, **overrides) -> Settings: tenant_id = overrides.pop("tenant_id", "tenant-id") base = { "app_name": "Office 365 Self Service Test", "host": "127.0.0.1", "port": 5000, "debug": False, "session_secret": "test-secret", "auth_enabled": False, "admin_username": "", "admin_password": "", "client_id": "client-id", "tenant_id": tenant_id, "client_secret": "client-secret", "default_password": "TempPassw0rd!", "default_domain": "example.com", "default_usage_location": "US", "default_license_sku": "", "license_assignment_required": False, "force_change_password": True, "graph_base_url": GRAPH_BASE_URL, "token_endpoint": TOKEN_ENDPOINT_TEMPLATE.format(tenant_id=tenant_id), "scope": GRAPH_SCOPE, "database_url": database_url, "validation_errors": (), "warnings": (), } base.update(overrides) return Settings(**base) class FakeService: def __init__(self, result=None, error=None): self.result = result or { "userPrincipalName": "alice@example.com", "temporaryPassword": "TempPassw0rd!", "licenseAssigned": True, "licenseMessage": None, } self.error = error self.calls = [] def create_user(self, username: str, **kwargs): self.calls.append(username) if self.error: raise self.error return self.result class FakeGraphClient: def __init__(self, skus=None, assign_result=None, assign_error=None, delete_error=None): self.payloads = [] self.deleted_users = [] self.skus = skus or [] self.assign_result = assign_result or {"status": "ok"} self.assign_error = assign_error self.delete_error = delete_error def create_user(self, payload): self.payloads.append(payload) return {"id": "user-1"} def list_subscribed_skus(self): return self.skus def assign_license(self, user_id, add_licenses=None, remove_licenses=None): if self.assign_error: raise self.assign_error return self.assign_result def delete_user(self, user_id): if self.delete_error: raise self.delete_error self.deleted_users.append(user_id) class AppRouteTests(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() db_path = Path(self.temp_dir.name) / "test.db" self.settings = build_settings(f"sqlite:///{db_path}") self.service = FakeService() self.app = create_app( settings_override=self.settings, service_factory=lambda _settings: self.service, ) self.app.testing = True self.client = self.app.test_client() with self.app.app_context(): db.drop_all() db.create_all() def tearDown(self): self.temp_dir.cleanup() def add_code(self, code: str, status: str = "available"): with self.app.app_context(): db.session.add(RedemptionCode(code=code, status=status)) db.session.commit() def fetch_code(self, code: str) -> RedemptionCode: with self.app.app_context(): return RedemptionCode.query.filter_by(code=code).first() def fetch_audit_events(self) -> list[AuditEvent]: with self.app.app_context(): return AuditEvent.query.order_by(AuditEvent.created_at.asc(), AuditEvent.id.asc()).all() def test_redeem_marks_code_used_and_prevents_second_use(self): self.add_code("CODE-001") response = self.client.post("/api/redeem", json={"code": "code-001", "username": "alice"}) payload = response.get_json() self.assertEqual(response.status_code, 201) self.assertTrue(payload["success"]) self.assertEqual(payload["data"]["userPrincipalName"], "alice@example.com") self.assertEqual(self.service.calls, ["alice"]) code = self.fetch_code("CODE-001") self.assertEqual(code.status, "used") self.assertEqual(code.used_by_username, "alice") self.assertEqual(code.used_by_principal_name, "alice@example.com") second = self.client.post("/api/redeem", json={"code": "CODE-001", "username": "bob"}) second_payload = second.get_json() self.assertEqual(second.status_code, 404) self.assertFalse(second_payload["success"]) self.assertEqual(self.service.calls, ["alice"]) def test_redeem_releases_code_when_service_fails(self): self.service = FakeService(error=ServiceOperationError("用户名已存在。", status_code=409)) self.app = create_app( settings_override=self.settings, service_factory=lambda _settings: self.service, ) self.app.testing = True self.client = self.app.test_client() with self.app.app_context(): db.drop_all() db.create_all() db.session.add(RedemptionCode(code="CODE-002")) db.session.commit() response = self.client.post("/api/redeem", json={"code": "CODE-002", "username": "alice"}) payload = response.get_json() self.assertEqual(response.status_code, 409) self.assertFalse(payload["success"]) code = self.fetch_code("CODE-002") self.assertEqual(code.status, "available") self.assertIsNone(code.used_by_username) self.assertEqual(self.service.calls, ["alice"]) def test_public_health_endpoint_is_available(self): response = self.client.get("/api/health") payload = response.get_json() self.assertEqual(response.status_code, 200) self.assertTrue(payload["success"]) self.assertIn("platform", payload["data"]) def test_generate_delete_and_failed_redeem_are_audited(self): generate = self.client.post("/admin/api/codes/generate", json={"count": 1}) generated_code = generate.get_json()["data"]["codes"][0] delete = self.client.delete(f"/admin/api/codes/{generated_code}") self.assertEqual(delete.status_code, 200) failed = self.client.post("/api/redeem", json={"code": "MISSING-CODE", "username": "alice"}) self.assertEqual(failed.status_code, 404) audit_response = self.client.get("/admin/api/audit-events?page=1&pageSize=10") audit_payload = audit_response.get_json() self.assertEqual(audit_response.status_code, 200) self.assertEqual(audit_payload["data"]["total"], 3) self.assertEqual( [event["eventType"] for event in audit_payload["data"]["events"]], ["redeem_completed", "code_deleted", "code_generated"], ) def test_successful_redeem_creates_success_audit_event(self): self.add_code("CODE-004") response = self.client.post("/api/redeem", json={"code": "CODE-004", "username": "alice"}) self.assertEqual(response.status_code, 201) events = self.fetch_audit_events() self.assertEqual(len(events), 1) event_payload = events[0].to_dict() self.assertEqual(event_payload["eventType"], "redeem_completed") self.assertEqual(event_payload["status"], "success") self.assertEqual(event_payload["code"], "CODE-004") self.assertEqual(event_payload["principalName"], "alice@example.com") self.assertEqual(event_payload["details"]["licenseAssigned"], True) def test_codes_api_uses_database_pagination(self): for index in range(5): self.add_code(f"CODE-{index:03d}") response = self.client.get("/admin/api/codes?page=2&pageSize=2") payload = response.get_json() self.assertEqual(response.status_code, 200) self.assertEqual(payload["data"]["page"], 2) self.assertEqual(payload["data"]["pageSize"], 2) self.assertEqual(payload["data"]["total"], 5) self.assertEqual(payload["data"]["pages"], 3) self.assertEqual(len(payload["data"]["codes"]), 2) def test_records_api_uses_database_pagination(self): with self.app.app_context(): for index in range(5): db.session.add( RedemptionCode( code=f"USED-{index:03d}", status="used", used_by_username=f"user{index}", used_by_principal_name=f"user{index}@example.com", ) ) db.session.commit() response = self.client.get("/admin/api/records?page=2&pageSize=2") payload = response.get_json() self.assertEqual(response.status_code, 200) self.assertEqual(payload["data"]["page"], 2) self.assertEqual(payload["data"]["pageSize"], 2) self.assertEqual(payload["data"]["total"], 5) self.assertEqual(payload["data"]["pages"], 3) self.assertEqual(len(payload["data"]["records"]), 2) def test_delete_rejects_non_available_codes(self): self.add_code("CODE-003", status="used") response = self.client.delete("/admin/api/codes/CODE-003") payload = response.get_json() self.assertEqual(response.status_code, 409) self.assertFalse(payload["success"]) class ServiceBehaviorTests(unittest.TestCase): def test_create_user_accepts_full_upn_without_default_domain(self): settings = build_settings("sqlite:////tmp/unused.db", default_domain="") service = Office365Service(settings) fake_client = FakeGraphClient() service._graph_client = fake_client result = service.create_user("alice@example.com") self.assertEqual(result["userPrincipalName"], "alice@example.com") self.assertEqual(fake_client.payloads[0]["mailNickname"], "alice") self.assertEqual(fake_client.payloads[0]["userPrincipalName"], "alice@example.com") def test_create_user_requires_full_upn_when_default_domain_missing(self): settings = build_settings("sqlite:////tmp/unused.db", default_domain="") service = Office365Service(settings) service._graph_client = FakeGraphClient() with self.assertRaises(ServiceConfigurationError): service.create_user("alice") def test_create_user_returns_license_warning_when_not_strict(self): settings = build_settings( "sqlite:////tmp/unused.db", default_license_sku="ENTERPRISEPACK", license_assignment_required=False, ) service = Office365Service(settings) service._graph_client = FakeGraphClient(skus=[]) result = service.create_user("alice") self.assertFalse(result["licenseAssigned"]) self.assertEqual(result["licenseMessage"], "未找到许可证 SKU: ENTERPRISEPACK") def test_create_user_rolls_back_when_license_required(self): settings = build_settings( "sqlite:////tmp/unused.db", default_license_sku="ENTERPRISEPACK", license_assignment_required=True, ) service = Office365Service(settings) fake_client = FakeGraphClient(skus=[]) service._graph_client = fake_client with self.assertRaises(ServiceOperationError) as context: service.create_user("alice") self.assertIn("已回滚删除账号 alice@example.com", str(context.exception)) self.assertEqual(fake_client.deleted_users, ["user-1"]) class ModelSerializationTests(unittest.TestCase): def test_redemption_code_serializes_datetimes_as_utc_z(self): code = RedemptionCode( code="CODE-UTC", created_at=datetime(2026, 3, 31, 12, 0, 0), used_at=datetime(2026, 3, 31, 13, 0, 0, tzinfo=timezone.utc), ) payload = code.to_dict() self.assertEqual(payload["createdAt"], "2026-03-31T12:00:00Z") self.assertEqual(payload["usedAt"], "2026-03-31T13:00:00Z") class SettingsTests(unittest.TestCase): def test_container_database_url_is_remapped_locally(self): env = { "CLIENT_ID": "client-id", "TENANT_ID": "tenant-id", "CLIENT_SECRET": "client-secret", "DEFAULT_PASSWORD": "TempPassw0rd!", "DATABASE_URL": "sqlite:////app/data/redemption.db", } with patch.dict(os.environ, env, clear=False): settings = load_settings() self.assertTrue(settings.database_url.endswith("/office365-self-service/data/redemption.db")) self.assertIn("已自动映射到本地路径", " ".join(settings.warnings)) if __name__ == "__main__": unittest.main()