""" 认证API端点 """ import sys import os # 添加src目录到Python路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..')) from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession from typing import Optional from app.base import get_db from app.schemas.auth import ( LoginRequest, TokenResponse, RefreshTokenRequest, LogoutRequest, UserInfoResponse, CaptchaResponse ) from app.services.auth_service import AuthService from app.core.exceptions import AuthenticationError, ValidationError from app.schemas.base import ResponseSchema import base64 import io from PIL import Image, ImageDraw, ImageFont import random import string import logging # 配置日志记录器 logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["登录认证"]) security = HTTPBearer() def get_client_ip(request: Request) -> str: """获取客户端IP地址""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() return request.client.host def get_user_agent(request: Request) -> str: """获取用户代理""" return request.headers.get("User-Agent", "") @router.post("/login", response_model=ResponseSchema) async def login( request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db) ): """用户登录""" try: logger.info(f"收到登录请求: username={login_data.username}") logger.debug(f"登录详情: username={login_data.username}, ip={get_client_ip(request)}") auth_service = AuthService(db) user, token_response = await auth_service.authenticate_user( username=login_data.username, password=login_data.password, ip_address=get_client_ip(request), user_agent=get_user_agent(request) ) logger.info(f"登录成功: username={login_data.username}, user_id={user.id}") return ResponseSchema( code="000000", message="登录成功", data=token_response.dict() ) except AuthenticationError as e: logger.warning(f"认证失败: username={login_data.username}, reason={e.message}") return ResponseSchema( code=e.code, message=e.message, data=None ) except Exception as e: logger.error(f"登录错误: {type(e).__name__}: {str(e)}", exc_info=True) return ResponseSchema( code="500001", message=f"服务器内部错误: {str(e)}", data=None ) @router.post("/refresh", response_model=ResponseSchema) async def refresh_token( refresh_data: RefreshTokenRequest, db: AsyncSession = Depends(get_db) ): """刷新访问令牌""" try: auth_service = AuthService(db) token_response = await auth_service.refresh_access_token( refresh_token=refresh_data.refresh_token ) return ResponseSchema( code="000000", message="令牌刷新成功", data=token_response.dict() ) except AuthenticationError as e: return ResponseSchema( code=e.code, message=e.message, data=None ) except Exception as e: return ResponseSchema( code="500001", message="服务器内部错误", data=None ) @router.post("/logout", response_model=ResponseSchema) async def logout( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), logout_data: Optional[LogoutRequest] = None ): """用户登出""" try: auth_service = AuthService(db) # 优先使用Authorization头中的token token = credentials.credentials if credentials else None refresh_token = None # 如果提供了logout_data,使用其中的token和refresh_token if logout_data: if not token and logout_data.token: token = logout_data.token if logout_data.refresh_token: refresh_token = logout_data.refresh_token await auth_service.logout( token=token, refresh_token=refresh_token ) return ResponseSchema( code="000000", message="登出成功", data=None ) except Exception as e: logger.exception("登出内部错误") return ResponseSchema( code="500001", message="服务器内部错误", data=None ) @router.get("/userinfo", response_model=ResponseSchema) async def get_user_info( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """获取用户信息""" try: auth_service = AuthService(db) # get_current_user 现在返回 (user, new_token) 元组 user, new_token = await auth_service.get_current_user(credentials.credentials) logger.info(f"user={user.username}, new_token={'存在' if new_token else '无'}") logger.info(f"user.id={user.id}") user_info = await auth_service.get_user_info(user) response_data = ResponseSchema( code="000000", message="获取用户信息成功", data=user_info.dict() ) # 如果token被刷新,添加到响应中 if new_token: response_dict = response_data.dict() response_dict["token_refreshed"] = True response_dict["new_token"] = new_token logger.info(f"用户信息API - Token已刷新,用户: {user.username}") return response_dict return response_data except AuthenticationError as e: return ResponseSchema( code=e.code, message=e.message, data=None ) except Exception as e: logger.exception("获取用户信息内部错误") return ResponseSchema( code="500001", message="服务器内部错误", data=None ) @router.get("/captcha", response_model=ResponseSchema) async def get_captcha(): """获取验证码""" try: # 生成随机验证码 captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4)) # 创建图片 width, height = 120, 40 image = Image.new('RGB', (width, height), color='white') draw = ImageDraw.Draw(image) # 绘制验证码文字 try: # 尝试使用系统字体 font = ImageFont.truetype("arial.ttf", 20) except: # 使用默认字体 font = ImageFont.load_default() # 计算文字位置 text_width = draw.textlength(captcha_text, font=font) text_height = 20 x = (width - text_width) // 2 y = (height - text_height) // 2 # 绘制文字 draw.text((x, y), captcha_text, fill='black', font=font) # 添加干扰线 for _ in range(5): x1 = random.randint(0, width) y1 = random.randint(0, height) x2 = random.randint(0, width) y2 = random.randint(0, height) draw.line([(x1, y1), (x2, y2)], fill='gray', width=1) # 转换为base64 buffer = io.BytesIO() image.save(buffer, format='PNG') image_base64 = base64.b64encode(buffer.getvalue()).decode() # 生成验证码ID(实际应用中应该存储到Redis等缓存中) captcha_id = ''.join(random.choices(string.ascii_letters + string.digits, k=32)) captcha_response = CaptchaResponse( captcha_id=captcha_id, captcha_image=f"data:image/png;base64,{image_base64}" ) return ResponseSchema( code="000000", message="获取验证码成功", data=captcha_response.dict() ) except Exception as e: return ResponseSchema( code="500001", message="生成验证码失败", data=None ) @router.get("/me", response_model=ResponseSchema) async def get_current_user_info( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db) ): """获取当前用户信息""" return await get_user_info(credentials, db)