346 lines
13 KiB
Python
346 lines
13 KiB
Python
import os
|
|
from datetime import datetime, timezone
|
|
import tempfile
|
|
import unittest
|
|
from unittest.mock import patch
|
|
from pathlib import Path
|
|
|
|
from office365_self_service import create_app, db
|
|
from office365_self_service.models import AuditEvent, RedemptionCode
|
|
from office365_self_service.services import Office365Service, ServiceConfigurationError, ServiceOperationError
|
|
from office365_self_service.settings import GRAPH_BASE_URL, GRAPH_SCOPE, Settings, TOKEN_ENDPOINT_TEMPLATE, load_settings
|
|
|
|
|
|
def build_settings(database_url: str, **overrides) -> Settings:
|
|
tenant_id = overrides.pop("tenant_id", "tenant-id")
|
|
base = {
|
|
"app_name": "Office 365 Self Service Test",
|
|
"host": "127.0.0.1",
|
|
"port": 5000,
|
|
"debug": False,
|
|
"session_secret": "test-secret",
|
|
"auth_enabled": False,
|
|
"admin_username": "",
|
|
"admin_password": "",
|
|
"client_id": "client-id",
|
|
"tenant_id": tenant_id,
|
|
"client_secret": "client-secret",
|
|
"default_password": "TempPassw0rd!",
|
|
"default_domain": "example.com",
|
|
"default_usage_location": "US",
|
|
"default_license_sku": "",
|
|
"license_assignment_required": False,
|
|
"force_change_password": True,
|
|
"graph_base_url": GRAPH_BASE_URL,
|
|
"token_endpoint": TOKEN_ENDPOINT_TEMPLATE.format(tenant_id=tenant_id),
|
|
"scope": GRAPH_SCOPE,
|
|
"database_url": database_url,
|
|
"validation_errors": (),
|
|
"warnings": (),
|
|
}
|
|
base.update(overrides)
|
|
return Settings(**base)
|
|
|
|
|
|
class FakeService:
|
|
def __init__(self, result=None, error=None):
|
|
self.result = result or {
|
|
"userPrincipalName": "alice@example.com",
|
|
"temporaryPassword": "TempPassw0rd!",
|
|
"licenseAssigned": True,
|
|
"licenseMessage": None,
|
|
}
|
|
self.error = error
|
|
self.calls = []
|
|
|
|
def create_user(self, username: str, **kwargs):
|
|
self.calls.append(username)
|
|
if self.error:
|
|
raise self.error
|
|
return self.result
|
|
|
|
|
|
class FakeGraphClient:
|
|
def __init__(self, skus=None, assign_result=None, assign_error=None, delete_error=None):
|
|
self.payloads = []
|
|
self.deleted_users = []
|
|
self.skus = skus or []
|
|
self.assign_result = assign_result or {"status": "ok"}
|
|
self.assign_error = assign_error
|
|
self.delete_error = delete_error
|
|
|
|
def create_user(self, payload):
|
|
self.payloads.append(payload)
|
|
return {"id": "user-1"}
|
|
|
|
def list_subscribed_skus(self):
|
|
return self.skus
|
|
|
|
def assign_license(self, user_id, add_licenses=None, remove_licenses=None):
|
|
if self.assign_error:
|
|
raise self.assign_error
|
|
return self.assign_result
|
|
|
|
def delete_user(self, user_id):
|
|
if self.delete_error:
|
|
raise self.delete_error
|
|
self.deleted_users.append(user_id)
|
|
|
|
|
|
class AppRouteTests(unittest.TestCase):
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.TemporaryDirectory()
|
|
db_path = Path(self.temp_dir.name) / "test.db"
|
|
self.settings = build_settings(f"sqlite:///{db_path}")
|
|
self.service = FakeService()
|
|
self.app = create_app(
|
|
settings_override=self.settings,
|
|
service_factory=lambda _settings: self.service,
|
|
)
|
|
self.app.testing = True
|
|
self.client = self.app.test_client()
|
|
|
|
with self.app.app_context():
|
|
db.drop_all()
|
|
db.create_all()
|
|
|
|
def tearDown(self):
|
|
self.temp_dir.cleanup()
|
|
|
|
def add_code(self, code: str, status: str = "available"):
|
|
with self.app.app_context():
|
|
db.session.add(RedemptionCode(code=code, status=status))
|
|
db.session.commit()
|
|
|
|
def fetch_code(self, code: str) -> RedemptionCode:
|
|
with self.app.app_context():
|
|
return RedemptionCode.query.filter_by(code=code).first()
|
|
|
|
def fetch_audit_events(self) -> list[AuditEvent]:
|
|
with self.app.app_context():
|
|
return AuditEvent.query.order_by(AuditEvent.created_at.asc(), AuditEvent.id.asc()).all()
|
|
|
|
def test_redeem_marks_code_used_and_prevents_second_use(self):
|
|
self.add_code("CODE-001")
|
|
|
|
response = self.client.post("/api/redeem", json={"code": "code-001", "username": "alice"})
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 201)
|
|
self.assertTrue(payload["success"])
|
|
self.assertEqual(payload["data"]["userPrincipalName"], "alice@example.com")
|
|
self.assertEqual(self.service.calls, ["alice"])
|
|
|
|
code = self.fetch_code("CODE-001")
|
|
self.assertEqual(code.status, "used")
|
|
self.assertEqual(code.used_by_username, "alice")
|
|
self.assertEqual(code.used_by_principal_name, "alice@example.com")
|
|
|
|
second = self.client.post("/api/redeem", json={"code": "CODE-001", "username": "bob"})
|
|
second_payload = second.get_json()
|
|
|
|
self.assertEqual(second.status_code, 404)
|
|
self.assertFalse(second_payload["success"])
|
|
self.assertEqual(self.service.calls, ["alice"])
|
|
|
|
def test_redeem_releases_code_when_service_fails(self):
|
|
self.service = FakeService(error=ServiceOperationError("用户名已存在。", status_code=409))
|
|
self.app = create_app(
|
|
settings_override=self.settings,
|
|
service_factory=lambda _settings: self.service,
|
|
)
|
|
self.app.testing = True
|
|
self.client = self.app.test_client()
|
|
|
|
with self.app.app_context():
|
|
db.drop_all()
|
|
db.create_all()
|
|
db.session.add(RedemptionCode(code="CODE-002"))
|
|
db.session.commit()
|
|
|
|
response = self.client.post("/api/redeem", json={"code": "CODE-002", "username": "alice"})
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 409)
|
|
self.assertFalse(payload["success"])
|
|
code = self.fetch_code("CODE-002")
|
|
self.assertEqual(code.status, "available")
|
|
self.assertIsNone(code.used_by_username)
|
|
self.assertEqual(self.service.calls, ["alice"])
|
|
|
|
def test_public_health_endpoint_is_available(self):
|
|
response = self.client.get("/api/health")
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(payload["success"])
|
|
self.assertIn("platform", payload["data"])
|
|
|
|
def test_generate_delete_and_failed_redeem_are_audited(self):
|
|
generate = self.client.post("/admin/api/codes/generate", json={"count": 1})
|
|
generated_code = generate.get_json()["data"]["codes"][0]
|
|
|
|
delete = self.client.delete(f"/admin/api/codes/{generated_code}")
|
|
self.assertEqual(delete.status_code, 200)
|
|
|
|
failed = self.client.post("/api/redeem", json={"code": "MISSING-CODE", "username": "alice"})
|
|
self.assertEqual(failed.status_code, 404)
|
|
|
|
audit_response = self.client.get("/admin/api/audit-events?page=1&pageSize=10")
|
|
audit_payload = audit_response.get_json()
|
|
|
|
self.assertEqual(audit_response.status_code, 200)
|
|
self.assertEqual(audit_payload["data"]["total"], 3)
|
|
self.assertEqual(
|
|
[event["eventType"] for event in audit_payload["data"]["events"]],
|
|
["redeem_completed", "code_deleted", "code_generated"],
|
|
)
|
|
|
|
def test_successful_redeem_creates_success_audit_event(self):
|
|
self.add_code("CODE-004")
|
|
|
|
response = self.client.post("/api/redeem", json={"code": "CODE-004", "username": "alice"})
|
|
self.assertEqual(response.status_code, 201)
|
|
|
|
events = self.fetch_audit_events()
|
|
self.assertEqual(len(events), 1)
|
|
event_payload = events[0].to_dict()
|
|
self.assertEqual(event_payload["eventType"], "redeem_completed")
|
|
self.assertEqual(event_payload["status"], "success")
|
|
self.assertEqual(event_payload["code"], "CODE-004")
|
|
self.assertEqual(event_payload["principalName"], "alice@example.com")
|
|
self.assertEqual(event_payload["details"]["licenseAssigned"], True)
|
|
|
|
def test_codes_api_uses_database_pagination(self):
|
|
for index in range(5):
|
|
self.add_code(f"CODE-{index:03d}")
|
|
|
|
response = self.client.get("/admin/api/codes?page=2&pageSize=2")
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(payload["data"]["page"], 2)
|
|
self.assertEqual(payload["data"]["pageSize"], 2)
|
|
self.assertEqual(payload["data"]["total"], 5)
|
|
self.assertEqual(payload["data"]["pages"], 3)
|
|
self.assertEqual(len(payload["data"]["codes"]), 2)
|
|
|
|
def test_records_api_uses_database_pagination(self):
|
|
with self.app.app_context():
|
|
for index in range(5):
|
|
db.session.add(
|
|
RedemptionCode(
|
|
code=f"USED-{index:03d}",
|
|
status="used",
|
|
used_by_username=f"user{index}",
|
|
used_by_principal_name=f"user{index}@example.com",
|
|
)
|
|
)
|
|
db.session.commit()
|
|
|
|
response = self.client.get("/admin/api/records?page=2&pageSize=2")
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(payload["data"]["page"], 2)
|
|
self.assertEqual(payload["data"]["pageSize"], 2)
|
|
self.assertEqual(payload["data"]["total"], 5)
|
|
self.assertEqual(payload["data"]["pages"], 3)
|
|
self.assertEqual(len(payload["data"]["records"]), 2)
|
|
|
|
def test_delete_rejects_non_available_codes(self):
|
|
self.add_code("CODE-003", status="used")
|
|
|
|
response = self.client.delete("/admin/api/codes/CODE-003")
|
|
payload = response.get_json()
|
|
|
|
self.assertEqual(response.status_code, 409)
|
|
self.assertFalse(payload["success"])
|
|
|
|
|
|
class ServiceBehaviorTests(unittest.TestCase):
|
|
def test_create_user_accepts_full_upn_without_default_domain(self):
|
|
settings = build_settings("sqlite:////tmp/unused.db", default_domain="")
|
|
service = Office365Service(settings)
|
|
fake_client = FakeGraphClient()
|
|
service._graph_client = fake_client
|
|
|
|
result = service.create_user("alice@example.com")
|
|
|
|
self.assertEqual(result["userPrincipalName"], "alice@example.com")
|
|
self.assertEqual(fake_client.payloads[0]["mailNickname"], "alice")
|
|
self.assertEqual(fake_client.payloads[0]["userPrincipalName"], "alice@example.com")
|
|
|
|
def test_create_user_requires_full_upn_when_default_domain_missing(self):
|
|
settings = build_settings("sqlite:////tmp/unused.db", default_domain="")
|
|
service = Office365Service(settings)
|
|
service._graph_client = FakeGraphClient()
|
|
|
|
with self.assertRaises(ServiceConfigurationError):
|
|
service.create_user("alice")
|
|
|
|
def test_create_user_returns_license_warning_when_not_strict(self):
|
|
settings = build_settings(
|
|
"sqlite:////tmp/unused.db",
|
|
default_license_sku="ENTERPRISEPACK",
|
|
license_assignment_required=False,
|
|
)
|
|
service = Office365Service(settings)
|
|
service._graph_client = FakeGraphClient(skus=[])
|
|
|
|
result = service.create_user("alice")
|
|
|
|
self.assertFalse(result["licenseAssigned"])
|
|
self.assertEqual(result["licenseMessage"], "未找到许可证 SKU: ENTERPRISEPACK")
|
|
|
|
def test_create_user_rolls_back_when_license_required(self):
|
|
settings = build_settings(
|
|
"sqlite:////tmp/unused.db",
|
|
default_license_sku="ENTERPRISEPACK",
|
|
license_assignment_required=True,
|
|
)
|
|
service = Office365Service(settings)
|
|
fake_client = FakeGraphClient(skus=[])
|
|
service._graph_client = fake_client
|
|
|
|
with self.assertRaises(ServiceOperationError) as context:
|
|
service.create_user("alice")
|
|
|
|
self.assertIn("已回滚删除账号 alice@example.com", str(context.exception))
|
|
self.assertEqual(fake_client.deleted_users, ["user-1"])
|
|
|
|
|
|
class ModelSerializationTests(unittest.TestCase):
|
|
def test_redemption_code_serializes_datetimes_as_utc_z(self):
|
|
code = RedemptionCode(
|
|
code="CODE-UTC",
|
|
created_at=datetime(2026, 3, 31, 12, 0, 0),
|
|
used_at=datetime(2026, 3, 31, 13, 0, 0, tzinfo=timezone.utc),
|
|
)
|
|
|
|
payload = code.to_dict()
|
|
|
|
self.assertEqual(payload["createdAt"], "2026-03-31T12:00:00Z")
|
|
self.assertEqual(payload["usedAt"], "2026-03-31T13:00:00Z")
|
|
|
|
|
|
class SettingsTests(unittest.TestCase):
|
|
def test_container_database_url_is_remapped_locally(self):
|
|
env = {
|
|
"CLIENT_ID": "client-id",
|
|
"TENANT_ID": "tenant-id",
|
|
"CLIENT_SECRET": "client-secret",
|
|
"DEFAULT_PASSWORD": "TempPassw0rd!",
|
|
"DATABASE_URL": "sqlite:////app/data/redemption.db",
|
|
}
|
|
|
|
with patch.dict(os.environ, env, clear=False):
|
|
settings = load_settings()
|
|
|
|
self.assertTrue(settings.database_url.endswith("/office365-self-service/data/redemption.db"))
|
|
self.assertIn("已自动映射到本地路径", " ".join(settings.warnings))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|