total.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse
  3. from sqlalchemy.orm import Session
  4. from pydantic import BaseModel
  5. from typing import Optional
  6. from database import get_db
  7. from models.total import RecommendQuestion, PolicyFile, FunctionCard, HotQuestion, FeedbackQuestion
  8. from models.chat import AIMessage
  9. from models.user_data import UserData
  10. from services.oss_service import oss_service
  11. from utils.crypto import decrypt_url
  12. import time
  13. import httpx
  14. router = APIRouter()
  15. @router.get("/recommend_question")
  16. async def get_recommend_question(db: Session = Depends(get_db)):
  17. """获取推荐问题"""
  18. questions = db.query(RecommendQuestion).limit(10).all()
  19. return {
  20. "statusCode": 200,
  21. "msg": "success",
  22. "data": [{"id": q.id, "question": q.question} for q in questions]
  23. }
  24. @router.get("/get_policy_file")
  25. async def get_policy_file(
  26. policy_type: Optional[int] = None,
  27. page: int = 1,
  28. page_size: int = 20,
  29. db: Session = Depends(get_db)
  30. ):
  31. """获取策略文件列表"""
  32. query = db.query(PolicyFile).filter(PolicyFile.is_deleted == 0)
  33. if policy_type is not None and policy_type != 0:
  34. query = query.filter(PolicyFile.policy_type == policy_type)
  35. total = query.count()
  36. offset = (page - 1) * page_size
  37. files = query.order_by(PolicyFile.updated_at.desc()).offset(
  38. offset).limit(page_size).all()
  39. return {
  40. "statusCode": 200,
  41. "msg": "success",
  42. "data": {
  43. "total": total,
  44. "items": [
  45. {
  46. "id": f.id,
  47. "policy_name": f.policy_name,
  48. "policy_file_url": f.policy_file_url,
  49. "policy_type": f.policy_type,
  50. "file_type": f.file_type,
  51. "view_count": f.view_count,
  52. "created_at": f.created_at
  53. }
  54. for f in files
  55. ]
  56. }
  57. }
  58. @router.get("/get_function_card")
  59. async def get_function_card(db: Session = Depends(get_db)):
  60. """获取功能卡片"""
  61. cards = db.query(FunctionCard).limit(4).all()
  62. return {
  63. "statusCode": 200,
  64. "msg": "success",
  65. "data": [
  66. {
  67. "id": c.id,
  68. "function_title": c.function_title,
  69. "function_icon": c.function_icon,
  70. "function_content": c.function_content,
  71. "function_type": c.function_type
  72. }
  73. for c in cards
  74. ]
  75. }
  76. @router.get("/get_hot_question")
  77. async def get_hot_question(db: Session = Depends(get_db)):
  78. """获取热点问题(按点击量排序)"""
  79. questions = db.query(HotQuestion).order_by(
  80. HotQuestion.click_count.desc()).limit(3).all()
  81. return {
  82. "statusCode": 200,
  83. "msg": "success",
  84. "data": [
  85. {
  86. "id": q.id,
  87. "question": q.question,
  88. "click_count": q.click_count or 0
  89. }
  90. for q in questions
  91. ]
  92. }
  93. class SubmitFeedbackRequest(BaseModel):
  94. feedback_type: str
  95. content: str
  96. contact: str = ""
  97. @router.post("/submit_feedback")
  98. async def submit_feedback(request: SubmitFeedbackRequest, req: Request, db: Session = Depends(get_db)):
  99. """提交意见反馈(对齐Go版本)"""
  100. # 从token获取user_id
  101. user = req.state.user
  102. user_id = user.id if user else 0
  103. # 映射反馈类型:支持中文描述和英文标识
  104. type_map = {
  105. "功能建议": 1, "bug": 1, "问题反馈": 1,
  106. "ui": 2, "界面优化": 2,
  107. "experience": 3, "体验问题": 3,
  108. "other": 4, "其他": 4
  109. }
  110. feedback_type_id = type_map.get(request.feedback_type, 4)
  111. feedback = FeedbackQuestion(
  112. feedback_type=feedback_type_id,
  113. feedback_content=request.content,
  114. feedback_user_phone=request.contact,
  115. user_id=user_id,
  116. created_at=int(time.time()),
  117. updated_at=int(time.time())
  118. )
  119. db.add(feedback)
  120. db.commit()
  121. return {
  122. "statusCode": 200,
  123. "msg": "感谢您的反馈!"
  124. }
  125. class LikeDislikeRequest(BaseModel):
  126. ai_message_id: int
  127. action: str # "like" 或 "dislike"
  128. @router.post("/like_and_dislike")
  129. async def like_and_dislike(request: LikeDislikeRequest, db: Session = Depends(get_db)):
  130. """点赞/踩(对齐Go版本)"""
  131. message = db.query(AIMessage).filter(
  132. AIMessage.id == request.ai_message_id).first()
  133. if not message:
  134. return {"statusCode": 404, "msg": "消息不存在"}
  135. # 将action转换为user_feedback:like=2(满意/赞), dislike=3(不满意/踩)
  136. user_feedback = 2 if request.action == "like" else 3
  137. message.user_feedback = user_feedback
  138. message.updated_at = int(time.time())
  139. db.commit()
  140. return {"statusCode": 200, "msg": "success"}
  141. @router.get("/get_user_data_id")
  142. async def get_user_data_id(request: Request, db: Session = Depends(get_db)):
  143. """通过 token 的 account_id 查询 UserData 主键"""
  144. user = request.state.user
  145. if not user:
  146. return {"statusCode": 401, "msg": "未认证"}
  147. user_data = db.query(UserData).filter(
  148. UserData.accountID == user.account).first()
  149. if not user_data:
  150. return {"statusCode": 404, "msg": "用户数据不存在"}
  151. return {
  152. "statusCode": 200,
  153. "msg": "success",
  154. "data": {"user_data_id": user_data.id}
  155. }
  156. class PolicyFileCountRequest(BaseModel):
  157. policy_file_id: int
  158. @router.post("/policy_file_count")
  159. async def get_policy_file_view_and_download_count(request: PolicyFileCountRequest, db: Session = Depends(get_db)):
  160. """更新政策文件查看计数 - 对齐Go版本接口名"""
  161. policy_file = db.query(PolicyFile).filter(
  162. PolicyFile.id == request.policy_file_id).first()
  163. if not policy_file:
  164. return {"statusCode": 404, "msg": "文件不存在"}
  165. policy_file.view_count = (policy_file.view_count or 0) + 1
  166. policy_file.updated_at = int(time.time())
  167. db.commit()
  168. return {"statusCode": 200, "msg": "success"}
  169. @router.get("/download_file")
  170. async def get_pdf_oss_download_link(pdf_oss_download_link: str):
  171. """流式代理下载 OSS 文件(解密代理 URL)- 对齐Go版本接口名"""
  172. try:
  173. # 解密代理 URL 获取真实 OSS URL
  174. real_url = decrypt_url(pdf_oss_download_link)
  175. # 流式代理下载
  176. async with httpx.AsyncClient() as client:
  177. async with client.stream("GET", real_url) as response:
  178. if response.status_code != 200:
  179. return {"statusCode": response.status_code, "msg": "文件下载失败"}
  180. # 获取文件名和内容类型
  181. content_type = response.headers.get(
  182. "content-type", "application/octet-stream")
  183. content_disposition = response.headers.get(
  184. "content-disposition", "")
  185. async def generate():
  186. async for chunk in response.aiter_bytes():
  187. yield chunk
  188. return StreamingResponse(
  189. generate(),
  190. media_type=content_type,
  191. headers={
  192. "Content-Disposition": content_disposition} if content_disposition else {}
  193. )
  194. except Exception as e:
  195. return {"statusCode": 500, "msg": f"下载失败: {str(e)}"}