common.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: common.py
  6. @date:2025/6/9 13:42
  7. @desc:
  8. """
  9. from typing import List
  10. from django.core.cache import cache
  11. from django.db.models import QuerySet
  12. from django.utils import timezone
  13. from django.utils.translation import gettext_lazy as _
  14. from application.models import Application, ChatRecord, Chat, ApplicationVersion, ChatUserType, ApplicationTypeChoices
  15. from application.serializers.application_chat import ChatCountSerializer
  16. from common.constants.cache_version import Cache_Version
  17. from common.database_model_manage.database_model_manage import DatabaseModelManage
  18. from common.exception.app_exception import ChatException
  19. from knowledge.models import Document
  20. from models_provider.models import Model
  21. from models_provider.tools import get_model_credential
  22. from system_manage.models.resource_mapping import ResourceMapping
  23. from tools.models import ToolRecord
  24. class ToolExecute:
  25. def __init__(self, tool_id: str,
  26. tool_record_id: str,
  27. workspace_id: str,
  28. source_type,
  29. source_id,
  30. debug=False):
  31. self.tool_id = tool_id
  32. self.workspace_id = workspace_id
  33. self.source_type = source_type
  34. self.source_id = source_id
  35. self.tool_record_id = tool_record_id
  36. self.debug = debug
  37. def get_record(self):
  38. if self.tool_record_id:
  39. if self.debug:
  40. return self.to_record(cache.get(Cache_Version.TOOL_WORKFLOW_EXECUTE.get_key(key=self.tool_record_id),
  41. version=Cache_Version.TOOL_WORKFLOW_EXECUTE.get_version()))
  42. else:
  43. return QuerySet(ToolRecord).filter(tool_id=self.tool_id, id=self.tool_record_id).first()
  44. return None
  45. def to_record(self, tool_record_dict):
  46. if tool_record_dict is None:
  47. return None
  48. return ToolRecord(id=tool_record_dict.get('id'),
  49. tool_id=tool_record_dict.get('tool_id'),
  50. workspace_id=tool_record_dict.get('workspace_id'),
  51. source_type=tool_record_dict.get('source_type'),
  52. source_id=tool_record_dict.get('source_id'),
  53. meta=tool_record_dict.get('meta'),
  54. state=tool_record_dict.get('state'),
  55. run_time=tool_record_dict.get('run_time'))
  56. def to_dict(self, tool_record):
  57. return {'id': tool_record.id,
  58. 'tool_id': tool_record.tool_id,
  59. 'workspace_id': tool_record.workspace_id,
  60. 'source_type': tool_record.source_type,
  61. 'source_id': tool_record.source_id,
  62. 'meta': tool_record.meta,
  63. 'state': tool_record.state,
  64. 'run_time': tool_record.run_time}
  65. def set_record(self, tool_record):
  66. cache.set(Cache_Version.TOOL_WORKFLOW_EXECUTE.get_key(key=self.tool_record_id), self.to_dict(tool_record),
  67. version=Cache_Version.TOOL_WORKFLOW_EXECUTE.get_version(),
  68. timeout=60 * 30)
  69. if not self.debug:
  70. QuerySet(ToolRecord).update_or_create(id=tool_record.id,
  71. create_defaults={'id': tool_record.id,
  72. 'tool_id': tool_record.tool_id,
  73. 'state': tool_record.state,
  74. 'workspace_id': tool_record.workspace_id,
  75. "source_type": tool_record.source_type,
  76. 'source_id': tool_record.source_id,
  77. 'meta': tool_record.meta,
  78. 'run_time': tool_record.run_time},
  79. defaults={
  80. 'workspace_id': tool_record.workspace_id,
  81. 'tool_id': tool_record.tool_id,
  82. "source_type": tool_record.source_type,
  83. 'source_id': tool_record.source_id,
  84. 'state': tool_record.state,
  85. 'meta': tool_record.meta,
  86. 'run_time': tool_record.run_time
  87. })
  88. class ChatInfo:
  89. def __init__(self,
  90. chat_id: str,
  91. chat_user_id: str,
  92. chat_user_type: str,
  93. ip_address: str,
  94. source: {},
  95. knowledge_id_list: List[str],
  96. exclude_document_id_list: list[str],
  97. application_id: str,
  98. debug=False):
  99. """
  100. :param chat_id: 对话id
  101. :param chat_user_id 对话用户id
  102. :param chat_user_type 对话用户类型
  103. :param knowledge_id_list: 知识库列表
  104. :param exclude_document_id_list: 排除的文档
  105. :param application_id 应用id
  106. :param debug 是否是调试
  107. :param ip_address: 用户ip地址
  108. :param source: 用户来源
  109. """
  110. self.chat_id = chat_id
  111. self.chat_user_id = chat_user_id
  112. self.chat_user_type = chat_user_type
  113. self.knowledge_id_list = knowledge_id_list
  114. self.exclude_document_id_list = exclude_document_id_list
  115. self.application_id = application_id
  116. self.chat_record_list: List[ChatRecord] = []
  117. self.application = None
  118. self.chat_user = None
  119. self.ip_address = ip_address
  120. self.source = source
  121. self.debug = debug
  122. @staticmethod
  123. def get_no_references_setting(knowledge_setting, model_setting):
  124. no_references_setting = knowledge_setting.get(
  125. 'no_references_setting', {
  126. 'status': 'ai_questioning',
  127. 'value': '{question}'})
  128. if no_references_setting.get('status') == 'ai_questioning':
  129. no_references_prompt = model_setting.get('no_references_prompt', '{question}')
  130. no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
  131. return no_references_setting
  132. def get_application(self):
  133. if self.debug:
  134. application = QuerySet(Application).filter(id=self.application_id).first()
  135. if not application:
  136. raise ChatException(500, _('The application does not exist'))
  137. else:
  138. application = QuerySet(ApplicationVersion).filter(application_id=self.application_id).order_by(
  139. '-create_time')[0:1].first()
  140. if not application:
  141. raise ChatException(500, _("The application has not been published. Please use it after publishing."))
  142. if application.type == ApplicationTypeChoices.SIMPLE.value:
  143. # 数据集id列表
  144. knowledge_id_list = [str(row.target_id) for row in
  145. QuerySet(ResourceMapping).filter(source_id=self.application_id,
  146. source_type='APPLICATION',
  147. target_type='KNOWLEDGE')]
  148. # 需要排除的文档
  149. exclude_document_id_list = [str(document.id) for document in
  150. QuerySet(Document).filter(
  151. knowledge_id__in=knowledge_id_list,
  152. is_active=False)]
  153. self.knowledge_id_list = knowledge_id_list
  154. self.exclude_document_id_list = exclude_document_id_list
  155. self.application = application
  156. return application
  157. def get_chat_user(self, asker=None):
  158. if self.chat_user:
  159. return self.chat_user
  160. chat_user_model = DatabaseModelManage.get_model("chat_user")
  161. if self.chat_user_type == ChatUserType.CHAT_USER.value and chat_user_model:
  162. chat_user = QuerySet(chat_user_model).filter(id=self.chat_user_id).first()
  163. return {
  164. 'id': str(chat_user.id),
  165. 'email': chat_user.email,
  166. 'phone': chat_user.phone,
  167. 'nick_name': chat_user.nick_name,
  168. 'username': chat_user.username,
  169. 'source': chat_user.source
  170. }
  171. else:
  172. if asker:
  173. if isinstance(asker, dict):
  174. self.chat_user = asker
  175. else:
  176. self.chat_user = {'username': asker}
  177. else:
  178. self.chat_user = {'username': '游客'}
  179. return self.chat_user
  180. def get_chat_user_group(self, asker=None):
  181. chat_user = self.get_chat_user(asker=asker)
  182. chat_user_id = chat_user.get('id')
  183. if not chat_user_id:
  184. return []
  185. user_group_relation_model = DatabaseModelManage.get_model("user_group_relation")
  186. if user_group_relation_model:
  187. return [{
  188. 'id': user_group_relation.group_id,
  189. 'name': user_group_relation.group.name
  190. } for user_group_relation in
  191. QuerySet(user_group_relation_model).select_related('group').filter(user_id=chat_user_id)]
  192. return []
  193. def to_base_pipeline_manage_params(self):
  194. self.get_application()
  195. self.get_chat_user()
  196. knowledge_setting = self.application.knowledge_setting
  197. model_setting = self.application.model_setting
  198. model_id = self.application.model_id
  199. model_params_setting = None
  200. if model_id is not None:
  201. model = QuerySet(Model).filter(id=model_id).first()
  202. credential = get_model_credential(model.provider, model.model_type, model.model_name)
  203. model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
  204. return {
  205. 'knowledge_id_list': self.knowledge_id_list,
  206. 'exclude_document_id_list': self.exclude_document_id_list,
  207. 'exclude_paragraph_id_list': [],
  208. 'top_n': 3 if knowledge_setting.get('top_n') is None else knowledge_setting.get('top_n'),
  209. 'similarity': 0.6 if knowledge_setting.get('similarity') is None else knowledge_setting.get('similarity'),
  210. 'max_paragraph_char_number': knowledge_setting.get('max_paragraph_char_number') or 5000,
  211. 'history_chat_record': self.chat_record_list,
  212. 'chat_id': self.chat_id,
  213. 'dialogue_number': self.application.dialogue_number,
  214. 'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
  215. self.application.problem_optimization_prompt) > 0 else _(
  216. "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag"),
  217. 'prompt': model_setting.get(
  218. 'prompt') if 'prompt' in model_setting and len(model_setting.get(
  219. 'prompt')) > 0 else Application.get_default_model_prompt(),
  220. 'system': model_setting.get(
  221. 'system', None),
  222. 'model_id': model_id,
  223. 'problem_optimization': self.application.problem_optimization,
  224. 'stream': True,
  225. 'model_setting': model_setting,
  226. 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
  227. self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
  228. 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
  229. 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
  230. 'workspace_id': self.application.workspace_id,
  231. 'application_id': self.application_id,
  232. 'mcp_enable': self.application.mcp_enable,
  233. 'mcp_tool_ids': self.application.mcp_tool_ids,
  234. 'mcp_servers': self.application.mcp_servers,
  235. 'mcp_source': self.application.mcp_source,
  236. 'tool_enable': self.application.tool_enable,
  237. 'tool_ids': self.application.tool_ids,
  238. 'application_enable': self.application.application_enable,
  239. 'application_ids': self.application.application_ids,
  240. 'skill_tool_ids': self.application.skill_tool_ids,
  241. 'mcp_output_enable': self.application.mcp_output_enable,
  242. }
  243. def to_pipeline_manage_params(self, problem_text: str, post_response_handler,
  244. exclude_paragraph_id_list, chat_user_id: str, chat_user_type, ip_address, source,
  245. stream=True,
  246. form_data=None):
  247. if form_data is None:
  248. form_data = {}
  249. params = self.to_base_pipeline_manage_params()
  250. return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
  251. 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id,
  252. 'chat_user_type': chat_user_type, 'ip_address': ip_address, 'source': source, 'form_data': form_data}
  253. def set_chat(self, question):
  254. if not self.debug:
  255. if not QuerySet(Chat).filter(id=self.chat_id).exists():
  256. Chat(id=self.chat_id, application_id=self.application_id, abstract=question[0:1024],
  257. chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type,
  258. ip_address=self.ip_address, source=self.source,
  259. asker=self.get_chat_user()).save()
  260. def set_chat_variable(self, chat_context):
  261. if not self.debug:
  262. chat = QuerySet(Chat).filter(id=self.chat_id).first()
  263. if chat:
  264. chat.meta = {**(chat.meta if isinstance(chat.meta, dict) else {}), **chat_context}
  265. chat.save()
  266. else:
  267. cache.set(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id), chat_context,
  268. version=Cache_Version.CHAT_VARIABLE.get_version(),
  269. timeout=60 * 30)
  270. def get_chat_variable(self):
  271. if not self.debug:
  272. chat = QuerySet(Chat).filter(id=self.chat_id).first()
  273. if chat:
  274. return chat.meta
  275. return {}
  276. else:
  277. return cache.get(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id),
  278. version=Cache_Version.CHAT_VARIABLE.get_version()) or {}
  279. def append_chat_record(self, chat_record: ChatRecord):
  280. chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
  281. chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
  282. is_save = True
  283. # 存入缓存中
  284. for index in range(len(self.chat_record_list)):
  285. record = self.chat_record_list[index]
  286. if record.id == chat_record.id:
  287. self.chat_record_list[index] = chat_record
  288. is_save = False
  289. break
  290. if is_save:
  291. self.chat_record_list.append(chat_record)
  292. if not self.debug:
  293. if not QuerySet(Chat).filter(id=self.chat_id).exists():
  294. Chat(id=self.chat_id, application_id=self.application_id, abstract=chat_record.problem_text[0:1024],
  295. chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type,
  296. ip_address=self.ip_address, source=self.source,
  297. asker=self.get_chat_user()).save()
  298. else:
  299. QuerySet(Chat).filter(id=self.chat_id).update(update_time=timezone.now())
  300. # 插入会话记录
  301. QuerySet(ChatRecord).update_or_create(id=chat_record.id,
  302. create_defaults={'id': chat_record.id,
  303. 'chat_id': chat_record.chat_id,
  304. "vote_status": chat_record.vote_status,
  305. 'problem_text': chat_record.problem_text,
  306. 'answer_text': chat_record.answer_text,
  307. 'answer_text_list': chat_record.answer_text_list,
  308. 'message_tokens': chat_record.message_tokens,
  309. 'answer_tokens': chat_record.answer_tokens,
  310. 'const': chat_record.const,
  311. 'details': chat_record.details,
  312. 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
  313. 'run_time': chat_record.run_time,
  314. 'source': chat_record.source,
  315. 'ip_address': chat_record.ip_address or '',
  316. 'index': chat_record.index},
  317. defaults={
  318. "vote_status": chat_record.vote_status,
  319. 'problem_text': chat_record.problem_text,
  320. 'answer_text': chat_record.answer_text,
  321. 'answer_text_list': chat_record.answer_text_list,
  322. 'message_tokens': chat_record.message_tokens,
  323. 'answer_tokens': chat_record.answer_tokens,
  324. 'const': chat_record.const,
  325. 'details': chat_record.details,
  326. 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
  327. 'run_time': chat_record.run_time,
  328. 'index': chat_record.index,
  329. 'source': chat_record.source,
  330. 'ip_address': chat_record.ip_address or '',
  331. })
  332. ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
  333. def to_dict(self):
  334. return {
  335. 'chat_id': self.chat_id,
  336. 'chat_user_id': self.chat_user_id,
  337. 'chat_user_type': self.chat_user_type,
  338. 'ip_address': self.ip_address,
  339. 'source': self.source,
  340. 'knowledge_id_list': self.knowledge_id_list,
  341. 'exclude_document_id_list': self.exclude_document_id_list,
  342. 'application_id': self.application_id,
  343. 'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list][-20:],
  344. 'debug': self.debug
  345. }
  346. def chat_record_to_map(self, chat_record):
  347. return {'id': chat_record.id,
  348. 'chat_id': chat_record.chat_id,
  349. 'vote_status': chat_record.vote_status,
  350. 'problem_text': chat_record.problem_text,
  351. 'answer_text': chat_record.answer_text,
  352. 'answer_text_list': chat_record.answer_text_list,
  353. 'message_tokens': chat_record.message_tokens,
  354. 'answer_tokens': chat_record.answer_tokens,
  355. 'const': chat_record.const,
  356. 'details': chat_record.details,
  357. 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
  358. 'run_time': chat_record.run_time,
  359. 'source': chat_record.source,
  360. 'ip_address': chat_record.ip_address,
  361. 'index': chat_record.index}
  362. @staticmethod
  363. def map_to_chat_record(chat_record_dict):
  364. return ChatRecord(id=chat_record_dict.get('id'),
  365. chat_id=chat_record_dict.get('chat_id'),
  366. vote_status=chat_record_dict.get('vote_status'),
  367. problem_text=chat_record_dict.get('problem_text'),
  368. answer_text=chat_record_dict.get('answer_text'),
  369. answer_text_list=chat_record_dict.get('answer_text_list'),
  370. message_tokens=chat_record_dict.get('message_tokens'),
  371. answer_tokens=chat_record_dict.get('answer_tokens'),
  372. const=chat_record_dict.get('const'),
  373. details=chat_record_dict.get('details'),
  374. improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
  375. run_time=chat_record_dict.get('run_time'),
  376. index=chat_record_dict.get('index'),
  377. source=chat_record_dict.get('source'),
  378. ip_address=chat_record_dict.get('ip_address'))
  379. def set_cache(self):
  380. cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
  381. version=Cache_Version.CHAT_INFO.get_version(),
  382. timeout=60 * 30)
  383. @staticmethod
  384. def map_to_chat_info(chat_info_dict):
  385. c = ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
  386. chat_info_dict.get('chat_user_type'), chat_info_dict.get('ip_address'),
  387. chat_info_dict.get('source'),
  388. chat_info_dict.get('knowledge_id_list'),
  389. chat_info_dict.get('exclude_document_id_list'),
  390. chat_info_dict.get('application_id'),
  391. debug=chat_info_dict.get('debug'))
  392. c.chat_record_list = [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]
  393. return c
  394. @staticmethod
  395. def get_cache(chat_id):
  396. chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id),
  397. version=Cache_Version.CHAT_INFO.get_version())
  398. if chat_info_dict:
  399. return ChatInfo.map_to_chat_info(chat_info_dict)
  400. return None
  401. def update_resource_mapping_by_application(application_id: str, other_resource_mapping=None):
  402. from application.flow.tools import get_instance_resource, save_workflow_mapping, \
  403. application_instance_field_call_dict
  404. from system_manage.models.resource_mapping import ResourceType
  405. if other_resource_mapping is None:
  406. other_resource_mapping = []
  407. application = QuerySet(Application).filter(id=application_id).first()
  408. instance_mapping = get_instance_resource(application, ResourceType.APPLICATION, str(application.id),
  409. application_instance_field_call_dict)
  410. if application.type == 'WORK_FLOW':
  411. save_workflow_mapping(application.work_flow, ResourceType.APPLICATION, str(application_id),
  412. instance_mapping + other_resource_mapping)
  413. return
  414. else:
  415. save_workflow_mapping({}, ResourceType.APPLICATION, str(application_id),
  416. instance_mapping + other_resource_mapping)