application_memory_service.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # coding=utf-8
  2. """
  3. 智能体记忆服务
  4. 提供记忆的 CRUD 操作和检索功能
  5. """
  6. from typing import List, Optional
  7. from django.db.models import QuerySet
  8. from application.models.application_memory import ApplicationMemory, ApplicationMemoryType
  9. from common.utils.logger import maxkb_logger
  10. class ApplicationMemoryService:
  11. """智能体记忆服务"""
  12. @staticmethod
  13. def get_memories(application_id: str, memory_type: str = None,
  14. is_enabled: bool = True, limit: int = 100) -> QuerySet:
  15. """获取应用的记忆列表"""
  16. queryset = ApplicationMemory.objects.filter(
  17. application_id=application_id,
  18. is_enabled=is_enabled
  19. )
  20. if memory_type:
  21. queryset = queryset.filter(memory_type=memory_type)
  22. return queryset.order_by('-relevance_score', '-create_time')[:limit]
  23. @staticmethod
  24. def get_memory(memory_id: str) -> Optional[ApplicationMemory]:
  25. """获取单条记忆"""
  26. try:
  27. return ApplicationMemory.objects.get(id=memory_id)
  28. except ApplicationMemory.DoesNotExist:
  29. return None
  30. @staticmethod
  31. def create_memory(application_id: str, content: str,
  32. memory_type: str = ApplicationMemoryType.DIALOGUE,
  33. metadata: dict = None) -> ApplicationMemory:
  34. """创建记忆"""
  35. return ApplicationMemory.objects.create(
  36. application_id=application_id,
  37. content=content,
  38. memory_type=memory_type,
  39. metadata=metadata or {}
  40. )
  41. @staticmethod
  42. def update_memory(memory_id: str, content: str = None,
  43. memory_type: str = None, is_enabled: bool = None,
  44. relevance_score: float = None,
  45. metadata: dict = None) -> Optional[ApplicationMemory]:
  46. """更新记忆"""
  47. memory = ApplicationMemoryService.get_memory(memory_id)
  48. if not memory:
  49. return None
  50. if content is not None:
  51. memory.content = content
  52. if memory_type is not None:
  53. memory.memory_type = memory_type
  54. if is_enabled is not None:
  55. memory.is_enabled = is_enabled
  56. if relevance_score is not None:
  57. memory.relevance_score = relevance_score
  58. if metadata is not None:
  59. memory.metadata = metadata
  60. memory.save()
  61. return memory
  62. @staticmethod
  63. def delete_memory(memory_id: str) -> bool:
  64. """删除记忆"""
  65. memory = ApplicationMemoryService.get_memory(memory_id)
  66. if not memory:
  67. return False
  68. memory.delete()
  69. return True
  70. @staticmethod
  71. def batch_delete_memories(memory_ids: List[str]) -> int:
  72. """批量删除记忆"""
  73. return ApplicationMemory.objects.filter(id__in=memory_ids).delete()[0]
  74. @staticmethod
  75. def search_memories(application_id: str, query: str,
  76. limit: int = 10) -> List[ApplicationMemory]:
  77. """搜索记忆(简单文本匹配)"""
  78. return list(
  79. ApplicationMemory.objects.filter(
  80. application_id=application_id,
  81. is_enabled=True,
  82. content__icontains=query
  83. ).order_by('-relevance_score', '-create_time')[:limit]
  84. )
  85. @staticmethod
  86. def get_memory_context(application_id: str, query: str = None,
  87. max_tokens: int = 2000) -> str:
  88. """获取记忆上下文(用于注入到对话 prompt)"""
  89. memories = ApplicationMemoryService.get_memories(
  90. application_id=application_id,
  91. is_enabled=True
  92. )
  93. if query:
  94. # 如果有查询,优先返回相关记忆
  95. relevant = ApplicationMemoryService.search_memories(
  96. application_id=application_id,
  97. query=query
  98. )
  99. if relevant:
  100. memories = relevant
  101. context_parts = []
  102. current_tokens = 0
  103. for memory in memories:
  104. # 简单估算 token 数(中文约 1.5 字符/token)
  105. estimated_tokens = len(memory.content) / 1.5
  106. if current_tokens + estimated_tokens > max_tokens:
  107. break
  108. context_parts.append(memory.content)
  109. current_tokens += estimated_tokens
  110. return "\n".join(context_parts)