annotation_platform_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. """标注平台 API 客户端服务。
  2. 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
  3. 参考文档:标注平台对外API接口文档.md
  4. 功能:列出项目、获取项目详情、数据集导出与下载。
  5. """
  6. import hashlib
  7. import hmac
  8. import secrets
  9. import time
  10. import uuid
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Any
  14. import httpx
  15. from app.config import get_settings
  16. from app.core.db import async_session, DatasetRecord
  17. from app.core.logging import logger
  18. settings = get_settings()
  19. # Token 缓存(内存中)
  20. _token_cache: dict[str, Any] = {}
  21. def _get_base_url() -> str:
  22. base_url = settings.annotation_platform_base_url
  23. if not base_url:
  24. raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
  25. return base_url.rstrip("/")
  26. def _get_credentials() -> tuple[str, str]:
  27. app_id = settings.annotation_platform_app_id
  28. app_secret = settings.annotation_platform_app_secret
  29. if not app_id or not app_secret:
  30. raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
  31. return app_id, app_secret
  32. def _build_token_headers() -> dict[str, str]:
  33. """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
  34. app_id, app_secret = _get_credentials()
  35. timestamp = str(int(time.time()))
  36. nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
  37. message = app_id + timestamp + nonce
  38. signature = hmac.new(
  39. key=app_secret.encode("utf-8"),
  40. msg=message.encode("utf-8"),
  41. digestmod=hashlib.sha256,
  42. ).hexdigest()
  43. return {
  44. "Content-Type": "application/json",
  45. "X-Api-Key": app_id,
  46. "X-Timestamp": timestamp,
  47. "X-Nonce": nonce,
  48. "X-Signature": signature,
  49. }
  50. def _is_token_valid() -> bool:
  51. """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
  52. if not _token_cache.get("access_token"):
  53. return False
  54. expires_at = _token_cache.get("expires_at", 0)
  55. return time.time() < expires_at - 300
  56. async def _refresh_token() -> str:
  57. """使用 Bearer Token 刷新 Access Token。
  58. POST /api/v1/open/auth/refresh
  59. 比重新签名更高效,仅在 Token 存在但即将过期时调用。
  60. """
  61. old_token = _token_cache.get("access_token", "")
  62. base_url = _get_base_url()
  63. async with httpx.AsyncClient(timeout=30) as client:
  64. resp = await client.post(
  65. f"{base_url}/api/v1/open/auth/refresh",
  66. headers={"Authorization": f"Bearer {old_token}"},
  67. )
  68. resp.raise_for_status()
  69. body = resp.json()
  70. if body.get("code") != 0:
  71. raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
  72. data = body.get("data", {})
  73. _token_cache["access_token"] = data["access_token"]
  74. _token_cache["expires_in"] = data.get("expires_in", 7200)
  75. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  76. return data["access_token"]
  77. async def get_token() -> str:
  78. """获取 Access Token,带缓存和自动刷新。
  79. 优先级:
  80. 1. Token 有效 → 直接返回
  81. 2. Token 存在但快过期 → 调用 /auth/refresh 刷新
  82. 3. 无 Token → 调用 /auth/token 重新签名获取
  83. """
  84. if _is_token_valid():
  85. return _token_cache["access_token"]
  86. # Token 存在但即将过期,尝试刷新
  87. if _token_cache.get("access_token"):
  88. try:
  89. return await _refresh_token()
  90. except Exception as e:
  91. logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
  92. # 无 Token 或刷新失败,重新签名获取
  93. headers = _build_token_headers()
  94. base_url = _get_base_url()
  95. async with httpx.AsyncClient(timeout=30) as client:
  96. resp = await client.post(
  97. f"{base_url}/api/v1/open/auth/token",
  98. json={},
  99. headers=headers,
  100. )
  101. resp.raise_for_status()
  102. body = resp.json()
  103. # 标注平台返回 code: 0 表示成功
  104. if body.get("code") != 0:
  105. raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
  106. data = body.get("data", {})
  107. _token_cache["access_token"] = data["access_token"]
  108. _token_cache["expires_in"] = data.get("expires_in", 7200)
  109. _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
  110. return data["access_token"]
  111. def _auth_headers() -> dict[str, str]:
  112. """构建业务接口的认证请求头。"""
  113. return {
  114. "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
  115. "Content-Type": "application/json",
  116. }
  117. async def _request(method: str, path: str, raise_on_error: bool = True, **kwargs) -> dict[str, Any]:
  118. """统一的业务请求方法,自动携带 Token。
  119. raise_on_error=False 时,400 等错误不抛异常,返回原始响应体。
  120. """
  121. await get_token()
  122. base_url = _get_base_url()
  123. async with httpx.AsyncClient(timeout=60) as client:
  124. resp = await client.request(
  125. method,
  126. f"{base_url}{path}",
  127. headers=_auth_headers(),
  128. **kwargs,
  129. )
  130. if not raise_on_error and resp.status_code >= 400:
  131. try:
  132. body = resp.json()
  133. except Exception:
  134. body = {"status_code": resp.status_code, "text": resp.text}
  135. return {"_error": True, "_status_code": resp.status_code, **body}
  136. resp.raise_for_status()
  137. body = resp.json()
  138. if body.get("code") != 0:
  139. raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
  140. return body.get("data", {})
  141. # ---------- 项目列表 ----------
  142. async def list_projects(
  143. page: int = 1,
  144. page_size: int = 20,
  145. name: str | None = None,
  146. project_type: str | None = None,
  147. status: str | None = None,
  148. ) -> dict[str, Any]:
  149. """获取标注平台项目列表。
  150. GET /api/v1/open/projects
  151. """
  152. params: dict[str, Any] = {"page": page, "page_size": page_size}
  153. if name:
  154. params["name"] = name
  155. if project_type:
  156. params["type"] = project_type
  157. if status:
  158. params["status"] = status
  159. return await _request("GET", "/api/v1/open/projects", params=params)
  160. # ---------- 项目详情 ----------
  161. async def get_project_detail(project_id: str) -> dict[str, Any]:
  162. """获取项目详情。
  163. GET /api/v1/open/projects/{project_id}
  164. """
  165. return await _request("GET", f"/api/v1/open/projects/{project_id}")
  166. # ---------- 数据集导出与下载 ----------
  167. async def import_project_dataset(
  168. project_id: str,
  169. project_name: str = "",
  170. format: str = "alpaca",
  171. ) -> dict[str, Any]:
  172. """导出并下载项目数据集,保存到本地并写入数据库。
  173. 流程:
  174. 1. 查询项目详情获取 task_type
  175. 2. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
  176. 3. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
  177. 4. 保存到 uploads 目录
  178. 5. 写入 DatasetRecord 数据库
  179. """
  180. # 1. 查询项目详情,获取 task_type 和 project_type
  181. try:
  182. project_detail = await get_project_detail(project_id)
  183. task_type = project_detail.get("task_type", "")
  184. project_type = project_detail.get("project_type", "text")
  185. logger.info(f"Project {project_id}: task_type={task_type}, project_type={project_type}")
  186. except Exception as e:
  187. logger.warning(f"Failed to get project detail: {e}, using default formats")
  188. task_type = ""
  189. project_type = "text"
  190. # 2. 根据项目类型选择导出格式
  191. # 文本项目: alpaca / sharegpt
  192. # 图片项目: json / csv / coco / yolo / pascal_voc
  193. if project_type == "text":
  194. formats_to_try = ["alpaca", "sharegpt"]
  195. else:
  196. formats_to_try = ["json", "csv", "coco"]
  197. # 如果用户指定了格式,优先使用
  198. if format and format not in formats_to_try:
  199. formats_to_try.insert(0, format)
  200. file_content = b""
  201. file_name = ""
  202. total_exported = 0
  203. used_format = ""
  204. last_error = ""
  205. for try_format in formats_to_try:
  206. # 请求导出
  207. export_data = await _request(
  208. "POST",
  209. f"/api/v1/open/projects/{project_id}/datasets/download",
  210. raise_on_error=False,
  211. json={"format": try_format, "completed_only": True},
  212. )
  213. # 检查是否返回错误
  214. if export_data.get("_error"):
  215. status_code = export_data.get("_status_code", 0)
  216. error_msg = export_data.get("message", export_data)
  217. last_error = f"HTTP {status_code}: {error_msg}"
  218. logger.warning(f"Format '{try_format}' failed: {last_error}, trying next...")
  219. continue
  220. logger.info(f"Annotation export response (format={try_format}): {export_data}")
  221. file_url = export_data.get("file_url", "")
  222. file_name = export_data.get("file_name", f"{project_id}_{try_format}.json")
  223. total_exported = export_data.get("total_exported", 0)
  224. export_status = export_data.get("status", "completed")
  225. # 检查导出状态
  226. if export_status != "completed":
  227. logger.warning(f"Export status={export_status} for format={try_format}, trying next...")
  228. last_error = f"导出状态: {export_status}"
  229. continue
  230. if not file_url:
  231. logger.warning(f"No file_url for format={try_format}, trying next...")
  232. last_error = "未返回下载链接"
  233. continue
  234. # 从 file_url 中提取 download_token
  235. if "/datasets/downloads/" in file_url:
  236. download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
  237. else:
  238. download_token = file_url.rstrip("/").split("/")[-1]
  239. # 下载文件,带轮询(标注平台生成文件可能需要时间)
  240. await get_token()
  241. base_url = _get_base_url()
  242. download_url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
  243. file_content = b""
  244. max_retries = 4
  245. for attempt in range(max_retries):
  246. async with httpx.AsyncClient(timeout=120) as client:
  247. resp = await client.get(
  248. download_url,
  249. headers=_auth_headers(),
  250. follow_redirects=False,
  251. )
  252. redirect_count = 0
  253. while resp.is_redirect and redirect_count < 5:
  254. redirect_url = resp.next_request.url
  255. logger.info(f"Download redirect to: {redirect_url}")
  256. resp = await client.get(
  257. str(redirect_url),
  258. headers=_auth_headers(),
  259. follow_redirects=False,
  260. )
  261. redirect_count += 1
  262. resp.raise_for_status()
  263. file_content = resp.content
  264. if len(file_content) > 10:
  265. break
  266. if attempt < max_retries - 1:
  267. import asyncio
  268. wait = 2 ** attempt # 1, 2, 4 秒
  269. logger.info(
  270. f"Download attempt {attempt + 1}/{max_retries} (format={try_format}): "
  271. f"file too small ({len(file_content)} bytes), retrying in {wait}s..."
  272. )
  273. await asyncio.sleep(wait)
  274. logger.info(
  275. f"Downloaded (format={try_format}): {len(file_content)} bytes, "
  276. f"content_type={resp.headers.get('content-type', 'unknown')}"
  277. )
  278. if len(file_content) > 10:
  279. used_format = try_format
  280. break
  281. logger.warning(
  282. f"Format '{try_format}' returned empty file ({len(file_content)} bytes), "
  283. f"trying next format..."
  284. )
  285. last_error = f"格式 {try_format} 导出文件为空"
  286. if len(file_content) <= 10:
  287. logger.warning(f"Annotation file content: {file_content!r}")
  288. raise RuntimeError(
  289. f"标注平台导出文件为空(task_type={task_type}, 尝试了格式: {formats_to_try}),"
  290. f"total_exported={total_exported}。最后错误: {last_error}"
  291. )
  292. # 保存到 uploads 目录
  293. upload_dir = settings.uploads_dir
  294. upload_dir.mkdir(parents=True, exist_ok=True)
  295. safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id))
  296. file_path = upload_dir / f"{safe_name}.jsonl"
  297. if file_path.exists():
  298. file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl"
  299. file_path.write_bytes(file_content)
  300. # 统一转为 JSONL 格式
  301. jsonl_path = _convert_to_jsonl(file_path)
  302. record_count = _count_records(jsonl_path, "jsonl")
  303. logger.info(f"Annotation file converted: {jsonl_path.name}, record_count={record_count}, format={used_format}")
  304. # 6. 写入数据库(格式统一为 jsonl)
  305. record_id = str(uuid.uuid4())
  306. record = DatasetRecord(
  307. id=record_id,
  308. name=project_name or project_id,
  309. format="jsonl",
  310. record_count=record_count,
  311. file_path=str(jsonl_path),
  312. created_at=datetime.utcnow(),
  313. )
  314. async with async_session() as session:
  315. session.add(record)
  316. await session.commit()
  317. logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)")
  318. return {
  319. "project_id": project_id,
  320. "project_name": project_name or project_id,
  321. "format": "jsonl",
  322. "total_exported": total_exported,
  323. "dataset_id": record_id,
  324. "dataset_name": jsonl_path.name,
  325. }
  326. def _convert_to_jsonl(file_path: Path) -> Path:
  327. """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
  328. import json as _json
  329. jsonl_path = file_path.with_suffix(".jsonl")
  330. with open(file_path, "r", encoding="utf-8") as f:
  331. content = f.read().strip()
  332. if not content:
  333. jsonl_path.touch()
  334. return jsonl_path
  335. try:
  336. # 尝试作为 JSON 解析
  337. data = _json.loads(content)
  338. if isinstance(data, list):
  339. # JSON 数组
  340. items = data
  341. elif isinstance(data, dict):
  342. # JSON 对象:查找嵌套的数组字段
  343. items = None
  344. for key in ("data", "items", "results", "records", "annotations", "samples"):
  345. if key in data and isinstance(data[key], list):
  346. items = data[key]
  347. break
  348. if items is None:
  349. # 单个对象,包装为数组
  350. items = [data]
  351. else:
  352. items = None
  353. if items is not None:
  354. with open(jsonl_path, "w", encoding="utf-8") as out:
  355. for item in items:
  356. out.write(_json.dumps(item, ensure_ascii=False) + "\n")
  357. # 只有当原始文件与新文件不同时才删除(避免删除刚写入的文件)
  358. if jsonl_path != file_path and file_path.exists():
  359. file_path.unlink()
  360. return jsonl_path
  361. except _json.JSONDecodeError:
  362. pass
  363. # 不是标准 JSON,可能是 JSONL,逐行验证
  364. lines = content.split("\n")
  365. valid_lines = []
  366. for line in lines:
  367. line = line.strip()
  368. if line:
  369. try:
  370. _json.loads(line)
  371. valid_lines.append(line)
  372. except _json.JSONDecodeError:
  373. continue # 跳过无效行
  374. with open(jsonl_path, "w", encoding="utf-8") as out:
  375. out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
  376. if jsonl_path != file_path and file_path.exists():
  377. file_path.unlink()
  378. return jsonl_path
  379. def _detect_format(filename: str) -> str:
  380. """根据文件名推断格式。"""
  381. name = filename.lower()
  382. if name.endswith(".jsonl"):
  383. return "jsonl"
  384. if name.endswith(".csv"):
  385. return "csv"
  386. if name.endswith(".parquet"):
  387. return "parquet"
  388. if name.endswith(".json"):
  389. return "json"
  390. return "json"
  391. def _count_records(file_path: Path, fmt: str) -> int:
  392. """计算文件中的记录数。"""
  393. import json
  394. if not file_path.exists():
  395. return 0
  396. try:
  397. if fmt == "jsonl":
  398. with open(file_path, "r", encoding="utf-8") as f:
  399. return sum(1 for line in f if line.strip())
  400. elif fmt == "json":
  401. with open(file_path, "r", encoding="utf-8") as f:
  402. data = json.load(f)
  403. if isinstance(data, list):
  404. return len(data)
  405. return 1
  406. elif fmt == "csv":
  407. import csv
  408. with open(file_path, "r", encoding="utf-8") as f:
  409. return sum(1 for _ in csv.DictReader(f))
  410. except Exception:
  411. return 0
  412. return 0