import os import logging from datetime import datetime, timezone, timedelta from typing import Optional from fastapi import APIRouter, Query, Depends, HTTPException, status from fastapi.responses import FileResponse from schemas.open_project import ( OpenProjectListResponse, OpenProjectListData, OpenProjectDetailResponse, OpenProjectItem, OpenProjectDetailItem, ) from schemas.open_dataset import ( DatasetDownloadRequest, DatasetDownloadResponse, DatasetDownloadResponseData, TEXT_FORMATS, IMAGE_FORMATS, ) from services.api.open_project_service import list_projects, get_project_detail, TASK_TO_PROJECT_TYPE, create_download_token, get_download_info from services.open_middleware import verify_open_api_token from services.export_service import ExportService from database import get_db_connection logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/v1/open/projects", tags=["open-api-projects"], ) @router.get("", response_model=OpenProjectListResponse) async def list_projects_endpoint( _auth: dict = Depends(verify_open_api_token), name: Optional[str] = Query(None, description="项目名称(模糊匹配)"), project_type: Optional[str] = Query(None, alias="type", description="项目类型: image/text"), status: Optional[str] = Query(None, description="项目状态筛选"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), ): """查询标注项目列表""" try: result = list_projects( name=name, project_type=project_type, status=status, page=page, page_size=page_size, ) except Exception as e: logger.error(f"Query projects failed: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error_code": "INTERNAL_ERROR", "message": str(e)}, ) return OpenProjectListResponse( code=0, message="success", data=OpenProjectListData(**result), ) @router.get("/{project_id}", response_model=OpenProjectDetailResponse) async def get_project_detail_endpoint( project_id: str, _auth: dict = Depends(verify_open_api_token), ): """根据项目 ID 查询项目详细信息""" try: result = get_project_detail(project_id) except Exception as e: logger.error(f"Query project detail failed: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error_code": "INTERNAL_ERROR", "message": str(e)}, ) if not result: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"}, ) return OpenProjectDetailResponse( code=0, message="success", data=OpenProjectDetailItem(**result), ) # --- Dataset download router --- download_router = APIRouter( prefix="/api/v1/open", tags=["open-api-datasets"], ) @download_router.post("/projects/{project_id}/datasets/download", response_model=DatasetDownloadResponse) async def download_dataset( project_id: str, req: DatasetDownloadRequest, _auth: dict = Depends(verify_open_api_token), ): """根据项目 ID 和格式导出标注数据集""" # 1. Check project exists project = _check_project(project_id) if not project: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"}, ) # 2. Check format compatibility project_type = TASK_TO_PROJECT_TYPE.get(project["task_type"] or "", "text") if req.format in TEXT_FORMATS and project_type != "text": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"}, ) if req.format in IMAGE_FORMATS and project_type != "image": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"}, ) # 3. Check data availability status_filter = "completed" if req.completed_only else "all" tasks = ExportService.get_tasks_with_annotations(project_id, status_filter) if not tasks: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={"error_code": "NO_DATA_AVAILABLE", "message": "项目中没有可导出的数据"}, ) # 4. Execute export format_val = req.format.value export_method = { "json": ExportService.export_to_json, "csv": ExportService.export_to_csv, "coco": ExportService.export_to_coco, "yolo": ExportService.export_to_yolo, "pascal_voc": ExportService.export_to_pascal_voc, "sharegpt": ExportService.export_to_sharegpt, "alpaca": ExportService.export_to_alpaca, }.get(format_val) if not export_method: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error_code": "INVALID_FORMAT", "message": f"不支持的导出格式: {format_val}"}, ) try: file_path, total_tasks, total_annotations = export_method( project_id=project_id, status_filter=status_filter, include_metadata=False, ) except Exception as e: logger.error(f"Export failed for project {project_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error_code": "EXPORT_FAILED", "message": f"导出失败: {str(e)}"}, ) # 5. Generate download token now = datetime.now(timezone.utc) expires_at = now + timedelta(hours=2) download_token = create_download_token( file_path=file_path, project_id=project_id, format_val=format_val, total_exported=total_tasks, expires_at=expires_at, ) file_name = f"{project_id}_{format_val}_{now.strftime('%Y%m%d_%H%M%S')}.{_get_ext(format_val)}" file_size = os.path.getsize(file_path) if os.path.exists(file_path) else None return DatasetDownloadResponse( code=0, message="success", data=DatasetDownloadResponseData( project_id=project_id, format=format_val, total_exported=total_tasks, file_url=f"/api/v1/open/datasets/downloads/{download_token}", file_name=file_name, file_size=file_size, expires_at=expires_at, status="completed", ), ) @download_router.get("/datasets/downloads/{download_token}") async def download_file( download_token: str, _auth: dict = Depends(verify_open_api_token), ): """根据下载令牌获取实际的数据集文件""" info = get_download_info(download_token) if not info: raise HTTPException( status_code=status.HTTP_410_GONE, detail={"error_code": "DOWNLOAD_EXPIRED", "message": "下载链接已过期"}, ) if not os.path.exists(info["file_path"]): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail={"error_code": "FILE_NOT_FOUND", "message": "导出文件不存在"}, ) filename = os.path.basename(info["file_path"]) media_type = _get_media_type(info["format_val"]) return FileResponse( path=info["file_path"], media_type=media_type, filename=filename, ) def _check_project(project_id: str) -> Optional[dict]: """Check project exists and is in allowed status.""" with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( "SELECT id, name, task_type, status FROM projects WHERE id = %s", (project_id,), ) row = cursor.fetchone() if row: return dict(row) return None def _get_ext(format_val: str) -> str: return "xml" if format_val == "pascal_voc" else "json" if format_val == "coco" else "csv" if format_val == "csv" else "json" def _get_media_type(format_val: str) -> str: return "text/csv" if format_val == "csv" else "application/xml" if format_val == "pascal_voc" else "application/json"