oauth.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """
  2. OAuth 2.0 认证路由
  3. 处理 OAuth 登录流程
  4. """
  5. from fastapi import APIRouter, HTTPException, Query
  6. from fastapi.responses import RedirectResponse
  7. from pydantic import BaseModel
  8. from typing import Optional
  9. from config import settings
  10. from services.oauth_service import OAuthService
  11. from services.jwt_service import JWTService
  12. from schemas.auth import TokenResponse, UserResponse
  13. router = APIRouter(prefix="/api/oauth", tags=["oauth"])
  14. class OAuthLoginResponse(BaseModel):
  15. """OAuth 登录响应"""
  16. authorization_url: str
  17. state: str
  18. @router.get("/login", response_model=OAuthLoginResponse)
  19. async def oauth_login():
  20. """
  21. 启动 OAuth 登录流程
  22. 生成授权 URL 和 state 参数,前端需要保存 state 并重定向到授权 URL
  23. Returns:
  24. 包含授权 URL 和 state 的响应
  25. """
  26. if not settings.OAUTH_ENABLED:
  27. raise HTTPException(status_code=400, detail="OAuth 登录未启用")
  28. # 生成 state 参数
  29. state = OAuthService.generate_state()
  30. # 构建授权 URL
  31. authorization_url = OAuthService.get_authorization_url(state)
  32. return OAuthLoginResponse(
  33. authorization_url=authorization_url,
  34. state=state
  35. )
  36. @router.get("/callback", response_model=TokenResponse)
  37. async def oauth_callback(
  38. code: str = Query(..., description="OAuth 授权码"),
  39. state: str = Query(..., description="State 参数"),
  40. ):
  41. """
  42. OAuth 回调端点
  43. 处理 OAuth 认证中心的回调,用授权码换取令牌,获取用户信息,
  44. 并创建或更新本地用户记录
  45. Args:
  46. code: OAuth 授权码
  47. state: State 参数(前端需要验证)
  48. Returns:
  49. JWT tokens 和用户信息
  50. """
  51. if not settings.OAUTH_ENABLED:
  52. raise HTTPException(status_code=400, detail="OAuth 登录未启用")
  53. try:
  54. # 1. 用授权码换取访问令牌
  55. token_data = await OAuthService.exchange_code_for_token(code)
  56. access_token = token_data.get("access_token")
  57. if not access_token:
  58. raise HTTPException(status_code=400, detail="未能获取访问令牌")
  59. # 2. 使用访问令牌获取用户信息
  60. oauth_user_info = await OAuthService.get_user_info(access_token)
  61. # 3. 同步用户到本地数据库
  62. user = OAuthService.sync_user_from_oauth(oauth_user_info)
  63. # 4. 生成本地 JWT tokens
  64. user_data = {
  65. "id": user.id,
  66. "username": user.username,
  67. "email": user.email,
  68. "role": user.role
  69. }
  70. jwt_access_token = JWTService.create_access_token(user_data)
  71. jwt_refresh_token = JWTService.create_refresh_token(user_data)
  72. # 5. 返回 tokens 和用户信息
  73. return TokenResponse(
  74. access_token=jwt_access_token,
  75. refresh_token=jwt_refresh_token,
  76. token_type="bearer",
  77. user=UserResponse(
  78. id=user.id,
  79. username=user.username,
  80. email=user.email,
  81. role=user.role,
  82. created_at=user.created_at
  83. )
  84. )
  85. except Exception as e:
  86. raise HTTPException(
  87. status_code=400,
  88. detail=f"OAuth 登录失败: {str(e)}"
  89. )
  90. @router.get("/status")
  91. async def oauth_status():
  92. """
  93. 获取 OAuth 配置状态
  94. Returns:
  95. OAuth 是否启用及相关配置信息
  96. """
  97. return {
  98. "enabled": settings.OAUTH_ENABLED,
  99. "provider": "SSO" if settings.OAUTH_ENABLED else None,
  100. "base_url": settings.OAUTH_BASE_URL if settings.OAUTH_ENABLED else None
  101. }