import logging import json import httpx from datetime import datetime, date from typing import List, Dict, Any, Tuple, Optional from app.base.async_mysql_connection import get_db_connection from app.base.minio_connection import get_minio_manager logger = logging.getLogger(__name__) class TaskService: """任务管理服务类""" _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次 def __init__(self): self.minio_manager = get_minio_manager() async def _ensure_table_schema(self, cursor, conn): """确保表结构和索引正确 (DDL 操作)""" if TaskService._schema_verified: return try: # 1. 动态维护字段 cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type") cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag") cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id") cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status") cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count") conn.commit() # 2. 处理索引冲突 cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'") if not cursor.fetchone(): cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata") logger.info("Added column file_url to t_task_management") cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'") indexes = cursor.fetchall() for idx in indexes: if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1: cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'") if len(cursor.fetchall()) == 1: cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management") logger.info(f"Dropped old unique index: {idx['Key_name']}") cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'") if not cursor.fetchone(): cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)") logger.info("Created new composite unique index: uk_business_project") conn.commit() TaskService._schema_verified = True except Exception as e: logger.warning(f"表结构维护失败: {e}") conn.rollback() async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]: """添加或更新任务记录 (适配单表结构)""" conn = get_db_connection() if not conn: return False, "数据库连接失败", None cursor = conn.cursor() try: # 确保表结构(仅在第一次调用时执行) await self._ensure_table_schema(cursor, conn) # 执行插入/更新 sql = """ INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE task_id = IFNULL(VALUES(task_id), task_id), project_id = IFNULL(VALUES(project_id), project_id), project_name = IFNULL(VALUES(project_name), project_name), annotation_status = IFNULL(VALUES(annotation_status), annotation_status), tag = IFNULL(VALUES(tag), tag), metadata = IFNULL(VALUES(metadata), metadata) """ cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata)) record_id = cursor.lastrowid conn.commit() return True, "成功", record_id except Exception as e: conn.rollback() logger.exception(f"添加任务失败: {e}") return False, str(e), None finally: cursor.close() conn.close() async def delete_task_by_id(self, id: int) -> Tuple[bool, str]: """根据主键 id 删除任务记录""" conn = get_db_connection() if not conn: return False, "数据库连接失败" cursor = conn.cursor() try: sql = "DELETE FROM t_task_management WHERE id = %s" cursor.execute(sql, (id,)) conn.commit() return True, "成功" except Exception as e: conn.rollback() logger.exception(f"删除任务失败: {e}") return False, str(e) finally: cursor.close() conn.close() def _serialize_datetime(self, obj: Any) -> Any: """递归遍历对象,将 datetime 转换为字符串""" if isinstance(obj, dict): return {k: self._serialize_datetime(v) for k, v in obj.items()} elif isinstance(obj, list): return [self._serialize_datetime(i) for i in obj] elif isinstance(obj, (datetime, date)): return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d') return obj async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]: """获取项目列表 (按 project_id 聚合)""" conn = get_db_connection() if not conn or task_type not in ['data', 'image']: return [] cursor = conn.cursor() try: # 确保表结构 await self._ensure_table_schema(cursor, conn) # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字) sql = """ SELECT project_id, MAX(project_name) as project_name, MAX(task_id) as task_id, MAX(tag) as tag, MAX(id) as sort_id, COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count, COALESCE( MAX(external_completed_count), SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END), 0 ) as completed_count FROM t_task_management WHERE type = %s GROUP BY project_id ORDER BY sort_id DESC """ cursor.execute(sql, (task_type,)) rows = cursor.fetchall() # 兼容旧数据:如果 project_name 为空,则用 project_id 代替 for row in rows: if not row['project_name']: row['project_name'] = row['project_id'] return self._serialize_datetime(rows) except Exception as e: logger.error(f"获取任务列表失败: {e}") return [] finally: cursor.close() conn.close() async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]: """获取项目详情 (按 project_id 查询)""" conn = get_db_connection() if not conn: return [] cursor = conn.cursor() try: # 确保表结构 await self._ensure_table_schema(cursor, conn) # 修改查询逻辑:获取 project_name 并在结果中包含它 sql = """ SELECT id, business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata FROM t_task_management WHERE project_id = %s AND type = %s """ cursor.execute(sql, (project_id, task_type)) rows = cursor.fetchall() # 处理 tag 和 metadata 的 JSON 解析 for row in rows: if row.get('tag'): try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag'] except: pass if row.get('metadata'): try: meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata'] row['metadata'] = meta # 提取名称供前端显示 if row['type'] == 'data': row['name'] = meta.get('title') or meta.get('filename') or row['business_id'] elif row['type'] == 'image': row['name'] = meta.get('image_name') or row['business_id'] else: row['name'] = row['business_id'] except: row['name'] = row['business_id'] else: row['name'] = row['business_id'] return self._serialize_datetime(rows) except Exception as e: logger.error(f"获取项目详情失败: {e}") return [] finally: cursor.close() conn.close() async def delete_task(self, business_id: str) -> Tuple[bool, str]: """删除任务记录""" conn = get_db_connection() if not conn: return False, "数据库连接失败" cursor = conn.cursor() try: sql = "DELETE FROM t_task_management WHERE business_id = %s" cursor.execute(sql, (business_id,)) conn.commit() return True, "删除成功" except Exception as e: logger.exception(f"删除任务记录失败: {e}") conn.rollback() return False, f"删除失败: {str(e)}" finally: cursor.close() conn.close() async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]: """创建标注项目并同步任务数据""" project_name = data.get('name') if not project_name: return False, "项目名称不能为空" # 0. 统一使用 UUID 方案获取或生成 project_id conn = get_db_connection() if not conn: return False, "数据库连接失败" cursor = conn.cursor() try: project_id = None cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,)) existing_project = cursor.fetchone() if existing_project: project_id = existing_project['project_id'] else: import uuid project_id = str(uuid.uuid4()) task_type = data.get('task_type', 'data') # 映射回内部类型 internal_type_map = { 'text_classification': 'data', 'image_classification': 'image' } internal_task_type = internal_type_map.get(task_type, task_type) tasks_data = data.get('data', []) if not tasks_data: return False, "任务数据不能为空" # 提取全局标签名列表 global_tags = [] if data.get('tags'): global_tags = [t['tag'] for t in data['tags'] if 'tag' in t] # 批量写入任务 import json tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None for item in tasks_data: business_id = item.get('id') if not business_id: continue # 检查是否已存在 (使用联合主键逻辑) cursor.execute( "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s", (business_id, project_id) ) if cursor.fetchone(): continue metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False) sql = """ INSERT INTO t_task_management (business_id, type, project_id, project_name, tag, metadata, annotation_status) VALUES (%s, %s, %s, %s, %s, %s, 'pending') """ cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str)) conn.commit() # 自动推送至外部平台 success, msg = await self.send_to_external_platform(project_id) if success: return True, project_id else: return False, f"任务已保存但推送失败: {msg}" except Exception as e: logger.exception(f"创建项目失败: {e}") conn.rollback() return False, str(e) finally: cursor.close() conn.close() async def get_project_progress(self, project_id: str) -> Dict[str, Any]: """获取项目进度统计 (单表化)""" conn = get_db_connection() if not conn: return {} cursor = conn.cursor() try: # 统计各状态数量 sql = """ SELECT annotation_status, COUNT(*) as count FROM t_task_management WHERE project_id = %s GROUP BY annotation_status """ cursor.execute(sql, (project_id,)) stats = cursor.fetchall() if not stats: return {} total = sum(s['count'] for s in stats) completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed') return { "project_id": project_id, "total": total, "completed": completed, "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%", "details": stats } except Exception as e: logger.error(f"查询进度失败: {e}") return {} finally: cursor.close() conn.close() def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]: """ 从 Milvus 获取文档分片内容 """ if not kb_info: return [] # 优先使用数据库存储的子表名,父表名作为兜底 collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c] if not collections: return [] from app.services.milvus_service import milvus_service contents = [] # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用 if not hasattr(self, '_collection_schema_cache'): self._collection_schema_cache = {} for coll_name in collections: try: if not milvus_service.client.has_collection(coll_name): continue # 获取或更新缓存的字段名 if coll_name not in self._collection_schema_cache: schema = milvus_service.client.describe_collection(coll_name) field_names = [f['name'] for f in schema.get('fields', [])] self._collection_schema_cache[coll_name] = { "id": "document_id", # 统一使用 document_id "content": "text" if "text" in field_names else "content" } fields = self._collection_schema_cache[coll_name] id_field = fields["id"] content_field = fields["content"] res = milvus_service.client.query( collection_name=coll_name, filter=f'{id_field} == "{task_id}"', output_fields=[content_field] ) if res: for s in res: val = s.get(content_field) if val: contents.append(val) if contents: return contents except Exception as e: logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}") continue return [] async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]: """导出项目数据为标注平台要求的格式 (单表化)""" should_close = False if not conn: conn = get_db_connection() should_close = True if not conn: return {} cursor = conn.cursor() try: # 1. 获取任务记录 sql_tasks = """ SELECT business_id as id, type, task_id, project_name, tag, metadata FROM t_task_management WHERE project_id = %s """ cursor.execute(sql_tasks, (project_id,)) rows = cursor.fetchall() if not rows: return {} # 2. 解析基本信息 first_row = rows[0] internal_task_type = first_row['type'] project_name = first_row.get('project_name') or project_id remote_project_id = first_row.get('task_id') or project_id # 映射任务类型 type_map = {'data': 'text_classification', 'image': 'image_classification'} external_task_type = type_map.get(internal_task_type, internal_task_type) # 3. 处理数据 final_tasks = [] all_project_tags = set() # 针对 'data' 类型的批量 Milvus 查询优化 milvus_data_map = {} if internal_task_type == 'data': all_task_ids = [r['id'] for r in rows] sql_kb = """ SELECT kb.collection_name_parent, kb.collection_name_children FROM t_samp_document_main d LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id WHERE d.id COLLATE utf8mb4_unicode_ci = %s COLLATE utf8mb4_unicode_ci """ cursor.execute(sql_kb, (all_task_ids[0],)) kb_info = cursor.fetchone() if kb_info: milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info) for item in rows: task_id = item['id'] # 记录原始数据状态 logger.debug(f"正在处理导出任务 {task_id}, tag: {item.get('tag')}, metadata_keys: {list(json.loads(item['metadata']).keys()) if item.get('metadata') else 'None'}") # 提取并处理标签 doc_tags = [] if item.get('tag'): try: doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag'] if doc_tags: for t in doc_tags: all_project_tags.add(t) except: pass # 解析数据库元数据 db_metadata = {} if item.get('metadata'): try: db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata'] except: pass # 获取任务内容 task_contents = [] annotation_results = [] if internal_task_type == 'data': task_contents = milvus_data_map.get(task_id, []) if not task_contents: cursor.execute("SELECT title FROM t_samp_document_main WHERE id = %s", (task_id,)) res = cursor.fetchone() if res: task_contents = [res['title']] elif internal_task_type == 'image': cursor.execute("SELECT image_url FROM t_image_info WHERE id = %s", (task_id,)) res = cursor.fetchone() if res: img_url = res['image_url'] if img_url and not img_url.startswith('http'): img_url = self.minio_manager.get_full_url(img_url) task_contents = [img_url] # 构建最终任务列表 for idx, content in enumerate(task_contents): if not content: continue # 合并元数据:数据库数据 + 动态 ID task_metadata = { "original_id": task_id, "chunk_index": idx } if db_metadata: task_metadata.update(db_metadata) if doc_tags: task_metadata['tags'] = [{"tag": tag} for tag in doc_tags] task_item = { "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id, "content": content, "metadata": task_metadata } # 尝试从元数据中提取 annotation_result if db_metadata and 'annotation_result' in db_metadata: task_item['annotation_result'] = db_metadata['annotation_result'] final_tasks.append(task_item) # 统一进行一次递归序列化处理 return self._serialize_datetime({ "name": project_name, "description": "", "task_type": external_task_type, "data": final_tasks, "external_id": remote_project_id, "tags": [{"tag": t} for t in sorted(list(all_project_tags))] }) except Exception as e: logger.exception(f"导出项目数据异常: {e}") return {} finally: cursor.close() if should_close: conn.close() def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]: """ 批量从 Milvus 获取文档分片内容 """ if not kb_info or not task_ids: return {} collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c] if not collections: return {} from app.services.milvus_service import milvus_service result_map = {tid: [] for tid in task_ids} if not hasattr(self, '_collection_schema_cache'): self._collection_schema_cache = {} for coll_name in collections: try: if not milvus_service.client.has_collection(coll_name): continue if coll_name not in self._collection_schema_cache: schema = milvus_service.client.describe_collection(coll_name) field_names = [f['name'] for f in schema.get('fields', [])] self._collection_schema_cache[coll_name] = { "id": "document_id", "content": "text" if "text" in field_names else "content" } fields = self._collection_schema_cache[coll_name] id_field = fields["id"] content_field = fields["content"] # 使用 in 表达式进行批量查询 # 注意:如果 task_ids 非常多,可能需要分批(如每批 100 个) id_list_str = ", ".join([f'"{tid}"' for tid in task_ids]) res = milvus_service.client.query( collection_name=coll_name, filter=f'{id_field} in [{id_list_str}]', output_fields=[id_field, content_field] ) if res: for s in res: tid = s.get(id_field) val = s.get(content_field) if tid in result_map and val: result_map[tid].append(val) # 如果当前集合已经查到了内容,就不再查兜底集合(逻辑同单条查询) if any(result_map.values()): return result_map except Exception as e: logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}") continue return result_map async def get_project_progress(self, project_id: str) -> Dict[str, Any]: """获取外部标注项目的进度""" conn = get_db_connection() if not conn: return {"error": "数据库连接失败"} try: # 1. 查询 remote_project_id (task_id) cursor = conn.cursor() cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,)) row = cursor.fetchone() cursor.close() if not row or not row['task_id']: return {"error": "未找到已推送的外部项目ID"} remote_project_id = row['task_id'] # 2. 获取配置 from app.core.config import config_handler api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/') progress_url = f"{api_base_url}/{remote_project_id}/progress" token = config_handler.get('external_api', 'admin_token', '') # 3. 发送请求 async with httpx.AsyncClient(timeout=10.0) as client: headers = {"Authorization": f"Bearer {token}"} response = await client.get(progress_url, headers=headers) if response.status_code == 200: data = response.json() # 同步更新本地缓存的完成数量和总数 completed_count = data.get('completed_tasks', 0) total_count = data.get('total_tasks', 0) try: conn_update = get_db_connection() if conn_update: cursor_update = conn_update.cursor() cursor_update.execute( """ UPDATE t_task_management SET external_completed_count = %s, external_total_count = %s WHERE project_id = %s """, (completed_count, total_count, project_id) ) conn_update.commit() cursor_update.close() conn_update.close() except Exception as ex: logger.warning(f"更新本地进度缓存失败: {ex}") return data else: logger.error(f"查询进度失败: {response.status_code} - {response.text}") return {"error": f"外部平台返回错误 ({response.status_code})"} except Exception as e: logger.exception(f"查询进度异常: {e}") return {"error": str(e)} finally: conn.close() async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]: """触发外部标注项目的数据导出""" conn = get_db_connection() if not conn: return {"error": "数据库连接失败"} try: # 1. 查询 remote_project_id (task_id) cursor = conn.cursor() cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,)) row = cursor.fetchone() cursor.close() if not row or not row['task_id']: return {"error": "未找到已推送的外部项目ID"} remote_project_id = row['task_id'] # 2. 获取配置 from app.core.config import config_handler api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/') export_url = f"{api_base_url}/{remote_project_id}/export" token = config_handler.get('external_api', 'admin_token', '') # 3. 发送请求 async with httpx.AsyncClient(timeout=30.0) as client: headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json" } payload = { "format": export_format, "completed_only": completed_only } response = await client.post(export_url, json=payload, headers=headers) if response.status_code in (200, 201): res_data = response.json() logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}") # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url) download_url = res_data.get('download_url') or res_data.get('file_url') if isinstance(res_data, dict) and download_url: try: cursor_update = conn.cursor() affected = cursor_update.execute( "UPDATE t_task_management SET file_url = %s WHERE project_id = %s", (download_url, project_id) ) conn.commit() cursor_update.close() logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}") except Exception as ex: logger.warning(f"导出同步回写 file_url 失败: {ex}") # 2. 同步回写 annotation_result 到 metadata # 情况 A: 接口直接返回了任务列表数据 export_items = [] if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list): export_items = res_data['data'] # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果 elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json': try: # 补全 URL 协议 (如果外部平台返回的是相对路径) full_download_url = download_url if not download_url.startswith('http'): from app.core.config import config_handler # 尝试从配置获取 base_url,如果没有则从 api_url 中提取 base_url = config_handler.get('external_api', 'download_base_url', '') if not base_url: # 兜底:从 project_api_url 中提取域名部分 api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003') from urllib.parse import urlparse parsed = urlparse(api_base) base_url = f"{parsed.scheme}://{parsed.netloc}" full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}" logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}") # 注意:这里需要带上 token file_res = await client.get(full_download_url, headers=headers) if file_res.status_code == 200: file_json = file_res.json() # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...] if isinstance(file_json, dict) and 'data' in file_json: export_items = file_json['data'] elif isinstance(file_json, list): export_items = file_json if export_items: logger.info(f"成功获取导出项,共 {len(export_items)} 条。") else: logger.warning("获取到的导出列表为空") except Exception as ex: logger.warning(f"从导出文件同步标注数据失败: {ex}") if export_items: updated_count = 0 try: cursor_meta = conn.cursor() for ext_item in export_items: # 根据实际 JSON 结构提取数据 # 1. 提取 original_id original_data = ext_item.get('original_data', {}) meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {}) # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取 original_id = meta.get('original_id') if not original_id: ext_id = ext_item.get('external_id', '') if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid original_id = ext_id.rsplit('_', 1)[0] else: original_id = ext_id # 2. 提取 annotation_result annotations = ext_item.get('annotations', []) if annotations and isinstance(annotations, list): annotation_res = annotations[0].get('result') else: annotation_res = ext_item.get('annotation_result') if original_id and annotation_res is not None: # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配 # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识 cursor_meta.execute( "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s", (original_id, original_id) ) row = cursor_meta.fetchone() if row: db_id = row['id'] current_meta = json.loads(row['metadata']) if row['metadata'] else {} current_meta['annotation_result'] = annotation_res cursor_meta.execute( "UPDATE t_task_management SET metadata = %s WHERE id = %s", (json.dumps(current_meta, ensure_ascii=False), db_id) ) updated_count += 1 else: logger.debug(f"未在数据库中找到对应的任务: {original_id}") conn.commit() cursor_meta.close() logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result") except Exception as ex: logger.warning(f"同步回写 annotation_result 异常: {ex}") return res_data else: logger.error(f"导出数据失败: {response.status_code} - {response.text}") return {"error": f"外部平台返回错误 ({response.status_code})"} except Exception as e: logger.exception(f"导出数据异常: {e}") return {"error": str(e)} finally: conn.close() async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]: """将项目数据推送至外部标注平台 (单表化)""" conn = get_db_connection() if not conn: return False, "数据库连接失败" try: # 1. 准备数据 (复用数据库连接) payload = await self.export_project_data(project_id=project_id, conn=conn) if not payload: return False, "项目导出失败,请检查项目ID是否正确" if not payload.get('data'): return False, f"项目数据为空 (查询到0条有效任务),无法推送" # 2. 获取配置 from app.core.config import config_handler api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/') api_url = f"{api_base_url}/init" token = config_handler.get('external_api', 'admin_token', '') # 3. 发送请求 async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json" } logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}") response = await client.post(api_url, json=payload, headers=headers) if response.status_code in (200, 201): res_data = response.json() logger.info(f"外部平台推送成功响应: {res_data}") remote_project_id = res_data.get('project_id') # 获取下载地址并回写 (兼容多种返回格式) download_url = res_data.get('download_url') or res_data.get('file_url') if download_url: # 4. 回写外部项目 ID 和下载地址 (复用当前连接) cursor = conn.cursor() # 先检查受影响行数 affected = cursor.execute( "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s", (remote_project_id, download_url, project_id) ) conn.commit() cursor.close() logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}") elif remote_project_id: # 仅回写外部项目 ID cursor = conn.cursor() affected = cursor.execute( "UPDATE t_task_management SET task_id = %s WHERE project_id = %s", (remote_project_id, project_id) ) conn.commit() cursor.close() logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}") return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}" else: error_msg = response.text logger.error(f"推送失败: {response.status_code} - {error_msg}") return False, f"外部平台返回错误 ({response.status_code})" except Exception as e: logger.exception(f"推送至外部平台异常: {e}") return False, f"推送异常: {str(e)}" finally: conn.close() task_service = TaskService()