|
@@ -0,0 +1,401 @@
|
|
|
|
|
+"""
|
|
|
|
|
+知识库对外API业务逻辑
|
|
|
|
|
+"""
|
|
|
|
|
+import logging
|
|
|
|
|
+import json
|
|
|
|
|
+from typing import Tuple, List, Dict, Any, Optional
|
|
|
|
|
+from app.base.async_mysql_connection import get_db_connection
|
|
|
|
|
+
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class KnowledgeBaseApiService:
|
|
|
|
|
+ """知识库对外API服务"""
|
|
|
|
|
+
|
|
|
|
|
+ async def get_knowledge_base_list(self, page: int, page_size: int) -> Tuple[int, List[Dict[str, Any]]]:
|
|
|
|
|
+ """查询知识库列表(分页)"""
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ if not conn:
|
|
|
|
|
+ return 0, []
|
|
|
|
|
+
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT COUNT(*) as cnt FROM t_samp_knowledge_base "
|
|
|
|
|
+ "WHERE status = 'normal' AND is_deleted = 0"
|
|
|
|
|
+ )
|
|
|
|
|
+ total = cursor.fetchone()['cnt']
|
|
|
|
|
+
|
|
|
|
|
+ offset = (page - 1) * page_size
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT id, name, "
|
|
|
|
|
+ "collection_name_parent as parent_table, "
|
|
|
|
|
+ "collection_name_children as child_table, "
|
|
|
|
|
+ "document_count, status, "
|
|
|
|
|
+ "created_time as created_at, created_by "
|
|
|
|
|
+ "FROM t_samp_knowledge_base "
|
|
|
|
|
+ "WHERE status = 'normal' AND is_deleted = 0 "
|
|
|
|
|
+ "ORDER BY created_time DESC LIMIT %s OFFSET %s",
|
|
|
|
|
+ (page_size, offset)
|
|
|
|
|
+ )
|
|
|
|
|
+ kbs = cursor.fetchall()
|
|
|
|
|
+
|
|
|
|
|
+ items = []
|
|
|
|
|
+ for kb in kbs:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT field_zh_name as field_name_cn, field_en_name as field_name_en, "
|
|
|
|
|
+ "field_type, remark as description "
|
|
|
|
|
+ "FROM t_samp_metadata WHERE knowledge_base_id = %s",
|
|
|
|
|
+ (kb['id'],)
|
|
|
|
|
+ )
|
|
|
|
|
+ metadata_schema = cursor.fetchall()
|
|
|
|
|
+ kb['metadata_schema'] = metadata_schema
|
|
|
|
|
+ items.append(kb)
|
|
|
|
|
+
|
|
|
|
|
+ return total, items
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"查询知识库列表失败: {e}")
|
|
|
|
|
+ return 0, []
|
|
|
|
|
+ finally:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ async def get_knowledge_base_detail(self, kb_id: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
+ """查询知识库详情"""
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ if not conn:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT id, name, description, "
|
|
|
|
|
+ "collection_name_parent as parent_table, "
|
|
|
|
|
+ "collection_name_children as child_table, "
|
|
|
|
|
+ "document_count, status, "
|
|
|
|
|
+ "created_time as created_at, created_by, "
|
|
|
|
|
+ "updated_time as updated_at "
|
|
|
|
|
+ "FROM t_samp_knowledge_base WHERE id = %s",
|
|
|
|
|
+ (kb_id,)
|
|
|
|
|
+ )
|
|
|
|
|
+ kb = cursor.fetchone()
|
|
|
|
|
+ if not kb:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT field_zh_name as field_name_cn, field_en_name as field_name_en, "
|
|
|
|
|
+ "field_type, remark as description "
|
|
|
|
|
+ "FROM t_samp_metadata WHERE knowledge_base_id = %s",
|
|
|
|
|
+ (kb_id,)
|
|
|
|
|
+ )
|
|
|
|
|
+ kb['metadata_schema'] = cursor.fetchall()
|
|
|
|
|
+
|
|
|
|
|
+ return kb
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"查询知识库详情失败: {e}")
|
|
|
|
|
+ return None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ async def create_batch_import_task(
|
|
|
|
|
+ self,
|
|
|
|
|
+ kb_id: str,
|
|
|
|
|
+ task_no: str,
|
|
|
|
|
+ callback_url: str,
|
|
|
|
|
+ parents: list,
|
|
|
|
|
+ children: list
|
|
|
|
|
+ ) -> Tuple[Optional[str], Optional[str]]:
|
|
|
|
|
+ """创建批量入库任务"""
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ if not conn:
|
|
|
|
|
+ return None, None
|
|
|
|
|
+
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT id, name, collection_name_parent, collection_name_children "
|
|
|
|
|
+ "FROM t_samp_knowledge_base WHERE id = %s AND status = 'normal'",
|
|
|
|
|
+ (kb_id,)
|
|
|
|
|
+ )
|
|
|
|
|
+ kb = cursor.fetchone()
|
|
|
|
|
+ if not kb:
|
|
|
|
|
+ return None, None
|
|
|
|
|
+
|
|
|
|
|
+ if not task_no:
|
|
|
|
|
+ logger.warning(f"批量入库任务缺少task_no, kb_id: {kb_id}")
|
|
|
|
|
+ return None, None
|
|
|
|
|
+
|
|
|
|
|
+ import uuid, time
|
|
|
|
|
+ task_id = f"task_{time.strftime('%Y%m%d')}{uuid.uuid4().hex[:12]}"
|
|
|
|
|
+
|
|
|
|
|
+ task_params = json.dumps({
|
|
|
|
|
+ "kb_id": kb_id,
|
|
|
|
|
+ "parents": parents,
|
|
|
|
|
+ "children": children
|
|
|
|
|
+ }, ensure_ascii=False)
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "INSERT INTO t_samp_task_management "
|
|
|
|
|
+ "(task_id, task_no, task_type, task_params, task_source, callback_url, status) "
|
|
|
|
|
+ "VALUES (%s, %s, %s, %s, %s, %s, %s)",
|
|
|
|
|
+ (task_id, task_no, "bi", task_params, "col", callback_url, "pending")
|
|
|
|
|
+ )
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ import asyncio
|
|
|
|
|
+ asyncio.create_task(
|
|
|
|
|
+ self._process_batch_import(task_id, task_no, kb, parents, children, callback_url)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return task_id, "pending"
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"创建入库任务失败: {e}")
|
|
|
|
|
+ conn.rollback()
|
|
|
|
|
+ return None, None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ async def _process_batch_import(
|
|
|
|
|
+ self,
|
|
|
|
|
+ task_id: str,
|
|
|
|
|
+ task_no: str,
|
|
|
|
|
+ kb: dict,
|
|
|
|
|
+ parents: list,
|
|
|
|
|
+ children: list,
|
|
|
|
|
+ callback_url: str
|
|
|
|
|
+ ):
|
|
|
|
|
+ """异步处理批量入库(后台执行)"""
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ if not conn:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ total = len(parents) + len(children)
|
|
|
|
|
+ succeeded = 0
|
|
|
|
|
+ failed = 0
|
|
|
|
|
+ failures = []
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "UPDATE t_samp_task_management SET status = %s, updated_time = NOW() WHERE task_id = %s",
|
|
|
|
|
+ ("processing", task_id)
|
|
|
|
|
+ )
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ for item in parents:
|
|
|
|
|
+ try:
|
|
|
|
|
+ self._insert_to_milvus(kb, item, is_parent=True)
|
|
|
|
|
+ succeeded += 1
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ failed += 1
|
|
|
|
|
+ failures.append({
|
|
|
|
|
+ "index": item.get("index", 0),
|
|
|
|
|
+ "parent_id": item.get("parent_id"),
|
|
|
|
|
+ "error": str(e)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ for item in children:
|
|
|
|
|
+ try:
|
|
|
|
|
+ self._insert_to_milvus(kb, item, is_parent=False)
|
|
|
|
|
+ succeeded += 1
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ failed += 1
|
|
|
|
|
+ failures.append({
|
|
|
|
|
+ "index": item.get("index", 0),
|
|
|
|
|
+ "parent_id": item.get("parent_id"),
|
|
|
|
|
+ "error": str(e)
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "UPDATE t_samp_task_management "
|
|
|
|
|
+ "SET status = %s, error_message = %s, completed_time = NOW(), updated_time = NOW() "
|
|
|
|
|
+ "WHERE task_id = %s",
|
|
|
|
|
+ ("completed", json.dumps(failures, ensure_ascii=False) if failures else None, task_id)
|
|
|
|
|
+ )
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ if callback_url:
|
|
|
|
|
+ await self._send_callback(callback_url, task_id, task_no, kb['id'], "completed", total, succeeded, failed, failures)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "UPDATE t_samp_task_management "
|
|
|
|
|
+ "SET status = %s, error_message = %s, completed_time = NOW(), updated_time = NOW() "
|
|
|
|
|
+ "WHERE task_id = %s",
|
|
|
|
|
+ ("failed", str(e), task_id)
|
|
|
|
|
+ )
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ if callback_url:
|
|
|
|
|
+ await self._send_callback(callback_url, task_id, task_no, kb['id'], "failed", total, 0, 0, [])
|
|
|
|
|
+ finally:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ def _insert_to_milvus(self, kb: dict, item: dict, is_parent: bool):
|
|
|
|
|
+ """将单条数据写入Milvus"""
|
|
|
|
|
+ from app.services.milvus_service import milvus_service
|
|
|
|
|
+ import time as _time
|
|
|
|
|
+
|
|
|
|
|
+ coll_name = kb['collection_name_parent'] if is_parent else kb['collection_name_children']
|
|
|
|
|
+ if not coll_name:
|
|
|
|
|
+ raise ValueError("集合名称为空")
|
|
|
|
|
+
|
|
|
|
|
+ # 如果集合不存在,自动创建
|
|
|
|
|
+ milvus_service.ensure_collection_exists(coll_name)
|
|
|
|
|
+
|
|
|
|
|
+ text = item.get("text", "")
|
|
|
|
|
+ if not text:
|
|
|
|
|
+ raise ValueError("文本内容为空")
|
|
|
|
|
+
|
|
|
|
|
+ from app.base.embedding_connection import get_embedding_model
|
|
|
|
|
+ model = get_embedding_model()
|
|
|
|
|
+ vector = model.embed_query(text)
|
|
|
|
|
+
|
|
|
|
|
+ now_ms = int(_time.time() * 1000)
|
|
|
|
|
+ record = {
|
|
|
|
|
+ "text": text,
|
|
|
|
|
+ "dense": vector,
|
|
|
|
|
+ "document_id": str(item.get("doc_id", item.get("parent_id", ""))),
|
|
|
|
|
+ "parent_id": str(item.get("parent_id", "")),
|
|
|
|
|
+ "index": item.get("index", 0),
|
|
|
|
|
+ "hierarchy": item.get("hierarchy", ""),
|
|
|
|
|
+ "metadata": json.dumps(item.get("metadata", {}), ensure_ascii=False),
|
|
|
|
|
+ "tag_list": json.dumps(item.get("tag_list", []), ensure_ascii=False),
|
|
|
|
|
+ "permission": json.dumps(item.get("permission", {}), ensure_ascii=False),
|
|
|
|
|
+ "is_deleted": False,
|
|
|
|
|
+ "created_by": "api_import",
|
|
|
|
|
+ "created_time": now_ms,
|
|
|
|
|
+ "updated_by": "api_import",
|
|
|
|
|
+ "updated_time": now_ms,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ milvus_service.client.insert(collection_name=coll_name, data=[record])
|
|
|
|
|
+
|
|
|
|
|
+ async def _send_callback(self, callback_url: str, task_id: str, task_no: str, kb_id: str, status: str, total: int, succeeded: int, failed: int, failures: list):
|
|
|
|
|
+ """发送回调通知"""
|
|
|
|
|
+ import httpx
|
|
|
|
|
+ payload = {
|
|
|
|
|
+ "task_id": task_id,
|
|
|
|
|
+ "task_no": task_no,
|
|
|
|
|
+ "kb_id": kb_id,
|
|
|
|
|
+ "status": status,
|
|
|
|
|
+ "progress": {
|
|
|
|
|
+ "total": total,
|
|
|
|
|
+ "processed": total,
|
|
|
|
|
+ "succeeded": succeeded,
|
|
|
|
|
+ "failed": failed
|
|
|
|
|
+ },
|
|
|
|
|
+ "failures": failures
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ max_retries = 3
|
|
|
|
|
+ delays = [10, 30, 60]
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(max_retries):
|
|
|
|
|
+ try:
|
|
|
|
|
+ async with httpx.AsyncClient(timeout=10) as client:
|
|
|
|
|
+ resp = await client.post(callback_url, json=payload)
|
|
|
|
|
+ if resp.status_code == 200:
|
|
|
|
|
+ return
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"回调失败(第{i+1}次): {e}")
|
|
|
|
|
+ if i < max_retries - 1:
|
|
|
|
|
+ import asyncio
|
|
|
|
|
+ await asyncio.sleep(delays[i])
|
|
|
|
|
+
|
|
|
|
|
+ logger.error(f"回调最终失败: {callback_url}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ async def get_batch_import_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
+ """查询批量入库任务状态"""
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ if not conn:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ try:
|
|
|
|
|
+ task_id = task_id.strip()
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT task_id, task_no, status, task_params, error_message, "
|
|
|
|
|
+ "completed_time, created_time, updated_time "
|
|
|
|
|
+ "FROM t_samp_task_management WHERE task_id = %s",
|
|
|
|
|
+ (task_id,)
|
|
|
|
|
+ )
|
|
|
|
|
+ task = cursor.fetchone()
|
|
|
|
|
+ if not task:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ status = task['status']
|
|
|
|
|
+
|
|
|
|
|
+ params = json.loads(task['task_params']) if task['task_params'] else {}
|
|
|
|
|
+ total = len(params.get('parents', [])) + len(params.get('children', []))
|
|
|
|
|
+
|
|
|
|
|
+ failures = []
|
|
|
|
|
+ if task['error_message']:
|
|
|
|
|
+ try:
|
|
|
|
|
+ failures = json.loads(task['error_message'])
|
|
|
|
|
+ if not isinstance(failures, list):
|
|
|
|
|
+ failures = []
|
|
|
|
|
+ except:
|
|
|
|
|
+ failures = []
|
|
|
|
|
+
|
|
|
|
|
+ result = {
|
|
|
|
|
+ "task_id": task['task_id'],
|
|
|
|
|
+ "task_no": task['task_no'],
|
|
|
|
|
+ "status": status
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if status in ('pending', 'processing'):
|
|
|
|
|
+ succeeded = max(0, total - len(failures))
|
|
|
|
|
+ result["progress"] = {
|
|
|
|
|
+ "total": total,
|
|
|
|
|
+ "processed": succeeded + len(failures),
|
|
|
|
|
+ "succeeded": succeeded,
|
|
|
|
|
+ "failed": len(failures)
|
|
|
|
|
+ }
|
|
|
|
|
+ result["created_at"] = task['created_time']
|
|
|
|
|
+ result["updated_at"] = task['updated_time']
|
|
|
|
|
+
|
|
|
|
|
+ elif status == 'completed':
|
|
|
|
|
+ succeeded = max(0, total - len(failures))
|
|
|
|
|
+ result["progress"] = {
|
|
|
|
|
+ "total": total,
|
|
|
|
|
+ "processed": total,
|
|
|
|
|
+ "succeeded": succeeded,
|
|
|
|
|
+ "failed": len(failures)
|
|
|
|
|
+ }
|
|
|
|
|
+ result["created_at"] = task['created_time']
|
|
|
|
|
+ result["completed_at"] = task['completed_time']
|
|
|
|
|
+ result["failures"] = failures
|
|
|
|
|
+
|
|
|
|
|
+ elif status == 'failed':
|
|
|
|
|
+ result["error"] = task['error_message'] if task['error_message'] else "任务执行失败"
|
|
|
|
|
+ result["created_at"] = task['created_time']
|
|
|
|
|
+ result["completed_at"] = task['completed_time']
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"查询任务状态失败: {e}")
|
|
|
|
|
+ return None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+ conn.close()
|