""" 知识片段业务逻辑服务 """ from typing import List, Optional, Tuple, Dict, Any import json import random import csv import io import time from datetime import datetime from app.services.milvus_service import milvus_service from app.schemas.base import PaginationSchema, PaginatedResponseSchema from app.utils.vector_utils import text_to_vector_algo class SnippetService: def get_list( self, page: int = 1, page_size: int = 10, kb: Optional[str] = None, keyword: Optional[str] = None, status: Optional[str] = None ) -> Tuple[List[Dict], PaginationSchema]: """获取知识片段列表 (跨集合查询)""" # 1. 确定要查询的目标集合列表 target_collections = [] if kb: target_collections = [kb] else: # 简单起见,先查 Milvus 的所有集合 target_collections = milvus_service.client.list_collections() if not target_collections: return [], PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0) # 2. 计算分页逻辑 (跨集合分页算法) global_total = 0 items = [] # 需要跳过的全局偏移量 skip_count = (page - 1) * page_size # 需要获取的目标数量 need_count = page_size # 遍历所有集合 for col_name in target_collections: if not milvus_service.has_collection(col_name): continue try: # 获取该集合总数 stats = milvus_service.client.get_collection_stats(col_name) col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0 if keyword: # 关键词模式:必须实际查询 desc = milvus_service.client.describe_collection(col_name) existing_fields = [f['name'] for f in desc.get('fields', [])] # 尝试获取所有字段 output_fields = ["*"] expr = f'text like "%{keyword}%"' if 'text' in existing_fields else "" if not expr: continue res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100) col_hits = len(res) global_total += col_hits if skip_count >= col_hits: skip_count -= col_hits continue take = min(need_count, col_hits - skip_count) chunk = res[skip_count : skip_count + take] for r in chunk: items.append(self._format_snippet(r, col_name)) skip_count = 0 need_count -= take if need_count <= 0: break else: # 无关键词模式 global_total += col_count if skip_count >= col_count: skip_count -= col_count continue if need_count > 0: current_offset = skip_count current_limit = min(need_count, col_count - current_offset) output_fields = ["*"] res = milvus_service.client.query( collection_name=col_name, filter="", output_fields=output_fields, limit=current_limit, offset=current_offset ) for r in res: items.append(self._format_snippet(r, col_name)) skip_count = 0 need_count -= current_limit except Exception as e: print(f"Collection {col_name} query error: {e}") continue total_pages = (global_total + page_size - 1) // page_size if page_size else 0 meta = PaginationSchema( page=page, page_size=page_size, total=global_total, total_pages=total_pages ) return items, meta def create(self, payload: Any) -> Dict: """创建知识片段""" # 使用统一算法生成向量 dim = milvus_service.DENSE_DIM fake_vector = text_to_vector_algo(payload.content, dim=dim) # 基础数据 now = int(time.time() * 1000) item = { "dense": fake_vector, "text": payload.content, "document_id": "manual_add", "tag_list": "", "permission": {}, "metadata": { "doc_name": payload.doc_name, "file_name": payload.doc_name, "title": payload.doc_name }, "index": 0, "is_deleted": 0, "created_by": "system", "created_time": now, "updated_by": "system", "updated_time": now } # 合并自定义字段 if hasattr(payload, 'custom_fields') and payload.custom_fields: item.update(payload.custom_fields) data = [item] res = milvus_service.client.insert( collection_name=payload.collection_name, data=data ) milvus_service.client.flush(payload.collection_name) return {"count": res.get("insert_count", 1)} def update(self, id: str, payload: Any) -> str: """更新知识片段""" kb = payload.collection_name # 1. 删除旧数据 desc = milvus_service.client.describe_collection(kb) fields = [f['name'] for f in desc.get('fields', [])] pk_field = "pk" if "pk" in fields else "id" if id.isdigit(): expr = f"{pk_field} in [{id}]" else: expr = f"{pk_field} in ['{id}']" milvus_service.client.delete(collection_name=kb, filter=expr) # 2. 插入新数据 # 使用统一算法生成向量 dim = milvus_service.DENSE_DIM fake_vector = text_to_vector_algo(payload.content, dim=dim) now = int(time.time() * 1000) item = { "dense": fake_vector, "text": payload.content, "document_id": "updated", "tag_list": "", "permission": {}, "metadata": { "doc_name": payload.doc_name or "已更新", "file_name": payload.doc_name, "title": payload.doc_name }, "index": 0, "is_deleted": 0, "created_by": "system", "created_time": now, "updated_by": "system", "updated_time": now } # 合并自定义字段 if hasattr(payload, 'custom_fields') and payload.custom_fields: item.update(payload.custom_fields) data = [item] milvus_service.client.insert(collection_name=kb, data=data) milvus_service.client.flush(kb) return "更新成功 (ID已变更)" def delete(self, id: str, kb: str) -> None: """删除知识片段""" if not milvus_service.has_collection(kb): raise ValueError("知识库不存在") desc = milvus_service.client.describe_collection(kb) fields = [f['name'] for f in desc.get('fields', [])] pk_field = "pk" if "pk" in fields else "id" if id.isdigit(): expr = f"{pk_field} in [{id}]" else: expr = f"{pk_field} in ['{id}']" milvus_service.client.delete( collection_name=kb, filter=expr ) milvus_service.client.flush(kb) def _format_snippet(self, r: Dict, col_name: str) -> Dict: id_val = r.get("id") or r.get("pk") content = r.get("text") or r.get("content") or r.get("page_content") or "" if not content: try: debug_content = r.copy() if "dense" in debug_content: del debug_content["dense"] content = json.dumps(debug_content, default=str, ensure_ascii=False) except: content = "无法解析内容" doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档" meta_info = f"ParentID: {r.get('parent_id', '-')}" return { "id": str(id_val), "collection_name": col_name, "doc_name": doc_name, "code": f"SNIP-{id_val}", "content": content, "char_count": len(content) if content else 0, "meta_info": meta_info, "status": "normal", "created_at": "-", "updated_at": "-" } def export_snippets(self, kb: Optional[str] = None, keyword: Optional[str] = None) -> Any: """导出知识片段 (生成器)""" # 1. 确定要查询的目标集合列表 target_collections = [] if kb: target_collections = [kb] else: target_collections = milvus_service.client.list_collections() for col_name in target_collections: if not milvus_service.has_collection(col_name): continue try: # 获取该集合总数 stats = milvus_service.client.get_collection_stats(col_name) col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0 if col_count == 0: continue output_fields = ["*"] expr = "" if keyword: desc = milvus_service.client.describe_collection(col_name) existing_fields = [f['name'] for f in desc.get('fields', [])] if 'text' in existing_fields: expr = f'text like "%{keyword}%"' else: continue # 分批获取所有数据 batch_size = 1000 offset = 0 while True: res = milvus_service.client.query( collection_name=col_name, filter=expr, output_fields=output_fields, limit=batch_size, offset=offset ) if not res: break for r in res: yield self._format_snippet(r, col_name) offset += len(res) if len(res) < batch_size: break except Exception as e: print(f"Collection {col_name} export error: {e}") continue def generate_csv_stream(self, kb: Optional[str] = None, keyword: Optional[str] = None): """生成CSV流""" output = io.StringIO() fieldnames = ["id", "collection_name", "doc_name", "content", "meta_info", "created_at", "status"] writer = csv.DictWriter(output, fieldnames=fieldnames) # 写入表头 writer.writeheader() yield output.getvalue() output.seek(0) output.truncate(0) for item in self.export_snippets(kb, keyword): # 过滤掉不在 fieldnames 中的字段 row = {k: item.get(k, "") for k in fieldnames} writer.writerow(row) yield output.getvalue() output.seek(0) output.truncate(0) snippet_service = SnippetService()