|
@@ -1,5 +1,6 @@
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
from fastapi import APIRouter, Depends, Request
|
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
|
|
|
+from sqlalchemy import or_
|
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import Session
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
from typing import Optional
|
|
@@ -36,6 +37,28 @@ def _build_conversation_title(conversation: AIConversation) -> str:
|
|
|
return _build_conversation_preview(conversation.content or "", limit=30)
|
|
return _build_conversation_preview(conversation.content or "", limit=30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _is_exam_workshop_conversation(conversation: Optional[AIConversation]) -> bool:
|
|
|
|
|
+ if not conversation:
|
|
|
|
|
+ return False
|
|
|
|
|
+ return conversation.business_type == 3 or bool((conversation.exam_name or "").strip())
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _resolve_conversation_metadata(
|
|
|
|
|
+ conversation: Optional[AIConversation],
|
|
|
|
|
+ requested_business_type: int,
|
|
|
|
|
+ requested_exam_name: str,
|
|
|
|
|
+) -> tuple[int, str]:
|
|
|
|
|
+ requested_exam_name = (requested_exam_name or "").strip()
|
|
|
|
|
+
|
|
|
|
|
+ if _is_exam_workshop_conversation(conversation):
|
|
|
|
|
+ return 3, requested_exam_name or (conversation.exam_name or "").strip()
|
|
|
|
|
+
|
|
|
|
|
+ if requested_business_type == 3:
|
|
|
|
|
+ return 3, requested_exam_name
|
|
|
|
|
+
|
|
|
|
|
+ return requested_business_type, ""
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
|
|
def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
|
|
|
latest_message = (
|
|
latest_message = (
|
|
|
db.query(AIMessage)
|
|
db.query(AIMessage)
|
|
@@ -72,7 +95,8 @@ def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: i
|
|
|
if latest_user_message and latest_user_message.content
|
|
if latest_user_message and latest_user_message.content
|
|
|
else latest_message.content
|
|
else latest_message.content
|
|
|
)
|
|
)
|
|
|
- preview_content = _build_conversation_preview(preview_source or "", limit=100)
|
|
|
|
|
|
|
+ preview_content = _build_conversation_preview(
|
|
|
|
|
+ preview_source or "", limit=100)
|
|
|
|
|
|
|
|
db.query(AIConversation).filter(
|
|
db.query(AIConversation).filter(
|
|
|
AIConversation.id == conversation_id,
|
|
AIConversation.id == conversation_id,
|
|
@@ -176,13 +200,31 @@ async def send_deepseek_message(
|
|
|
|
|
|
|
|
conversation_id = data.conversation_id or data.ai_conversation_id
|
|
conversation_id = data.conversation_id or data.ai_conversation_id
|
|
|
|
|
|
|
|
|
|
+ existing_conversation = None
|
|
|
|
|
+ if conversation_id:
|
|
|
|
|
+ existing_conversation = (
|
|
|
|
|
+ db.query(AIConversation)
|
|
|
|
|
+ .filter(
|
|
|
|
|
+ AIConversation.id == conversation_id,
|
|
|
|
|
+ AIConversation.user_id == user.user_id,
|
|
|
|
|
+ AIConversation.is_deleted == 0,
|
|
|
|
|
+ )
|
|
|
|
|
+ .first()
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ effective_business_type, effective_exam_name = _resolve_conversation_metadata(
|
|
|
|
|
+ existing_conversation,
|
|
|
|
|
+ data.business_type,
|
|
|
|
|
+ data.exam_name,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
# 创建或获取对话
|
|
# 创建或获取对话
|
|
|
if not conversation_id:
|
|
if not conversation_id:
|
|
|
conversation = AIConversation(
|
|
conversation = AIConversation(
|
|
|
user_id=user.user_id,
|
|
user_id=user.user_id,
|
|
|
content=message[:100],
|
|
content=message[:100],
|
|
|
- business_type=data.business_type,
|
|
|
|
|
- exam_name=data.exam_name if data.business_type == 3 else "",
|
|
|
|
|
|
|
+ business_type=effective_business_type,
|
|
|
|
|
+ exam_name=effective_exam_name,
|
|
|
created_at=int(time.time()),
|
|
created_at=int(time.time()),
|
|
|
updated_at=int(time.time()),
|
|
updated_at=int(time.time()),
|
|
|
is_deleted=0,
|
|
is_deleted=0,
|
|
@@ -199,8 +241,8 @@ async def send_deepseek_message(
|
|
|
AIConversation.is_deleted == 0,
|
|
AIConversation.is_deleted == 0,
|
|
|
).update({
|
|
).update({
|
|
|
"content": message[:100],
|
|
"content": message[:100],
|
|
|
- "business_type": data.business_type,
|
|
|
|
|
- "exam_name": data.exam_name if data.business_type == 3 else "",
|
|
|
|
|
|
|
+ "business_type": effective_business_type,
|
|
|
|
|
+ "exam_name": effective_exam_name,
|
|
|
"updated_at": int(time.time()),
|
|
"updated_at": int(time.time()),
|
|
|
})
|
|
})
|
|
|
db.commit()
|
|
db.commit()
|
|
@@ -214,7 +256,8 @@ async def send_deepseek_message(
|
|
|
intent_type = ""
|
|
intent_type = ""
|
|
|
if isinstance(intent_result, dict):
|
|
if isinstance(intent_result, dict):
|
|
|
intent_type = (
|
|
intent_type = (
|
|
|
- intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
|
|
|
|
+ intent_result.get("intent_type") or intent_result.get(
|
|
|
|
|
+ "intent") or ""
|
|
|
).lower()
|
|
).lower()
|
|
|
|
|
|
|
|
rag_context = ""
|
|
rag_context = ""
|
|
@@ -237,7 +280,8 @@ async def send_deepseek_message(
|
|
|
try:
|
|
try:
|
|
|
if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"):
|
|
if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"):
|
|
|
response_json = json.loads(qwen_response)
|
|
response_json = json.loads(qwen_response)
|
|
|
- response_text = response_json.get("natural_language_answer", qwen_response)
|
|
|
|
|
|
|
+ response_text = response_json.get(
|
|
|
|
|
+ "natural_language_answer", qwen_response)
|
|
|
else:
|
|
else:
|
|
|
response_text = qwen_response
|
|
response_text = qwen_response
|
|
|
except Exception:
|
|
except Exception:
|
|
@@ -294,6 +338,8 @@ async def send_deepseek_message(
|
|
|
system_content = (
|
|
system_content = (
|
|
|
"你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
|
|
"你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
|
|
|
"请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
|
|
"请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
|
|
|
|
|
+ "用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n"
|
|
|
|
|
+ "题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n"
|
|
|
"输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
|
|
"输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
|
|
|
"JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
|
|
"JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
|
|
|
"singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
|
|
"singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
|
|
@@ -310,9 +356,41 @@ async def send_deepseek_message(
|
|
|
|
|
|
|
|
response_text = await qwen_service.chat(messages)
|
|
response_text = await qwen_service.chat(messages)
|
|
|
|
|
|
|
|
- if data.exam_name:
|
|
|
|
|
|
|
+ now_ts = int(time.time())
|
|
|
|
|
+ user_message = AIMessage(
|
|
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
|
|
+ user_id=user.user_id,
|
|
|
|
|
+ type="user",
|
|
|
|
|
+ content=message,
|
|
|
|
|
+ created_at=now_ts,
|
|
|
|
|
+ updated_at=now_ts,
|
|
|
|
|
+ is_deleted=0,
|
|
|
|
|
+ )
|
|
|
|
|
+ db.add(user_message)
|
|
|
|
|
+ db.commit()
|
|
|
|
|
+ db.refresh(user_message)
|
|
|
|
|
+
|
|
|
|
|
+ ai_message = AIMessage(
|
|
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
|
|
+ user_id=user.user_id,
|
|
|
|
|
+ type="ai",
|
|
|
|
|
+ content=response_text,
|
|
|
|
|
+ prev_user_id=user_message.id,
|
|
|
|
|
+ created_at=now_ts,
|
|
|
|
|
+ updated_at=now_ts,
|
|
|
|
|
+ is_deleted=0,
|
|
|
|
|
+ )
|
|
|
|
|
+ db.add(ai_message)
|
|
|
|
|
+ db.commit()
|
|
|
|
|
+
|
|
|
|
|
+ _refresh_conversation_snapshot(db, conv_id, user.user_id)
|
|
|
|
|
+ db.commit()
|
|
|
|
|
+
|
|
|
|
|
+ if effective_exam_name:
|
|
|
db.query(AIConversation).filter(AIConversation.id == conv_id).update(
|
|
db.query(AIConversation).filter(AIConversation.id == conv_id).update(
|
|
|
- {"exam_name": data.exam_name, "updated_at": int(time.time())}
|
|
|
|
|
|
|
+ {"business_type": 3,
|
|
|
|
|
+ "exam_name": effective_exam_name,
|
|
|
|
|
+ "updated_at": int(time.time())}
|
|
|
)
|
|
)
|
|
|
db.commit()
|
|
db.commit()
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -333,7 +411,7 @@ async def send_deepseek_message(
|
|
|
"content": response_text,
|
|
"content": response_text,
|
|
|
"message": response_text,
|
|
"message": response_text,
|
|
|
"user_id": user.user_id,
|
|
"user_id": user.user_id,
|
|
|
- "business_type": data.business_type,
|
|
|
|
|
|
|
+ "business_type": effective_business_type,
|
|
|
},
|
|
},
|
|
|
}
|
|
}
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -393,9 +471,18 @@ async def get_history_record(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if business_type is not None:
|
|
if business_type is not None:
|
|
|
- conversations_query = conversations_query.filter(
|
|
|
|
|
- AIConversation.business_type == business_type
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if business_type == 3:
|
|
|
|
|
+ conversations_query = conversations_query.filter(
|
|
|
|
|
+ or_(
|
|
|
|
|
+ AIConversation.business_type == 3,
|
|
|
|
|
+ AIConversation.exam_name.isnot(None),
|
|
|
|
|
+ AIConversation.exam_name != "",
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ conversations_query = conversations_query.filter(
|
|
|
|
|
+ AIConversation.business_type == business_type
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
total = conversations_query.count()
|
|
total = conversations_query.count()
|
|
|
conversations = (
|
|
conversations = (
|
|
@@ -469,7 +556,8 @@ async def delete_conversation(
|
|
|
AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
|
|
AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
|
|
|
).update({"is_deleted": 1, "updated_at": now_ts})
|
|
).update({"is_deleted": 1, "updated_at": now_ts})
|
|
|
|
|
|
|
|
- _refresh_conversation_snapshot(db, ai_message.ai_conversation_id, user.user_id)
|
|
|
|
|
|
|
+ _refresh_conversation_snapshot(
|
|
|
|
|
+ db, ai_message.ai_conversation_id, user.user_id)
|
|
|
db.commit()
|
|
db.commit()
|
|
|
return {"statusCode": 200, "msg": "删除成功"}
|
|
return {"statusCode": 200, "msg": "删除成功"}
|
|
|
|
|
|
|
@@ -532,7 +620,8 @@ async def stream_chat(request: Request, data: StreamChatRequest):
|
|
|
intent_result = await qwen_service.intent_recognition(message)
|
|
intent_result = await qwen_service.intent_recognition(message)
|
|
|
if isinstance(intent_result, dict):
|
|
if isinstance(intent_result, dict):
|
|
|
intent_type = (
|
|
intent_type = (
|
|
|
- intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
|
|
|
|
+ intent_result.get("intent_type") or intent_result.get(
|
|
|
|
|
+ "intent") or ""
|
|
|
).lower()
|
|
).lower()
|
|
|
except Exception as ie:
|
|
except Exception as ie:
|
|
|
logger.warning(f"[stream/chat] 意图识别异常: {ie}")
|
|
logger.warning(f"[stream/chat] 意图识别异常: {ie}")
|
|
@@ -644,7 +733,8 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
|
|
|
else:
|
|
else:
|
|
|
conversation = AIConversation(
|
|
conversation = AIConversation(
|
|
|
user_id=user.user_id,
|
|
user_id=user.user_id,
|
|
|
- content=_build_conversation_preview(message, limit=100),
|
|
|
|
|
|
|
+ content=_build_conversation_preview(
|
|
|
|
|
+ message, limit=100),
|
|
|
business_type=data.business_type,
|
|
business_type=data.business_type,
|
|
|
exam_name=data.exam_name if data.business_type == 3 else "",
|
|
exam_name=data.exam_name if data.business_type == 3 else "",
|
|
|
created_at=int(time.time()),
|
|
created_at=int(time.time()),
|
|
@@ -717,9 +807,10 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
|
|
|
context_parts.append(f"知识库内容:\n{rag_context}")
|
|
context_parts.append(f"知识库内容:\n{rag_context}")
|
|
|
if data.online_search_content:
|
|
if data.online_search_content:
|
|
|
context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
|
|
context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
|
|
|
-
|
|
|
|
|
- context_json = "\n\n".join(context_parts) if context_parts else "暂无相关知识库内容"
|
|
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ context_json = "\n\n".join(
|
|
|
|
|
+ context_parts) if context_parts else "暂无相关知识库内容"
|
|
|
|
|
+
|
|
|
# 使用prompt加载器加载最终回答prompt
|
|
# 使用prompt加载器加载最终回答prompt
|
|
|
system_content = load_prompt(
|
|
system_content = load_prompt(
|
|
|
"final_answer",
|
|
"final_answer",
|
|
@@ -817,7 +908,8 @@ async def guess_you_want(
|
|
|
try:
|
|
try:
|
|
|
# 尝试从响应中提取 JSON
|
|
# 尝试从响应中提取 JSON
|
|
|
import re
|
|
import re
|
|
|
- json_match = re.search(r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
|
|
|
|
|
|
|
+ json_match = re.search(
|
|
|
|
|
+ r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
|
|
|
if json_match:
|
|
if json_match:
|
|
|
response_json = json.loads(json_match.group())
|
|
response_json = json.loads(json_match.group())
|
|
|
else:
|
|
else:
|
|
@@ -894,10 +986,12 @@ async def online_search(question: str, request: Request, db: Session = Depends(g
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
|
resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
|
|
resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
|
|
|
if resp.status_code != 200:
|
|
if resp.status_code != 200:
|
|
|
- logger.error(f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
|
|
|
|
|
|
|
+ logger.error(
|
|
|
|
|
+ f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
|
|
|
return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
|
|
return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
|
|
|
result = resp.json()
|
|
result = resp.json()
|
|
|
- search_text = result.get("data", {}).get("outputs", {}).get("text", "")
|
|
|
|
|
|
|
+ search_text = result.get("data", {}).get(
|
|
|
|
|
+ "outputs", {}).get("text", "")
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
"statusCode": 200,
|
|
"statusCode": 200,
|
|
@@ -928,7 +1022,8 @@ async def save_online_search_result(
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
|
|
db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
|
|
|
- {"search_source": data.search_result, "updated_at": int(time.time())}
|
|
|
|
|
|
|
+ {"search_source": data.search_result,
|
|
|
|
|
+ "updated_at": int(time.time())}
|
|
|
)
|
|
)
|
|
|
db.commit()
|
|
db.commit()
|
|
|
return {"statusCode": 200, "msg": "保存成功"}
|
|
return {"statusCode": 200, "msg": "保存成功"}
|
|
@@ -964,7 +1059,8 @@ async def intent_recognition(
|
|
|
response_text = ""
|
|
response_text = ""
|
|
|
if isinstance(intent_result, dict):
|
|
if isinstance(intent_result, dict):
|
|
|
intent_type = (
|
|
intent_type = (
|
|
|
- intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
|
|
|
|
+ intent_result.get("intent_type") or intent_result.get(
|
|
|
|
|
+ "intent") or ""
|
|
|
).lower()
|
|
).lower()
|
|
|
response_text = intent_result.get("response", "")
|
|
response_text = intent_result.get("response", "")
|
|
|
|
|
|
|
@@ -1050,10 +1146,13 @@ async def get_user_recommend_question(
|
|
|
):
|
|
):
|
|
|
"""获取推荐问题(支持模糊查询)"""
|
|
"""获取推荐问题(支持模糊查询)"""
|
|
|
try:
|
|
try:
|
|
|
- query = db.query(RecommendQuestion).filter(RecommendQuestion.is_deleted == 0)
|
|
|
|
|
|
|
+ query = db.query(RecommendQuestion).filter(
|
|
|
|
|
+ RecommendQuestion.is_deleted == 0)
|
|
|
if keyword:
|
|
if keyword:
|
|
|
- query = query.filter(RecommendQuestion.question.like(f"%{keyword}%"))
|
|
|
|
|
- questions = query.order_by(RecommendQuestion.id.desc()).limit(limit).all()
|
|
|
|
|
|
|
+ query = query.filter(
|
|
|
|
|
+ RecommendQuestion.question.like(f"%{keyword}%"))
|
|
|
|
|
+ questions = query.order_by(
|
|
|
|
|
+ RecommendQuestion.id.desc()).limit(limit).all()
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
"statusCode": 200,
|
|
"statusCode": 200,
|