chat.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: chat.py
  6. @date:2025/6/9 11:23
  7. @desc:
  8. """
  9. import json
  10. import os
  11. from gettext import gettext
  12. from typing import List, Dict
  13. import uuid_utils.compat as uuid
  14. from django.db.models import QuerySet
  15. from django.utils.translation import gettext_lazy as _
  16. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  17. from rest_framework import serializers
  18. from application.chat_pipeline.pipeline_manage import PipelineManage
  19. from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
  20. from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
  21. from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
  22. BaseGenerateHumanMessageStep
  23. from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
  24. from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
  25. from application.flow.common import Answer, Workflow
  26. from application.flow.i_step_node import WorkFlowPostHandler
  27. from application.flow.tools import to_stream_response_simple
  28. from application.flow.workflow_manage import WorkflowManage
  29. from application.models import Application, ApplicationTypeChoices, \
  30. ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat, ApplicationVersion
  31. from application.serializers.application import ApplicationOperateSerializer
  32. from application.serializers.common import ChatInfo
  33. from common.database_model_manage.database_model_manage import DatabaseModelManage
  34. from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed, ChatException
  35. from common.handle.base_to_response import BaseToResponse
  36. from common.handle.impl.response.openai_to_response import OpenaiToResponse
  37. from common.handle.impl.response.system_to_response import SystemToResponse
  38. from common.utils.common import flat_map, get_file_content, is_valid_uuid
  39. from knowledge.models import Document, Paragraph
  40. from maxkb.conf import PROJECT_DIR
  41. from models_provider.models import Model, Status
  42. from models_provider.tools import get_model_instance_by_model_workspace_id
  43. from system_manage.models.resource_mapping import ResourceMapping
  44. class ChatMessagesSerializers(serializers.Serializer):
  45. role = serializers.CharField(required=True, label=_("Role"))
  46. content = serializers.CharField(required=True, label=_("Content"))
  47. class GeneratePromptSerializers(serializers.Serializer):
  48. prompt = serializers.CharField(required=True, label=_("Prompt template"))
  49. messages = serializers.ListSerializer(child=ChatMessagesSerializers(), required=True, label=_("Chat context"))
  50. def is_valid(self, *, raise_exception=False):
  51. super().is_valid(raise_exception=True)
  52. messages = self.data.get("messages")
  53. if len(messages) > 30:
  54. raise AppApiException(400, _("Too many messages"))
  55. for index in range(len(messages)):
  56. role = messages[index].get('role')
  57. if role == 'ai' and index % 2 != 1:
  58. raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
  59. if role == 'user' and index % 2 != 0:
  60. raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
  61. if role not in ['user', 'ai']:
  62. raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
  63. class ChatMessageSerializers(serializers.Serializer):
  64. message = serializers.CharField(required=True, label=_("User Questions"))
  65. stream = serializers.BooleanField(required=True,
  66. label=_("Is the answer in streaming mode"))
  67. re_chat = serializers.BooleanField(required=True, label=_("Do you want to reply again"))
  68. chat_record_id = serializers.UUIDField(required=False, allow_null=True,
  69. label=_("Conversation record id"))
  70. node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
  71. label=_("Node id"))
  72. runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
  73. label=_("Runtime node id"))
  74. node_data = serializers.DictField(required=False, allow_null=True,
  75. label=_("Node parameters"))
  76. form_data = serializers.DictField(required=False, label=_("Global variables"))
  77. image_list = serializers.ListField(required=False, label=_("picture"))
  78. document_list = serializers.ListField(required=False, label=_("document"))
  79. audio_list = serializers.ListField(required=False, label=_("Audio"))
  80. other_list = serializers.ListField(required=False, label=_("Other"))
  81. child_node = serializers.DictField(required=False, allow_null=True,
  82. label=_("Child Nodes"))
  83. def get_post_handler(chat_info: ChatInfo):
  84. class PostHandler(PostResponseHandler):
  85. def handler(self,
  86. chat_id,
  87. chat_record_id,
  88. paragraph_list: List[Paragraph],
  89. problem_text: str,
  90. answer_text,
  91. manage: PipelineManage,
  92. step: BaseChatStep,
  93. padding_problem_text: str = None,
  94. **kwargs):
  95. answer_list = [[Answer(answer_text, 'ai-chat-node', 'ai-chat-node', 'ai-chat-node', {}, 'ai-chat-node',
  96. kwargs.get('reasoning_content', '')).to_dict()]]
  97. chat_record = ChatRecord(id=chat_record_id,
  98. chat_id=chat_id,
  99. problem_text=problem_text,
  100. answer_text=answer_text,
  101. details=manage.get_details(),
  102. message_tokens=manage.context['message_tokens'],
  103. answer_tokens=manage.context['answer_tokens'],
  104. answer_text_list=answer_list,
  105. run_time=manage.context['run_time'],
  106. index=len(chat_info.chat_record_list) + 1,
  107. ip_address=chat_info.ip_address,
  108. source=chat_info.source
  109. )
  110. chat_info.append_chat_record(chat_record)
  111. # 重新设置缓存
  112. chat_info.set_cache()
  113. return PostHandler()
  114. class DebugChatSerializers(serializers.Serializer):
  115. chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
  116. def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
  117. self.is_valid(raise_exception=True)
  118. chat_id = self.data.get('chat_id')
  119. chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
  120. application = QuerySet(Application).filter(id=chat_info.application_id).first()
  121. chat_info.application = application
  122. return ChatSerializers(data={
  123. 'chat_id': chat_id, "chat_user_id": chat_info.chat_user_id,
  124. "chat_user_type": chat_info.chat_user_type,
  125. "application_id": chat_info.application.id, "debug": True
  126. }).chat(instance, base_to_response)
  127. SYSTEM_ROLE = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system'))
  128. class PromptGenerateSerializer(serializers.Serializer):
  129. workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
  130. model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model"))
  131. application_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Application"))
  132. def is_valid(self, *, raise_exception=False):
  133. super().is_valid(raise_exception=True)
  134. workspace_id = self.data.get('workspace_id')
  135. query_set = QuerySet(Application).filter(id=self.data.get('application_id'))
  136. if workspace_id:
  137. query_set = query_set.filter(workspace_id=workspace_id)
  138. application = query_set.first()
  139. if application is None:
  140. raise AppApiException(500, _('Application id does not exist'))
  141. return application
  142. def generate_prompt(self, instance: dict):
  143. application = self.is_valid(raise_exception=True)
  144. GeneratePromptSerializers(data=instance).is_valid(raise_exception=True)
  145. workspace_id = self.data.get('workspace_id')
  146. model_id = self.data.get('model_id')
  147. prompt = instance.get('prompt')
  148. messages = instance.get('messages')
  149. message = messages[-1]['content']
  150. q = prompt.replace("{userInput}", message)
  151. messages[-1]['content'] = q
  152. SUPPORTED_MODEL_TYPES = ["LLM", "IMAGE"]
  153. model_exist = QuerySet(Model).filter(
  154. id=model_id,
  155. model_type__in=SUPPORTED_MODEL_TYPES
  156. ).exists()
  157. if not model_exist:
  158. raise Exception(_("Model does not exists or is not an LLM model"))
  159. def process():
  160. model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,
  161. **application.model_params_setting)
  162. try:
  163. for r in model.stream([SystemMessage(content=SYSTEM_ROLE),
  164. *[HumanMessage(content=m.get('content')) if m.get(
  165. 'role') == 'user' else AIMessage(
  166. content=m.get('content')) for m in messages]]):
  167. yield 'data: ' + json.dumps({'content': r.content}) + '\n\n'
  168. except Exception as e:
  169. yield 'data: ' + json.dumps({'error': str(e)}) + '\n\n'
  170. return to_stream_response_simple(process())
  171. class OpenAIMessage(serializers.Serializer):
  172. content = serializers.CharField(required=True, label=_('content'))
  173. role = serializers.CharField(required=True, label=_('Role'))
  174. class OpenAIInstanceSerializer(serializers.Serializer):
  175. messages = serializers.ListField(child=OpenAIMessage())
  176. chat_id = serializers.UUIDField(required=False, label=_("Conversation ID"))
  177. re_chat = serializers.BooleanField(required=False, label=_("Regenerate"))
  178. stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
  179. class OpenAIChatSerializer(serializers.Serializer):
  180. application_id = serializers.UUIDField(required=True, label=_("Application ID"))
  181. chat_user_id = serializers.CharField(required=True, label=_("Client id"))
  182. chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
  183. ip_address = serializers.CharField(required=False, label=_("IP Address"))
  184. source = serializers.JSONField(required=False, label=_("Source"))
  185. @staticmethod
  186. def get_message(instance):
  187. return instance.get('messages')[-1].get('content')
  188. @staticmethod
  189. def generate_chat(chat_id, application_id, message, chat_user_id, chat_user_type, ip_address, source):
  190. if chat_id is None:
  191. chat_id = str(uuid.uuid1())
  192. chat_info = ChatInfo(chat_id, chat_user_id, chat_user_type, ip_address, source, [], [],
  193. application_id)
  194. chat_info.set_cache()
  195. else:
  196. chat_info = ChatInfo.get_cache(chat_id)
  197. if chat_info is None:
  198. open_chat = ChatSerializers(data={
  199. 'chat_id': chat_id,
  200. 'chat_user_id': chat_user_id,
  201. 'chat_user_type': chat_user_type,
  202. 'application_id': application_id,
  203. 'ip_address': ip_address,
  204. 'source': source,
  205. })
  206. open_chat.is_valid(raise_exception=True)
  207. chat_info = open_chat.re_open_chat(chat_id)
  208. chat_info.set_cache()
  209. return chat_id
  210. def chat(self, instance: Dict, with_valid=True):
  211. if with_valid:
  212. self.is_valid(raise_exception=True)
  213. OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True)
  214. chat_id = instance.get('chat_id')
  215. message = self.get_message(instance)
  216. re_chat = instance.get('re_chat', False)
  217. stream = instance.get('stream', False)
  218. application_id = self.data.get('application_id')
  219. chat_user_id = self.data.get('chat_user_id')
  220. chat_user_type = self.data.get('chat_user_type')
  221. ip_address = self.data.get('ip_address')
  222. source = self.data.get('source')
  223. chat_id = self.generate_chat(chat_id, application_id, message, chat_user_id, chat_user_type, ip_address, source)
  224. return ChatSerializers(
  225. data={
  226. 'chat_id': chat_id,
  227. 'chat_user_id': chat_user_id,
  228. 'chat_user_type': chat_user_type,
  229. 'application_id': application_id,
  230. 'ip_address': ip_address,
  231. 'source': source,
  232. }
  233. ).chat({'message': message,
  234. 're_chat': re_chat,
  235. 'stream': stream,
  236. 'form_data': instance.get('form_data', {}),
  237. 'image_list': instance.get('image_list', []),
  238. 'document_list': instance.get('document_list', []),
  239. 'audio_list': instance.get('audio_list', []),
  240. 'other_list': instance.get('other_list', [])},
  241. base_to_response=OpenaiToResponse())
  242. class ChatSerializers(serializers.Serializer):
  243. chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
  244. chat_user_id = serializers.CharField(required=True, label=_("Client id"))
  245. chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
  246. application_id = serializers.UUIDField(required=True, allow_null=True,
  247. label=_("Application ID"))
  248. debug = serializers.BooleanField(required=False, label=_("Debug"))
  249. ip_address = serializers.CharField(required=False, label=_("IP Address"), allow_null=True, allow_blank=True)
  250. source = serializers.JSONField(required=False, label=_("Source"))
  251. def is_valid_application_workflow(self, *, raise_exception=False):
  252. self.is_valid_intraday_access_num()
  253. def is_valid_chat_id(self, chat_info: ChatInfo):
  254. if self.data.get('application_id') is not None and self.data.get('application_id') != str(
  255. chat_info.application_id):
  256. raise ChatException(500, _("Conversation does not exist"))
  257. def is_valid_intraday_access_num(self):
  258. if not self.data.get('debug') and [ChatUserType.ANONYMOUS_USER.value,
  259. ChatUserType.CHAT_USER.value].__contains__(
  260. self.data.get('chat_user_type')):
  261. access_client = QuerySet(ApplicationChatUserStats).filter(chat_user_id=self.data.get('chat_user_id'),
  262. application_id=self.data.get(
  263. 'application_id')).first()
  264. if access_client is None:
  265. access_client = ApplicationChatUserStats(chat_user_id=self.data.get('chat_user_id'),
  266. chat_user_type=self.data.get('chat_user_type'),
  267. application_id=self.data.get('application_id'),
  268. access_num=0,
  269. intraday_access_num=0)
  270. access_client.save()
  271. application_access_token = QuerySet(ApplicationAccessToken).filter(
  272. application_id=self.data.get('application_id')).first()
  273. if application_access_token.access_num <= access_client.intraday_access_num:
  274. raise AppChatNumOutOfBoundsFailed(1002, _("The number of visits exceeds today's visits"))
  275. def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
  276. self.is_valid_intraday_access_num()
  277. model_id = chat_info.application.model_id
  278. if model_id is None:
  279. return chat_info
  280. model = QuerySet(Model).filter(id=model_id).first()
  281. if model is None:
  282. return chat_info
  283. if model.status == Status.ERROR:
  284. raise ChatException(500, _("The current model is not available"))
  285. if model.status == Status.DOWNLOAD:
  286. raise ChatException(500, _("The model is downloading, please try again later"))
  287. return chat_info
  288. def chat_simple(self, chat_info: ChatInfo, instance, base_to_response):
  289. message = instance.get('message')
  290. re_chat = instance.get('re_chat')
  291. stream = instance.get('stream')
  292. chat_user_id = self.data.get('chat_user_id')
  293. chat_user_type = self.data.get('chat_user_type')
  294. ip_address = self.data.get('ip_address')
  295. source = self.data.get('source')
  296. form_data = instance.get("form_data")
  297. chat_record_id = instance.get('chat_record_id')
  298. pipeline_manage_builder = PipelineManage.builder()
  299. # 如果开启了问题优化,则添加上问题优化步骤
  300. if chat_info.application.problem_optimization:
  301. pipeline_manage_builder.append_step(BaseResetProblemStep)
  302. # 构建流水线管理器
  303. pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
  304. .append_step(BaseGenerateHumanMessageStep)
  305. .append_step(BaseChatStep)
  306. .add_base_to_response(base_to_response)
  307. .add_debug(self.data.get('debug', False))
  308. .build())
  309. exclude_paragraph_id_list = []
  310. # 相同问题是否需要排除已经查询到的段落
  311. if re_chat:
  312. paragraph_id_list = flat_map(
  313. [[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
  314. chat_record in chat_info.chat_record_list if
  315. chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
  316. chat_record.details['search_step']])
  317. exclude_paragraph_id_list = list(set(paragraph_id_list))
  318. # 构建运行参数
  319. params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
  320. chat_user_id, chat_user_type, ip_address, source, stream,
  321. form_data)
  322. if chat_record_id:
  323. params['chat_record_id'] = chat_record_id
  324. chat_info.set_chat(message)
  325. # 运行流水线作业
  326. pipeline_message.run(params)
  327. return pipeline_message.context['chat_result']
  328. @staticmethod
  329. def get_chat_record(chat_info, chat_record_id):
  330. if chat_info is not None:
  331. chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
  332. str(chat_record.id) == str(chat_record_id)]
  333. if chat_record_list is not None and len(chat_record_list):
  334. return chat_record_list[-1]
  335. chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
  336. if chat_record is None:
  337. if not is_valid_uuid(chat_record_id):
  338. raise ChatException(500, _("Conversation record does not exist"))
  339. chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
  340. return chat_record
  341. def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
  342. message = instance.get('message')
  343. re_chat = instance.get('re_chat')
  344. stream = instance.get('stream')
  345. chat_user_id = self.data.get("chat_user_id")
  346. chat_user_type = self.data.get('chat_user_type')
  347. ip_address = self.data.get('ip_address')
  348. source = self.data.get('source')
  349. form_data = instance.get('form_data')
  350. image_list = instance.get('image_list')
  351. video_list = instance.get('video_list')
  352. document_list = instance.get('document_list')
  353. audio_list = instance.get('audio_list')
  354. other_list = instance.get('other_list')
  355. workspace_id = chat_info.application.workspace_id
  356. chat_record_id = instance.get('chat_record_id')
  357. debug = self.data.get('debug', False)
  358. chat_record = None
  359. history_chat_record = chat_info.chat_record_list
  360. if chat_record_id is not None:
  361. chat_record = self.get_chat_record(chat_info, chat_record_id)
  362. if chat_record:
  363. history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
  364. work_flow = chat_info.application.work_flow
  365. work_flow_manage = WorkflowManage(Workflow.new_instance(work_flow),
  366. {'history_chat_record': history_chat_record, 'question': message,
  367. 'chat_id': chat_info.chat_id, 'chat_record_id': str(
  368. uuid.uuid7()) if chat_record_id is None else str(chat_record_id),
  369. 'stream': stream,
  370. 're_chat': re_chat,
  371. 'chat_user_id': chat_user_id,
  372. 'chat_user_type': chat_user_type,
  373. 'ip_address': ip_address,
  374. 'source': source,
  375. 'workspace_id': workspace_id,
  376. 'debug': debug,
  377. 'chat_user': chat_info.get_chat_user(),
  378. 'chat_user_group': chat_info.get_chat_user_group(),
  379. 'application_id': str(chat_info.application_id)},
  380. WorkFlowPostHandler(chat_info),
  381. base_to_response, form_data, image_list, document_list, audio_list,
  382. video_list,
  383. other_list,
  384. instance.get('runtime_node_id'),
  385. instance.get('node_data'), chat_record, instance.get('child_node'))
  386. chat_info.set_chat(message)
  387. r = work_flow_manage.run()
  388. return r
  389. def is_valid_chat_user(self):
  390. chat_user_id = self.data.get('chat_user_id')
  391. application_id = self.data.get('application_id')
  392. chat_user_type = self.data.get('chat_user_type')
  393. is_auth_chat_user = DatabaseModelManage.get_model("is_auth_chat_user")
  394. application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first()
  395. if application_access_token and application_access_token.authentication and application_access_token.authentication_value.get(
  396. 'type') == 'login':
  397. if chat_user_type == ChatUserType.ANONYMOUS_USER.value:
  398. raise ChatException(500, _("The chat user is not authorized."))
  399. if chat_user_type == ChatUserType.CHAT_USER.value and is_auth_chat_user:
  400. is_auth = is_auth_chat_user(chat_user_id, application_id)
  401. if not is_auth:
  402. raise ChatException(500, _("The chat user is not authorized."))
  403. def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
  404. super().is_valid(raise_exception=True)
  405. ChatMessageSerializers(data=instance).is_valid(raise_exception=True)
  406. chat_info = self.get_chat_info()
  407. chat_info.get_application()
  408. chat_info.get_chat_user(asker=(instance.get('form_data') or {}).get('asker'))
  409. self.is_valid_chat_id(chat_info)
  410. if not self.data.get('debug'):
  411. self.is_valid_chat_user()
  412. if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
  413. self.is_valid_application_simple(raise_exception=True, chat_info=chat_info)
  414. return self.chat_simple(chat_info, instance, base_to_response)
  415. else:
  416. self.is_valid_application_workflow(raise_exception=True)
  417. return self.chat_work_flow(chat_info, instance, base_to_response)
  418. def get_chat_info(self):
  419. self.is_valid(raise_exception=True)
  420. chat_id = self.data.get('chat_id')
  421. chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
  422. if chat_info is None:
  423. chat_info: ChatInfo = self.re_open_chat(chat_id)
  424. chat_info.set_cache()
  425. return chat_info
  426. def re_open_chat(self, chat_id: str):
  427. chat = QuerySet(Chat).filter(id=chat_id).first()
  428. if chat is None:
  429. raise ChatException(500, _("Conversation does not exist"))
  430. application = QuerySet(Application).filter(id=chat.application_id).first()
  431. if application is None:
  432. raise ChatException(500, _("Application does not exist"))
  433. application_version = QuerySet(ApplicationVersion).filter(application_id=application.id).order_by(
  434. '-create_time')[0:1].first()
  435. if application_version is None:
  436. raise ChatException(500, _("The application has not been published. Please use it after publishing."))
  437. if application.type == ApplicationTypeChoices.SIMPLE:
  438. return self.re_open_chat_simple(chat_id, application)
  439. else:
  440. return self.re_open_chat_work_flow(chat_id, application)
  441. def re_open_chat_simple(self, chat_id, application):
  442. # 数据集id列表
  443. knowledge_id_list = [str(row.target_id) for row in
  444. QuerySet(ResourceMapping).filter(source_id=str(application.id),
  445. source_type='APPLICATION',
  446. target_type='KNOWLEDGE')]
  447. # 需要排除的文档
  448. exclude_document_id_list = [str(document.id) for document in
  449. QuerySet(Document).filter(
  450. knowledge_id__in=knowledge_id_list,
  451. is_active=False)]
  452. chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'),
  453. self.data.get('ip_address'),
  454. self.data.get('source'), knowledge_id_list,
  455. exclude_document_id_list, application.id)
  456. chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
  457. chat_record_list.sort(key=lambda r: r.create_time)
  458. for chat_record in chat_record_list:
  459. chat_info.chat_record_list.append(chat_record)
  460. return chat_info
  461. def re_open_chat_work_flow(self, chat_id, application):
  462. chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'),
  463. self.data.get('ip_address'),
  464. self.data.get('source'), [], [],
  465. application.id)
  466. chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
  467. chat_record_list.sort(key=lambda r: r.create_time)
  468. for chat_record in chat_record_list:
  469. chat_info.chat_record_list.append(chat_record)
  470. return chat_info
  471. class OpenChatSerializers(serializers.Serializer):
  472. workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
  473. application_id = serializers.UUIDField(required=True)
  474. chat_user_id = serializers.CharField(required=True, label=_("Client id"))
  475. chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
  476. debug = serializers.BooleanField(required=True, label=_("Debug"))
  477. ip_address = serializers.CharField(required=False, label=_("IP Address"))
  478. source = serializers.JSONField(required=False, label=_("Source"))
  479. def is_valid(self, *, raise_exception=False):
  480. super().is_valid(raise_exception=True)
  481. workspace_id = self.data.get('workspace_id')
  482. application_id = self.data.get('application_id')
  483. query_set = QuerySet(Application).filter(id=application_id)
  484. if workspace_id:
  485. query_set = query_set.filter(workspace_id=workspace_id)
  486. if not query_set.exists():
  487. raise AppApiException(500, gettext('Application does not exist'))
  488. def open(self):
  489. self.is_valid(raise_exception=True)
  490. application_id = self.data.get('application_id')
  491. application = QuerySet(Application).get(id=application_id)
  492. debug = self.data.get("debug")
  493. if not debug:
  494. application_version = QuerySet(ApplicationVersion).filter(application_id=application_id).order_by(
  495. '-create_time')[0:1].first()
  496. if application_version is None:
  497. raise AppApiException(500,
  498. _("The application has not been published. Please use it after publishing."))
  499. if application.type == ApplicationTypeChoices.SIMPLE:
  500. return self.open_simple(application)
  501. else:
  502. return self.open_work_flow(application)
  503. def open_work_flow(self, application):
  504. self.is_valid(raise_exception=True)
  505. application_id = self.data.get('application_id')
  506. chat_user_id = self.data.get("chat_user_id")
  507. chat_user_type = self.data.get("chat_user_type")
  508. ip_address = self.data.get("ip_address")
  509. source = self.data.get("source")
  510. debug = self.data.get("debug")
  511. chat_id = str(uuid.uuid7())
  512. ChatInfo(chat_id, chat_user_id, chat_user_type, ip_address, source, [],
  513. [],
  514. application_id, debug).set_cache()
  515. return chat_id
  516. def open_simple(self, application):
  517. application_id = self.data.get('application_id')
  518. chat_user_id = self.data.get("chat_user_id")
  519. chat_user_type = self.data.get("chat_user_type")
  520. ip_address = self.data.get("ip_address")
  521. source = self.data.get("source")
  522. debug = self.data.get("debug")
  523. knowledge_id_list = [str(row.target_id) for row in
  524. QuerySet(ResourceMapping).filter(source_id=str(application_id),
  525. source_type='APPLICATION',
  526. target_type='KNOWLEDGE')]
  527. chat_id = str(uuid.uuid7())
  528. ChatInfo(chat_id, chat_user_id, chat_user_type, ip_address, source, knowledge_id_list,
  529. [str(document.id) for document in
  530. QuerySet(Document).filter(
  531. knowledge_id__in=knowledge_id_list,
  532. is_active=False)],
  533. application_id,
  534. debug=debug).set_cache()
  535. return chat_id
  536. class TextToSpeechSerializers(serializers.Serializer):
  537. application_id = serializers.UUIDField(required=True, label=_("Application ID"))
  538. def text_to_speech(self, instance):
  539. self.is_valid(raise_exception=True)
  540. application_id = self.data.get('application_id')
  541. application = QuerySet(Application).filter(id=application_id).first()
  542. return ApplicationOperateSerializer(
  543. data={'application_id': application_id,
  544. 'user_id': application.user_id}).text_to_speech(instance, False)
  545. class SpeechToTextSerializers(serializers.Serializer):
  546. application_id = serializers.UUIDField(required=True, label=_("Application ID"))
  547. def speech_to_text(self, instance):
  548. self.is_valid(raise_exception=True)
  549. application_id = self.data.get('application_id')
  550. application = QuerySet(Application).filter(id=application_id).first()
  551. return ApplicationOperateSerializer(
  552. data={'application_id': application_id,
  553. 'user_id': application.user_id}).speech_to_text(instance, False)