local_model_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. """
  2. 本地模型服务层
  3. 提供本地模型的CRUD业务逻辑处理
  4. 需求: 2.8, 4.1, 4.5, 4.6, 2.5, 2.6, 2.7, 14.1, 14.3, 14.4
  5. """
  6. from typing import List, Optional, Tuple
  7. from urllib.parse import urlparse
  8. import ipaddress
  9. import re
  10. from sqlalchemy.orm import Session
  11. from sqlalchemy import or_, and_
  12. from fastapi import HTTPException
  13. import httpx
  14. from app.models.model import ModelNew
  15. from app.schemas.local_model import LocalModelResponse
  16. from app.services.crypto_utils import encrypt_api_key
  17. class LocalModelService:
  18. """本地模型业务服务类"""
  19. # 内网敏感主机名(用于SSRF防护)
  20. PRIVATE_HOSTNAMES = [
  21. 'localhost',
  22. 'localhost.localdomain',
  23. ]
  24. def __init__(self, db: Session):
  25. self.db = db
  26. @staticmethod
  27. def validate_base_url(base_url: str) -> Tuple[bool, str]:
  28. """
  29. 验证Base URL的安全性
  30. 验证URL格式合法性,防止SSRF攻击(拒绝内网敏感地址)
  31. Args:
  32. base_url: 要验证的URL
  33. Returns:
  34. (是否有效, 错误消息)
  35. 需求: 14.1, 14.3
  36. """
  37. if not base_url:
  38. return False, "Base URL不能为空"
  39. # 去除首尾空格
  40. base_url = base_url.strip()
  41. # 解析URL
  42. try:
  43. parsed = urlparse(base_url)
  44. except Exception:
  45. return False, "URL格式无效"
  46. # 验证协议(必须是http或https)
  47. if parsed.scheme not in ('http', 'https'):
  48. return False, "URL协议必须是http或https"
  49. # 验证主机名存在
  50. hostname = parsed.hostname
  51. if not hostname:
  52. return False, "URL缺少主机名"
  53. hostname_lower = hostname.lower()
  54. # 检查是否为敏感主机名
  55. if hostname_lower in LocalModelService.PRIVATE_HOSTNAMES:
  56. return False, f"不允许使用内网地址:{hostname}"
  57. # 检查是否为IP地址
  58. try:
  59. ip = ipaddress.ip_address(hostname)
  60. # 检查IPv4私有地址
  61. if isinstance(ip, ipaddress.IPv4Address):
  62. if ip.is_loopback: # 127.x.x.x
  63. return False, f"不允许使用回环地址:{hostname}"
  64. if ip.is_private: # 10.x, 172.16-31.x, 192.168.x
  65. return False, f"不允许使用内网地址:{hostname}"
  66. if ip.is_link_local: # 169.254.x.x
  67. return False, f"不允许使用链路本地地址:{hostname}"
  68. if ip.is_reserved:
  69. return False, f"不允许使用保留地址:{hostname}"
  70. if ip.is_unspecified: # 0.0.0.0
  71. return False, f"不允许使用未指定地址:{hostname}"
  72. # 检查IPv6私有地址
  73. elif isinstance(ip, ipaddress.IPv6Address):
  74. if ip.is_loopback: # ::1
  75. return False, f"不允许使用回环地址:{hostname}"
  76. if ip.is_private:
  77. return False, f"不允许使用内网地址:{hostname}"
  78. if ip.is_link_local: # fe80::
  79. return False, f"不允许使用链路本地地址:{hostname}"
  80. if ip.is_reserved:
  81. return False, f"不允许使用保留地址:{hostname}"
  82. if ip.is_unspecified: # ::
  83. return False, f"不允许使用未指定地址:{hostname}"
  84. except ValueError:
  85. # 不是IP地址,是域名,进行额外检查
  86. # 检查是否包含可疑的内网域名模式
  87. suspicious_patterns = [
  88. r'\.local$',
  89. r'\.internal$',
  90. r'\.intranet$',
  91. r'\.corp$',
  92. r'\.lan$',
  93. ]
  94. for pattern in suspicious_patterns:
  95. if re.search(pattern, hostname_lower):
  96. return False, f"不允许使用内网域名:{hostname}"
  97. return True, ""
  98. def create_local_model(
  99. self,
  100. user_id: Optional[str],
  101. name: str,
  102. supplier: str,
  103. base_url: str,
  104. api_key: Optional[str] = None,
  105. visibility: str = 'global',
  106. categories: list = None
  107. ) -> ModelNew:
  108. """
  109. 创建本地模型(仅管理员)
  110. """
  111. is_valid, error_msg = self.validate_base_url(base_url)
  112. if not is_valid:
  113. raise HTTPException(status_code=400, detail=error_msg)
  114. encrypted_api_key = encrypt_api_key(api_key) if api_key else None
  115. import time
  116. timestamp = int(time.time() * 1000)
  117. model_code = f"local_admin_{timestamp}"
  118. local_model = ModelNew(
  119. model_code=model_code,
  120. display_name=name,
  121. supplier=supplier,
  122. img="",
  123. categories=categories or [0],
  124. is_local=True,
  125. user_id=None,
  126. base_url=base_url,
  127. local_api_key=encrypted_api_key,
  128. is_show_enabled=False,
  129. is_api_enabled=True,
  130. visibility=visibility,
  131. )
  132. self.db.add(local_model)
  133. self.db.commit()
  134. self.db.refresh(local_model)
  135. return local_model
  136. async def get_user_local_models(self, user_id: str) -> List[ModelNew]:
  137. """
  138. 获取用户的本地模型列表
  139. Args:
  140. user_id: 用户ID
  141. Returns:
  142. 用户的本地模型列表
  143. """
  144. from app.models.user_local_model_permission import UserLocalModelPermission
  145. from app.services.cache_service import CacheService
  146. # 从缓存获取用户本地模型列表
  147. permitted_model_ids = await CacheService.get_user_local_models(user_id)
  148. if not permitted_model_ids:
  149. # 从数据库获取用户有权限的模型ID列表
  150. user_permissions = self.db.query(UserLocalModelPermission).filter(
  151. UserLocalModelPermission.user_id == user_id,
  152. UserLocalModelPermission.has_access == True
  153. ).all()
  154. # 提取有权限的模型ID
  155. permitted_model_ids = [perm.model_id for perm in user_permissions]
  156. # 缓存用户本地模型列表
  157. await CacheService.set_user_local_models(user_id, permitted_model_ids)
  158. # 获取用户可以访问的模型:
  159. # 1. 用户自己创建的模型
  160. # 2. 全局可见且用户有权限的模型
  161. return self.db.query(ModelNew).filter(
  162. ModelNew.is_local == True,
  163. or_(
  164. ModelNew.user_id == user_id,
  165. and_(
  166. ModelNew.visibility == "global",
  167. ModelNew.id.in_(permitted_model_ids)
  168. )
  169. )
  170. ).order_by(ModelNew.created_at.desc()).all()
  171. def update_local_model(
  172. self,
  173. model_id: int,
  174. user_id: Optional[str],
  175. **kwargs
  176. ) -> ModelNew:
  177. """更新本地模型配置(仅管理员)"""
  178. model = self.db.query(ModelNew).filter(
  179. ModelNew.id == model_id,
  180. ModelNew.is_local == True
  181. ).first()
  182. if not model:
  183. raise HTTPException(status_code=404, detail="本地模型不存在")
  184. # 如果更新base_url,需要验证URL安全性(需求: 14.1, 14.3)
  185. if "base_url" in kwargs and kwargs["base_url"] is not None:
  186. is_valid, error_msg = self.validate_base_url(kwargs["base_url"])
  187. if not is_valid:
  188. raise HTTPException(status_code=400, detail=error_msg)
  189. # 更新允许的字段
  190. if "name" in kwargs and kwargs["name"] is not None:
  191. model.display_name = kwargs["name"]
  192. if "supplier" in kwargs and kwargs["supplier"] is not None:
  193. model.supplier = kwargs["supplier"]
  194. if "base_url" in kwargs and kwargs["base_url"] is not None:
  195. model.base_url = kwargs["base_url"]
  196. if "api_key" in kwargs:
  197. # api_key 可以设置为空字符串来清除
  198. if kwargs["api_key"]:
  199. model.local_api_key = encrypt_api_key(kwargs["api_key"])
  200. else:
  201. model.local_api_key = None
  202. if "visibility" in kwargs and kwargs["visibility"] is not None:
  203. model.visibility = kwargs["visibility"]
  204. if "category" in kwargs and kwargs["category"] is not None:
  205. cats = kwargs["category"]
  206. model.categories = [cats] if isinstance(cats, int) else cats
  207. self.db.commit()
  208. self.db.refresh(model)
  209. return model
  210. def delete_local_model(self, model_id: int, user_id: Optional[str]) -> bool:
  211. """
  212. 删除本地模型(仅管理员)
  213. Args:
  214. model_id: 模型ID
  215. user_id: 用户ID(本地模型不关联用户,传None)
  216. Returns:
  217. 删除是否成功
  218. Raises:
  219. HTTPException: 模型不存在时抛出错误
  220. """
  221. # 查询本地模型(不验证user_id,因为本地模型由管理员管理)
  222. model = self.db.query(ModelNew).filter(
  223. ModelNew.id == model_id,
  224. ModelNew.is_local == True
  225. ).first()
  226. if not model:
  227. raise HTTPException(status_code=404, detail="本地模型不存在")
  228. # 删除相关的用户权限记录
  229. from app.models.user_local_model_permission import UserLocalModelPermission
  230. self.db.query(UserLocalModelPermission).filter(
  231. UserLocalModelPermission.model_id == model_id
  232. ).delete()
  233. self.db.delete(model)
  234. self.db.commit()
  235. return True
  236. async def test_connection(
  237. self,
  238. base_url: str,
  239. api_key: Optional[str] = None,
  240. model_name: Optional[str] = None,
  241. category: Optional[int] = None
  242. ) -> dict:
  243. """
  244. 测试本地模型连接
  245. 根据模型分类测试不同的 OpenAI 兼容端点:
  246. - category=0 (对话): /chat/completions
  247. - category=2 (TTS): /audio/speech
  248. - category=3 (STT): /audio/transcriptions
  249. - category=4 (文生图): /images/generations
  250. - category=5 (视频生成): /videos/generations
  251. - category=6 (图像编辑): /images/edits
  252. - category=7 (向量嵌入): /embeddings
  253. - category=8 (重排序): /rerank
  254. - 其他: 默认测试 /chat/completions
  255. 注意:本地模型必须是 OpenAI 兼容的接口,不支持其他格式。
  256. Args:
  257. base_url: 本地模型API基础地址(如 https://api.example.com/v1)
  258. api_key: 本地模型访问密钥(可选)
  259. model_name: 模型名称(可选,用于测试)
  260. category: 模型分类(可选,用于选择测试端点)
  261. Returns:
  262. {"success": bool, "message": str}
  263. 需求: 2.5, 2.6, 2.7, 14.4
  264. """
  265. # 根据分类选择测试端点和请求体
  266. if category == 5: # 视频生成
  267. endpoint = '/videos/generations'
  268. test_payload = {
  269. "model": model_name or "video-model",
  270. "prompt": "a beautiful sunset timelapse",
  271. "size": "1280x720", # OpenAI标准格式
  272. "duration": 5
  273. }
  274. elif category == 4: # 文生图
  275. endpoint = '/images/generations'
  276. test_payload = {
  277. "model": model_name or "image-model",
  278. "prompt": "a beautiful landscape",
  279. "n": 1,
  280. "size": "1024x1024"
  281. }
  282. elif category == 6: # 图像编辑(图生图)
  283. endpoint = '/images/edits'
  284. test_payload = {
  285. "model": model_name or "dall-e-2",
  286. "prompt": "add a rainbow in the sky",
  287. "image": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # 1x1 透明PNG
  288. "size": "1024x1024"
  289. }
  290. elif category == 2: # 语音合成 (TTS)
  291. endpoint = '/audio/speech'
  292. test_payload = {
  293. "model": model_name or "tts-1",
  294. "input": "Hello, this is a test.",
  295. "voice": "alloy",
  296. "response_format": "mp3"
  297. }
  298. elif category == 3: # 语音识别 (STT)
  299. endpoint = '/audio/transcriptions'
  300. test_payload = {
  301. "model": model_name or "whisper-1",
  302. "file": "data:audio/mp3;base64,SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Ljc2LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWGluZwAAAA8AAAACAAADhAC7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7//////////////////////////////////////////////////////////////////8AAAAATGF2YzU4LjEzAAAAAAAAAAAAAAAAJAAAAAAAAAAAA4T8DeHGAAAAAAD/+xDEAAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEHgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEKAPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/+xDEPgPAAAGkAAAAIAAANIAAAARMQU1FMy4xMDAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
  303. "language": "en"
  304. }
  305. elif category == 7: # 向量嵌入 (Embedding)
  306. endpoint = '/embeddings'
  307. test_payload = {
  308. "model": model_name or "text-embedding-ada-002",
  309. "input": "Hello, this is a test.",
  310. "encoding_format": "float"
  311. }
  312. elif category == 8: # 重排序 (Rerank)
  313. endpoint = '/rerank'
  314. test_payload = {
  315. "model": model_name or "bge-reranker-v2-m3",
  316. "query": "What is machine learning?",
  317. "documents": [
  318. "Machine learning is a branch of artificial intelligence.",
  319. "Deep learning uses neural networks.",
  320. "Python is a programming language."
  321. ],
  322. "top_n": 2
  323. }
  324. else: # 默认测试对话接口 (LLM, MULTIMODAL)
  325. endpoint = '/chat/completions'
  326. test_payload = {
  327. "model": model_name or "gpt-3.5-turbo",
  328. "messages": [
  329. {"role": "user", "content": "hi"}
  330. ],
  331. "max_tokens": 5,
  332. "stream": False
  333. }
  334. # 构建请求URL(拼接端点路径)
  335. url = base_url.rstrip('/') + endpoint
  336. # 构建请求头
  337. headers = {
  338. "Content-Type": "application/json"
  339. }
  340. if api_key:
  341. headers["Authorization"] = f"Bearer {api_key}"
  342. try:
  343. async with httpx.AsyncClient(timeout=10.0) as client:
  344. response = await client.post(
  345. url,
  346. json=test_payload,
  347. headers=headers
  348. )
  349. # 尝试解析响应体
  350. try:
  351. response_data = response.json()
  352. error_info = response_data.get("error", {})
  353. error_msg = error_info.get("message", "") if isinstance(error_info, dict) else str(error_info)
  354. error_type = error_info.get("type", "") if isinstance(error_info, dict) else ""
  355. except:
  356. response_data = {}
  357. error_msg = ""
  358. error_type = ""
  359. # 检查响应状态码
  360. if response.status_code == 200:
  361. return {"success": True, "message": "连接成功,接口支持OpenAI格式"}
  362. elif response.status_code == 401:
  363. # 401说明接口存在且能识别OpenAI格式,只是API Key的问题
  364. if response_data:
  365. return {"success": True, "message": "接口支持OpenAI格式(API Key认证失败,请确保填写了正确的API Key)"}
  366. return {"success": False, "message": "认证失败:API Key无效或缺失"}
  367. elif response.status_code == 403:
  368. # 403也说明接口存在,只是权限问题
  369. if response_data:
  370. return {"success": True, "message": "接口支持OpenAI格式(请检查API Key权限)"}
  371. return {"success": False, "message": "访问被拒绝:权限不足"}
  372. elif response.status_code == 404:
  373. # 404可能是模型不存在或端点不存在
  374. error_lower = error_msg.lower()
  375. if "model" in error_lower or "not found" in error_lower or "does not exist" in error_lower:
  376. return {"success": True, "message": "接口支持OpenAI格式(测试模型名称不存在,实际使用时请填写正确的模型名称)"}
  377. # 如果返回了JSON响应,说明接口是通的
  378. if response_data:
  379. return {"success": True, "message": "接口支持OpenAI格式(请确认模型名称)"}
  380. return {"success": False, "message": f"接口不存在(404 Not Found):{endpoint} 端点不可用,请检查 Base URL 和模型分类是否正确"}
  381. elif response.status_code == 400:
  382. # 400错误可能是参数问题,但接口本身是通的
  383. if response_data:
  384. return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整,实际使用时请填写正确的模型名称)"}
  385. return {"success": True, "message": "接口支持OpenAI格式(请求参数可能需要调整)"}
  386. elif response.status_code >= 500:
  387. return {"success": False, "message": f"服务器错误:{response.status_code},请检查服务是否正常运行"}
  388. else:
  389. # 其他状态码,如果返回了JSON就认为接口是通的
  390. if response_data:
  391. return {"success": True, "message": f"接口可能支持OpenAI格式(状态码:{response.status_code})"}
  392. return {"success": False, "message": f"未知响应:状态码 {response.status_code}"}
  393. except httpx.TimeoutException:
  394. return {"success": False, "message": "连接超时:请检查服务是否正常运行(超时时间:10秒)"}
  395. except httpx.ConnectError:
  396. return {"success": False, "message": "连接失败:无法连接到指定地址,请检查URL是否正确"}
  397. except httpx.RequestError as e:
  398. return {"success": False, "message": f"请求错误:{str(e)}"}
  399. except Exception as e:
  400. return {"success": False, "message": f"未知错误:{str(e)}"}
  401. def get_local_model_by_id(self, model_id: int, user_id: str) -> ModelNew:
  402. """
  403. 根据ID获取本地模型
  404. Args:
  405. model_id: 模型ID
  406. user_id: 用户ID(用于权限验证)
  407. Returns:
  408. 模型对象
  409. Raises:
  410. HTTPException: 模型不存在或无权限时抛出错误
  411. """
  412. model = self.db.query(ModelNew).filter(
  413. ModelNew.id == model_id,
  414. ModelNew.is_local == True,
  415. ModelNew.user_id == user_id
  416. ).first()
  417. if not model:
  418. raise HTTPException(status_code=404, detail="本地模型不存在或无权限访问")
  419. return model
  420. @staticmethod
  421. def mask_base_url(base_url: str) -> str:
  422. """
  423. 对Base URL进行脱敏处理
  424. Args:
  425. base_url: 原始URL
  426. Returns:
  427. 脱敏后的URL(显示协议和域名,隐藏路径细节)
  428. """
  429. if not base_url:
  430. return ""
  431. # 简单脱敏:保留协议和域名部分
  432. try:
  433. parsed = urlparse(base_url)
  434. if parsed.scheme and parsed.netloc:
  435. return f"{parsed.scheme}://{parsed.netloc}/***"
  436. return base_url[:20] + "***" if len(base_url) > 20 else base_url
  437. except Exception:
  438. return base_url[:20] + "***" if len(base_url) > 20 else base_url
  439. def to_response(self, model: ModelNew) -> LocalModelResponse:
  440. """
  441. 将模型对象转换为响应对象
  442. Args:
  443. model: 模型对象
  444. Returns:
  445. 本地模型响应对象
  446. """
  447. return LocalModelResponse(
  448. id=model.id,
  449. name=model.display_name,
  450. base_url=self.mask_base_url(model.base_url),
  451. is_local=model.is_local,
  452. user_id=model.user_id,
  453. created_at=model.created_at,
  454. updated_at=model.updated_at
  455. )
  456. async def get_all_local_models(self) -> List[ModelNew]:
  457. return self.db.query(ModelNew).filter(
  458. ModelNew.is_local == True
  459. ).order_by(ModelNew.created_at.desc()).all()