Browse Source

修复标注中心接口,前端非文本不支持ppo和dpo

lxylxy123321 20 hours ago
parent
commit
ce484efe8e
4 changed files with 125 additions and 101 deletions
  1. 1 1
      backend/app/config.py
  2. 73 43
      backend/app/services/annotation_platform_service.py
  3. 11 2
      frontend/src/pages/Training.tsx
  4. 40 55
      result.txt

+ 1 - 1
backend/app/config.py

@@ -113,7 +113,7 @@ class Settings(BaseSettings):
     compute_node_ssh_timeout: int = 300  # SSH 命令超时(秒)
 
     # --- 标注平台 ---
-    annotation_platform_base_url: str = "http://192.168.92.61:9003"
+    annotation_platform_base_url: str = "http://192.168.92.61:8003"
     annotation_platform_app_id: str = "nlKLQJdJK3f5ub7UDfQ_E71z2Lo3YSQx"
     annotation_platform_app_secret: str = "wh0HU_9T83rYMjfLFToNxFOKcrk_8H7Ba_27nNGlPqtTf9ROCytsOgp2ue0ol5mm"
 

+ 73 - 43
backend/app/services/annotation_platform_service.py

@@ -1,7 +1,8 @@
 """标注平台 API 客户端服务。
 
-对接标注平台的对外 API,支持 HMAC-SHA256 签名认证。
-功能:列出项目、获取项目详情、导出并下载数据集。
+对接标注平台的对外 API(HMAC-SHA256 签名认证)。
+参考文档:标注平台对外API接口文档.md
+功能:列出项目、获取项目详情、数据集导出与下载。
 """
 
 import hashlib
@@ -26,59 +27,74 @@ _token_cache: dict[str, Any] = {}
 
 
 def _get_base_url() -> str:
-    if not settings.annotation_platform_base_url:
-        raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL 环境变量")
-    return settings.annotation_platform_base_url.rstrip("/")
+    base_url = settings.annotation_platform_base_url
+    if not base_url:
+        raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
+    return base_url.rstrip("/")
 
 
 def _get_credentials() -> tuple[str, str]:
-    if not settings.annotation_platform_app_id or not settings.annotation_platform_app_secret:
+    app_id = settings.annotation_platform_app_id
+    app_secret = settings.annotation_platform_app_secret
+    if not app_id or not app_secret:
         raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
-    return settings.annotation_platform_app_id, settings.annotation_platform_app_secret
+    return app_id, app_secret
 
 
-def _sign(app_secret: str, app_id: str, timestamp: str, nonce: str) -> str:
-    """HMAC-SHA256 签名。"""
+def _build_token_headers() -> dict[str, str]:
+    """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
+    app_id, app_secret = _get_credentials()
+    timestamp = str(int(time.time()))
+    nonce = uuid.uuid4().hex  # 使用 uuid4 保证每次唯一,16+ 位随机字符串
     message = app_id + timestamp + nonce
-    return hmac.new(app_secret.encode(), message.encode(), hashlib.sha256).hexdigest()
+    signature = hmac.new(
+        key=app_secret.encode("utf-8"),
+        msg=message.encode("utf-8"),
+        digestmod=hashlib.sha256,
+    ).hexdigest()
 
+    return {
+        "Content-Type": "application/json",
+        "X-Api-Key": app_id,
+        "X-Timestamp": timestamp,
+        "X-Nonce": nonce,
+        "X-Signature": signature,
+    }
 
-def _check_token_valid() -> bool:
+
+def _is_token_valid() -> bool:
+    """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
     if not _token_cache.get("access_token"):
         return False
     expires_at = _token_cache.get("expires_at", 0)
-    return time.time() < expires_at - 300  # 提前 5 分钟刷新
+    return time.time() < expires_at - 300
 
 
 async def get_token() -> str:
-    """获取 Access Token,带缓存。"""
-    if _check_token_valid():
+    """获取 Access Token,带缓存。
+
+    POST /api/v1/open/auth/token
+    使用 HMAC-SHA256 签名认证,无请求体。
+    """
+    if _is_token_valid():
         return _token_cache["access_token"]
 
-    app_id, app_secret = _get_credentials()
+    headers = _build_token_headers()
     base_url = _get_base_url()
 
-    timestamp = str(int(time.time()))
-    nonce = secrets.token_hex(8)  # 16 位十六进制随机字符串
-    signature = _sign(app_secret, app_id, timestamp, nonce)
-
     async with httpx.AsyncClient(timeout=30) as client:
         resp = await client.post(
             f"{base_url}/api/v1/open/auth/token",
-            headers={
-                "X-Api-Key": app_id,
-                "X-Signature": signature,
-                "X-Timestamp": timestamp,
-                "X-Nonce": nonce,
-            },
+            headers=headers,
         )
         resp.raise_for_status()
         body = resp.json()
 
+    # 标注平台返回 code: 0 表示成功
     if body.get("code") != 0:
-        raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message')}")
+        raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
 
-    data = body["data"]
+    data = body.get("data", {})
     _token_cache["access_token"] = data["access_token"]
     _token_cache["expires_in"] = data.get("expires_in", 7200)
     _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
@@ -87,15 +103,15 @@ async def get_token() -> str:
 
 
 def _auth_headers() -> dict[str, str]:
-    token = _token_cache.get("access_token", "")
+    """构建业务接口的认证请求头。"""
     return {
-        "Authorization": f"Bearer {token}",
+        "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
         "Content-Type": "application/json",
     }
 
 
 async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
-    """统一的请求方法,自动携带 Token 并处理错误。"""
+    """统一的业务请求方法,自动携带 Token。"""
     await get_token()
     base_url = _get_base_url()
 
@@ -110,7 +126,7 @@ async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
         body = resp.json()
 
     if body.get("code") != 0:
-        raise RuntimeError(f"标注平台请求失败: {body.get('message')}")
+        raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
 
     return body.get("data", {})
 
@@ -124,7 +140,10 @@ async def list_projects(
     project_type: str | None = None,
     status: str | None = None,
 ) -> dict[str, Any]:
-    """获取标注平台项目列表。"""
+    """获取标注平台项目列表。
+
+    GET /api/v1/open/projects
+    """
     params: dict[str, Any] = {"page": page, "page_size": page_size}
     if name:
         params["name"] = name
@@ -139,7 +158,10 @@ async def list_projects(
 # ---------- 项目详情 ----------
 
 async def get_project_detail(project_id: str) -> dict[str, Any]:
-    """获取项目详情。"""
+    """获取项目详情。
+
+    GET /api/v1/open/projects/{project_id}
+    """
     return await _request("GET", f"/api/v1/open/projects/{project_id}")
 
 
@@ -153,9 +175,10 @@ async def import_project_dataset(
     """导出并下载项目数据集,保存到本地并写入数据库。
 
     流程:
-    1. POST 请求导出 → 获取 file_url
-    2. GET 下载文件 → 保存到 uploads 目录
-    3. 写入 DatasetRecord 数据库
+    1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
+    2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
+    3. 保存到 uploads 目录
+    4. 写入 DatasetRecord 数据库
     """
     # 1. 请求导出
     export_data = await _request(
@@ -171,21 +194,28 @@ async def import_project_dataset(
     if not file_url:
         raise RuntimeError("标注平台未返回下载链接")
 
-    # 2. 下载文件
+    # 2. 从 file_url 中提取 download_token
+    # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123
+    if "/datasets/downloads/" in file_url:
+        download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
+    else:
+        # 兜底:直接使用 file_url 的最后一段
+        download_token = file_url.rstrip("/").split("/")[-1]
+
+    # 3. 通过独立的下载接口获取文件(文档 4.6 节)
     await get_token()
     base_url = _get_base_url()
-    download_url = f"{base_url}{file_url}" if file_url.startswith("/") else file_url
 
     async with httpx.AsyncClient(timeout=120) as client:
         resp = await client.get(
-            download_url,
-            headers={"Authorization": f"Bearer {_token_cache.get('access_token', '')}"},
+            f"{base_url}/api/v1/open/datasets/downloads/{download_token}",
+            headers=_auth_headers(),
             follow_redirects=True,
         )
         resp.raise_for_status()
         file_content = resp.content
 
-    # 3. 保存到 uploads 目录
+    # 4. 保存到 uploads 目录
     upload_dir = settings.uploads_dir
     upload_dir.mkdir(parents=True, exist_ok=True)
 
@@ -198,11 +228,11 @@ async def import_project_dataset(
 
     file_path.write_bytes(file_content)
 
-    # 4. 检测格式和记录数
+    # 5. 检测格式和记录数
     fmt = _detect_format(file_path.name)
     record_count = _count_records(file_path, fmt)
 
-    # 5. 写入数据库
+    # 6. 写入数据库
     record_id = str(uuid.uuid4())
     record = DatasetRecord(
         id=record_id,

+ 11 - 2
frontend/src/pages/Training.tsx

@@ -481,7 +481,7 @@ export function Training() {
 
     api.training.create({
       model_id: modelId, model_type: modelType, dataset_id: datasetId,
-      peft_method: peftMethod, task_type: taskType, dataset_template: template,
+      peft_method: peftMethod, task_type: modelType === 'text' ? taskType : 'sft', dataset_template: template,
       epochs, batch_size: batchSize, gradient_accumulation: gradAcc,
       max_seq_length: seqLen, learning_rate: parseFloat(lr),
       lora_r: loraR, lora_alpha: loraR * 2, num_gpus: numGpus,
@@ -545,7 +545,16 @@ export function Training() {
           </div>
           <div>
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练方法</label>
-            <Select options={TASK_TYPES} value={taskType} onChange={v => setTaskType(String(v))} />
+            {modelType === 'text' ? (
+              <Select options={TASK_TYPES} value={taskType} onChange={v => setTaskType(String(v))} />
+            ) : (
+              <>
+                <Select options={[{ value: 'sft', label: 'SFT (监督微调)' }]} value="sft" onChange={() => {}} />
+                <div style={{ fontSize: 11, color: '#f59e0b', marginTop: 4 }}>
+                  视觉/多模态模型仅支持 SFT 训练
+                </div>
+              </>
+            )}
           </div>
           <div>
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据模板</label>

+ 40 - 55
result.txt

@@ -1,56 +1,41 @@
-(base) [root@localhost ~]# docker exec finetune-trainer cat /tmp/train_f3038ef4-bb2c-44e5-bba5-fc481d1415e8.log | grep -A 30 "Traceback"
-Traceback (most recent call last):
-  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/cextension.py", line 320, in <module>
-    lib = get_native_library()
-  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/cextension.py", line 288, in get_native_library
-    raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
-RuntimeError: Configured CUDA binary not found at /opt/conda/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda116.so
-[transformers] warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
-[transformers] warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
-trainable params: 5,070,848 || all params: 757,463,872 || trainable%: 0.6695
-Map: 100%|██████████| 5/5 [00:00<00:00, 155.69 examples/s]
-  0%|          | 0/1 [00:00<?, ?it/s]Training failed for job f3038ef4-bb2c-44e5-bba5-fc481d1415e8: DPOTrainer.compute_loss() got an unexpected keyword argument 'num_items_in_batch'
-[remote_train] [rank 0] ERROR: DPOTrainer.compute_loss() got an unexpected keyword argument 'num_items_in_batch'
-[remote_train] Traceback (most recent call last):
-  File "/root/Fine-tuning/backend/app/engines/remote_train.py", line 236, in run_training
-    adapter_path = await engine.train(
-  File "/root/Fine-tuning/backend/app/engines/text_engine.py", line 546, in train
-    trainer.train()
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1427, in train
-    return inner_training_loop(
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1509, in _inner_training_loop
-    self._run_epoch(
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1737, in _run_epoch
-    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1909, in training_step
-    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
-TypeError: DPOTrainer.compute_loss() got an unexpected keyword argument 'num_items_in_batch'
+import time
+import hmac
+import hashlib
+import requests
+import uuid
 
-[remote_train] === Training job failed: f3038ef4-bb2c-44e5-bba5-fc481d1415e8 ===
-Traceback (most recent call last):
-  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
-    return _run_code(code, main_globals, None,
-  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
-    exec(code, run_globals)
-  File "/root/Fine-tuning/backend/app/engines/remote_train.py", line 466, in <module>
-    main()
-  File "/root/Fine-tuning/backend/app/engines/remote_train.py", line 461, in main
-    asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config,
-  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
-    return loop.run_until_complete(main)
-  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
-    return future.result()
-  File "/root/Fine-tuning/backend/app/engines/remote_train.py", line 236, in run_training
-    adapter_path = await engine.train(
-  File "/root/Fine-tuning/backend/app/engines/text_engine.py", line 546, in train
-    trainer.train()
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1427, in train
-    return inner_training_loop(
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1509, in _inner_training_loop
-    self._run_epoch(
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1737, in _run_epoch
-    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1909, in training_step
-    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
-TypeError: DPOTrainer.compute_loss() got an unexpected keyword argument 'num_items_in_batch'
-  0%|          | 0/1 [00:12<?, ?it/s]
+# 1. 配置(直接填你的)
+app_id = "hmDeOtXZVbeo2AZ-x58yPssZLg4Tcb1W"
+app_secret = "pj9UirhGUFPsFnCizCz-Qo1dOGi3kxRIrDKKmJZu2aRCPgtTogTubDRW1weM4KNL"
+url = "http://192.168.92.61:8003/api/v1/open/auth/token"  # 注意带 /open/
+
+# 2. 生成参数
+timestamp = str(int(time.time()))  # 秒级时间戳
+nonce = uuid.uuid4().hex  # 随机字符串
+message = app_id + timestamp + nonce
+
+# 3. 计算 HMAC-SHA256 签名
+signature = hmac.new(
+    key=app_secret.encode("utf-8"),
+    msg=message.encode("utf-8"),
+    digestmod=hashlib.sha256
+).hexdigest()
+
+# 4. 构造请求头
+headers = {
+    "Content-Type": "application/json",
+    "X-Api-Key": app_id,
+    "X-Timestamp": timestamp,
+    "X-Nonce": nonce,
+    "X-Signature": signature
+}
+
+# 5. 发送请求(body 为空 {})
+body = {}
+resp = requests.post(url, json=body, headers=headers)
+
+# 6. 打印结果
+print("=== 请求信息 ===")
+print("headers:", headers)
+print("status_code:", resp.status_code)
+print("response:", resp.text)