annotation_platform_service.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """标注平台 API 客户端服务。
  2. 对接标注平台的对外 API,支持 HMAC-SHA256 签名认证。
  3. 功能:列出项目、获取项目详情、导出并下载数据集。
  4. """
  5. import hashlib
  6. import hmac
  7. import secrets
  8. import time
  9. import uuid
  10. from datetime import datetime
  11. from pathlib import Path
  12. from typing import Any
  13. import httpx
  14. from app.config import get_settings
  15. from app.core.db import async_session, DatasetRecord
  16. from app.core.logging import logger
  17. settings = get_settings()
  18. # Token 缓存(内存中)
  19. _token_cache: dict[str, Any] = {}
  20. def _get_base_url() -> str:
  21. if not settings.annotation_platform_base_url:
  22. raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL 环境变量")
  23. return settings.annotation_platform_base_url.rstrip("/")
  24. def _get_credentials() -> tuple[str, str]:
  25. if not settings.annotation_platform_app_id or not settings.annotation_platform_app_secret:
  26. raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
  27. return settings.annotation_platform_app_id, settings.annotation_platform_app_secret
  28. def _sign(app_secret: str, app_id: str, timestamp: str, nonce: str) -> str:
  29. """HMAC-SHA256 签名。"""
  30. message = app_id + timestamp + nonce
  31. return hmac.new(app_secret.encode(), message.encode(), hashlib.sha256).hexdigest()
  32. def _check_token_valid() -> bool:
  33. if not _token_cache.get("access_token"):
  34. return False
  35. expires_at = _token_cache.get("expires_at", 0)
  36. return time.time() < expires_at - 300 # 提前 5 分钟刷新
  37. async def get_token() -> str:
  38. """获取 Access Token,带缓存。"""
  39. if _check_token_valid():
  40. return _token_cache["access_token"]
  41. app_id, app_secret = _get_credentials()
  42. base_url = _get_base_url()
  43. timestamp = str(int(time.time()))
  44. nonce = secrets.token_hex(8) # 16 位十六进制随机字符串
  45. signature = _sign(app_secret, app_id, timestamp, nonce)
  46. async with httpx.AsyncClient(timeout=30) as client:
  47. resp = await client.post(
  48. f"{base_url}/api/v1/open/auth/token",
  49. headers={
  50. "X-Api-Key": app_id,
  51. "X-Signature": signature,
  52. "X-Timestamp": timestamp,
  53. "X-Nonce": nonce,
  54. },
  55. )
  56. resp.raise_for_status()
  57. body = resp.json()
  58. if body.get("code") != 0:
  59. raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message')}")
  60. data = body["data"]
  61. _token_cache["access_token"] = data["access_token"]
  62. _token_cache["expires_in"] = data.get("expires_in", 7200)
  63. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  64. return data["access_token"]
  65. def _auth_headers() -> dict[str, str]:
  66. token = _token_cache.get("access_token", "")
  67. return {
  68. "Authorization": f"Bearer {token}",
  69. "Content-Type": "application/json",
  70. }
  71. async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
  72. """统一的请求方法,自动携带 Token 并处理错误。"""
  73. await get_token()
  74. base_url = _get_base_url()
  75. async with httpx.AsyncClient(timeout=60) as client:
  76. resp = await client.request(
  77. method,
  78. f"{base_url}{path}",
  79. headers=_auth_headers(),
  80. **kwargs,
  81. )
  82. resp.raise_for_status()
  83. body = resp.json()
  84. if body.get("code") != 0:
  85. raise RuntimeError(f"标注平台请求失败: {body.get('message')}")
  86. return body.get("data", {})
  87. # ---------- 项目列表 ----------
  88. async def list_projects(
  89. page: int = 1,
  90. page_size: int = 20,
  91. name: str | None = None,
  92. project_type: str | None = None,
  93. status: str | None = None,
  94. ) -> dict[str, Any]:
  95. """获取标注平台项目列表。"""
  96. params: dict[str, Any] = {"page": page, "page_size": page_size}
  97. if name:
  98. params["name"] = name
  99. if project_type:
  100. params["type"] = project_type
  101. if status:
  102. params["status"] = status
  103. return await _request("GET", "/api/v1/open/projects", params=params)
  104. # ---------- 项目详情 ----------
  105. async def get_project_detail(project_id: str) -> dict[str, Any]:
  106. """获取项目详情。"""
  107. return await _request("GET", f"/api/v1/open/projects/{project_id}")
  108. # ---------- 数据集导出与下载 ----------
  109. async def import_project_dataset(
  110. project_id: str,
  111. project_name: str = "",
  112. format: str = "alpaca",
  113. ) -> dict[str, Any]:
  114. """导出并下载项目数据集,保存到本地并写入数据库。
  115. 流程:
  116. 1. POST 请求导出 → 获取 file_url
  117. 2. GET 下载文件 → 保存到 uploads 目录
  118. 3. 写入 DatasetRecord 数据库
  119. """
  120. # 1. 请求导出
  121. export_data = await _request(
  122. "POST",
  123. f"/api/v1/open/projects/{project_id}/datasets/download",
  124. json={"format": format, "completed_only": True},
  125. )
  126. file_url = export_data.get("file_url", "")
  127. file_name = export_data.get("file_name", f"{project_id}_{format}.json")
  128. total_exported = export_data.get("total_exported", 0)
  129. if not file_url:
  130. raise RuntimeError("标注平台未返回下载链接")
  131. # 2. 下载文件
  132. await get_token()
  133. base_url = _get_base_url()
  134. download_url = f"{base_url}{file_url}" if file_url.startswith("/") else file_url
  135. async with httpx.AsyncClient(timeout=120) as client:
  136. resp = await client.get(
  137. download_url,
  138. headers={"Authorization": f"Bearer {_token_cache.get('access_token', '')}"},
  139. follow_redirects=True,
  140. )
  141. resp.raise_for_status()
  142. file_content = resp.content
  143. # 3. 保存到 uploads 目录
  144. upload_dir = settings.uploads_dir
  145. upload_dir.mkdir(parents=True, exist_ok=True)
  146. safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name
  147. # 清理文件名中的非法字符
  148. safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name)
  149. file_path = upload_dir / safe_name
  150. if file_path.exists():
  151. file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
  152. file_path.write_bytes(file_content)
  153. # 4. 检测格式和记录数
  154. fmt = _detect_format(file_path.name)
  155. record_count = _count_records(file_path, fmt)
  156. # 5. 写入数据库
  157. record_id = str(uuid.uuid4())
  158. record = DatasetRecord(
  159. id=record_id,
  160. name=file_path.name,
  161. format=fmt,
  162. record_count=record_count,
  163. file_path=str(file_path),
  164. created_at=datetime.utcnow(),
  165. )
  166. async with async_session() as session:
  167. session.add(record)
  168. await session.commit()
  169. logger.info(f"Imported dataset from annotation platform: {project_id} -> {file_path.name} ({record_count} records)")
  170. return {
  171. "project_id": project_id,
  172. "project_name": project_name or project_id,
  173. "format": format,
  174. "total_exported": total_exported,
  175. "dataset_id": record_id,
  176. "dataset_name": file_path.name,
  177. }
  178. def _detect_format(filename: str) -> str:
  179. """根据文件名推断格式。"""
  180. name = filename.lower()
  181. if name.endswith(".jsonl"):
  182. return "jsonl"
  183. if name.endswith(".csv"):
  184. return "csv"
  185. if name.endswith(".parquet"):
  186. return "parquet"
  187. if name.endswith(".json"):
  188. return "json"
  189. return "json"
  190. def _count_records(file_path: Path, fmt: str) -> int:
  191. """计算文件中的记录数。"""
  192. import json
  193. if not file_path.exists():
  194. return 0
  195. try:
  196. if fmt == "jsonl":
  197. with open(file_path, "r", encoding="utf-8") as f:
  198. return sum(1 for line in f if line.strip())
  199. elif fmt == "json":
  200. with open(file_path, "r", encoding="utf-8") as f:
  201. data = json.load(f)
  202. if isinstance(data, list):
  203. return len(data)
  204. return 1
  205. elif fmt == "csv":
  206. import csv
  207. with open(file_path, "r", encoding="utf-8") as f:
  208. return sum(1 for _ in csv.DictReader(f))
  209. except Exception:
  210. return 0
  211. return 0