|
@@ -0,0 +1,242 @@
|
|
|
|
|
+import os
|
|
|
|
|
+import logging
|
|
|
|
|
+from datetime import datetime, timezone, timedelta
|
|
|
|
|
+from typing import Optional
|
|
|
|
|
+from fastapi import APIRouter, Query, Depends, HTTPException, status
|
|
|
|
|
+from fastapi.responses import FileResponse
|
|
|
|
|
+from schemas.open_project import (
|
|
|
|
|
+ OpenProjectListResponse, OpenProjectListData, OpenProjectDetailResponse,
|
|
|
|
|
+ OpenProjectItem, OpenProjectDetailItem,
|
|
|
|
|
+)
|
|
|
|
|
+from schemas.open_dataset import (
|
|
|
|
|
+ DatasetDownloadRequest, DatasetDownloadResponse, DatasetDownloadResponseData,
|
|
|
|
|
+ TEXT_FORMATS, IMAGE_FORMATS,
|
|
|
|
|
+)
|
|
|
|
|
+from services.api.open_project_service import list_projects, get_project_detail, TASK_TO_PROJECT_TYPE, create_download_token, get_download_info
|
|
|
|
|
+from services.open_middleware import verify_open_api_token
|
|
|
|
|
+from services.export_service import ExportService
|
|
|
|
|
+from database import get_db_connection
|
|
|
|
|
+
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+router = APIRouter(
|
|
|
|
|
+ prefix="/api/v1/open/projects",
|
|
|
|
|
+ tags=["open-api-projects"],
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("", response_model=OpenProjectListResponse)
|
|
|
|
|
+async def list_projects_endpoint(
|
|
|
|
|
+ _auth: dict = Depends(verify_open_api_token),
|
|
|
|
|
+ name: Optional[str] = Query(None, description="项目名称(模糊匹配)"),
|
|
|
|
|
+ project_type: Optional[str] = Query(None, alias="type", description="项目类型: image/text"),
|
|
|
|
|
+ status: Optional[str] = Query(None, description="项目状态筛选"),
|
|
|
|
|
+ page: int = Query(1, ge=1, description="页码"),
|
|
|
|
|
+ page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
|
|
|
|
+):
|
|
|
|
|
+ """查询标注项目列表"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = list_projects(
|
|
|
|
|
+ name=name,
|
|
|
|
|
+ project_type=project_type,
|
|
|
|
|
+ status=status,
|
|
|
|
|
+ page=page,
|
|
|
|
|
+ page_size=page_size,
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Query projects failed: {e}")
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return OpenProjectListResponse(
|
|
|
|
|
+ code=0,
|
|
|
|
|
+ message="success",
|
|
|
|
|
+ data=OpenProjectListData(**result),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/{project_id}", response_model=OpenProjectDetailResponse)
|
|
|
|
|
+async def get_project_detail_endpoint(
|
|
|
|
|
+ project_id: str,
|
|
|
|
|
+ _auth: dict = Depends(verify_open_api_token),
|
|
|
|
|
+):
|
|
|
|
|
+ """根据项目 ID 查询项目详细信息"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ result = get_project_detail(project_id)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Query project detail failed: {e}")
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not result:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
+ detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return OpenProjectDetailResponse(
|
|
|
|
|
+ code=0,
|
|
|
|
|
+ message="success",
|
|
|
|
|
+ data=OpenProjectDetailItem(**result),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --- Dataset download router ---
|
|
|
|
|
+
|
|
|
|
|
+download_router = APIRouter(
|
|
|
|
|
+ prefix="/api/v1/open",
|
|
|
|
|
+ tags=["open-api-datasets"],
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@download_router.post("/projects/{project_id}/datasets/download", response_model=DatasetDownloadResponse)
|
|
|
|
|
+async def download_dataset(
|
|
|
|
|
+ project_id: str,
|
|
|
|
|
+ req: DatasetDownloadRequest,
|
|
|
|
|
+ _auth: dict = Depends(verify_open_api_token),
|
|
|
|
|
+):
|
|
|
|
|
+ """根据项目 ID 和格式导出标注数据集"""
|
|
|
|
|
+ # 1. Check project exists
|
|
|
|
|
+ project = _check_project(project_id)
|
|
|
|
|
+ if not project:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
+ detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 2. Check format compatibility
|
|
|
|
|
+ project_type = TASK_TO_PROJECT_TYPE.get(project["task_type"] or "", "text")
|
|
|
|
|
+ if req.format in TEXT_FORMATS and project_type != "text":
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
+ detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
|
|
|
|
|
+ )
|
|
|
|
|
+ if req.format in IMAGE_FORMATS and project_type != "image":
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
+ detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 3. Check data availability
|
|
|
|
|
+ status_filter = "completed" if req.completed_only else "all"
|
|
|
|
|
+ tasks = ExportService.get_tasks_with_annotations(project_id, status_filter)
|
|
|
|
|
+ if not tasks:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
+ detail={"error_code": "NO_DATA_AVAILABLE", "message": "项目中没有可导出的数据"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 4. Execute export
|
|
|
|
|
+ format_val = req.format.value
|
|
|
|
|
+ export_method = {
|
|
|
|
|
+ "json": ExportService.export_to_json,
|
|
|
|
|
+ "csv": ExportService.export_to_csv,
|
|
|
|
|
+ "coco": ExportService.export_to_coco,
|
|
|
|
|
+ "yolo": ExportService.export_to_yolo,
|
|
|
|
|
+ "pascal_voc": ExportService.export_to_pascal_voc,
|
|
|
|
|
+ "sharegpt": ExportService.export_to_sharegpt,
|
|
|
|
|
+ "alpaca": ExportService.export_to_alpaca,
|
|
|
|
|
+ }.get(format_val)
|
|
|
|
|
+
|
|
|
|
|
+ if not export_method:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
+ detail={"error_code": "INVALID_FORMAT", "message": f"不支持的导出格式: {format_val}"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ file_path, total_tasks, total_annotations = export_method(
|
|
|
|
|
+ project_id=project_id,
|
|
|
|
|
+ status_filter=status_filter,
|
|
|
|
|
+ include_metadata=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Export failed for project {project_id}: {e}")
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail={"error_code": "EXPORT_FAILED", "message": f"导出失败: {str(e)}"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 5. Generate download token
|
|
|
|
|
+ now = datetime.now(timezone.utc)
|
|
|
|
|
+ expires_at = now + timedelta(hours=2)
|
|
|
|
|
+ download_token = create_download_token(
|
|
|
|
|
+ file_path=file_path,
|
|
|
|
|
+ project_id=project_id,
|
|
|
|
|
+ format_val=format_val,
|
|
|
|
|
+ total_exported=total_tasks,
|
|
|
|
|
+ expires_at=expires_at,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ file_name = f"{project_id}_{format_val}_{now.strftime('%Y%m%d_%H%M%S')}.{_get_ext(format_val)}"
|
|
|
|
|
+ file_size = os.path.getsize(file_path) if os.path.exists(file_path) else None
|
|
|
|
|
+
|
|
|
|
|
+ return DatasetDownloadResponse(
|
|
|
|
|
+ code=0,
|
|
|
|
|
+ message="success",
|
|
|
|
|
+ data=DatasetDownloadResponseData(
|
|
|
|
|
+ project_id=project_id,
|
|
|
|
|
+ format=format_val,
|
|
|
|
|
+ total_exported=total_tasks,
|
|
|
|
|
+ file_url=f"/api/v1/open/datasets/downloads/{download_token}",
|
|
|
|
|
+ file_name=file_name,
|
|
|
|
|
+ file_size=file_size,
|
|
|
|
|
+ expires_at=expires_at,
|
|
|
|
|
+ status="completed",
|
|
|
|
|
+ ),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@download_router.get("/datasets/downloads/{download_token}")
|
|
|
|
|
+async def download_file(
|
|
|
|
|
+ download_token: str,
|
|
|
|
|
+ _auth: dict = Depends(verify_open_api_token),
|
|
|
|
|
+):
|
|
|
|
|
+ """根据下载令牌获取实际的数据集文件"""
|
|
|
|
|
+ info = get_download_info(download_token)
|
|
|
|
|
+ if not info:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_410_GONE,
|
|
|
|
|
+ detail={"error_code": "DOWNLOAD_EXPIRED", "message": "下载链接已过期"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(info["file_path"]):
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
+ detail={"error_code": "FILE_NOT_FOUND", "message": "导出文件不存在"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ filename = os.path.basename(info["file_path"])
|
|
|
|
|
+ media_type = _get_media_type(info["format_val"])
|
|
|
|
|
+
|
|
|
|
|
+ return FileResponse(
|
|
|
|
|
+ path=info["file_path"],
|
|
|
|
|
+ media_type=media_type,
|
|
|
|
|
+ filename=filename,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _check_project(project_id: str) -> Optional[dict]:
|
|
|
|
|
+ """Check project exists and is in allowed status."""
|
|
|
|
|
+ with get_db_connection() as conn:
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "SELECT id, name, task_type, status FROM projects WHERE id = %s",
|
|
|
|
|
+ (project_id,),
|
|
|
|
|
+ )
|
|
|
|
|
+ row = cursor.fetchone()
|
|
|
|
|
+ if row:
|
|
|
|
|
+ return dict(row)
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _get_ext(format_val: str) -> str:
|
|
|
|
|
+ return "xml" if format_val == "pascal_voc" else "json" if format_val == "coco" else "csv" if format_val == "csv" else "json"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _get_media_type(format_val: str) -> str:
|
|
|
|
|
+ return "text/csv" if format_val == "csv" else "application/xml" if format_val == "pascal_voc" else "application/json"
|