| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- """样本中心 API 客户端服务。"""
- import httpx
- import time
- from typing import Any
- from app.config import get_settings
- from app.core.logging import logger
- settings = get_settings()
- # Token 缓存(内存中)
- _token_cache: dict[str, Any] = {}
- def _get_base_url() -> str:
- if not settings.sample_center_base_url:
- raise ValueError("样本中心地址未配置,请检查 SAMPLE_CENTER_BASE_URL 环境变量")
- return settings.sample_center_base_url.rstrip("/")
- def _get_credentials() -> tuple[str, str]:
- if not settings.sample_center_app_id or not settings.sample_center_app_secret:
- raise ValueError("样本中心凭证未配置,请检查 SAMPLE_CENTER_APP_ID 和 SAMPLE_CENTER_APP_SECRET")
- return settings.sample_center_app_id, settings.sample_center_app_secret
- def _check_token_valid() -> bool:
- if not _token_cache.get("access_token"):
- return False
- expires_at = _token_cache.get("expires_at", 0)
- return time.time() < expires_at - 300 # 提前 5 分钟过期
- async def get_token() -> str:
- if _check_token_valid():
- return _token_cache["access_token"]
- app_id, app_secret = _get_credentials()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/auth/token",
- json={"app_id": app_id, "app_secret": app_secret},
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != "000000":
- raise RuntimeError(f"获取样本中心 Token 失败: {body.get('message')}")
- data = body["data"]
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- _token_cache["token_type"] = data.get("token_type", "Bearer")
- return data["access_token"]
- def _auth_headers() -> dict[str, str]:
- app_id, _ = _get_credentials()
- token = _token_cache.get("access_token", "")
- return {
- "Authorization": f"Bearer {token}",
- "X-App-Id": app_id,
- "Content-Type": "application/json",
- }
- async def list_knowledge_bases(page: int = 1, page_size: int = 20) -> dict[str, Any]:
- """获取知识库列表。"""
- token = await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.get(
- f"{base_url}/api/v1/knowledge-bases",
- params={"page": page, "page_size": page_size},
- headers=_auth_headers(),
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != "000000":
- raise RuntimeError(f"获取知识库列表失败: {body.get('message')}")
- return body["data"]
- async def get_knowledge_base_detail(kb_id: str) -> dict[str, Any]:
- """获取知识库详情。"""
- token = await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.get(
- f"{base_url}/api/v1/knowledge-bases/{kb_id}",
- headers=_auth_headers(),
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != "000000":
- raise RuntimeError(f"获取知识库详情失败: {body.get('message')}")
- return body["data"]
- async def batch_import(kb_id: str, parents: list[dict], children: list[dict] | None = None,
- callback_url: str | None = None) -> dict[str, Any]:
- """提交批量入库任务。"""
- import uuid
- token = await get_token()
- base_url = _get_base_url()
- task_no = f"IMP{int(time.time())}{uuid.uuid4().hex[:8]}"
- payload: dict[str, Any] = {
- "task_no": task_no,
- "parents": parents,
- }
- if children:
- payload["children"] = children
- if callback_url:
- payload["callback_url"] = callback_url
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.post(
- f"{base_url}/api/v1/knowledge-bases/{kb_id}/batch-import",
- json=payload,
- headers=_auth_headers(),
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != "000000":
- raise RuntimeError(f"批量入库提交失败: {body.get('message')}")
- return body["data"]
- async def query_import_task(task_id: str) -> dict[str, Any]:
- """查询批量入库任务状态。"""
- token = await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.get(
- f"{base_url}/api/v1/knowledge-bases/batch-import/{task_id}",
- headers=_auth_headers(),
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != "000000":
- raise RuntimeError(f"查询任务失败: {body.get('message')}")
- return body["data"]
- async def import_kb_to_dataset(kb_id: str, kb_name: str) -> dict[str, Any]:
- """从知识库导入数据:查询知识库详情,将数据转为训练格式并保存为数据集。
- 由于样本中心的批量入库是异步任务,这里采用直接查询知识库内容的方式。
- 先获取知识库详情,然后根据 metadata_schema 构建训练数据集。
- """
- kb_detail = await get_knowledge_base_detail(kb_id)
- # 这里返回知识库信息,前端可据此展示给用户
- # 实际的数据导入由批量入库 API 完成
- return {
- "kb_id": kb_id,
- "kb_name": kb_name or kb_detail.get("name", ""),
- "document_count": kb_detail.get("document_count", 0),
- "metadata_schema": kb_detail.get("metadata_schema", []),
- "parent_table": kb_detail.get("parent_table", ""),
- "child_table": kb_detail.get("child_table", ""),
- }
|