datasets.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from fastapi import APIRouter, UploadFile, File, Query, HTTPException
  2. from app.schemas.dataset import (
  3. DatasetDownloadRequest,
  4. DatasetDownloadResponse,
  5. DatasetPreviewResponse,
  6. DatasetUploadResponse,
  7. DatasetValidationResult,
  8. )
  9. from app.schemas.background_task import DatasetDownloadTaskResponse
  10. from app.services import dataset_service
  11. router = APIRouter()
  12. @router.post("/download", response_model=DatasetDownloadResponse, status_code=200)
  13. async def download_dataset(req: DatasetDownloadRequest):
  14. """启动数据集下载后台任务,立即返回 task_id。"""
  15. result = await dataset_service.download_dataset(req)
  16. return result
  17. @router.get("/download/{task_id}")
  18. async def get_dataset_download_status(task_id: str):
  19. """查询数据集下载任务状态。"""
  20. result = await dataset_service.get_dataset_download_status(task_id)
  21. if result.get("status") == "not_found":
  22. raise HTTPException(status_code=404, detail="Download task not found")
  23. return result
  24. @router.get("/downloads")
  25. async def list_dataset_downloads():
  26. """列出所有数据集下载任务。"""
  27. return await dataset_service.list_dataset_downloads()
  28. @router.post("/download/{task_id}/cancel")
  29. async def cancel_dataset_download(task_id: str):
  30. """取消数据集下载任务。"""
  31. return await dataset_service.cancel_dataset_download(task_id)
  32. @router.post("/upload", response_model=DatasetUploadResponse, status_code=201)
  33. async def upload_dataset(file: UploadFile = File(...)):
  34. """上传数据集文件(JSONL / CSV / Parquet / JSON)。"""
  35. result = await dataset_service.upload_dataset(file)
  36. return DatasetUploadResponse(**result)
  37. @router.get("/{dataset_id}/preview", response_model=DatasetPreviewResponse)
  38. async def preview_dataset(dataset_id: str, rows: int = Query(default=10, le=100)):
  39. """预览数据集前 N 行。"""
  40. result = await dataset_service.preview_dataset(dataset_id, rows)
  41. return DatasetPreviewResponse(**result)
  42. @router.post("/{dataset_id}/validate", response_model=DatasetValidationResult)
  43. async def validate_dataset(dataset_id: str):
  44. """校验数据集格式和 Schema。"""
  45. result = await dataset_service.validate_dataset(dataset_id)
  46. return DatasetValidationResult(**result)
  47. @router.get("/", response_model=list[DatasetUploadResponse])
  48. async def list_datasets():
  49. """列出所有已上传数据集。"""
  50. items = await dataset_service.list_datasets()
  51. return [DatasetUploadResponse(**item) for item in items]
  52. @router.delete("/{dataset_id}", status_code=200)
  53. async def delete_dataset(dataset_id: str):
  54. """删除数据集。"""
  55. return await dataset_service.delete_dataset(dataset_id)