"""
Template API tests.
Tests for template management and XML validation endpoints.
"""
import pytest
import uuid
from fastapi.testclient import TestClient
from main import app
from database import init_database
from services.jwt_service import JWTService
import bcrypt
from database import get_db_connection
# 测试客户端
client = TestClient(app)
@pytest.fixture(scope="module")
def setup_database():
"""初始化测试数据库"""
init_database()
yield
@pytest.fixture
def auth_token(setup_database):
"""创建测试用户并返回 token"""
user_id = f"user_{uuid.uuid4().hex[:8]}"
password_hash = bcrypt.hashpw("test123".encode(), bcrypt.gensalt()).decode()
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO users (id, username, email, password_hash, role)
VALUES (?, ?, ?, ?, 'annotator')
""", (user_id, f"test_user_{user_id}", f"user_{user_id}@test.com", password_hash))
user_data = {
"id": user_id,
"username": f"test_user_{user_id}",
"email": f"user_{user_id}@test.com",
"role": "annotator"
}
token = JWTService.create_access_token(user_data)
yield token
# 清理
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
class TestTemplateList:
"""模板列表测试"""
def test_list_templates_without_auth(self, setup_database):
"""未认证时应返回 401"""
response = client.get("/api/templates")
assert response.status_code == 401
def test_list_templates(self, auth_token):
"""获取模板列表"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates", headers=headers)
assert response.status_code == 200
data = response.json()
assert "templates" in data
assert "total" in data
assert len(data["templates"]) > 0
assert data["total"] > 0
def test_list_templates_by_category(self, auth_token):
"""按类别筛选模板"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates?category=image_classification", headers=headers)
assert response.status_code == 200
data = response.json()
# 所有返回的模板都应该是 image_classification 类别
for template in data["templates"]:
assert template["category"] == "image_classification"
def test_list_templates_with_search(self, auth_token):
"""搜索模板"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates?search=图像", headers=headers)
assert response.status_code == 200
data = response.json()
# 搜索结果应该包含匹配的模板
assert len(data["templates"]) > 0
class TestTemplateCategories:
"""模板类别测试"""
def test_list_categories_without_auth(self, setup_database):
"""未认证时应返回 401"""
response = client.get("/api/templates/categories")
assert response.status_code == 401
def test_list_categories(self, auth_token):
"""获取模板类别列表"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates/categories", headers=headers)
assert response.status_code == 200
data = response.json()
assert "categories" in data
assert len(data["categories"]) > 0
# 检查类别结构
category = data["categories"][0]
assert "id" in category
assert "name" in category
assert "description" in category
class TestTemplateDetail:
"""模板详情测试"""
def test_get_template_without_auth(self, setup_database):
"""未认证时应返回 401"""
response = client.get("/api/templates/image_classification_basic")
assert response.status_code == 401
def test_get_template(self, auth_token):
"""获取模板详情"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates/image_classification_basic", headers=headers)
assert response.status_code == 200
data = response.json()
assert data["id"] == "image_classification_basic"
assert "name" in data
assert "category" in data
assert "config" in data
assert "description" in data
def test_get_nonexistent_template(self, auth_token):
"""获取不存在的模板应返回 404"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.get("/api/templates/nonexistent_template", headers=headers)
assert response.status_code == 404
class TestConfigValidation:
"""配置验证测试"""
def test_validate_config_without_auth(self, setup_database):
"""未认证时应返回 401"""
response = client.post(
"/api/templates/validate",
json={"config": ""}
)
assert response.status_code == 401
def test_validate_valid_config(self, auth_token):
"""验证有效的 XML 配置"""
headers = {"Authorization": f"Bearer {auth_token}"}
valid_config = """
"""
response = client.post(
"/api/templates/validate",
json={"config": valid_config},
headers=headers
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is True
assert len(data["errors"]) == 0
def test_validate_invalid_config(self, auth_token):
"""验证无效的 XML 配置"""
headers = {"Authorization": f"Bearer {auth_token}"}
invalid_config = """
"""
response = client.post(
"/api/templates/validate",
json={"config": invalid_config},
headers=headers
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert len(data["errors"]) > 0
# 检查错误结构
error = data["errors"][0]
assert "line" in error
assert "column" in error
assert "message" in error
def test_validate_empty_config(self, auth_token):
"""验证空配置"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.post(
"/api/templates/validate",
json={"config": ""},
headers=headers
)
# 空字符串应该被 Pydantic 验证拒绝
assert response.status_code == 422
def test_validate_whitespace_config(self, auth_token):
"""验证只有空白的配置"""
headers = {"Authorization": f"Bearer {auth_token}"}
response = client.post(
"/api/templates/validate",
json={"config": " "},
headers=headers
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
class TestTemplateContent:
"""模板内容测试"""
def test_all_templates_have_valid_config(self, auth_token):
"""所有预设模板的配置都应该是有效的 XML"""
headers = {"Authorization": f"Bearer {auth_token}"}
# 获取所有模板
response = client.get("/api/templates", headers=headers)
assert response.status_code == 200
templates = response.json()["templates"]
# 验证每个模板的配置
for template in templates:
validate_response = client.post(
"/api/templates/validate",
json={"config": template["config"]},
headers=headers
)
assert validate_response.status_code == 200
validation_result = validate_response.json()
assert validation_result["valid"] is True, \
f"模板 {template['id']} 的配置无效: {validation_result['errors']}"