annotation_platform_service.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """标注平台 API 客户端服务。
  2. 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
  3. 参考文档:标注平台对外API接口文档.md
  4. 功能:列出项目、获取项目详情、数据集导出与下载。
  5. """
  6. import hashlib
  7. import hmac
  8. import secrets
  9. import time
  10. import uuid
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Any
  14. import httpx
  15. from app.config import get_settings
  16. from app.core.db import async_session, DatasetRecord
  17. from app.core.logging import logger
  18. settings = get_settings()
  19. # Token 缓存(内存中)
  20. _token_cache: dict[str, Any] = {}
  21. def _get_base_url() -> str:
  22. base_url = settings.annotation_platform_base_url
  23. if not base_url:
  24. raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
  25. return base_url.rstrip("/")
  26. def _get_credentials() -> tuple[str, str]:
  27. app_id = settings.annotation_platform_app_id
  28. app_secret = settings.annotation_platform_app_secret
  29. if not app_id or not app_secret:
  30. raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
  31. return app_id, app_secret
  32. def _build_token_headers() -> dict[str, str]:
  33. """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
  34. app_id, app_secret = _get_credentials()
  35. timestamp = str(int(time.time()))
  36. nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
  37. message = app_id + timestamp + nonce
  38. signature = hmac.new(
  39. key=app_secret.encode("utf-8"),
  40. msg=message.encode("utf-8"),
  41. digestmod=hashlib.sha256,
  42. ).hexdigest()
  43. return {
  44. "Content-Type": "application/json",
  45. "X-Api-Key": app_id,
  46. "X-Timestamp": timestamp,
  47. "X-Nonce": nonce,
  48. "X-Signature": signature,
  49. }
  50. def _is_token_valid() -> bool:
  51. """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
  52. if not _token_cache.get("access_token"):
  53. return False
  54. expires_at = _token_cache.get("expires_at", 0)
  55. return time.time() < expires_at - 300
  56. async def get_token() -> str:
  57. """获取 Access Token,带缓存。
  58. POST /api/v1/open/auth/token
  59. 使用 HMAC-SHA256 签名认证,无请求体。
  60. """
  61. if _is_token_valid():
  62. return _token_cache["access_token"]
  63. headers = _build_token_headers()
  64. base_url = _get_base_url()
  65. async with httpx.AsyncClient(timeout=30) as client:
  66. resp = await client.post(
  67. f"{base_url}/api/v1/open/auth/token",
  68. json={},
  69. headers=headers,
  70. )
  71. resp.raise_for_status()
  72. body = resp.json()
  73. # 标注平台返回 code: 0 表示成功
  74. if body.get("code") != 0:
  75. raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
  76. data = body.get("data", {})
  77. _token_cache["access_token"] = data["access_token"]
  78. _token_cache["expires_in"] = data.get("expires_in", 7200)
  79. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  80. return data["access_token"]
  81. def _auth_headers() -> dict[str, str]:
  82. """构建业务接口的认证请求头。"""
  83. return {
  84. "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
  85. "Content-Type": "application/json",
  86. }
  87. async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
  88. """统一的业务请求方法,自动携带 Token。"""
  89. await get_token()
  90. base_url = _get_base_url()
  91. async with httpx.AsyncClient(timeout=60) as client:
  92. resp = await client.request(
  93. method,
  94. f"{base_url}{path}",
  95. headers=_auth_headers(),
  96. **kwargs,
  97. )
  98. resp.raise_for_status()
  99. body = resp.json()
  100. if body.get("code") != 0:
  101. raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
  102. return body.get("data", {})
  103. # ---------- 项目列表 ----------
  104. async def list_projects(
  105. page: int = 1,
  106. page_size: int = 20,
  107. name: str | None = None,
  108. project_type: str | None = None,
  109. status: str | None = None,
  110. ) -> dict[str, Any]:
  111. """获取标注平台项目列表。
  112. GET /api/v1/open/projects
  113. """
  114. params: dict[str, Any] = {"page": page, "page_size": page_size}
  115. if name:
  116. params["name"] = name
  117. if project_type:
  118. params["type"] = project_type
  119. if status:
  120. params["status"] = status
  121. return await _request("GET", "/api/v1/open/projects", params=params)
  122. # ---------- 项目详情 ----------
  123. async def get_project_detail(project_id: str) -> dict[str, Any]:
  124. """获取项目详情。
  125. GET /api/v1/open/projects/{project_id}
  126. """
  127. return await _request("GET", f"/api/v1/open/projects/{project_id}")
  128. # ---------- 数据集导出与下载 ----------
  129. async def import_project_dataset(
  130. project_id: str,
  131. project_name: str = "",
  132. format: str = "alpaca",
  133. ) -> dict[str, Any]:
  134. """导出并下载项目数据集,保存到本地并写入数据库。
  135. 流程:
  136. 1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
  137. 2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
  138. 3. 保存到 uploads 目录
  139. 4. 写入 DatasetRecord 数据库
  140. """
  141. # 1. 请求导出
  142. export_data = await _request(
  143. "POST",
  144. f"/api/v1/open/projects/{project_id}/datasets/download",
  145. json={"format": format, "completed_only": True},
  146. )
  147. file_url = export_data.get("file_url", "")
  148. file_name = export_data.get("file_name", f"{project_id}_{format}.json")
  149. total_exported = export_data.get("total_exported", 0)
  150. if not file_url:
  151. raise RuntimeError("标注平台未返回下载链接")
  152. # 2. 从 file_url 中提取 download_token
  153. # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123
  154. if "/datasets/downloads/" in file_url:
  155. download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
  156. else:
  157. # 兜底:直接使用 file_url 的最后一段
  158. download_token = file_url.rstrip("/").split("/")[-1]
  159. # 3. 通过独立的下载接口获取文件(文档 4.6 节)
  160. await get_token()
  161. base_url = _get_base_url()
  162. async with httpx.AsyncClient(timeout=120) as client:
  163. resp = await client.get(
  164. f"{base_url}/api/v1/open/datasets/downloads/{download_token}",
  165. headers=_auth_headers(),
  166. follow_redirects=True,
  167. )
  168. resp.raise_for_status()
  169. file_content = resp.content
  170. # 4. 保存到 uploads 目录
  171. upload_dir = settings.uploads_dir
  172. upload_dir.mkdir(parents=True, exist_ok=True)
  173. safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name
  174. # 清理文件名中的非法字符
  175. safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name)
  176. file_path = upload_dir / safe_name
  177. if file_path.exists():
  178. file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
  179. file_path.write_bytes(file_content)
  180. # 5. 检测格式和记录数
  181. fmt = _detect_format(file_path.name)
  182. record_count = _count_records(file_path, fmt)
  183. # 6. 写入数据库
  184. record_id = str(uuid.uuid4())
  185. record = DatasetRecord(
  186. id=record_id,
  187. name=file_path.name,
  188. format=fmt,
  189. record_count=record_count,
  190. file_path=str(file_path),
  191. created_at=datetime.utcnow(),
  192. )
  193. async with async_session() as session:
  194. session.add(record)
  195. await session.commit()
  196. logger.info(f"Imported dataset from annotation platform: {project_id} -> {file_path.name} ({record_count} records)")
  197. return {
  198. "project_id": project_id,
  199. "project_name": project_name or project_id,
  200. "format": format,
  201. "total_exported": total_exported,
  202. "dataset_id": record_id,
  203. "dataset_name": file_path.name,
  204. }
  205. def _detect_format(filename: str) -> str:
  206. """根据文件名推断格式。"""
  207. name = filename.lower()
  208. if name.endswith(".jsonl"):
  209. return "jsonl"
  210. if name.endswith(".csv"):
  211. return "csv"
  212. if name.endswith(".parquet"):
  213. return "parquet"
  214. if name.endswith(".json"):
  215. return "json"
  216. return "json"
  217. def _count_records(file_path: Path, fmt: str) -> int:
  218. """计算文件中的记录数。"""
  219. import json
  220. if not file_path.exists():
  221. return 0
  222. try:
  223. if fmt == "jsonl":
  224. with open(file_path, "r", encoding="utf-8") as f:
  225. return sum(1 for line in f if line.strip())
  226. elif fmt == "json":
  227. with open(file_path, "r", encoding="utf-8") as f:
  228. data = json.load(f)
  229. if isinstance(data, list):
  230. return len(data)
  231. return 1
  232. elif fmt == "csv":
  233. import csv
  234. with open(file_path, "r", encoding="utf-8") as f:
  235. return sum(1 for _ in csv.DictReader(f))
  236. except Exception:
  237. return 0
  238. return 0