open_project_view.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import os
  2. import logging
  3. from datetime import datetime, timezone, timedelta
  4. from typing import Optional
  5. from fastapi import APIRouter, Query, Depends, HTTPException, status
  6. from fastapi.responses import FileResponse
  7. from schemas.open_project import (
  8. OpenProjectListResponse, OpenProjectListData, OpenProjectDetailResponse,
  9. OpenProjectItem, OpenProjectDetailItem,
  10. )
  11. from schemas.open_dataset import (
  12. DatasetDownloadRequest, DatasetDownloadResponse, DatasetDownloadResponseData,
  13. TEXT_FORMATS, IMAGE_FORMATS,
  14. )
  15. from services.api.open_project_service import list_projects, get_project_detail, TASK_TO_PROJECT_TYPE, create_download_token, get_download_info
  16. from services.open_middleware import verify_open_api_token
  17. from services.export_service import ExportService
  18. from database import get_db_connection
  19. logger = logging.getLogger(__name__)
  20. router = APIRouter(
  21. prefix="/api/v1/open/projects",
  22. tags=["open-api-projects"],
  23. )
  24. @router.get("", response_model=OpenProjectListResponse)
  25. async def list_projects_endpoint(
  26. _auth: dict = Depends(verify_open_api_token),
  27. name: Optional[str] = Query(None, description="项目名称(模糊匹配)"),
  28. project_type: Optional[str] = Query(None, alias="type", description="项目类型: image/text"),
  29. status: Optional[str] = Query(None, description="项目状态筛选"),
  30. page: int = Query(1, ge=1, description="页码"),
  31. page_size: int = Query(20, ge=1, le=100, description="每页数量"),
  32. ):
  33. """查询标注项目列表"""
  34. try:
  35. result = list_projects(
  36. name=name,
  37. project_type=project_type,
  38. status=status,
  39. page=page,
  40. page_size=page_size,
  41. )
  42. except Exception as e:
  43. logger.error(f"Query projects failed: {e}")
  44. raise HTTPException(
  45. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  46. detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
  47. )
  48. return OpenProjectListResponse(
  49. code=0,
  50. message="success",
  51. data=OpenProjectListData(**result),
  52. )
  53. @router.get("/{project_id}", response_model=OpenProjectDetailResponse)
  54. async def get_project_detail_endpoint(
  55. project_id: str,
  56. _auth: dict = Depends(verify_open_api_token),
  57. ):
  58. """根据项目 ID 查询项目详细信息"""
  59. try:
  60. result = get_project_detail(project_id)
  61. except Exception as e:
  62. logger.error(f"Query project detail failed: {e}")
  63. raise HTTPException(
  64. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  65. detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
  66. )
  67. if not result:
  68. raise HTTPException(
  69. status_code=status.HTTP_404_NOT_FOUND,
  70. detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
  71. )
  72. return OpenProjectDetailResponse(
  73. code=0,
  74. message="success",
  75. data=OpenProjectDetailItem(**result),
  76. )
  77. # --- Dataset download router ---
  78. download_router = APIRouter(
  79. prefix="/api/v1/open",
  80. tags=["open-api-datasets"],
  81. )
  82. @download_router.post("/projects/{project_id}/datasets/download", response_model=DatasetDownloadResponse)
  83. async def download_dataset(
  84. project_id: str,
  85. req: DatasetDownloadRequest,
  86. _auth: dict = Depends(verify_open_api_token),
  87. ):
  88. """根据项目 ID 和格式导出标注数据集"""
  89. # 1. Check project exists
  90. project = _check_project(project_id)
  91. if not project:
  92. raise HTTPException(
  93. status_code=status.HTTP_404_NOT_FOUND,
  94. detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
  95. )
  96. # 2. Check format compatibility
  97. project_type = TASK_TO_PROJECT_TYPE.get(project["task_type"] or "", "text")
  98. if req.format in TEXT_FORMATS and project_type != "text":
  99. raise HTTPException(
  100. status_code=status.HTTP_400_BAD_REQUEST,
  101. detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
  102. )
  103. if req.format in IMAGE_FORMATS and project_type != "image":
  104. raise HTTPException(
  105. status_code=status.HTTP_400_BAD_REQUEST,
  106. detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
  107. )
  108. # 3. Check data availability
  109. status_filter = "completed" if req.completed_only else "all"
  110. tasks = ExportService.get_tasks_with_annotations(project_id, status_filter)
  111. if not tasks:
  112. raise HTTPException(
  113. status_code=status.HTTP_404_NOT_FOUND,
  114. detail={"error_code": "NO_DATA_AVAILABLE", "message": "项目中没有可导出的数据"},
  115. )
  116. # 4. Execute export
  117. format_val = req.format.value
  118. export_method = {
  119. "json": ExportService.export_to_json,
  120. "csv": ExportService.export_to_csv,
  121. "coco": ExportService.export_to_coco,
  122. "yolo": ExportService.export_to_yolo,
  123. "pascal_voc": ExportService.export_to_pascal_voc,
  124. "sharegpt": ExportService.export_to_sharegpt,
  125. "alpaca": ExportService.export_to_alpaca,
  126. }.get(format_val)
  127. if not export_method:
  128. raise HTTPException(
  129. status_code=status.HTTP_400_BAD_REQUEST,
  130. detail={"error_code": "INVALID_FORMAT", "message": f"不支持的导出格式: {format_val}"},
  131. )
  132. try:
  133. file_path, total_tasks, total_annotations = export_method(
  134. project_id=project_id,
  135. status_filter=status_filter,
  136. include_metadata=False,
  137. )
  138. except Exception as e:
  139. logger.error(f"Export failed for project {project_id}: {e}")
  140. raise HTTPException(
  141. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  142. detail={"error_code": "EXPORT_FAILED", "message": f"导出失败: {str(e)}"},
  143. )
  144. # 5. Generate download token
  145. now = datetime.now(timezone.utc)
  146. expires_at = now + timedelta(hours=2)
  147. download_token = create_download_token(
  148. file_path=file_path,
  149. project_id=project_id,
  150. format_val=format_val,
  151. total_exported=total_tasks,
  152. expires_at=expires_at,
  153. )
  154. file_name = f"{project_id}_{format_val}_{now.strftime('%Y%m%d_%H%M%S')}.{_get_ext(format_val)}"
  155. file_size = os.path.getsize(file_path) if os.path.exists(file_path) else None
  156. return DatasetDownloadResponse(
  157. code=0,
  158. message="success",
  159. data=DatasetDownloadResponseData(
  160. project_id=project_id,
  161. format=format_val,
  162. total_exported=total_tasks,
  163. file_url=f"/api/v1/open/datasets/downloads/{download_token}",
  164. file_name=file_name,
  165. file_size=file_size,
  166. expires_at=expires_at,
  167. status="completed",
  168. ),
  169. )
  170. @download_router.get("/datasets/downloads/{download_token}")
  171. async def download_file(
  172. download_token: str,
  173. _auth: dict = Depends(verify_open_api_token),
  174. ):
  175. """根据下载令牌获取实际的数据集文件"""
  176. info = get_download_info(download_token)
  177. if not info:
  178. raise HTTPException(
  179. status_code=status.HTTP_410_GONE,
  180. detail={"error_code": "DOWNLOAD_EXPIRED", "message": "下载链接已过期"},
  181. )
  182. if not os.path.exists(info["file_path"]):
  183. raise HTTPException(
  184. status_code=status.HTTP_404_NOT_FOUND,
  185. detail={"error_code": "FILE_NOT_FOUND", "message": "导出文件不存在"},
  186. )
  187. filename = os.path.basename(info["file_path"])
  188. media_type = _get_media_type(info["format_val"])
  189. return FileResponse(
  190. path=info["file_path"],
  191. media_type=media_type,
  192. filename=filename,
  193. )
  194. def _check_project(project_id: str) -> Optional[dict]:
  195. """Check project exists and is in allowed status."""
  196. with get_db_connection() as conn:
  197. cursor = conn.cursor()
  198. cursor.execute(
  199. "SELECT id, name, task_type, status FROM projects WHERE id = %s",
  200. (project_id,),
  201. )
  202. row = cursor.fetchone()
  203. if row:
  204. return dict(row)
  205. return None
  206. def _get_ext(format_val: str) -> str:
  207. return "xml" if format_val == "pascal_voc" else "json" if format_val == "coco" else "csv" if format_val == "csv" else "json"
  208. def _get_media_type(format_val: str) -> str:
  209. return "text/csv" if format_val == "csv" else "application/xml" if format_val == "pascal_voc" else "application/json"