""" Open API project query service. Handles project listing, detail, and dataset export queries. """ import math import uuid import logging from datetime import datetime, timezone from typing import Optional, Dict from database import get_db_connection logger = logging.getLogger(__name__) # task_type → project_type 映射 TASK_TO_PROJECT_TYPE = { "text_classification": "text", "ner": "text", "image_classification": "image", "object_detection": "image", "polygon": "image", } # 开放 API 允许访问的项目状态 ALLOWED_STATUSES = ("ready", "in_progress", "completed") # Download token store: {token: {file_path, project_id, format_val, total_exported, expires_at}} _DOWNLOAD_STORE: Dict[str, Dict] = {} def _map_project_type(task_type: str) -> str: """将数据库 task_type 映射为 API 返回的 project_type。""" return TASK_TO_PROJECT_TYPE.get(task_type, "text") def list_projects( name: Optional[str] = None, project_type: Optional[str] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20, ) -> dict: """ 查询项目列表。 返回包含 items、total、分页信息的字典。 仅返回状态为 ready/in_progress/completed 的项目。 """ page = max(page, 1) page_size = min(max(page_size, 1), 100) conditions = ["p.status IN %s"] params = [ALLOWED_STATUSES] if name: conditions.append("p.name LIKE %s") params.append(f"%{name}%") if project_type: allowed_types = [t for t, pt in TASK_TO_PROJECT_TYPE.items() if pt == project_type] if allowed_types: conditions.append("p.task_type IN %s") params.append(tuple(allowed_types)) if status and status in ALLOWED_STATUSES: conditions.append("p.status = %s") params.append(status) where = " AND ".join(conditions) with get_db_connection() as conn: # Count total cursor = conn.cursor() cursor.execute(f"SELECT COUNT(*) AS cnt FROM projects p WHERE {where}", tuple(params)) total = cursor.fetchone()["cnt"] total_pages = max(math.ceil(total / page_size), 1) if total > 0 else 1 offset = (page - 1) * page_size # Fetch items cursor.execute( f""" SELECT p.id, p.name, p.description, p.task_type, p.status, p.created_at, p.updated_at, COUNT(t.id) AS task_count, SUM(CASE WHEN t.status = 'completed' THEN 1 ELSE 0 END) AS completed_task_count FROM projects p LEFT JOIN tasks t ON t.project_id = p.id WHERE {where} GROUP BY p.id ORDER BY p.updated_at DESC LIMIT %s OFFSET %s """, tuple(params) + (page_size, offset), ) rows = cursor.fetchall() items = [] for row in rows: items.append({ "project_id": row["id"], "project_name": row["name"], "description": row["description"] or "", "project_type": _map_project_type(row["task_type"] or ""), "task_type": row["task_type"] or "", "status": row["status"], "created_by": "", "created_at": row["created_at"], "updated_at": row["updated_at"], "task_count": int(row["task_count"]), "completed_task_count": int(row["completed_task_count"]), }) return { "items": items, "total": total, "page": page, "page_size": page_size, "total_pages": total_pages, "has_next": page < total_pages, "has_prev": page > 1, } def get_project_detail(project_id: str) -> Optional[dict]: """ 根据项目 ID 查询项目详情。 返回包含统计信息的完整项目信息。 """ with get_db_connection() as conn: cursor = conn.cursor() cursor.execute( """ SELECT p.id, p.name, p.description, p.task_type, p.status, p.created_at, p.updated_at, COUNT(t.id) AS task_count, SUM(CASE WHEN t.status = 'completed' THEN 1 ELSE 0 END) AS completed_task_count, SUM(CASE WHEN t.assigned_to IS NOT NULL THEN 1 ELSE 0 END) AS assigned_task_count FROM projects p LEFT JOIN tasks t ON t.project_id = p.id WHERE p.id = %s GROUP BY p.id """, (project_id,), ) row = cursor.fetchone() if not row: return None total = max(row["task_count"], 1) completed = row["completed_task_count"] return { "project_id": row["id"], "project_name": row["name"], "description": row["description"] or "", "project_type": _map_project_type(row["task_type"] or ""), "task_type": row["task_type"] or "", "status": row["status"], "created_by": "", "created_at": row["created_at"], "updated_at": row["updated_at"], "task_count": int(row["task_count"]), "completed_task_count": int(row["completed_task_count"]), "assigned_task_count": int(row["assigned_task_count"]), "completion_percentage": round(completed / total * 100, 1), } # --- Download token management --- def create_download_token( file_path: str, project_id: str, format_val: str, total_exported: int, expires_at: datetime, ) -> str: """Create a download token and store the download info.""" token = f"dl_{uuid.uuid4().hex[:12]}" _DOWNLOAD_STORE[token] = { "file_path": file_path, "project_id": project_id, "format_val": format_val, "total_exported": total_exported, "expires_at": expires_at, } _cleanup_expired_tokens() return token def get_download_info(token: str) -> Optional[Dict]: """Get download info for a token. Returns None if expired or not found.""" info = _DOWNLOAD_STORE.get(token) if not info: return None if info["expires_at"] < datetime.now(timezone.utc): del _DOWNLOAD_STORE[token] return None return info def _cleanup_expired_tokens(): """Remove expired tokens.""" now = datetime.now(timezone.utc) expired = [t for t, info in _DOWNLOAD_STORE.items() if info["expires_at"] < now] for t in expired: del _DOWNLOAD_STORE[t]