login.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: login.py
  6. @date:2025/4/14 11:08
  7. @desc:
  8. """
  9. import base64
  10. import json
  11. import logging
  12. from captcha.image import ImageCaptcha
  13. from django.core import signing
  14. from django.core.cache import cache
  15. from django.db.models import QuerySet
  16. from django.utils.translation import gettext_lazy as _
  17. from rest_framework import serializers
  18. from application.models import ApplicationAccessToken
  19. from common.constants.authentication_type import AuthenticationType
  20. from common.constants.cache_version import Cache_Version
  21. from common.database_model_manage.database_model_manage import DatabaseModelManage
  22. from common.exception.app_exception import AppApiException
  23. from common.utils.common import password_encrypt, get_random_chars
  24. from common.utils.rsa_util import decrypt
  25. from maxkb.const import CONFIG
  26. from users.models import User
  27. from common.utils.logger import maxkb_logger
  28. class LoginRequest(serializers.Serializer):
  29. username = serializers.CharField(required=True, max_length=64, help_text=_("Username"), label=_("Username"))
  30. password = serializers.CharField(required=True, max_length=128, label=_("Password"))
  31. captcha = serializers.CharField(required=False, max_length=64, label=_('captcha'), allow_null=True,
  32. allow_blank=True)
  33. encryptedData = serializers.CharField(required=False, label=_('encryptedData'), allow_null=True,
  34. allow_blank=True)
  35. system_version, system_get_key = Cache_Version.SYSTEM.value
  36. class LoginResponse(serializers.Serializer):
  37. """
  38. 登录响应对象
  39. """
  40. token = serializers.CharField(required=True, label=_("token"))
  41. def record_login_fail(username: str, expire: int = 600):
  42. """记录登录失败次数(原子)返回当前失败计数"""
  43. if not username:
  44. return 0
  45. fail_key = system_get_key(f'system_{username}')
  46. try:
  47. fail_count = cache.incr(fail_key, 1, version=system_version)
  48. except ValueError:
  49. # key 不存在,初始化并设置过期
  50. cache.set(fail_key, 1, timeout=expire, version=system_version)
  51. fail_count = 1
  52. return fail_count
  53. def record_login_fail_lock(username: str, expire: int = 10):
  54. """
  55. 使用 cache.incr 保证原子递增,并在不存在时初始化计数器并返回当前值。
  56. 这里的计数器用于判断是否应当进入“锁定”分支,避免依赖非原子 get -> set 的组合。
  57. """
  58. if not username:
  59. return 0
  60. fail_key = system_get_key(f'system_{username}_lock_count')
  61. try:
  62. fail_count = cache.incr(fail_key, 1, version=system_version)
  63. except ValueError:
  64. # key 不存在,初始化并设置过期(分钟转秒)
  65. cache.set(fail_key, 1, timeout=expire * 60, version=system_version)
  66. fail_count = 1
  67. return fail_count
  68. class LoginSerializer(serializers.Serializer):
  69. @staticmethod
  70. def get_auth_setting():
  71. """获取认证设置"""
  72. auth_setting_model = DatabaseModelManage.get_model('auth_setting')
  73. auth_setting = {}
  74. if auth_setting_model:
  75. setting_obj = auth_setting_model.objects.filter(param_key='auth_setting').first()
  76. if setting_obj:
  77. try:
  78. auth_setting = json.loads(setting_obj.param_value) or {}
  79. except Exception:
  80. auth_setting = {}
  81. return auth_setting
  82. @staticmethod
  83. def login(instance):
  84. # 解密数据
  85. username = instance.get("username", "")
  86. encrypted_data = instance.get("encryptedData", "")
  87. if encrypted_data:
  88. try:
  89. decrypted_raw = decrypt(encrypted_data)
  90. # decrypt 可能返回非 JSON 字符串,防护解析异常
  91. decrypted_data = json.loads(decrypted_raw) if decrypted_raw else {}
  92. if isinstance(decrypted_data, dict):
  93. instance.update(decrypted_data)
  94. except Exception as e:
  95. maxkb_logger.exception("Failed to decrypt/parse encryptedData for user %s: %s", username, e)
  96. raise AppApiException(500, _("Invalid encrypted data"))
  97. try:
  98. LoginRequest(data=instance).is_valid(raise_exception=True)
  99. except serializers.ValidationError:
  100. raise
  101. except Exception as e:
  102. raise AppApiException(500, str(e))
  103. password = instance.get("password")
  104. captcha = instance.get("captcha", "")
  105. # 获取认证配置
  106. auth_setting = LoginSerializer.get_auth_setting()
  107. max_attempts = auth_setting.get("max_attempts", 1)
  108. failed_attempts = auth_setting.get("failed_attempts", 5)
  109. lock_time = auth_setting.get("lock_time", 10)
  110. # 检查许可证有效性
  111. license_validator = DatabaseModelManage.get_model('license_is_valid') or (lambda: False)
  112. is_license_valid = license_validator() if license_validator() is not None else False
  113. if is_license_valid:
  114. # 检查账户是否被锁定
  115. if LoginSerializer._is_account_locked(username, failed_attempts):
  116. raise AppApiException(
  117. 1005,
  118. _("This account has been locked for %s minutes, please try again later") % lock_time
  119. )
  120. # 验证验证码
  121. if LoginSerializer._need_captcha(username, max_attempts):
  122. LoginSerializer._validate_captcha(username, captcha)
  123. # 验证用户凭据
  124. user = User.objects.filter(
  125. username=username,
  126. password=password_encrypt(password)
  127. ).first()
  128. if not user:
  129. LoginSerializer._handle_failed_login(username, is_license_valid, failed_attempts, lock_time)
  130. raise AppApiException(500, _('The username or password is incorrect'))
  131. if not user.is_active:
  132. raise AppApiException(1005, _("The user has been disabled, please contact the administrator!"))
  133. # 清除失败计数并生成令牌
  134. cache.delete(system_get_key(f'system_{username}'), version=system_version)
  135. cache.delete(system_get_key(f'system_{username}_lock'), version=system_version)
  136. token = signing.dumps({
  137. 'username': user.username,
  138. 'id': str(user.id),
  139. 'email': user.email,
  140. 'type': AuthenticationType.SYSTEM_USER.value
  141. })
  142. version, get_key = Cache_Version.TOKEN.value
  143. timeout = CONFIG.get_session_timeout()
  144. cache.set(get_key(token), user, timeout=timeout, version=version)
  145. return {'token': token}
  146. @staticmethod
  147. def _is_account_locked(username: str, failed_attempts: int) -> bool:
  148. """检查账户是否被锁定"""
  149. if failed_attempts == -1:
  150. return False
  151. lock_cache = cache.get(system_get_key(f'system_{username}_lock'), version=system_version)
  152. return bool(lock_cache)
  153. @staticmethod
  154. def _need_captcha(username: str, max_attempts: int) -> bool:
  155. """判断是否需要验证码"""
  156. if max_attempts == -1:
  157. return False
  158. elif max_attempts > 0:
  159. fail_count = cache.get(system_get_key(f'system_{username}'), version=system_version) or 0
  160. return fail_count >= max_attempts
  161. return True
  162. @staticmethod
  163. def _validate_captcha(username: str, captcha: str) -> None:
  164. """验证验证码"""
  165. if not captcha:
  166. raise AppApiException(1005, _("Captcha is required"))
  167. captcha_cache = cache.get(
  168. Cache_Version.CAPTCHA.get_key(captcha=f"system_{username}"),
  169. version=Cache_Version.CAPTCHA.get_version()
  170. )
  171. if captcha_cache is None or captcha.lower() != captcha_cache:
  172. raise AppApiException(1005, _("Captcha code error or expiration"))
  173. @staticmethod
  174. def _handle_failed_login(username: str, is_license_valid: bool, failed_attempts: int, lock_time: int) -> None:
  175. """处理登录失败
  176. 修复要点:
  177. - 使用 record_login_fail / record_login_fail_lock 两个原子 incr 来记录失败;
  178. - 不再依赖精确等于 0 的比较来触发锁,而是基于原子计数 >= 阈值来决定进入锁定分支;
  179. - 使用 cache.add 原子创建锁键,cache.add 保证只有第一个成功创建者可写入该键;
  180. 其他并发到达的请求若发现计数已到达阈值也应当返回“已锁定”响应,避免出现绕过。
  181. """
  182. # 记录普通失败计数(供验证码触发使用)
  183. try:
  184. record_login_fail(username)
  185. except Exception:
  186. maxkb_logger.exception("Failed to record login fail for user %s", username)
  187. # 记录用于锁定判断的失败计数(按 lock_time 作为初始化过期分钟)
  188. lock_fail_count = 0
  189. try:
  190. lock_fail_count = record_login_fail_lock(username, lock_time)
  191. except Exception:
  192. maxkb_logger.exception("Failed to record lock fail count for user %s", username)
  193. # 如果不是企业版或禁用锁定功能,直接返回(但计数已经记录)
  194. if not is_license_valid or failed_attempts <= 0:
  195. return
  196. # 当计数小于阈值,告知剩余尝试次数
  197. if lock_fail_count < failed_attempts:
  198. remain_attempts = failed_attempts - lock_fail_count
  199. raise AppApiException(
  200. 1005,
  201. _("Login failed %s times, account will be locked, you have %s more chances !") % (
  202. failed_attempts, remain_attempts
  203. )
  204. )
  205. # 当计数达到或超过阈值时,尝试原子创建锁键;无论 cache.add 返回 True/False,都返回已锁定响应,
  206. # 因为若为 False 说明其他并发请求已将账户标记为锁定,行为应一致。
  207. try:
  208. locked = cache.add(
  209. system_get_key(f'system_{username}_lock'),
  210. 1,
  211. timeout=lock_time * 60,
  212. version=system_version
  213. )
  214. if locked:
  215. maxkb_logger.info("Account %s locked by setting cache key", username)
  216. else:
  217. maxkb_logger.info("Account %s lock key already present (another request set it)", username)
  218. except Exception:
  219. maxkb_logger.exception("Failed to set lock key for user %s", username)
  220. raise AppApiException(
  221. 1005,
  222. _("This account has been locked for %s minutes, please try again later") % lock_time
  223. )
  224. class CaptchaResponse(serializers.Serializer):
  225. """
  226. 登录响应对象
  227. """
  228. captcha = serializers.CharField(required=True, label=_("captcha"))
  229. class CaptchaSerializer(serializers.Serializer):
  230. @staticmethod
  231. def generate(username: str, type: str = 'system'):
  232. auth_setting = LoginSerializer.get_auth_setting()
  233. max_attempts = auth_setting.get("max_attempts", 1)
  234. need_captcha = True
  235. if max_attempts == -1:
  236. need_captcha = False
  237. elif max_attempts > 0:
  238. fail_count = cache.get(system_get_key(f'system_{username}'), version=system_version) or 0
  239. need_captcha = fail_count >= max_attempts
  240. return CaptchaSerializer._generate_captcha_if_needed(username, type, need_captcha)
  241. @staticmethod
  242. def chat_generate(username: str, type: str = 'chat', access_token: str = ''):
  243. application_access_token = ApplicationAccessToken.objects.filter(
  244. access_token=access_token
  245. ).first()
  246. if not application_access_token:
  247. raise AppApiException(1005, _('Invalid access token'))
  248. auth_setting = application_access_token.authentication_value
  249. max_attempts = auth_setting.get("max_attempts", 1)
  250. need_captcha = True
  251. if max_attempts == -1:
  252. need_captcha = False
  253. elif max_attempts > 0:
  254. fail_count = cache.get(system_get_key(f'{type}_{username}'), version=system_version) or 0
  255. need_captcha = fail_count >= max_attempts
  256. return CaptchaSerializer._generate_captcha_if_needed(username, type, need_captcha)
  257. @staticmethod
  258. def _generate_captcha_if_needed(username: str, type: str, need_captcha: bool):
  259. """
  260. 提取的公共验证码生成方法
  261. """
  262. if need_captcha:
  263. chars = get_random_chars()
  264. image = ImageCaptcha()
  265. data = image.generate(chars)
  266. captcha = base64.b64encode(data.getbuffer())
  267. cache.set(Cache_Version.CAPTCHA.get_key(captcha=f'{type}_{username}'), chars.lower(),
  268. timeout=300, version=Cache_Version.CAPTCHA.get_version())
  269. return {'captcha': 'data:image/png;base64,' + captcha.decode()}
  270. return {'captcha': ''}