from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from pydantic import BaseModel from typing import Optional from database import get_db from models.total import RecommendQuestion, PolicyFile, FunctionCard, HotQuestion, FeedbackQuestion, User from models.chat import AIMessage from models.user_data import UserData from services.oss_service import oss_service from utils.crypto import decrypt_url from utils.config import get_proxy_url import time import httpx router = APIRouter() @router.get("/recommend_question") async def get_recommend_question(db: Session = Depends(get_db)): """获取推荐问题""" questions = db.query(RecommendQuestion).limit(10).all() return { "statusCode": 200, "msg": "success", "data": [{"id": q.id, "question": q.question} for q in questions] } @router.get("/get_policy_file") async def get_policy_file( policy_type: Optional[str] = None, page: int = 1, pageSize: int = 20, search: str = "", db: Session = Depends(get_db) ): """获取策略文件列表""" query = db.query(PolicyFile).filter(PolicyFile.is_deleted == 0) # 只有当policy_type有效且不为0或空字符串时才添加类型过滤 if policy_type and policy_type != "" and policy_type != "0": try: policy_type_int = int(policy_type) query = query.filter(PolicyFile.policy_type == policy_type_int) except ValueError: pass # 忽略无效的类型值 if search: query = query.filter(PolicyFile.policy_name.like(f"%{search}%")) offset = (page - 1) * pageSize files = query.order_by(PolicyFile.updated_at.desc()).offset( offset).limit(pageSize).all() return { "statusCode": 200, "msg": "success", "data": [ { "id": f.id, "policy_name": f.policy_name, "policy_file_url": get_proxy_url(f.policy_file_url), "policy_type": f.policy_type, "file_type": f.file_type, "file_tag": getattr(f, 'file_tag', ''), "publish_time": getattr(f, 'publish_time', f.created_at), "view_count": f.view_count, "created_at": f.created_at, "updated_at": f.updated_at } for f in files ] } @router.get("/get_function_card") async def get_function_card( function_type: Optional[int] = None, db: Session = Depends(get_db) ): """获取功能卡片""" query = db.query(FunctionCard).filter(FunctionCard.is_deleted == 0) if function_type is not None: query = query.filter(FunctionCard.function_type == function_type) cards = query.order_by(FunctionCard.id.asc()).limit(4).all() return { "statusCode": 200, "msg": "success", "data": [ { "id": c.id, "function_title": c.function_title, "function_icon": c.function_icon, "function_content": c.function_content, "function_type": c.function_type } for c in cards ] } @router.get("/get_hot_question") async def get_hot_question( question_type: Optional[int] = None, db: Session = Depends(get_db) ): """获取热点问题(按点击量排序)""" query = db.query(HotQuestion).filter(HotQuestion.is_deleted == 0) if question_type is not None: query = query.filter(HotQuestion.question_type == question_type) questions = query.order_by( HotQuestion.click_count.desc(), HotQuestion.id.asc() ).limit(3).all() return { "statusCode": 200, "msg": "success", "data": [ { "id": q.id, "question": q.question, "click_count": q.click_count or 0, "question_type": q.question_type } for q in questions ] } class SubmitFeedbackRequest(BaseModel): feedback_type: str content: str contact: str = "" @router.post("/submit_feedback") async def submit_feedback(request: SubmitFeedbackRequest, req: Request, db: Session = Depends(get_db)): """提交意见反馈(对齐Go版本)""" # 从token获取user_id user = req.state.user user_id = user.id if user else 0 # 映射反馈类型:支持中文描述和英文标识 type_map = { "功能建议": 1, "bug": 1, "问题反馈": 1, "ui": 2, "界面优化": 2, "experience": 3, "体验问题": 3, "other": 4, "其他": 4 } feedback_type_id = type_map.get(request.feedback_type, 4) feedback = FeedbackQuestion( feedback_type=feedback_type_id, feedback_content=request.content, feedback_user_phone=request.contact, user_id=user_id, created_at=int(time.time()), updated_at=int(time.time()) ) db.add(feedback) db.commit() return { "statusCode": 200, "msg": "感谢您的反馈!" } class LikeDislikeRequest(BaseModel): ai_message_id: Optional[int] = None action: Optional[str] = None id: Optional[int] = None user_feedback: Optional[int] = None FEEDBACK_NONE = 0 FEEDBACK_LIKE = 2 FEEDBACK_DISLIKE = 3 FEEDBACK_REWARD_POINTS = { FEEDBACK_NONE: 0, FEEDBACK_LIKE: 2, FEEDBACK_DISLIKE: 1, } def _resolve_like_dislike_payload(data: LikeDislikeRequest): message_id = data.ai_message_id or data.id if not message_id: return None, None, "缺少消息ID" if data.user_feedback is not None: feedback = int(data.user_feedback) else: action = (data.action or "").strip().lower() if action in ("like", "2"): feedback = FEEDBACK_LIKE elif action in ("dislike", "3"): feedback = FEEDBACK_DISLIKE elif action in ("", "none", "cancel", "0"): feedback = FEEDBACK_NONE else: return None, None, "反馈类型错误" if feedback not in (FEEDBACK_NONE, FEEDBACK_LIKE, FEEDBACK_DISLIKE): return None, None, "反馈类型错误" return message_id, feedback, None def _find_current_points_holder(db: Session, user_info): user = db.query(User).filter( User.id == user_info.user_id, User.is_deleted == 0, ).first() if user: return user return db.query(UserData).filter( UserData.accountID == user_info.account, ).first() @router.post("/like_and_dislike") async def like_and_dislike(data: LikeDislikeRequest, request: Request, db: Session = Depends(get_db)): """Save AI message feedback and reward points to the current user.""" user_info = request.state.user if not user_info: return {"statusCode": 401, "msg": "未认证"} message_id, feedback, error = _resolve_like_dislike_payload(data) if error: return {"statusCode": 400, "msg": error} message = db.query(AIMessage).filter(AIMessage.id == message_id).first() if not message: return {"statusCode": 404, "msg": "消息不存在"} if getattr(message, "user_id", user_info.user_id) != user_info.user_id: return {"statusCode": 403, "msg": "无权评价该消息"} points_holder = _find_current_points_holder(db, user_info) if not points_holder: return {"statusCode": 404, "msg": "未找到用户数据"} previous_feedback = int(message.user_feedback or 0) previous_points = FEEDBACK_REWARD_POINTS.get(previous_feedback, 0) current_points = FEEDBACK_REWARD_POINTS.get(feedback, 0) points_delta = current_points - previous_points try: message.user_feedback = feedback message.updated_at = int(time.time()) new_balance = (points_holder.points or 0) + points_delta points_holder.points = new_balance db.commit() except Exception as e: db.rollback() return {"statusCode": 500, "msg": f"反馈提交失败: {str(e)}"} return { "statusCode": 200, "msg": "success", "data": { "ai_message_id": message_id, "user_feedback": feedback, "points_added": points_delta, "points_delta": points_delta, "new_balance": new_balance, }, } @router.get("/get_user_data_id") async def get_user_data_id(request: Request, db: Session = Depends(get_db)): """通过 token 的 account_id 查询 UserData 主键""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未认证"} user_data = db.query(UserData).filter( UserData.accountID == user.account).first() if not user_data: return {"statusCode": 404, "msg": "用户数据不存在"} return { "statusCode": 200, "msg": "success", "data": {"user_data_id": user_data.id} } class PolicyFileCountRequest(BaseModel): policy_file_id: int @router.post("/policy_file_count") async def get_policy_file_view_and_download_count(request: PolicyFileCountRequest, db: Session = Depends(get_db)): """更新政策文件查看计数 - 对齐Go版本接口名""" policy_file = db.query(PolicyFile).filter( PolicyFile.id == request.policy_file_id).first() if not policy_file: return {"statusCode": 404, "msg": "文件不存在"} policy_file.view_count = (policy_file.view_count or 0) + 1 policy_file.updated_at = int(time.time()) db.commit() return {"statusCode": 200, "msg": "success"} @router.get("/download_file") async def get_pdf_oss_download_link(pdf_oss_download_link: str): """流式代理下载 OSS 文件(解密代理 URL)- 对齐Go版本接口名""" try: # 解密代理 URL 获取真实 OSS URL real_url = decrypt_url(pdf_oss_download_link) # 流式代理下载 async with httpx.AsyncClient() as client: async with client.stream("GET", real_url) as response: if response.status_code != 200: return {"statusCode": response.status_code, "msg": "文件下载失败"} # 获取文件名和内容类型 content_type = response.headers.get( "content-type", "application/octet-stream") content_disposition = response.headers.get( "content-disposition", "") async def generate(): async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( generate(), media_type=content_type, headers={ "Content-Disposition": content_disposition} if content_disposition else {} ) except Exception as e: return {"statusCode": 500, "msg": f"下载失败: {str(e)}"}