| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352 |
- """
- 知识片段业务逻辑服务
- """
- from typing import List, Optional, Tuple, Dict, Any
- import json
- import random
- import csv
- import io
- import time
- from datetime import datetime
- from app.services.milvus_service import milvus_service
- from app.schemas.base import PaginationSchema, PaginatedResponseSchema
- from app.utils.vector_utils import text_to_vector_algo
- class SnippetService:
-
- def get_list(
- self,
- page: int = 1,
- page_size: int = 10,
- kb: Optional[str] = None,
- keyword: Optional[str] = None,
- status: Optional[str] = None
- ) -> Tuple[List[Dict], PaginationSchema]:
- """获取知识片段列表 (跨集合查询)"""
-
- # 1. 确定要查询的目标集合列表
- target_collections = []
- if kb:
- target_collections = [kb]
- else:
- # 简单起见,先查 Milvus 的所有集合
- target_collections = milvus_service.client.list_collections()
-
- if not target_collections:
- return [], PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
-
- # 2. 计算分页逻辑 (跨集合分页算法)
- global_total = 0
- items = []
-
- # 需要跳过的全局偏移量
- skip_count = (page - 1) * page_size
- # 需要获取的目标数量
- need_count = page_size
-
- # 遍历所有集合
- for col_name in target_collections:
- if not milvus_service.has_collection(col_name):
- continue
-
- try:
- # 获取该集合总数
- stats = milvus_service.client.get_collection_stats(col_name)
- col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
-
- if keyword:
- # 关键词模式:必须实际查询
- desc = milvus_service.client.describe_collection(col_name)
- existing_fields = [f['name'] for f in desc.get('fields', [])]
-
- # 尝试获取所有字段
- output_fields = ["*"]
-
- expr = f'text like "%{keyword}%"' if 'text' in existing_fields else ""
- if not expr: continue
-
- res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
- col_hits = len(res)
- global_total += col_hits
-
- if skip_count >= col_hits:
- skip_count -= col_hits
- continue
-
- take = min(need_count, col_hits - skip_count)
- chunk = res[skip_count : skip_count + take]
-
- for r in chunk:
- items.append(self._format_snippet(r, col_name))
-
- skip_count = 0
- need_count -= take
- if need_count <= 0: break
-
- else:
- # 无关键词模式
- global_total += col_count
-
- if skip_count >= col_count:
- skip_count -= col_count
- continue
-
- if need_count > 0:
- current_offset = skip_count
- current_limit = min(need_count, col_count - current_offset)
-
- output_fields = ["*"]
-
- res = milvus_service.client.query(
- collection_name=col_name,
- filter="",
- output_fields=output_fields,
- limit=current_limit,
- offset=current_offset
- )
-
- for r in res:
- items.append(self._format_snippet(r, col_name))
-
- skip_count = 0
- need_count -= current_limit
- except Exception as e:
- print(f"Collection {col_name} query error: {e}")
- continue
- total_pages = (global_total + page_size - 1) // page_size if page_size else 0
- meta = PaginationSchema(
- page=page,
- page_size=page_size,
- total=global_total,
- total_pages=total_pages
- )
-
- return items, meta
- def create(self, payload: Any) -> Dict:
- """创建知识片段"""
- # 使用统一算法生成向量
- dim = milvus_service.DENSE_DIM
- fake_vector = text_to_vector_algo(payload.content, dim=dim)
-
- # 基础数据
- now = int(time.time() * 1000)
- item = {
- "dense": fake_vector,
- "text": payload.content,
- "document_id": "manual_add",
- "tag_list": "",
- "permission": {},
- "metadata": {
- "doc_name": payload.doc_name,
- "file_name": payload.doc_name,
- "title": payload.doc_name
- },
- "index": 0,
- "is_deleted": 0,
- "created_by": "system",
- "created_time": now,
- "updated_by": "system",
- "updated_time": now
- }
-
- # 合并自定义字段
- if hasattr(payload, 'custom_fields') and payload.custom_fields:
- item.update(payload.custom_fields)
-
- data = [item]
-
- res = milvus_service.client.insert(
- collection_name=payload.collection_name,
- data=data
- )
-
- milvus_service.client.flush(payload.collection_name)
- return {"count": res.get("insert_count", 1)}
- def update(self, id: str, payload: Any) -> str:
- """更新知识片段"""
- kb = payload.collection_name
-
- # 1. 删除旧数据
- desc = milvus_service.client.describe_collection(kb)
- fields = [f['name'] for f in desc.get('fields', [])]
- pk_field = "pk" if "pk" in fields else "id"
-
- if id.isdigit():
- expr = f"{pk_field} in [{id}]"
- else:
- expr = f"{pk_field} in ['{id}']"
-
- milvus_service.client.delete(collection_name=kb, filter=expr)
-
- # 2. 插入新数据
- # 使用统一算法生成向量
- dim = milvus_service.DENSE_DIM
- fake_vector = text_to_vector_algo(payload.content, dim=dim)
-
- now = int(time.time() * 1000)
- item = {
- "dense": fake_vector,
- "text": payload.content,
- "document_id": "updated",
- "tag_list": "",
- "permission": {},
- "metadata": {
- "doc_name": payload.doc_name or "已更新",
- "file_name": payload.doc_name,
- "title": payload.doc_name
- },
- "index": 0,
- "is_deleted": 0,
- "created_by": "system",
- "created_time": now,
- "updated_by": "system",
- "updated_time": now
- }
-
- # 合并自定义字段
- if hasattr(payload, 'custom_fields') and payload.custom_fields:
- item.update(payload.custom_fields)
-
- data = [item]
-
- milvus_service.client.insert(collection_name=kb, data=data)
- milvus_service.client.flush(kb)
-
- return "更新成功 (ID已变更)"
- def delete(self, id: str, kb: str) -> None:
- """删除知识片段"""
- if not milvus_service.has_collection(kb):
- raise ValueError("知识库不存在")
-
- desc = milvus_service.client.describe_collection(kb)
- fields = [f['name'] for f in desc.get('fields', [])]
- pk_field = "pk" if "pk" in fields else "id"
-
- if id.isdigit():
- expr = f"{pk_field} in [{id}]"
- else:
- expr = f"{pk_field} in ['{id}']"
-
- milvus_service.client.delete(
- collection_name=kb,
- filter=expr
- )
- milvus_service.client.flush(kb)
- def _format_snippet(self, r: Dict, col_name: str) -> Dict:
- id_val = r.get("id") or r.get("pk")
- content = r.get("text") or r.get("content") or r.get("page_content") or ""
-
- if not content:
- try:
- debug_content = r.copy()
- if "dense" in debug_content: del debug_content["dense"]
- content = json.dumps(debug_content, default=str, ensure_ascii=False)
- except:
- content = "无法解析内容"
- doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
- meta_info = f"ParentID: {r.get('parent_id', '-')}"
-
- return {
- "id": str(id_val),
- "collection_name": col_name,
- "doc_name": doc_name,
- "code": f"SNIP-{id_val}",
- "content": content,
- "char_count": len(content) if content else 0,
- "meta_info": meta_info,
- "status": "normal",
- "created_at": "-",
- "updated_at": "-"
- }
- def export_snippets(self, kb: Optional[str] = None, keyword: Optional[str] = None) -> Any:
- """导出知识片段 (生成器)"""
-
- # 1. 确定要查询的目标集合列表
- target_collections = []
- if kb:
- target_collections = [kb]
- else:
- target_collections = milvus_service.client.list_collections()
-
- for col_name in target_collections:
- if not milvus_service.has_collection(col_name):
- continue
-
- try:
- # 获取该集合总数
- stats = milvus_service.client.get_collection_stats(col_name)
- col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
-
- if col_count == 0:
- continue
- output_fields = ["*"]
- expr = ""
-
- if keyword:
- desc = milvus_service.client.describe_collection(col_name)
- existing_fields = [f['name'] for f in desc.get('fields', [])]
- if 'text' in existing_fields:
- expr = f'text like "%{keyword}%"'
- else:
- continue
-
- # 分批获取所有数据
- batch_size = 1000
- offset = 0
-
- while True:
- res = milvus_service.client.query(
- collection_name=col_name,
- filter=expr,
- output_fields=output_fields,
- limit=batch_size,
- offset=offset
- )
-
- if not res:
- break
-
- for r in res:
- yield self._format_snippet(r, col_name)
-
- offset += len(res)
- if len(res) < batch_size:
- break
-
- except Exception as e:
- print(f"Collection {col_name} export error: {e}")
- continue
- def generate_csv_stream(self, kb: Optional[str] = None, keyword: Optional[str] = None):
- """生成CSV流"""
- output = io.StringIO()
- fieldnames = ["id", "collection_name", "doc_name", "content", "meta_info", "created_at", "status"]
- writer = csv.DictWriter(output, fieldnames=fieldnames)
-
- # 写入表头
- writer.writeheader()
- yield output.getvalue()
- output.seek(0)
- output.truncate(0)
-
- for item in self.export_snippets(kb, keyword):
- # 过滤掉不在 fieldnames 中的字段
- row = {k: item.get(k, "") for k in fieldnames}
- writer.writerow(row)
- yield output.getvalue()
- output.seek(0)
- output.truncate(0)
- snippet_service = SnippetService()
|