auth_view.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """
  2. 认证API端点
  3. """
  4. import sys
  5. import os
  6. # 添加src目录到Python路径
  7. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
  9. from fastapi import APIRouter, Depends, HTTPException, Request, Response
  10. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  11. from sqlalchemy.ext.asyncio import AsyncSession
  12. from typing import Optional
  13. from app.base import get_db
  14. from app.schemas.auth import (
  15. LoginRequest,
  16. TokenResponse,
  17. RefreshTokenRequest,
  18. LogoutRequest,
  19. UserInfoResponse,
  20. CaptchaResponse
  21. )
  22. from app.services.auth_service import AuthService
  23. from app.core.exceptions import AuthenticationError, ValidationError
  24. from app.schemas.base import ResponseSchema
  25. import base64
  26. import io
  27. from PIL import Image, ImageDraw, ImageFont
  28. import random
  29. import string
  30. import logging
  31. # 配置日志记录器
  32. logger = logging.getLogger(__name__)
  33. router = APIRouter(prefix="/auth", tags=["登录认证"])
  34. security = HTTPBearer()
  35. def get_client_ip(request: Request) -> str:
  36. """获取客户端IP地址"""
  37. forwarded = request.headers.get("X-Forwarded-For")
  38. if forwarded:
  39. return forwarded.split(",")[0].strip()
  40. return request.client.host
  41. def get_user_agent(request: Request) -> str:
  42. """获取用户代理"""
  43. return request.headers.get("User-Agent", "")
  44. @router.post("/login", response_model=ResponseSchema)
  45. async def login(
  46. request: Request,
  47. login_data: LoginRequest,
  48. db: AsyncSession = Depends(get_db)
  49. ):
  50. """用户登录"""
  51. try:
  52. logger.info(f"收到登录请求: username={login_data.username}")
  53. logger.debug(f"登录详情: username={login_data.username}, ip={get_client_ip(request)}")
  54. auth_service = AuthService(db)
  55. user, token_response = await auth_service.authenticate_user(
  56. username=login_data.username,
  57. password=login_data.password,
  58. ip_address=get_client_ip(request),
  59. user_agent=get_user_agent(request)
  60. )
  61. logger.info(f"登录成功: username={login_data.username}, user_id={user.id}")
  62. return ResponseSchema(
  63. code="000000",
  64. message="登录成功",
  65. data=token_response.dict()
  66. )
  67. except AuthenticationError as e:
  68. logger.warning(f"认证失败: username={login_data.username}, reason={e.message}")
  69. return ResponseSchema(
  70. code=e.code,
  71. message=e.message,
  72. data=None
  73. )
  74. except Exception as e:
  75. logger.error(f"登录错误: {type(e).__name__}: {str(e)}", exc_info=True)
  76. return ResponseSchema(
  77. code="500001",
  78. message=f"服务器内部错误: {str(e)}",
  79. data=None
  80. )
  81. @router.post("/refresh", response_model=ResponseSchema)
  82. async def refresh_token(
  83. refresh_data: RefreshTokenRequest,
  84. db: AsyncSession = Depends(get_db)
  85. ):
  86. """刷新访问令牌"""
  87. try:
  88. auth_service = AuthService(db)
  89. token_response = await auth_service.refresh_access_token(
  90. refresh_token=refresh_data.refresh_token
  91. )
  92. return ResponseSchema(
  93. code="000000",
  94. message="令牌刷新成功",
  95. data=token_response.dict()
  96. )
  97. except AuthenticationError as e:
  98. return ResponseSchema(
  99. code=e.code,
  100. message=e.message,
  101. data=None
  102. )
  103. except Exception as e:
  104. return ResponseSchema(
  105. code="500001",
  106. message="服务器内部错误",
  107. data=None
  108. )
  109. @router.post("/logout", response_model=ResponseSchema)
  110. async def logout(
  111. credentials: HTTPAuthorizationCredentials = Depends(security),
  112. db: AsyncSession = Depends(get_db),
  113. logout_data: Optional[LogoutRequest] = None
  114. ):
  115. """用户登出"""
  116. try:
  117. auth_service = AuthService(db)
  118. # 优先使用Authorization头中的token
  119. token = credentials.credentials if credentials else None
  120. refresh_token = None
  121. # 如果提供了logout_data,使用其中的token和refresh_token
  122. if logout_data:
  123. if not token and logout_data.token:
  124. token = logout_data.token
  125. if logout_data.refresh_token:
  126. refresh_token = logout_data.refresh_token
  127. await auth_service.logout(
  128. token=token,
  129. refresh_token=refresh_token
  130. )
  131. return ResponseSchema(
  132. code="000000",
  133. message="登出成功",
  134. data=None
  135. )
  136. except Exception as e:
  137. logger.exception("登出内部错误")
  138. return ResponseSchema(
  139. code="500001",
  140. message="服务器内部错误",
  141. data=None
  142. )
  143. @router.get("/userinfo", response_model=ResponseSchema)
  144. async def get_user_info(
  145. credentials: HTTPAuthorizationCredentials = Depends(security),
  146. db: AsyncSession = Depends(get_db)
  147. ):
  148. """获取用户信息"""
  149. try:
  150. auth_service = AuthService(db)
  151. # get_current_user 现在返回 (user, new_token) 元组
  152. user, new_token = await auth_service.get_current_user(credentials.credentials)
  153. logger.info(f"user={user.username}, new_token={'存在' if new_token else '无'}")
  154. logger.info(f"user.id={user.id}")
  155. user_info = await auth_service.get_user_info(user)
  156. response_data = ResponseSchema(
  157. code="000000",
  158. message="获取用户信息成功",
  159. data=user_info.dict()
  160. )
  161. # 如果token被刷新,添加到响应中
  162. if new_token:
  163. response_dict = response_data.dict()
  164. response_dict["token_refreshed"] = True
  165. response_dict["new_token"] = new_token
  166. logger.info(f"用户信息API - Token已刷新,用户: {user.username}")
  167. return response_dict
  168. return response_data
  169. except AuthenticationError as e:
  170. return ResponseSchema(
  171. code=e.code,
  172. message=e.message,
  173. data=None
  174. )
  175. except Exception as e:
  176. logger.exception("获取用户信息内部错误")
  177. return ResponseSchema(
  178. code="500001",
  179. message="服务器内部错误",
  180. data=None
  181. )
  182. @router.get("/captcha", response_model=ResponseSchema)
  183. async def get_captcha():
  184. """获取验证码"""
  185. try:
  186. # 生成随机验证码
  187. captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
  188. # 创建图片
  189. width, height = 120, 40
  190. image = Image.new('RGB', (width, height), color='white')
  191. draw = ImageDraw.Draw(image)
  192. # 绘制验证码文字
  193. try:
  194. # 尝试使用系统字体
  195. font = ImageFont.truetype("arial.ttf", 20)
  196. except:
  197. # 使用默认字体
  198. font = ImageFont.load_default()
  199. # 计算文字位置
  200. text_width = draw.textlength(captcha_text, font=font)
  201. text_height = 20
  202. x = (width - text_width) // 2
  203. y = (height - text_height) // 2
  204. # 绘制文字
  205. draw.text((x, y), captcha_text, fill='black', font=font)
  206. # 添加干扰线
  207. for _ in range(5):
  208. x1 = random.randint(0, width)
  209. y1 = random.randint(0, height)
  210. x2 = random.randint(0, width)
  211. y2 = random.randint(0, height)
  212. draw.line([(x1, y1), (x2, y2)], fill='gray', width=1)
  213. # 转换为base64
  214. buffer = io.BytesIO()
  215. image.save(buffer, format='PNG')
  216. image_base64 = base64.b64encode(buffer.getvalue()).decode()
  217. # 生成验证码ID(实际应用中应该存储到Redis等缓存中)
  218. captcha_id = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
  219. captcha_response = CaptchaResponse(
  220. captcha_id=captcha_id,
  221. captcha_image=f"data:image/png;base64,{image_base64}"
  222. )
  223. return ResponseSchema(
  224. code="000000",
  225. message="获取验证码成功",
  226. data=captcha_response.dict()
  227. )
  228. except Exception as e:
  229. return ResponseSchema(
  230. code="500001",
  231. message="生成验证码失败",
  232. data=None
  233. )
  234. @router.get("/me", response_model=ResponseSchema)
  235. async def get_current_user_info(
  236. credentials: HTTPAuthorizationCredentials = Depends(security),
  237. db: AsyncSession = Depends(get_db)
  238. ):
  239. """获取当前用户信息"""
  240. return await get_user_info(credentials, db)