""" 本地模型服务层 提供本地模型的CRUD业务逻辑处理 需求: 2.8, 4.1, 4.5, 4.6, 2.5, 2.6, 2.7, 14.1, 14.3, 14.4 """ from typing import List, Optional, Tuple from urllib.parse import urlparse import ipaddress import re from sqlalchemy.orm import Session from sqlalchemy import or_, and_ from fastapi import HTTPException import httpx from app.models.model import ModelNew from app.schemas.local_model import LocalModelResponse from app.services.crypto_utils import encrypt_api_key class LocalModelService: """本地模型业务服务类""" # 内网敏感主机名(用于SSRF防护) PRIVATE_HOSTNAMES = [ 'localhost', 'localhost.localdomain', ] def __init__(self, db: Session): self.db = db @staticmethod def validate_base_url(base_url: str) -> Tuple[bool, str]: """ 验证Base URL的安全性 验证URL格式合法性,防止SSRF攻击(拒绝内网敏感地址) Args: base_url: 要验证的URL Returns: (是否有效, 错误消息) 需求: 14.1, 14.3 """ if not base_url: return False, "Base URL不能为空" # 去除首尾空格 base_url = base_url.strip() # 解析URL try: parsed = urlparse(base_url) except Exception: return False, "URL格式无效" # 验证协议(必须是http或https) if parsed.scheme not in ('http', 'https'): return False, "URL协议必须是http或https" # 验证主机名存在 hostname = parsed.hostname if not hostname: return False, "URL缺少主机名" hostname_lower = hostname.lower() # 检查是否为敏感主机名 if hostname_lower in LocalModelService.PRIVATE_HOSTNAMES: return False, f"不允许使用内网地址:{hostname}" # 检查是否为IP地址 try: ip = ipaddress.ip_address(hostname) # 检查IPv4私有地址 if isinstance(ip, ipaddress.IPv4Address): if ip.is_loopback: # 127.x.x.x return False, f"不允许使用回环地址:{hostname}" if ip.is_private: # 10.x, 172.16-31.x, 192.168.x return False, f"不允许使用内网地址:{hostname}" if ip.is_link_local: # 169.254.x.x return False, f"不允许使用链路本地地址:{hostname}" if ip.is_reserved: return False, f"不允许使用保留地址:{hostname}" if ip.is_unspecified: # 0.0.0.0 return False, f"不允许使用未指定地址:{hostname}" # 检查IPv6私有地址 elif isinstance(ip, ipaddress.IPv6Address): if ip.is_loopback: # ::1 return False, f"不允许使用回环地址:{hostname}" if ip.is_private: return False, f"不允许使用内网地址:{hostname}" if ip.is_link_local: # fe80:: return False, f"不允许使用链路本地地址:{hostname}" if ip.is_reserved: return False, f"不允许使用保留地址:{hostname}" if ip.is_unspecified: # :: return False, f"不允许使用未指定地址:{hostname}" except ValueError: # 不是IP地址,是域名,进行额外检查 # 检查是否包含可疑的内网域名模式 suspicious_patterns = [ r'\.local$', r'\.internal$', r'\.intranet$', r'\.corp$', r'\.lan$', ] for pattern in suspicious_patterns: if re.search(pattern, hostname_lower): return False, f"不允许使用内网域名:{hostname}" return True, "" def create_local_model( self, user_id: Optional[str], name: str, supplier: str, base_url: str, api_key: Optional[str] = None, visibility: str = 'global', categories: list = None ) -> ModelNew: """ 创建本地模型(仅管理员) """ is_valid, error_msg = self.validate_base_url(base_url) if not is_valid: raise HTTPException(status_code=400, detail=error_msg) encrypted_api_key = encrypt_api_key(api_key) if api_key else None import time timestamp = int(time.time() * 1000) model_code = f"local_admin_{timestamp}" local_model = ModelNew( model_code=model_code, display_name=name, supplier=supplier, img="", categories=categories or [0], is_local=True, user_id=None, base_url=base_url, local_api_key=encrypted_api_key, is_show_enabled=False, is_api_enabled=True, visibility=visibility, ) self.db.add(local_model) self.db.commit() self.db.refresh(local_model) return local_model async def get_user_local_models(self, user_id: str) -> List[ModelNew]: """ 获取用户的本地模型列表 Args: user_id: 用户ID Returns: 用户的本地模型列表 """ from app.models.user_local_model_permission import UserLocalModelPermission from app.services.cache_service import CacheService # 从缓存获取用户本地模型列表 permitted_model_ids = await CacheService.get_user_local_models(user_id) if not permitted_model_ids: # 从数据库获取用户有权限的模型ID列表 user_permissions = self.db.query(UserLocalModelPermission).filter( UserLocalModelPermission.user_id == user_id, UserLocalModelPermission.has_access == True ).all() # 提取有权限的模型ID permitted_model_ids = [perm.model_id for perm in user_permissions] # 缓存用户本地模型列表 await CacheService.set_user_local_models(user_id, permitted_model_ids) # 获取用户可以访问的模型: # 1. 用户自己创建的模型 # 2. 全局可见且用户有权限的模型 return self.db.query(ModelNew).filter( ModelNew.is_local == True, or_( ModelNew.user_id == user_id, and_( ModelNew.visibility == "global", ModelNew.id.in_(permitted_model_ids) ) ) ).order_by(ModelNew.created_at.desc()).all() def update_local_model( self, model_id: int, user_id: Optional[str], **kwargs ) -> ModelNew: """更新本地模型配置(仅管理员)""" model = self.db.query(ModelNew).filter( ModelNew.id == model_id, ModelNew.is_local == True ).first() if not model: raise HTTPException(status_code=404, detail="本地模型不存在") # 如果更新base_url,需要验证URL安全性(需求: 14.1, 14.3) if "base_url" in kwargs and kwargs["base_url"] is not None: is_valid, error_msg = self.validate_base_url(kwargs["base_url"]) if not is_valid: raise HTTPException(status_code=400, detail=error_msg) # 更新允许的字段 if "name" in kwargs and kwargs["name"] is not None: model.display_name = kwargs["name"] if "supplier" in kwargs and kwargs["supplier"] is not None: model.supplier = kwargs["supplier"] if "base_url" in kwargs and kwargs["base_url"] is not None: model.base_url = kwargs["base_url"] if "api_key" in kwargs: # api_key 可以设置为空字符串来清除 if kwargs["api_key"]: model.local_api_key = encrypt_api_key(kwargs["api_key"]) else: model.local_api_key = None if "visibility" in kwargs and kwargs["visibility"] is not None: model.visibility = kwargs["visibility"] if "category" in kwargs and kwargs["category"] is not None: cats = kwargs["category"] model.categories = [cats] if isinstance(cats, int) else cats self.db.commit() self.db.refresh(model) return model def delete_local_model(self, model_id: int, user_id: Optional[str]) -> bool: """ 删除本地模型(仅管理员) Args: model_id: 模型ID user_id: 用户ID(本地模型不关联用户,传None) Returns: 删除是否成功 Raises: HTTPException: 模型不存在时抛出错误 """ # 查询本地模型(不验证user_id,因为本地模型由管理员管理) model = self.db.query(ModelNew).filter( ModelNew.id == model_id, ModelNew.is_local == True ).first() if not model: raise HTTPException(status_code=404, detail="本地模型不存在") # 删除相关的用户权限记录 from app.models.user_local_model_permission import UserLocalModelPermission self.db.query(UserLocalModelPermission).filter( UserLocalModelPermission.model_id == model_id ).delete() self.db.delete(model) self.db.commit() return True async def test_connection( self, base_url: str, api_key: Optional[str] = None, model_name: Optional[str] = None, category: Optional[int] = None ) -> dict: """ 测试本地模型连接 根据模型分类测试不同的 OpenAI 兼容端点: - category=0 (对话): /chat/completions - category=2 (TTS): /audio/speech - category=3 (STT): /audio/transcriptions - category=4 (文生图): /images/generations - category=5 (视频生成): /videos/generations - category=6 (图像编辑): /images/edits - category=7 (向量嵌入): /embeddings - category=8 (重排序): /rerank - 其他: 默认测试 /chat/completions 注意:本地模型必须是 OpenAI 兼容的接口,不支持其他格式。 Args: base_url: 本地模型API基础地址(如 https://api.example.com/v1) api_key: 本地模型访问密钥(可选) model_name: 模型名称(可选,用于测试) category: 模型分类(可选,用于选择测试端点) Returns: {"success": bool, "message": str} 需求: 2.5, 2.6, 2.7, 14.4 """ # 根据分类选择测试端点和请求体 if category == 5: # 视频生成 endpoint = '/videos/generations' test_payload = { "model": model_name or "video-model", "prompt": "a beautiful sunset timelapse", "size": "1280x720", # OpenAI标准格式 "duration": 5 } elif category == 4: # 文生图 endpoint = '/images/generations' test_payload = { "model": model_name or "image-model", "prompt": "a beautiful landscape", "n": 1, "size": "1024x1024" } elif category == 6: # 图像编辑(图生图) endpoint = '/images/edits' test_payload = { "model": model_name or "dall-e-2", "prompt": "add a rainbow in the sky", "image": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # 1x1 透明PNG "size": "1024x1024" } elif category == 2: # 语音合成 (TTS) endpoint = '/audio/speech' test_payload = { "model": model_name or "tts-1", "input": "Hello, this is a test.", "voice": "alloy", "response_format": "mp3" } elif category == 3: # 语音识别 (STT) endpoint = '/audio/transcriptions' test_payload = { "model": model_name or "whisper-1", "file": "data:audio/mp3;base64,SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Ljc2LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWGluZwAAAA8AAAACAAADhAC7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7//////////////////////////////////////////////////////////////////8AAAAATGF2YzU4LjEzAAAAAAAAAAAAAAAAJAAAAAAAAAAAA4T8DeHGAAAAAAD/+xDEAAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEHgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEKAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEPgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", "language": "en" } elif category == 7: # 向量嵌入 (Embedding) endpoint = '/embeddings' test_payload = { "model": model_name or "text-embedding-ada-002", "input": "Hello, this is a test.", "encoding_format": "float" } elif category == 8: # 重排序 (Rerank) endpoint = '/rerank' test_payload = { "model": model_name or "bge-reranker-v2-m3", "query": "What is machine learning?", "documents": [ "Machine learning is a branch of artificial intelligence.", "Deep learning uses neural networks.", "Python is a programming language." ], "top_n": 2 } else: # 默认测试对话接口 (LLM, MULTIMODAL) endpoint = '/chat/completions' test_payload = { "model": model_name or "gpt-3.5-turbo", "messages": [ {"role": "user", "content": "hi"} ], "max_tokens": 5, "stream": False } # 构建请求URL(拼接端点路径) url = base_url.rstrip('/') + endpoint # 构建请求头 headers = { "Content-Type": "application/json" } if api_key: headers["Authorization"] = f"Bearer {api_key}" try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( url, json=test_payload, headers=headers ) # 尝试解析响应体 try: response_data = response.json() error_info = response_data.get("error", {}) error_msg = error_info.get("message", "") if isinstance(error_info, dict) else str(error_info) error_type = error_info.get("type", "") if isinstance(error_info, dict) else "" except: response_data = {} error_msg = "" error_type = "" # 检查响应状态码 if response.status_code == 200: return {"success": True, "message": "连接成功,接口支持OpenAI格式"} elif response.status_code == 401: # 401说明接口存在且能识别OpenAI格式,只是API Key的问题 if response_data: return {"success": True, "message": "接口支持OpenAI格式(API Key认证失败,请确保填写了正确的API Key)"} return {"success": False, "message": "认证失败:API Key无效或缺失"} elif response.status_code == 403: # 403也说明接口存在,只是权限问题 if response_data: return {"success": True, "message": "接口支持OpenAI格式(请检查API Key权限)"} return {"success": False, "message": "访问被拒绝:权限不足"} elif response.status_code == 404: # 404可能是模型不存在或端点不存在 error_lower = error_msg.lower() if "model" in error_lower or "not found" in error_lower or "does not exist" in error_lower: return {"success": True, "message": "接口支持OpenAI格式(测试模型名称不存在,实际使用时请填写正确的模型名称)"} # 如果返回了JSON响应,说明接口是通的 if response_data: return {"success": True, "message": "接口支持OpenAI格式(请确认模型名称)"} return {"success": False, "message": f"接口不存在(404 Not Found):{endpoint} 端点不可用,请检查 Base URL 和模型分类是否正确"} elif response.status_code == 400: # 400错误可能是参数问题,但接口本身是通的 if response_data: return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整,实际使用时请填写正确的模型名称)"} return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整)"} elif response.status_code >= 500: return {"success": False, "message": f"服务器错误:{response.status_code},请检查服务是否正常运行"} else: # 其他状态码,如果返回了JSON就认为接口是通的 if response_data: return {"success": True, "message": f"接口可能支持OpenAI格式(状态码:{response.status_code})"} return {"success": False, "message": f"未知响应:状态码 {response.status_code}"} except httpx.TimeoutException: return {"success": False, "message": "连接超时:请检查服务是否正常运行(超时时间:10秒)"} except httpx.ConnectError: return {"success": False, "message": "连接失败:无法连接到指定地址,请检查URL是否正确"} except httpx.RequestError as e: return {"success": False, "message": f"请求错误:{str(e)}"} except Exception as e: return {"success": False, "message": f"未知错误:{str(e)}"} def get_local_model_by_id(self, model_id: int, user_id: str) -> ModelNew: """ 根据ID获取本地模型 Args: model_id: 模型ID user_id: 用户ID(用于权限验证) Returns: 模型对象 Raises: HTTPException: 模型不存在或无权限时抛出错误 """ model = self.db.query(ModelNew).filter( ModelNew.id == model_id, ModelNew.is_local == True, ModelNew.user_id == user_id ).first() if not model: raise HTTPException(status_code=404, detail="本地模型不存在或无权限访问") return model @staticmethod def mask_base_url(base_url: str) -> str: """ 对Base URL进行脱敏处理 Args: base_url: 原始URL Returns: 脱敏后的URL(显示协议和域名,隐藏路径细节) """ if not base_url: return "" # 简单脱敏:保留协议和域名部分 try: parsed = urlparse(base_url) if parsed.scheme and parsed.netloc: return f"{parsed.scheme}://{parsed.netloc}/***" return base_url[:20] + "***" if len(base_url) > 20 else base_url except Exception: return base_url[:20] + "***" if len(base_url) > 20 else base_url def to_response(self, model: ModelNew) -> LocalModelResponse: """ 将模型对象转换为响应对象 Args: model: 模型对象 Returns: 本地模型响应对象 """ return LocalModelResponse( id=model.id, name=model.display_name, base_url=self.mask_base_url(model.base_url), is_local=model.is_local, user_id=model.user_id, created_at=model.created_at, updated_at=model.updated_at ) async def get_all_local_models(self) -> List[ModelNew]: return self.db.query(ModelNew).filter( ModelNew.is_local == True ).order_by(ModelNew.created_at.desc()).all()