|
|
@@ -2,14 +2,10 @@
|
|
|
知识片段视图路由
|
|
|
"""
|
|
|
from fastapi import APIRouter, Depends, Query, Path, Body
|
|
|
-from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
-from typing import Optional, List, Dict, Any
|
|
|
-from datetime import datetime, timezone
|
|
|
-import json
|
|
|
+from typing import Optional
|
|
|
|
|
|
-from app.base.async_mysql_connection import get_db
|
|
|
-from app.services.milvus_service import milvus_service
|
|
|
-from app.schemas.base import ResponseSchema, PaginatedResponseSchema, PaginationSchema
|
|
|
+from app.services.snippet_service import snippet_service
|
|
|
+from app.schemas.base import ResponseSchema, PaginatedResponseSchema
|
|
|
from app.services.jwt_token import verify_token
|
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
from pydantic import BaseModel
|
|
|
@@ -39,147 +35,14 @@ async def get_snippets(
|
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
):
|
|
|
"""获取知识片段列表 (跨集合查询)"""
|
|
|
- try:
|
|
|
- # 1. 确定要查询的目标集合列表
|
|
|
- target_collections = []
|
|
|
- if kb:
|
|
|
- target_collections = [kb]
|
|
|
- else:
|
|
|
- # 查询所有启用的知识库 (这里简单起见,直接查 Milvus 或者需要注入 DB 查询)
|
|
|
- # 为了解耦,这里直接查 Milvus 的所有集合,或者如果需要 DB 过滤,则需要注入 db session
|
|
|
- # 简单起见,先查 Milvus
|
|
|
- target_collections = milvus_service.client.list_collections()
|
|
|
-
|
|
|
- if not target_collections:
|
|
|
- return PaginatedResponseSchema(
|
|
|
- code=0, message="没有可用的知识库", data=[],
|
|
|
- meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
|
|
|
- )
|
|
|
-
|
|
|
- # 2. 计算分页逻辑 (跨集合分页算法)
|
|
|
- global_total = 0
|
|
|
- items = []
|
|
|
-
|
|
|
- # 需要跳过的全局偏移量
|
|
|
- skip_count = (page - 1) * page_size
|
|
|
- # 需要获取的目标数量
|
|
|
- need_count = page_size
|
|
|
-
|
|
|
- # 遍历所有集合
|
|
|
- for col_name in target_collections:
|
|
|
- if not milvus_service.has_collection(col_name):
|
|
|
- continue
|
|
|
-
|
|
|
- try:
|
|
|
- # 获取该集合总数
|
|
|
- stats = milvus_service.client.get_collection_stats(col_name)
|
|
|
- col_count = int(stats.get("row_count", 0)) if isinstance(stats, dict) else 0
|
|
|
-
|
|
|
- if keyword:
|
|
|
- # 关键词模式:必须实际查询
|
|
|
- desc = milvus_service.client.describe_collection(col_name)
|
|
|
- existing_fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
-
|
|
|
- # 尝试获取所有字段
|
|
|
- output_fields = ["*"]
|
|
|
-
|
|
|
- expr = f'text like "%{keyword}%"' if 'text' in existing_fields else ""
|
|
|
- if not expr: continue
|
|
|
-
|
|
|
- res = milvus_service.client.query(col_name, filter=expr, output_fields=output_fields, limit=100)
|
|
|
- col_hits = len(res)
|
|
|
- global_total += col_hits
|
|
|
-
|
|
|
- if skip_count >= col_hits:
|
|
|
- skip_count -= col_hits
|
|
|
- continue
|
|
|
-
|
|
|
- take = min(need_count, col_hits - skip_count)
|
|
|
- chunk = res[skip_count : skip_count + take]
|
|
|
-
|
|
|
- for r in chunk:
|
|
|
- items.append(format_snippet(r, col_name))
|
|
|
-
|
|
|
- skip_count = 0
|
|
|
- need_count -= take
|
|
|
- if need_count <= 0: break
|
|
|
-
|
|
|
- else:
|
|
|
- # 无关键词模式
|
|
|
- global_total += col_count
|
|
|
-
|
|
|
- if skip_count >= col_count:
|
|
|
- skip_count -= col_count
|
|
|
- continue
|
|
|
-
|
|
|
- if need_count > 0:
|
|
|
- current_offset = skip_count
|
|
|
- current_limit = min(need_count, col_count - current_offset)
|
|
|
-
|
|
|
- output_fields = ["*"]
|
|
|
-
|
|
|
- res = milvus_service.client.query(
|
|
|
- collection_name=col_name,
|
|
|
- filter="",
|
|
|
- output_fields=output_fields,
|
|
|
- limit=current_limit,
|
|
|
- offset=current_offset
|
|
|
- )
|
|
|
-
|
|
|
- for r in res:
|
|
|
- items.append(format_snippet(r, col_name))
|
|
|
-
|
|
|
- skip_count = 0
|
|
|
- need_count -= current_limit
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Collection {col_name} query error: {e}")
|
|
|
- continue
|
|
|
-
|
|
|
- total_pages = (global_total + page_size - 1) // page_size
|
|
|
-
|
|
|
- return PaginatedResponseSchema(
|
|
|
- code=0,
|
|
|
- message="获取成功",
|
|
|
- data=items,
|
|
|
- meta=PaginationSchema(total=global_total, page=page, page_size=page_size, total_pages=total_pages)
|
|
|
- )
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Query Snippets Error: {e}")
|
|
|
- return PaginatedResponseSchema(
|
|
|
- code=500, message=f"查询失败: {str(e)}", data=[],
|
|
|
- meta=PaginationSchema(total=0, page=page, page_size=page_size, total_pages=0)
|
|
|
- )
|
|
|
-
|
|
|
-def format_snippet(r: Dict, col_name: str) -> Dict:
|
|
|
- id_val = r.get("id") or r.get("pk")
|
|
|
- content = r.get("text") or r.get("content") or r.get("page_content") or ""
|
|
|
-
|
|
|
- # 兜底:如果内容为空,显示 Keys 以便调试
|
|
|
- if not content:
|
|
|
- try:
|
|
|
- debug_content = r.copy()
|
|
|
- if "dense" in debug_content: del debug_content["dense"]
|
|
|
- content = json.dumps(debug_content, default=str, ensure_ascii=False)
|
|
|
- except:
|
|
|
- content = "无法解析内容"
|
|
|
-
|
|
|
- doc_name = r.get("file_name") or r.get("title") or r.get("source") or r.get("doc_name") or "未知文档"
|
|
|
- meta_info = f"ParentID: {r.get('parent_id', '-')}"
|
|
|
+ items, meta = snippet_service.get_list(page, page_size, kb, keyword, status)
|
|
|
|
|
|
- return {
|
|
|
- "id": str(id_val),
|
|
|
- "collection_name": col_name,
|
|
|
- "doc_name": doc_name,
|
|
|
- "code": f"SNIP-{id_val}",
|
|
|
- "content": content,
|
|
|
- "char_count": len(content) if content else 0,
|
|
|
- "meta_info": meta_info,
|
|
|
- "status": "normal",
|
|
|
- "created_at": "-",
|
|
|
- "updated_at": "-"
|
|
|
- }
|
|
|
+ return PaginatedResponseSchema(
|
|
|
+ code=0,
|
|
|
+ message="获取成功",
|
|
|
+ data=items,
|
|
|
+ meta=meta
|
|
|
+ )
|
|
|
|
|
|
@router.post("", response_model=ResponseSchema)
|
|
|
async def create_snippet(
|
|
|
@@ -187,100 +50,37 @@ async def create_snippet(
|
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
):
|
|
|
"""创建知识片段"""
|
|
|
- try:
|
|
|
- import random
|
|
|
- fake_vector = [random.random() for _ in range(768)]
|
|
|
-
|
|
|
- data = [{
|
|
|
- "vector": fake_vector,
|
|
|
- "text": payload.content,
|
|
|
- "source": payload.doc_name,
|
|
|
- "doc_id": "manual_add",
|
|
|
- "file_name": payload.doc_name, # 确保这些字段都有值
|
|
|
- "title": payload.doc_name
|
|
|
- }]
|
|
|
-
|
|
|
- res = milvus_service.client.insert(
|
|
|
- collection_name=payload.collection_name,
|
|
|
- data=data
|
|
|
- )
|
|
|
-
|
|
|
- milvus_service.client.flush(payload.collection_name)
|
|
|
-
|
|
|
- return ResponseSchema(code=0, message="创建成功", data={"count": res.get("insert_count", 1)})
|
|
|
- except Exception as e:
|
|
|
- print(f"Create Snippet Error: {e}")
|
|
|
- return ResponseSchema(code=500, message=str(e))
|
|
|
+ payload_token = verify_token(credentials.credentials)
|
|
|
+ if not payload_token:
|
|
|
+ return ResponseSchema(code=401, message="无效的访问令牌")
|
|
|
+
|
|
|
+ data = snippet_service.create(payload)
|
|
|
+ return ResponseSchema(code=0, message="创建成功", data=data)
|
|
|
|
|
|
-@router.put("/{id}", response_model=ResponseSchema)
|
|
|
+@router.post("/{id}", response_model=ResponseSchema)
|
|
|
async def update_snippet(
|
|
|
id: str,
|
|
|
payload: SnippetUpdate,
|
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
):
|
|
|
"""更新知识片段"""
|
|
|
- try:
|
|
|
- kb = payload.collection_name
|
|
|
-
|
|
|
- # 1. 删除旧数据
|
|
|
- desc = milvus_service.client.describe_collection(kb)
|
|
|
- fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
- pk_field = "pk" if "pk" in fields else "id"
|
|
|
-
|
|
|
- if id.isdigit():
|
|
|
- expr = f"{pk_field} in [{id}]"
|
|
|
- else:
|
|
|
- expr = f"{pk_field} in ['{id}']"
|
|
|
-
|
|
|
- milvus_service.client.delete(collection_name=kb, filter=expr)
|
|
|
-
|
|
|
- # 2. 插入新数据
|
|
|
- import random
|
|
|
- fake_vector = [random.random() for _ in range(768)]
|
|
|
-
|
|
|
- data = [{
|
|
|
- "vector": fake_vector,
|
|
|
- "text": payload.content,
|
|
|
- "source": payload.doc_name or "已更新",
|
|
|
- "doc_id": "updated",
|
|
|
- "file_name": payload.doc_name,
|
|
|
- "title": payload.doc_name
|
|
|
- }]
|
|
|
-
|
|
|
- milvus_service.client.insert(collection_name=kb, data=data)
|
|
|
- milvus_service.client.flush(kb)
|
|
|
-
|
|
|
- return ResponseSchema(code=0, message="更新成功 (ID已变更)")
|
|
|
- except Exception as e:
|
|
|
- print(f"Update Snippet Error: {e}")
|
|
|
- return ResponseSchema(code=500, message=str(e))
|
|
|
+ payload_token = verify_token(credentials.credentials)
|
|
|
+ if not payload_token:
|
|
|
+ return ResponseSchema(code=401, message="无效的访问令牌")
|
|
|
+
|
|
|
+ msg = snippet_service.update(id, payload)
|
|
|
+ return ResponseSchema(code=0, message=msg)
|
|
|
|
|
|
-@router.delete("/{id}", response_model=ResponseSchema)
|
|
|
+@router.post("/{id}/delete", response_model=ResponseSchema)
|
|
|
async def delete_snippet(
|
|
|
id: str,
|
|
|
kb: str = Query(..., description="知识库名称"),
|
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
|
):
|
|
|
"""删除知识片段"""
|
|
|
- try:
|
|
|
- if not milvus_service.has_collection(kb):
|
|
|
- return ResponseSchema(code=404, message="知识库不存在")
|
|
|
-
|
|
|
- desc = milvus_service.client.describe_collection(kb)
|
|
|
- fields = [f['name'] for f in desc.get('fields', [])]
|
|
|
- pk_field = "pk" if "pk" in fields else "id"
|
|
|
-
|
|
|
- if id.isdigit():
|
|
|
- expr = f"{pk_field} in [{id}]"
|
|
|
- else:
|
|
|
- expr = f"{pk_field} in ['{id}']"
|
|
|
-
|
|
|
- milvus_service.client.delete(
|
|
|
- collection_name=kb,
|
|
|
- filter=expr
|
|
|
- )
|
|
|
- milvus_service.client.flush(kb)
|
|
|
+ payload_token = verify_token(credentials.credentials)
|
|
|
+ if not payload_token:
|
|
|
+ return ResponseSchema(code=401, message="无效的访问令牌")
|
|
|
|
|
|
- return ResponseSchema(code=0, message="删除成功")
|
|
|
- except Exception as e:
|
|
|
- return ResponseSchema(code=500, message=str(e))
|
|
|
+ snippet_service.delete(id, kb)
|
|
|
+ return ResponseSchema(code=0, message="删除成功")
|