total.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse
  3. from sqlalchemy.orm import Session
  4. from pydantic import BaseModel
  5. from typing import Optional
  6. from database import get_db
  7. from models.total import RecommendQuestion, PolicyFile, FunctionCard, HotQuestion, FeedbackQuestion
  8. from models.chat import AIMessage
  9. from models.user_data import UserData
  10. from services.oss_service import oss_service
  11. from utils.crypto import decrypt_url
  12. import time
  13. import httpx
  14. router = APIRouter()
  15. @router.get("/recommend_question")
  16. async def get_recommend_question(db: Session = Depends(get_db)):
  17. """获取推荐问题"""
  18. questions = db.query(RecommendQuestion).limit(10).all()
  19. return {
  20. "statusCode": 200,
  21. "msg": "success",
  22. "data": [{"id": q.id, "question": q.question} for q in questions]
  23. }
  24. @router.get("/get_policy_file")
  25. async def get_policy_file(
  26. policy_type: Optional[int] = None,
  27. page: int = 1,
  28. page_size: int = 20,
  29. db: Session = Depends(get_db)
  30. ):
  31. """获取策略文件列表"""
  32. query = db.query(PolicyFile).filter(PolicyFile.is_deleted == 0)
  33. if policy_type is not None and policy_type != 0:
  34. query = query.filter(PolicyFile.policy_type == policy_type)
  35. total = query.count()
  36. offset = (page - 1) * page_size
  37. files = query.order_by(PolicyFile.updated_at.desc()).offset(
  38. offset).limit(page_size).all()
  39. return {
  40. "statusCode": 200,
  41. "msg": "success",
  42. "data": {
  43. "total": total,
  44. "items": [
  45. {
  46. "id": f.id,
  47. "policy_name": f.policy_name,
  48. "policy_file_url": f.policy_file_url,
  49. "policy_type": f.policy_type,
  50. "file_type": f.file_type,
  51. "view_count": f.view_count,
  52. "created_at": f.created_at
  53. }
  54. for f in files
  55. ]
  56. }
  57. }
  58. @router.get("/get_function_card")
  59. async def get_function_card(
  60. function_type: Optional[int] = None,
  61. db: Session = Depends(get_db)
  62. ):
  63. """获取功能卡片"""
  64. query = db.query(FunctionCard).filter(FunctionCard.is_deleted == 0)
  65. if function_type is not None:
  66. query = query.filter(FunctionCard.function_type == function_type)
  67. cards = query.order_by(FunctionCard.id.asc()).limit(4).all()
  68. return {
  69. "statusCode": 200,
  70. "msg": "success",
  71. "data": [
  72. {
  73. "id": c.id,
  74. "function_title": c.function_title,
  75. "function_icon": c.function_icon,
  76. "function_content": c.function_content,
  77. "function_type": c.function_type
  78. }
  79. for c in cards
  80. ]
  81. }
  82. @router.get("/get_hot_question")
  83. async def get_hot_question(
  84. question_type: Optional[int] = None,
  85. db: Session = Depends(get_db)
  86. ):
  87. """获取热点问题(按点击量排序)"""
  88. query = db.query(HotQuestion).filter(HotQuestion.is_deleted == 0)
  89. if question_type is not None:
  90. query = query.filter(HotQuestion.question_type == question_type)
  91. questions = query.order_by(
  92. HotQuestion.click_count.desc(),
  93. HotQuestion.id.asc()
  94. ).limit(3).all()
  95. return {
  96. "statusCode": 200,
  97. "msg": "success",
  98. "data": [
  99. {
  100. "id": q.id,
  101. "question": q.question,
  102. "click_count": q.click_count or 0,
  103. "question_type": q.question_type
  104. }
  105. for q in questions
  106. ]
  107. }
  108. class SubmitFeedbackRequest(BaseModel):
  109. feedback_type: str
  110. content: str
  111. contact: str = ""
  112. @router.post("/submit_feedback")
  113. async def submit_feedback(request: SubmitFeedbackRequest, req: Request, db: Session = Depends(get_db)):
  114. """提交意见反馈(对齐Go版本)"""
  115. # 从token获取user_id
  116. user = req.state.user
  117. user_id = user.id if user else 0
  118. # 映射反馈类型:支持中文描述和英文标识
  119. type_map = {
  120. "功能建议": 1, "bug": 1, "问题反馈": 1,
  121. "ui": 2, "界面优化": 2,
  122. "experience": 3, "体验问题": 3,
  123. "other": 4, "其他": 4
  124. }
  125. feedback_type_id = type_map.get(request.feedback_type, 4)
  126. feedback = FeedbackQuestion(
  127. feedback_type=feedback_type_id,
  128. feedback_content=request.content,
  129. feedback_user_phone=request.contact,
  130. user_id=user_id,
  131. created_at=int(time.time()),
  132. updated_at=int(time.time())
  133. )
  134. db.add(feedback)
  135. db.commit()
  136. return {
  137. "statusCode": 200,
  138. "msg": "感谢您的反馈!"
  139. }
  140. class LikeDislikeRequest(BaseModel):
  141. ai_message_id: int
  142. action: str # "like" 或 "dislike"
  143. @router.post("/like_and_dislike")
  144. async def like_and_dislike(request: LikeDislikeRequest, db: Session = Depends(get_db)):
  145. """点赞/踩(对齐Go版本)"""
  146. message = db.query(AIMessage).filter(
  147. AIMessage.id == request.ai_message_id).first()
  148. if not message:
  149. return {"statusCode": 404, "msg": "消息不存在"}
  150. # 将action转换为user_feedback:like=2(满意/赞), dislike=3(不满意/踩)
  151. user_feedback = 2 if request.action == "like" else 3
  152. message.user_feedback = user_feedback
  153. message.updated_at = int(time.time())
  154. db.commit()
  155. return {"statusCode": 200, "msg": "success"}
  156. @router.get("/get_user_data_id")
  157. async def get_user_data_id(request: Request, db: Session = Depends(get_db)):
  158. """通过 token 的 account_id 查询 UserData 主键"""
  159. user = request.state.user
  160. if not user:
  161. return {"statusCode": 401, "msg": "未认证"}
  162. user_data = db.query(UserData).filter(
  163. UserData.accountID == user.account).first()
  164. if not user_data:
  165. return {"statusCode": 404, "msg": "用户数据不存在"}
  166. return {
  167. "statusCode": 200,
  168. "msg": "success",
  169. "data": {"user_data_id": user_data.id}
  170. }
  171. class PolicyFileCountRequest(BaseModel):
  172. policy_file_id: int
  173. @router.post("/policy_file_count")
  174. async def get_policy_file_view_and_download_count(request: PolicyFileCountRequest, db: Session = Depends(get_db)):
  175. """更新政策文件查看计数 - 对齐Go版本接口名"""
  176. policy_file = db.query(PolicyFile).filter(
  177. PolicyFile.id == request.policy_file_id).first()
  178. if not policy_file:
  179. return {"statusCode": 404, "msg": "文件不存在"}
  180. policy_file.view_count = (policy_file.view_count or 0) + 1
  181. policy_file.updated_at = int(time.time())
  182. db.commit()
  183. return {"statusCode": 200, "msg": "success"}
  184. @router.get("/download_file")
  185. async def get_pdf_oss_download_link(pdf_oss_download_link: str):
  186. """流式代理下载 OSS 文件(解密代理 URL)- 对齐Go版本接口名"""
  187. try:
  188. # 解密代理 URL 获取真实 OSS URL
  189. real_url = decrypt_url(pdf_oss_download_link)
  190. # 流式代理下载
  191. async with httpx.AsyncClient() as client:
  192. async with client.stream("GET", real_url) as response:
  193. if response.status_code != 200:
  194. return {"statusCode": response.status_code, "msg": "文件下载失败"}
  195. # 获取文件名和内容类型
  196. content_type = response.headers.get(
  197. "content-type", "application/octet-stream")
  198. content_disposition = response.headers.get(
  199. "content-disposition", "")
  200. async def generate():
  201. async for chunk in response.aiter_bytes():
  202. yield chunk
  203. return StreamingResponse(
  204. generate(),
  205. media_type=content_type,
  206. headers={
  207. "Content-Disposition": content_disposition} if content_disposition else {}
  208. )
  209. except Exception as e:
  210. return {"statusCode": 500, "msg": f"下载失败: {str(e)}"}