total.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse
  3. from sqlalchemy.orm import Session
  4. from pydantic import BaseModel
  5. from typing import Optional
  6. from database import get_db
  7. from models.total import RecommendQuestion, PolicyFile, FunctionCard, HotQuestion, FeedbackQuestion, User
  8. from models.chat import AIMessage
  9. from models.user_data import UserData
  10. from services.oss_service import oss_service
  11. from utils.crypto import decrypt_url
  12. from utils.config import get_proxy_url
  13. import time
  14. import httpx
  15. router = APIRouter()
  16. @router.get("/recommend_question")
  17. async def get_recommend_question(db: Session = Depends(get_db)):
  18. """获取推荐问题"""
  19. questions = db.query(RecommendQuestion).limit(10).all()
  20. return {
  21. "statusCode": 200,
  22. "msg": "success",
  23. "data": [{"id": q.id, "question": q.question} for q in questions]
  24. }
  25. @router.get("/get_policy_file")
  26. async def get_policy_file(
  27. policy_type: Optional[str] = None,
  28. page: int = 1,
  29. pageSize: int = 20,
  30. search: str = "",
  31. db: Session = Depends(get_db)
  32. ):
  33. """获取策略文件列表"""
  34. query = db.query(PolicyFile).filter(PolicyFile.is_deleted == 0)
  35. # 只有当policy_type有效且不为0或空字符串时才添加类型过滤
  36. if policy_type and policy_type != "" and policy_type != "0":
  37. try:
  38. policy_type_int = int(policy_type)
  39. query = query.filter(PolicyFile.policy_type == policy_type_int)
  40. except ValueError:
  41. pass # 忽略无效的类型值
  42. if search:
  43. query = query.filter(PolicyFile.policy_name.like(f"%{search}%"))
  44. offset = (page - 1) * pageSize
  45. files = query.order_by(PolicyFile.updated_at.desc()).offset(
  46. offset).limit(pageSize).all()
  47. return {
  48. "statusCode": 200,
  49. "msg": "success",
  50. "data": [
  51. {
  52. "id": f.id,
  53. "policy_name": f.policy_name,
  54. "policy_file_url": get_proxy_url(f.policy_file_url),
  55. "policy_type": f.policy_type,
  56. "file_type": f.file_type,
  57. "file_tag": getattr(f, 'file_tag', ''),
  58. "publish_time": getattr(f, 'publish_time', f.created_at),
  59. "view_count": f.view_count,
  60. "created_at": f.created_at,
  61. "updated_at": f.updated_at
  62. }
  63. for f in files
  64. ]
  65. }
  66. @router.get("/get_function_card")
  67. async def get_function_card(
  68. function_type: Optional[int] = None,
  69. db: Session = Depends(get_db)
  70. ):
  71. """获取功能卡片"""
  72. query = db.query(FunctionCard).filter(FunctionCard.is_deleted == 0)
  73. if function_type is not None:
  74. query = query.filter(FunctionCard.function_type == function_type)
  75. cards = query.order_by(FunctionCard.id.asc()).limit(4).all()
  76. return {
  77. "statusCode": 200,
  78. "msg": "success",
  79. "data": [
  80. {
  81. "id": c.id,
  82. "function_title": c.function_title,
  83. "function_icon": c.function_icon,
  84. "function_content": c.function_content,
  85. "function_type": c.function_type
  86. }
  87. for c in cards
  88. ]
  89. }
  90. @router.get("/get_hot_question")
  91. async def get_hot_question(
  92. question_type: Optional[int] = None,
  93. db: Session = Depends(get_db)
  94. ):
  95. """获取热点问题(按点击量排序)"""
  96. query = db.query(HotQuestion).filter(HotQuestion.is_deleted == 0)
  97. if question_type is not None:
  98. query = query.filter(HotQuestion.question_type == question_type)
  99. questions = query.order_by(
  100. HotQuestion.click_count.desc(),
  101. HotQuestion.id.asc()
  102. ).limit(3).all()
  103. return {
  104. "statusCode": 200,
  105. "msg": "success",
  106. "data": [
  107. {
  108. "id": q.id,
  109. "question": q.question,
  110. "click_count": q.click_count or 0,
  111. "question_type": q.question_type
  112. }
  113. for q in questions
  114. ]
  115. }
  116. class SubmitFeedbackRequest(BaseModel):
  117. feedback_type: str
  118. content: str
  119. contact: str = ""
  120. @router.post("/submit_feedback")
  121. async def submit_feedback(request: SubmitFeedbackRequest, req: Request, db: Session = Depends(get_db)):
  122. """提交意见反馈(对齐Go版本)"""
  123. # 从token获取user_id
  124. user = req.state.user
  125. user_id = user.id if user else 0
  126. # 映射反馈类型:支持中文描述和英文标识
  127. type_map = {
  128. "功能建议": 1, "bug": 1, "问题反馈": 1,
  129. "ui": 2, "界面优化": 2,
  130. "experience": 3, "体验问题": 3,
  131. "other": 4, "其他": 4
  132. }
  133. feedback_type_id = type_map.get(request.feedback_type, 4)
  134. feedback = FeedbackQuestion(
  135. feedback_type=feedback_type_id,
  136. feedback_content=request.content,
  137. feedback_user_phone=request.contact,
  138. user_id=user_id,
  139. created_at=int(time.time()),
  140. updated_at=int(time.time())
  141. )
  142. db.add(feedback)
  143. db.commit()
  144. return {
  145. "statusCode": 200,
  146. "msg": "感谢您的反馈!"
  147. }
  148. class LikeDislikeRequest(BaseModel):
  149. ai_message_id: Optional[int] = None
  150. action: Optional[str] = None
  151. id: Optional[int] = None
  152. user_feedback: Optional[int] = None
  153. FEEDBACK_NONE = 0
  154. FEEDBACK_LIKE = 2
  155. FEEDBACK_DISLIKE = 3
  156. FEEDBACK_REWARD_POINTS = {
  157. FEEDBACK_NONE: 0,
  158. FEEDBACK_LIKE: 2,
  159. FEEDBACK_DISLIKE: 1,
  160. }
  161. def _resolve_like_dislike_payload(data: LikeDislikeRequest):
  162. message_id = data.ai_message_id or data.id
  163. if not message_id:
  164. return None, None, "缺少消息ID"
  165. if data.user_feedback is not None:
  166. feedback = int(data.user_feedback)
  167. else:
  168. action = (data.action or "").strip().lower()
  169. if action in ("like", "2"):
  170. feedback = FEEDBACK_LIKE
  171. elif action in ("dislike", "3"):
  172. feedback = FEEDBACK_DISLIKE
  173. elif action in ("", "none", "cancel", "0"):
  174. feedback = FEEDBACK_NONE
  175. else:
  176. return None, None, "反馈类型错误"
  177. if feedback not in (FEEDBACK_NONE, FEEDBACK_LIKE, FEEDBACK_DISLIKE):
  178. return None, None, "反馈类型错误"
  179. return message_id, feedback, None
  180. def _find_current_points_holder(db: Session, user_info):
  181. user = db.query(User).filter(
  182. User.id == user_info.user_id,
  183. User.is_deleted == 0,
  184. ).first()
  185. if user:
  186. return user
  187. return db.query(UserData).filter(
  188. UserData.accountID == user_info.account,
  189. ).first()
  190. @router.post("/like_and_dislike")
  191. async def like_and_dislike(data: LikeDislikeRequest, request: Request, db: Session = Depends(get_db)):
  192. """Save AI message feedback and reward points to the current user."""
  193. user_info = request.state.user
  194. if not user_info:
  195. return {"statusCode": 401, "msg": "未认证"}
  196. message_id, feedback, error = _resolve_like_dislike_payload(data)
  197. if error:
  198. return {"statusCode": 400, "msg": error}
  199. message = db.query(AIMessage).filter(AIMessage.id == message_id).first()
  200. if not message:
  201. return {"statusCode": 404, "msg": "消息不存在"}
  202. if getattr(message, "user_id", user_info.user_id) != user_info.user_id:
  203. return {"statusCode": 403, "msg": "无权评价该消息"}
  204. points_holder = _find_current_points_holder(db, user_info)
  205. if not points_holder:
  206. return {"statusCode": 404, "msg": "未找到用户数据"}
  207. previous_feedback = int(message.user_feedback or 0)
  208. previous_points = FEEDBACK_REWARD_POINTS.get(previous_feedback, 0)
  209. current_points = FEEDBACK_REWARD_POINTS.get(feedback, 0)
  210. points_delta = current_points - previous_points
  211. try:
  212. message.user_feedback = feedback
  213. message.updated_at = int(time.time())
  214. new_balance = (points_holder.points or 0) + points_delta
  215. points_holder.points = new_balance
  216. db.commit()
  217. except Exception as e:
  218. db.rollback()
  219. return {"statusCode": 500, "msg": f"反馈提交失败: {str(e)}"}
  220. return {
  221. "statusCode": 200,
  222. "msg": "success",
  223. "data": {
  224. "ai_message_id": message_id,
  225. "user_feedback": feedback,
  226. "points_added": points_delta,
  227. "points_delta": points_delta,
  228. "new_balance": new_balance,
  229. },
  230. }
  231. @router.get("/get_user_data_id")
  232. async def get_user_data_id(request: Request, db: Session = Depends(get_db)):
  233. """通过 token 的 account_id 查询 UserData 主键"""
  234. user = request.state.user
  235. if not user:
  236. return {"statusCode": 401, "msg": "未认证"}
  237. user_data = db.query(UserData).filter(
  238. UserData.accountID == user.account).first()
  239. if not user_data:
  240. return {"statusCode": 404, "msg": "用户数据不存在"}
  241. return {
  242. "statusCode": 200,
  243. "msg": "success",
  244. "data": {"user_data_id": user_data.id}
  245. }
  246. class PolicyFileCountRequest(BaseModel):
  247. policy_file_id: int
  248. @router.post("/policy_file_count")
  249. async def get_policy_file_view_and_download_count(request: PolicyFileCountRequest, db: Session = Depends(get_db)):
  250. """更新政策文件查看计数 - 对齐Go版本接口名"""
  251. policy_file = db.query(PolicyFile).filter(
  252. PolicyFile.id == request.policy_file_id).first()
  253. if not policy_file:
  254. return {"statusCode": 404, "msg": "文件不存在"}
  255. policy_file.view_count = (policy_file.view_count or 0) + 1
  256. policy_file.updated_at = int(time.time())
  257. db.commit()
  258. return {"statusCode": 200, "msg": "success"}
  259. @router.get("/download_file")
  260. async def get_pdf_oss_download_link(pdf_oss_download_link: str):
  261. """流式代理下载 OSS 文件(解密代理 URL)- 对齐Go版本接口名"""
  262. try:
  263. # 解密代理 URL 获取真实 OSS URL
  264. real_url = decrypt_url(pdf_oss_download_link)
  265. # 流式代理下载
  266. async with httpx.AsyncClient() as client:
  267. async with client.stream("GET", real_url) as response:
  268. if response.status_code != 200:
  269. return {"statusCode": response.status_code, "msg": "文件下载失败"}
  270. # 获取文件名和内容类型
  271. content_type = response.headers.get(
  272. "content-type", "application/octet-stream")
  273. content_disposition = response.headers.get(
  274. "content-disposition", "")
  275. async def generate():
  276. async for chunk in response.aiter_bytes():
  277. yield chunk
  278. return StreamingResponse(
  279. generate(),
  280. media_type=content_type,
  281. headers={
  282. "Content-Disposition": content_disposition} if content_disposition else {}
  283. )
  284. except Exception as e:
  285. return {"statusCode": 500, "msg": f"下载失败: {str(e)}"}