auth_service.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. """
  2. Authentication Service for user management and authentication.
  3. Handles user registration, login, token refresh, and user queries.
  4. """
  5. import bcrypt
  6. import uuid
  7. from typing import Dict
  8. from database import get_db_connection
  9. from models import User
  10. from services.jwt_service import JWTService
  11. from fastapi import HTTPException, status
  12. class AuthService:
  13. """Service for authentication operations."""
  14. @staticmethod
  15. def register_user(username: str, email: str, password: str) -> User:
  16. """
  17. Register a new user.
  18. Args:
  19. username: Unique username
  20. email: Unique email address
  21. password: Plain text password
  22. Returns:
  23. Created User object
  24. Raises:
  25. HTTPException: 409 if username or email already exists
  26. """
  27. with get_db_connection() as conn:
  28. cursor = conn.cursor()
  29. # Check username uniqueness
  30. cursor.execute(
  31. "SELECT id FROM users WHERE username = ?",
  32. (username,)
  33. )
  34. if cursor.fetchone():
  35. raise HTTPException(
  36. status_code=status.HTTP_409_CONFLICT,
  37. detail="用户名已被使用"
  38. )
  39. # Check email uniqueness
  40. cursor.execute(
  41. "SELECT id FROM users WHERE email = ?",
  42. (email,)
  43. )
  44. if cursor.fetchone():
  45. raise HTTPException(
  46. status_code=status.HTTP_409_CONFLICT,
  47. detail="邮箱已被使用"
  48. )
  49. # Hash password
  50. password_hash = bcrypt.hashpw(
  51. password.encode('utf-8'),
  52. bcrypt.gensalt()
  53. ).decode('utf-8')
  54. # Create user
  55. user_id = f"user_{uuid.uuid4().hex[:12]}"
  56. cursor.execute("""
  57. INSERT INTO users (
  58. id, username, email, password_hash, role
  59. )
  60. VALUES (?, ?, ?, ?, ?)
  61. """, (user_id, username, email, password_hash, "annotator"))
  62. # Fetch created user
  63. cursor.execute(
  64. "SELECT * FROM users WHERE id = ?",
  65. (user_id,)
  66. )
  67. row = cursor.fetchone()
  68. return User.from_row(row)
  69. @staticmethod
  70. def login_user(username: str, password: str) -> Dict:
  71. """
  72. Authenticate user and generate tokens.
  73. Args:
  74. username: Username
  75. password: Plain text password
  76. Returns:
  77. Dict containing access_token, refresh_token, and user info
  78. Raises:
  79. HTTPException: 401 if credentials are invalid
  80. """
  81. with get_db_connection() as conn:
  82. cursor = conn.cursor()
  83. # Find user
  84. cursor.execute(
  85. "SELECT * FROM users WHERE username = ?",
  86. (username,)
  87. )
  88. row = cursor.fetchone()
  89. if not row:
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail="用户名或密码错误"
  93. )
  94. user = User.from_row(row)
  95. # Verify password
  96. if not bcrypt.checkpw(
  97. password.encode('utf-8'),
  98. user.password_hash.encode('utf-8')
  99. ):
  100. raise HTTPException(
  101. status_code=status.HTTP_401_UNAUTHORIZED,
  102. detail="用户名或密码错误"
  103. )
  104. # Generate tokens
  105. user_data = {
  106. "id": user.id,
  107. "username": user.username,
  108. "email": user.email,
  109. "role": user.role,
  110. "created_at": user.created_at
  111. }
  112. access_token = JWTService.create_access_token(user_data)
  113. refresh_token = JWTService.create_refresh_token(user_data)
  114. return {
  115. "access_token": access_token,
  116. "refresh_token": refresh_token,
  117. "user": user_data
  118. }
  119. @staticmethod
  120. def refresh_tokens(refresh_token: str) -> Dict:
  121. """
  122. Refresh access token using refresh token.
  123. Args:
  124. refresh_token: Valid refresh token
  125. Returns:
  126. Dict containing new access_token and refresh_token
  127. Raises:
  128. HTTPException: 401 if refresh token is invalid or expired
  129. """
  130. try:
  131. payload = JWTService.verify_token(refresh_token, "refresh")
  132. if not payload:
  133. raise HTTPException(
  134. status_code=status.HTTP_401_UNAUTHORIZED,
  135. detail="无效的刷新令牌"
  136. )
  137. user_id = payload["sub"]
  138. # Fetch user from database
  139. with get_db_connection() as conn:
  140. cursor = conn.cursor()
  141. cursor.execute(
  142. "SELECT * FROM users WHERE id = ?",
  143. (user_id,)
  144. )
  145. row = cursor.fetchone()
  146. if not row:
  147. raise HTTPException(
  148. status_code=status.HTTP_401_UNAUTHORIZED,
  149. detail="用户不存在"
  150. )
  151. user = User.from_row(row)
  152. user_data = {
  153. "id": user.id,
  154. "username": user.username,
  155. "email": user.email,
  156. "role": user.role,
  157. "created_at": user.created_at
  158. }
  159. # Generate new tokens (token rotation)
  160. new_access_token = JWTService.create_access_token(user_data)
  161. new_refresh_token = JWTService.create_refresh_token(user_data)
  162. return {
  163. "access_token": new_access_token,
  164. "refresh_token": new_refresh_token,
  165. "user": user_data
  166. }
  167. except Exception as e:
  168. raise HTTPException(
  169. status_code=status.HTTP_401_UNAUTHORIZED,
  170. detail="刷新令牌已过期或无效,请重新登录"
  171. )
  172. @staticmethod
  173. def get_current_user(user_id: str) -> User:
  174. """
  175. Get user by ID.
  176. Args:
  177. user_id: User unique identifier
  178. Returns:
  179. User object
  180. Raises:
  181. HTTPException: 404 if user not found
  182. """
  183. with get_db_connection() as conn:
  184. cursor = conn.cursor()
  185. cursor.execute(
  186. "SELECT * FROM users WHERE id = ?",
  187. (user_id,)
  188. )
  189. row = cursor.fetchone()
  190. if not row:
  191. raise HTTPException(
  192. status_code=status.HTTP_404_NOT_FOUND,
  193. detail="用户不存在"
  194. )
  195. return User.from_row(row)