| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- """
- SSO 免登授权码交换端点
- 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
- """
- import sys
- import os
- # 添加src目录到Python路径
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
- import logging
- import httpx
- from fastapi import APIRouter, Depends
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select
- from pydantic import BaseModel
- from app.base import get_db
- from app.schemas.base import ResponseSchema
- from app.core.config import config_handler
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/api/oauth", tags=["SSO免登"])
- class ExchangeCodeRequest(BaseModel):
- code: str
- @router.post("/exchange-code", response_model=ResponseSchema)
- async def exchange_code(request_data: ExchangeCodeRequest, db: AsyncSession = Depends(get_db)):
- """
- SSO 免登授权码交换端点。
- 前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
- 流程:
- 1. 用 code 调 SSO /oauth/token 换取 SSO access_token
- 2. 用 SSO access_token 调 SSO /oauth/userinfo 获取用户信息+角色
- 3. 同步用户到本地数据库
- 4. 签发本地 JWT
- 5. 返回 { token, refresh_token, user }
- """
- try:
- logger.info(f"[exchange-code] ========== 收到授权码交换请求 ==========")
- logger.info(f"[exchange-code] code={request_data.code[:10]}...")
- # 读取 SSO 配置
- sso_base_url = config_handler.get("admin_sso", "SSO_BASE_URL", "http://localhost:8200")
- sso_client_id = config_handler.get("admin_sso", "SSO_CLIENT_ID", "lqadmin")
- sso_client_secret = config_handler.get("admin_sso", "SSO_CLIENT_SECRET", "")
- sso_redirect_uri = config_handler.get("admin_sso", "REDIRECT_URI", "http://localhost:3000/auth/callback")
- # ========== 步骤1:用 code 换 SSO access_token ==========
- logger.info(f"[exchange-code] 步骤1: 用 code 换 SSO access_token")
- async with httpx.AsyncClient(timeout=10.0) as client:
- token_resp = await client.post(
- f"{sso_base_url}/oauth/token",
- data={
- "grant_type": "authorization_code",
- "code": request_data.code,
- "redirect_uri": sso_redirect_uri,
- "client_id": sso_client_id,
- "client_secret": sso_client_secret,
- },
- )
- token_data = token_resp.json()
- logger.info(f"[exchange-code] SSO token 响应: status={token_resp.status_code}")
- sso_access_token = token_data.get("access_token")
- if not sso_access_token:
- error_desc = token_data.get("error_description", token_data.get("error", "未知错误"))
- logger.warning(f"[exchange-code] 未获取到 SSO access_token: {error_desc}")
- return ResponseSchema(code="400001", message=f"SSO 授权码无效: {error_desc}", data=None)
- logger.info(f"[exchange-code] SSO access_token 获取成功")
- # ========== 步骤2:获取用户信息 ==========
- logger.info(f"[exchange-code] 步骤2: 获取 SSO 用户信息")
- async with httpx.AsyncClient(timeout=10.0) as client:
- userinfo_resp = await client.get(
- f"{sso_base_url}/oauth/userinfo",
- headers={"Authorization": f"Bearer {sso_access_token}"},
- )
- sso_user_info = userinfo_resp.json()
- logger.info(f"[exchange-code] SSO userinfo 响应: {sso_user_info}")
- if "sub" not in sso_user_info:
- logger.warning(f"[exchange-code] SSO userinfo 缺少 sub 字段")
- return ResponseSchema(code="400002", message="SSO 用户信息格式异常", data=None)
- sso_user_id = sso_user_info.get("sub")
- sso_username = sso_user_info.get("username", sso_user_id)
- sso_email = sso_user_info.get("email", "")
- sso_roles = sso_user_info.get("roles", [])
- logger.info(f"[exchange-code] SSO 用户: id={sso_user_id}, username={sso_username}, roles={sso_roles}")
- # ========== 步骤3:同步用户到本地数据库 ==========
- logger.info(f"[exchange-code] 步骤3: 同步用户到本地DB")
- from app.models.user import User, UserProfile, Role, UserRole
- # 查找用户(通过 email 或 username)
- stmt = select(User).where(User.email == sso_email)
- result = await db.execute(stmt)
- user = result.scalar_one_or_none()
- if not user and sso_username:
- stmt = select(User).where(User.username == sso_username)
- result = await db.execute(stmt)
- user = result.scalar_one_or_none()
- if user:
- logger.info(f"[exchange-code] 更新已有用户: id={user.id}, username={user.username}")
- else:
- import bcrypt
- logger.info(f"[exchange-code] 创建新用户: username={sso_username}")
- default_password = "SsoLogin@123"
- hashed_password = bcrypt.hashpw(default_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
- user = User(
- username=sso_username,
- email=sso_email or f"{sso_username}@sso.local",
- password_hash=hashed_password,
- is_active=True,
- is_superuser=False,
- )
- db.add(user)
- await db.flush()
- logger.info(f"[exchange-code] 新用户创建成功: id={user.id}")
- # 创建用户档案
- profile = UserProfile(
- user_id=user.id,
- real_name=sso_user_info.get("real_name", sso_username),
- company=sso_user_info.get("company", ""),
- department=sso_user_info.get("department", ""),
- position=sso_user_info.get("position", ""),
- )
- db.add(profile)
- # 处理角色映射
- logger.info(f"[exchange-code] 处理角色映射: sso_roles={sso_roles}")
- SSO_ROLE_MAPPING = {
- "ann_sys_admin": "admin",
- "ann_operator": "annotator",
- "ann_viewer": "viewer",
- "标注管理员": "admin",
- "标注员": "annotator",
- "查看者": "viewer",
- "super_admin": "admin",
- "admin": "admin",
- }
- # 从 roles 列表中提取 code 和 name 进行映射
- local_role_codes = []
- for role_item in sso_roles:
- if isinstance(role_item, dict):
- code = role_item.get("code", "")
- name = role_item.get("name", "")
- else:
- code = str(role_item)
- name = code
- mapped = SSO_ROLE_MAPPING.get(code) or SSO_ROLE_MAPPING.get(name)
- if mapped and mapped not in local_role_codes:
- local_role_codes.append(mapped)
- logger.info(f"[exchange-code] 映射后的本地角色: {local_role_codes}")
- # 查找并关联数据库角色
- if local_role_codes:
- stmt = select(Role).where(Role.code.in_(local_role_codes))
- result = await db.execute(stmt)
- db_roles = result.fetchall()
- db_role_list = [r[0] for r in db_roles]
- logger.info(f"[exchange-code] 找到数据库角色: {[r.code for r in db_role_list]}")
- # 清除用户现有角色
- stmt = select(UserRole).where(UserRole.user_id == user.id)
- result = await db.execute(stmt)
- existing_roles = result.fetchall()
- for er in existing_roles:
- await db.delete(er[0])
- # 添加新角色
- for db_role in db_role_list:
- user_role = UserRole(user_id=user.id, role_id=db_role.id)
- db.add(user_role)
- # 设置超级管理员标志
- if "admin" in local_role_codes:
- user.is_superuser = True
- await db.commit()
- await db.refresh(user)
- # 重新加载用户角色
- stmt = select(Role).join(UserRole).where(UserRole.user_id == user.id)
- result = await db.execute(stmt)
- user_roles = [r[0].code for r in result.fetchall()]
- logger.info(f"[exchange-code] 用户角色已更新: {user_roles}")
- # ========== 步骤4:签发本地 JWT ==========
- logger.info(f"[exchange-code] 步骤4: 签发本地 JWT")
- from app.services.jwt_token import create_access_token
- from app.utils import redis_token_manager as rtm
- access_payload = {
- "sub": str(user.id),
- "username": user.username,
- "email": user.email or "",
- "is_superuser": user.is_superuser,
- "roles": user_roles,
- "client_id": sso_client_id,
- }
- access_token = create_access_token(access_payload)
- refresh_payload = {
- "sub": str(user.id),
- "type": "refresh",
- }
- refresh_token = create_access_token(refresh_payload)
- # 存储 token 到 Redis(admin 通道)
- rtm.store_access_token(access_token, access_payload)
- # 同时存储 OAuth 通道 key,使 /oauth/userinfo 端点能验证该 token
- rtm.store_oauth_access_token(access_token, sso_client_id, str(user.id))
- rtm.store_refresh_token(refresh_token, str(user.id))
- # ========== 步骤5:返回结果 ==========
- user_info = {
- "id": str(user.id),
- "username": user.username,
- "email": user.email or "",
- "phone": user.phone if hasattr(user, "phone") else None,
- "is_superuser": user.is_superuser,
- "roles": user_roles,
- }
- logger.info(f"[exchange-code] ========== 授权码交换成功: user={user.username} ==========")
- return ResponseSchema(
- code="000000",
- message="登录成功",
- data={
- "token": access_token,
- "refresh_token": refresh_token,
- "token_type": "bearer",
- "user": user_info,
- }
- )
- except Exception as e:
- logger.error(f"[exchange-code] ========== 授权码交换错误 ==========")
- logger.error(f"[exchange-code] {type(e).__name__}: {str(e)}", exc_info=True)
- return ResponseSchema(
- code="500001",
- message=f"服务器内部错误: {str(e)}",
- data=None
- )
|