Przeglądaj źródła

增加日志打印信息

lingmin_package@163.com 3 tygodni temu
rodzic
commit
8223532171
1 zmienionych plików z 261 dodań i 0 usunięć
  1. 261 0
      src/views/oauth_exchange_view.py

+ 261 - 0
src/views/oauth_exchange_view.py

@@ -0,0 +1,261 @@
+"""
+SSO 免登授权码交换端点
+前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
+"""
+import sys
+import os
+
+# 添加src目录到Python路径
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
+
+import logging
+import httpx
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+from pydantic import BaseModel
+
+from app.base import get_db
+from app.schemas.base import ResponseSchema
+from app.core.config import config_handler
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/oauth", tags=["SSO免登"])
+
+
+class ExchangeCodeRequest(BaseModel):
+    code: str
+
+
+@router.post("/exchange-code", response_model=ResponseSchema)
+async def exchange_code(request_data: ExchangeCodeRequest, db: AsyncSession = Depends(get_db)):
+    """
+    SSO 免登授权码交换端点。
+    前端从 SSO 回调拿到 code 后,调用此接口换取本地 JWT。
+
+    流程:
+    1. 用 code 调 SSO /oauth/token 换取 SSO access_token
+    2. 用 SSO access_token 调 SSO /oauth/userinfo 获取用户信息+角色
+    3. 同步用户到本地数据库
+    4. 签发本地 JWT
+    5. 返回 { token, refresh_token, user }
+    """
+    try:
+        logger.info(f"[exchange-code] ========== 收到授权码交换请求 ==========")
+        logger.info(f"[exchange-code] code={request_data.code[:10]}...")
+
+        # 读取 SSO 配置
+        sso_base_url = config_handler.get("admin_sso", "SSO_BASE_URL", "http://localhost:8200")
+        sso_client_id = config_handler.get("admin_sso", "SSO_CLIENT_ID", "lqadmin")
+        sso_client_secret = config_handler.get("admin_sso", "SSO_CLIENT_SECRET", "")
+        sso_redirect_uri = config_handler.get("admin_sso", "REDIRECT_URI", "http://localhost:3000/auth/callback")
+
+        # ========== 步骤1:用 code 换 SSO access_token ==========
+        logger.info(f"[exchange-code] 步骤1: 用 code 换 SSO access_token")
+        async with httpx.AsyncClient(timeout=10.0) as client:
+            token_resp = await client.post(
+                f"{sso_base_url}/oauth/token",
+                data={
+                    "grant_type": "authorization_code",
+                    "code": request_data.code,
+                    "redirect_uri": sso_redirect_uri,
+                    "client_id": sso_client_id,
+                    "client_secret": sso_client_secret,
+                },
+            )
+            token_data = token_resp.json()
+            logger.info(f"[exchange-code] SSO token 响应: status={token_resp.status_code}")
+
+        sso_access_token = token_data.get("access_token")
+        if not sso_access_token:
+            error_desc = token_data.get("error_description", token_data.get("error", "未知错误"))
+            logger.warning(f"[exchange-code] 未获取到 SSO access_token: {error_desc}")
+            return ResponseSchema(code="400001", message=f"SSO 授权码无效: {error_desc}", data=None)
+
+        logger.info(f"[exchange-code] SSO access_token 获取成功")
+
+        # ========== 步骤2:获取用户信息 ==========
+        logger.info(f"[exchange-code] 步骤2: 获取 SSO 用户信息")
+        async with httpx.AsyncClient(timeout=10.0) as client:
+            userinfo_resp = await client.get(
+                f"{sso_base_url}/oauth/userinfo",
+                headers={"Authorization": f"Bearer {sso_access_token}"},
+            )
+            sso_user_info = userinfo_resp.json()
+            logger.info(f"[exchange-code] SSO userinfo 响应: {sso_user_info}")
+
+        if "sub" not in sso_user_info:
+            logger.warning(f"[exchange-code] SSO userinfo 缺少 sub 字段")
+            return ResponseSchema(code="400002", message="SSO 用户信息格式异常", data=None)
+
+        sso_user_id = sso_user_info.get("sub")
+        sso_username = sso_user_info.get("username", sso_user_id)
+        sso_email = sso_user_info.get("email", "")
+        sso_roles = sso_user_info.get("roles", [])
+        logger.info(f"[exchange-code] SSO 用户: id={sso_user_id}, username={sso_username}, roles={sso_roles}")
+
+        # ========== 步骤3:同步用户到本地数据库 ==========
+        logger.info(f"[exchange-code] 步骤3: 同步用户到本地DB")
+        from app.models.user import User, UserProfile, Role, UserRole
+
+        # 查找用户(通过 email 或 username)
+        stmt = select(User).where(User.email == sso_email)
+        result = await db.execute(stmt)
+        user = result.scalar_one_or_none()
+
+        if not user and sso_username:
+            stmt = select(User).where(User.username == sso_username)
+            result = await db.execute(stmt)
+            user = result.scalar_one_or_none()
+
+        if user:
+            logger.info(f"[exchange-code] 更新已有用户: id={user.id}, username={user.username}")
+        else:
+            import bcrypt
+
+            logger.info(f"[exchange-code] 创建新用户: username={sso_username}")
+            default_password = "SsoLogin@123"
+            hashed_password = bcrypt.hashpw(default_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
+
+            user = User(
+                username=sso_username,
+                email=sso_email or f"{sso_username}@sso.local",
+                password_hash=hashed_password,
+                is_active=True,
+                is_superuser=False,
+            )
+            db.add(user)
+            await db.flush()
+            logger.info(f"[exchange-code] 新用户创建成功: id={user.id}")
+
+            # 创建用户档案
+            profile = UserProfile(
+                user_id=user.id,
+                real_name=sso_user_info.get("real_name", sso_username),
+                company=sso_user_info.get("company", ""),
+                department=sso_user_info.get("department", ""),
+                position=sso_user_info.get("position", ""),
+            )
+            db.add(profile)
+
+        # 处理角色映射
+        logger.info(f"[exchange-code] 处理角色映射: sso_roles={sso_roles}")
+        SSO_ROLE_MAPPING = {
+            "label_admin": "admin",
+            "annotator": "annotator",
+            "viewer": "viewer",
+            "标注管理员": "admin",
+            "标注员": "annotator",
+            "查看者": "viewer",
+            "super_admin": "admin",
+            "admin": "admin",
+        }
+
+        # 从 roles 列表中提取 code 和 name 进行映射
+        local_role_codes = []
+        for role_item in sso_roles:
+            if isinstance(role_item, dict):
+                code = role_item.get("code", "")
+                name = role_item.get("name", "")
+            else:
+                code = str(role_item)
+                name = code
+
+            mapped = SSO_ROLE_MAPPING.get(code) or SSO_ROLE_MAPPING.get(name)
+            if mapped and mapped not in local_role_codes:
+                local_role_codes.append(mapped)
+
+        logger.info(f"[exchange-code] 映射后的本地角色: {local_role_codes}")
+
+        # 查找并关联数据库角色
+        if local_role_codes:
+            stmt = select(Role).where(Role.code.in_(local_role_codes))
+            result = await db.execute(stmt)
+            db_roles = result.fetchall()
+            db_role_list = [r[0] for r in db_roles]
+
+            logger.info(f"[exchange-code] 找到数据库角色: {[r.code for r in db_role_list]}")
+
+            # 清除用户现有角色
+            stmt = select(UserRole).where(UserRole.user_id == user.id)
+            result = await db.execute(stmt)
+            existing_roles = result.fetchall()
+            for er in existing_roles:
+                await db.delete(er[0])
+
+            # 添加新角色
+            for db_role in db_role_list:
+                user_role = UserRole(user_id=user.id, role_id=db_role.id)
+                db.add(user_role)
+
+            # 设置超级管理员标志
+            if "admin" in local_role_codes:
+                user.is_superuser = True
+
+        await db.commit()
+        await db.refresh(user)
+
+        # 重新加载用户角色
+        stmt = select(Role).join(UserRole).where(UserRole.user_id == user.id)
+        result = await db.execute(stmt)
+        user_roles = [r[0].code for r in result.fetchall()]
+
+        logger.info(f"[exchange-code] 用户角色已更新: {user_roles}")
+
+        # ========== 步骤4:签发本地 JWT ==========
+        logger.info(f"[exchange-code] 步骤4: 签发本地 JWT")
+        from app.services.jwt_token import create_access_token
+        from app.utils import redis_token_manager as rtm
+
+        access_payload = {
+            "sub": str(user.id),
+            "username": user.username,
+            "email": user.email or "",
+            "is_superuser": user.is_superuser,
+            "roles": user_roles,
+        }
+        access_token = create_access_token(access_payload)
+
+        refresh_payload = {
+            "sub": str(user.id),
+            "type": "refresh",
+        }
+        refresh_token = create_access_token(refresh_payload)
+
+        # 存储 token 到 Redis
+        rtm.store_access_token(access_token, access_payload)
+        rtm.store_refresh_token(refresh_token, str(user.id))
+
+        # ========== 步骤5:返回结果 ==========
+        user_info = {
+            "id": str(user.id),
+            "username": user.username,
+            "email": user.email or "",
+            "phone": user.phone if hasattr(user, "phone") else None,
+            "is_superuser": user.is_superuser,
+            "roles": user_roles,
+        }
+
+        logger.info(f"[exchange-code] ========== 授权码交换成功: user={user.username} ==========")
+
+        return ResponseSchema(
+            code="000000",
+            message="登录成功",
+            data={
+                "token": access_token,
+                "refresh_token": refresh_token,
+                "token_type": "bearer",
+                "user": user_info,
+            }
+        )
+
+    except Exception as e:
+        logger.error(f"[exchange-code] ========== 授权码交换错误 ==========")
+        logger.error(f"[exchange-code] {type(e).__name__}: {str(e)}", exc_info=True)
+        return ResponseSchema(
+            code="500001",
+            message=f"服务器内部错误: {str(e)}",
+            data=None
+        )