oauth.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. OAuth 2.0 认证路由
  3. 处理 SSO 登录流程、本地 JWT 签发、token 刷新和用户信息查询。
  4. """
  5. import logging
  6. import httpx
  7. from fastapi import APIRouter, HTTPException, Request, status
  8. from pydantic import BaseModel
  9. from config import settings
  10. from services.oauth_service import OAuthService
  11. from services.auth_service import AuthService
  12. from services import jwt_service
  13. logger = logging.getLogger(__name__)
  14. router = APIRouter(prefix="/api/oauth", tags=["oauth"])
  15. class OAuthLoginResponse(BaseModel):
  16. """OAuth 登录响应"""
  17. authorization_url: str
  18. state: str
  19. class SSOTokenResponse(BaseModel):
  20. """SSO Token 响应(本地签发的 JWT)"""
  21. token: str
  22. refresh_token: str
  23. token_type: str = "bearer"
  24. user: dict
  25. class ExchangeCodeRequest(BaseModel):
  26. """授权码交换请求"""
  27. code: str
  28. class ExchangeCodeResponse(BaseModel):
  29. """授权码交换响应(本地签发的 JWT)"""
  30. token: str
  31. refresh_token: str
  32. token_type: str = "bearer"
  33. user: dict
  34. class RefreshRequest(BaseModel):
  35. """Token 刷新请求"""
  36. refresh_token: str
  37. class UserResponse(BaseModel):
  38. """用户信息响应"""
  39. id: str
  40. username: str
  41. email: str
  42. role: str
  43. created_at: str
  44. @router.get("/login", response_model=OAuthLoginResponse)
  45. async def oauth_login():
  46. """
  47. 启动 OAuth 登录流程。
  48. 生成授权 URL 和 state 参数。
  49. """
  50. state = OAuthService.generate_state()
  51. authorization_url = OAuthService.get_authorization_url(state)
  52. return OAuthLoginResponse(
  53. authorization_url=authorization_url,
  54. state=state,
  55. )
  56. @router.post("/exchange-code", response_model=ExchangeCodeResponse)
  57. async def exchange_code(request_body: ExchangeCodeRequest):
  58. """
  59. 授权码交换端点(前端调用)。
  60. 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
  61. """
  62. logger.info(f"Exchange code received: code={request_body.code[:10]}...")
  63. # 1. 用授权码换取 SSO access_token
  64. token_data = await OAuthService.exchange_code_for_token(request_body.code)
  65. sso_access_token = token_data.get("access_token")
  66. if not sso_access_token:
  67. raise HTTPException(status_code=400, detail="未能获取访问令牌")
  68. # 2. 获取完整用户信息(含角色)
  69. user_info = await OAuthService.get_user_profile(sso_access_token)
  70. # 3. 同步用户到本地数据库
  71. user = OAuthService.sync_user_from_oauth(user_info)
  72. # 4. 签发本地 JWT
  73. access_token = jwt_service.create_access_token(
  74. user_id=user.id,
  75. username=user.username,
  76. email=user.email,
  77. role=user.role,
  78. )
  79. refresh_token = jwt_service.create_refresh_token(user_id=user.id)
  80. logger.info(f"Code exchange successful for user: {user.username}")
  81. return ExchangeCodeResponse(
  82. token=access_token,
  83. refresh_token=refresh_token,
  84. token_type="bearer",
  85. user={
  86. "id": user.id,
  87. "username": user.username,
  88. "email": user.email,
  89. "role": user.role,
  90. "created_at": str(user.created_at),
  91. },
  92. )
  93. @router.post("/refresh")
  94. async def oauth_refresh(request_body: RefreshRequest):
  95. """
  96. Token 刷新端点。
  97. 验证 refresh_token(本地 JWT),签发新的 access_token + refresh_token。
  98. """
  99. logger.info("Token refresh requested")
  100. try:
  101. payload = jwt_service.verify_token(request_body.refresh_token)
  102. if payload.get("type") != "refresh":
  103. raise HTTPException(
  104. status_code=status.HTTP_401_UNAUTHORIZED,
  105. detail="无效的 refresh token",
  106. )
  107. user_id = payload.get("sub")
  108. # 从数据库获取用户信息
  109. user = AuthService.get_current_user(user_id)
  110. if not user:
  111. raise HTTPException(
  112. status_code=status.HTTP_401_UNAUTHORIZED,
  113. detail="用户不存在",
  114. )
  115. new_access_token = jwt_service.create_access_token(
  116. user_id=user.id,
  117. username=user.username,
  118. email=user.email,
  119. role=user.role,
  120. )
  121. new_refresh_token = jwt_service.create_refresh_token(user_id=user.id)
  122. return {
  123. "token": new_access_token,
  124. "refresh_token": new_refresh_token,
  125. "token_type": "bearer",
  126. }
  127. except HTTPException:
  128. raise
  129. except Exception as e:
  130. logger.error(f"Token refresh unexpected error: {e}", exc_info=True)
  131. raise HTTPException(
  132. status_code=400,
  133. detail=f"Token 刷新失败: {str(e)}",
  134. )
  135. @router.post("/logout")
  136. async def oauth_logout(request: Request):
  137. """
  138. 登出端点。
  139. 可选通知 SSO 注销,返回统一认证平台登录页面 URL 供前端跳转。
  140. """
  141. auth_header = request.headers.get("Authorization", "")
  142. token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else None
  143. if token and settings.SSO_REVOKE_ENDPOINT:
  144. # 通知 SSO 注销
  145. try:
  146. async with httpx.AsyncClient(timeout=5.0) as client:
  147. await client.post(
  148. f"{settings.SSO_BASE_URL}{settings.SSO_REVOKE_ENDPOINT}",
  149. headers={"Authorization": f"Bearer {token}"},
  150. )
  151. except Exception as e:
  152. logger.warning(f"SSO revoke failed (non-critical): {e}")
  153. return {
  154. "message": "登出成功",
  155. "logout_url": settings.SSO_LOGOUT_REDIRECT_URL,
  156. }
  157. @router.get("/me", response_model=UserResponse)
  158. async def get_current_user(request: Request):
  159. """
  160. 获取当前认证用户信息。
  161. 用户信息由 AuthMiddleware 从 JWT 验证后填充到 request.state。
  162. """
  163. user_data = getattr(request.state, "user", None)
  164. if not user_data:
  165. raise HTTPException(
  166. status_code=status.HTTP_401_UNAUTHORIZED,
  167. detail="未认证",
  168. )
  169. user = AuthService.get_current_user(user_data["id"])
  170. return UserResponse(
  171. id=user.id,
  172. username=user.username,
  173. email=user.email,
  174. role=user.role,
  175. created_at=str(user.created_at),
  176. )
  177. @router.get("/status")
  178. async def oauth_status():
  179. """获取 SSO 配置状态"""
  180. return {
  181. "enabled": settings.SSO_ENABLED,
  182. "provider": "SSO" if settings.SSO_ENABLED else None,
  183. "base_url": settings.SSO_BASE_URL if settings.SSO_ENABLED else None,
  184. }