| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957 |
- import logging
- import json
- import httpx
- from datetime import datetime, date
- from typing import List, Dict, Any, Tuple, Optional
- from app.base.async_mysql_connection import get_db_connection
- from app.base.minio_connection import get_minio_manager
- logger = logging.getLogger(__name__)
- class TaskService:
- """任务管理服务类"""
-
- _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次
- def __init__(self):
- self.minio_manager = get_minio_manager()
- async def _ensure_table_schema(self, cursor, conn):
- """确保表结构和索引正确 (DDL 操作)"""
- if TaskService._schema_verified:
- return
-
- try:
- # 1. 动态维护字段
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count")
-
- conn.commit()
- # 2. 处理索引冲突
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata")
- logger.info("Added column file_url to t_task_management")
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'")
- indexes = cursor.fetchall()
- for idx in indexes:
- if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1:
- cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'")
- if len(cursor.fetchall()) == 1:
- cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management")
- logger.info(f"Dropped old unique index: {idx['Key_name']}")
-
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'")
- if not cursor.fetchone():
- cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)")
- logger.info("Created new composite unique index: uk_business_project")
-
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'idx_project_id'")
- if not cursor.fetchone():
- cursor.execute("CREATE INDEX idx_project_id ON t_task_management (project_id)")
- logger.info("Created index: idx_project_id")
-
- conn.commit()
- TaskService._schema_verified = True
- except Exception as e:
- logger.warning(f"表结构维护失败: {e}")
- conn.rollback()
- async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]:
- """添加或更新任务记录 (适配单表结构)"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败", None
-
- cursor = conn.cursor()
- try:
- # 确保表结构(仅在第一次调用时执行)
- await self._ensure_table_schema(cursor, conn)
- # 执行插入/更新
- sql = """
- INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- ON DUPLICATE KEY UPDATE
- task_id = IFNULL(VALUES(task_id), task_id),
- project_id = IFNULL(VALUES(project_id), project_id),
- project_name = IFNULL(VALUES(project_name), project_name),
- annotation_status = IFNULL(VALUES(annotation_status), annotation_status),
- tag = IFNULL(VALUES(tag), tag),
- metadata = IFNULL(VALUES(metadata), metadata)
- """
- cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata))
- record_id = cursor.lastrowid
-
- conn.commit()
- return True, "成功", record_id
- except Exception as e:
- conn.rollback()
- logger.exception(f"添加任务失败: {e}")
- return False, str(e), None
- finally:
- cursor.close()
- conn.close()
- async def delete_task_by_id(self, id: int) -> Tuple[bool, str]:
- """根据主键 id 删除任务记录"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- sql = "DELETE FROM t_task_management WHERE id = %s"
- cursor.execute(sql, (id,))
- conn.commit()
- return True, "成功"
- except Exception as e:
- conn.rollback()
- logger.exception(f"删除任务失败: {e}")
- return False, str(e)
- finally:
- cursor.close()
- conn.close()
- def _serialize_datetime(self, obj: Any) -> Any:
- """递归遍历对象,将 datetime 转换为字符串"""
- if isinstance(obj, dict):
- return {k: self._serialize_datetime(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [self._serialize_datetime(i) for i in obj]
- elif isinstance(obj, (datetime, date)):
- return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d')
- return obj
- async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]:
- """获取项目列表 (按 project_id 聚合)"""
- conn = get_db_connection()
- if not conn or task_type not in ['data', 'image']:
- return []
-
- cursor = conn.cursor()
- try:
- # 确保表结构
- await self._ensure_table_schema(cursor, conn)
-
- # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字)
- sql = """
- SELECT
- project_id,
- MAX(project_name) as project_name,
- MAX(task_id) as task_id,
- MAX(tag) as tag,
- MAX(id) as sort_id,
- COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count,
- COALESCE(
- MAX(external_completed_count),
- SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END),
- 0
- ) as completed_count
- FROM t_task_management
- WHERE type = %s
- GROUP BY project_id
- ORDER BY sort_id DESC
- """
- cursor.execute(sql, (task_type,))
- rows = cursor.fetchall()
-
- # 兼容旧数据:如果 project_name 为空,则用 project_id 代替
- for row in rows:
- if not row['project_name']:
- row['project_name'] = row['project_id']
-
- return self._serialize_datetime(rows)
- except Exception as e:
- logger.error(f"获取任务列表失败: {e}")
- return []
- finally:
- cursor.close()
- conn.close()
- async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]:
- """获取项目详情 (按 project_id 查询)"""
- conn = get_db_connection()
- if not conn:
- return []
-
- cursor = conn.cursor()
- try:
- # 确保表结构
- await self._ensure_table_schema(cursor, conn)
-
- # 修改查询逻辑:获取 project_name 并在结果中包含它
- sql = """
- SELECT
- id, business_id, task_id, project_id, project_name,
- type, annotation_status, tag, metadata
- FROM t_task_management
- WHERE project_id = %s AND type = %s
- """
- cursor.execute(sql, (project_id, task_type))
- rows = cursor.fetchall()
-
- # 处理 tag 和 metadata 的 JSON 解析
- for row in rows:
- if row.get('tag'):
- try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag']
- except: pass
- if row.get('metadata'):
- try:
- meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
- row['metadata'] = meta
- # 提取名称供前端显示
- if row['type'] == 'data':
- row['name'] = meta.get('title') or meta.get('filename') or row['business_id']
- elif row['type'] == 'image':
- row['name'] = meta.get('image_name') or row['business_id']
- else:
- row['name'] = row['business_id']
- except:
- row['name'] = row['business_id']
- else:
- row['name'] = row['business_id']
-
- return self._serialize_datetime(rows)
- except Exception as e:
- logger.error(f"获取项目详情失败: {e}")
- return []
- finally:
- cursor.close()
- conn.close()
- async def delete_task(self, business_id: str) -> Tuple[bool, str]:
- """删除任务记录"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- sql = "DELETE FROM t_task_management WHERE business_id = %s"
- cursor.execute(sql, (business_id,))
- conn.commit()
- return True, "删除成功"
- except Exception as e:
- logger.exception(f"删除任务记录失败: {e}")
- conn.rollback()
- return False, f"删除失败: {str(e)}"
- finally:
- cursor.close()
- conn.close()
- async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]:
- """创建标注项目并同步任务数据"""
- project_name = data.get('name')
- if not project_name:
- return False, "项目名称不能为空"
-
- # 0. 统一使用 UUID 方案获取或生成 project_id
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- project_id = None
- cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,))
- existing_project = cursor.fetchone()
- if existing_project:
- project_id = existing_project['project_id']
- else:
- import uuid
- project_id = str(uuid.uuid4())
- task_type = data.get('task_type', 'data')
- # 映射回内部类型
- internal_type_map = {
- 'text_classification': 'data',
- 'image_classification': 'image'
- }
- internal_task_type = internal_type_map.get(task_type, task_type)
-
- tasks_data = data.get('data', [])
- if not tasks_data:
- return False, "任务数据不能为空"
-
- # 提取全局标签名列表
- global_tags = []
- if data.get('tags'):
- global_tags = [t['tag'] for t in data['tags'] if 'tag' in t]
-
- # 批量写入任务
- import json
- tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None
-
- for item in tasks_data:
- business_id = item.get('id')
- if not business_id: continue
-
- # 检查是否已存在 (使用联合主键逻辑)
- cursor.execute(
- "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s",
- (business_id, project_id)
- )
- if cursor.fetchone():
- continue
-
- metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False)
-
- sql = """
- INSERT INTO t_task_management
- (business_id, type, project_id, project_name, tag, metadata, annotation_status)
- VALUES (%s, %s, %s, %s, %s, %s, 'pending')
- """
- cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str))
-
- conn.commit()
-
- # 自动推送至外部平台
- success, msg = await self.send_to_external_platform(project_id)
- if success:
- return True, project_id
- else:
- return False, f"任务已保存但推送失败: {msg}"
-
- except Exception as e:
- logger.exception(f"创建项目失败: {e}")
- conn.rollback()
- return False, str(e)
- finally:
- cursor.close()
- conn.close()
- async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
- """获取项目进度统计 (单表化)"""
- conn = get_db_connection()
- if not conn:
- return {}
-
- cursor = conn.cursor()
- try:
- # 统计各状态数量
- sql = """
- SELECT
- annotation_status,
- COUNT(*) as count
- FROM t_task_management
- WHERE project_id = %s
- GROUP BY annotation_status
- """
- cursor.execute(sql, (project_id,))
- stats = cursor.fetchall()
-
- if not stats:
- return {}
- total = sum(s['count'] for s in stats)
- completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed')
-
- return {
- "project_id": project_id,
- "total": total,
- "completed": completed,
- "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%",
- "details": stats
- }
- except Exception as e:
- logger.error(f"查询进度失败: {e}")
- return {}
- finally:
- cursor.close()
- conn.close()
- def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]:
- """
- 从 Milvus 获取文档分片内容
- """
- if not kb_info:
- return []
- # 优先使用数据库存储的子表名,父表名作为兜底
- collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
-
- if not collections:
- return []
- from app.services.milvus_service import milvus_service
- contents = []
-
- # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用
- if not hasattr(self, '_collection_schema_cache'):
- self._collection_schema_cache = {}
- for coll_name in collections:
- try:
- if not milvus_service.client.has_collection(coll_name):
- continue
-
- # 获取或更新缓存的字段名
- if coll_name not in self._collection_schema_cache:
- schema = milvus_service.client.describe_collection(coll_name)
- field_names = [f['name'] for f in schema.get('fields', [])]
- self._collection_schema_cache[coll_name] = {
- "id": "document_id", # 统一使用 document_id
- "content": "text" if "text" in field_names else "content"
- }
-
- fields = self._collection_schema_cache[coll_name]
- id_field = fields["id"]
- content_field = fields["content"]
-
- res = milvus_service.client.query(
- collection_name=coll_name,
- filter=f'{id_field} == "{task_id}"',
- output_fields=[content_field]
- )
-
- if res:
- for s in res:
- val = s.get(content_field)
- if val: contents.append(val)
-
- if contents:
- return contents
- except Exception as e:
- logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}")
- continue
-
- return []
- async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]:
- """导出项目数据为标注平台要求的格式 (单表化)"""
- should_close = False
- if not conn:
- conn = get_db_connection()
- should_close = True
-
- if not conn:
- return {}
-
- cursor = conn.cursor()
- try:
- # 1. 获取任务记录
- sql_tasks = """
- SELECT business_id as id, type, task_id, project_name, tag, metadata
- FROM t_task_management
- WHERE project_id = %s
- """
- cursor.execute(sql_tasks, (project_id,))
- rows = cursor.fetchall()
-
- if not rows:
- return {}
-
- # 2. 解析基本信息
- first_row = rows[0]
- internal_task_type = first_row['type']
- project_name = first_row.get('project_name') or project_id
- remote_project_id = first_row.get('task_id') or project_id
-
- # 映射任务类型
- type_map = {'data': 'text_classification', 'image': 'image_classification'}
- external_task_type = type_map.get(internal_task_type, internal_task_type)
- # 3. 处理数据
- final_tasks = []
- all_project_tags = set()
-
- # 针对 'data' 类型的批量 Milvus 查询优化
- milvus_data_map = {}
- missing_milvus_ids = []
- if internal_task_type == 'data':
- all_task_ids = [r['id'] for r in rows]
- sql_kb = """
- SELECT kb.collection_name_parent, kb.collection_name_children
- FROM t_samp_document_main d
- LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id
- WHERE d.id = %s
- """
- cursor.execute(sql_kb, (all_task_ids[0],))
- kb_info = cursor.fetchone()
- if kb_info:
- milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info)
-
- # 记录缺失 Milvus 数据的 ID
- missing_milvus_ids = [tid for tid in all_task_ids if not milvus_data_map.get(tid)]
-
- # 针对缺失数据的批量标题查询
- title_map = {}
- if missing_milvus_ids:
- placeholders = ', '.join(['%s'] * len(missing_milvus_ids))
- cursor.execute(f"SELECT id, title FROM t_samp_document_main WHERE id IN ({placeholders})", missing_milvus_ids)
- title_map = {r['id']: r['title'] for r in cursor.fetchall()}
- # 针对 'image' 类型的批量 MinIO 查询优化
- image_data_map = {}
- if internal_task_type == 'image':
- all_image_ids = [r['id'] for r in rows]
- placeholders = ', '.join(['%s'] * len(all_image_ids))
- cursor.execute(f"SELECT id, image_url FROM t_image_info WHERE id IN ({placeholders})", all_image_ids)
- for img_row in cursor.fetchall():
- img_url = img_row['image_url']
- if img_url and not img_url.startswith('http'):
- img_url = self.minio_manager.get_full_url(img_url)
- image_data_map[img_row['id']] = [img_url] if img_url else []
- for item in rows:
- task_id = item['id']
-
- # 提取并处理标签
- doc_tags = []
- if item.get('tag'):
- try:
- doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag']
- if doc_tags:
- for t in doc_tags: all_project_tags.add(t)
- except: pass
-
- # 解析数据库元数据 (提前序列化日期)
- db_metadata = {}
- if item.get('metadata'):
- try:
- db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata']
- if db_metadata:
- db_metadata = self._serialize_datetime(db_metadata)
- except: pass
-
- # 获取任务内容
- task_contents = []
- if internal_task_type == 'data':
- task_contents = milvus_data_map.get(task_id, [])
- if not task_contents:
- title = title_map.get(task_id)
- if title: task_contents = [title]
- elif internal_task_type == 'image':
- task_contents = image_data_map.get(task_id, [])
- # 构建最终任务列表
- for idx, content in enumerate(task_contents):
- if not content: continue
-
- # 合并元数据:数据库数据 + 动态 ID
- task_metadata = {
- "original_id": task_id,
- "chunk_index": idx
- }
- if db_metadata:
- task_metadata.update(db_metadata)
-
- if doc_tags:
- task_metadata['tags'] = [{"tag": tag} for tag in doc_tags]
-
- task_item = {
- "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id,
- "content": content,
- "metadata": task_metadata
- }
-
- # 尝试从元数据中提取 annotation_result
- if db_metadata and 'annotation_result' in db_metadata:
- task_item['annotation_result'] = db_metadata['annotation_result']
-
- final_tasks.append(task_item)
- # 准备返回结果,不再进行全局递归序列化 (已在局部处理)
- return {
- "name": project_name,
- "description": "",
- "task_type": external_task_type,
- "data": final_tasks,
- "external_id": remote_project_id,
- "tags": [{"tag": t} for t in sorted(list(all_project_tags))]
- }
- except Exception as e:
- logger.exception(f"导出项目数据异常: {e}")
- return {}
- finally:
- cursor.close()
- if should_close:
- conn.close()
- def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]:
- """
- 批量从 Milvus 获取文档分片内容
- """
- if not kb_info or not task_ids:
- return {}
- collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
- if not collections:
- return {}
- from app.services.milvus_service import milvus_service
- result_map = {tid: [] for tid in task_ids}
-
- if not hasattr(self, '_collection_schema_cache'):
- self._collection_schema_cache = {}
- for coll_name in collections:
- try:
- if not milvus_service.client.has_collection(coll_name):
- continue
-
- if coll_name not in self._collection_schema_cache:
- schema = milvus_service.client.describe_collection(coll_name)
- field_names = [f['name'] for f in schema.get('fields', [])]
- self._collection_schema_cache[coll_name] = {
- "id": "document_id",
- "content": "text" if "text" in field_names else "content"
- }
-
- fields = self._collection_schema_cache[coll_name]
- id_field = fields["id"]
- content_field = fields["content"]
-
- # 使用 in 表达式进行批量查询 (分批处理以防 ID 过多)
- CHUNK_SIZE = 100
- for i in range(0, len(task_ids), CHUNK_SIZE):
- chunk_ids = task_ids[i:i + CHUNK_SIZE]
- id_list_str = ", ".join([f'"{tid}"' for tid in chunk_ids])
-
- logger.info(f"正在从 Milvus 集合 {coll_name} 查询分片内容 ({i}/{len(task_ids)})...")
- res = milvus_service.client.query(
- collection_name=coll_name,
- filter=f'{id_field} in [{id_list_str}]',
- output_fields=[id_field, content_field]
- )
-
- if res:
- for s in res:
- tid = s.get(id_field)
- val = s.get(content_field)
- if tid in result_map and val:
- result_map[tid].append(val)
-
- # 如果当前集合已经查到了内容,就不再查兜底集合 (除非结果仍为空)
- if any(result_map.values()):
- logger.info(f"从 Milvus 集合 {coll_name} 查得 {sum(len(v) for v in result_map.values())} 条内容分片")
- return result_map
- except Exception as e:
- logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}")
- continue
-
- return result_map
- async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
- """获取外部标注项目的进度"""
- conn = get_db_connection()
- if not conn:
- return {"error": "数据库连接失败"}
-
- try:
- # 1. 查询 remote_project_id (task_id)
- cursor = conn.cursor()
- cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
- row = cursor.fetchone()
- cursor.close()
-
- if not row or not row['task_id']:
- return {"error": "未找到已推送的外部项目ID"}
-
- remote_project_id = row['task_id']
-
- # 2. 获取配置
- from app.core.config import config_handler
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- progress_url = f"{api_base_url}/{remote_project_id}/progress"
- token = config_handler.get('external_api', 'admin_token', '')
- # 3. 发送请求
- async with httpx.AsyncClient(timeout=10.0) as client:
- headers = {"Authorization": f"Bearer {token}"}
- response = await client.get(progress_url, headers=headers)
-
- if response.status_code == 200:
- data = response.json()
- # 同步更新本地缓存的完成数量和总数
- completed_count = data.get('completed_tasks', 0)
- total_count = data.get('total_tasks', 0)
- try:
- conn_update = get_db_connection()
- if conn_update:
- cursor_update = conn_update.cursor()
- cursor_update.execute(
- """
- UPDATE t_task_management
- SET external_completed_count = %s, external_total_count = %s
- WHERE project_id = %s
- """,
- (completed_count, total_count, project_id)
- )
- conn_update.commit()
- cursor_update.close()
- conn_update.close()
- except Exception as ex:
- logger.warning(f"更新本地进度缓存失败: {ex}")
-
- return data
- else:
- logger.error(f"查询进度失败: {response.status_code} - {response.text}")
- return {"error": f"外部平台返回错误 ({response.status_code})"}
-
- except Exception as e:
- logger.exception(f"查询进度异常: {e}")
- return {"error": str(e)}
- finally:
- conn.close()
- async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]:
- """触发外部标注项目的数据导出"""
- conn = get_db_connection()
- if not conn:
- return {"error": "数据库连接失败"}
-
- try:
- # 1. 查询 remote_project_id (task_id)
- cursor = conn.cursor()
- cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
- row = cursor.fetchone()
- cursor.close()
-
- if not row or not row['task_id']:
- return {"error": "未找到已推送的外部项目ID"}
-
- remote_project_id = row['task_id']
-
- # 2. 获取配置
- from app.core.config import config_handler
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- export_url = f"{api_base_url}/{remote_project_id}/export"
- token = config_handler.get('external_api', 'admin_token', '')
- # 3. 发送请求
- async with httpx.AsyncClient(timeout=30.0) as client:
- headers = {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json"
- }
- payload = {
- "format": export_format,
- "completed_only": completed_only
- }
- response = await client.post(export_url, json=payload, headers=headers)
-
- if response.status_code in (200, 201):
- res_data = response.json()
- logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}")
-
- # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url)
- download_url = res_data.get('download_url') or res_data.get('file_url')
- if isinstance(res_data, dict) and download_url:
- try:
- cursor_update = conn.cursor()
- affected = cursor_update.execute(
- "UPDATE t_task_management SET file_url = %s WHERE project_id = %s",
- (download_url, project_id)
- )
- conn.commit()
- cursor_update.close()
- logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}")
- except Exception as ex:
- logger.warning(f"导出同步回写 file_url 失败: {ex}")
-
- # 2. 同步回写 annotation_result 到 metadata
- # 情况 A: 接口直接返回了任务列表数据
- export_items = []
- if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list):
- export_items = res_data['data']
-
- # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果
- elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json':
- try:
- # 补全 URL 协议 (如果外部平台返回的是相对路径)
- full_download_url = download_url
- if not download_url.startswith('http'):
- from app.core.config import config_handler
- # 尝试从配置获取 base_url,如果没有则从 api_url 中提取
- base_url = config_handler.get('external_api', 'download_base_url', '')
- if not base_url:
- # 兜底:从 project_api_url 中提取域名部分
- api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003')
- from urllib.parse import urlparse
- parsed = urlparse(api_base)
- base_url = f"{parsed.scheme}://{parsed.netloc}"
-
- full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}"
-
- logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}")
- # 注意:这里需要带上 token
- file_res = await client.get(full_download_url, headers=headers)
- if file_res.status_code == 200:
- file_json = file_res.json()
- # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...]
- if isinstance(file_json, dict) and 'data' in file_json:
- export_items = file_json['data']
- elif isinstance(file_json, list):
- export_items = file_json
-
- if export_items:
- logger.info(f"成功获取导出项,共 {len(export_items)} 条。")
- else:
- logger.warning("获取到的导出列表为空")
- except Exception as ex:
- logger.warning(f"从导出文件同步标注数据失败: {ex}")
- if export_items:
- updated_count = 0
- try:
- cursor_meta = conn.cursor()
- for ext_item in export_items:
- # 根据实际 JSON 结构提取数据
- # 1. 提取 original_id
- original_data = ext_item.get('original_data', {})
- meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {})
- # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取
- original_id = meta.get('original_id')
- if not original_id:
- ext_id = ext_item.get('external_id', '')
- if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid
- original_id = ext_id.rsplit('_', 1)[0]
- else:
- original_id = ext_id
-
- # 2. 提取 annotation_result
- annotations = ext_item.get('annotations', [])
- if annotations and isinstance(annotations, list):
- annotation_res = annotations[0].get('result')
- else:
- annotation_res = ext_item.get('annotation_result')
-
- if original_id and annotation_res is not None:
- # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配
- # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识
- cursor_meta.execute(
- "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s",
- (original_id, original_id)
- )
- row = cursor_meta.fetchone()
- if row:
- db_id = row['id']
- current_meta = json.loads(row['metadata']) if row['metadata'] else {}
- current_meta['annotation_result'] = annotation_res
-
- cursor_meta.execute(
- "UPDATE t_task_management SET metadata = %s WHERE id = %s",
- (json.dumps(current_meta, ensure_ascii=False), db_id)
- )
- updated_count += 1
- else:
- logger.debug(f"未在数据库中找到对应的任务: {original_id}")
-
- conn.commit()
- cursor_meta.close()
- logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result")
- except Exception as ex:
- logger.warning(f"同步回写 annotation_result 异常: {ex}")
-
- return res_data
- else:
- logger.error(f"导出数据失败: {response.status_code} - {response.text}")
- return {"error": f"外部平台返回错误 ({response.status_code})"}
-
- except Exception as e:
- logger.exception(f"导出数据异常: {e}")
- return {"error": str(e)}
- finally:
- conn.close()
- async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]:
- """将项目数据推送至外部标注平台 (单表化)"""
- # 1. 准备数据 (导出数据需要数据库连接)
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- try:
- logger.info(f"开始导出项目 {project_id} 数据...")
- payload = await self.export_project_data(project_id=project_id, conn=conn)
-
- if not payload:
- return False, "项目导出失败,请检查项目ID是否正确"
-
- if not payload.get('data'):
- return False, f"项目数据为空 (查询到0条有效任务),无法推送"
- except Exception as e:
- logger.exception(f"导出项目数据异常: {e}")
- return False, f"导出异常: {str(e)}"
- finally:
- # 及时释放连接,防止在 HTTP 请求期间占用
- conn.close()
-
- # 2. 获取配置
- try:
- from app.core.config import config_handler
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- api_url = f"{api_base_url}/init"
- token = config_handler.get('external_api', 'admin_token', '')
- # 3. 发送请求 (不持有数据库连接)
- async with httpx.AsyncClient(timeout=120.0) as client: # 增加超时时间到 120s
- headers = {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json"
- }
- logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}")
- response = await client.post(api_url, json=payload, headers=headers)
-
- if response.status_code in (200, 201):
- res_data = response.json()
- logger.info(f"外部平台推送成功响应: {res_data}")
- remote_project_id = res_data.get('project_id')
- download_url = res_data.get('download_url') or res_data.get('file_url')
-
- # 4. 回写外部项目 ID 和下载地址 (重新获取连接)
- if remote_project_id:
- conn = get_db_connection()
- if conn:
- try:
- cursor = conn.cursor()
- if download_url:
- affected = cursor.execute(
- "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s",
- (remote_project_id, download_url, project_id)
- )
- logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}")
- else:
- affected = cursor.execute(
- "UPDATE t_task_management SET task_id = %s WHERE project_id = %s",
- (remote_project_id, project_id)
- )
- logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}")
- conn.commit()
- finally:
- cursor.close()
- conn.close()
-
- return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}"
- else:
- error_msg = response.text
- logger.error(f"推送失败: {response.status_code} - {error_msg}")
- return False, f"外部平台返回错误 ({response.status_code})"
-
- except Exception as e:
- logger.exception(f"推送至外部平台异常: {e}")
- return False, f"推送异常: {str(e)}"
- task_service = TaskService()
|