search.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: search.py
  6. @date:2023/10/7 18:20
  7. @desc:
  8. """
  9. import hashlib
  10. from typing import Dict, Any
  11. from django.db import DEFAULT_DB_ALIAS, models, connections
  12. from django.db.models import QuerySet
  13. from common.db.compiler import AppSQLCompiler
  14. from common.db.sql_execute import select_one, select_list, update_execute
  15. from common.result import Page
  16. # 添加模型缓存
  17. _model_cache = {}
  18. def get_dynamics_model(attr: dict, table_name='dynamics'):
  19. """
  20. 获取一个动态的django模型
  21. :param attr: 模型字段
  22. :param table_name: 表名
  23. :return: django 模型
  24. """
  25. # 创建缓存键,基于属性和表名
  26. cache_key = hashlib.md5(f"{table_name}_{str(sorted(attr.items()))}".encode()).hexdigest()
  27. # print(f'cache_key: {cache_key}')
  28. # 如果模型已存在,直接返回缓存的模型
  29. if cache_key in _model_cache:
  30. return _model_cache[cache_key]
  31. attributes = {
  32. "__module__": "knowledge.models",
  33. "Meta": type("Meta", (), {'db_table': table_name}),
  34. **attr
  35. }
  36. # 使用唯一的类名避免冲突
  37. class_name = f'Dynamics_{cache_key[:8]}'
  38. model_class = type(class_name, (models.Model,), attributes)
  39. # 缓存模型
  40. _model_cache[cache_key] = model_class
  41. return model_class
  42. def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
  43. field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False):
  44. """
  45. 生成 查询sql
  46. :param with_table_name:
  47. :param queryset_dict: 多条件 查询条件
  48. :param select_string: 查询sql
  49. :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
  50. :return: sql:需要查询的sql params: sql 参数
  51. """
  52. params_dict: Dict[int, Any] = {}
  53. result_params = []
  54. for key in queryset_dict.keys():
  55. value = queryset_dict.get(key)
  56. sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key),
  57. with_table_name)
  58. params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
  59. select_string = select_string.replace("${" + key + "}", sql)
  60. for key in sorted(list(params_dict.keys())):
  61. result_params = [*result_params, *params_dict.get(key)]
  62. return select_string, result_params
  63. def generate_sql_by_query(queryset: QuerySet, select_string: str,
  64. field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
  65. """
  66. 生成 查询sql
  67. :param queryset: 查询条件
  68. :param select_string: 原始sql
  69. :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
  70. :return: sql:需要查询的sql params: sql 参数
  71. """
  72. sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name)
  73. return select_string + " " + sql, params
  74. def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
  75. """
  76. 解析 queryset查询对象
  77. :param with_table_name:
  78. :param queryset: 查询对象
  79. :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
  80. :return: sql:需要查询的sql params: sql 参数
  81. """
  82. q = queryset.query
  83. compiler = q.get_compiler(DEFAULT_DB_ALIAS)
  84. if field_replace_dict is None:
  85. field_replace_dict = get_field_replace_dict(queryset)
  86. app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
  87. field_replace_dict=field_replace_dict)
  88. sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
  89. return sql, params
  90. def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
  91. field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
  92. with_search_one=False, with_table_name=False):
  93. """
  94. 复杂查询
  95. :param with_table_name: 生成sql是否包含表名
  96. :param queryset: 查询条件构造器
  97. :param select_string: 查询前缀 不包括 where limit 等信息
  98. :param field_replace_dict: 需要替换的字段
  99. :param with_search_one: 查询
  100. :return: 查询结果
  101. """
  102. if isinstance(queryset, Dict):
  103. exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
  104. else:
  105. exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
  106. if with_search_one:
  107. return select_one(exec_sql, exec_params)
  108. else:
  109. return select_list(exec_sql, exec_params)
  110. def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
  111. field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
  112. with_table_name=False):
  113. """
  114. 复杂查询
  115. :param with_table_name: 生成sql是否包含表名
  116. :param queryset: 查询条件构造器
  117. :param select_string: 查询前缀 不包括 where limit 等信息
  118. :param field_replace_dict: 需要替换的字段
  119. :return: 查询结果
  120. """
  121. if isinstance(queryset, Dict):
  122. exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
  123. else:
  124. exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
  125. return update_execute(exec_sql, exec_params)
  126. def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
  127. """
  128. 分页查询
  129. :param current_page: 当前页
  130. :param page_size: 每页大小
  131. :param queryset: 查询条件
  132. :param post_records_handler: 数据处理器
  133. :return: 分页结果
  134. """
  135. total = QuerySet(query=queryset.query.clone(), model=queryset.model).count()
  136. result = queryset.all()[((current_page - 1) * page_size):(current_page * page_size)]
  137. return Page(total, list(map(post_records_handler, result)), current_page, page_size)
  138. def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
  139. field_replace_dict=None,
  140. post_records_handler=lambda r: r,
  141. with_table_name=False):
  142. """
  143. 复杂分页查询
  144. :param with_table_name:
  145. :param current_page: 当前页
  146. :param page_size: 每页大小
  147. :param queryset: 查询条件
  148. :param select_string: 查询
  149. :param field_replace_dict: 特殊字段替换
  150. :param post_records_handler: 数据row处理器
  151. :return: 分页结果
  152. """
  153. if isinstance(queryset, Dict):
  154. exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
  155. else:
  156. exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
  157. total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
  158. total = select_one(total_sql, exec_params)
  159. limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
  160. ((current_page - 1) * page_size), (current_page * page_size)
  161. )
  162. page_sql = exec_sql + " " + limit_sql
  163. result = select_list(page_sql, exec_params)
  164. return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)
  165. def native_page_handler(page_size: int,
  166. queryset: QuerySet | Dict[str, QuerySet],
  167. select_string: str,
  168. field_replace_dict=None,
  169. with_table_name=False,
  170. primary_key=None,
  171. get_primary_value=None,
  172. primary_queryset: str = None,
  173. ):
  174. if isinstance(queryset, Dict):
  175. exec_sql, exec_params = generate_sql_by_query_dict({**queryset,
  176. primary_queryset: queryset[primary_queryset].order_by(
  177. primary_key)}, select_string, field_replace_dict, with_table_name)
  178. else:
  179. exec_sql, exec_params = generate_sql_by_query(queryset.order_by(
  180. primary_key), select_string, field_replace_dict, with_table_name)
  181. total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
  182. total = select_one(total_sql, exec_params)
  183. processed_count = 0
  184. last_id = None
  185. while processed_count < total.get("count"):
  186. if last_id is not None:
  187. if isinstance(queryset, Dict):
  188. exec_sql, exec_params = generate_sql_by_query_dict({**queryset,
  189. primary_queryset: queryset[primary_queryset].filter(
  190. **{f"{primary_key}__gt": last_id}).order_by(
  191. primary_key)},
  192. select_string, field_replace_dict,
  193. with_table_name)
  194. else:
  195. exec_sql, exec_params = generate_sql_by_query(
  196. queryset.filter(**{f"{primary_key}__gt": last_id}).order_by(
  197. primary_key),
  198. select_string, field_replace_dict,
  199. with_table_name)
  200. limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
  201. 0, page_size
  202. )
  203. page_sql = exec_sql + " " + limit_sql
  204. result = select_list(page_sql, exec_params)
  205. yield result
  206. processed_count += page_size
  207. last_id = get_primary_value(result[-1])
  208. def get_field_replace_dict(queryset: QuerySet):
  209. """
  210. 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx"
  211. :param queryset: 查询对象
  212. :return: 需要替换的字典
  213. """
  214. result = {}
  215. for field in queryset.model._meta.local_fields:
  216. if field.attname.__contains__("."):
  217. replace_field = to_replace_field(field.attname)
  218. result.__setitem__('"' + field.attname + '"', replace_field)
  219. return result
  220. def to_replace_field(field: str):
  221. """
  222. 将field 转换为 需要替换的field “xxx.xxx”需要被替换成 “xxx”."xxx" 只替换 field包含.的字段
  223. :param field: django field字段
  224. :return: 替换字段
  225. """
  226. split_field = field.split(".")
  227. return ".".join(list(map(lambda sf: '"' + sf + '"', split_field)))