generate.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import traceback
  2. from celery_once import QueueOnce
  3. from django.db.models import QuerySet
  4. from django.db.models.functions import Reverse, Substr
  5. from django.utils.translation import gettext_lazy as _
  6. from langchain_core.messages import HumanMessage
  7. from common.config.embedding_config import ModelManage
  8. from common.event.listener_manage import ListenerManagement
  9. from common.utils.logger import maxkb_logger
  10. from common.utils.page_utils import page, page_desc
  11. from knowledge.models import Paragraph, Document, Status, TaskType, State
  12. from knowledge.task.handler import save_problem
  13. from models_provider.models import Model
  14. from models_provider.tools import get_model
  15. from ops import celery_app
  16. def get_llm_model(model_id, model_params_setting=None):
  17. model = QuerySet(Model).filter(id=model_id).first()
  18. return ModelManage.get_model(model_id, lambda _id: get_model(model, **(model_params_setting or {})))
  19. def generate_problem_by_paragraph(paragraph, llm_model, prompt):
  20. try:
  21. ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
  22. State.STARTED)
  23. res = llm_model.invoke(
  24. [HumanMessage(content=prompt.replace('{data}', paragraph.content).replace('{title}', paragraph.title))])
  25. if (res.content is None) or (len(res.content) == 0):
  26. return
  27. problems = res.content.split('\n')
  28. for problem in problems:
  29. save_problem(paragraph.knowledge_id, paragraph.document_id, paragraph.id, problem)
  30. ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
  31. State.SUCCESS)
  32. except Exception as e:
  33. ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
  34. State.FAILURE)
  35. def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
  36. def generate_problem(paragraph_list):
  37. for paragraph in paragraph_list:
  38. if is_the_task_interrupted():
  39. return
  40. generate_problem_by_paragraph(paragraph, llm_model, prompt)
  41. post_apply()
  42. return generate_problem
  43. def get_is_the_task_interrupted(document_id):
  44. def is_the_task_interrupted():
  45. document = QuerySet(Document).filter(id=document_id).first()
  46. if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
  47. return True
  48. return False
  49. return is_the_task_interrupted
  50. @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']},
  51. name='celery:generate_related_by_knowledge')
  52. def generate_related_by_knowledge_id(knowledge_id, model_id, model_params_setting, prompt, state_list=None):
  53. document_list = QuerySet(Document).filter(knowledge_id=knowledge_id)
  54. for document in document_list:
  55. try:
  56. generate_related_by_document_id.delay(document.id, model_id, model_params_setting, prompt, state_list)
  57. except Exception as e:
  58. pass
  59. @celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
  60. name='celery:generate_related_by_document')
  61. def generate_related_by_document_id(document_id, model_id, model_params_setting, prompt, state_list=None):
  62. if state_list is None:
  63. state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
  64. State.REVOKE.value,
  65. State.REVOKED.value, State.IGNORED.value]
  66. try:
  67. is_the_task_interrupted = get_is_the_task_interrupted(document_id)
  68. if is_the_task_interrupted():
  69. return
  70. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
  71. TaskType.GENERATE_PROBLEM,
  72. State.STARTED)
  73. llm_model = get_llm_model(model_id, model_params_setting)
  74. # 生成问题函数
  75. generate_problem = get_generate_problem(llm_model, prompt,
  76. ListenerManagement.get_aggregation_document_status(
  77. document_id), is_the_task_interrupted)
  78. query_set = QuerySet(Paragraph).annotate(
  79. reversed_status=Reverse('status'),
  80. task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
  81. 1),
  82. ).filter(task_type_status__in=state_list, document_id=document_id)
  83. page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
  84. except Exception as e:
  85. maxkb_logger.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
  86. maxkb_logger.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
  87. document_id=document_id, error=str(e), traceback=traceback.format_exc()))
  88. finally:
  89. ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
  90. maxkb_logger.info(_('End--->Generate problem: {document_id}').format(document_id=document_id))
  91. @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
  92. name='celery:generate_related_by_paragraph_list')
  93. def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, model_params_setting, prompt):
  94. try:
  95. is_the_task_interrupted = get_is_the_task_interrupted(document_id)
  96. if is_the_task_interrupted():
  97. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
  98. TaskType.GENERATE_PROBLEM,
  99. State.REVOKED)
  100. return
  101. ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
  102. TaskType.GENERATE_PROBLEM,
  103. State.STARTED)
  104. llm_model = get_llm_model(model_id, model_params_setting)
  105. # 生成问题函数
  106. generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
  107. document_id))
  108. def is_the_task_interrupted():
  109. document = QuerySet(Document).filter(id=document_id).first()
  110. if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
  111. return True
  112. return False
  113. page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
  114. finally:
  115. ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)