knowledge.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from fastapi import APIRouter, Depends, Request
  2. from sqlalchemy.orm import Session
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. from database import get_db
  6. from models.total import PolicyFile
  7. import time
  8. router = APIRouter()
  9. @router.get("/get_chromadb_document")
  10. async def get_chromadb_document(
  11. query: str,
  12. n: int = 5,
  13. request: Request = None
  14. ):
  15. """获取ChromaDB文档"""
  16. from services.chromadb_service import chromadb_service
  17. try:
  18. results = await chromadb_service.query(query, n)
  19. return {
  20. "statusCode": 200,
  21. "msg": "success",
  22. "data": results
  23. }
  24. except Exception as e:
  25. # 如果 ChromaDB 服务不可用,返回模拟数据
  26. results = [
  27. {
  28. "content": f"关于{query}的文档内容{i}",
  29. "score": 0.9 - i * 0.1,
  30. "metadata": {"source": f"doc_{i}.pdf"}
  31. }
  32. for i in range(1, min(n + 1, 6))
  33. ]
  34. return {
  35. "statusCode": 200,
  36. "msg": "success (fallback)",
  37. "data": results
  38. }
  39. @router.get("/knowledge/files/advanced-search")
  40. async def advanced_search(
  41. keyword: Optional[str] = None,
  42. category: Optional[str] = None,
  43. date_from: Optional[str] = None,
  44. date_to: Optional[str] = None,
  45. page: int = 1,
  46. page_size: int = 20,
  47. request: Request = None,
  48. db: Session = Depends(get_db)
  49. ):
  50. """知识库高级搜索"""
  51. query = db.query(PolicyFile)
  52. # 关键词搜索
  53. if keyword:
  54. query = query.filter(PolicyFile.policy_name.like(f"%{keyword}%"))
  55. # 分类筛选
  56. if category:
  57. category_map = {
  58. "国家规范": 1,
  59. "行业规范": 2,
  60. "地方规范": 3,
  61. "内部条例": 4
  62. }
  63. if category in category_map:
  64. query = query.filter(PolicyFile.policy_type == category_map[category])
  65. # 日期筛选
  66. if date_from:
  67. query = query.filter(PolicyFile.created_at >= int(time.mktime(time.strptime(date_from, "%Y-%m-%d"))))
  68. if date_to:
  69. query = query.filter(PolicyFile.created_at <= int(time.mktime(time.strptime(date_to, "%Y-%m-%d"))))
  70. # 分页
  71. total = query.count()
  72. files = query.offset((page - 1) * page_size).limit(page_size).all()
  73. return {
  74. "statusCode": 200,
  75. "msg": "success",
  76. "data": {
  77. "total": total,
  78. "page": page,
  79. "page_size": page_size,
  80. "items": [
  81. {
  82. "id": f.id,
  83. "policy_name": f.policy_name,
  84. "policy_file_url": f.policy_file_url,
  85. "policy_type": f.policy_type,
  86. "view_count": f.view_count,
  87. "file_type": f.file_type,
  88. "created_at": f.created_at
  89. }
  90. for f in files
  91. ]
  92. }
  93. }