auth_service.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. Authentication Service for user management and authentication.
  3. Handles user registration, login, token refresh, and user queries.
  4. """
  5. import bcrypt
  6. import uuid
  7. from typing import Dict
  8. from database import get_db_connection
  9. from models import User
  10. from services.jwt_service import JWTService
  11. from fastapi import HTTPException, status
  12. class AuthService:
  13. """Service for authentication operations."""
  14. @staticmethod
  15. def register_user(username: str, email: str, password: str) -> User:
  16. """
  17. Register a new user.
  18. Args:
  19. username: Unique username
  20. email: Unique email address
  21. password: Plain text password
  22. Returns:
  23. Created User object
  24. Raises:
  25. HTTPException: 409 if username or email already exists
  26. """
  27. with get_db_connection() as conn:
  28. cursor = conn.cursor()
  29. # Check username uniqueness
  30. cursor.execute(
  31. "SELECT id FROM users WHERE username = ?",
  32. (username,)
  33. )
  34. if cursor.fetchone():
  35. raise HTTPException(
  36. status_code=status.HTTP_409_CONFLICT,
  37. detail="用户名已被使用"
  38. )
  39. # Check email uniqueness
  40. cursor.execute(
  41. "SELECT id FROM users WHERE email = ?",
  42. (email,)
  43. )
  44. if cursor.fetchone():
  45. raise HTTPException(
  46. status_code=status.HTTP_409_CONFLICT,
  47. detail="邮箱已被使用"
  48. )
  49. # Hash password
  50. password_hash = bcrypt.hashpw(
  51. password.encode('utf-8'),
  52. bcrypt.gensalt()
  53. ).decode('utf-8')
  54. # Create user
  55. user_id = f"user_{uuid.uuid4().hex[:12]}"
  56. cursor.execute("""
  57. INSERT INTO users (
  58. id, username, email, password_hash, role
  59. )
  60. VALUES (?, ?, ?, ?, ?)
  61. """, (user_id, username, email, password_hash, "annotator"))
  62. # Fetch created user
  63. cursor.execute(
  64. "SELECT * FROM users WHERE id = ?",
  65. (user_id,)
  66. )
  67. row = cursor.fetchone()
  68. return User.from_row(row)
  69. @staticmethod
  70. def login_user(username: str, password: str) -> Dict:
  71. """
  72. Authenticate user and generate tokens.
  73. Args:
  74. username: Username
  75. password: Plain text password
  76. Returns:
  77. Dict containing access_token, refresh_token, and user info
  78. Raises:
  79. HTTPException: 401 if credentials are invalid
  80. """
  81. with get_db_connection() as conn:
  82. cursor = conn.cursor()
  83. # Find user
  84. cursor.execute(
  85. "SELECT * FROM users WHERE username = ?",
  86. (username,)
  87. )
  88. row = cursor.fetchone()
  89. if not row:
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail="用户名或密码错误"
  93. )
  94. user = User.from_row(row)
  95. # Verify password
  96. try:
  97. password_valid = bcrypt.checkpw(
  98. password.encode('utf-8'),
  99. user.password_hash.encode('utf-8')
  100. )
  101. except (ValueError, TypeError) as e:
  102. # Invalid salt or hash format - password hash is corrupted or not bcrypt
  103. # This can happen if password was stored in plaintext or different format
  104. raise HTTPException(
  105. status_code=status.HTTP_401_UNAUTHORIZED,
  106. detail="用户名或密码错误"
  107. )
  108. if not password_valid:
  109. raise HTTPException(
  110. status_code=status.HTTP_401_UNAUTHORIZED,
  111. detail="用户名或密码错误"
  112. )
  113. # Generate tokens
  114. user_data = {
  115. "id": user.id,
  116. "username": user.username,
  117. "email": user.email,
  118. "role": user.role,
  119. "created_at": user.created_at
  120. }
  121. access_token = JWTService.create_access_token(user_data)
  122. refresh_token = JWTService.create_refresh_token(user_data)
  123. return {
  124. "access_token": access_token,
  125. "refresh_token": refresh_token,
  126. "user": user_data
  127. }
  128. @staticmethod
  129. def refresh_tokens(refresh_token: str) -> Dict:
  130. """
  131. Refresh access token using refresh token.
  132. Args:
  133. refresh_token: Valid refresh token
  134. Returns:
  135. Dict containing new access_token and refresh_token
  136. Raises:
  137. HTTPException: 401 if refresh token is invalid or expired
  138. """
  139. try:
  140. payload = JWTService.verify_token(refresh_token, "refresh")
  141. if not payload:
  142. raise HTTPException(
  143. status_code=status.HTTP_401_UNAUTHORIZED,
  144. detail="无效的刷新令牌"
  145. )
  146. user_id = payload["sub"]
  147. # Fetch user from database
  148. with get_db_connection() as conn:
  149. cursor = conn.cursor()
  150. cursor.execute(
  151. "SELECT * FROM users WHERE id = ?",
  152. (user_id,)
  153. )
  154. row = cursor.fetchone()
  155. if not row:
  156. raise HTTPException(
  157. status_code=status.HTTP_401_UNAUTHORIZED,
  158. detail="用户不存在"
  159. )
  160. user = User.from_row(row)
  161. user_data = {
  162. "id": user.id,
  163. "username": user.username,
  164. "email": user.email,
  165. "role": user.role,
  166. "created_at": user.created_at
  167. }
  168. # Generate new tokens (token rotation)
  169. new_access_token = JWTService.create_access_token(user_data)
  170. new_refresh_token = JWTService.create_refresh_token(user_data)
  171. return {
  172. "access_token": new_access_token,
  173. "refresh_token": new_refresh_token,
  174. "user": user_data
  175. }
  176. except Exception as e:
  177. raise HTTPException(
  178. status_code=status.HTTP_401_UNAUTHORIZED,
  179. detail="刷新令牌已过期或无效,请重新登录"
  180. )
  181. @staticmethod
  182. def get_current_user(user_id: str) -> User:
  183. """
  184. Get user by ID.
  185. Args:
  186. user_id: User unique identifier
  187. Returns:
  188. User object
  189. Raises:
  190. HTTPException: 404 if user not found
  191. """
  192. with get_db_connection() as conn:
  193. cursor = conn.cursor()
  194. cursor.execute(
  195. "SELECT * FROM users WHERE id = ?",
  196. (user_id,)
  197. )
  198. row = cursor.fetchone()
  199. if not row:
  200. raise HTTPException(
  201. status_code=status.HTTP_404_NOT_FOUND,
  202. detail="用户不存在"
  203. )
  204. return User.from_row(row)