sso_router.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. OAuth2 统一认证平台 SSO 路由
  3. """
  4. import logging
  5. import uuid
  6. import os
  7. from typing import Optional
  8. import httpx
  9. from fastapi import APIRouter, Depends, HTTPException
  10. from pydantic import BaseModel
  11. from sqlalchemy.orm import Session
  12. from app.database import get_db
  13. from app.models.user import User
  14. from app.services.auth_service import AuthService
  15. from app.services.user_service import UserService
  16. from app.schemas.user_schema import UserCreate
  17. from app.services import local_config_service
  18. logger = logging.getLogger(__name__)
  19. router = APIRouter(tags=["OAuth2 SSO"])
  20. SSO_BASE_URL = os.getenv("SSO_BASE_URL", "http://192.168.92.61:8200")
  21. SSO_CLIENT_ID = os.getenv("SSO_CLIENT_ID", "")
  22. SSO_CLIENT_SECRET = os.getenv("SSO_CLIENT_SECRET", "")
  23. SSO_REDIRECT_URI = os.getenv("SSO_REDIRECT_URI", "http://localhost:3000/#/auth/callback")
  24. SSO_SCOPE = os.getenv("SSO_SCOPE", "profile email")
  25. SSO_LOGOUT_REDIRECT_URL = os.getenv("SSO_LOGOUT_REDIRECT_URL", "http://192.168.92.61:9200/login")
  26. def _get_sso_config() -> dict:
  27. try:
  28. data = local_config_service.get_all()
  29. sso = data.get("sso", {})
  30. except Exception:
  31. sso = {}
  32. # sso_enable_redirect 控制是否启用SSO重定向
  33. enabled = sso.get("sso_enable_redirect")
  34. if enabled is None:
  35. enabled = True # 默认启用SSO
  36. return {
  37. "base_url": (sso.get("sso_base_url") or SSO_BASE_URL).rstrip("/"),
  38. "client_id": sso.get("sso_client_id") or SSO_CLIENT_ID,
  39. "client_secret": sso.get("sso_client_secret") or SSO_CLIENT_SECRET,
  40. "redirect_uri": sso.get("sso_redirect_uri") or SSO_REDIRECT_URI,
  41. "scope": sso.get("sso_scope") or SSO_SCOPE,
  42. "logout_redirect_url": sso.get("sso_logout_redirect_url") or SSO_LOGOUT_REDIRECT_URL,
  43. "enabled": bool(enabled),
  44. }
  45. class ExchangeCodeRequest(BaseModel):
  46. code: str
  47. @router.get("/api/sso/config")
  48. def get_sso_public_config():
  49. """Return SSO config for frontend."""
  50. cfg = _get_sso_config()
  51. authorize_url = (
  52. f"{cfg['base_url']}/oauth/authorize"
  53. f"?response_type=code"
  54. f"&client_id={cfg['client_id']}"
  55. f"&redirect_uri={cfg['redirect_uri']}"
  56. f"&scope={cfg['scope']}"
  57. )
  58. return {
  59. "sso_enabled": cfg["enabled"] and bool(cfg["client_id"]),
  60. "authorize_url": authorize_url,
  61. "logout_redirect_url": cfg["logout_redirect_url"],
  62. }
  63. @router.post("/api/oauth/exchange-code")
  64. async def exchange_code(request: ExchangeCodeRequest, db: Session = Depends(get_db)):
  65. """
  66. OAuth2 authorization code exchange endpoint.
  67. Frontend calls this with the code from SSO callback to get a local JWT.
  68. """
  69. cfg = _get_sso_config()
  70. if not cfg["client_id"] or not cfg["client_secret"]:
  71. raise HTTPException(status_code=500, detail="SSO not configured")
  72. async with httpx.AsyncClient(timeout=15) as client:
  73. # Exchange code for access_token
  74. token_url = f"{cfg['base_url']}/oauth/token"
  75. token_data = {
  76. "grant_type": "authorization_code",
  77. "code": request.code,
  78. "redirect_uri": cfg["redirect_uri"],
  79. "client_id": cfg["client_id"],
  80. "client_secret": cfg["client_secret"],
  81. }
  82. try:
  83. token_resp = await client.post(
  84. token_url,
  85. data=token_data,
  86. headers={"Content-Type": "application/x-www-form-urlencoded"}
  87. )
  88. except httpx.RequestError as e:
  89. logger.error("SSO token exchange failed: %s", e)
  90. raise HTTPException(status_code=502, detail="Cannot connect to SSO server")
  91. if token_resp.status_code != 200:
  92. error_data = {}
  93. try:
  94. error_data = token_resp.json()
  95. except Exception:
  96. pass
  97. error_msg = error_data.get("error_description") or error_data.get("error") or "Invalid code"
  98. raise HTTPException(status_code=401, detail=f"SSO login failed: {error_msg}")
  99. token_result = token_resp.json()
  100. access_token = token_result.get("access_token")
  101. if not access_token:
  102. raise HTTPException(status_code=502, detail="SSO did not return access_token")
  103. # Get user info
  104. userinfo_url = f"{cfg['base_url']}/oauth/userinfo"
  105. try:
  106. userinfo_resp = await client.get(
  107. userinfo_url,
  108. headers={"Authorization": f"Bearer {access_token}"}
  109. )
  110. except httpx.RequestError as e:
  111. logger.error("SSO userinfo failed: %s", e)
  112. raise HTTPException(status_code=502, detail="Failed to get user info")
  113. if userinfo_resp.status_code != 200:
  114. raise HTTPException(status_code=502, detail="Failed to get user info")
  115. userinfo = userinfo_resp.json()
  116. # Sync user to local database
  117. username = userinfo.get("username") or userinfo.get("sub") or ""
  118. if not username:
  119. raise HTTPException(status_code=400, detail="SSO did not return username")
  120. user = db.query(User).filter(User.username == username).first()
  121. if not user:
  122. email = userinfo.get("email") or None
  123. nickname = userinfo.get("real_name") or userinfo.get("username") or username
  124. random_password = uuid.uuid4().hex[:16]
  125. user_service = UserService(db)
  126. user_payload = {"username": username, "password": random_password, "nickname": nickname}
  127. if email:
  128. user_payload["email"] = email
  129. user_create = UserCreate(**user_payload)
  130. user = user_service.create_user(user_create)
  131. logger.info("SSO created new user: %s (id=%s)", username, user.id)
  132. else:
  133. updated = False
  134. if userinfo.get("email") and user.email != userinfo["email"]:
  135. user.email = userinfo["email"]
  136. updated = True
  137. if userinfo.get("real_name") and user.nickname != userinfo["real_name"]:
  138. user.nickname = userinfo["real_name"]
  139. updated = True
  140. if updated:
  141. db.commit()
  142. # Issue local JWT
  143. local_token = AuthService.create_access_token(user.id)
  144. return {
  145. "code": 200,
  146. "message": "success",
  147. "data": {
  148. "token": local_token,
  149. "user": {
  150. "id": user.id,
  151. "username": user.username,
  152. "nickname": user.nickname,
  153. "email": user.email,
  154. "phone": user.phone,
  155. "avatar": user.avatar,
  156. }
  157. }
  158. }