annotation_platform_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. """标注平台 API 客户端服务。
  2. 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
  3. 参考文档:标注平台对外API接口文档.md
  4. 功能:列出项目、获取项目详情、数据集导出与下载。
  5. """
  6. import hashlib
  7. import hmac
  8. import secrets
  9. import time
  10. import uuid
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Any
  14. import httpx
  15. from app.config import get_settings
  16. from app.core.db import async_session, DatasetRecord
  17. from app.core.logging import logger
  18. settings = get_settings()
  19. # Token 缓存(内存中)
  20. _token_cache: dict[str, Any] = {}
  21. def _get_base_url() -> str:
  22. base_url = settings.annotation_platform_base_url
  23. if not base_url:
  24. raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
  25. return base_url.rstrip("/")
  26. def _get_credentials() -> tuple[str, str]:
  27. app_id = settings.annotation_platform_app_id
  28. app_secret = settings.annotation_platform_app_secret
  29. if not app_id or not app_secret:
  30. raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
  31. return app_id, app_secret
  32. def _build_token_headers() -> dict[str, str]:
  33. """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
  34. app_id, app_secret = _get_credentials()
  35. timestamp = str(int(time.time()))
  36. nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
  37. message = app_id + timestamp + nonce
  38. signature = hmac.new(
  39. key=app_secret.encode("utf-8"),
  40. msg=message.encode("utf-8"),
  41. digestmod=hashlib.sha256,
  42. ).hexdigest()
  43. return {
  44. "Content-Type": "application/json",
  45. "X-Api-Key": app_id,
  46. "X-Timestamp": timestamp,
  47. "X-Nonce": nonce,
  48. "X-Signature": signature,
  49. }
  50. def _is_token_valid() -> bool:
  51. """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
  52. if not _token_cache.get("access_token"):
  53. return False
  54. expires_at = _token_cache.get("expires_at", 0)
  55. return time.time() < expires_at - 300
  56. async def _refresh_token() -> str:
  57. """使用 Bearer Token 刷新 Access Token。
  58. POST /api/v1/open/auth/refresh
  59. 比重新签名更高效,仅在 Token 存在但即将过期时调用。
  60. """
  61. old_token = _token_cache.get("access_token", "")
  62. base_url = _get_base_url()
  63. async with httpx.AsyncClient(timeout=30) as client:
  64. resp = await client.post(
  65. f"{base_url}/api/v1/open/auth/refresh",
  66. headers={"Authorization": f"Bearer {old_token}"},
  67. )
  68. resp.raise_for_status()
  69. body = resp.json()
  70. if body.get("code") != 0:
  71. raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
  72. data = body.get("data", {})
  73. _token_cache["access_token"] = data["access_token"]
  74. _token_cache["expires_in"] = data.get("expires_in", 7200)
  75. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  76. return data["access_token"]
  77. async def get_token() -> str:
  78. """获取 Access Token,带缓存和自动刷新。
  79. 优先级:
  80. 1. Token 有效 → 直接返回
  81. 2. Token 存在但快过期 → 调用 /auth/refresh 刷新
  82. 3. 无 Token → 调用 /auth/token 重新签名获取
  83. """
  84. if _is_token_valid():
  85. return _token_cache["access_token"]
  86. # Token 存在但即将过期,尝试刷新
  87. if _token_cache.get("access_token"):
  88. try:
  89. return await _refresh_token()
  90. except Exception as e:
  91. logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
  92. # 无 Token 或刷新失败,重新签名获取
  93. headers = _build_token_headers()
  94. base_url = _get_base_url()
  95. async with httpx.AsyncClient(timeout=30) as client:
  96. resp = await client.post(
  97. f"{base_url}/api/v1/open/auth/token",
  98. json={},
  99. headers=headers,
  100. )
  101. resp.raise_for_status()
  102. body = resp.json()
  103. # 标注平台返回 code: 0 表示成功
  104. if body.get("code") != 0:
  105. raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
  106. data = body.get("data", {})
  107. _token_cache["access_token"] = data["access_token"]
  108. _token_cache["expires_in"] = data.get("expires_in", 7200)
  109. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  110. return data["access_token"]
  111. def _auth_headers() -> dict[str, str]:
  112. """构建业务接口的认证请求头。"""
  113. return {
  114. "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
  115. "Content-Type": "application/json",
  116. }
  117. async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
  118. """统一的业务请求方法,自动携带 Token。"""
  119. await get_token()
  120. base_url = _get_base_url()
  121. async with httpx.AsyncClient(timeout=60) as client:
  122. resp = await client.request(
  123. method,
  124. f"{base_url}{path}",
  125. headers=_auth_headers(),
  126. **kwargs,
  127. )
  128. resp.raise_for_status()
  129. body = resp.json()
  130. if body.get("code") != 0:
  131. raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
  132. return body.get("data", {})
  133. # ---------- 项目列表 ----------
  134. async def list_projects(
  135. page: int = 1,
  136. page_size: int = 20,
  137. name: str | None = None,
  138. project_type: str | None = None,
  139. status: str | None = None,
  140. ) -> dict[str, Any]:
  141. """获取标注平台项目列表。
  142. GET /api/v1/open/projects
  143. """
  144. params: dict[str, Any] = {"page": page, "page_size": page_size}
  145. if name:
  146. params["name"] = name
  147. if project_type:
  148. params["type"] = project_type
  149. if status:
  150. params["status"] = status
  151. return await _request("GET", "/api/v1/open/projects", params=params)
  152. # ---------- 项目详情 ----------
  153. async def get_project_detail(project_id: str) -> dict[str, Any]:
  154. """获取项目详情。
  155. GET /api/v1/open/projects/{project_id}
  156. """
  157. return await _request("GET", f"/api/v1/open/projects/{project_id}")
  158. # ---------- 数据集导出与下载 ----------
  159. async def import_project_dataset(
  160. project_id: str,
  161. project_name: str = "",
  162. format: str = "alpaca",
  163. ) -> dict[str, Any]:
  164. """导出并下载项目数据集,保存到本地并写入数据库。
  165. 流程:
  166. 1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
  167. 2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
  168. 3. 保存到 uploads 目录
  169. 4. 写入 DatasetRecord 数据库
  170. """
  171. # 1. 请求导出(先尝试只导出已完成,若为空则回退导出全部)
  172. export_data = await _request(
  173. "POST",
  174. f"/api/v1/open/projects/{project_id}/datasets/download",
  175. json={"format": format, "completed_only": True},
  176. )
  177. file_url = export_data.get("file_url", "")
  178. file_name = export_data.get("file_name", f"{project_id}_{format}.json")
  179. total_exported = export_data.get("total_exported", 0)
  180. logger.info(
  181. f"Annotation export (completed_only=True): total_exported={total_exported}, "
  182. f"file_url={file_url}, file_name={file_name}"
  183. )
  184. # 如果已完成数据为空,回退导出全部(包括未完成的)
  185. if total_exported == 0 or not file_url:
  186. logger.info(f"No completed items found, retrying with completed_only=False")
  187. export_data = await _request(
  188. "POST",
  189. f"/api/v1/open/projects/{project_id}/datasets/download",
  190. json={"format": format, "completed_only": False},
  191. )
  192. file_url = export_data.get("file_url", "")
  193. file_name = export_data.get("file_name", f"{project_id}_{format}.json")
  194. total_exported = export_data.get("total_exported", 0)
  195. logger.info(
  196. f"Annotation export (completed_only=False): total_exported={total_exported}, "
  197. f"file_url={file_url}, file_name={file_name}"
  198. )
  199. if not file_url:
  200. raise RuntimeError("标注平台未返回下载链接")
  201. # 2. 从 file_url 中提取 download_token
  202. # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123
  203. if "/datasets/downloads/" in file_url:
  204. download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
  205. else:
  206. # 兜底:直接使用 file_url 的最后一段
  207. download_token = file_url.rstrip("/").split("/")[-1]
  208. # 3. 通过独立的下载接口获取文件(文档 4.6 节)
  209. await get_token()
  210. base_url = _get_base_url()
  211. download_url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
  212. async with httpx.AsyncClient(timeout=120) as client:
  213. # 先手动处理重定向,确保每次请求都带上认证头
  214. resp = await client.get(
  215. download_url,
  216. headers=_auth_headers(),
  217. follow_redirects=False,
  218. )
  219. # 手动跟随重定向,每次都带上认证头
  220. redirect_count = 0
  221. while resp.is_redirect and redirect_count < 5:
  222. redirect_url = resp.next_request.url
  223. logger.info(f"Download redirect to: {redirect_url}")
  224. resp = await client.get(
  225. str(redirect_url),
  226. headers=_auth_headers(),
  227. follow_redirects=False,
  228. )
  229. redirect_count += 1
  230. resp.raise_for_status()
  231. file_content = resp.content
  232. logger.info(
  233. f"Downloaded annotation file: {len(file_content)} bytes, "
  234. f"content_type={resp.headers.get('content-type', 'unknown')}, "
  235. f"url={resp.url}, redirects={redirect_count}"
  236. )
  237. if len(file_content) < 200:
  238. logger.warning(f"Annotation file content suspiciously small: {file_content!r}")
  239. # 4. 保存到 uploads 目录
  240. upload_dir = settings.uploads_dir
  241. upload_dir.mkdir(parents=True, exist_ok=True)
  242. safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id))
  243. file_path = upload_dir / f"{safe_name}.jsonl"
  244. if file_path.exists():
  245. file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl"
  246. file_path.write_bytes(file_content)
  247. # 5. 统一转为 JSONL 格式(和 ModelScope/HF 下载的数据格式一致)
  248. jsonl_path = _convert_to_jsonl(file_path)
  249. record_count = _count_records(jsonl_path, "jsonl")
  250. logger.info(f"Annotation file converted: {jsonl_path.name}, record_count={record_count}")
  251. # 6. 写入数据库(格式统一为 jsonl)
  252. record_id = str(uuid.uuid4())
  253. record = DatasetRecord(
  254. id=record_id,
  255. name=project_name or project_id,
  256. format="jsonl",
  257. record_count=record_count,
  258. file_path=str(jsonl_path),
  259. created_at=datetime.utcnow(),
  260. )
  261. async with async_session() as session:
  262. session.add(record)
  263. await session.commit()
  264. logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)")
  265. return {
  266. "project_id": project_id,
  267. "project_name": project_name or project_id,
  268. "format": "jsonl",
  269. "total_exported": total_exported,
  270. "dataset_id": record_id,
  271. "dataset_name": jsonl_path.name,
  272. }
  273. def _convert_to_jsonl(file_path: Path) -> Path:
  274. """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
  275. import json as _json
  276. jsonl_path = file_path.with_suffix(".jsonl")
  277. with open(file_path, "r", encoding="utf-8") as f:
  278. content = f.read().strip()
  279. if not content:
  280. jsonl_path.touch()
  281. return jsonl_path
  282. try:
  283. # 尝试作为 JSON 解析
  284. data = _json.loads(content)
  285. if isinstance(data, list):
  286. # JSON 数组
  287. items = data
  288. elif isinstance(data, dict):
  289. # JSON 对象:查找嵌套的数组字段
  290. items = None
  291. for key in ("data", "items", "results", "records", "annotations", "samples"):
  292. if key in data and isinstance(data[key], list):
  293. items = data[key]
  294. break
  295. if items is None:
  296. # 单个对象,包装为数组
  297. items = [data]
  298. else:
  299. items = None
  300. if items is not None:
  301. with open(jsonl_path, "w", encoding="utf-8") as out:
  302. for item in items:
  303. out.write(_json.dumps(item, ensure_ascii=False) + "\n")
  304. # 只有当原始文件与新文件不同时才删除(避免删除刚写入的文件)
  305. if jsonl_path != file_path and file_path.exists():
  306. file_path.unlink()
  307. return jsonl_path
  308. except _json.JSONDecodeError:
  309. pass
  310. # 不是标准 JSON,可能是 JSONL,逐行验证
  311. lines = content.split("\n")
  312. valid_lines = []
  313. for line in lines:
  314. line = line.strip()
  315. if line:
  316. try:
  317. _json.loads(line)
  318. valid_lines.append(line)
  319. except _json.JSONDecodeError:
  320. continue # 跳过无效行
  321. with open(jsonl_path, "w", encoding="utf-8") as out:
  322. out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
  323. if jsonl_path != file_path and file_path.exists():
  324. file_path.unlink()
  325. return jsonl_path
  326. def _detect_format(filename: str) -> str:
  327. """根据文件名推断格式。"""
  328. name = filename.lower()
  329. if name.endswith(".jsonl"):
  330. return "jsonl"
  331. if name.endswith(".csv"):
  332. return "csv"
  333. if name.endswith(".parquet"):
  334. return "parquet"
  335. if name.endswith(".json"):
  336. return "json"
  337. return "json"
  338. def _count_records(file_path: Path, fmt: str) -> int:
  339. """计算文件中的记录数。"""
  340. import json
  341. if not file_path.exists():
  342. return 0
  343. try:
  344. if fmt == "jsonl":
  345. with open(file_path, "r", encoding="utf-8") as f:
  346. return sum(1 for line in f if line.strip())
  347. elif fmt == "json":
  348. with open(file_path, "r", encoding="utf-8") as f:
  349. data = json.load(f)
  350. if isinstance(data, list):
  351. return len(data)
  352. return 1
  353. elif fmt == "csv":
  354. import csv
  355. with open(file_path, "r", encoding="utf-8") as f:
  356. return sum(1 for _ in csv.DictReader(f))
  357. except Exception:
  358. return 0
  359. return 0