"""样本中心 API 客户端服务。""" import httpx import time from typing import Any from app.config import get_settings from app.core.logging import logger settings = get_settings() # Token 缓存(内存中) _token_cache: dict[str, Any] = {} def _get_base_url() -> str: if not settings.sample_center_base_url: raise ValueError("样本中心地址未配置,请检查 SAMPLE_CENTER_BASE_URL 环境变量") return settings.sample_center_base_url.rstrip("/") def _get_credentials() -> tuple[str, str]: if not settings.sample_center_app_id or not settings.sample_center_app_secret: raise ValueError("样本中心凭证未配置,请检查 SAMPLE_CENTER_APP_ID 和 SAMPLE_CENTER_APP_SECRET") return settings.sample_center_app_id, settings.sample_center_app_secret def _check_token_valid() -> bool: if not _token_cache.get("access_token"): return False expires_at = _token_cache.get("expires_at", 0) return time.time() < expires_at - 300 # 提前 5 分钟过期 async def get_token() -> str: if _check_token_valid(): return _token_cache["access_token"] app_id, app_secret = _get_credentials() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/auth/token", json={"app_id": app_id, "app_secret": app_secret}, ) resp.raise_for_status() body = resp.json() if body.get("code") != "000000": raise RuntimeError(f"获取样本中心 Token 失败: {body.get('message')}") data = body["data"] _token_cache["access_token"] = data["access_token"] _token_cache["expires_in"] = data.get("expires_in", 7200) _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) _token_cache["token_type"] = data.get("token_type", "Bearer") return data["access_token"] def _auth_headers() -> dict[str, str]: app_id, _ = _get_credentials() token = _token_cache.get("access_token", "") return { "Authorization": f"Bearer {token}", "X-App-Id": app_id, "Content-Type": "application/json", } async def list_knowledge_bases(page: int = 1, page_size: int = 20) -> dict[str, Any]: """获取知识库列表。""" token = await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.get( f"{base_url}/api/v1/knowledge-bases", params={"page": page, "page_size": page_size}, headers=_auth_headers(), ) resp.raise_for_status() body = resp.json() if body.get("code") != "000000": raise RuntimeError(f"获取知识库列表失败: {body.get('message')}") return body["data"] async def get_knowledge_base_detail(kb_id: str) -> dict[str, Any]: """获取知识库详情。""" token = await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.get( f"{base_url}/api/v1/knowledge-bases/{kb_id}", headers=_auth_headers(), ) resp.raise_for_status() body = resp.json() if body.get("code") != "000000": raise RuntimeError(f"获取知识库详情失败: {body.get('message')}") return body["data"] async def batch_import(kb_id: str, parents: list[dict], children: list[dict] | None = None, callback_url: str | None = None) -> dict[str, Any]: """提交批量入库任务。""" import uuid token = await get_token() base_url = _get_base_url() task_no = f"IMP{int(time.time())}{uuid.uuid4().hex[:8]}" payload: dict[str, Any] = { "task_no": task_no, "parents": parents, } if children: payload["children"] = children if callback_url: payload["callback_url"] = callback_url async with httpx.AsyncClient(timeout=60) as client: resp = await client.post( f"{base_url}/api/v1/knowledge-bases/{kb_id}/batch-import", json=payload, headers=_auth_headers(), ) resp.raise_for_status() body = resp.json() if body.get("code") != "000000": raise RuntimeError(f"批量入库提交失败: {body.get('message')}") return body["data"] async def query_import_task(task_id: str) -> dict[str, Any]: """查询批量入库任务状态。""" token = await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.get( f"{base_url}/api/v1/knowledge-bases/batch-import/{task_id}", headers=_auth_headers(), ) resp.raise_for_status() body = resp.json() if body.get("code") != "000000": raise RuntimeError(f"查询任务失败: {body.get('message')}") return body["data"] async def import_kb_to_dataset(kb_id: str, kb_name: str) -> dict[str, Any]: """从知识库导入数据:查询知识库详情,将数据转为训练格式并保存为数据集。 由于样本中心的批量入库是异步任务,这里采用直接查询知识库内容的方式。 先获取知识库详情,然后根据 metadata_schema 构建训练数据集。 """ kb_detail = await get_knowledge_base_detail(kb_id) # 这里返回知识库信息,前端可据此展示给用户 # 实际的数据导入由批量入库 API 完成 return { "kb_id": kb_id, "kb_name": kb_name or kb_detail.get("name", ""), "document_count": kb_detail.get("document_count", 0), "metadata_schema": kb_detail.get("metadata_schema", []), "parent_table": kb_detail.get("parent_table", ""), "child_table": kb_detail.get("child_table", ""), }