"""
Template API router.
Provides endpoints for template management and XML configuration validation.
"""
from typing import List, Optional
from fastapi import APIRouter, HTTPException, status, Query, Request
from schemas.template import (
TemplateResponse,
TemplateListResponse,
TemplateCategory,
TemplateCategoryListResponse,
ConfigValidationRequest,
ConfigValidationResponse,
ValidationError
)
import xml.etree.ElementTree as ET
import re
router = APIRouter(
prefix="/api/templates",
tags=["templates"]
)
# 预设模板类别
TEMPLATE_CATEGORIES = [
TemplateCategory(
id="image_classification",
name="图像分类",
description="对图像进行分类标注",
icon="image"
),
TemplateCategory(
id="object_detection",
name="目标检测",
description="在图像中标注目标边界框",
icon="box"
),
TemplateCategory(
id="image_segmentation",
name="图像分割",
description="对图像进行像素级分割标注",
icon="layers"
),
TemplateCategory(
id="text_classification",
name="文本分类",
description="对文本进行分类标注",
icon="file-text"
),
TemplateCategory(
id="ner",
name="命名实体识别",
description="标注文本中的命名实体",
icon="tag"
),
TemplateCategory(
id="text_labeling",
name="文本标注",
description="对文本进行通用标注",
icon="edit"
),
TemplateCategory(
id="audio_transcription",
name="音频转写",
description="将音频转写为文本",
icon="mic"
),
TemplateCategory(
id="video_annotation",
name="视频标注",
description="对视频进行标注",
icon="video"
),
]
# 预设模板
PREDEFINED_TEMPLATES = [
# 图像分类模板
TemplateResponse(
id="image_classification_basic",
name="图像分类 - 基础",
category="image_classification",
description="简单的图像分类任务,支持单选分类",
config="""
""",
tags=["图像", "分类", "单选"]
),
TemplateResponse(
id="image_classification_multi",
name="图像分类 - 多标签",
category="image_classification",
description="支持多标签的图像分类任务",
config="""
""",
tags=["图像", "分类", "多选"]
),
# 目标检测模板
TemplateResponse(
id="object_detection_bbox",
name="目标检测 - 边界框",
category="object_detection",
description="使用矩形框标注图像中的目标",
config="""
""",
tags=["图像", "检测", "边界框"]
),
TemplateResponse(
id="object_detection_keypoint",
name="目标检测 - 关键点",
category="object_detection",
description="使用关键点标注图像中的目标",
config="""
""",
tags=["图像", "检测", "关键点"]
),
# 图像分割模板
TemplateResponse(
id="image_segmentation_polygon",
name="图像分割 - 多边形",
category="image_segmentation",
description="使用多边形进行图像分割标注",
config="""
""",
tags=["图像", "分割", "多边形"]
),
TemplateResponse(
id="image_segmentation_brush",
name="图像分割 - 画笔",
category="image_segmentation",
description="使用画笔工具进行图像分割标注",
config="""
""",
tags=["图像", "分割", "画笔"]
),
# 文本分类模板
TemplateResponse(
id="text_classification_sentiment",
name="文本分类 - 情感分析",
category="text_classification",
description="对文本进行情感分类",
config="""
""",
tags=["文本", "分类", "情感"]
),
TemplateResponse(
id="text_classification_topic",
name="文本分类 - 主题分类",
category="text_classification",
description="对文本进行主题分类",
config="""
""",
tags=["文本", "分类", "主题"]
),
# 命名实体识别模板
TemplateResponse(
id="ner_basic",
name="命名实体识别 - 基础",
category="ner",
description="标注文本中的命名实体",
config="""
""",
tags=["文本", "NER", "实体"]
),
TemplateResponse(
id="ner_relation",
name="命名实体识别 - 关系抽取",
category="ner",
description="标注实体及其关系",
config="""
""",
tags=["文本", "NER", "关系"]
),
# 文本标注模板
TemplateResponse(
id="text_labeling_qa",
name="文本标注 - 问答对",
category="text_labeling",
description="标注问答对数据",
config="""
""",
tags=["文本", "问答", "标注"]
),
TemplateResponse(
id="text_labeling_summary",
name="文本标注 - 摘要生成",
category="text_labeling",
description="为文本生成摘要",
config="""
""",
tags=["文本", "摘要", "标注"]
),
# 音频转写模板
TemplateResponse(
id="audio_transcription_basic",
name="音频转写 - 基础",
category="audio_transcription",
description="将音频转写为文本",
config="""
""",
tags=["音频", "转写", "ASR"]
),
TemplateResponse(
id="audio_transcription_segment",
name="音频转写 - 分段标注",
category="audio_transcription",
description="对音频进行分段转写",
config="""
""",
tags=["音频", "转写", "分段"]
),
# 视频标注模板
TemplateResponse(
id="video_annotation_basic",
name="视频标注 - 基础",
category="video_annotation",
description="对视频进行基础标注",
config="""
""",
tags=["视频", "标注", "分类"]
),
TemplateResponse(
id="video_annotation_bbox",
name="视频标注 - 目标跟踪",
category="video_annotation",
description="在视频中跟踪目标",
config="""
""",
tags=["视频", "跟踪", "检测"]
),
]
def validate_xml_config(config: str) -> ConfigValidationResponse:
"""
验证 XML 配置的有效性。
Args:
config: XML 配置字符串
Returns:
ConfigValidationResponse 包含验证结果
"""
errors = []
# 检查是否为空
if not config or not config.strip():
errors.append(ValidationError(
line=1,
column=1,
message="配置不能为空"
))
return ConfigValidationResponse(valid=False, errors=errors)
try:
# 尝试解析 XML
ET.fromstring(config)
return ConfigValidationResponse(valid=True, errors=[])
except ET.ParseError as e:
# 解析错误信息
error_msg = str(e)
line = 1
column = 1
# 尝试从错误信息中提取行号和列号
match = re.search(r'line (\d+), column (\d+)', error_msg)
if match:
line = int(match.group(1))
column = int(match.group(2))
errors.append(ValidationError(
line=line,
column=column,
message=f"XML 解析错误: {error_msg}"
))
return ConfigValidationResponse(valid=False, errors=errors)
@router.get("", response_model=TemplateListResponse)
async def list_templates(
request: Request,
category: Optional[str] = Query(None, description="按类别筛选"),
search: Optional[str] = Query(None, description="搜索关键词")
):
"""
获取模板列表。
支持按类别筛选和关键词搜索。
"""
templates = PREDEFINED_TEMPLATES.copy()
# 按类别筛选
if category:
templates = [t for t in templates if t.category == category]
# 按关键词搜索
if search:
search_lower = search.lower()
templates = [
t for t in templates
if search_lower in t.name.lower()
or search_lower in t.description.lower()
or any(search_lower in tag.lower() for tag in t.tags)
]
return TemplateListResponse(
templates=templates,
total=len(templates)
)
@router.get("/categories", response_model=TemplateCategoryListResponse)
async def list_categories(request: Request):
"""
获取模板类别列表。
"""
return TemplateCategoryListResponse(categories=TEMPLATE_CATEGORIES)
@router.get("/{template_id}", response_model=TemplateResponse)
async def get_template(request: Request, template_id: str):
"""
获取指定模板详情。
"""
for template in PREDEFINED_TEMPLATES:
if template.id == template_id:
return template
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"模板 '{template_id}' 不存在"
)
@router.post("/validate", response_model=ConfigValidationResponse)
async def validate_config(request: Request, validation_request: ConfigValidationRequest):
"""
验证 XML 配置的有效性。
检查 XML 语法是否正确,返回验证结果和错误信息。
"""
return validate_xml_config(validation_request.config)