auth.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import logging
  2. import uuid
  3. from datetime import datetime, timedelta, timezone
  4. from urllib.parse import urlencode
  5. from fastapi import APIRouter, Depends, HTTPException, Query
  6. from fastapi.responses import RedirectResponse
  7. from pydantic import BaseModel
  8. from sqlalchemy import select
  9. from app.config import get_settings
  10. from app.core.auth import get_current_user
  11. from app.core.db import RefreshTokenModel, UserModel, async_session
  12. from app.core.security import create_access_token, create_refresh_token
  13. from app.core.sso_client import exchange_code_for_token, fetch_sso_userinfo
  14. router = APIRouter()
  15. settings = get_settings()
  16. logger = logging.getLogger(__name__)
  17. class CodeExchangeRequest(BaseModel):
  18. code: str
  19. class RefreshRequest(BaseModel):
  20. refresh_token: str
  21. class LogoutRequest(BaseModel):
  22. token: str
  23. refresh_token: str
  24. async def _sync_user(sso_info: dict) -> UserModel:
  25. username = sso_info.get("username", sso_info.get("sub", "unknown"))
  26. role_codes = [r.get("code", "") for r in sso_info.get("roles", [])]
  27. async with async_session() as session:
  28. result = await session.execute(select(UserModel).where(UserModel.username == username))
  29. user = result.scalar_one_or_none()
  30. if not user:
  31. user = UserModel(
  32. id=str(uuid.uuid4()),
  33. username=username,
  34. email=sso_info.get("email"),
  35. real_name=sso_info.get("real_name"),
  36. avatar_url=sso_info.get("avatar_url"),
  37. company=sso_info.get("company"),
  38. department=sso_info.get("department"),
  39. position=sso_info.get("position"),
  40. roles=role_codes,
  41. is_active=1,
  42. )
  43. session.add(user)
  44. else:
  45. user.roles = role_codes
  46. user.email = sso_info.get("email", user.email)
  47. user.updated_at = datetime.now(timezone.utc)
  48. await session.commit()
  49. await session.refresh(user)
  50. return user
  51. @router.post("/api/oauth/exchange-code")
  52. async def exchange_code(req: CodeExchangeRequest):
  53. if not req.code:
  54. return {"code": "100001", "message": "缺少授权码", "data": None}
  55. logger.info("[SSO] exchange_code start, code=%s", req.code[:10])
  56. logger.info("[SSO] sso_base_url=%s", settings.sso_base_url)
  57. logger.info("[SSO] client_id=%s", settings.sso_client_id)
  58. logger.info("[SSO] redirect_uri=%s", settings.sso_redirect_uri)
  59. try:
  60. token_resp = await exchange_code_for_token(req.code)
  61. logger.info("[SSO] token response: %s", token_resp)
  62. sso_access_token = token_resp.get("access_token")
  63. if not sso_access_token:
  64. logger.error("[SSO] no access_token in response: %s", token_resp)
  65. raise HTTPException(status_code=500, detail="登录失败: 获取令牌失败")
  66. sso_userinfo = await fetch_sso_userinfo(sso_access_token)
  67. logger.info("[SSO] userinfo: %s", sso_userinfo)
  68. if not sso_userinfo.get("username") and not sso_userinfo.get("sub"):
  69. raise HTTPException(status_code=500, detail="登录失败: 用户信息格式异常")
  70. user = await _sync_user(sso_userinfo)
  71. local_token = create_access_token(
  72. user_id=user.id, username=user.username, roles=user.roles or [],
  73. )
  74. refresh_token_str = create_refresh_token()
  75. expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
  76. async with async_session() as session:
  77. rt = RefreshTokenModel(
  78. id=str(uuid.uuid4()),
  79. user_id=user.id,
  80. token=refresh_token_str,
  81. expires_at=expires_at,
  82. )
  83. session.add(rt)
  84. await session.commit()
  85. return {
  86. "code": "000000",
  87. "message": "登录成功",
  88. "data": {
  89. "token": local_token,
  90. "refresh_token": refresh_token_str,
  91. "user": {
  92. "id": user.id,
  93. "username": user.username,
  94. "email": user.email,
  95. "phone": None,
  96. "is_superuser": bool(user.is_superuser),
  97. "is_active": bool(user.is_active),
  98. "roles": user.roles,
  99. },
  100. },
  101. }
  102. except HTTPException:
  103. raise
  104. except Exception as e:
  105. import traceback
  106. logger.error("[SSO] exchange_code failed: %s", traceback.format_exc())
  107. raise HTTPException(status_code=500, detail=f"登录失败: {str(e)}")
  108. @router.get("/auth/sso/authorize")
  109. async def sso_authorize(redirect: bool = Query(False)):
  110. params = urlencode({
  111. "response_type": "code",
  112. "client_id": settings.sso_client_id,
  113. "redirect_uri": settings.sso_redirect_uri,
  114. "scope": settings.sso_scope,
  115. })
  116. authorize_url = f"{settings.sso_base_url}/oauth/authorize?{params}"
  117. if redirect:
  118. return RedirectResponse(url=authorize_url)
  119. return {"code": "000000", "message": "获取授权URL成功", "data": {"authorize_url": authorize_url}}
  120. @router.post("/api/v1/auth/refresh")
  121. async def refresh_token_endpoint(req: RefreshRequest):
  122. async with async_session() as session:
  123. result = await session.execute(
  124. select(RefreshTokenModel).where(
  125. RefreshTokenModel.token == req.refresh_token,
  126. RefreshTokenModel.revoked == 0,
  127. RefreshTokenModel.expires_at > datetime.now(timezone.utc),
  128. )
  129. )
  130. rt = result.scalar_one_or_none()
  131. if not rt:
  132. raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
  133. result = await session.execute(select(UserModel).where(UserModel.id == rt.user_id))
  134. user = result.scalar_one_or_none()
  135. if not user or not user.is_active:
  136. raise HTTPException(status_code=401, detail="User not found")
  137. rt.revoked = 1
  138. new_token_str = create_refresh_token()
  139. new_expires = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
  140. new_rt = RefreshTokenModel(
  141. id=str(uuid.uuid4()), user_id=user.id, token=new_token_str, expires_at=new_expires,
  142. )
  143. session.add(new_rt)
  144. await session.commit()
  145. new_access = create_access_token(
  146. user_id=user.id, username=user.username, roles=user.roles or [],
  147. )
  148. return {
  149. "code": "000000",
  150. "message": "刷新成功",
  151. "data": {"token": new_access, "refresh_token": new_token_str},
  152. }
  153. @router.post("/api/v1/auth/logout")
  154. async def logout(req: LogoutRequest, current_user: dict = Depends(get_current_user)):
  155. async with async_session() as session:
  156. result = await session.execute(
  157. select(RefreshTokenModel).where(RefreshTokenModel.token == req.refresh_token)
  158. )
  159. rt = result.scalar_one_or_none()
  160. if rt:
  161. rt.revoked = 1
  162. await session.commit()
  163. return {
  164. "code": "000000",
  165. "message": "登出成功",
  166. "data": {"sso_logout_url": settings.sso_logout_redirect_url},
  167. }
  168. @router.get("/api/v1/auth/userinfo")
  169. async def get_userinfo(current_user: dict = Depends(get_current_user)):
  170. user_id = current_user.get("sub")
  171. async with async_session() as session:
  172. result = await session.execute(select(UserModel).where(UserModel.id == user_id))
  173. user = result.scalar_one_or_none()
  174. if not user:
  175. raise HTTPException(status_code=404, detail="User not found")
  176. return {
  177. "code": "000000",
  178. "data": {
  179. "id": user.id,
  180. "username": user.username,
  181. "email": user.email,
  182. "real_name": user.real_name,
  183. "roles": user.roles,
  184. "avatar_url": user.avatar_url,
  185. "permissions": [],
  186. },
  187. }
  188. @router.get("/api/v1/auth/me")
  189. async def get_me(current_user: dict = Depends(get_current_user)):
  190. return await get_userinfo(current_user)