| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- """
- 认证API端点
- """
- import sys
- import os
- # 添加src目录到Python路径
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
- from fastapi import APIRouter, Depends, HTTPException, Request, Response
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
- from sqlalchemy.ext.asyncio import AsyncSession
- from typing import Optional
- from app.base import get_db
- from app.schemas.auth import (
- LoginRequest,
- TokenResponse,
- RefreshTokenRequest,
- LogoutRequest,
- UserInfoResponse,
- CaptchaResponse
- )
- from app.services.auth_service import AuthService
- from app.core.exceptions import AuthenticationError, ValidationError
- from app.schemas.base import ResponseSchema
- import base64
- import io
- from PIL import Image, ImageDraw, ImageFont
- import random
- import string
- import logging
- # 配置日志记录器
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/auth", tags=["登录认证"])
- security = HTTPBearer()
- def get_client_ip(request: Request) -> str:
- """获取客户端IP地址"""
- forwarded = request.headers.get("X-Forwarded-For")
- if forwarded:
- return forwarded.split(",")[0].strip()
- return request.client.host
- def get_user_agent(request: Request) -> str:
- """获取用户代理"""
- return request.headers.get("User-Agent", "")
- @router.post("/login", response_model=ResponseSchema)
- async def login(
- request: Request,
- login_data: LoginRequest,
- db: AsyncSession = Depends(get_db)
- ):
- """用户登录"""
- try:
- logger.info(f"收到登录请求: username={login_data.username}")
- logger.debug(f"登录详情: username={login_data.username}, ip={get_client_ip(request)}")
-
- auth_service = AuthService(db)
-
- user, token_response = await auth_service.authenticate_user(
- username=login_data.username,
- password=login_data.password,
- ip_address=get_client_ip(request),
- user_agent=get_user_agent(request)
- )
-
- logger.info(f"登录成功: username={login_data.username}, user_id={user.id}")
- return ResponseSchema(
- code="000000",
- message="登录成功",
- data=token_response.dict()
- )
-
- except AuthenticationError as e:
- logger.warning(f"认证失败: username={login_data.username}, reason={e.message}")
- return ResponseSchema(
- code=e.code,
- message=e.message,
- data=None
- )
- except Exception as e:
- logger.error(f"登录错误: {type(e).__name__}: {str(e)}", exc_info=True)
- return ResponseSchema(
- code="500001",
- message=f"服务器内部错误: {str(e)}",
- data=None
- )
- @router.post("/refresh", response_model=ResponseSchema)
- async def refresh_token(
- refresh_data: RefreshTokenRequest,
- db: AsyncSession = Depends(get_db)
- ):
- """刷新访问令牌"""
- try:
- auth_service = AuthService(db)
-
- token_response = await auth_service.refresh_access_token(
- refresh_token=refresh_data.refresh_token
- )
-
- return ResponseSchema(
- code="000000",
- message="令牌刷新成功",
- data=token_response.dict()
- )
-
- except AuthenticationError as e:
- return ResponseSchema(
- code=e.code,
- message=e.message,
- data=None
- )
- except Exception as e:
- return ResponseSchema(
- code="500001",
- message="服务器内部错误",
- data=None
- )
- @router.post("/logout", response_model=ResponseSchema)
- async def logout(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db),
- logout_data: Optional[LogoutRequest] = None
- ):
- """用户登出"""
- try:
- auth_service = AuthService(db)
-
- # 优先使用Authorization头中的token
- token = credentials.credentials if credentials else None
- refresh_token = None
-
- # 如果提供了logout_data,使用其中的token和refresh_token
- if logout_data:
- if not token and logout_data.token:
- token = logout_data.token
- if logout_data.refresh_token:
- refresh_token = logout_data.refresh_token
-
- await auth_service.logout(
- token=token,
- refresh_token=refresh_token
- )
-
- return ResponseSchema(
- code="000000",
- message="登出成功",
- data=None
- )
-
- except Exception as e:
- logger.exception("登出内部错误")
- return ResponseSchema(
- code="500001",
- message="服务器内部错误",
- data=None
- )
- @router.get("/userinfo", response_model=ResponseSchema)
- async def get_user_info(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """获取用户信息"""
- try:
- auth_service = AuthService(db)
-
- # get_current_user 现在返回 (user, new_token) 元组
- user, new_token = await auth_service.get_current_user(credentials.credentials)
- logger.info(f"user={user.username}, new_token={'存在' if new_token else '无'}")
- logger.info(f"user.id={user.id}")
-
- user_info = await auth_service.get_user_info(user)
-
- response_data = ResponseSchema(
- code="000000",
- message="获取用户信息成功",
- data=user_info.dict()
- )
-
- # 如果token被刷新,添加到响应中
- if new_token:
- response_dict = response_data.dict()
- response_dict["token_refreshed"] = True
- response_dict["new_token"] = new_token
- logger.info(f"用户信息API - Token已刷新,用户: {user.username}")
- return response_dict
-
- return response_data
-
- except AuthenticationError as e:
- return ResponseSchema(
- code=e.code,
- message=e.message,
- data=None
- )
- except Exception as e:
- logger.exception("获取用户信息内部错误")
- return ResponseSchema(
- code="500001",
- message="服务器内部错误",
- data=None
- )
- @router.get("/captcha", response_model=ResponseSchema)
- async def get_captcha():
- """获取验证码"""
- try:
- # 生成随机验证码
- captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
-
- # 创建图片
- width, height = 120, 40
- image = Image.new('RGB', (width, height), color='white')
- draw = ImageDraw.Draw(image)
-
- # 绘制验证码文字
- try:
- # 尝试使用系统字体
- font = ImageFont.truetype("arial.ttf", 20)
- except:
- # 使用默认字体
- font = ImageFont.load_default()
-
- # 计算文字位置
- text_width = draw.textlength(captcha_text, font=font)
- text_height = 20
- x = (width - text_width) // 2
- y = (height - text_height) // 2
-
- # 绘制文字
- draw.text((x, y), captcha_text, fill='black', font=font)
-
- # 添加干扰线
- for _ in range(5):
- x1 = random.randint(0, width)
- y1 = random.randint(0, height)
- x2 = random.randint(0, width)
- y2 = random.randint(0, height)
- draw.line([(x1, y1), (x2, y2)], fill='gray', width=1)
-
- # 转换为base64
- buffer = io.BytesIO()
- image.save(buffer, format='PNG')
- image_base64 = base64.b64encode(buffer.getvalue()).decode()
-
- # 生成验证码ID(实际应用中应该存储到Redis等缓存中)
- captcha_id = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
-
- captcha_response = CaptchaResponse(
- captcha_id=captcha_id,
- captcha_image=f"data:image/png;base64,{image_base64}"
- )
-
- return ResponseSchema(
- code="000000",
- message="获取验证码成功",
- data=captcha_response.dict()
- )
-
- except Exception as e:
- return ResponseSchema(
- code="500001",
- message="生成验证码失败",
- data=None
- )
- @router.get("/me", response_model=ResponseSchema)
- async def get_current_user_info(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: AsyncSession = Depends(get_db)
- ):
- """获取当前用户信息"""
- return await get_user_info(credentials, db)
|