oauth_exchange_view.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. """
  2. SSO 免登授权码交换端点
  3. 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
  4. """
  5. import sys
  6. import os
  7. # 添加src目录到Python路径
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
  9. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
  10. import logging
  11. import httpx
  12. from fastapi import APIRouter, Depends
  13. from sqlalchemy.ext.asyncio import AsyncSession
  14. from sqlalchemy import select
  15. from pydantic import BaseModel
  16. from app.base import get_db
  17. from app.schemas.base import ResponseSchema
  18. from app.core.config import config_handler
  19. logger = logging.getLogger(__name__)
  20. router = APIRouter(prefix="/api/oauth", tags=["SSO免登"])
  21. class ExchangeCodeRequest(BaseModel):
  22. code: str
  23. @router.post("/exchange-code", response_model=ResponseSchema)
  24. async def exchange_code(request_data: ExchangeCodeRequest, db: AsyncSession = Depends(get_db)):
  25. """
  26. SSO 免登授权码交换端点。
  27. 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
  28. 流程:
  29. 1. 用 code 调 SSO /oauth/token 换取 SSO access_token
  30. 2. 用 SSO access_token 调 SSO /oauth/userinfo 获取用户信息+角色
  31. 3. 同步用户到本地数据库
  32. 4. 签发本地 JWT
  33. 5. 返回 { token, refresh_token, user }
  34. """
  35. try:
  36. logger.info(f"[exchange-code] ========== 收到授权码交换请求 ==========")
  37. logger.info(f"[exchange-code] code={request_data.code[:10]}...")
  38. # 读取 SSO 配置
  39. sso_base_url = config_handler.get("admin_sso", "SSO_BASE_URL", "http://localhost:8200")
  40. sso_client_id = config_handler.get("admin_sso", "SSO_CLIENT_ID", "lqadmin")
  41. sso_client_secret = config_handler.get("admin_sso", "SSO_CLIENT_SECRET", "")
  42. sso_redirect_uri = config_handler.get("admin_sso", "REDIRECT_URI", "http://localhost:3000/auth/callback")
  43. # ========== 步骤1:用 code 换 SSO access_token ==========
  44. logger.info(f"[exchange-code] 步骤1: 用 code 换 SSO access_token")
  45. async with httpx.AsyncClient(timeout=10.0) as client:
  46. token_resp = await client.post(
  47. f"{sso_base_url}/oauth/token",
  48. data={
  49. "grant_type": "authorization_code",
  50. "code": request_data.code,
  51. "redirect_uri": sso_redirect_uri,
  52. "client_id": sso_client_id,
  53. "client_secret": sso_client_secret,
  54. },
  55. )
  56. token_data = token_resp.json()
  57. logger.info(f"[exchange-code] SSO token 响应: status={token_resp.status_code}")
  58. sso_access_token = token_data.get("access_token")
  59. if not sso_access_token:
  60. error_desc = token_data.get("error_description", token_data.get("error", "未知错误"))
  61. logger.warning(f"[exchange-code] 未获取到 SSO access_token: {error_desc}")
  62. return ResponseSchema(code="400001", message=f"SSO 授权码无效: {error_desc}", data=None)
  63. logger.info(f"[exchange-code] SSO access_token 获取成功")
  64. # ========== 步骤2:获取用户信息 ==========
  65. logger.info(f"[exchange-code] 步骤2: 获取 SSO 用户信息")
  66. async with httpx.AsyncClient(timeout=10.0) as client:
  67. userinfo_resp = await client.get(
  68. f"{sso_base_url}/oauth/userinfo",
  69. headers={"Authorization": f"Bearer {sso_access_token}"},
  70. )
  71. sso_user_info = userinfo_resp.json()
  72. logger.info(f"[exchange-code] SSO userinfo 响应: {sso_user_info}")
  73. if "sub" not in sso_user_info:
  74. logger.warning(f"[exchange-code] SSO userinfo 缺少 sub 字段")
  75. return ResponseSchema(code="400002", message="SSO 用户信息格式异常", data=None)
  76. sso_user_id = sso_user_info.get("sub")
  77. sso_username = sso_user_info.get("username", sso_user_id)
  78. sso_email = sso_user_info.get("email", "")
  79. sso_roles = sso_user_info.get("roles", [])
  80. logger.info(f"[exchange-code] SSO 用户: id={sso_user_id}, username={sso_username}, roles={sso_roles}")
  81. # ========== 步骤3:同步用户到本地数据库 ==========
  82. logger.info(f"[exchange-code] 步骤3: 同步用户到本地DB")
  83. from app.models.user import User, UserProfile, Role, UserRole
  84. # 查找用户(通过 email 或 username)
  85. stmt = select(User).where(User.email == sso_email)
  86. result = await db.execute(stmt)
  87. user = result.scalar_one_or_none()
  88. if not user and sso_username:
  89. stmt = select(User).where(User.username == sso_username)
  90. result = await db.execute(stmt)
  91. user = result.scalar_one_or_none()
  92. if user:
  93. logger.info(f"[exchange-code] 更新已有用户: id={user.id}, username={user.username}")
  94. else:
  95. import bcrypt
  96. logger.info(f"[exchange-code] 创建新用户: username={sso_username}")
  97. default_password = "SsoLogin@123"
  98. hashed_password = bcrypt.hashpw(default_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
  99. user = User(
  100. username=sso_username,
  101. email=sso_email or f"{sso_username}@sso.local",
  102. password_hash=hashed_password,
  103. is_active=True,
  104. is_superuser=False,
  105. )
  106. db.add(user)
  107. await db.flush()
  108. logger.info(f"[exchange-code] 新用户创建成功: id={user.id}")
  109. # 创建用户档案
  110. profile = UserProfile(
  111. user_id=user.id,
  112. real_name=sso_user_info.get("real_name", sso_username),
  113. company=sso_user_info.get("company", ""),
  114. department=sso_user_info.get("department", ""),
  115. position=sso_user_info.get("position", ""),
  116. )
  117. db.add(profile)
  118. # 处理角色映射
  119. logger.info(f"[exchange-code] 处理角色映射: sso_roles={sso_roles}")
  120. SSO_ROLE_MAPPING = {
  121. "ann_sys_admin": "admin",
  122. "ann_operator": "annotator",
  123. "ann_viewer": "viewer",
  124. "标注管理员": "admin",
  125. "标注员": "annotator",
  126. "查看者": "viewer",
  127. "super_admin": "admin",
  128. "admin": "admin",
  129. }
  130. # 从 roles 列表中提取 code 和 name 进行映射
  131. local_role_codes = []
  132. for role_item in sso_roles:
  133. if isinstance(role_item, dict):
  134. code = role_item.get("code", "")
  135. name = role_item.get("name", "")
  136. else:
  137. code = str(role_item)
  138. name = code
  139. mapped = SSO_ROLE_MAPPING.get(code) or SSO_ROLE_MAPPING.get(name)
  140. if mapped and mapped not in local_role_codes:
  141. local_role_codes.append(mapped)
  142. logger.info(f"[exchange-code] 映射后的本地角色: {local_role_codes}")
  143. # 查找并关联数据库角色
  144. if local_role_codes:
  145. stmt = select(Role).where(Role.code.in_(local_role_codes))
  146. result = await db.execute(stmt)
  147. db_roles = result.fetchall()
  148. db_role_list = [r[0] for r in db_roles]
  149. logger.info(f"[exchange-code] 找到数据库角色: {[r.code for r in db_role_list]}")
  150. # 清除用户现有角色
  151. stmt = select(UserRole).where(UserRole.user_id == user.id)
  152. result = await db.execute(stmt)
  153. existing_roles = result.fetchall()
  154. for er in existing_roles:
  155. await db.delete(er[0])
  156. # 添加新角色
  157. for db_role in db_role_list:
  158. user_role = UserRole(user_id=user.id, role_id=db_role.id)
  159. db.add(user_role)
  160. # 设置超级管理员标志
  161. if "admin" in local_role_codes:
  162. user.is_superuser = True
  163. await db.commit()
  164. await db.refresh(user)
  165. # 重新加载用户角色
  166. stmt = select(Role).join(UserRole).where(UserRole.user_id == user.id)
  167. result = await db.execute(stmt)
  168. user_roles = [r[0].code for r in result.fetchall()]
  169. logger.info(f"[exchange-code] 用户角色已更新: {user_roles}")
  170. # ========== 步骤4:签发本地 JWT ==========
  171. logger.info(f"[exchange-code] 步骤4: 签发本地 JWT")
  172. from app.services.jwt_token import create_access_token
  173. from app.utils import redis_token_manager as rtm
  174. access_payload = {
  175. "sub": str(user.id),
  176. "username": user.username,
  177. "email": user.email or "",
  178. "is_superuser": user.is_superuser,
  179. "roles": user_roles,
  180. "client_id": sso_client_id,
  181. }
  182. access_token = create_access_token(access_payload)
  183. refresh_payload = {
  184. "sub": str(user.id),
  185. "type": "refresh",
  186. }
  187. refresh_token = create_access_token(refresh_payload)
  188. # 存储 token 到 Redis(admin 通道)
  189. rtm.store_access_token(access_token, access_payload)
  190. # 同时存储 OAuth 通道 key,使 /oauth/userinfo 端点能验证该 token
  191. rtm.store_oauth_access_token(access_token, sso_client_id, str(user.id))
  192. rtm.store_refresh_token(refresh_token, str(user.id))
  193. # ========== 步骤5:返回结果 ==========
  194. user_info = {
  195. "id": str(user.id),
  196. "username": user.username,
  197. "email": user.email or "",
  198. "phone": user.phone if hasattr(user, "phone") else None,
  199. "is_superuser": user.is_superuser,
  200. "roles": user_roles,
  201. }
  202. logger.info(f"[exchange-code] ========== 授权码交换成功: user={user.username} ==========")
  203. return ResponseSchema(
  204. code="000000",
  205. message="登录成功",
  206. data={
  207. "token": access_token,
  208. "refresh_token": refresh_token,
  209. "token_type": "bearer",
  210. "user": user_info,
  211. }
  212. )
  213. except Exception as e:
  214. logger.error(f"[exchange-code] ========== 授权码交换错误 ==========")
  215. logger.error(f"[exchange-code] {type(e).__name__}: {str(e)}", exc_info=True)
  216. return ResponseSchema(
  217. code="500001",
  218. message=f"服务器内部错误: {str(e)}",
  219. data=None
  220. )