llm_router.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """
  2. AI对话API路由
  3. 提供LLM对话的RESTful API端点,支持联网搜索功能
  4. 需求: 2.1, 3.1, 3.2, 5.5, 6.4, 10.5
  5. 流式搜索需求: 9.1, 9.2, 9.3, 9.4, 9.5
  6. """
  7. from typing import List, AsyncGenerator
  8. from fastapi import APIRouter, Depends, Request, HTTPException
  9. from fastapi.responses import StreamingResponse
  10. from sqlalchemy.orm import Session
  11. from app.database import get_db, SessionLocal
  12. from app.services.llm_service import LLMService
  13. from app.services.system_config_manager import get_config_bool
  14. from app.schemas.model_schema import ApiResponse, ModelResponse
  15. from app.schemas.llm_schema import (
  16. ChatRequest, ChatResponse,
  17. EnhancedChatRequest, EnhancedChatResponse
  18. )
  19. from app.models.user import User
  20. from app.middleware import get_current_user_from_request
  21. router = APIRouter(prefix="/api/llm", tags=["AI对话"])
  22. async def _stream_with_db(generator: AsyncGenerator) -> AsyncGenerator[str, None]:
  23. """
  24. 包装流式生成器,确保 db 连接在流完全结束后才关闭。
  25. 生成器本身负责持有 db 引用,此函数只负责透传数据。
  26. """
  27. async for chunk in generator:
  28. yield chunk
  29. @router.post("/chat")
  30. async def chat(
  31. request: EnhancedChatRequest,
  32. req: Request,
  33. db: Session = Depends(get_db),
  34. current_user: User = Depends(get_current_user_from_request)
  35. ):
  36. """
  37. 统一对话API端点
  38. 支持流式和非流式输出,自动检测搜索选项
  39. 需要用户认证,使用用户的apikey调用百炼平台
  40. 需要余额检查,余额不足时返回402错误
  41. 向后兼容:如果没有提供search_options,则使用普通对话模式
  42. """
  43. api_key = current_user.apikey
  44. if not api_key:
  45. return ApiResponse(code=403, message="未配置API密钥,请在用户设置中配置apikey", data=None)
  46. # 优先使用模型自带的 api_key(爬虫同步的),没有才 fallback 到用户自己配置的 apikey
  47. from app.services.crypto_utils import get_effective_api_key
  48. api_key = get_effective_api_key(db, request.model, api_key)
  49. search_enabled = (
  50. hasattr(request, 'search_options') and
  51. request.search_options and
  52. request.search_options.enable_search
  53. )
  54. if search_enabled and not get_config_bool("enable_search", True):
  55. return ApiResponse(code=403, message="系统暂未开放联网搜索功能", data=None)
  56. try:
  57. if request.stream:
  58. # 流式请求:手动管理 db 生命周期,确保流结束后才关闭连接
  59. stream_db = SessionLocal()
  60. async def stream_and_close():
  61. try:
  62. service = LLMService(stream_db, api_key=api_key, user_id=str(current_user.id))
  63. if search_enabled:
  64. gen = service.chat_stream_with_search(request, conversation_id=request.conversation_id)
  65. else:
  66. gen = service.chat_stream(request, conversation_id=request.conversation_id)
  67. async for chunk in gen:
  68. yield chunk
  69. finally:
  70. stream_db.close()
  71. return StreamingResponse(
  72. stream_and_close(),
  73. media_type="text/event-stream",
  74. headers={
  75. "Cache-Control": "no-cache",
  76. "Connection": "keep-alive",
  77. "X-Accel-Buffering": "no"
  78. }
  79. )
  80. else:
  81. service = LLMService(db, api_key=api_key, user_id=str(current_user.id))
  82. if search_enabled:
  83. data = service.chat_with_search(request, conversation_id=request.conversation_id)
  84. else:
  85. data = service.chat(request, conversation_id=request.conversation_id)
  86. return ApiResponse(code=200, message="success", data=data)
  87. except HTTPException:
  88. raise
  89. except Exception as e:
  90. return ApiResponse(code=500, message=f"对话服务异常: {str(e)}", data=None)
  91. @router.post("/chat/search")
  92. async def chat_with_search(
  93. request: EnhancedChatRequest,
  94. req: Request,
  95. db: Session = Depends(get_db),
  96. current_user: User = Depends(get_current_user_from_request)
  97. ):
  98. """
  99. 支持搜索的对话API端点
  100. 支持流式和非流式输出,集成联网搜索功能
  101. 需要用户认证,使用用户的apikey调用百炼平台
  102. 需求: 5.5, 6.4, 9.1, 9.2, 9.3, 9.4, 9.5
  103. """
  104. if not get_config_bool("enable_search", True):
  105. return ApiResponse(code=403, message="系统暂未开放联网搜索功能", data=None)
  106. api_key = current_user.apikey
  107. if not api_key:
  108. return ApiResponse(code=403, message="未配置API密钥,请在用户设置中配置apikey", data=None)
  109. # 优先使用模型自带的 api_key(爬虫同步的),没有才 fallback 到用户自己配置的 apikey
  110. from app.services.crypto_utils import get_effective_api_key
  111. api_key = get_effective_api_key(db, request.model, api_key)
  112. try:
  113. if request.stream:
  114. stream_db = SessionLocal()
  115. async def stream_and_close():
  116. try:
  117. service = LLMService(stream_db, api_key=api_key, user_id=str(current_user.id))
  118. async for chunk in service.chat_stream_with_search(request, conversation_id=request.conversation_id):
  119. yield chunk
  120. finally:
  121. stream_db.close()
  122. return StreamingResponse(
  123. stream_and_close(),
  124. media_type="text/event-stream",
  125. headers={
  126. "Cache-Control": "no-cache",
  127. "Connection": "keep-alive",
  128. "X-Accel-Buffering": "no"
  129. }
  130. )
  131. else:
  132. service = LLMService(db, api_key=api_key, user_id=str(current_user.id))
  133. data = service.chat_with_search(request, conversation_id=request.conversation_id)
  134. return ApiResponse(code=200, message="success", data=data)
  135. except HTTPException:
  136. raise
  137. except Exception as e:
  138. return ApiResponse(code=500, message=f"搜索增强对话服务异常: {str(e)}", data=None)
  139. @router.get("/search/models", response_model=ApiResponse[List[str]])
  140. def get_search_supported_models(
  141. db: Session = Depends(get_db),
  142. current_user: User = Depends(get_current_user_from_request)
  143. ):
  144. """
  145. 获取支持搜索功能的模型列表
  146. 返回支持联网搜索的模型名称列表
  147. 需求: 10.1, 10.4
  148. """
  149. api_key = current_user.apikey
  150. if not api_key:
  151. return ApiResponse(code=403, message="未配置API密钥,请在用户设置中配置apikey", data=None)
  152. service = LLMService(db, api_key=api_key, user_id=str(current_user.id))
  153. data = service.get_search_supported_models()
  154. return ApiResponse(code=200, message="success", data=data)
  155. @router.get("/search/check/{model}")
  156. def check_search_support(
  157. model: str,
  158. db: Session = Depends(get_db),
  159. current_user: User = Depends(get_current_user_from_request)
  160. ):
  161. """
  162. 检查指定模型是否支持搜索功能
  163. Args:
  164. model: 模型名称
  165. Returns:
  166. 是否支持搜索功能的布尔值
  167. 需求: 10.1, 10.4
  168. """
  169. api_key = current_user.apikey
  170. if not api_key:
  171. return ApiResponse(code=403, message="未配置API密钥,请在用户设置中配置apikey", data=None)
  172. service = LLMService(db, api_key=api_key, user_id=str(current_user.id))
  173. is_supported = service.is_search_supported(model)
  174. return ApiResponse(
  175. code=200,
  176. message="success",
  177. data={"model": model, "search_supported": is_supported}
  178. )
  179. @router.get("/models", response_model=ApiResponse[List[ModelResponse]])
  180. def get_llm_models(db: Session = Depends(get_db)):
  181. """
  182. 获取所有可用的LLM模型列表
  183. 返回type=0的语言模型
  184. """
  185. service = LLMService(db)
  186. data = service.get_llm_models()
  187. return ApiResponse(code=200, message="success", data=data)