|
|
@@ -0,0 +1,261 @@
|
|
|
+"""
|
|
|
+SSO 免登授权码交换端点
|
|
|
+前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
|
|
|
+"""
|
|
|
+import sys
|
|
|
+import os
|
|
|
+
|
|
|
+# 添加src目录到Python路径
|
|
|
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
|
|
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
|
|
|
+
|
|
|
+import logging
|
|
|
+import httpx
|
|
|
+from fastapi import APIRouter, Depends
|
|
|
+from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
+from sqlalchemy import select
|
|
|
+from pydantic import BaseModel
|
|
|
+
|
|
|
+from app.base import get_db
|
|
|
+from app.schemas.base import ResponseSchema
|
|
|
+from app.core.config import config_handler
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+router = APIRouter(prefix="/api/oauth", tags=["SSO免登"])
|
|
|
+
|
|
|
+
|
|
|
+class ExchangeCodeRequest(BaseModel):
|
|
|
+ code: str
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/exchange-code", response_model=ResponseSchema)
|
|
|
+async def exchange_code(request_data: ExchangeCodeRequest, db: AsyncSession = Depends(get_db)):
|
|
|
+ """
|
|
|
+ SSO 免登授权码交换端点。
|
|
|
+ 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
|
|
|
+
|
|
|
+ 流程:
|
|
|
+ 1. 用 code 调 SSO /oauth/token 换取 SSO access_token
|
|
|
+ 2. 用 SSO access_token 调 SSO /oauth/userinfo 获取用户信息+角色
|
|
|
+ 3. 同步用户到本地数据库
|
|
|
+ 4. 签发本地 JWT
|
|
|
+ 5. 返回 { token, refresh_token, user }
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ logger.info(f"[exchange-code] ========== 收到授权码交换请求 ==========")
|
|
|
+ logger.info(f"[exchange-code] code={request_data.code[:10]}...")
|
|
|
+
|
|
|
+ # 读取 SSO 配置
|
|
|
+ sso_base_url = config_handler.get("admin_sso", "SSO_BASE_URL", "http://localhost:8200")
|
|
|
+ sso_client_id = config_handler.get("admin_sso", "SSO_CLIENT_ID", "lqadmin")
|
|
|
+ sso_client_secret = config_handler.get("admin_sso", "SSO_CLIENT_SECRET", "")
|
|
|
+ sso_redirect_uri = config_handler.get("admin_sso", "REDIRECT_URI", "http://localhost:3000/auth/callback")
|
|
|
+
|
|
|
+ # ========== 步骤1:用 code 换 SSO access_token ==========
|
|
|
+ logger.info(f"[exchange-code] 步骤1: 用 code 换 SSO access_token")
|
|
|
+ async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
+ token_resp = await client.post(
|
|
|
+ f"{sso_base_url}/oauth/token",
|
|
|
+ data={
|
|
|
+ "grant_type": "authorization_code",
|
|
|
+ "code": request_data.code,
|
|
|
+ "redirect_uri": sso_redirect_uri,
|
|
|
+ "client_id": sso_client_id,
|
|
|
+ "client_secret": sso_client_secret,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ token_data = token_resp.json()
|
|
|
+ logger.info(f"[exchange-code] SSO token 响应: status={token_resp.status_code}")
|
|
|
+
|
|
|
+ sso_access_token = token_data.get("access_token")
|
|
|
+ if not sso_access_token:
|
|
|
+ error_desc = token_data.get("error_description", token_data.get("error", "未知错误"))
|
|
|
+ logger.warning(f"[exchange-code] 未获取到 SSO access_token: {error_desc}")
|
|
|
+ return ResponseSchema(code="400001", message=f"SSO 授权码无效: {error_desc}", data=None)
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] SSO access_token 获取成功")
|
|
|
+
|
|
|
+ # ========== 步骤2:获取用户信息 ==========
|
|
|
+ logger.info(f"[exchange-code] 步骤2: 获取 SSO 用户信息")
|
|
|
+ async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
+ userinfo_resp = await client.get(
|
|
|
+ f"{sso_base_url}/oauth/userinfo",
|
|
|
+ headers={"Authorization": f"Bearer {sso_access_token}"},
|
|
|
+ )
|
|
|
+ sso_user_info = userinfo_resp.json()
|
|
|
+ logger.info(f"[exchange-code] SSO userinfo 响应: {sso_user_info}")
|
|
|
+
|
|
|
+ if "sub" not in sso_user_info:
|
|
|
+ logger.warning(f"[exchange-code] SSO userinfo 缺少 sub 字段")
|
|
|
+ return ResponseSchema(code="400002", message="SSO 用户信息格式异常", data=None)
|
|
|
+
|
|
|
+ sso_user_id = sso_user_info.get("sub")
|
|
|
+ sso_username = sso_user_info.get("username", sso_user_id)
|
|
|
+ sso_email = sso_user_info.get("email", "")
|
|
|
+ sso_roles = sso_user_info.get("roles", [])
|
|
|
+ logger.info(f"[exchange-code] SSO 用户: id={sso_user_id}, username={sso_username}, roles={sso_roles}")
|
|
|
+
|
|
|
+ # ========== 步骤3:同步用户到本地数据库 ==========
|
|
|
+ logger.info(f"[exchange-code] 步骤3: 同步用户到本地DB")
|
|
|
+ from app.models.user import User, UserProfile, Role, UserRole
|
|
|
+
|
|
|
+ # 查找用户(通过 email 或 username)
|
|
|
+ stmt = select(User).where(User.email == sso_email)
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ user = result.scalar_one_or_none()
|
|
|
+
|
|
|
+ if not user and sso_username:
|
|
|
+ stmt = select(User).where(User.username == sso_username)
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ user = result.scalar_one_or_none()
|
|
|
+
|
|
|
+ if user:
|
|
|
+ logger.info(f"[exchange-code] 更新已有用户: id={user.id}, username={user.username}")
|
|
|
+ else:
|
|
|
+ import bcrypt
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] 创建新用户: username={sso_username}")
|
|
|
+ default_password = "SsoLogin@123"
|
|
|
+ hashed_password = bcrypt.hashpw(default_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
|
|
+
|
|
|
+ user = User(
|
|
|
+ username=sso_username,
|
|
|
+ email=sso_email or f"{sso_username}@sso.local",
|
|
|
+ password_hash=hashed_password,
|
|
|
+ is_active=True,
|
|
|
+ is_superuser=False,
|
|
|
+ )
|
|
|
+ db.add(user)
|
|
|
+ await db.flush()
|
|
|
+ logger.info(f"[exchange-code] 新用户创建成功: id={user.id}")
|
|
|
+
|
|
|
+ # 创建用户档案
|
|
|
+ profile = UserProfile(
|
|
|
+ user_id=user.id,
|
|
|
+ real_name=sso_user_info.get("real_name", sso_username),
|
|
|
+ company=sso_user_info.get("company", ""),
|
|
|
+ department=sso_user_info.get("department", ""),
|
|
|
+ position=sso_user_info.get("position", ""),
|
|
|
+ )
|
|
|
+ db.add(profile)
|
|
|
+
|
|
|
+ # 处理角色映射
|
|
|
+ logger.info(f"[exchange-code] 处理角色映射: sso_roles={sso_roles}")
|
|
|
+ SSO_ROLE_MAPPING = {
|
|
|
+ "label_admin": "admin",
|
|
|
+ "annotator": "annotator",
|
|
|
+ "viewer": "viewer",
|
|
|
+ "标注管理员": "admin",
|
|
|
+ "标注员": "annotator",
|
|
|
+ "查看者": "viewer",
|
|
|
+ "super_admin": "admin",
|
|
|
+ "admin": "admin",
|
|
|
+ }
|
|
|
+
|
|
|
+ # 从 roles 列表中提取 code 和 name 进行映射
|
|
|
+ local_role_codes = []
|
|
|
+ for role_item in sso_roles:
|
|
|
+ if isinstance(role_item, dict):
|
|
|
+ code = role_item.get("code", "")
|
|
|
+ name = role_item.get("name", "")
|
|
|
+ else:
|
|
|
+ code = str(role_item)
|
|
|
+ name = code
|
|
|
+
|
|
|
+ mapped = SSO_ROLE_MAPPING.get(code) or SSO_ROLE_MAPPING.get(name)
|
|
|
+ if mapped and mapped not in local_role_codes:
|
|
|
+ local_role_codes.append(mapped)
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] 映射后的本地角色: {local_role_codes}")
|
|
|
+
|
|
|
+ # 查找并关联数据库角色
|
|
|
+ if local_role_codes:
|
|
|
+ stmt = select(Role).where(Role.code.in_(local_role_codes))
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ db_roles = result.fetchall()
|
|
|
+ db_role_list = [r[0] for r in db_roles]
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] 找到数据库角色: {[r.code for r in db_role_list]}")
|
|
|
+
|
|
|
+ # 清除用户现有角色
|
|
|
+ stmt = select(UserRole).where(UserRole.user_id == user.id)
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ existing_roles = result.fetchall()
|
|
|
+ for er in existing_roles:
|
|
|
+ await db.delete(er[0])
|
|
|
+
|
|
|
+ # 添加新角色
|
|
|
+ for db_role in db_role_list:
|
|
|
+ user_role = UserRole(user_id=user.id, role_id=db_role.id)
|
|
|
+ db.add(user_role)
|
|
|
+
|
|
|
+ # 设置超级管理员标志
|
|
|
+ if "admin" in local_role_codes:
|
|
|
+ user.is_superuser = True
|
|
|
+
|
|
|
+ await db.commit()
|
|
|
+ await db.refresh(user)
|
|
|
+
|
|
|
+ # 重新加载用户角色
|
|
|
+ stmt = select(Role).join(UserRole).where(UserRole.user_id == user.id)
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ user_roles = [r[0].code for r in result.fetchall()]
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] 用户角色已更新: {user_roles}")
|
|
|
+
|
|
|
+ # ========== 步骤4:签发本地 JWT ==========
|
|
|
+ logger.info(f"[exchange-code] 步骤4: 签发本地 JWT")
|
|
|
+ from app.services.jwt_token import create_access_token
|
|
|
+ from app.utils import redis_token_manager as rtm
|
|
|
+
|
|
|
+ access_payload = {
|
|
|
+ "sub": str(user.id),
|
|
|
+ "username": user.username,
|
|
|
+ "email": user.email or "",
|
|
|
+ "is_superuser": user.is_superuser,
|
|
|
+ "roles": user_roles,
|
|
|
+ }
|
|
|
+ access_token = create_access_token(access_payload)
|
|
|
+
|
|
|
+ refresh_payload = {
|
|
|
+ "sub": str(user.id),
|
|
|
+ "type": "refresh",
|
|
|
+ }
|
|
|
+ refresh_token = create_access_token(refresh_payload)
|
|
|
+
|
|
|
+ # 存储 token 到 Redis
|
|
|
+ rtm.store_access_token(access_token, access_payload)
|
|
|
+ rtm.store_refresh_token(refresh_token, str(user.id))
|
|
|
+
|
|
|
+ # ========== 步骤5:返回结果 ==========
|
|
|
+ user_info = {
|
|
|
+ "id": str(user.id),
|
|
|
+ "username": user.username,
|
|
|
+ "email": user.email or "",
|
|
|
+ "phone": user.phone if hasattr(user, "phone") else None,
|
|
|
+ "is_superuser": user.is_superuser,
|
|
|
+ "roles": user_roles,
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"[exchange-code] ========== 授权码交换成功: user={user.username} ==========")
|
|
|
+
|
|
|
+ return ResponseSchema(
|
|
|
+ code="000000",
|
|
|
+ message="登录成功",
|
|
|
+ data={
|
|
|
+ "token": access_token,
|
|
|
+ "refresh_token": refresh_token,
|
|
|
+ "token_type": "bearer",
|
|
|
+ "user": user_info,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[exchange-code] ========== 授权码交换错误 ==========")
|
|
|
+ logger.error(f"[exchange-code] {type(e).__name__}: {str(e)}", exc_info=True)
|
|
|
+ return ResponseSchema(
|
|
|
+ code="500001",
|
|
|
+ message=f"服务器内部错误: {str(e)}",
|
|
|
+ data=None
|
|
|
+ )
|