|
@@ -0,0 +1,215 @@
|
|
|
|
|
+import uuid
|
|
|
|
|
+from datetime import datetime, timedelta, timezone
|
|
|
|
|
+from urllib.parse import urlencode
|
|
|
|
|
+
|
|
|
|
|
+from fastapi import APIRouter, Depends, HTTPException, Query
|
|
|
|
|
+from fastapi.responses import RedirectResponse
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+from sqlalchemy import select
|
|
|
|
|
+
|
|
|
|
|
+from app.config import get_settings
|
|
|
|
|
+from app.core.auth import get_current_user
|
|
|
|
|
+from app.core.db import RefreshTokenModel, UserModel, async_session
|
|
|
|
|
+from app.core.security import create_access_token, create_refresh_token
|
|
|
|
|
+from app.core.sso_client import exchange_code_for_token, fetch_sso_userinfo
|
|
|
|
|
+
|
|
|
|
|
+router = APIRouter()
|
|
|
|
|
+settings = get_settings()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CodeExchangeRequest(BaseModel):
|
|
|
|
|
+ code: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RefreshRequest(BaseModel):
|
|
|
|
|
+ refresh_token: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LogoutRequest(BaseModel):
|
|
|
|
|
+ token: str
|
|
|
|
|
+ refresh_token: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _sync_user(sso_info: dict) -> UserModel:
|
|
|
|
|
+ username = sso_info.get("username", sso_info.get("sub", "unknown"))
|
|
|
|
|
+ role_codes = [r.get("code", "") for r in sso_info.get("roles", [])]
|
|
|
|
|
+
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(select(UserModel).where(UserModel.username == username))
|
|
|
|
|
+ user = result.scalar_one_or_none()
|
|
|
|
|
+
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ user = UserModel(
|
|
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
|
|
+ username=username,
|
|
|
|
|
+ email=sso_info.get("email"),
|
|
|
|
|
+ real_name=sso_info.get("real_name"),
|
|
|
|
|
+ avatar_url=sso_info.get("avatar_url"),
|
|
|
|
|
+ company=sso_info.get("company"),
|
|
|
|
|
+ department=sso_info.get("department"),
|
|
|
|
|
+ position=sso_info.get("position"),
|
|
|
|
|
+ roles=role_codes,
|
|
|
|
|
+ is_active=1,
|
|
|
|
|
+ )
|
|
|
|
|
+ session.add(user)
|
|
|
|
|
+ else:
|
|
|
|
|
+ user.roles = role_codes
|
|
|
|
|
+ user.email = sso_info.get("email", user.email)
|
|
|
|
|
+ user.updated_at = datetime.now(timezone.utc)
|
|
|
|
|
+
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+ await session.refresh(user)
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/api/oauth/exchange-code")
|
|
|
|
|
+async def exchange_code(req: CodeExchangeRequest):
|
|
|
|
|
+ if not req.code:
|
|
|
|
|
+ return {"code": "100001", "message": "缺少授权码", "data": None}
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ token_resp = await exchange_code_for_token(req.code)
|
|
|
|
|
+ sso_access_token = token_resp.get("access_token")
|
|
|
|
|
+ if not sso_access_token:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="登录失败: 获取令牌失败")
|
|
|
|
|
+
|
|
|
|
|
+ sso_userinfo = await fetch_sso_userinfo(sso_access_token)
|
|
|
|
|
+ if not sso_userinfo.get("username") and not sso_userinfo.get("sub"):
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="登录失败: 用户信息格式异常")
|
|
|
|
|
+
|
|
|
|
|
+ user = await _sync_user(sso_userinfo)
|
|
|
|
|
+
|
|
|
|
|
+ local_token = create_access_token(
|
|
|
|
|
+ user_id=user.id, username=user.username, roles=user.roles or [],
|
|
|
|
|
+ )
|
|
|
|
|
+ refresh_token_str = create_refresh_token()
|
|
|
|
|
+ expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
|
|
|
|
|
+
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ rt = RefreshTokenModel(
|
|
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
|
|
+ user_id=user.id,
|
|
|
|
|
+ token=refresh_token_str,
|
|
|
|
|
+ expires_at=expires_at,
|
|
|
|
|
+ )
|
|
|
|
|
+ session.add(rt)
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": "000000",
|
|
|
|
|
+ "message": "登录成功",
|
|
|
|
|
+ "data": {
|
|
|
|
|
+ "token": local_token,
|
|
|
|
|
+ "refresh_token": refresh_token_str,
|
|
|
|
|
+ "user": {
|
|
|
|
|
+ "id": user.id,
|
|
|
|
|
+ "username": user.username,
|
|
|
|
|
+ "email": user.email,
|
|
|
|
|
+ "phone": None,
|
|
|
|
|
+ "is_superuser": bool(user.is_superuser),
|
|
|
|
|
+ "is_active": bool(user.is_active),
|
|
|
|
|
+ "roles": user.roles,
|
|
|
|
|
+ },
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"登录失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/auth/sso/authorize")
|
|
|
|
|
+async def sso_authorize(redirect: bool = Query(False)):
|
|
|
|
|
+ params = urlencode({
|
|
|
|
|
+ "response_type": "code",
|
|
|
|
|
+ "client_id": settings.sso_client_id,
|
|
|
|
|
+ "redirect_uri": settings.sso_redirect_uri,
|
|
|
|
|
+ "scope": settings.sso_scope,
|
|
|
|
|
+ })
|
|
|
|
|
+ authorize_url = f"{settings.sso_base_url}/oauth/authorize?{params}"
|
|
|
|
|
+ if redirect:
|
|
|
|
|
+ return RedirectResponse(url=authorize_url)
|
|
|
|
|
+ return {"code": "000000", "message": "获取授权URL成功", "data": {"authorize_url": authorize_url}}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/api/v1/auth/refresh")
|
|
|
|
|
+async def refresh_token_endpoint(req: RefreshRequest):
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(
|
|
|
|
|
+ select(RefreshTokenModel).where(
|
|
|
|
|
+ RefreshTokenModel.token == req.refresh_token,
|
|
|
|
|
+ RefreshTokenModel.revoked == 0,
|
|
|
|
|
+ RefreshTokenModel.expires_at > datetime.now(timezone.utc),
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ rt = result.scalar_one_or_none()
|
|
|
|
|
+ if not rt:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
|
|
|
|
+
|
|
|
|
|
+ result = await session.execute(select(UserModel).where(UserModel.id == rt.user_id))
|
|
|
|
|
+ user = result.scalar_one_or_none()
|
|
|
|
|
+ if not user or not user.is_active:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="User not found")
|
|
|
|
|
+
|
|
|
|
|
+ rt.revoked = 1
|
|
|
|
|
+ new_token_str = create_refresh_token()
|
|
|
|
|
+ new_expires = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
|
|
|
|
|
+ new_rt = RefreshTokenModel(
|
|
|
|
|
+ id=str(uuid.uuid4()), user_id=user.id, token=new_token_str, expires_at=new_expires,
|
|
|
|
|
+ )
|
|
|
|
|
+ session.add(new_rt)
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+
|
|
|
|
|
+ new_access = create_access_token(
|
|
|
|
|
+ user_id=user.id, username=user.username, roles=user.roles or [],
|
|
|
|
|
+ )
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": "000000",
|
|
|
|
|
+ "message": "刷新成功",
|
|
|
|
|
+ "data": {"token": new_access, "refresh_token": new_token_str},
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/api/v1/auth/logout")
|
|
|
|
|
+async def logout(req: LogoutRequest, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(
|
|
|
|
|
+ select(RefreshTokenModel).where(RefreshTokenModel.token == req.refresh_token)
|
|
|
|
|
+ )
|
|
|
|
|
+ rt = result.scalar_one_or_none()
|
|
|
|
|
+ if rt:
|
|
|
|
|
+ rt.revoked = 1
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": "000000",
|
|
|
|
|
+ "message": "登出成功",
|
|
|
|
|
+ "data": {"sso_logout_url": settings.sso_logout_redirect_url},
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/api/v1/auth/userinfo")
|
|
|
|
|
+async def get_userinfo(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
+ user_id = current_user.get("sub")
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(select(UserModel).where(UserModel.id == user_id))
|
|
|
|
|
+ user = result.scalar_one_or_none()
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "code": "000000",
|
|
|
|
|
+ "data": {
|
|
|
|
|
+ "id": user.id,
|
|
|
|
|
+ "username": user.username,
|
|
|
|
|
+ "email": user.email,
|
|
|
|
|
+ "real_name": user.real_name,
|
|
|
|
|
+ "roles": user.roles,
|
|
|
|
|
+ "avatar_url": user.avatar_url,
|
|
|
|
|
+ "permissions": [],
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/api/v1/auth/me")
|
|
|
|
|
+async def get_me(current_user: dict = Depends(get_current_user)):
|
|
|
|
|
+ return await get_userinfo(current_user)
|