sso_view.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # coding=utf-8
  2. """
  3. SSO 统一认证视图
  4. 对接 LQAI-middle-platform 统一认证平台
  5. """
  6. import uuid_utils.compat as uuid
  7. from django.utils.translation import gettext as _
  8. from rest_framework.request import Request
  9. from rest_framework.views import APIView
  10. from common.auth import TokenAuth
  11. from common.exception.app_exception import AppApiException
  12. from common.result import result
  13. from sso.services.sso_client import get_sso_client
  14. from users.models import User
  15. from common.utils.logger import maxkb_logger
  16. # SSO 角色 code → 本地角色映射
  17. SSO_ROLE_MAP = {
  18. 'super_admin': 'ADMIN',
  19. 'ws_admin': 'WORKSPACE_MANAGE',
  20. 'user': 'USER',
  21. }
  22. class SSOView(APIView):
  23. """SSO 统一认证"""
  24. authentication_classes = [TokenAuth]
  25. class AuthorizeUrl(APIView):
  26. """获取 SSO 授权 URL"""
  27. authentication_classes = []
  28. def get(self, request: Request):
  29. redirect = request.query_params.get('redirect', 'false') == 'true'
  30. state = request.query_params.get('state', str(uuid.uuid4()))
  31. client = get_sso_client()
  32. authorize_url = client.get_authorize_url(state=state)
  33. if redirect:
  34. from django.http import HttpResponseRedirect
  35. return HttpResponseRedirect(authorize_url)
  36. return result.success({
  37. 'authorize_url': authorize_url,
  38. 'state': state,
  39. })
  40. class ExchangeCode(APIView):
  41. """授权码交换(核心免登接口)
  42. POST /api/oauth/exchange-code
  43. """
  44. authentication_classes = []
  45. def post(self, request: Request):
  46. code = request.data.get('code', '')
  47. if not code:
  48. raise AppApiException(400, _('缺少授权码'))
  49. client = get_sso_client()
  50. # Step 1: 用 code 换 SSO access_token
  51. token_data = client.exchange_code(code)
  52. sso_access_token = token_data.get('access_token', '')
  53. if not sso_access_token:
  54. raise AppApiException(400, _('SSO token exchange failed'))
  55. # Step 2: 获取用户信息(含角色)
  56. userinfo = client.get_userinfo(sso_access_token)
  57. # Step 3: 同步用户到本地数据库
  58. user = SSOView.ExchangeCode._sync_user(userinfo)
  59. # Step 4: 同步角色
  60. roles = userinfo.get('roles', [])
  61. SSOView.ExchangeCode._sync_roles(user, roles)
  62. # Step 5: 签发本地 JWT(与常规登录保持一致:包含 id、type,并缓存用户)
  63. from common.constants.authentication_type import AuthenticationType
  64. from common.constants.cache_version import Cache_Version
  65. from django.core import signing
  66. from django.core.cache import cache
  67. from maxkb.const import CONFIG
  68. token = signing.dumps({
  69. 'username': user.username,
  70. 'id': str(user.id),
  71. 'email': user.email,
  72. 'type': AuthenticationType.SYSTEM_USER.value,
  73. })
  74. refresh_token = signing.dumps({'user_id': str(user.id), 'type': 'refresh'})
  75. version, get_key = Cache_Version.TOKEN.value
  76. timeout = CONFIG.get_session_timeout()
  77. cache.set(get_key(token), user, timeout=timeout, version=version)
  78. return result.success({
  79. 'token': token,
  80. 'refresh_token': refresh_token,
  81. 'user': {
  82. 'id': str(user.id),
  83. 'username': user.username,
  84. 'email': user.email,
  85. 'phone': user.phone,
  86. 'is_active': user.is_active,
  87. 'roles': [user.role],
  88. },
  89. })
  90. @staticmethod
  91. def _sync_user(userinfo: dict) -> User:
  92. """同步 SSO 用户到本地数据库"""
  93. username = userinfo.get('username', '')
  94. email = userinfo.get('email', '')
  95. real_name = userinfo.get('real_name', '')
  96. if not username:
  97. raise AppApiException(400, _('SSO user info missing username'))
  98. user, created = User.objects.get_or_create(
  99. username=username,
  100. defaults={
  101. 'email': email,
  102. 'nick_name': real_name or username,
  103. 'is_active': True,
  104. }
  105. )
  106. if not created:
  107. user.email = email
  108. if real_name:
  109. user.nick_name = real_name
  110. user.save()
  111. return user
  112. @staticmethod
  113. def _sync_roles(user: User, sso_roles: list):
  114. """同步 SSO 角色到本地用户角色字段"""
  115. if not sso_roles:
  116. return
  117. # 取第一个匹配的本地角色
  118. for role_info in sso_roles:
  119. code = role_info.get('code', '')
  120. local_role = SSO_ROLE_MAP.get(code)
  121. if local_role:
  122. user.role = local_role
  123. user.save()
  124. return
  125. # 没有匹配的角色,默认普通用户
  126. if not user.role:
  127. user.role = 'USER'
  128. user.save()