application_chat.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: application_chat.py
  6. @date:2025/6/10 11:06
  7. @desc:
  8. """
  9. import datetime
  10. import os
  11. import re
  12. from io import BytesIO
  13. from typing import Dict
  14. import openpyxl
  15. import pytz
  16. from django.core import validators
  17. from django.db import models
  18. from django.db.models import QuerySet, Q
  19. from django.http import StreamingHttpResponse
  20. from django.utils import timezone
  21. from django.utils.translation import gettext_lazy as _, gettext
  22. from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
  23. from rest_framework import serializers
  24. from application.models import Chat, Application, ChatRecord, ChatSourceChoices
  25. from common.db.search import get_dynamics_model, native_search, native_page_search, native_page_handler
  26. from common.exception.app_exception import AppApiException
  27. from common.utils.common import get_file_content
  28. from maxkb.conf import PROJECT_DIR
  29. from maxkb.settings import TIME_ZONE, edition
  30. class ApplicationChatResponseSerializers(serializers.Serializer):
  31. id = serializers.UUIDField(required=True, label=_("chat id"))
  32. abstract = serializers.CharField(required=True, label=_("summary"))
  33. chat_user_id = serializers.UUIDField(required=True, label=_("Chat User ID"))
  34. chat_user_type = serializers.CharField(required=True, label=_("Chat User Type"))
  35. is_deleted = serializers.BooleanField(required=True, label=_("Is delete"))
  36. application_id = serializers.UUIDField(required=True, label=_("Application ID"))
  37. chat_record_count = serializers.IntegerField(required=True, label=_("Number of conversations"))
  38. star_num = serializers.IntegerField(required=True, label=_("Number of Likes"))
  39. trample_num = serializers.IntegerField(required=True, label=_("Number of thumbs-downs"))
  40. mark_sum = serializers.IntegerField(required=True, label=_("Number of tags"))
  41. class ApplicationChatRecordExportRequest(serializers.Serializer):
  42. select_ids = serializers.ListField(required=True, label=_("Chat ID List"),
  43. child=serializers.UUIDField(required=True, label=_("Chat ID")))
  44. class ApplicationChatQuerySerializers(serializers.Serializer):
  45. workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
  46. abstract = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("summary"))
  47. username = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("username"))
  48. start_time = serializers.DateField(format='%Y-%m-%d', label=_("Start time"))
  49. end_time = serializers.DateField(format='%Y-%m-%d', label=_("End time"))
  50. application_id = serializers.UUIDField(required=True, label=_("Application ID"))
  51. min_star = serializers.IntegerField(required=False, min_value=0,
  52. label=_("Minimum number of likes"))
  53. min_trample = serializers.IntegerField(required=False, min_value=0,
  54. label=_("Minimum number of clicks"))
  55. comparer = serializers.CharField(required=False, label=_("Comparator"), validators=[
  56. validators.RegexValidator(regex=re.compile("^and|or$"),
  57. message=_("Only supports and|or"), code=500)
  58. ])
  59. def is_valid(self, *, raise_exception=False):
  60. super().is_valid(raise_exception=True)
  61. workspace_id = self.data.get('workspace_id')
  62. query_set = QuerySet(Application).filter(id=self.data.get('application_id'))
  63. if workspace_id:
  64. query_set = query_set.filter(workspace_id=workspace_id)
  65. if not query_set.exists():
  66. raise AppApiException(500, _('Application id does not exist'))
  67. def get_end_time(self):
  68. d = datetime.datetime.strptime(self.data.get('end_time'), '%Y-%m-%d').date()
  69. naive = datetime.datetime.combine(d, datetime.time.max)
  70. return timezone.make_aware(naive, timezone.get_default_timezone())
  71. def get_start_time(self):
  72. d = datetime.datetime.strptime(self.data.get('start_time'), '%Y-%m-%d').date()
  73. naive = datetime.datetime.combine(d, datetime.time.min)
  74. return timezone.make_aware(naive, timezone.get_default_timezone())
  75. def get_query_set(self, select_ids=None):
  76. end_time = self.get_end_time()
  77. start_time = self.get_start_time()
  78. query_set = QuerySet(model=get_dynamics_model(
  79. {'application_chat.application_id': models.CharField(),
  80. 'application_chat.abstract': models.CharField(),
  81. 'application_chat.asker': models.JSONField(),
  82. "star_num": models.IntegerField(),
  83. 'trample_num': models.IntegerField(),
  84. 'comparer': models.CharField(),
  85. 'application_chat.update_time': models.DateTimeField(),
  86. 'application_chat.id': models.UUIDField(),
  87. 'application_chat_record_temp.id': models.UUIDField()}))
  88. base_query_dict = {'application_chat.application_id': self.data.get("application_id"),
  89. 'application_chat.update_time__gte': start_time,
  90. 'application_chat.update_time__lte': end_time,
  91. }
  92. if 'abstract' in self.data and self.data.get('abstract') is not None:
  93. base_query_dict['application_chat.abstract__icontains'] = self.data.get('abstract')
  94. if 'username' in self.data and self.data.get('username') is not None:
  95. base_query_dict['application_chat.asker__username__icontains'] = self.data.get('username')
  96. if select_ids is not None and len(select_ids) > 0:
  97. base_query_dict['application_chat.id__in'] = select_ids
  98. base_condition = Q(**base_query_dict)
  99. min_star_query = None
  100. min_trample_query = None
  101. if 'min_star' in self.data and self.data.get('min_star') is not None:
  102. min_star_query = Q(star_num__gte=self.data.get('min_star'))
  103. if 'min_trample' in self.data and self.data.get('min_trample') is not None:
  104. min_trample_query = Q(trample_num__gte=self.data.get('min_trample'))
  105. if min_star_query is not None and min_trample_query is not None:
  106. if self.data.get(
  107. 'comparer') is not None and self.data.get('comparer') == 'or':
  108. condition = base_condition & (min_star_query | min_trample_query)
  109. else:
  110. condition = base_condition & (min_star_query & min_trample_query)
  111. elif min_star_query is not None:
  112. condition = base_condition & min_star_query
  113. elif min_trample_query is not None:
  114. condition = base_condition & min_trample_query
  115. else:
  116. condition = base_condition
  117. return {
  118. 'default_queryset': query_set.filter(condition).order_by("-application_chat.update_time")
  119. }
  120. def list(self, with_valid=True):
  121. if with_valid:
  122. self.is_valid(raise_exception=True)
  123. return native_search(self.get_query_set(), select_string=get_file_content(
  124. os.path.join(PROJECT_DIR, "apps", "application", 'sql',
  125. ('list_application_chat_ee.sql' if ['PE', 'EE'].__contains__(
  126. edition) else 'list_application_chat.sql'))),
  127. with_table_name=False)
  128. @staticmethod
  129. def paragraph_list_to_string(paragraph_list):
  130. return "\n**********\n".join(
  131. [f"{paragraph.get('title')}:\n{paragraph.get('content')}" for paragraph in
  132. paragraph_list] if paragraph_list is not None else '')
  133. @staticmethod
  134. def to_row(row: Dict):
  135. details = row.get('details') or {}
  136. padding_problem_text = ' '.join((node.get("answer", "") or "") for key, node in details.items() if
  137. node.get("type") == 'question-node')
  138. search_dataset_node_list = [(key, node) for key, node in details.items() if
  139. node.get("type") == 'search-dataset-node' or node.get(
  140. "step_type") == 'search_step' or node.get("type") == 'search-knowledge-node']
  141. reference_paragraph_len = '\n'.join([str(len(node.get('paragraph_list',
  142. []))) if key == 'search_step' else node.get(
  143. 'name') + ':' + str(
  144. len(node.get('paragraph_list', [])) if node.get('paragraph_list', []) is not None else '0') for
  145. key, node in search_dataset_node_list])
  146. reference_paragraph = '\n----------\n'.join(
  147. [ApplicationChatQuerySerializers.paragraph_list_to_string(node.get('paragraph_list',
  148. [])) if key == 'search_step' else node.get(
  149. 'name') + ':\n' + ApplicationChatQuerySerializers.paragraph_list_to_string(node.get('paragraph_list',
  150. [])) for
  151. key, node in search_dataset_node_list])
  152. improve_paragraph_list = row.get('improve_paragraph_list') or []
  153. vote_status_map = {'-1': '未投票', '0': '赞同', '1': '反对'}
  154. vote_reason_map = {'accurate': gettext('accurate'), 'complete': gettext('complete'),
  155. 'inaccurate': gettext('inaccurate'), 'incomplete': gettext('incomplete'),
  156. 'other': gettext('Other'), }
  157. return [str(row.get('chat_id')), row.get('abstract'), row.get('problem_text'), padding_problem_text,
  158. row.get('answer_text'), vote_status_map.get(row.get('vote_status')),
  159. vote_reason_map.get(row.get('vote_reason')),
  160. row.get('vote_other_content'),
  161. reference_paragraph_len,
  162. reference_paragraph,
  163. "\n".join([
  164. f"{improve_paragraph_list[index].get('title')}\n{improve_paragraph_list[index].get('content')}"
  165. for index in range(len(improve_paragraph_list))]),
  166. row.get('asker').get('username'),
  167. (row.get('message_tokens') or 0) + (row.get('answer_tokens') or 0),
  168. row.get('ip_address') or '-',
  169. get_source_display(row.get('source')),
  170. row.get('run_time'),
  171. str(row.get('create_time').astimezone(pytz.timezone(TIME_ZONE)).strftime('%Y-%m-%d %H:%M:%S')
  172. if row.get('create_time') is not None else None)]
  173. @staticmethod
  174. def reset_value(value):
  175. if isinstance(value, str):
  176. value = re.sub(ILLEGAL_CHARACTERS_RE, '', value)
  177. if value.startswith(('=', '+', '-', '@')):
  178. value = "'" + value
  179. if isinstance(value, datetime.datetime):
  180. eastern = pytz.timezone(TIME_ZONE)
  181. c = datetime.timezone(eastern._utcoffset)
  182. value = value.astimezone(c)
  183. return value
  184. def export(self, data, with_valid=True):
  185. if with_valid:
  186. self.is_valid(raise_exception=True)
  187. ApplicationChatRecordExportRequest(data=data).is_valid(raise_exception=True)
  188. def stream_response():
  189. workbook = openpyxl.Workbook(write_only=True)
  190. worksheet = workbook.create_sheet(title='Sheet1')
  191. current_page = 1
  192. page_size = 500
  193. headers = [gettext('Conversation ID'), gettext('summary'), gettext('User Questions'),
  194. gettext('Problem after optimization'),
  195. gettext('answer'), gettext('User feedback'), gettext('Feedback reason'),
  196. gettext('Other reason content'),
  197. gettext('Reference segment number'),
  198. gettext('Section title + content'),
  199. gettext('Annotation'), gettext('USER'), gettext('Consuming tokens'),
  200. gettext('Ip Address'), gettext('source'),
  201. gettext('Time consumed (s)'),
  202. gettext('Question Time')]
  203. worksheet.append(headers)
  204. for data_list in native_page_handler(page_size, self.get_query_set(data.get('select_ids')),
  205. primary_key='application_chat_record_temp.id',
  206. primary_queryset='default_queryset',
  207. get_primary_value=lambda item: item.get('id'),
  208. select_string=get_file_content(
  209. os.path.join(PROJECT_DIR, "apps", "application", 'sql',
  210. ('export_application_chat_ee.sql' if ['PE',
  211. 'EE'].__contains__(
  212. edition) else 'export_application_chat.sql'))),
  213. with_table_name=False):
  214. for item in data_list:
  215. row = [self.reset_value(v) for v in self.to_row(item)]
  216. worksheet.append(row)
  217. current_page = current_page + 1
  218. output = BytesIO()
  219. workbook.save(output)
  220. output.seek(0)
  221. yield output.getvalue()
  222. output.close()
  223. workbook.close()
  224. response = StreamingHttpResponse(stream_response(),
  225. content_type='application/vnd.open.xmlformats-officedocument.spreadsheetml.sheet')
  226. response['Content-Disposition'] = 'attachment; filename="data.xlsx"'
  227. return response
  228. def page(self, current_page: int, page_size: int, with_valid=True):
  229. if with_valid:
  230. self.is_valid(raise_exception=True)
  231. return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
  232. os.path.join(PROJECT_DIR, "apps", "application", 'sql',
  233. ('list_application_chat_ee.sql' if ['PE', 'EE'].__contains__(
  234. edition) else 'list_application_chat.sql'))),
  235. with_table_name=False)
  236. class ChatCountSerializer(serializers.Serializer):
  237. chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
  238. def get_query_set(self):
  239. return QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
  240. def update_chat(self):
  241. self.is_valid(raise_exception=True)
  242. count_chat_record = native_search(self.get_query_set(), get_file_content(
  243. os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'count_chat_record.sql')), with_search_one=True)
  244. QuerySet(Chat).filter(id=self.data.get('chat_id')).update(star_num=count_chat_record.get('star_num', 0) or 0,
  245. trample_num=count_chat_record.get('trample_num',
  246. 0) or 0,
  247. chat_record_count=count_chat_record.get(
  248. 'chat_record_count', 0) or 0,
  249. mark_sum=count_chat_record.get('mark_sum', 0) or 0)
  250. return True
  251. def get_source_display(source):
  252. if not source or not isinstance(source, dict) or 'type' not in source:
  253. return '-'
  254. source_type = source.get('type')
  255. # 定义映射关系
  256. source_mapping = {
  257. ChatSourceChoices.ONLINE.value: gettext('Online Usage'),
  258. ChatSourceChoices.API_CALL.value: gettext('API Call'),
  259. ChatSourceChoices.ENTERPRISE_WECHAT.value: gettext('Enterprise WeChat'),
  260. ChatSourceChoices.WECHAT_PUBLIC_ACCOUNT.value: gettext('WeChat Public Account'),
  261. ChatSourceChoices.LARK.value: gettext('Lark'),
  262. ChatSourceChoices.DINGTALK.value: gettext('DingTalk'),
  263. ChatSourceChoices.ENTERPRISE_WECHAT_ROBOT.value: gettext('Enterprise WeChat Robot'),
  264. ChatSourceChoices.TRIGGER.value: gettext('Trigger'),
  265. ChatSourceChoices.SLACK.value: gettext('Slack'),
  266. }
  267. return source_mapping.get(source_type, str(source_type))