from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from pydantic import BaseModel from typing import Optional from database import get_db from models.total import RecommendQuestion, PolicyFile, FunctionCard, HotQuestion, FeedbackQuestion from models.chat import AIMessage from models.user_data import UserData from services.oss_service import oss_service from utils.crypto import decrypt_url import time import httpx router = APIRouter() @router.get("/recommend_question") async def get_recommend_question(db: Session = Depends(get_db)): """获取推荐问题""" questions = db.query(RecommendQuestion).limit(10).all() return { "statusCode": 200, "msg": "success", "data": [{"id": q.id, "question": q.question} for q in questions] } @router.get("/get_policy_file") async def get_policy_file( policy_type: Optional[int] = None, page: int = 1, page_size: int = 20, db: Session = Depends(get_db) ): """获取策略文件列表""" query = db.query(PolicyFile).filter(PolicyFile.is_deleted == 0) if policy_type is not None and policy_type != 0: query = query.filter(PolicyFile.policy_type == policy_type) total = query.count() offset = (page - 1) * page_size files = query.order_by(PolicyFile.updated_at.desc()).offset( offset).limit(page_size).all() return { "statusCode": 200, "msg": "success", "data": { "total": total, "items": [ { "id": f.id, "policy_name": f.policy_name, "policy_file_url": f.policy_file_url, "policy_type": f.policy_type, "file_type": f.file_type, "view_count": f.view_count, "created_at": f.created_at } for f in files ] } } @router.get("/get_function_card") async def get_function_card(db: Session = Depends(get_db)): """获取功能卡片""" cards = db.query(FunctionCard).limit(4).all() return { "statusCode": 200, "msg": "success", "data": [ { "id": c.id, "function_title": c.function_title, "function_icon": c.function_icon, "function_content": c.function_content, "function_type": c.function_type } for c in cards ] } @router.get("/get_hot_question") async def get_hot_question(db: Session = Depends(get_db)): """获取热点问题(按点击量排序)""" questions = db.query(HotQuestion).order_by( HotQuestion.click_count.desc()).limit(3).all() return { "statusCode": 200, "msg": "success", "data": [ { "id": q.id, "question": q.question, "click_count": q.click_count or 0 } for q in questions ] } class SubmitFeedbackRequest(BaseModel): feedback_type: str content: str contact: str = "" @router.post("/submit_feedback") async def submit_feedback(request: SubmitFeedbackRequest, req: Request, db: Session = Depends(get_db)): """提交意见反馈(对齐Go版本)""" # 从token获取user_id user = req.state.user user_id = user.id if user else 0 # 映射反馈类型:支持中文描述和英文标识 type_map = { "功能建议": 1, "bug": 1, "问题反馈": 1, "ui": 2, "界面优化": 2, "experience": 3, "体验问题": 3, "other": 4, "其他": 4 } feedback_type_id = type_map.get(request.feedback_type, 4) feedback = FeedbackQuestion( feedback_type=feedback_type_id, feedback_content=request.content, feedback_user_phone=request.contact, user_id=user_id, created_at=int(time.time()), updated_at=int(time.time()) ) db.add(feedback) db.commit() return { "statusCode": 200, "msg": "感谢您的反馈!" } class LikeDislikeRequest(BaseModel): ai_message_id: int action: str # "like" 或 "dislike" @router.post("/like_and_dislike") async def like_and_dislike(request: LikeDislikeRequest, db: Session = Depends(get_db)): """点赞/踩(对齐Go版本)""" message = db.query(AIMessage).filter( AIMessage.id == request.ai_message_id).first() if not message: return {"statusCode": 404, "msg": "消息不存在"} # 将action转换为user_feedback:like=2(满意/赞), dislike=3(不满意/踩) user_feedback = 2 if request.action == "like" else 3 message.user_feedback = user_feedback message.updated_at = int(time.time()) db.commit() return {"statusCode": 200, "msg": "success"} @router.get("/get_user_data_id") async def get_user_data_id(request: Request, db: Session = Depends(get_db)): """通过 token 的 account_id 查询 UserData 主键""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未认证"} user_data = db.query(UserData).filter( UserData.accountID == user.account).first() if not user_data: return {"statusCode": 404, "msg": "用户数据不存在"} return { "statusCode": 200, "msg": "success", "data": {"user_data_id": user_data.id} } class PolicyFileCountRequest(BaseModel): policy_file_id: int @router.post("/policy_file_count") async def get_policy_file_view_and_download_count(request: PolicyFileCountRequest, db: Session = Depends(get_db)): """更新政策文件查看计数 - 对齐Go版本接口名""" policy_file = db.query(PolicyFile).filter( PolicyFile.id == request.policy_file_id).first() if not policy_file: return {"statusCode": 404, "msg": "文件不存在"} policy_file.view_count = (policy_file.view_count or 0) + 1 policy_file.updated_at = int(time.time()) db.commit() return {"statusCode": 200, "msg": "success"} @router.get("/download_file") async def get_pdf_oss_download_link(pdf_oss_download_link: str): """流式代理下载 OSS 文件(解密代理 URL)- 对齐Go版本接口名""" try: # 解密代理 URL 获取真实 OSS URL real_url = decrypt_url(pdf_oss_download_link) # 流式代理下载 async with httpx.AsyncClient() as client: async with client.stream("GET", real_url) as response: if response.status_code != 200: return {"statusCode": response.status_code, "msg": "文件下载失败"} # 获取文件名和内容类型 content_type = response.headers.get( "content-type", "application/octet-stream") content_disposition = response.headers.get( "content-disposition", "") async def generate(): async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( generate(), media_type=content_type, headers={ "Content-Disposition": content_disposition} if content_disposition else {} ) except Exception as e: return {"statusCode": 500, "msg": f"下载失败: {str(e)}"}