|
@@ -0,0 +1,286 @@
|
|
|
|
|
+"""
|
|
|
|
|
+知识片段视图路由
|
|
|
|
|
+"""
|
|
|
|
|
+from fastapi import APIRouter, Depends, Query, Path, Body
|
|
|
|
|
+from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
+from typing import Optional, List, Dict, Any
|
|
|
|
|
+from datetime import datetime, timezone
|
|
|
|
|
+import json
|
|
|
|
|
+
|
|
|
|
|
+from app.base.async_mysql_connection import get_db
|
|
|
|
|
+from app.services.milvus_service import milvus_service
|
|
|
|
|
+from app.schemas.base import ResponseSchema, PaginatedResponseSchema, PaginationSchema
|
|
|
|
|
+from app.services.jwt_token import verify_token
|
|
|
|
|
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+
|
|
|
|
|
+router = APIRouter(prefix="/document/snippet", tags=["样本中心-知识片段"])
|
|
|
|
|
+security = HTTPBearer()
|
|
|
|
|
+
|
|
|
|
|
+# Schemas
|
|
|
|
|
+class SnippetCreate(BaseModel):
|
|
|
|
|
+ collection_name: str
|
|
|
|
|
+ doc_name: str = "手动添加"
|
|
|
|
|
+ content: str
|
|
|
|
|
+ meta_info: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+class SnippetUpdate(BaseModel):
|
|
|
|
|
+ collection_name: str
|
|
|
|
|
+ doc_name: Optional[str] = None
|
|
|
|
|
+ content: str
|
|
|
|
|
+
|
|
|
|
|
+@router.get("", response_model=PaginatedResponseSchema)
|
|
|
|
|
+async def get_snippets(
|
|
|
|
|
+ page: int = Query(1, ge=1),
|
|
|
|
|
+ page_size: int = Query(10, ge=1),
|
|
|
|
|
+ kb: Optional[str] = Query(None, description="知识库集合名称"),
|
|
|
|
|
+ keyword: Optional[str] = Query(None),
|
|
|
|
|
+ status: Optional[str] = Query(None),
|
|
|
|
|
+ credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
|
|
+):
|
|
|
|
|
+ """获取知识片段列表 (跨集合查询)"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 1. 确定要查询的目标集合列表
|
|
|
|
|
+ target_collections = []
|
|
|
|
|
+ if kb:
|
|
|
|
|
+ target_collections = [kb]
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 查询所有启用的知识库 (这里简单起见,直接查 Milvus 或者需要注入 DB 查询)
|
|
|
|
|
+ # 为了解耦,这里直接查 Milvus 的所有集合,或者如果需要 DB 过滤,则需要注入 db session
|
|
|
|
|
+ # 简单起见,先查 Milvus
|
|
|
|
|
+ target_collections = milvus_service.client.list_collections()
|
|
|
|
|
+
|
|
|
|
|
+ if not target_collections:
|
|
|
|
|
+ return PaginatedResponseSchema(
|
|
|
|
|
+ code=0, message="没有可用的知识库", data=[],
|
|
|
|
|
+ meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 计算分页逻辑 (跨集合分页算法)
|
|
|
|
|
+ global_total = 0
|
|
|
|
|
+ items = []
|
|
|
|
|
+
|
|
|
|
|
+ # 需要跳过的全局偏移量
|
|
|
|
|
+ skip_count = (page - 1) * page_size
|
|
|
|
|
+ # 需要获取的目标数量
|
|
|
|
|
+ need_count = page_size
|
|
|
|
|
+
|
|
|
|
|
+ # 遍历所有集合
|
|
|
|
|
+ for col_name in target_collections:
|
|
|
|
|
+ if not milvus_service.has_collection(col_name):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 获取该集合总数
|
|
|
|
|
+ stats = milvus_service.client.get_collection_stats(col_name)
|
|
|
|
|
+ col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
|
|
|
|
|
+
|
|
|
|
|
+ if keyword:
|
|
|
|
|
+ # 关键词模式:必须实际查询
|
|
|
|
|
+ desc = milvus_service.client.describe_collection(col_name)
|
|
|
|
|
+ existing_fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
|
|
+
|
|
|
|
|
+ # 尝试获取所有字段
|
|
|
|
|
+ output_fields = ["*"]
|
|
|
|
|
+
|
|
|
|
|
+ expr = f'text like "%{keyword}%"' if 'text' in existing_fields else ""
|
|
|
|
|
+ if not expr: continue
|
|
|
|
|
+
|
|
|
|
|
+ res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
|
|
|
|
|
+ col_hits = len(res)
|
|
|
|
|
+ global_total += col_hits
|
|
|
|
|
+
|
|
|
|
|
+ if skip_count >= col_hits:
|
|
|
|
|
+ skip_count -= col_hits
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ take = min(need_count, col_hits - skip_count)
|
|
|
|
|
+ chunk = res[skip_count : skip_count + take]
|
|
|
|
|
+
|
|
|
|
|
+ for r in chunk:
|
|
|
|
|
+ items.append(format_snippet(r, col_name))
|
|
|
|
|
+
|
|
|
|
|
+ skip_count = 0
|
|
|
|
|
+ need_count -= take
|
|
|
|
|
+ if need_count <= 0: break
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 无关键词模式
|
|
|
|
|
+ global_total += col_count
|
|
|
|
|
+
|
|
|
|
|
+ if skip_count >= col_count:
|
|
|
|
|
+ skip_count -= col_count
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ if need_count > 0:
|
|
|
|
|
+ current_offset = skip_count
|
|
|
|
|
+ current_limit = min(need_count, col_count - current_offset)
|
|
|
|
|
+
|
|
|
|
|
+ output_fields = ["*"]
|
|
|
|
|
+
|
|
|
|
|
+ res = milvus_service.client.query(
|
|
|
|
|
+ collection_name=col_name,
|
|
|
|
|
+ filter="",
|
|
|
|
|
+ output_fields=output_fields,
|
|
|
|
|
+ limit=current_limit,
|
|
|
|
|
+ offset=current_offset
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ for r in res:
|
|
|
|
|
+ items.append(format_snippet(r, col_name))
|
|
|
|
|
+
|
|
|
|
|
+ skip_count = 0
|
|
|
|
|
+ need_count -= current_limit
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Collection {col_name} query error: {e}")
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ total_pages = (global_total + page_size - 1) // page_size
|
|
|
|
|
+
|
|
|
|
|
+ return PaginatedResponseSchema(
|
|
|
|
|
+ code=0,
|
|
|
|
|
+ message="获取成功",
|
|
|
|
|
+ data=items,
|
|
|
|
|
+ meta=PaginationSchema(total=global_total, page=page, page_size=page_size, total_pages=total_pages)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Query Snippets Error: {e}")
|
|
|
|
|
+ return PaginatedResponseSchema(
|
|
|
|
|
+ code=500, message=f"查询失败: {str(e)}", data=[],
|
|
|
|
|
+ meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+def format_snippet(r: Dict, col_name: str) -> Dict:
|
|
|
|
|
+ id_val = r.get("id") or r.get("pk")
|
|
|
|
|
+ content = r.get("text") or r.get("content") or r.get("page_content") or ""
|
|
|
|
|
+
|
|
|
|
|
+ # 兜底:如果内容为空,显示 Keys 以便调试
|
|
|
|
|
+ if not content:
|
|
|
|
|
+ try:
|
|
|
|
|
+ debug_content = r.copy()
|
|
|
|
|
+ if "dense" in debug_content: del debug_content["dense"]
|
|
|
|
|
+ content = json.dumps(debug_content, default=str, ensure_ascii=False)
|
|
|
|
|
+ except:
|
|
|
|
|
+ content = "无法解析内容"
|
|
|
|
|
+
|
|
|
|
|
+ doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
|
|
|
|
|
+ meta_info = f"ParentID: {r.get('parent_id', '-')}"
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "id": str(id_val),
|
|
|
|
|
+ "collection_name": col_name,
|
|
|
|
|
+ "doc_name": doc_name,
|
|
|
|
|
+ "code": f"SNIP-{id_val}",
|
|
|
|
|
+ "content": content,
|
|
|
|
|
+ "char_count": len(content) if content else 0,
|
|
|
|
|
+ "meta_info": meta_info,
|
|
|
|
|
+ "status": "normal",
|
|
|
|
|
+ "created_at": "-",
|
|
|
|
|
+ "updated_at": "-"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+@router.post("", response_model=ResponseSchema)
|
|
|
|
|
+async def create_snippet(
|
|
|
|
|
+ payload: SnippetCreate,
|
|
|
|
|
+ credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
|
|
+):
|
|
|
|
|
+ """创建知识片段"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ import random
|
|
|
|
|
+ fake_vector = [random.random() for _ in range(768)]
|
|
|
|
|
+
|
|
|
|
|
+ data = [{
|
|
|
|
|
+ "vector": fake_vector,
|
|
|
|
|
+ "text": payload.content,
|
|
|
|
|
+ "source": payload.doc_name,
|
|
|
|
|
+ "doc_id": "manual_add",
|
|
|
|
|
+ "file_name": payload.doc_name, # 确保这些字段都有值
|
|
|
|
|
+ "title": payload.doc_name
|
|
|
|
|
+ }]
|
|
|
|
|
+
|
|
|
|
|
+ res = milvus_service.client.insert(
|
|
|
|
|
+ collection_name=payload.collection_name,
|
|
|
|
|
+ data=data
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ milvus_service.client.flush(payload.collection_name)
|
|
|
|
|
+
|
|
|
|
|
+ return ResponseSchema(code=0, message="创建成功", data={"count": res.get("insert_count", 1)})
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Create Snippet Error: {e}")
|
|
|
|
|
+ return ResponseSchema(code=500, message=str(e))
|
|
|
|
|
+
|
|
|
|
|
+@router.put("/{id}", response_model=ResponseSchema)
|
|
|
|
|
+async def update_snippet(
|
|
|
|
|
+ id: str,
|
|
|
|
|
+ payload: SnippetUpdate,
|
|
|
|
|
+ credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
|
|
+):
|
|
|
|
|
+ """更新知识片段"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ kb = payload.collection_name
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 删除旧数据
|
|
|
|
|
+ desc = milvus_service.client.describe_collection(kb)
|
|
|
|
|
+ fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
|
|
+ pk_field = "pk" if "pk" in fields else "id"
|
|
|
|
|
+
|
|
|
|
|
+ if id.isdigit():
|
|
|
|
|
+ expr = f"{pk_field} in [{id}]"
|
|
|
|
|
+ else:
|
|
|
|
|
+ expr = f"{pk_field} in ['{id}']"
|
|
|
|
|
+
|
|
|
|
|
+ milvus_service.client.delete(collection_name=kb, filter=expr)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 插入新数据
|
|
|
|
|
+ import random
|
|
|
|
|
+ fake_vector = [random.random() for _ in range(768)]
|
|
|
|
|
+
|
|
|
|
|
+ data = [{
|
|
|
|
|
+ "vector": fake_vector,
|
|
|
|
|
+ "text": payload.content,
|
|
|
|
|
+ "source": payload.doc_name or "已更新",
|
|
|
|
|
+ "doc_id": "updated",
|
|
|
|
|
+ "file_name": payload.doc_name,
|
|
|
|
|
+ "title": payload.doc_name
|
|
|
|
|
+ }]
|
|
|
|
|
+
|
|
|
|
|
+ milvus_service.client.insert(collection_name=kb, data=data)
|
|
|
|
|
+ milvus_service.client.flush(kb)
|
|
|
|
|
+
|
|
|
|
|
+ return ResponseSchema(code=0, message="更新成功 (ID已变更)")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Update Snippet Error: {e}")
|
|
|
|
|
+ return ResponseSchema(code=500, message=str(e))
|
|
|
|
|
+
|
|
|
|
|
+@router.delete("/{id}", response_model=ResponseSchema)
|
|
|
|
|
+async def delete_snippet(
|
|
|
|
|
+ id: str,
|
|
|
|
|
+ kb: str = Query(..., description="知识库名称"),
|
|
|
|
|
+ credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
|
|
+):
|
|
|
|
|
+ """删除知识片段"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ if not milvus_service.has_collection(kb):
|
|
|
|
|
+ return ResponseSchema(code=404, message="知识库不存在")
|
|
|
|
|
+
|
|
|
|
|
+ desc = milvus_service.client.describe_collection(kb)
|
|
|
|
|
+ fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
|
|
+ pk_field = "pk" if "pk" in fields else "id"
|
|
|
|
|
+
|
|
|
|
|
+ if id.isdigit():
|
|
|
|
|
+ expr = f"{pk_field} in [{id}]"
|
|
|
|
|
+ else:
|
|
|
|
|
+ expr = f"{pk_field} in ['{id}']"
|
|
|
|
|
+
|
|
|
|
|
+ milvus_service.client.delete(
|
|
|
|
|
+ collection_name=kb,
|
|
|
|
|
+ filter=expr
|
|
|
|
|
+ )
|
|
|
|
|
+ milvus_service.client.flush(kb)
|
|
|
|
|
+
|
|
|
|
|
+ return ResponseSchema(code=0, message="删除成功")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return ResponseSchema(code=500, message=str(e))
|