snippet_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. """
  2. 知识片段业务逻辑服务
  3. """
  4. from typing import List, Optional, Tuple, Dict, Any
  5. import json
  6. import random
  7. import csv
  8. import io
  9. import time
  10. from datetime import datetime
  11. from app.services.milvus_service import milvus_service
  12. from app.schemas.base import PaginationSchema, PaginatedResponseSchema
  13. from app.utils.vector_utils import text_to_vector_algo
  14. class SnippetService:
  15. def get_list(
  16. self,
  17. page: int = 1,
  18. page_size: int = 10,
  19. kb: Optional[str] = None,
  20. keyword: Optional[str] = None,
  21. status: Optional[str] = None
  22. ) -> Tuple[List[Dict], PaginationSchema]:
  23. """获取知识片段列表 (跨集合查询)"""
  24. # 1. 确定要查询的目标集合列表
  25. target_collections = []
  26. if kb:
  27. target_collections = [kb]
  28. else:
  29. # 简单起见,先查 Milvus 的所有集合
  30. target_collections = milvus_service.client.list_collections()
  31. if not target_collections:
  32. return [], PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
  33. # 2. 计算分页逻辑 (跨集合分页算法)
  34. global_total = 0
  35. items = []
  36. # 需要跳过的全局偏移量
  37. skip_count = (page - 1) * page_size
  38. # 需要获取的目标数量
  39. need_count = page_size
  40. # 遍历所有集合
  41. for col_name in target_collections:
  42. if not milvus_service.has_collection(col_name):
  43. continue
  44. try:
  45. # 获取该集合总数
  46. stats = milvus_service.client.get_collection_stats(col_name)
  47. col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
  48. if keyword:
  49. # 关键词模式:必须实际查询
  50. desc = milvus_service.client.describe_collection(col_name)
  51. existing_fields = [f['name'] for f in desc.get('fields', [])]
  52. # 尝试获取所有字段
  53. output_fields = ["*"]
  54. expr = f'text like "%{keyword}%"' if 'text' in existing_fields else ""
  55. if not expr: continue
  56. res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
  57. col_hits = len(res)
  58. global_total += col_hits
  59. if skip_count >= col_hits:
  60. skip_count -= col_hits
  61. continue
  62. take = min(need_count, col_hits - skip_count)
  63. chunk = res[skip_count : skip_count + take]
  64. for r in chunk:
  65. items.append(self._format_snippet(r, col_name))
  66. skip_count = 0
  67. need_count -= take
  68. if need_count <= 0: break
  69. else:
  70. # 无关键词模式
  71. global_total += col_count
  72. if skip_count >= col_count:
  73. skip_count -= col_count
  74. continue
  75. if need_count > 0:
  76. current_offset = skip_count
  77. current_limit = min(need_count, col_count - current_offset)
  78. output_fields = ["*"]
  79. res = milvus_service.client.query(
  80. collection_name=col_name,
  81. filter="",
  82. output_fields=output_fields,
  83. limit=current_limit,
  84. offset=current_offset
  85. )
  86. for r in res:
  87. items.append(self._format_snippet(r, col_name))
  88. skip_count = 0
  89. need_count -= current_limit
  90. except Exception as e:
  91. print(f"Collection {col_name} query error: {e}")
  92. continue
  93. total_pages = (global_total + page_size - 1) // page_size if page_size else 0
  94. meta = PaginationSchema(
  95. page=page,
  96. page_size=page_size,
  97. total=global_total,
  98. total_pages=total_pages
  99. )
  100. return items, meta
  101. def create(self, payload: Any) -> Dict:
  102. """创建知识片段"""
  103. # 使用统一算法生成向量
  104. dim = milvus_service.DENSE_DIM
  105. fake_vector = text_to_vector_algo(payload.content, dim=dim)
  106. # 基础数据
  107. now = int(time.time() * 1000)
  108. item = {
  109. "dense": fake_vector,
  110. "text": payload.content,
  111. "document_id": "manual_add",
  112. "tag_list": "",
  113. "permission": {},
  114. "metadata": {
  115. "doc_name": payload.doc_name,
  116. "file_name": payload.doc_name,
  117. "title": payload.doc_name
  118. },
  119. "index": 0,
  120. "is_deleted": 0,
  121. "created_by": "system",
  122. "created_time": now,
  123. "updated_by": "system",
  124. "updated_time": now
  125. }
  126. # 合并自定义字段
  127. if hasattr(payload, 'custom_fields') and payload.custom_fields:
  128. item.update(payload.custom_fields)
  129. data = [item]
  130. res = milvus_service.client.insert(
  131. collection_name=payload.collection_name,
  132. data=data
  133. )
  134. milvus_service.client.flush(payload.collection_name)
  135. return {"count": res.get("insert_count", 1)}
  136. def update(self, id: str, payload: Any) -> str:
  137. """更新知识片段"""
  138. kb = payload.collection_name
  139. # 1. 删除旧数据
  140. desc = milvus_service.client.describe_collection(kb)
  141. fields = [f['name'] for f in desc.get('fields', [])]
  142. pk_field = "pk" if "pk" in fields else "id"
  143. if id.isdigit():
  144. expr = f"{pk_field} in [{id}]"
  145. else:
  146. expr = f"{pk_field} in ['{id}']"
  147. milvus_service.client.delete(collection_name=kb, filter=expr)
  148. # 2. 插入新数据
  149. # 使用统一算法生成向量
  150. dim = milvus_service.DENSE_DIM
  151. fake_vector = text_to_vector_algo(payload.content, dim=dim)
  152. now = int(time.time() * 1000)
  153. item = {
  154. "dense": fake_vector,
  155. "text": payload.content,
  156. "document_id": "updated",
  157. "tag_list": "",
  158. "permission": {},
  159. "metadata": {
  160. "doc_name": payload.doc_name or "已更新",
  161. "file_name": payload.doc_name,
  162. "title": payload.doc_name
  163. },
  164. "index": 0,
  165. "is_deleted": 0,
  166. "created_by": "system",
  167. "created_time": now,
  168. "updated_by": "system",
  169. "updated_time": now
  170. }
  171. # 合并自定义字段
  172. if hasattr(payload, 'custom_fields') and payload.custom_fields:
  173. item.update(payload.custom_fields)
  174. data = [item]
  175. milvus_service.client.insert(collection_name=kb, data=data)
  176. milvus_service.client.flush(kb)
  177. return "更新成功 (ID已变更)"
  178. def delete(self, id: str, kb: str) -> None:
  179. """删除知识片段"""
  180. if not milvus_service.has_collection(kb):
  181. raise ValueError("知识库不存在")
  182. desc = milvus_service.client.describe_collection(kb)
  183. fields = [f['name'] for f in desc.get('fields', [])]
  184. pk_field = "pk" if "pk" in fields else "id"
  185. if id.isdigit():
  186. expr = f"{pk_field} in [{id}]"
  187. else:
  188. expr = f"{pk_field} in ['{id}']"
  189. milvus_service.client.delete(
  190. collection_name=kb,
  191. filter=expr
  192. )
  193. milvus_service.client.flush(kb)
  194. def _format_snippet(self, r: Dict, col_name: str) -> Dict:
  195. id_val = r.get("id") or r.get("pk")
  196. content = r.get("text") or r.get("content") or r.get("page_content") or ""
  197. if not content:
  198. try:
  199. debug_content = r.copy()
  200. if "dense" in debug_content: del debug_content["dense"]
  201. content = json.dumps(debug_content, default=str, ensure_ascii=False)
  202. except:
  203. content = "无法解析内容"
  204. doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
  205. meta_info = f"ParentID: {r.get('parent_id', '-')}"
  206. return {
  207. "id": str(id_val),
  208. "collection_name": col_name,
  209. "doc_name": doc_name,
  210. "code": f"SNIP-{id_val}",
  211. "content": content,
  212. "char_count": len(content) if content else 0,
  213. "meta_info": meta_info,
  214. "status": "normal",
  215. "created_at": "-",
  216. "updated_at": "-"
  217. }
  218. def export_snippets(self, kb: Optional[str] = None, keyword: Optional[str] = None) -> Any:
  219. """导出知识片段 (生成器)"""
  220. # 1. 确定要查询的目标集合列表
  221. target_collections = []
  222. if kb:
  223. target_collections = [kb]
  224. else:
  225. target_collections = milvus_service.client.list_collections()
  226. for col_name in target_collections:
  227. if not milvus_service.has_collection(col_name):
  228. continue
  229. try:
  230. # 获取该集合总数
  231. stats = milvus_service.client.get_collection_stats(col_name)
  232. col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
  233. if col_count == 0:
  234. continue
  235. output_fields = ["*"]
  236. expr = ""
  237. if keyword:
  238. desc = milvus_service.client.describe_collection(col_name)
  239. existing_fields = [f['name'] for f in desc.get('fields', [])]
  240. if 'text' in existing_fields:
  241. expr = f'text like "%{keyword}%"'
  242. else:
  243. continue
  244. # 分批获取所有数据
  245. batch_size = 1000
  246. offset = 0
  247. while True:
  248. res = milvus_service.client.query(
  249. collection_name=col_name,
  250. filter=expr,
  251. output_fields=output_fields,
  252. limit=batch_size,
  253. offset=offset
  254. )
  255. if not res:
  256. break
  257. for r in res:
  258. yield self._format_snippet(r, col_name)
  259. offset += len(res)
  260. if len(res) < batch_size:
  261. break
  262. except Exception as e:
  263. print(f"Collection {col_name} export error: {e}")
  264. continue
  265. def generate_csv_stream(self, kb: Optional[str] = None, keyword: Optional[str] = None):
  266. """生成CSV流"""
  267. output = io.StringIO()
  268. fieldnames = ["id", "collection_name", "doc_name", "content", "meta_info", "created_at", "status"]
  269. writer = csv.DictWriter(output, fieldnames=fieldnames)
  270. # 写入表头
  271. writer.writeheader()
  272. yield output.getvalue()
  273. output.seek(0)
  274. output.truncate(0)
  275. for item in self.export_snippets(kb, keyword):
  276. # 过滤掉不在 fieldnames 中的字段
  277. row = {k: item.get(k, "") for k in fieldnames}
  278. writer.writerow(row)
  279. yield output.getvalue()
  280. output.seek(0)
  281. output.truncate(0)
  282. snippet_service = SnippetService()