external_service.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. """
  2. External API Service.
  3. Provides business logic for external system integration.
  4. """
  5. import uuid
  6. import json
  7. import logging
  8. from datetime import datetime
  9. from typing import Optional, List, Dict, Any
  10. from database import get_db_connection
  11. from schemas.external import (
  12. TaskType, ProjectInitRequest, ProjectInitResponse,
  13. ProgressResponse, AnnotatorProgress,
  14. ExternalExportFormat, ExternalExportRequest, ExternalExportResponse,
  15. TaskDataItem
  16. )
  17. logger = logging.getLogger(__name__)
  18. # 默认XML配置模板(不含标签,由管理员后续配置)
  19. DEFAULT_CONFIGS = {
  20. TaskType.TEXT_CLASSIFICATION: """<View>
  21. <Text name="text" value="$text"/>
  22. <Choices name="label" toName="text" choice="single">
  23. <!-- 标签由管理员配置 -->
  24. </Choices>
  25. </View>""",
  26. TaskType.IMAGE_CLASSIFICATION: """<View>
  27. <Image name="image" value="$image"/>
  28. <Choices name="label" toName="image" choice="single">
  29. <!-- 标签由管理员配置 -->
  30. </Choices>
  31. </View>""",
  32. TaskType.OBJECT_DETECTION: """<View>
  33. <Image name="image" value="$image"/>
  34. <RectangleLabels name="label" toName="image">
  35. <!-- 标签由管理员配置 -->
  36. </RectangleLabels>
  37. </View>""",
  38. TaskType.NER: """<View>
  39. <Text name="text" value="$text"/>
  40. <Labels name="label" toName="text">
  41. <!-- 标签由管理员配置 -->
  42. </Labels>
  43. </View>"""
  44. }
  45. class ExternalService:
  46. """对外API服务类"""
  47. @staticmethod
  48. def get_default_config(task_type: TaskType) -> str:
  49. """获取任务类型对应的默认XML配置"""
  50. return DEFAULT_CONFIGS.get(task_type, DEFAULT_CONFIGS[TaskType.TEXT_CLASSIFICATION])
  51. @staticmethod
  52. def init_project(request: ProjectInitRequest, user_id: str) -> ProjectInitResponse:
  53. """
  54. 初始化项目并创建任务
  55. Args:
  56. request: 项目初始化请求
  57. user_id: 创建者用户ID
  58. Returns:
  59. ProjectInitResponse: 项目初始化响应
  60. """
  61. # 生成项目ID
  62. project_id = f"proj_{uuid.uuid4().hex[:12]}"
  63. # 获取默认配置
  64. config = ExternalService.get_default_config(request.task_type)
  65. with get_db_connection() as conn:
  66. cursor = conn.cursor()
  67. # 创建项目
  68. cursor.execute("""
  69. INSERT INTO projects (id, name, description, config, status, source, task_type, external_id, updated_at)
  70. VALUES (?, ?, ?, ?, 'draft', 'external', ?, ?, CURRENT_TIMESTAMP)
  71. """, (
  72. project_id,
  73. request.name,
  74. request.description or "",
  75. config,
  76. request.task_type.value,
  77. request.external_id
  78. ))
  79. # 创建任务
  80. task_count = 0
  81. for i, item in enumerate(request.data):
  82. task_id = f"task_{uuid.uuid4().hex[:12]}"
  83. task_name = f"Task {i + 1}"
  84. # 根据任务类型构建数据格式
  85. if request.task_type in [TaskType.TEXT_CLASSIFICATION, TaskType.NER]:
  86. task_data = {
  87. "text": item.content,
  88. "external_id": item.id,
  89. "metadata": item.metadata or {}
  90. }
  91. else:
  92. task_data = {
  93. "image": item.content,
  94. "external_id": item.id,
  95. "metadata": item.metadata or {}
  96. }
  97. cursor.execute("""
  98. INSERT INTO tasks (id, project_id, name, data, status)
  99. VALUES (?, ?, ?, ?, 'pending')
  100. """, (
  101. task_id,
  102. project_id,
  103. task_name,
  104. json.dumps(task_data)
  105. ))
  106. task_count += 1
  107. # 获取创建时间
  108. cursor.execute("SELECT created_at FROM projects WHERE id = ?", (project_id,))
  109. row = cursor.fetchone()
  110. created_at = row["created_at"] if row else datetime.now()
  111. return ProjectInitResponse(
  112. project_id=project_id,
  113. project_name=request.name,
  114. task_count=task_count,
  115. status="draft",
  116. created_at=created_at,
  117. config=config,
  118. external_id=request.external_id
  119. )
  120. @staticmethod
  121. def get_project_progress(project_id: str) -> Optional[ProgressResponse]:
  122. """
  123. 获取项目进度
  124. Args:
  125. project_id: 项目ID
  126. Returns:
  127. ProgressResponse: 进度响应,如果项目不存在返回None
  128. """
  129. with get_db_connection() as conn:
  130. cursor = conn.cursor()
  131. # 获取项目信息
  132. cursor.execute("""
  133. SELECT id, name, status, updated_at
  134. FROM projects
  135. WHERE id = ?
  136. """, (project_id,))
  137. project = cursor.fetchone()
  138. if not project:
  139. return None
  140. # 获取任务统计
  141. cursor.execute("""
  142. SELECT
  143. COUNT(*) as total,
  144. SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed,
  145. SUM(CASE WHEN status = 'in_progress' THEN 1 ELSE 0 END) as in_progress,
  146. SUM(CASE WHEN status = 'pending' THEN 1 ELSE 0 END) as pending
  147. FROM tasks
  148. WHERE project_id = ?
  149. """, (project_id,))
  150. stats = cursor.fetchone()
  151. total_tasks = stats["total"] or 0
  152. completed_tasks = stats["completed"] or 0
  153. in_progress_tasks = stats["in_progress"] or 0
  154. pending_tasks = stats["pending"] or 0
  155. # 计算完成百分比
  156. completion_percentage = 0.0
  157. if total_tasks > 0:
  158. completion_percentage = round((completed_tasks / total_tasks) * 100, 2)
  159. # 获取标注人员统计
  160. cursor.execute("""
  161. SELECT
  162. t.assigned_to,
  163. u.username,
  164. COUNT(*) as assigned_count,
  165. SUM(CASE WHEN t.status = 'completed' THEN 1 ELSE 0 END) as completed_count,
  166. SUM(CASE WHEN t.status = 'in_progress' THEN 1 ELSE 0 END) as in_progress_count
  167. FROM tasks t
  168. LEFT JOIN users u ON t.assigned_to = u.id
  169. WHERE t.project_id = ? AND t.assigned_to IS NOT NULL
  170. GROUP BY t.assigned_to, u.username
  171. """, (project_id,))
  172. annotators = []
  173. for row in cursor.fetchall():
  174. assigned_count = row["assigned_count"] or 0
  175. completed_count = row["completed_count"] or 0
  176. completion_rate = 0.0
  177. if assigned_count > 0:
  178. completion_rate = round((completed_count / assigned_count) * 100, 2)
  179. annotators.append(AnnotatorProgress(
  180. user_id=row["assigned_to"] or "",
  181. username=row["username"] or "Unknown",
  182. assigned_count=assigned_count,
  183. completed_count=completed_count,
  184. in_progress_count=row["in_progress_count"] or 0,
  185. completion_rate=completion_rate
  186. ))
  187. return ProgressResponse(
  188. project_id=project_id,
  189. project_name=project["name"],
  190. status=project["status"] or "draft",
  191. total_tasks=total_tasks,
  192. completed_tasks=completed_tasks,
  193. in_progress_tasks=in_progress_tasks,
  194. pending_tasks=pending_tasks,
  195. completion_percentage=completion_percentage,
  196. annotators=annotators,
  197. last_updated=project["updated_at"]
  198. )
  199. @staticmethod
  200. def check_project_exists(project_id: str) -> bool:
  201. """检查项目是否存在"""
  202. with get_db_connection() as conn:
  203. cursor = conn.cursor()
  204. cursor.execute("SELECT id FROM projects WHERE id = ?", (project_id,))
  205. return cursor.fetchone() is not None
  206. @staticmethod
  207. def export_project_data(
  208. project_id: str,
  209. request: ExternalExportRequest,
  210. base_url: str = ""
  211. ) -> Optional[ExternalExportResponse]:
  212. """
  213. 导出项目数据
  214. Args:
  215. project_id: 项目ID
  216. request: 导出请求
  217. base_url: 基础URL,用于生成下载链接
  218. Returns:
  219. ExternalExportResponse: 导出响应,如果项目不存在返回None
  220. """
  221. import os
  222. from datetime import timedelta
  223. with get_db_connection() as conn:
  224. cursor = conn.cursor()
  225. # 检查项目是否存在
  226. cursor.execute("SELECT id, name FROM projects WHERE id = ?", (project_id,))
  227. project = cursor.fetchone()
  228. if not project:
  229. return None
  230. # 构建查询条件
  231. status_filter = ""
  232. if request.completed_only:
  233. status_filter = "AND t.status = 'completed'"
  234. # 获取任务和标注数据
  235. cursor.execute(f"""
  236. SELECT
  237. t.id as task_id,
  238. t.data,
  239. t.status,
  240. t.assigned_to,
  241. u.username as annotator_name,
  242. a.id as annotation_id,
  243. a.result as annotation_result,
  244. a.updated_at as annotation_time
  245. FROM tasks t
  246. LEFT JOIN users u ON t.assigned_to = u.id
  247. LEFT JOIN annotations a ON t.id = a.task_id
  248. WHERE t.project_id = ? {status_filter}
  249. ORDER BY t.id
  250. """, (project_id,))
  251. rows = cursor.fetchall()
  252. # 组织数据
  253. tasks_data = {}
  254. for row in rows:
  255. task_id = row["task_id"]
  256. if task_id not in tasks_data:
  257. task_data = json.loads(row["data"]) if row["data"] else {}
  258. tasks_data[task_id] = {
  259. "task_id": task_id,
  260. "external_id": task_data.get("external_id"),
  261. "original_data": task_data,
  262. "annotations": [],
  263. "status": row["status"],
  264. "annotator": row["annotator_name"],
  265. "completed_at": None
  266. }
  267. if row["annotation_id"]:
  268. annotation_result = json.loads(row["annotation_result"]) if row["annotation_result"] else {}
  269. tasks_data[task_id]["annotations"].append(annotation_result)
  270. if row["annotation_time"]:
  271. tasks_data[task_id]["completed_at"] = row["annotation_time"]
  272. # 转换为列表
  273. export_data = list(tasks_data.values())
  274. total_exported = len(export_data)
  275. # 生成导出文件
  276. export_id = f"export_{uuid.uuid4().hex[:12]}"
  277. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  278. # 确保导出目录存在
  279. export_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "exports")
  280. os.makedirs(export_dir, exist_ok=True)
  281. # 根据格式生成文件
  282. file_name, file_content = ExternalService._generate_export_file(
  283. export_data,
  284. request.format,
  285. project_id,
  286. timestamp
  287. )
  288. file_path = os.path.join(export_dir, file_name)
  289. # 写入文件
  290. if isinstance(file_content, str):
  291. with open(file_path, 'w', encoding='utf-8') as f:
  292. f.write(file_content)
  293. else:
  294. with open(file_path, 'wb') as f:
  295. f.write(file_content)
  296. file_size = os.path.getsize(file_path)
  297. # 记录导出任务
  298. cursor.execute("""
  299. INSERT INTO export_jobs (id, project_id, format, status, status_filter, file_path, total_tasks, exported_tasks, completed_at)
  300. VALUES (?, ?, ?, 'completed', ?, ?, ?, ?, CURRENT_TIMESTAMP)
  301. """, (
  302. export_id,
  303. project_id,
  304. request.format.value,
  305. 'completed' if request.completed_only else 'all',
  306. file_path,
  307. total_exported,
  308. total_exported
  309. ))
  310. # 计算过期时间(7天后)
  311. expires_at = datetime.now() + timedelta(days=7)
  312. return ExternalExportResponse(
  313. project_id=project_id,
  314. format=request.format.value,
  315. total_exported=total_exported,
  316. file_url=f"/api/exports/{export_id}/download",
  317. file_name=file_name,
  318. file_size=file_size,
  319. expires_at=expires_at
  320. )
  321. @staticmethod
  322. def _generate_export_file(
  323. data: List[Dict],
  324. format: ExternalExportFormat,
  325. project_id: str,
  326. timestamp: str
  327. ) -> tuple:
  328. """
  329. 根据格式生成导出文件
  330. Returns:
  331. tuple: (文件名, 文件内容)
  332. """
  333. if format == ExternalExportFormat.JSON:
  334. return ExternalService._export_json(data, project_id, timestamp)
  335. elif format == ExternalExportFormat.CSV:
  336. return ExternalService._export_csv(data, project_id, timestamp)
  337. elif format == ExternalExportFormat.SHAREGPT:
  338. return ExternalService._export_sharegpt(data, project_id, timestamp)
  339. elif format == ExternalExportFormat.YOLO:
  340. return ExternalService._export_yolo(data, project_id, timestamp)
  341. elif format == ExternalExportFormat.COCO:
  342. return ExternalService._export_coco(data, project_id, timestamp)
  343. elif format == ExternalExportFormat.ALPACA:
  344. return ExternalService._export_alpaca(data, project_id, timestamp)
  345. else:
  346. return ExternalService._export_json(data, project_id, timestamp)
  347. @staticmethod
  348. def _export_json(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  349. """导出JSON格式"""
  350. file_name = f"export_{project_id}_{timestamp}.json"
  351. content = json.dumps(data, ensure_ascii=False, indent=2)
  352. return file_name, content
  353. @staticmethod
  354. def _export_csv(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  355. """导出CSV格式"""
  356. import csv
  357. import io
  358. file_name = f"export_{project_id}_{timestamp}.csv"
  359. output = io.StringIO()
  360. writer = csv.writer(output)
  361. # 写入表头
  362. writer.writerow(['task_id', 'external_id', 'status', 'annotator', 'original_data', 'annotations'])
  363. # 写入数据
  364. for item in data:
  365. writer.writerow([
  366. item.get('task_id', ''),
  367. item.get('external_id', ''),
  368. item.get('status', ''),
  369. item.get('annotator', ''),
  370. json.dumps(item.get('original_data', {}), ensure_ascii=False),
  371. json.dumps(item.get('annotations', []), ensure_ascii=False)
  372. ])
  373. return file_name, output.getvalue()
  374. @staticmethod
  375. def _export_sharegpt(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  376. """导出ShareGPT对话格式"""
  377. file_name = f"export_{project_id}_sharegpt_{timestamp}.json"
  378. conversations = []
  379. for item in data:
  380. original = item.get('original_data', {})
  381. annotations = item.get('annotations', [])
  382. # 获取原始文本
  383. text = original.get('text', original.get('image', ''))
  384. # 获取标注结果
  385. label = ""
  386. if annotations:
  387. for ann in annotations:
  388. if isinstance(ann, list):
  389. for a in ann:
  390. if 'value' in a and 'choices' in a['value']:
  391. label = ', '.join(a['value']['choices'])
  392. break
  393. elif isinstance(ann, dict):
  394. if 'value' in ann and 'choices' in ann['value']:
  395. label = ', '.join(ann['value']['choices'])
  396. if text and label:
  397. conversations.append({
  398. "conversations": [
  399. {"from": "human", "value": text},
  400. {"from": "gpt", "value": label}
  401. ]
  402. })
  403. content = json.dumps(conversations, ensure_ascii=False, indent=2)
  404. return file_name, content
  405. @staticmethod
  406. def _export_yolo(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  407. """导出YOLO格式(简化版,返回JSON描述)"""
  408. file_name = f"export_{project_id}_yolo_{timestamp}.json"
  409. yolo_data = []
  410. for item in data:
  411. original = item.get('original_data', {})
  412. annotations = item.get('annotations', [])
  413. image_url = original.get('image', '')
  414. boxes = []
  415. for ann in annotations:
  416. if isinstance(ann, list):
  417. for a in ann:
  418. if a.get('type') == 'rectanglelabels':
  419. value = a.get('value', {})
  420. boxes.append({
  421. "label": value.get('rectanglelabels', [''])[0],
  422. "x": value.get('x', 0) / 100,
  423. "y": value.get('y', 0) / 100,
  424. "width": value.get('width', 0) / 100,
  425. "height": value.get('height', 0) / 100
  426. })
  427. if image_url:
  428. yolo_data.append({
  429. "image": image_url,
  430. "boxes": boxes
  431. })
  432. content = json.dumps(yolo_data, ensure_ascii=False, indent=2)
  433. return file_name, content
  434. @staticmethod
  435. def _export_coco(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  436. """导出COCO格式"""
  437. file_name = f"export_{project_id}_coco_{timestamp}.json"
  438. coco_data = {
  439. "images": [],
  440. "annotations": [],
  441. "categories": []
  442. }
  443. category_map = {}
  444. annotation_id = 1
  445. for idx, item in enumerate(data):
  446. original = item.get('original_data', {})
  447. annotations = item.get('annotations', [])
  448. image_url = original.get('image', '')
  449. # 添加图像
  450. coco_data["images"].append({
  451. "id": idx + 1,
  452. "file_name": image_url,
  453. "width": 0,
  454. "height": 0
  455. })
  456. # 处理标注
  457. for ann in annotations:
  458. if isinstance(ann, list):
  459. for a in ann:
  460. if a.get('type') == 'rectanglelabels':
  461. value = a.get('value', {})
  462. label = value.get('rectanglelabels', [''])[0]
  463. # 添加类别
  464. if label and label not in category_map:
  465. cat_id = len(category_map) + 1
  466. category_map[label] = cat_id
  467. coco_data["categories"].append({
  468. "id": cat_id,
  469. "name": label
  470. })
  471. if label:
  472. coco_data["annotations"].append({
  473. "id": annotation_id,
  474. "image_id": idx + 1,
  475. "category_id": category_map.get(label, 0),
  476. "bbox": [
  477. value.get('x', 0),
  478. value.get('y', 0),
  479. value.get('width', 0),
  480. value.get('height', 0)
  481. ],
  482. "area": value.get('width', 0) * value.get('height', 0),
  483. "iscrowd": 0
  484. })
  485. annotation_id += 1
  486. content = json.dumps(coco_data, ensure_ascii=False, indent=2)
  487. return file_name, content
  488. @staticmethod
  489. def _export_alpaca(data: List[Dict], project_id: str, timestamp: str) -> tuple:
  490. """导出Alpaca指令微调格式"""
  491. file_name = f"export_{project_id}_alpaca_{timestamp}.json"
  492. alpaca_data = []
  493. for item in data:
  494. original = item.get('original_data', {})
  495. annotations = item.get('annotations', [])
  496. # 获取原始文本
  497. text = original.get('text', '')
  498. # 获取标注结果
  499. label = ""
  500. if annotations:
  501. for ann in annotations:
  502. if isinstance(ann, list):
  503. for a in ann:
  504. if 'value' in a and 'choices' in a['value']:
  505. label = ', '.join(a['value']['choices'])
  506. break
  507. elif isinstance(ann, dict):
  508. if 'value' in ann and 'choices' in ann['value']:
  509. label = ', '.join(ann['value']['choices'])
  510. if text:
  511. alpaca_data.append({
  512. "instruction": "请对以下文本进行分类",
  513. "input": text,
  514. "output": label or "未标注"
  515. })
  516. content = json.dumps(alpaca_data, ensure_ascii=False, indent=2)
  517. return file_name, content