external_service.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966
  1. """
  2. External API Service.
  3. Provides business logic for external system integration.
  4. """
  5. import uuid
  6. import json
  7. import logging
  8. import random
  9. from datetime import datetime
  10. from typing import Optional, List, Dict, Any
  11. from database import get_db_connection
  12. from schemas.external import (
  13. TaskType, ProjectInitRequest, ProjectInitResponse,
  14. ProgressResponse, AnnotatorProgress,
  15. ExternalExportFormat, ExternalExportRequest, ExternalExportResponse,
  16. TaskDataItem, TagItem
  17. )
  18. logger = logging.getLogger(__name__)
  19. def generate_random_color() -> str:
  20. """
  21. 生成随机颜色
  22. Returns:
  23. str: #RRGGBB 格式的颜色字符串
  24. """
  25. return f"#{random.randint(0, 0xFFFFFF):06x}"
  26. # 预定义的颜色列表,用于生成更美观的颜色
  27. PRESET_COLORS = [
  28. "#FF5733", "#33FF57", "#3357FF", "#FF33F5", "#F5FF33",
  29. "#33FFF5", "#FF8C33", "#8C33FF", "#33FF8C", "#FF338C",
  30. "#5733FF", "#57FF33", "#FF3357", "#33F5FF", "#F533FF",
  31. "#8CFF33", "#338CFF", "#FF338C", "#33FF57", "#5733FF"
  32. ]
  33. def get_color_for_tag(index: int, specified_color: Optional[str] = None) -> str:
  34. """
  35. 获取标签颜色
  36. Args:
  37. index: 标签索引,用于从预设颜色中选择
  38. specified_color: 指定的颜色,如果有则直接使用
  39. Returns:
  40. str: #RRGGBB 格式的颜色字符串
  41. """
  42. if specified_color:
  43. return specified_color
  44. if index < len(PRESET_COLORS):
  45. return PRESET_COLORS[index]
  46. return generate_random_color()
  47. # 默认XML配置模板(不含标签,由管理员后续配置)
  48. DEFAULT_CONFIGS = {
  49. TaskType.TEXT_CLASSIFICATION: """<View>
  50. <Text name="text" value="$text"/>
  51. <Choices name="label" toName="text" choice="single">
  52. <!-- 标签由管理员配置 -->
  53. </Choices>
  54. </View>""",
  55. TaskType.IMAGE_CLASSIFICATION: """<View>
  56. <Image name="image" value="$image"/>
  57. <Choices name="label" toName="image" choice="single">
  58. <!-- 标签由管理员配置 -->
  59. </Choices>
  60. </View>""",
  61. TaskType.OBJECT_DETECTION: """<View>
  62. <Image name="image" value="$image"/>
  63. <RectangleLabels name="label" toName="image">
  64. <!-- 标签由管理员配置 -->
  65. </RectangleLabels>
  66. </View>""",
  67. TaskType.NER: """<View>
  68. <Text name="text" value="$text"/>
  69. <Labels name="label" toName="text">
  70. <!-- 标签由管理员配置 -->
  71. </Labels>
  72. </View>""",
  73. TaskType.POLYGON: """<View>
  74. <Image name="image" value="$image"/>
  75. <PolygonLabels name="label" toName="image">
  76. <!-- 标签由管理员配置 -->
  77. </PolygonLabels>
  78. </View>"""
  79. }
  80. def generate_config_with_tags(task_type: TaskType, tags: Optional[List[TagItem]] = None) -> str:
  81. """
  82. 根据任务类型和标签生成XML配置
  83. Args:
  84. task_type: 任务类型
  85. tags: 标签列表,可选
  86. Returns:
  87. str: 生成的XML配置字符串
  88. """
  89. if not tags or len(tags) == 0:
  90. # 没有标签,返回默认配置
  91. return DEFAULT_CONFIGS.get(task_type, DEFAULT_CONFIGS[TaskType.TEXT_CLASSIFICATION])
  92. # 根据任务类型生成带标签的配置
  93. if task_type == TaskType.TEXT_CLASSIFICATION:
  94. labels_xml = "\n".join([
  95. f' <Choice value="{tag.tag}" style="background-color: {get_color_for_tag(i, tag.color)}"/>'
  96. for i, tag in enumerate(tags)
  97. ])
  98. return f"""<View>
  99. <Text name="text" value="$text"/>
  100. <Choices name="label" toName="text" choice="single">
  101. {labels_xml}
  102. </Choices>
  103. </View>"""
  104. elif task_type == TaskType.IMAGE_CLASSIFICATION:
  105. labels_xml = "\n".join([
  106. f' <Choice value="{tag.tag}" style="background-color: {get_color_for_tag(i, tag.color)}"/>'
  107. for i, tag in enumerate(tags)
  108. ])
  109. return f"""<View>
  110. <Image name="image" value="$image"/>
  111. <Choices name="label" toName="image" choice="single">
  112. {labels_xml}
  113. </Choices>
  114. </View>"""
  115. elif task_type == TaskType.OBJECT_DETECTION:
  116. labels_xml = "\n".join([
  117. f' <Label value="{tag.tag}" background="{get_color_for_tag(i, tag.color)}"/>'
  118. for i, tag in enumerate(tags)
  119. ])
  120. return f"""<View>
  121. <Image name="image" value="$image"/>
  122. <RectangleLabels name="label" toName="image">
  123. {labels_xml}
  124. </RectangleLabels>
  125. </View>"""
  126. elif task_type == TaskType.NER:
  127. labels_xml = "\n".join([
  128. f' <Label value="{tag.tag}" background="{get_color_for_tag(i, tag.color)}"/>'
  129. for i, tag in enumerate(tags)
  130. ])
  131. return f"""<View>
  132. <Text name="text" value="$text"/>
  133. <Labels name="label" toName="text">
  134. {labels_xml}
  135. </Labels>
  136. </View>"""
  137. elif task_type == TaskType.POLYGON:
  138. labels_xml = "\n".join([
  139. f' <Label value="{tag.tag}" background="{get_color_for_tag(i, tag.color)}"/>'
  140. for i, tag in enumerate(tags)
  141. ])
  142. return f"""<View>
  143. <Image name="image" value="$image"/>
  144. <PolygonLabels name="label" toName="image">
  145. {labels_xml}
  146. </PolygonLabels>
  147. </View>"""
  148. else:
  149. return DEFAULT_CONFIGS.get(task_type, DEFAULT_CONFIGS[TaskType.TEXT_CLASSIFICATION])
  150. class ExternalService:
  151. """对外API服务类"""
  152. @staticmethod
  153. def get_default_config(task_type: TaskType) -> str:
  154. """获取任务类型对应的默认XML配置"""
  155. return DEFAULT_CONFIGS.get(task_type, DEFAULT_CONFIGS[TaskType.TEXT_CLASSIFICATION])
  156. @staticmethod
  157. def init_project(request: ProjectInitRequest, user_id: str) -> ProjectInitResponse:
  158. """
  159. 初始化项目并创建任务
  160. Args:
  161. request: 项目初始化请求
  162. user_id: 创建者用户ID
  163. Returns:
  164. ProjectInitResponse: 项目初始化响应
  165. """
  166. # 生成项目ID
  167. project_id = f"proj_{uuid.uuid4().hex[:12]}"
  168. # 根据是否有标签生成配置
  169. if request.tags and len(request.tags) > 0:
  170. config = generate_config_with_tags(request.task_type, request.tags)
  171. else:
  172. config = ExternalService.get_default_config(request.task_type)
  173. with get_db_connection() as conn:
  174. cursor = conn.cursor()
  175. # 创建项目
  176. cursor.execute("""
  177. INSERT INTO projects (id, name, description, config, status, source, task_type, external_id, updated_at)
  178. VALUES (?, ?, ?, ?, 'draft', 'external', ?, ?, CURRENT_TIMESTAMP)
  179. """, (
  180. project_id,
  181. request.name,
  182. request.description or "",
  183. config,
  184. request.task_type.value,
  185. request.external_id
  186. ))
  187. # 创建任务
  188. task_count = 0
  189. for i, item in enumerate(request.data):
  190. task_id = f"task_{uuid.uuid4().hex[:12]}"
  191. task_name = f"Task {i + 1}"
  192. # 根据任务类型构建数据格式
  193. if request.task_type in [TaskType.TEXT_CLASSIFICATION, TaskType.NER]:
  194. task_data = {
  195. "text": item.content,
  196. "external_id": item.id,
  197. "metadata": item.metadata or {}
  198. }
  199. else:
  200. task_data = {
  201. "image": item.content,
  202. "external_id": item.id,
  203. "metadata": item.metadata or {}
  204. }
  205. cursor.execute("""
  206. INSERT INTO tasks (id, project_id, name, data, status)
  207. VALUES (?, ?, ?, ?, 'pending')
  208. """, (
  209. task_id,
  210. project_id,
  211. task_name,
  212. json.dumps(task_data)
  213. ))
  214. task_count += 1
  215. # 获取创建时间
  216. cursor.execute("SELECT created_at FROM projects WHERE id = ?", (project_id,))
  217. row = cursor.fetchone()
  218. created_at = row["created_at"] if row else datetime.now()
  219. return ProjectInitResponse(
  220. project_id=project_id,
  221. project_name=request.name,
  222. task_count=task_count,
  223. status="draft",
  224. created_at=created_at,
  225. config=config,
  226. external_id=request.external_id
  227. )
  228. @staticmethod
  229. def get_project_progress(project_id: str) -> Optional[ProgressResponse]:
  230. """
  231. 获取项目进度
  232. Args:
  233. project_id: 项目ID
  234. Returns:
  235. ProgressResponse: 进度响应,如果项目不存在返回None
  236. """
  237. with get_db_connection() as conn:
  238. cursor = conn.cursor()
  239. # 获取项目信息
  240. cursor.execute("""
  241. SELECT id, name, status, updated_at
  242. FROM projects
  243. WHERE id = ?
  244. """, (project_id,))
  245. project = cursor.fetchone()
  246. if not project:
  247. return None
  248. # 获取任务统计
  249. cursor.execute("""
  250. SELECT
  251. COUNT(*) as total,
  252. SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
  253. SUM(CASE WHEN status = 'in_progress' THEN 1 ELSE 0 END) as in_progress,
  254. SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending
  255. FROM tasks
  256. WHERE project_id = ?
  257. """, (project_id,))
  258. stats = cursor.fetchone()
  259. total_tasks = stats["total"] or 0
  260. completed_tasks = stats["completed"] or 0
  261. in_progress_tasks = stats["in_progress"] or 0
  262. pending_tasks = stats["pending"] or 0
  263. # 计算完成百分比
  264. completion_percentage = 0.0
  265. if total_tasks > 0:
  266. completion_percentage = round((completed_tasks / total_tasks) * 100, 2)
  267. # 获取标注人员统计
  268. cursor.execute("""
  269. SELECT
  270. t.assigned_to,
  271. u.username,
  272. COUNT(*) as assigned_count,
  273. SUM(CASE WHEN t.status = 'completed' THEN 1 ELSE 0 END) as completed_count,
  274. SUM(CASE WHEN t.status = 'in_progress' THEN 1 ELSE 0 END) as in_progress_count
  275. FROM tasks t
  276. LEFT JOIN users u ON t.assigned_to = u.id
  277. WHERE t.project_id = ? AND t.assigned_to IS NOT NULL
  278. GROUP BY t.assigned_to, u.username
  279. """, (project_id,))
  280. annotators = []
  281. for row in cursor.fetchall():
  282. assigned_count = row["assigned_count"] or 0
  283. completed_count = row["completed_count"] or 0
  284. completion_rate = 0.0
  285. if assigned_count > 0:
  286. completion_rate = round((completed_count / assigned_count) * 100, 2)
  287. annotators.append(AnnotatorProgress(
  288. user_id=row["assigned_to"] or "",
  289. username=row["username"] or "Unknown",
  290. assigned_count=assigned_count,
  291. completed_count=completed_count,
  292. in_progress_count=row["in_progress_count"] or 0,
  293. completion_rate=completion_rate
  294. ))
  295. return ProgressResponse(
  296. project_id=project_id,
  297. project_name=project["name"],
  298. status=project["status"] or "draft",
  299. total_tasks=total_tasks,
  300. completed_tasks=completed_tasks,
  301. in_progress_tasks=in_progress_tasks,
  302. pending_tasks=pending_tasks,
  303. completion_percentage=completion_percentage,
  304. annotators=annotators,
  305. last_updated=project["updated_at"]
  306. )
  307. @staticmethod
  308. def check_project_exists(project_id: str) -> bool:
  309. """检查项目是否存在"""
  310. with get_db_connection() as conn:
  311. cursor = conn.cursor()
  312. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  313. return cursor.fetchone() is not None
  314. @staticmethod
  315. def export_project_data(
  316. project_id: str,
  317. request: ExternalExportRequest,
  318. base_url: str = ""
  319. ) -> Optional[ExternalExportResponse]:
  320. """
  321. 导出项目数据
  322. Args:
  323. project_id: 项目ID
  324. request: 导出请求
  325. base_url: 基础URL,用于生成下载链接
  326. Returns:
  327. ExternalExportResponse: 导出响应,如果项目不存在返回None
  328. """
  329. import os
  330. from datetime import timedelta
  331. with get_db_connection() as conn:
  332. cursor = conn.cursor()
  333. # 检查项目是否存在
  334. cursor.execute("SELECT id, name FROM projects WHERE id = ?", (project_id,))
  335. project = cursor.fetchone()
  336. if not project:
  337. return None
  338. # 构建查询条件
  339. status_filter = ""
  340. if request.completed_only:
  341. status_filter = "AND t.status = 'completed'"
  342. # 获取任务和标注数据
  343. cursor.execute(f"""
  344. SELECT
  345. t.id as task_id,
  346. t.data,
  347. t.status,
  348. t.assigned_to,
  349. u.username as annotator_name,
  350. a.id as annotation_id,
  351. a.result as annotation_result,
  352. a.updated_at as annotation_time
  353. FROM tasks t
  354. LEFT JOIN users u ON t.assigned_to = u.id
  355. LEFT JOIN annotations a ON t.id = a.task_id
  356. WHERE t.project_id = ? {status_filter}
  357. ORDER BY t.id
  358. """, (project_id,))
  359. rows = cursor.fetchall()
  360. # 组织数据
  361. tasks_data = {}
  362. for row in rows:
  363. task_id = row["task_id"]
  364. if task_id not in tasks_data:
  365. task_data = json.loads(row["data"]) if row["data"] else {}
  366. tasks_data[task_id] = {
  367. "task_id": task_id,
  368. "external_id": task_data.get("external_id"),
  369. "original_data": task_data,
  370. "annotations": [],
  371. "status": row["status"],
  372. "annotator": row["annotator_name"],
  373. "completed_at": None
  374. }
  375. if row["annotation_id"]:
  376. annotation_result = json.loads(row["annotation_result"]) if row["annotation_result"] else {}
  377. tasks_data[task_id]["annotations"].append(annotation_result)
  378. if row["annotation_time"]:
  379. tasks_data[task_id]["completed_at"] = row["annotation_time"]
  380. # 转换为列表
  381. export_data = list(tasks_data.values())
  382. total_exported = len(export_data)
  383. # 生成导出文件
  384. export_id = f"export_{uuid.uuid4().hex[:12]}"
  385. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  386. # 确保导出目录存在
  387. export_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "exports")
  388. os.makedirs(export_dir, exist_ok=True)
  389. # 根据格式生成文件
  390. file_name, file_content = ExternalService._generate_export_file(
  391. export_data,
  392. request.format,
  393. project_id,
  394. timestamp
  395. )
  396. file_path = os.path.join(export_dir, file_name)
  397. # 写入文件
  398. if isinstance(file_content, str):
  399. with open(file_path, 'w', encoding='utf-8') as f:
  400. f.write(file_content)
  401. else:
  402. with open(file_path, 'wb') as f:
  403. f.write(file_content)
  404. file_size = os.path.getsize(file_path)
  405. # 记录导出任务
  406. cursor.execute("""
  407. INSERT INTO export_jobs (id, project_id, format, status, status_filter, file_path, total_tasks, exported_tasks, completed_at)
  408. VALUES (?, ?, ?, 'completed', ?, ?, ?, ?, CURRENT_TIMESTAMP)
  409. """, (
  410. export_id,
  411. project_id,
  412. request.format.value,
  413. 'completed' if request.completed_only else 'all',
  414. file_path,
  415. total_exported,
  416. total_exported
  417. ))
  418. # 计算过期时间(7天后)
  419. expires_at = datetime.now() + timedelta(days=7)
  420. return ExternalExportResponse(
  421. project_id=project_id,
  422. format=request.format.value,
  423. total_exported=total_exported,
  424. file_url=f"/api/exports/{export_id}/download",
  425. file_name=file_name,
  426. file_size=file_size,
  427. expires_at=expires_at
  428. )
  429. @staticmethod
  430. def _generate_export_file(
  431. data: List[Dict],
  432. format: ExternalExportFormat,
  433. project_id: str,
  434. timestamp: str
  435. ) -> tuple:
  436. """
  437. 根据格式生成导出文件
  438. Returns:
  439. tuple: (文件名, 文件内容)
  440. """
  441. if format == ExternalExportFormat.JSON:
  442. return ExternalService._export_json(data, project_id, timestamp)
  443. elif format == ExternalExportFormat.CSV:
  444. return ExternalService._export_csv(data, project_id, timestamp)
  445. elif format == ExternalExportFormat.SHAREGPT:
  446. return ExternalService._export_sharegpt(data, project_id, timestamp)
  447. elif format == ExternalExportFormat.YOLO:
  448. return ExternalService._export_yolo(data, project_id, timestamp)
  449. elif format == ExternalExportFormat.COCO:
  450. return ExternalService._export_coco(data, project_id, timestamp)
  451. elif format == ExternalExportFormat.ALPACA:
  452. return ExternalService._export_alpaca(data, project_id, timestamp)
  453. elif format == ExternalExportFormat.PASCAL_VOC:
  454. return ExternalService._export_pascal_voc(data, project_id, timestamp)
  455. else:
  456. return ExternalService._export_json(data, project_id, timestamp)
  457. @staticmethod
  458. def _export_json(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  459. """导出JSON格式"""
  460. file_name = f"export_{project_id}_{timestamp}.json"
  461. content = json.dumps(data, ensure_ascii=False, indent=2)
  462. return file_name, content
  463. @staticmethod
  464. def _export_csv(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  465. """导出CSV格式"""
  466. import csv
  467. import io
  468. file_name = f"export_{project_id}_{timestamp}.csv"
  469. output = io.StringIO()
  470. writer = csv.writer(output)
  471. # 写入表头
  472. writer.writerow(['task_id', 'external_id', 'status', 'annotator', 'original_data', 'annotations'])
  473. # 写入数据
  474. for item in data:
  475. writer.writerow([
  476. item.get('task_id', ''),
  477. item.get('external_id', ''),
  478. item.get('status', ''),
  479. item.get('annotator', ''),
  480. json.dumps(item.get('original_data', {}), ensure_ascii=False),
  481. json.dumps(item.get('annotations', []), ensure_ascii=False)
  482. ])
  483. return file_name, output.getvalue()
  484. @staticmethod
  485. def _export_sharegpt(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  486. """导出ShareGPT对话格式"""
  487. file_name = f"export_{project_id}_sharegpt_{timestamp}.json"
  488. conversations = []
  489. for item in data:
  490. original = item.get('original_data', {})
  491. annotations = item.get('annotations', [])
  492. # 获取原始文本
  493. text = original.get('text', original.get('image', ''))
  494. # 获取标注结果
  495. label = ""
  496. if annotations:
  497. for ann in annotations:
  498. if isinstance(ann, list):
  499. for a in ann:
  500. if 'value' in a and 'choices' in a['value']:
  501. label = ', '.join(a['value']['choices'])
  502. break
  503. elif isinstance(ann, dict):
  504. if 'value' in ann and 'choices' in ann['value']:
  505. label = ', '.join(ann['value']['choices'])
  506. if text and label:
  507. conversations.append({
  508. "conversations": [
  509. {"from": "human", "value": text},
  510. {"from": "gpt", "value": label}
  511. ]
  512. })
  513. content = json.dumps(conversations, ensure_ascii=False, indent=2)
  514. return file_name, content
  515. @staticmethod
  516. def _export_yolo(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  517. """导出YOLO格式(简化版,返回JSON描述)"""
  518. file_name = f"export_{project_id}_yolo_{timestamp}.json"
  519. yolo_data = []
  520. for item in data:
  521. original = item.get('original_data', {})
  522. annotations = item.get('annotations', [])
  523. image_url = original.get('image', '')
  524. boxes = []
  525. polygons = []
  526. for ann in annotations:
  527. if isinstance(ann, list):
  528. for a in ann:
  529. if a.get('type') == 'rectanglelabels':
  530. value = a.get('value', {})
  531. boxes.append({
  532. "label": value.get('rectanglelabels', [''])[0],
  533. "x": value.get('x', 0) / 100,
  534. "y": value.get('y', 0) / 100,
  535. "width": value.get('width', 0) / 100,
  536. "height": value.get('height', 0) / 100
  537. })
  538. elif a.get('type') == 'polygonlabels':
  539. value = a.get('value', {})
  540. points = value.get('points', [])
  541. # 将点坐标归一化到0-1范围
  542. normalized_points = [[p[0] / 100, p[1] / 100] for p in points]
  543. polygons.append({
  544. "label": value.get('polygonlabels', [''])[0],
  545. "points": normalized_points
  546. })
  547. if image_url:
  548. entry = {"image": image_url}
  549. if boxes:
  550. entry["boxes"] = boxes
  551. if polygons:
  552. entry["polygons"] = polygons
  553. yolo_data.append(entry)
  554. content = json.dumps(yolo_data, ensure_ascii=False, indent=2)
  555. return file_name, content
  556. @staticmethod
  557. def _export_coco(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  558. """导出COCO格式"""
  559. file_name = f"export_{project_id}_coco_{timestamp}.json"
  560. coco_data = {
  561. "images": [],
  562. "annotations": [],
  563. "categories": []
  564. }
  565. category_map = {}
  566. annotation_id = 1
  567. for idx, item in enumerate(data):
  568. original = item.get('original_data', {})
  569. annotations = item.get('annotations', [])
  570. image_url = original.get('image', '')
  571. # 添加图像
  572. coco_data["images"].append({
  573. "id": idx + 1,
  574. "file_name": image_url,
  575. "width": 0,
  576. "height": 0
  577. })
  578. # 处理标注
  579. for ann in annotations:
  580. if isinstance(ann, list):
  581. for a in ann:
  582. ann_type = a.get('type', '')
  583. value = a.get('value', {})
  584. if ann_type == 'rectanglelabels':
  585. label = value.get('rectanglelabels', [''])[0]
  586. # 添加类别
  587. if label and label not in category_map:
  588. cat_id = len(category_map) + 1
  589. category_map[label] = cat_id
  590. coco_data["categories"].append({
  591. "id": cat_id,
  592. "name": label
  593. })
  594. if label:
  595. coco_data["annotations"].append({
  596. "id": annotation_id,
  597. "image_id": idx + 1,
  598. "category_id": category_map.get(label, 0),
  599. "bbox": [
  600. value.get('x', 0),
  601. value.get('y', 0),
  602. value.get('width', 0),
  603. value.get('height', 0)
  604. ],
  605. "area": value.get('width', 0) * value.get('height', 0),
  606. "iscrowd": 0
  607. })
  608. annotation_id += 1
  609. elif ann_type == 'polygonlabels':
  610. label = value.get('polygonlabels', [''])[0]
  611. points = value.get('points', [])
  612. # 添加类别
  613. if label and label not in category_map:
  614. cat_id = len(category_map) + 1
  615. category_map[label] = cat_id
  616. coco_data["categories"].append({
  617. "id": cat_id,
  618. "name": label
  619. })
  620. if label and points:
  621. # 将点列表转换为COCO segmentation格式 [x1, y1, x2, y2, ...]
  622. segmentation = []
  623. for p in points:
  624. segmentation.extend([p[0], p[1]])
  625. # 计算边界框
  626. x_coords = [p[0] for p in points]
  627. y_coords = [p[1] for p in points]
  628. x_min, x_max = min(x_coords), max(x_coords)
  629. y_min, y_max = min(y_coords), max(y_coords)
  630. width = x_max - x_min
  631. height = y_max - y_min
  632. # 计算面积(使用鞋带公式)
  633. n = len(points)
  634. area = 0
  635. for i in range(n):
  636. j = (i + 1) % n
  637. area += points[i][0] * points[j][1]
  638. area -= points[j][0] * points[i][1]
  639. area = abs(area) / 2
  640. coco_data["annotations"].append({
  641. "id": annotation_id,
  642. "image_id": idx + 1,
  643. "category_id": category_map.get(label, 0),
  644. "segmentation": [segmentation],
  645. "bbox": [x_min, y_min, width, height],
  646. "area": area,
  647. "iscrowd": 0
  648. })
  649. annotation_id += 1
  650. content = json.dumps(coco_data, ensure_ascii=False, indent=2)
  651. return file_name, content
  652. @staticmethod
  653. def _export_alpaca(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  654. """导出Alpaca指令微调格式"""
  655. file_name = f"export_{project_id}_alpaca_{timestamp}.json"
  656. alpaca_data = []
  657. for item in data:
  658. original = item.get('original_data', {})
  659. annotations = item.get('annotations', [])
  660. # 获取原始文本
  661. text = original.get('text', '')
  662. # 获取标注结果
  663. label = ""
  664. if annotations:
  665. for ann in annotations:
  666. if isinstance(ann, list):
  667. for a in ann:
  668. if 'value' in a and 'choices' in a['value']:
  669. label = ', '.join(a['value']['choices'])
  670. break
  671. elif isinstance(ann, dict):
  672. if 'value' in ann and 'choices' in ann['value']:
  673. label = ', '.join(ann['value']['choices'])
  674. if text:
  675. alpaca_data.append({
  676. "instruction": "请对以下文本进行分类",
  677. "input": text,
  678. "output": label or "未标注"
  679. })
  680. content = json.dumps(alpaca_data, ensure_ascii=False, indent=2)
  681. return file_name, content
  682. @staticmethod
  683. def _export_pascal_voc(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  684. """
  685. 导出PascalVOC XML格式
  686. PascalVOC格式是一种常用的目标检测数据集格式,每张图片对应一个XML文件。
  687. 由于我们需要返回单个文件,这里返回一个包含所有标注的JSON文件,
  688. 其中每个条目包含对应的PascalVOC XML内容。
  689. """
  690. file_name = f"export_{project_id}_pascal_voc_{timestamp}.json"
  691. voc_data = []
  692. for idx, item in enumerate(data):
  693. original = item.get('original_data', {})
  694. annotations = item.get('annotations', [])
  695. image_url = original.get('image', '')
  696. # 从URL中提取文件名
  697. image_filename = image_url.split('/')[-1] if image_url else f"image_{idx + 1}.jpg"
  698. # 获取图像尺寸(如果有的话)
  699. img_width = original.get('width', 0)
  700. img_height = original.get('height', 0)
  701. objects = []
  702. # 处理标注
  703. for ann in annotations:
  704. if isinstance(ann, list):
  705. for a in ann:
  706. ann_type = a.get('type', '')
  707. value = a.get('value', {})
  708. if ann_type == 'rectanglelabels':
  709. label = value.get('rectanglelabels', [''])[0]
  710. if label:
  711. # 转换百分比坐标为像素坐标
  712. x_pct = value.get('x', 0)
  713. y_pct = value.get('y', 0)
  714. w_pct = value.get('width', 0)
  715. h_pct = value.get('height', 0)
  716. # 如果有图像尺寸,转换为像素;否则保持百分比
  717. if img_width > 0 and img_height > 0:
  718. xmin = int(x_pct * img_width / 100)
  719. ymin = int(y_pct * img_height / 100)
  720. xmax = int((x_pct + w_pct) * img_width / 100)
  721. ymax = int((y_pct + h_pct) * img_height / 100)
  722. else:
  723. xmin = x_pct
  724. ymin = y_pct
  725. xmax = x_pct + w_pct
  726. ymax = y_pct + h_pct
  727. objects.append({
  728. "name": label,
  729. "pose": "Unspecified",
  730. "truncated": 0,
  731. "difficult": 0,
  732. "bndbox": {
  733. "xmin": xmin,
  734. "ymin": ymin,
  735. "xmax": xmax,
  736. "ymax": ymax
  737. }
  738. })
  739. elif ann_type == 'polygonlabels':
  740. label = value.get('polygonlabels', [''])[0]
  741. points = value.get('points', [])
  742. if label and points:
  743. # 计算边界框
  744. x_coords = [p[0] for p in points]
  745. y_coords = [p[1] for p in points]
  746. if img_width > 0 and img_height > 0:
  747. xmin = int(min(x_coords) * img_width / 100)
  748. ymin = int(min(y_coords) * img_height / 100)
  749. xmax = int(max(x_coords) * img_width / 100)
  750. ymax = int(max(y_coords) * img_height / 100)
  751. else:
  752. xmin = min(x_coords)
  753. ymin = min(y_coords)
  754. xmax = max(x_coords)
  755. ymax = max(y_coords)
  756. # 转换多边形点坐标
  757. if img_width > 0 and img_height > 0:
  758. polygon_points = [[int(p[0] * img_width / 100), int(p[1] * img_height / 100)] for p in points]
  759. else:
  760. polygon_points = points
  761. objects.append({
  762. "name": label,
  763. "pose": "Unspecified",
  764. "truncated": 0,
  765. "difficult": 0,
  766. "bndbox": {
  767. "xmin": xmin,
  768. "ymin": ymin,
  769. "xmax": xmax,
  770. "ymax": ymax
  771. },
  772. "polygon": polygon_points
  773. })
  774. # 生成PascalVOC XML内容
  775. xml_content = ExternalService._generate_voc_xml(
  776. image_filename,
  777. img_width or 0,
  778. img_height or 0,
  779. objects
  780. )
  781. voc_data.append({
  782. "image": image_url,
  783. "filename": image_filename,
  784. "xml_content": xml_content,
  785. "objects": objects
  786. })
  787. content = json.dumps(voc_data, ensure_ascii=False, indent=2)
  788. return file_name, content
  789. @staticmethod
  790. def _generate_voc_xml(filename: str, width: int, height: int, objects: List[Dict]) -> str:
  791. """生成PascalVOC格式的XML字符串"""
  792. xml_lines = [
  793. '<?xml version="1.0" encoding="UTF-8"?>',
  794. '<annotation>',
  795. f' <filename>{filename}</filename>',
  796. ' <source>',
  797. ' <database>Annotation Platform</database>',
  798. ' </source>',
  799. ' <size>',
  800. f' <width>{width}</width>',
  801. f' <height>{height}</height>',
  802. ' <depth>3</depth>',
  803. ' </size>',
  804. ' <segmented>0</segmented>'
  805. ]
  806. for obj in objects:
  807. xml_lines.append(' <object>')
  808. xml_lines.append(f' <name>{obj["name"]}</name>')
  809. xml_lines.append(f' <pose>{obj.get("pose", "Unspecified")}</pose>')
  810. xml_lines.append(f' <truncated>{obj.get("truncated", 0)}</truncated>')
  811. xml_lines.append(f' <difficult>{obj.get("difficult", 0)}</difficult>')
  812. xml_lines.append(' <bndbox>')
  813. xml_lines.append(f' <xmin>{obj["bndbox"]["xmin"]}</xmin>')
  814. xml_lines.append(f' <ymin>{obj["bndbox"]["ymin"]}</ymin>')
  815. xml_lines.append(f' <xmax>{obj["bndbox"]["xmax"]}</xmax>')
  816. xml_lines.append(f' <ymax>{obj["bndbox"]["ymax"]}</ymax>')
  817. xml_lines.append(' </bndbox>')
  818. # 如果有多边形数据,也添加进去
  819. if 'polygon' in obj:
  820. xml_lines.append(' <polygon>')
  821. for point in obj['polygon']:
  822. xml_lines.append(f' <pt><x>{point[0]}</x><y>{point[1]}</y></pt>')
  823. xml_lines.append(' </polygon>')
  824. xml_lines.append(' </object>')
  825. xml_lines.append('</annotation>')
  826. return '\n'.join(xml_lines)