sample_center_service.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """样本中心 API 客户端服务。"""
  2. import httpx
  3. import time
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.logging import logger
  7. settings = get_settings()
  8. # Token 缓存(内存中)
  9. _token_cache: dict[str, Any] = {}
  10. def _get_base_url() -> str:
  11. if not settings.sample_center_base_url:
  12. raise ValueError("样本中心地址未配置,请检查 SAMPLE_CENTER_BASE_URL 环境变量")
  13. return settings.sample_center_base_url.rstrip("/")
  14. def _get_credentials() -> tuple[str, str]:
  15. if not settings.sample_center_app_id or not settings.sample_center_app_secret:
  16. raise ValueError("样本中心凭证未配置,请检查 SAMPLE_CENTER_APP_ID 和 SAMPLE_CENTER_APP_SECRET")
  17. return settings.sample_center_app_id, settings.sample_center_app_secret
  18. def _check_token_valid() -> bool:
  19. if not _token_cache.get("access_token"):
  20. return False
  21. expires_at = _token_cache.get("expires_at", 0)
  22. return time.time() < expires_at - 300 # 提前 5 分钟过期
  23. async def get_token() -> str:
  24. if _check_token_valid():
  25. return _token_cache["access_token"]
  26. app_id, app_secret = _get_credentials()
  27. base_url = _get_base_url()
  28. async with httpx.AsyncClient(timeout=30) as client:
  29. resp = await client.post(
  30. f"{base_url}/api/v1/auth/token",
  31. json={"app_id": app_id, "app_secret": app_secret},
  32. )
  33. resp.raise_for_status()
  34. body = resp.json()
  35. if body.get("code") != "000000":
  36. raise RuntimeError(f"获取样本中心 Token 失败: {body.get('message')}")
  37. data = body["data"]
  38. _token_cache["access_token"] = data["access_token"]
  39. _token_cache["expires_in"] = data.get("expires_in", 7200)
  40. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  41. _token_cache["token_type"] = data.get("token_type", "Bearer")
  42. return data["access_token"]
  43. def _auth_headers() -> dict[str, str]:
  44. app_id, _ = _get_credentials()
  45. token = _token_cache.get("access_token", "")
  46. return {
  47. "Authorization": f"Bearer {token}",
  48. "X-App-Id": app_id,
  49. "Content-Type": "application/json",
  50. }
  51. async def list_knowledge_bases(page: int = 1, page_size: int = 20) -> dict[str, Any]:
  52. """获取知识库列表。"""
  53. token = await get_token()
  54. base_url = _get_base_url()
  55. async with httpx.AsyncClient(timeout=30) as client:
  56. resp = await client.get(
  57. f"{base_url}/api/v1/knowledge-bases",
  58. params={"page": page, "page_size": page_size},
  59. headers=_auth_headers(),
  60. )
  61. resp.raise_for_status()
  62. body = resp.json()
  63. if body.get("code") != "000000":
  64. raise RuntimeError(f"获取知识库列表失败: {body.get('message')}")
  65. return body["data"]
  66. async def get_knowledge_base_detail(kb_id: str) -> dict[str, Any]:
  67. """获取知识库详情。"""
  68. token = await get_token()
  69. base_url = _get_base_url()
  70. async with httpx.AsyncClient(timeout=30) as client:
  71. resp = await client.get(
  72. f"{base_url}/api/v1/knowledge-bases/{kb_id}",
  73. headers=_auth_headers(),
  74. )
  75. resp.raise_for_status()
  76. body = resp.json()
  77. if body.get("code") != "000000":
  78. raise RuntimeError(f"获取知识库详情失败: {body.get('message')}")
  79. return body["data"]
  80. async def batch_import(kb_id: str, parents: list[dict], children: list[dict] | None = None,
  81. callback_url: str | None = None) -> dict[str, Any]:
  82. """提交批量入库任务。"""
  83. import uuid
  84. token = await get_token()
  85. base_url = _get_base_url()
  86. task_no = f"IMP{int(time.time())}{uuid.uuid4().hex[:8]}"
  87. payload: dict[str, Any] = {
  88. "task_no": task_no,
  89. "parents": parents,
  90. }
  91. if children:
  92. payload["children"] = children
  93. if callback_url:
  94. payload["callback_url"] = callback_url
  95. async with httpx.AsyncClient(timeout=60) as client:
  96. resp = await client.post(
  97. f"{base_url}/api/v1/knowledge-bases/{kb_id}/batch-import",
  98. json=payload,
  99. headers=_auth_headers(),
  100. )
  101. resp.raise_for_status()
  102. body = resp.json()
  103. if body.get("code") != "000000":
  104. raise RuntimeError(f"批量入库提交失败: {body.get('message')}")
  105. return body["data"]
  106. async def query_import_task(task_id: str) -> dict[str, Any]:
  107. """查询批量入库任务状态。"""
  108. token = await get_token()
  109. base_url = _get_base_url()
  110. async with httpx.AsyncClient(timeout=30) as client:
  111. resp = await client.get(
  112. f"{base_url}/api/v1/knowledge-bases/batch-import/{task_id}",
  113. headers=_auth_headers(),
  114. )
  115. resp.raise_for_status()
  116. body = resp.json()
  117. if body.get("code") != "000000":
  118. raise RuntimeError(f"查询任务失败: {body.get('message')}")
  119. return body["data"]
  120. async def import_kb_to_dataset(kb_id: str, kb_name: str) -> dict[str, Any]:
  121. """从知识库导入数据:查询知识库详情,将数据转为训练格式并保存为数据集。
  122. 由于样本中心的批量入库是异步任务,这里采用直接查询知识库内容的方式。
  123. 先获取知识库详情,然后根据 metadata_schema 构建训练数据集。
  124. """
  125. kb_detail = await get_knowledge_base_detail(kb_id)
  126. # 这里返回知识库信息,前端可据此展示给用户
  127. # 实际的数据导入由批量入库 API 完成
  128. return {
  129. "kb_id": kb_id,
  130. "kb_name": kb_name or kb_detail.get("name", ""),
  131. "document_count": kb_detail.get("document_count", 0),
  132. "metadata_schema": kb_detail.get("metadata_schema", []),
  133. "parent_table": kb_detail.get("parent_table", ""),
  134. "child_table": kb_detail.get("child_table", ""),
  135. }