| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- import uuid
- from datetime import datetime, timedelta, timezone
- from urllib.parse import urlencode
- from fastapi import APIRouter, Depends, HTTPException, Query
- from fastapi.responses import RedirectResponse
- from pydantic import BaseModel
- from sqlalchemy import select
- from app.config import get_settings
- from app.core.auth import get_current_user
- from app.core.db import RefreshTokenModel, UserModel, async_session
- from app.core.security import create_access_token, create_refresh_token
- from app.core.sso_client import exchange_code_for_token, fetch_sso_userinfo
- router = APIRouter()
- settings = get_settings()
- class CodeExchangeRequest(BaseModel):
- code: str
- class RefreshRequest(BaseModel):
- refresh_token: str
- class LogoutRequest(BaseModel):
- token: str
- refresh_token: str
- async def _sync_user(sso_info: dict) -> UserModel:
- username = sso_info.get("username", sso_info.get("sub", "unknown"))
- role_codes = [r.get("code", "") for r in sso_info.get("roles", [])]
- async with async_session() as session:
- result = await session.execute(select(UserModel).where(UserModel.username == username))
- user = result.scalar_one_or_none()
- if not user:
- user = UserModel(
- id=str(uuid.uuid4()),
- username=username,
- email=sso_info.get("email"),
- real_name=sso_info.get("real_name"),
- avatar_url=sso_info.get("avatar_url"),
- company=sso_info.get("company"),
- department=sso_info.get("department"),
- position=sso_info.get("position"),
- roles=role_codes,
- is_active=1,
- )
- session.add(user)
- else:
- user.roles = role_codes
- user.email = sso_info.get("email", user.email)
- user.updated_at = datetime.now(timezone.utc)
- await session.commit()
- await session.refresh(user)
- return user
- @router.post("/api/oauth/exchange-code")
- async def exchange_code(req: CodeExchangeRequest):
- if not req.code:
- return {"code": "100001", "message": "缺少授权码", "data": None}
- try:
- token_resp = await exchange_code_for_token(req.code)
- sso_access_token = token_resp.get("access_token")
- if not sso_access_token:
- raise HTTPException(status_code=500, detail="登录失败: 获取令牌失败")
- sso_userinfo = await fetch_sso_userinfo(sso_access_token)
- if not sso_userinfo.get("username") and not sso_userinfo.get("sub"):
- raise HTTPException(status_code=500, detail="登录失败: 用户信息格式异常")
- user = await _sync_user(sso_userinfo)
- local_token = create_access_token(
- user_id=user.id, username=user.username, roles=user.roles or [],
- )
- refresh_token_str = create_refresh_token()
- expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
- async with async_session() as session:
- rt = RefreshTokenModel(
- id=str(uuid.uuid4()),
- user_id=user.id,
- token=refresh_token_str,
- expires_at=expires_at,
- )
- session.add(rt)
- await session.commit()
- return {
- "code": "000000",
- "message": "登录成功",
- "data": {
- "token": local_token,
- "refresh_token": refresh_token_str,
- "user": {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "phone": None,
- "is_superuser": bool(user.is_superuser),
- "is_active": bool(user.is_active),
- "roles": user.roles,
- },
- },
- }
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"登录失败: {str(e)}")
- @router.get("/auth/sso/authorize")
- async def sso_authorize(redirect: bool = Query(False)):
- params = urlencode({
- "response_type": "code",
- "client_id": settings.sso_client_id,
- "redirect_uri": settings.sso_redirect_uri,
- "scope": settings.sso_scope,
- })
- authorize_url = f"{settings.sso_base_url}/oauth/authorize?{params}"
- if redirect:
- return RedirectResponse(url=authorize_url)
- return {"code": "000000", "message": "获取授权URL成功", "data": {"authorize_url": authorize_url}}
- @router.post("/api/v1/auth/refresh")
- async def refresh_token_endpoint(req: RefreshRequest):
- async with async_session() as session:
- result = await session.execute(
- select(RefreshTokenModel).where(
- RefreshTokenModel.token == req.refresh_token,
- RefreshTokenModel.revoked == 0,
- RefreshTokenModel.expires_at > datetime.now(timezone.utc),
- )
- )
- rt = result.scalar_one_or_none()
- if not rt:
- raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
- result = await session.execute(select(UserModel).where(UserModel.id == rt.user_id))
- user = result.scalar_one_or_none()
- if not user or not user.is_active:
- raise HTTPException(status_code=401, detail="User not found")
- rt.revoked = 1
- new_token_str = create_refresh_token()
- new_expires = datetime.now(timezone.utc) + timedelta(hours=settings.jwt_refresh_expire_hours)
- new_rt = RefreshTokenModel(
- id=str(uuid.uuid4()), user_id=user.id, token=new_token_str, expires_at=new_expires,
- )
- session.add(new_rt)
- await session.commit()
- new_access = create_access_token(
- user_id=user.id, username=user.username, roles=user.roles or [],
- )
- return {
- "code": "000000",
- "message": "刷新成功",
- "data": {"token": new_access, "refresh_token": new_token_str},
- }
- @router.post("/api/v1/auth/logout")
- async def logout(req: LogoutRequest, current_user: dict = Depends(get_current_user)):
- async with async_session() as session:
- result = await session.execute(
- select(RefreshTokenModel).where(RefreshTokenModel.token == req.refresh_token)
- )
- rt = result.scalar_one_or_none()
- if rt:
- rt.revoked = 1
- await session.commit()
- return {
- "code": "000000",
- "message": "登出成功",
- "data": {"sso_logout_url": settings.sso_logout_redirect_url},
- }
- @router.get("/api/v1/auth/userinfo")
- async def get_userinfo(current_user: dict = Depends(get_current_user)):
- user_id = current_user.get("sub")
- async with async_session() as session:
- result = await session.execute(select(UserModel).where(UserModel.id == user_id))
- user = result.scalar_one_or_none()
- if not user:
- raise HTTPException(status_code=404, detail="User not found")
- return {
- "code": "000000",
- "data": {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "real_name": user.real_name,
- "roles": user.roles,
- "avatar_url": user.avatar_url,
- "permissions": [],
- },
- }
- @router.get("/api/v1/auth/me")
- async def get_me(current_user: dict = Depends(get_current_user)):
- return await get_userinfo(current_user)
|