| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 |
- """
- 本地模型服务层
- 提供本地模型的CRUD业务逻辑处理
- 需求: 2.8, 4.1, 4.5, 4.6, 2.5, 2.6, 2.7, 14.1, 14.3, 14.4
- """
- from typing import List, Optional, Tuple
- from urllib.parse import urlparse
- import ipaddress
- import re
- from sqlalchemy.orm import Session
- from sqlalchemy import or_, and_
- from fastapi import HTTPException
- import httpx
- from app.models.model import ModelNew
- from app.schemas.local_model import LocalModelResponse
- from app.services.crypto_utils import encrypt_api_key
- class LocalModelService:
- """本地模型业务服务类"""
- # 内网敏感主机名(用于SSRF防护)
- PRIVATE_HOSTNAMES = [
- 'localhost',
- 'localhost.localdomain',
- ]
- def __init__(self, db: Session):
- self.db = db
- @staticmethod
- def validate_base_url(base_url: str) -> Tuple[bool, str]:
- """
- 验证Base URL的安全性
-
- 验证URL格式合法性,防止SSRF攻击(拒绝内网敏感地址)
-
- Args:
- base_url: 要验证的URL
-
- Returns:
- (是否有效, 错误消息)
-
- 需求: 14.1, 14.3
- """
- if not base_url:
- return False, "Base URL不能为空"
-
- # 去除首尾空格
- base_url = base_url.strip()
-
- # 解析URL
- try:
- parsed = urlparse(base_url)
- except Exception:
- return False, "URL格式无效"
-
- # 验证协议(必须是http或https)
- if parsed.scheme not in ('http', 'https'):
- return False, "URL协议必须是http或https"
-
- # 验证主机名存在
- hostname = parsed.hostname
- if not hostname:
- return False, "URL缺少主机名"
-
- hostname_lower = hostname.lower()
-
- # 检查是否为敏感主机名
- if hostname_lower in LocalModelService.PRIVATE_HOSTNAMES:
- return False, f"不允许使用内网地址:{hostname}"
-
- # 检查是否为IP地址
- try:
- ip = ipaddress.ip_address(hostname)
-
- # 检查IPv4私有地址
- if isinstance(ip, ipaddress.IPv4Address):
- if ip.is_loopback: # 127.x.x.x
- return False, f"不允许使用回环地址:{hostname}"
- if ip.is_private: # 10.x, 172.16-31.x, 192.168.x
- return False, f"不允许使用内网地址:{hostname}"
- if ip.is_link_local: # 169.254.x.x
- return False, f"不允许使用链路本地地址:{hostname}"
- if ip.is_reserved:
- return False, f"不允许使用保留地址:{hostname}"
- if ip.is_unspecified: # 0.0.0.0
- return False, f"不允许使用未指定地址:{hostname}"
-
- # 检查IPv6私有地址
- elif isinstance(ip, ipaddress.IPv6Address):
- if ip.is_loopback: # ::1
- return False, f"不允许使用回环地址:{hostname}"
- if ip.is_private:
- return False, f"不允许使用内网地址:{hostname}"
- if ip.is_link_local: # fe80::
- return False, f"不允许使用链路本地地址:{hostname}"
- if ip.is_reserved:
- return False, f"不允许使用保留地址:{hostname}"
- if ip.is_unspecified: # ::
- return False, f"不允许使用未指定地址:{hostname}"
-
- except ValueError:
- # 不是IP地址,是域名,进行额外检查
- # 检查是否包含可疑的内网域名模式
- suspicious_patterns = [
- r'\.local$',
- r'\.internal$',
- r'\.intranet$',
- r'\.corp$',
- r'\.lan$',
- ]
- for pattern in suspicious_patterns:
- if re.search(pattern, hostname_lower):
- return False, f"不允许使用内网域名:{hostname}"
-
- return True, ""
- def create_local_model(
- self,
- user_id: Optional[str],
- name: str,
- supplier: str,
- base_url: str,
- api_key: Optional[str] = None,
- visibility: str = 'global',
- categories: list = None
- ) -> ModelNew:
- """
- 创建本地模型(仅管理员)
- """
- is_valid, error_msg = self.validate_base_url(base_url)
- if not is_valid:
- raise HTTPException(status_code=400, detail=error_msg)
- encrypted_api_key = encrypt_api_key(api_key) if api_key else None
- import time
- timestamp = int(time.time() * 1000)
- model_code = f"local_admin_{timestamp}"
- local_model = ModelNew(
- model_code=model_code,
- display_name=name,
- supplier=supplier,
- img="",
- categories=categories or [0],
- is_local=True,
- user_id=None,
- base_url=base_url,
- local_api_key=encrypted_api_key,
- is_show_enabled=False,
- is_api_enabled=True,
- visibility=visibility,
- )
- self.db.add(local_model)
- self.db.commit()
- self.db.refresh(local_model)
- return local_model
- async def get_user_local_models(self, user_id: str) -> List[ModelNew]:
- """
- 获取用户的本地模型列表
-
- Args:
- user_id: 用户ID
-
- Returns:
- 用户的本地模型列表
- """
- from app.models.user_local_model_permission import UserLocalModelPermission
- from app.services.cache_service import CacheService
-
- # 从缓存获取用户本地模型列表
- permitted_model_ids = await CacheService.get_user_local_models(user_id)
-
- if not permitted_model_ids:
- # 从数据库获取用户有权限的模型ID列表
- user_permissions = self.db.query(UserLocalModelPermission).filter(
- UserLocalModelPermission.user_id == user_id,
- UserLocalModelPermission.has_access == True
- ).all()
-
- # 提取有权限的模型ID
- permitted_model_ids = [perm.model_id for perm in user_permissions]
- # 缓存用户本地模型列表
- await CacheService.set_user_local_models(user_id, permitted_model_ids)
-
- # 获取用户可以访问的模型:
- # 1. 用户自己创建的模型
- # 2. 全局可见且用户有权限的模型
- return self.db.query(ModelNew).filter(
- ModelNew.is_local == True,
- or_(
- ModelNew.user_id == user_id,
- and_(
- ModelNew.visibility == "global",
- ModelNew.id.in_(permitted_model_ids)
- )
- )
- ).order_by(ModelNew.created_at.desc()).all()
- def update_local_model(
- self,
- model_id: int,
- user_id: Optional[str],
- **kwargs
- ) -> ModelNew:
- """更新本地模型配置(仅管理员)"""
- model = self.db.query(ModelNew).filter(
- ModelNew.id == model_id,
- ModelNew.is_local == True
- ).first()
-
- if not model:
- raise HTTPException(status_code=404, detail="本地模型不存在")
-
- # 如果更新base_url,需要验证URL安全性(需求: 14.1, 14.3)
- if "base_url" in kwargs and kwargs["base_url"] is not None:
- is_valid, error_msg = self.validate_base_url(kwargs["base_url"])
- if not is_valid:
- raise HTTPException(status_code=400, detail=error_msg)
-
- # 更新允许的字段
- if "name" in kwargs and kwargs["name"] is not None:
- model.display_name = kwargs["name"]
-
- if "supplier" in kwargs and kwargs["supplier"] is not None:
- model.supplier = kwargs["supplier"]
-
- if "base_url" in kwargs and kwargs["base_url"] is not None:
- model.base_url = kwargs["base_url"]
-
- if "api_key" in kwargs:
- # api_key 可以设置为空字符串来清除
- if kwargs["api_key"]:
- model.local_api_key = encrypt_api_key(kwargs["api_key"])
- else:
- model.local_api_key = None
-
- if "visibility" in kwargs and kwargs["visibility"] is not None:
- model.visibility = kwargs["visibility"]
-
- if "category" in kwargs and kwargs["category"] is not None:
- cats = kwargs["category"]
- model.categories = [cats] if isinstance(cats, int) else cats
-
- self.db.commit()
- self.db.refresh(model)
-
- return model
- def delete_local_model(self, model_id: int, user_id: Optional[str]) -> bool:
- """
- 删除本地模型(仅管理员)
-
- Args:
- model_id: 模型ID
- user_id: 用户ID(本地模型不关联用户,传None)
-
- Returns:
- 删除是否成功
-
- Raises:
- HTTPException: 模型不存在时抛出错误
- """
- # 查询本地模型(不验证user_id,因为本地模型由管理员管理)
- model = self.db.query(ModelNew).filter(
- ModelNew.id == model_id,
- ModelNew.is_local == True
- ).first()
-
- if not model:
- raise HTTPException(status_code=404, detail="本地模型不存在")
-
- # 删除相关的用户权限记录
- from app.models.user_local_model_permission import UserLocalModelPermission
- self.db.query(UserLocalModelPermission).filter(
- UserLocalModelPermission.model_id == model_id
- ).delete()
-
- self.db.delete(model)
- self.db.commit()
-
- return True
- async def test_connection(
- self,
- base_url: str,
- api_key: Optional[str] = None,
- model_name: Optional[str] = None,
- category: Optional[int] = None
- ) -> dict:
- """
- 测试本地模型连接
-
- 根据模型分类测试不同的 OpenAI 兼容端点:
- - category=0 (对话): /chat/completions
- - category=2 (TTS): /audio/speech
- - category=3 (STT): /audio/transcriptions
- - category=4 (文生图): /images/generations
- - category=5 (视频生成): /videos/generations
- - category=6 (图像编辑): /images/edits
- - category=7 (向量嵌入): /embeddings
- - category=8 (重排序): /rerank
- - 其他: 默认测试 /chat/completions
-
- 注意:本地模型必须是 OpenAI 兼容的接口,不支持其他格式。
-
- Args:
- base_url: 本地模型API基础地址(如 https://api.example.com/v1)
- api_key: 本地模型访问密钥(可选)
- model_name: 模型名称(可选,用于测试)
- category: 模型分类(可选,用于选择测试端点)
-
- Returns:
- {"success": bool, "message": str}
-
- 需求: 2.5, 2.6, 2.7, 14.4
- """
- # 根据分类选择测试端点和请求体
- if category == 5: # 视频生成
- endpoint = '/videos/generations'
- test_payload = {
- "model": model_name or "video-model",
- "prompt": "a beautiful sunset timelapse",
- "size": "1280x720", # OpenAI标准格式
- "duration": 5
- }
- elif category == 4: # 文生图
- endpoint = '/images/generations'
- test_payload = {
- "model": model_name or "image-model",
- "prompt": "a beautiful landscape",
- "n": 1,
- "size": "1024x1024"
- }
- elif category == 6: # 图像编辑(图生图)
- endpoint = '/images/edits'
- test_payload = {
- "model": model_name or "dall-e-2",
- "prompt": "add a rainbow in the sky",
- "image": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # 1x1 透明PNG
- "size": "1024x1024"
- }
- elif category == 2: # 语音合成 (TTS)
- endpoint = '/audio/speech'
- test_payload = {
- "model": model_name or "tts-1",
- "input": "Hello, this is a test.",
- "voice": "alloy",
- "response_format": "mp3"
- }
- elif category == 3: # 语音识别 (STT)
- endpoint = '/audio/transcriptions'
- test_payload = {
- "model": model_name or "whisper-1",
- "file": "data:audio/mp3;base64,SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Ljc2LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWGluZwAAAA8AAAACAAADhAC7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7//////////////////////////////////////////////////////////////////8AAAAATGF2YzU4LjEzAAAAAAAAAAAAAAAAJAAAAAAAAAAAA4T8DeHGAAAAAAD/+xDEAAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEHgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEKAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEPgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
- "language": "en"
- }
- elif category == 7: # 向量嵌入 (Embedding)
- endpoint = '/embeddings'
- test_payload = {
- "model": model_name or "text-embedding-ada-002",
- "input": "Hello, this is a test.",
- "encoding_format": "float"
- }
- elif category == 8: # 重排序 (Rerank)
- endpoint = '/rerank'
- test_payload = {
- "model": model_name or "bge-reranker-v2-m3",
- "query": "What is machine learning?",
- "documents": [
- "Machine learning is a branch of artificial intelligence.",
- "Deep learning uses neural networks.",
- "Python is a programming language."
- ],
- "top_n": 2
- }
- else: # 默认测试对话接口 (LLM, MULTIMODAL)
- endpoint = '/chat/completions'
- test_payload = {
- "model": model_name or "gpt-3.5-turbo",
- "messages": [
- {"role": "user", "content": "hi"}
- ],
- "max_tokens": 5,
- "stream": False
- }
-
- # 构建请求URL(拼接端点路径)
- url = base_url.rstrip('/') + endpoint
-
- # 构建请求头
- headers = {
- "Content-Type": "application/json"
- }
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- try:
- async with httpx.AsyncClient(timeout=10.0) as client:
- response = await client.post(
- url,
- json=test_payload,
- headers=headers
- )
-
- # 尝试解析响应体
- try:
- response_data = response.json()
- error_info = response_data.get("error", {})
- error_msg = error_info.get("message", "") if isinstance(error_info, dict) else str(error_info)
- error_type = error_info.get("type", "") if isinstance(error_info, dict) else ""
- except:
- response_data = {}
- error_msg = ""
- error_type = ""
-
- # 检查响应状态码
- if response.status_code == 200:
- return {"success": True, "message": "连接成功,接口支持OpenAI格式"}
- elif response.status_code == 401:
- # 401说明接口存在且能识别OpenAI格式,只是API Key的问题
- if response_data:
- return {"success": True, "message": "接口支持OpenAI格式(API Key认证失败,请确保填写了正确的API Key)"}
- return {"success": False, "message": "认证失败:API Key无效或缺失"}
- elif response.status_code == 403:
- # 403也说明接口存在,只是权限问题
- if response_data:
- return {"success": True, "message": "接口支持OpenAI格式(请检查API Key权限)"}
- return {"success": False, "message": "访问被拒绝:权限不足"}
- elif response.status_code == 404:
- # 404可能是模型不存在或端点不存在
- error_lower = error_msg.lower()
- if "model" in error_lower or "not found" in error_lower or "does not exist" in error_lower:
- return {"success": True, "message": "接口支持OpenAI格式(测试模型名称不存在,实际使用时请填写正确的模型名称)"}
- # 如果返回了JSON响应,说明接口是通的
- if response_data:
- return {"success": True, "message": "接口支持OpenAI格式(请确认模型名称)"}
- return {"success": False, "message": f"接口不存在(404 Not Found):{endpoint} 端点不可用,请检查 Base URL 和模型分类是否正确"}
- elif response.status_code == 400:
- # 400错误可能是参数问题,但接口本身是通的
- if response_data:
- return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整,实际使用时请填写正确的模型名称)"}
- return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整)"}
- elif response.status_code >= 500:
- return {"success": False, "message": f"服务器错误:{response.status_code},请检查服务是否正常运行"}
- else:
- # 其他状态码,如果返回了JSON就认为接口是通的
- if response_data:
- return {"success": True, "message": f"接口可能支持OpenAI格式(状态码:{response.status_code})"}
- return {"success": False, "message": f"未知响应:状态码 {response.status_code}"}
-
- except httpx.TimeoutException:
- return {"success": False, "message": "连接超时:请检查服务是否正常运行(超时时间:10秒)"}
- except httpx.ConnectError:
- return {"success": False, "message": "连接失败:无法连接到指定地址,请检查URL是否正确"}
- except httpx.RequestError as e:
- return {"success": False, "message": f"请求错误:{str(e)}"}
- except Exception as e:
- return {"success": False, "message": f"未知错误:{str(e)}"}
- def get_local_model_by_id(self, model_id: int, user_id: str) -> ModelNew:
- """
- 根据ID获取本地模型
-
- Args:
- model_id: 模型ID
- user_id: 用户ID(用于权限验证)
-
- Returns:
- 模型对象
-
- Raises:
- HTTPException: 模型不存在或无权限时抛出错误
- """
- model = self.db.query(ModelNew).filter(
- ModelNew.id == model_id,
- ModelNew.is_local == True,
- ModelNew.user_id == user_id
- ).first()
-
- if not model:
- raise HTTPException(status_code=404, detail="本地模型不存在或无权限访问")
-
- return model
- @staticmethod
- def mask_base_url(base_url: str) -> str:
- """
- 对Base URL进行脱敏处理
-
- Args:
- base_url: 原始URL
-
- Returns:
- 脱敏后的URL(显示协议和域名,隐藏路径细节)
- """
- if not base_url:
- return ""
-
- # 简单脱敏:保留协议和域名部分
- try:
- parsed = urlparse(base_url)
- if parsed.scheme and parsed.netloc:
- return f"{parsed.scheme}://{parsed.netloc}/***"
- return base_url[:20] + "***" if len(base_url) > 20 else base_url
- except Exception:
- return base_url[:20] + "***" if len(base_url) > 20 else base_url
- def to_response(self, model: ModelNew) -> LocalModelResponse:
- """
- 将模型对象转换为响应对象
-
- Args:
- model: 模型对象
-
- Returns:
- 本地模型响应对象
- """
- return LocalModelResponse(
- id=model.id,
- name=model.display_name,
- base_url=self.mask_base_url(model.base_url),
- is_local=model.is_local,
- user_id=model.user_id,
- created_at=model.created_at,
- updated_at=model.updated_at
- )
- async def get_all_local_models(self) -> List[ModelNew]:
- return self.db.query(ModelNew).filter(
- ModelNew.is_local == True
- ).order_by(ModelNew.created_at.desc()).all()
|