user_local_model_permission_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. 用户本地模型权限服务
  3. 处理用户对本地模型的权限管理逻辑
  4. """
  5. from typing import List, Dict, Optional
  6. from sqlalchemy.orm import Session
  7. from sqlalchemy.exc import IntegrityError
  8. from app.models.user_local_model_permission import UserLocalModelPermission
  9. from app.models.model import ModelNew
  10. class UserLocalModelPermissionService:
  11. """
  12. 用户本地模型权限服务类
  13. """
  14. def __init__(self, db: Session):
  15. self.db = db
  16. def get_user_model_permissions(self, user_id: str) -> List[Dict]:
  17. """
  18. 获取用户的所有本地模型权限
  19. Args:
  20. user_id: 用户ID
  21. Returns:
  22. 权限列表,包含模型信息和权限状态
  23. """
  24. try:
  25. # 获取所有本地模型
  26. local_models = self.db.query(ModelNew).filter(ModelNew.is_local == True).all()
  27. # 获取用户的权限设置
  28. permissions = self.db.query(UserLocalModelPermission).filter(
  29. UserLocalModelPermission.user_id == user_id
  30. ).all()
  31. # 构建权限字典
  32. permission_dict = {p.model_id: p.has_access for p in permissions}
  33. # 构建结果列表
  34. result = []
  35. for model in local_models:
  36. result.append({
  37. "model_id": model.id,
  38. "model_name": model.display_name,
  39. "model_title": model.model_code,
  40. "base_url": model.base_url,
  41. "has_access": permission_dict.get(model.id, False)
  42. })
  43. return result
  44. except Exception as e:
  45. # 如果发生错误,返回空列表
  46. return []
  47. async def update_user_model_permission(self, user_id: str, model_id: int, has_access: bool) -> bool:
  48. """
  49. 更新用户对本地模型的权限
  50. Args:
  51. user_id: 用户ID
  52. model_id: 模型ID
  53. has_access: 是否有权限访问
  54. Returns:
  55. 是否更新成功
  56. """
  57. # 检查模型是否存在且是本地模型
  58. model = self.db.query(ModelNew).filter(
  59. ModelNew.id == model_id,
  60. ModelNew.is_local == True
  61. ).first()
  62. if not model:
  63. return False
  64. # 查找现有权限记录
  65. permission = self.db.query(UserLocalModelPermission).filter(
  66. UserLocalModelPermission.user_id == user_id,
  67. UserLocalModelPermission.model_id == model_id
  68. ).first()
  69. if permission:
  70. # 更新现有记录
  71. permission.has_access = has_access
  72. else:
  73. # 创建新记录
  74. permission = UserLocalModelPermission(
  75. user_id=user_id,
  76. model_id=model_id,
  77. has_access=has_access
  78. )
  79. self.db.add(permission)
  80. try:
  81. self.db.commit()
  82. # 删除相关缓存
  83. from app.services.cache_service import CacheService
  84. await CacheService.delete_user_permission(user_id, model_id)
  85. await CacheService.delete_user_local_models(user_id)
  86. return True
  87. except IntegrityError:
  88. self.db.rollback()
  89. return False
  90. async def update_user_all_model_permissions(self, user_id: str, has_access: bool) -> bool:
  91. """
  92. 更新用户对所有本地模型的权限
  93. Args:
  94. user_id: 用户ID
  95. has_access: 是否有权限访问
  96. Returns:
  97. 是否更新成功
  98. """
  99. # 获取所有本地模型
  100. local_models = self.db.query(ModelNew).filter(ModelNew.is_local == True).all()
  101. try:
  102. for model in local_models:
  103. # 查找现有权限记录
  104. permission = self.db.query(UserLocalModelPermission).filter(
  105. UserLocalModelPermission.user_id == user_id,
  106. UserLocalModelPermission.model_id == model.id
  107. ).first()
  108. if permission:
  109. # 更新现有记录
  110. permission.has_access = has_access
  111. else:
  112. # 创建新记录
  113. permission = UserLocalModelPermission(
  114. user_id=user_id,
  115. model_id=model.id,
  116. has_access=has_access
  117. )
  118. self.db.add(permission)
  119. self.db.commit()
  120. # 删除相关缓存
  121. from app.services.cache_service import CacheService
  122. await CacheService.delete_user_local_models(user_id)
  123. # 也可以删除每个模型的权限缓存,但为了性能考虑,这里只删除用户本地模型列表缓存
  124. return True
  125. except Exception:
  126. self.db.rollback()
  127. return False
  128. async def check_user_model_access(self, user_id: str, model_id: int) -> bool:
  129. """
  130. 检查用户是否有权限访问指定本地模型
  131. Args:
  132. user_id: 用户ID
  133. model_id: 模型ID
  134. Returns:
  135. 是否有权限访问
  136. """
  137. # 检查模型是否存在且是本地模型
  138. model = self.db.query(ModelNew).filter(
  139. ModelNew.id == model_id,
  140. ModelNew.is_local == True
  141. ).first()
  142. if not model:
  143. return False
  144. # 检查本地模型是否启用
  145. from app.services.system_config_manager import get_config_bool
  146. if get_config_bool("enable_local_models", True):
  147. # 如果本地模型启用,所有用户都有权限访问所有本地模型
  148. return True
  149. # 从缓存获取权限
  150. from app.services.cache_service import CacheService
  151. has_access = await CacheService.get_user_permission(user_id, model_id)
  152. if has_access is not None:
  153. return has_access
  154. # 从数据库获取
  155. permission = self.db.query(UserLocalModelPermission).filter(
  156. UserLocalModelPermission.user_id == user_id,
  157. UserLocalModelPermission.model_id == model_id
  158. ).first()
  159. has_access = permission.has_access if permission else False
  160. # 缓存权限信息
  161. await CacheService.set_user_permission(user_id, model_id, has_access)
  162. return has_access