| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- """
- Authentication Service for user management and authentication.
- Handles user registration, login, token refresh, and user queries.
- """
- import bcrypt
- import uuid
- from typing import Dict
- from database import get_db_connection
- from models import User
- from services.jwt_service import JWTService
- from fastapi import HTTPException, status
- class AuthService:
- """Service for authentication operations."""
-
- @staticmethod
- def register_user(username: str, email: str, password: str) -> User:
- """
- Register a new user.
-
- Args:
- username: Unique username
- email: Unique email address
- password: Plain text password
-
- Returns:
- Created User object
-
- Raises:
- HTTPException: 409 if username or email already exists
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Check username uniqueness
- cursor.execute(
- "SELECT id FROM users WHERE username = ?",
- (username,)
- )
- if cursor.fetchone():
- raise HTTPException(
- status_code=status.HTTP_409_CONFLICT,
- detail="用户名已被使用"
- )
-
- # Check email uniqueness
- cursor.execute(
- "SELECT id FROM users WHERE email = ?",
- (email,)
- )
- if cursor.fetchone():
- raise HTTPException(
- status_code=status.HTTP_409_CONFLICT,
- detail="邮箱已被使用"
- )
-
- # Hash password
- password_hash = bcrypt.hashpw(
- password.encode('utf-8'),
- bcrypt.gensalt()
- ).decode('utf-8')
-
- # Create user
- user_id = f"user_{uuid.uuid4().hex[:12]}"
- cursor.execute("""
- INSERT INTO users (
- id, username, email, password_hash, role
- )
- VALUES (?, ?, ?, ?, ?)
- """, (user_id, username, email, password_hash, "annotator"))
-
- # Fetch created user
- cursor.execute(
- "SELECT * FROM users WHERE id = ?",
- (user_id,)
- )
- row = cursor.fetchone()
- return User.from_row(row)
-
- @staticmethod
- def login_user(username: str, password: str) -> Dict:
- """
- Authenticate user and generate tokens.
-
- Args:
- username: Username
- password: Plain text password
-
- Returns:
- Dict containing access_token, refresh_token, and user info
-
- Raises:
- HTTPException: 401 if credentials are invalid
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Find user
- cursor.execute(
- "SELECT * FROM users WHERE username = ?",
- (username,)
- )
- row = cursor.fetchone()
-
- if not row:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码错误"
- )
-
- user = User.from_row(row)
-
- # Verify password
- try:
- password_valid = bcrypt.checkpw(
- password.encode('utf-8'),
- user.password_hash.encode('utf-8')
- )
- except (ValueError, TypeError) as e:
- # Invalid salt or hash format - password hash is corrupted or not bcrypt
- # This can happen if password was stored in plaintext or different format
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码错误"
- )
-
- if not password_valid:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码错误"
- )
-
- # Generate tokens
- user_data = {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "role": user.role,
- "created_at": user.created_at
- }
-
- access_token = JWTService.create_access_token(user_data)
- refresh_token = JWTService.create_refresh_token(user_data)
-
- return {
- "access_token": access_token,
- "refresh_token": refresh_token,
- "user": user_data
- }
-
- @staticmethod
- def refresh_tokens(refresh_token: str) -> Dict:
- """
- Refresh access token using refresh token.
-
- Args:
- refresh_token: Valid refresh token
-
- Returns:
- Dict containing new access_token and refresh_token
-
- Raises:
- HTTPException: 401 if refresh token is invalid or expired
- """
- try:
- payload = JWTService.verify_token(refresh_token, "refresh")
- if not payload:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的刷新令牌"
- )
-
- user_id = payload["sub"]
-
- # Fetch user from database
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(
- "SELECT * FROM users WHERE id = ?",
- (user_id,)
- )
- row = cursor.fetchone()
-
- if not row:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户不存在"
- )
-
- user = User.from_row(row)
- user_data = {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "role": user.role,
- "created_at": user.created_at
- }
-
- # Generate new tokens (token rotation)
- new_access_token = JWTService.create_access_token(user_data)
- new_refresh_token = JWTService.create_refresh_token(user_data)
-
- return {
- "access_token": new_access_token,
- "refresh_token": new_refresh_token,
- "user": user_data
- }
-
- except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="刷新令牌已过期或无效,请重新登录"
- )
-
- @staticmethod
- def get_current_user(user_id: str) -> User:
- """
- Get user by ID.
-
- Args:
- user_id: User unique identifier
-
- Returns:
- User object
-
- Raises:
- HTTPException: 404 if user not found
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(
- "SELECT * FROM users WHERE id = ?",
- (user_id,)
- )
- row = cursor.fetchone()
-
- if not row:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="用户不存在"
- )
-
- return User.from_row(row)
|