FanHong před 4 dny
rodič
revize
dd6a256752

+ 27 - 26
shudao-chat-py/services/aichat_proxy.py

@@ -12,46 +12,46 @@ from utils.logger import logger
 
 class AIChatProxy:
     """AIChat 服务代理"""
-    
+
     def __init__(self):
         self.base_url = settings.aichat.api_url.rstrip('/')
         self.timeout = settings.aichat.timeout
-    
+
     def _get_auth_headers(self, request: Request) -> dict:
         """提取并转发认证 headers"""
         headers = {}
-        
+
         # 支持多种 header 名称
         for header_name in ["Authorization", "Token", "token"]:
             header_value = request.headers.get(header_name, "").strip()
             if header_value:
                 headers[header_name] = header_value
-        
+
         return headers
-    
+
     async def proxy_sse(
-        self, 
-        path: str, 
+        self,
+        path: str,
         request: Request,
         request_body: bytes
     ) -> StreamingResponse:
         """
         代理 SSE 流式请求到 aichat
-        
+
         Args:
             path: API 路径(如 /report/complete-flow)
             request: FastAPI Request 对象
             request_body: 请求体
-            
+
         Returns:
             StreamingResponse
         """
         url = f"{self.base_url}{path}"
         headers = self._get_auth_headers(request)
         headers["Content-Type"] = "application/json"
-        
+
         logger.info(f"[AIChat代理] SSE 请求: {url}")
-        
+
         async def stream_generator() -> AsyncGenerator[bytes, None]:
             try:
                 async with httpx.AsyncClient(timeout=self.timeout) as client:
@@ -63,26 +63,27 @@ class AIChatProxy:
                     ) as response:
                         if response.status_code != 200:
                             error_text = await response.aread()
-                            logger.error(f"[AIChat代理] SSE 请求失败: {response.status_code} {error_text.decode()}")
+                            logger.error(
+                                f"[AIChat代理] SSE 请求失败: {response.status_code} {error_text.decode()}")
                             error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"AIChat服务返回异常: {response.status_code}\"}}\n\n"
                             yield error_msg.encode('utf-8')
                             yield b"data: {\"type\": \"completed\"}\n\n"
                             return
-                        
+
                         # 流式转发响应
                         async for chunk in response.aiter_bytes(chunk_size=4096):
                             yield chunk
-                            
+
             except httpx.TimeoutException:
                 logger.error("[AIChat代理] SSE 请求超时")
-                yield b"data: {\"type\": \"online_error\", \"message\": \"AIChat服务请求超时\"}\n\n"
+                yield f'data: {{"type": "online_error", "message": "AIChat服务请求超时"}}\n\n'.encode('utf-8')
                 yield b"data: {\"type\": \"completed\"}\n\n"
             except Exception as e:
                 logger.error(f"[AIChat代理] SSE 请求异常: {e}")
-                error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"AIChat服务异常: {str(e)}\"}}\n\n"
+                yield f'data: {{"type": "online_error", "message": "AIChat服务请求超时"}}\n\n'.encode('utf-8')
                 yield error_msg.encode('utf-8')
                 yield b"data: {\"type\": \"completed\"}\n\n"
-        
+
         return StreamingResponse(
             stream_generator(),
             media_type="text/event-stream",
@@ -92,7 +93,7 @@ class AIChatProxy:
                 "Access-Control-Allow-Origin": "*",
             }
         )
-    
+
     async def proxy_json(
         self,
         path: str,
@@ -101,21 +102,21 @@ class AIChatProxy:
     ) -> JSONResponse:
         """
         代理 JSON 请求到 aichat
-        
+
         Args:
             path: API 路径(如 /report/update-ai-message)
             request: FastAPI Request 对象
             request_body: 请求体
-            
+
         Returns:
             JSONResponse
         """
         url = f"{self.base_url}{path}"
         headers = self._get_auth_headers(request)
         headers["Content-Type"] = "application/json"
-        
+
         logger.info(f"[AIChat代理] JSON 请求: {url}")
-        
+
         try:
             async with httpx.AsyncClient(timeout=30) as client:
                 response = await client.post(
@@ -123,13 +124,13 @@ class AIChatProxy:
                     content=request_body,
                     headers=headers
                 )
-                
+
                 # 转发响应
                 return JSONResponse(
                     content=response.json(),
                     status_code=response.status_code
                 )
-                
+
         except httpx.TimeoutException:
             logger.error("[AIChat代理] JSON 请求超时")
             return JSONResponse(
@@ -142,11 +143,11 @@ class AIChatProxy:
                 content={"success": False, "message": f"AIChat服务异常: {str(e)}"},
                 status_code=500
             )
-    
+
     async def health_check(self) -> bool:
         """
         检查 aichat 服务健康状态
-        
+
         Returns:
             True 表示服务可用,False 表示不可用
         """

+ 2 - 1
shudao-chat-py/utils/__init__.py

@@ -1,5 +1,6 @@
 from .config import settings, get_base_url, get_proxy_url
-from .token import TokenUserInfo, verify_token, get_user_info_from_token
+# from .token import TokenUserInfo, verify_token, get_user_info_from_token
+from .token import verify_local_token
 from .crypto import encrypt_url, decrypt_url
 from .string_match import levenshtein_distance, string_similarity, find_best_match
 

+ 15 - 13
shudao-chat-py/utils/auth_middleware.py

@@ -1,6 +1,6 @@
 from fastapi import Request, HTTPException, status
 from fastapi.responses import JSONResponse
-from .token import verify_token
+from .token import verify_local_token
 from .logger import logger
 
 
@@ -18,7 +18,7 @@ async def auth_middleware(request: Request, call_next):
         "/apiv1/auth/local_login",
         "/apiv1/auth/register"
     ]
-    
+
     # 检查是否在白名单中
     path = request.url.path
     for whitelist_path in whitelist_paths:
@@ -26,35 +26,37 @@ async def auth_middleware(request: Request, call_next):
             # 白名单路径也设置一个默认user,避免后续访问出错
             request.state.user = None
             return await call_next(request)
-    
+
     # 获取Token
-    token = request.headers.get("token") or request.headers.get("Authorization", "").replace("Bearer ", "")
-    
+    token = request.headers.get("token") or request.headers.get(
+        "Authorization", "").replace("Bearer ", "")
+
     logger.info(f"认证中间件 - 路径: {path}")
     logger.info(f"认证中间件 - Token (前20字符): {token[:20] if token else 'None'}...")
-    
+
     if not token:
         logger.warning("认证中间件 - 未提供Token")
         return JSONResponse(
             status_code=status.HTTP_401_UNAUTHORIZED,
             content={"code": 401, "msg": "未提供认证Token"}
         )
-    
+
     # 验证Token
     logger.info("认证中间件 - 开始验证Token")
-    user_info = await verify_token(token)
-    
+    user_info = await verify_local_token(token)
+
     if not user_info:
         logger.error("认证中间件 - Token验证失败,返回401")
         return JSONResponse(
             status_code=status.HTTP_401_UNAUTHORIZED,
             content={"code": 401, "msg": "Token验证失败"}
         )
-    
-    logger.info(f"认证中间件 - Token验证成功,用户: {user_info.username} ({user_info.account})")
-    
+
+    logger.info(
+        f"认证中间件 - Token验证成功,用户: {user_info.username} ({user_info.account})")
+
     # 将用户信息存储到request.state中
     request.state.user = user_info
-    
+
     response = await call_next(request)
     return response