瀏覽代碼

bugfix:修复与ai问答有关的调用错误问题

Cline 4 天之前
父節點
當前提交
174506e8ed

+ 35 - 0
shudao-chat-py/fix_user_id.py

@@ -0,0 +1,35 @@
+import re
+import os
+
+# 获取脚本所在目录
+script_dir = os.path.dirname(os.path.abspath(__file__))
+
+# 需要修复的文件列表
+files_to_fix = [
+    'routers/chat.py',
+    'routers/scene.py',
+    'routers/hazard.py'
+]
+
+for file_path in files_to_fix:
+    full_path = os.path.join(script_dir, file_path)
+    if not os.path.exists(full_path):
+        print(f'文件不存在: {full_path}')
+        continue
+    
+    print(f'正在修复: {file_path}')
+    
+    # 读取文件
+    with open(full_path, 'r', encoding='utf-8') as f:
+        content = f.read()
+    
+    # 替换所有 user.user_id 为 user.userCode
+    new_content = re.sub(r'user\.user_id', 'user.userCode', content)
+    
+    # 写回文件
+    with open(full_path, 'w', encoding='utf-8') as f:
+        f.write(new_content)
+    
+    print(f'✓ {file_path} 修复完成')
+
+print('\n所有文件修复完成!')

+ 4 - 0
shudao-chat-py/migrate_tracking_user_id.sql

@@ -0,0 +1,4 @@
+-- 修改 tracking_record 表的 user_id 字段类型
+-- 从 BIGINT 改为 VARCHAR(50) 以支持字符串类型的 userCode
+
+ALTER TABLE `tracking_record` MODIFY `user_id` VARCHAR(50);

+ 1 - 1
shudao-chat-py/models/tracking.py

@@ -6,7 +6,7 @@ class TrackingRecord(Base):
     __tablename__ = "tracking_record"
 
     id = Column(BigInteger, primary_key=True, autoincrement=True)
-    user_id = Column(BigInteger)
+    user_id = Column(String(50))
     api_path = Column(String(255))
     api_name = Column(String(255))
     method = Column(String(10), default="POST")

+ 12 - 12
shudao-chat-py/routers/chat.py

@@ -108,7 +108,7 @@ async def send_deepseek_message(
         # 创建或获取对话
         if not data.conversation_id:
             conversation = AIConversation(
-                user_id=user.user_id,
+                user_id=user.userCode,
                 content=message[:100],
                 business_type=data.business_type,
                 exam_name=data.exam_name if data.business_type == 3 else "",
@@ -241,7 +241,7 @@ async def send_deepseek_message(
             "data": {
                 "conversation_id": conv_id,
                 "response": response_text,
-                "user_id": user.user_id,
+                "user_id": user.userCode,
                 "business_type": data.business_type,
             },
         }
@@ -259,7 +259,7 @@ async def get_history_record(request: Request, db: Session = Depends(get_db)):
     conversations = (
         db.query(AIConversation)
         .filter(
-            AIConversation.user_id == user.user_id,
+            AIConversation.user_id == user.userCode,
             AIConversation.is_deleted == 0,
         )
         .order_by(AIConversation.created_at.desc())
@@ -301,7 +301,7 @@ async def delete_conversation(
 
     db.query(AIConversation).filter(
         AIConversation.id == data.ai_conversation_id,
-        AIConversation.user_id == user.user_id,
+        AIConversation.user_id == user.userCode,
     ).update({"is_deleted": 1, "updated_at": int(time.time())})
 
     db.query(AIMessage).filter(
@@ -326,7 +326,7 @@ async def delete_history_record(
         return {"statusCode": 401, "msg": "未授权"}
     db.query(AIConversation).filter(
         AIConversation.id == data.ai_conversation_id,
-        AIConversation.user_id == user.user_id,
+        AIConversation.user_id == user.userCode,
     ).update({"is_deleted": 1, "updated_at": int(time.time())})
     db.commit()
     return {"statusCode": 200, "msg": "删除成功"}
@@ -426,7 +426,7 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
             # 1. 创建或获取对话
             if data.ai_conversation_id == 0:
                 conversation = AIConversation(
-                    user_id=user.user_id,
+                    user_id=user.userCode,
                     content=message[:100],
                     business_type=data.business_type,
                     exam_name=data.exam_name,
@@ -444,7 +444,7 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
             # 2. 插入用户消息
             user_msg = AIMessage(
                 ai_conversation_id=conv_id,
-                user_id=user.user_id,
+                user_id=user.userCode,
                 type="user",
                 content=message,
                 created_at=int(time.time()),
@@ -458,7 +458,7 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
             # 3. 插入 AI 占位消息
             ai_msg = AIMessage(
                 ai_conversation_id=conv_id,
-                user_id=user.user_id,
+                user_id=user.userCode,
                 type="ai",
                 content="",
                 prev_user_id=user_msg.id,
@@ -661,7 +661,7 @@ async def online_search(question: str, request: Request, db: Session = Depends(g
                 "max_text_len": 4000  # 最大文本长度
             },
             "response_mode": "blocking",
-            "user": getattr(user, "account", str(user.user_id)),
+            "user": getattr(user, "account", str(user.userCode)),
         }
 
         async with httpx.AsyncClient(timeout=30.0) as client:
@@ -744,7 +744,7 @@ async def intent_recognition(
         if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
             if data.ai_conversation_id == 0:
                 conversation = AIConversation(
-                    user_id=user.user_id,
+                    user_id=user.userCode,
                     content=data.message[:100],
                     business_type=0,
                     created_at=int(time.time()),
@@ -760,7 +760,7 @@ async def intent_recognition(
 
             user_msg = AIMessage(
                 ai_conversation_id=conv_id,
-                user_id=user.user_id,
+                user_id=user.userCode,
                 type="user",
                 content=data.message,
                 created_at=int(time.time()),
@@ -772,7 +772,7 @@ async def intent_recognition(
 
             ai_msg = AIMessage(
                 ai_conversation_id=conv_id,
-                user_id=user.user_id,
+                user_id=user.userCode,
                 type="ai",
                 content=response_text,
                 prev_user_id=user_msg.id,

+ 3 - 3
shudao-chat-py/routers/hazard.py

@@ -86,12 +86,12 @@ async def hazard(
         )
         
         # 5. 上传结果图片到 OSS
-        result_filename = f"hazard_detection/{user.user_id}/{int(time.time())}.jpg"
+        result_filename = f"hazard_detection/{user.userCode}/{int(time.time())}.jpg"
         result_url = await oss_service.upload_bytes(result_image_bytes, result_filename)
         
         # 6. 插入 RecognitionRecord
         record = RecognitionRecord(
-            user_id=user.user_id,
+            user_id=user.userCode,
             scene_type=data.scene_type,
             original_image_url=data.image_url,
             recognition_image_url=result_url,
@@ -145,7 +145,7 @@ async def save_step(
         # 更新步骤
         affected = db.query(RecognitionRecord).filter(
             RecognitionRecord.id == data.record_id,
-            RecognitionRecord.user_id == user.user_id
+            RecognitionRecord.user_id == user.userCode
         ).update({
             "current_step": data.current_step,
             "updated_at": int(time.time())

+ 5 - 5
shudao-chat-py/routers/scene.py

@@ -102,13 +102,13 @@ async def get_history_recognition_record(request: Request, db: Session = Depends
     
     # 获取所有记录(不限制数量)
     records = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.user_id,
+        RecognitionRecord.user_id == user.userCode,
         RecognitionRecord.is_deleted == 0
     ).order_by(RecognitionRecord.updated_at.desc()).all()
     
     # 计算总数
     total = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.user_id,
+        RecognitionRecord.user_id == user.userCode,
         RecognitionRecord.is_deleted == 0
     ).count()
     
@@ -178,7 +178,7 @@ async def delete_recognition_record(data: DeleteRecognitionRequest, request: Req
     
     db.query(RecognitionRecord).filter(
         RecognitionRecord.id == data.recognition_id,
-        RecognitionRecord.user_id == user.user_id
+        RecognitionRecord.user_id == user.userCode
     ).update({
         "is_deleted": 1,
         "deleted_at": int(time.time())
@@ -231,7 +231,7 @@ async def get_latest_recognition_record(request: Request, db: Session = Depends(
         return {"statusCode": 401, "msg": "未授权"}
     
     record = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.user_id,
+        RecognitionRecord.user_id == user.userCode,
         RecognitionRecord.is_deleted == 0
     ).order_by(RecognitionRecord.created_at.desc()).first()
     
@@ -357,7 +357,7 @@ async def get_recognition_records(
     
     # 构建查询条件
     query = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.user_id,
+        RecognitionRecord.user_id == user.userCode,
         RecognitionRecord.is_deleted == 0
     )
     

+ 2 - 2
shudao-chat-py/routers/tracking.py

@@ -30,7 +30,7 @@ async def record_tracking(
     
     # 创建埋点记录
     record = TrackingRecord(
-        user_id=user.user_id,
+        user_id=user.userCode,
         api_path=data.api_path,
         api_name=data.api_name,
         method=request.method,
@@ -55,7 +55,7 @@ async def get_tracking_records(
     user = request.state.user
     
     records = db.query(TrackingRecord).filter(
-        TrackingRecord.user_id == user.user_id
+        TrackingRecord.user_id == user.userCode
     ).order_by(TrackingRecord.created_at.desc()).limit(limit).all()
     
     return {

+ 36 - 0
shudao-chat-py/run_migration.py

@@ -0,0 +1,36 @@
+import pymysql
+import yaml
+import os
+
+# 切换到脚本所在目录
+script_dir = os.path.dirname(os.path.abspath(__file__))
+os.chdir(script_dir)
+
+# 读取配置
+with open('config.yaml', 'r', encoding='utf-8') as f:
+    config = yaml.safe_load(f)
+
+db_config = config['database']
+
+# 连接数据库
+connection = pymysql.connect(
+    host=db_config['host'],
+    port=db_config['port'],
+    user=db_config['user'],
+    password=db_config['password'],
+    database=db_config['database']
+)
+
+try:
+    with connection.cursor() as cursor:
+        # 执行迁移
+        sql = "ALTER TABLE tracking_record MODIFY user_id VARCHAR(50)"
+        print(f"执行 SQL: {sql}")
+        cursor.execute(sql)
+        connection.commit()
+        print("✓ 数据库迁移成功!tracking_record.user_id 字段已改为 VARCHAR(50)")
+except Exception as e:
+    print(f"✗ 迁移失败: {e}")
+    connection.rollback()
+finally:
+    connection.close()

+ 6 - 3
shudao-vue-frontend/src/request/axios.js

@@ -45,9 +45,10 @@ async function recordTrackingAsync(apiPath, method) {
             'Content-Type': 'application/json'
         }
         
-        // 如果有 token,添加到请求头
+        // 如果有 token,添加到请求头(同时添加 token 和 Authorization)
         if (token) {
             headers['Authorization'] = `Bearer ${token}`
+            headers['token'] = token
         }
         
         // 使用 fetch 发送埋点请求,避免使用 axios 造成循环
@@ -85,7 +86,9 @@ http.interceptors.request.use((config) => {
     if (token && !shouldSkipToken) {
         // 格式:Authorization: Bearer {refresh_token}
         config.headers['Authorization'] = `${tokenType.charAt(0).toUpperCase() + tokenType.slice(1)} ${token}`
-        console.log('🔑 已添加 Authorization 头:', `${tokenType} ${token.substring(0, 50)}...`)
+        // 同时添加 token 头,确保后端中间件能够获取到(防止代理丢失Authorization头)
+        config.headers['token'] = token
+        console.log('🔑 已添加 Authorization 和 token 头:', `${tokenType} ${token.substring(0, 50)}...`)
     } else if (shouldSkipToken) {
         console.log('🔧 开发模式:跳过 token 添加')
     }
@@ -162,4 +165,4 @@ http.interceptors.response.use((res) => {
     console.error('📡 请求错误:', message, error.response?.status)
     return Promise.reject(error)
 })
-export default http
+export default http