| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import os
- import logging
- from datetime import datetime, timezone, timedelta
- from typing import Optional
- from fastapi import APIRouter, Query, Depends, HTTPException, status
- from fastapi.responses import FileResponse
- from schemas.open_project import (
- OpenProjectListResponse, OpenProjectListData, OpenProjectDetailResponse,
- OpenProjectItem, OpenProjectDetailItem,
- )
- from schemas.open_dataset import (
- DatasetDownloadRequest, DatasetDownloadResponse, DatasetDownloadResponseData,
- TEXT_FORMATS, IMAGE_FORMATS,
- )
- from services.api.open_project_service import list_projects, get_project_detail, TASK_TO_PROJECT_TYPE, create_download_token, get_download_info
- from services.open_middleware import verify_open_api_token
- from services.export_service import ExportService
- from database import get_db_connection
- logger = logging.getLogger(__name__)
- router = APIRouter(
- prefix="/api/v1/open/projects",
- tags=["open-api-projects"],
- )
- @router.get("", response_model=OpenProjectListResponse)
- async def list_projects_endpoint(
- _auth: dict = Depends(verify_open_api_token),
- name: Optional[str] = Query(None, description="项目名称(模糊匹配)"),
- project_type: Optional[str] = Query(None, alias="type", description="项目类型: image/text"),
- status: Optional[str] = Query(None, description="项目状态筛选"),
- page: int = Query(1, ge=1, description="页码"),
- page_size: int = Query(20, ge=1, le=100, description="每页数量"),
- ):
- """查询标注项目列表"""
- try:
- result = list_projects(
- name=name,
- project_type=project_type,
- status=status,
- page=page,
- page_size=page_size,
- )
- except Exception as e:
- logger.error(f"Query projects failed: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
- )
- return OpenProjectListResponse(
- code=0,
- message="success",
- data=OpenProjectListData(**result),
- )
- @router.get("/{project_id}", response_model=OpenProjectDetailResponse)
- async def get_project_detail_endpoint(
- project_id: str,
- _auth: dict = Depends(verify_open_api_token),
- ):
- """根据项目 ID 查询项目详细信息"""
- try:
- result = get_project_detail(project_id)
- except Exception as e:
- logger.error(f"Query project detail failed: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail={"error_code": "INTERNAL_ERROR", "message": str(e)},
- )
- if not result:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
- )
- return OpenProjectDetailResponse(
- code=0,
- message="success",
- data=OpenProjectDetailItem(**result),
- )
- # --- Dataset download router ---
- download_router = APIRouter(
- prefix="/api/v1/open",
- tags=["open-api-datasets"],
- )
- @download_router.post("/projects/{project_id}/datasets/download", response_model=DatasetDownloadResponse)
- async def download_dataset(
- project_id: str,
- req: DatasetDownloadRequest,
- _auth: dict = Depends(verify_open_api_token),
- ):
- """根据项目 ID 和格式导出标注数据集"""
- # 1. Check project exists
- project = _check_project(project_id)
- if not project:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail={"error_code": "PROJECT_NOT_FOUND", "message": "项目不存在"},
- )
- # 2. Check format compatibility
- project_type = TASK_TO_PROJECT_TYPE.get(project["task_type"] or "", "text")
- if req.format in TEXT_FORMATS and project_type != "text":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
- )
- if req.format in IMAGE_FORMATS and project_type != "image":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={"error_code": "FORMAT_NOT_COMPATIBLE", "message": f"格式 {req.format.value} 不适用于 {project_type} 类型项目"},
- )
- # 3. Check data availability
- status_filter = "completed" if req.completed_only else "all"
- tasks = ExportService.get_tasks_with_annotations(project_id, status_filter)
- if not tasks:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail={"error_code": "NO_DATA_AVAILABLE", "message": "项目中没有可导出的数据"},
- )
- # 4. Execute export
- format_val = req.format.value
- export_method = {
- "json": ExportService.export_to_json,
- "csv": ExportService.export_to_csv,
- "coco": ExportService.export_to_coco,
- "yolo": ExportService.export_to_yolo,
- "pascal_voc": ExportService.export_to_pascal_voc,
- "sharegpt": ExportService.export_to_sharegpt,
- "alpaca": ExportService.export_to_alpaca,
- }.get(format_val)
- if not export_method:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail={"error_code": "INVALID_FORMAT", "message": f"不支持的导出格式: {format_val}"},
- )
- try:
- file_path, total_tasks, total_annotations = export_method(
- project_id=project_id,
- status_filter=status_filter,
- include_metadata=False,
- )
- except Exception as e:
- logger.error(f"Export failed for project {project_id}: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail={"error_code": "EXPORT_FAILED", "message": f"导出失败: {str(e)}"},
- )
- # 5. Generate download token
- now = datetime.now(timezone.utc)
- expires_at = now + timedelta(hours=2)
- download_token = create_download_token(
- file_path=file_path,
- project_id=project_id,
- format_val=format_val,
- total_exported=total_tasks,
- expires_at=expires_at,
- )
- file_name = f"{project_id}_{format_val}_{now.strftime('%Y%m%d_%H%M%S')}.{_get_ext(format_val)}"
- file_size = os.path.getsize(file_path) if os.path.exists(file_path) else None
- return DatasetDownloadResponse(
- code=0,
- message="success",
- data=DatasetDownloadResponseData(
- project_id=project_id,
- format=format_val,
- total_exported=total_tasks,
- file_url=f"/api/v1/open/datasets/downloads/{download_token}",
- file_name=file_name,
- file_size=file_size,
- expires_at=expires_at,
- status="completed",
- ),
- )
- @download_router.get("/datasets/downloads/{download_token}")
- async def download_file(
- download_token: str,
- _auth: dict = Depends(verify_open_api_token),
- ):
- """根据下载令牌获取实际的数据集文件"""
- info = get_download_info(download_token)
- if not info:
- raise HTTPException(
- status_code=status.HTTP_410_GONE,
- detail={"error_code": "DOWNLOAD_EXPIRED", "message": "下载链接已过期"},
- )
- if not os.path.exists(info["file_path"]):
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail={"error_code": "FILE_NOT_FOUND", "message": "导出文件不存在"},
- )
- filename = os.path.basename(info["file_path"])
- media_type = _get_media_type(info["format_val"])
- return FileResponse(
- path=info["file_path"],
- media_type=media_type,
- filename=filename,
- )
- def _check_project(project_id: str) -> Optional[dict]:
- """Check project exists and is in allowed status."""
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(
- "SELECT id, name, task_type, status FROM projects WHERE id = %s",
- (project_id,),
- )
- row = cursor.fetchone()
- if row:
- return dict(row)
- return None
- def _get_ext(format_val: str) -> str:
- return "xml" if format_val == "pascal_voc" else "json" if format_val == "coco" else "csv" if format_val == "csv" else "json"
- def _get_media_type(format_val: str) -> str:
- return "text/csv" if format_val == "csv" else "application/xml" if format_val == "pascal_voc" else "application/json"
|