annotation_platform_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. """标注平台 API 客户端服务。
  2. 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
  3. 参考文档:标注平台对外API接口文档.md
  4. 功能:列出项目、获取项目详情、数据集导出与下载。
  5. """
  6. import asyncio
  7. import hashlib
  8. import hmac
  9. import json
  10. import time
  11. import uuid
  12. from datetime import datetime
  13. from pathlib import Path
  14. from typing import Any
  15. import httpx
  16. from app.config import get_settings
  17. from app.core.db import async_session, DatasetRecord
  18. from app.core.logging import logger
  19. settings = get_settings()
  20. # Token 缓存(内存中)
  21. _token_cache: dict[str, Any] = {}
  22. def _get_base_url() -> str:
  23. base_url = settings.annotation_platform_base_url
  24. if not base_url:
  25. raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
  26. return base_url.rstrip("/")
  27. def _get_credentials() -> tuple[str, str]:
  28. app_id = settings.annotation_platform_app_id
  29. app_secret = settings.annotation_platform_app_secret
  30. if not app_id or not app_secret:
  31. raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
  32. return app_id, app_secret
  33. def _build_token_headers() -> dict[str, str]:
  34. """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
  35. app_id, app_secret = _get_credentials()
  36. timestamp = str(int(time.time()))
  37. nonce = uuid.uuid4().hex
  38. message = app_id + timestamp + nonce
  39. signature = hmac.new(
  40. key=app_secret.encode("utf-8"),
  41. msg=message.encode("utf-8"),
  42. digestmod=hashlib.sha256,
  43. ).hexdigest()
  44. return {
  45. "Content-Type": "application/json",
  46. "X-Api-Key": app_id,
  47. "X-Timestamp": timestamp,
  48. "X-Nonce": nonce,
  49. "X-Signature": signature,
  50. }
  51. def _is_token_valid() -> bool:
  52. """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
  53. if not _token_cache.get("access_token"):
  54. return False
  55. expires_at = _token_cache.get("expires_at", 0)
  56. return time.time() < expires_at - 300
  57. async def _refresh_token() -> str:
  58. """使用 Bearer Token 刷新 Access Token。"""
  59. old_token = _token_cache.get("access_token", "")
  60. base_url = _get_base_url()
  61. async with httpx.AsyncClient(timeout=30) as client:
  62. resp = await client.post(
  63. f"{base_url}/api/v1/open/auth/refresh",
  64. headers={"Authorization": f"Bearer {old_token}"},
  65. )
  66. resp.raise_for_status()
  67. body = resp.json()
  68. if body.get("code") != 0:
  69. raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
  70. data = body.get("data", {})
  71. _token_cache["access_token"] = data["access_token"]
  72. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  73. return data["access_token"]
  74. async def get_token() -> str:
  75. """获取 Access Token,带缓存和自动刷新。"""
  76. if _is_token_valid():
  77. return _token_cache["access_token"]
  78. # Token 存在但即将过期,尝试刷新
  79. if _token_cache.get("access_token"):
  80. try:
  81. return await _refresh_token()
  82. except Exception as e:
  83. logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
  84. # 无 Token 或刷新失败,重新签名获取
  85. headers = _build_token_headers()
  86. base_url = _get_base_url()
  87. async with httpx.AsyncClient(timeout=30) as client:
  88. resp = await client.post(
  89. f"{base_url}/api/v1/open/auth/token",
  90. json={},
  91. headers=headers,
  92. )
  93. resp.raise_for_status()
  94. body = resp.json()
  95. if body.get("code") != 0:
  96. raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
  97. data = body.get("data", {})
  98. _token_cache["access_token"] = data["access_token"]
  99. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  100. return data["access_token"]
  101. def _auth_headers() -> dict[str, str]:
  102. """构建业务接口的认证请求头。"""
  103. return {
  104. "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
  105. "Content-Type": "application/json",
  106. }
  107. async def _request(
  108. method: str, path: str, *, raise_on_error: bool = True, **kwargs
  109. ) -> dict[str, Any]:
  110. """统一的业务请求方法,自动携带 Token。
  111. raise_on_error=False 时,HTTP 4xx/5xx 不抛异常,返回带 _error 标记的 dict。
  112. """
  113. await get_token()
  114. base_url = _get_base_url()
  115. async with httpx.AsyncClient(timeout=60) as client:
  116. resp = await client.request(
  117. method,
  118. f"{base_url}{path}",
  119. headers=_auth_headers(),
  120. **kwargs,
  121. )
  122. if not raise_on_error and resp.status_code >= 400:
  123. try:
  124. body = resp.json()
  125. except Exception:
  126. body = {"message": resp.text}
  127. return {"_error": True, "_status_code": resp.status_code, **body}
  128. resp.raise_for_status()
  129. body = resp.json()
  130. if body.get("code") != 0:
  131. raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
  132. return body.get("data", {})
  133. # ---------- 项目列表 ----------
  134. async def list_projects(
  135. page: int = 1,
  136. page_size: int = 20,
  137. name: str | None = None,
  138. project_type: str | None = None,
  139. status: str | None = None,
  140. ) -> dict[str, Any]:
  141. """获取标注平台项目列表。"""
  142. params: dict[str, Any] = {"page": page, "page_size": page_size}
  143. if name:
  144. params["name"] = name
  145. if project_type:
  146. params["type"] = project_type
  147. if status:
  148. params["status"] = status
  149. return await _request("GET", "/api/v1/open/projects", params=params)
  150. # ---------- 项目详情 ----------
  151. async def get_project_detail(project_id: str) -> dict[str, Any]:
  152. """获取项目详情。"""
  153. return await _request("GET", f"/api/v1/open/projects/{project_id}")
  154. # ---------- 数据集导出与下载 ----------
  155. # 文本项目和图片项目支持的导出格式(参考标注平台 API 文档)
  156. _TEXT_FORMATS = ["alpaca", "sharegpt"]
  157. _IMAGE_FORMATS = ["json", "csv", "coco"]
  158. async def _request_export(project_id: str, fmt: str) -> dict[str, Any] | None:
  159. """请求导出并返回导出信息,格式不兼容或失败时返回 None。"""
  160. data = await _request(
  161. "POST",
  162. f"/api/v1/open/projects/{project_id}/datasets/download",
  163. raise_on_error=False,
  164. json={"format": fmt, "completed_only": True},
  165. )
  166. if data.get("_error"):
  167. logger.info(f"导出格式 '{fmt}' 不可用: {data.get('message', data.get('detail', ''))}")
  168. return None
  169. if data.get("status") != "completed":
  170. logger.info(f"导出格式 '{fmt}' 状态异常: {data.get('status')}")
  171. return None
  172. if not data.get("file_url"):
  173. logger.info(f"导出格式 '{fmt}' 未返回下载链接")
  174. return None
  175. return data
  176. async def _download_file(download_token: str, max_retries: int = 3) -> bytes:
  177. """通过 download_token 下载导出文件,带重试(标注平台可能需要时间生成文件)。"""
  178. await get_token()
  179. base_url = _get_base_url()
  180. url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
  181. for attempt in range(max_retries):
  182. async with httpx.AsyncClient(timeout=120) as client:
  183. resp = await client.get(url, headers=_auth_headers(), follow_redirects=True)
  184. resp.raise_for_status()
  185. content = resp.content
  186. if len(content) > 10:
  187. return content
  188. if attempt < max_retries - 1:
  189. wait = 2 ** attempt # 1, 2 秒
  190. logger.info(f"文件尚未就绪 ({len(content)} bytes),{wait}s 后重试...")
  191. await asyncio.sleep(wait)
  192. return content
  193. def _extract_download_token(file_url: str) -> str:
  194. """从 file_url 中提取 download_token。"""
  195. if "/datasets/downloads/" in file_url:
  196. return file_url.split("/datasets/downloads/")[-1].strip("/")
  197. return file_url.rstrip("/").split("/")[-1]
  198. async def import_project_dataset(
  199. project_id: str,
  200. project_name: str = "",
  201. format: str = "alpaca",
  202. ) -> dict[str, Any]:
  203. """导出并下载项目数据集,保存到本地并写入数据库。
  204. 流程:
  205. 1. 查询项目详情获取 project_type
  206. 2. 根据项目类型选择合适的导出格式,依次尝试
  207. 3. 下载文件并转换为 JSONL
  208. 4. 写入数据库
  209. """
  210. # 1. 查询项目类型,决定可用格式
  211. try:
  212. detail = await get_project_detail(project_id)
  213. project_type = detail.get("project_type", "text")
  214. except Exception:
  215. project_type = "text"
  216. # 2. 构建格式尝试列表(用户指定的格式优先)
  217. if project_type == "text":
  218. formats_to_try = list(_TEXT_FORMATS)
  219. else:
  220. formats_to_try = list(_IMAGE_FORMATS)
  221. if format and format not in formats_to_try:
  222. formats_to_try.insert(0, format)
  223. # 3. 依次尝试各格式:请求导出 → 下载文件
  224. file_content = b""
  225. total_exported = 0
  226. used_format = ""
  227. for fmt in formats_to_try:
  228. export_data = await _request_export(project_id, fmt)
  229. if not export_data:
  230. continue
  231. total_exported = export_data.get("total_exported", 0)
  232. download_token = _extract_download_token(export_data["file_url"])
  233. file_content = await _download_file(download_token)
  234. if len(file_content) > 10:
  235. used_format = fmt
  236. logger.info(
  237. f"标注平台导出成功: format={fmt}, {len(file_content)} bytes, "
  238. f"total_exported={total_exported}"
  239. )
  240. break
  241. logger.info(f"格式 '{fmt}' 导出文件为空 ({len(file_content)} bytes),尝试下一格式")
  242. if len(file_content) <= 10:
  243. raise RuntimeError(
  244. f"标注平台所有导出格式均返回空文件(project_type={project_type},"
  245. f"尝试格式: {formats_to_try}),请检查标注平台该项目的数据是否支持导出"
  246. )
  247. # 4. 保存到 uploads 目录
  248. upload_dir = settings.uploads_dir
  249. upload_dir.mkdir(parents=True, exist_ok=True)
  250. safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id))
  251. file_path = upload_dir / f"{safe_name}.jsonl"
  252. if file_path.exists():
  253. file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl"
  254. file_path.write_bytes(file_content)
  255. # 5. 转为 JSONL 格式
  256. jsonl_path = _convert_to_jsonl(file_path)
  257. record_count = _count_records(jsonl_path, "jsonl")
  258. # 6. 写入数据库
  259. record_id = str(uuid.uuid4())
  260. record = DatasetRecord(
  261. id=record_id,
  262. name=project_name or project_id,
  263. format="jsonl",
  264. record_count=record_count,
  265. file_path=str(jsonl_path),
  266. created_at=datetime.utcnow(),
  267. )
  268. async with async_session() as session:
  269. session.add(record)
  270. await session.commit()
  271. logger.info(f"Imported annotation dataset: {project_name} ({record_count} records, format={used_format})")
  272. return {
  273. "project_id": project_id,
  274. "project_name": project_name or project_id,
  275. "format": "jsonl",
  276. "total_exported": total_exported,
  277. "dataset_id": record_id,
  278. "dataset_name": jsonl_path.name,
  279. }
  280. def _convert_to_jsonl(file_path: Path) -> Path:
  281. """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
  282. jsonl_path = file_path.with_suffix(".jsonl")
  283. with open(file_path, "r", encoding="utf-8") as f:
  284. content = f.read().strip()
  285. if not content:
  286. jsonl_path.touch()
  287. return jsonl_path
  288. try:
  289. data = json.loads(content)
  290. items = _extract_items(data)
  291. if items is not None:
  292. with open(jsonl_path, "w", encoding="utf-8") as out:
  293. for item in items:
  294. out.write(json.dumps(item, ensure_ascii=False) + "\n")
  295. if jsonl_path != file_path and file_path.exists():
  296. file_path.unlink()
  297. return jsonl_path
  298. except json.JSONDecodeError:
  299. pass
  300. # JSONL 逐行验证
  301. valid_lines = [line.strip() for line in content.split("\n") if line.strip()]
  302. valid_lines = [line for line in valid_lines if _is_valid_json(line)]
  303. with open(jsonl_path, "w", encoding="utf-8") as out:
  304. out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
  305. if jsonl_path != file_path and file_path.exists():
  306. file_path.unlink()
  307. return jsonl_path
  308. def _extract_items(data) -> list | None:
  309. """从 JSON 数据中提取记录列表。"""
  310. if isinstance(data, list):
  311. return data
  312. if isinstance(data, dict):
  313. for key in ("data", "items", "results", "records", "annotations", "samples"):
  314. if key in data and isinstance(data[key], list):
  315. return data[key]
  316. return [data]
  317. return None
  318. def _is_valid_json(s: str) -> bool:
  319. try:
  320. json.loads(s)
  321. return True
  322. except json.JSONDecodeError:
  323. return False
  324. def _count_records(file_path: Path, fmt: str) -> int:
  325. """计算文件中的记录数。"""
  326. if not file_path.exists():
  327. return 0
  328. try:
  329. if fmt == "jsonl":
  330. with open(file_path, "r", encoding="utf-8") as f:
  331. return sum(1 for line in f if line.strip())
  332. elif fmt == "json":
  333. with open(file_path, "r", encoding="utf-8") as f:
  334. data = json.load(f)
  335. return len(data) if isinstance(data, list) else 1
  336. elif fmt == "csv":
  337. import csv
  338. with open(file_path, "r", encoding="utf-8") as f:
  339. return sum(1 for _ in csv.DictReader(f))
  340. except Exception:
  341. return 0
  342. return 0