task_service.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. import logging
  2. import json
  3. import httpx
  4. from datetime import datetime, date
  5. from typing import List, Dict, Any, Tuple, Optional
  6. from app.base.async_mysql_connection import get_db_connection
  7. from app.base.minio_connection import get_minio_manager
  8. logger = logging.getLogger(__name__)
  9. class TaskService:
  10. """任务管理服务类"""
  11. _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次
  12. def __init__(self):
  13. self.minio_manager = get_minio_manager()
  14. async def _ensure_table_schema(self, cursor, conn):
  15. """确保表结构和索引正确 (DDL 操作)"""
  16. if TaskService._schema_verified:
  17. return
  18. try:
  19. # 1. 动态维护字段
  20. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'")
  21. if not cursor.fetchone():
  22. cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type")
  23. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'")
  24. if not cursor.fetchone():
  25. cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag")
  26. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'")
  27. if not cursor.fetchone():
  28. cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id")
  29. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'")
  30. if not cursor.fetchone():
  31. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status")
  32. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'")
  33. if not cursor.fetchone():
  34. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count")
  35. conn.commit()
  36. # 2. 处理索引冲突
  37. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'")
  38. if not cursor.fetchone():
  39. cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata")
  40. logger.info("Added column file_url to t_task_management")
  41. cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'")
  42. indexes = cursor.fetchall()
  43. for idx in indexes:
  44. if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1:
  45. cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'")
  46. if len(cursor.fetchall()) == 1:
  47. cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management")
  48. logger.info(f"Dropped old unique index: {idx['Key_name']}")
  49. cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'")
  50. if not cursor.fetchone():
  51. cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)")
  52. logger.info("Created new composite unique index: uk_business_project")
  53. conn.commit()
  54. TaskService._schema_verified = True
  55. except Exception as e:
  56. logger.warning(f"表结构维护失败: {e}")
  57. conn.rollback()
  58. async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]:
  59. """添加或更新任务记录 (适配单表结构)"""
  60. conn = get_db_connection()
  61. if not conn:
  62. return False, "数据库连接失败", None
  63. cursor = conn.cursor()
  64. try:
  65. # 确保表结构(仅在第一次调用时执行)
  66. await self._ensure_table_schema(cursor, conn)
  67. # 执行插入/更新
  68. sql = """
  69. INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata)
  70. VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  71. ON DUPLICATE KEY UPDATE
  72. task_id = IFNULL(VALUES(task_id), task_id),
  73. project_id = IFNULL(VALUES(project_id), project_id),
  74. project_name = IFNULL(VALUES(project_name), project_name),
  75. annotation_status = IFNULL(VALUES(annotation_status), annotation_status),
  76. tag = IFNULL(VALUES(tag), tag),
  77. metadata = IFNULL(VALUES(metadata), metadata)
  78. """
  79. cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata))
  80. record_id = cursor.lastrowid
  81. conn.commit()
  82. return True, "成功", record_id
  83. except Exception as e:
  84. conn.rollback()
  85. logger.exception(f"添加任务失败: {e}")
  86. return False, str(e), None
  87. finally:
  88. cursor.close()
  89. conn.close()
  90. async def delete_task_by_id(self, id: int) -> Tuple[bool, str]:
  91. """根据主键 id 删除任务记录"""
  92. conn = get_db_connection()
  93. if not conn:
  94. return False, "数据库连接失败"
  95. cursor = conn.cursor()
  96. try:
  97. sql = "DELETE FROM t_task_management WHERE id = %s"
  98. cursor.execute(sql, (id,))
  99. conn.commit()
  100. return True, "成功"
  101. except Exception as e:
  102. conn.rollback()
  103. logger.exception(f"删除任务失败: {e}")
  104. return False, str(e)
  105. finally:
  106. cursor.close()
  107. conn.close()
  108. def _serialize_datetime(self, obj: Any) -> Any:
  109. """递归遍历对象,将 datetime 转换为字符串"""
  110. if isinstance(obj, dict):
  111. return {k: self._serialize_datetime(v) for k, v in obj.items()}
  112. elif isinstance(obj, list):
  113. return [self._serialize_datetime(i) for i in obj]
  114. elif isinstance(obj, (datetime, date)):
  115. return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d')
  116. return obj
  117. async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]:
  118. """获取项目列表 (按 project_id 聚合)"""
  119. conn = get_db_connection()
  120. if not conn or task_type not in ['data', 'image']:
  121. return []
  122. cursor = conn.cursor()
  123. try:
  124. # 确保表结构
  125. await self._ensure_table_schema(cursor, conn)
  126. # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字)
  127. sql = """
  128. SELECT
  129. project_id,
  130. MAX(project_name) as project_name,
  131. MAX(task_id) as task_id,
  132. MAX(tag) as tag,
  133. MAX(id) as sort_id,
  134. COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count,
  135. COALESCE(
  136. MAX(external_completed_count),
  137. SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END),
  138. 0
  139. ) as completed_count
  140. FROM t_task_management
  141. WHERE type = %s
  142. GROUP BY project_id
  143. ORDER BY sort_id DESC
  144. """
  145. cursor.execute(sql, (task_type,))
  146. rows = cursor.fetchall()
  147. # 兼容旧数据:如果 project_name 为空,则用 project_id 代替
  148. for row in rows:
  149. if not row['project_name']:
  150. row['project_name'] = row['project_id']
  151. return self._serialize_datetime(rows)
  152. except Exception as e:
  153. logger.error(f"获取任务列表失败: {e}")
  154. return []
  155. finally:
  156. cursor.close()
  157. conn.close()
  158. async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]:
  159. """获取项目详情 (按 project_id 查询)"""
  160. conn = get_db_connection()
  161. if not conn:
  162. return []
  163. cursor = conn.cursor()
  164. try:
  165. # 确保表结构
  166. await self._ensure_table_schema(cursor, conn)
  167. # 修改查询逻辑:获取 project_name 并在结果中包含它
  168. sql = """
  169. SELECT
  170. id, business_id, task_id, project_id, project_name,
  171. type, annotation_status, tag, metadata
  172. FROM t_task_management
  173. WHERE project_id = %s AND type = %s
  174. """
  175. cursor.execute(sql, (project_id, task_type))
  176. rows = cursor.fetchall()
  177. # 处理 tag 和 metadata 的 JSON 解析
  178. for row in rows:
  179. if row.get('tag'):
  180. try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag']
  181. except: pass
  182. if row.get('metadata'):
  183. try:
  184. meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
  185. row['metadata'] = meta
  186. # 提取名称供前端显示
  187. if row['type'] == 'data':
  188. row['name'] = meta.get('title') or meta.get('filename') or row['business_id']
  189. elif row['type'] == 'image':
  190. row['name'] = meta.get('image_name') or row['business_id']
  191. else:
  192. row['name'] = row['business_id']
  193. except:
  194. row['name'] = row['business_id']
  195. else:
  196. row['name'] = row['business_id']
  197. return self._serialize_datetime(rows)
  198. except Exception as e:
  199. logger.error(f"获取项目详情失败: {e}")
  200. return []
  201. finally:
  202. cursor.close()
  203. conn.close()
  204. async def delete_task(self, business_id: str) -> Tuple[bool, str]:
  205. """删除任务记录"""
  206. conn = get_db_connection()
  207. if not conn:
  208. return False, "数据库连接失败"
  209. cursor = conn.cursor()
  210. try:
  211. sql = "DELETE FROM t_task_management WHERE business_id = %s"
  212. cursor.execute(sql, (business_id,))
  213. conn.commit()
  214. return True, "删除成功"
  215. except Exception as e:
  216. logger.exception(f"删除任务记录失败: {e}")
  217. conn.rollback()
  218. return False, f"删除失败: {str(e)}"
  219. finally:
  220. cursor.close()
  221. conn.close()
  222. async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]:
  223. """创建标注项目并同步任务数据"""
  224. project_name = data.get('name')
  225. if not project_name:
  226. return False, "项目名称不能为空"
  227. # 0. 统一使用 UUID 方案获取或生成 project_id
  228. conn = get_db_connection()
  229. if not conn:
  230. return False, "数据库连接失败"
  231. cursor = conn.cursor()
  232. try:
  233. project_id = None
  234. cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,))
  235. existing_project = cursor.fetchone()
  236. if existing_project:
  237. project_id = existing_project['project_id']
  238. else:
  239. import uuid
  240. project_id = str(uuid.uuid4())
  241. task_type = data.get('task_type', 'data')
  242. # 映射回内部类型
  243. internal_type_map = {
  244. 'text_classification': 'data',
  245. 'image_classification': 'image'
  246. }
  247. internal_task_type = internal_type_map.get(task_type, task_type)
  248. tasks_data = data.get('data', [])
  249. if not tasks_data:
  250. return False, "任务数据不能为空"
  251. # 提取全局标签名列表
  252. global_tags = []
  253. if data.get('tags'):
  254. global_tags = [t['tag'] for t in data['tags'] if 'tag' in t]
  255. # 批量写入任务
  256. import json
  257. tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None
  258. for item in tasks_data:
  259. business_id = item.get('id')
  260. if not business_id: continue
  261. # 检查是否已存在 (使用联合主键逻辑)
  262. cursor.execute(
  263. "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s",
  264. (business_id, project_id)
  265. )
  266. if cursor.fetchone():
  267. continue
  268. metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False)
  269. sql = """
  270. INSERT INTO t_task_management
  271. (business_id, type, project_id, project_name, tag, metadata, annotation_status)
  272. VALUES (%s, %s, %s, %s, %s, %s, 'pending')
  273. """
  274. cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str))
  275. conn.commit()
  276. # 自动推送至外部平台
  277. success, msg = await self.send_to_external_platform(project_id)
  278. if success:
  279. return True, project_id
  280. else:
  281. return False, f"任务已保存但推送失败: {msg}"
  282. except Exception as e:
  283. logger.exception(f"创建项目失败: {e}")
  284. conn.rollback()
  285. return False, str(e)
  286. finally:
  287. cursor.close()
  288. conn.close()
  289. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  290. """获取项目进度统计 (单表化)"""
  291. conn = get_db_connection()
  292. if not conn:
  293. return {}
  294. cursor = conn.cursor()
  295. try:
  296. # 统计各状态数量
  297. sql = """
  298. SELECT
  299. annotation_status,
  300. COUNT(*) as count
  301. FROM t_task_management
  302. WHERE project_id = %s
  303. GROUP BY annotation_status
  304. """
  305. cursor.execute(sql, (project_id,))
  306. stats = cursor.fetchall()
  307. if not stats:
  308. return {}
  309. total = sum(s['count'] for s in stats)
  310. completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed')
  311. return {
  312. "project_id": project_id,
  313. "total": total,
  314. "completed": completed,
  315. "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%",
  316. "details": stats
  317. }
  318. except Exception as e:
  319. logger.error(f"查询进度失败: {e}")
  320. return {}
  321. finally:
  322. cursor.close()
  323. conn.close()
  324. def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]:
  325. """
  326. 从 Milvus 获取文档分片内容
  327. """
  328. if not kb_info:
  329. return []
  330. # 优先使用数据库存储的子表名,父表名作为兜底
  331. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  332. if not collections:
  333. return []
  334. from app.services.milvus_service import milvus_service
  335. contents = []
  336. # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用
  337. if not hasattr(self, '_collection_schema_cache'):
  338. self._collection_schema_cache = {}
  339. for coll_name in collections:
  340. try:
  341. if not milvus_service.client.has_collection(coll_name):
  342. continue
  343. # 获取或更新缓存的字段名
  344. if coll_name not in self._collection_schema_cache:
  345. schema = milvus_service.client.describe_collection(coll_name)
  346. field_names = [f['name'] for f in schema.get('fields', [])]
  347. self._collection_schema_cache[coll_name] = {
  348. "id": "document_id", # 统一使用 document_id
  349. "content": "text" if "text" in field_names else "content"
  350. }
  351. fields = self._collection_schema_cache[coll_name]
  352. id_field = fields["id"]
  353. content_field = fields["content"]
  354. res = milvus_service.client.query(
  355. collection_name=coll_name,
  356. filter=f'{id_field} == "{task_id}"',
  357. output_fields=[content_field]
  358. )
  359. if res:
  360. for s in res:
  361. val = s.get(content_field)
  362. if val: contents.append(val)
  363. if contents:
  364. return contents
  365. except Exception as e:
  366. logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}")
  367. continue
  368. return []
  369. async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]:
  370. """导出项目数据为标注平台要求的格式 (单表化)"""
  371. should_close = False
  372. if not conn:
  373. conn = get_db_connection()
  374. should_close = True
  375. if not conn:
  376. return {}
  377. cursor = conn.cursor()
  378. try:
  379. # 1. 获取任务记录
  380. sql_tasks = """
  381. SELECT business_id as id, type, task_id, project_name, tag, metadata
  382. FROM t_task_management
  383. WHERE project_id = %s
  384. """
  385. cursor.execute(sql_tasks, (project_id,))
  386. rows = cursor.fetchall()
  387. if not rows:
  388. return {}
  389. # 2. 解析基本信息
  390. first_row = rows[0]
  391. internal_task_type = first_row['type']
  392. project_name = first_row.get('project_name') or project_id
  393. remote_project_id = first_row.get('task_id') or project_id
  394. # 映射任务类型
  395. type_map = {'data': 'text_classification', 'image': 'image_classification'}
  396. external_task_type = type_map.get(internal_task_type, internal_task_type)
  397. # 3. 处理数据
  398. final_tasks = []
  399. all_project_tags = set()
  400. # 针对 'data' 类型的批量 Milvus 查询优化
  401. milvus_data_map = {}
  402. if internal_task_type == 'data':
  403. all_task_ids = [r['id'] for r in rows]
  404. sql_kb = """
  405. SELECT kb.collection_name_parent, kb.collection_name_children
  406. FROM t_samp_document_main d
  407. LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id
  408. WHERE d.id COLLATE utf8mb4_unicode_ci = %s COLLATE utf8mb4_unicode_ci
  409. """
  410. cursor.execute(sql_kb, (all_task_ids[0],))
  411. kb_info = cursor.fetchone()
  412. if kb_info:
  413. milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info)
  414. for item in rows:
  415. task_id = item['id']
  416. # 记录原始数据状态
  417. logger.debug(f"正在处理导出任务 {task_id}, tag: {item.get('tag')}, metadata_keys: {list(json.loads(item['metadata']).keys()) if item.get('metadata') else 'None'}")
  418. # 提取并处理标签
  419. doc_tags = []
  420. if item.get('tag'):
  421. try:
  422. doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag']
  423. if doc_tags:
  424. for t in doc_tags: all_project_tags.add(t)
  425. except: pass
  426. # 解析数据库元数据
  427. db_metadata = {}
  428. if item.get('metadata'):
  429. try:
  430. db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata']
  431. except: pass
  432. # 获取任务内容
  433. task_contents = []
  434. annotation_results = []
  435. if internal_task_type == 'data':
  436. task_contents = milvus_data_map.get(task_id, [])
  437. if not task_contents:
  438. cursor.execute("SELECT title FROM t_samp_document_main WHERE id = %s", (task_id,))
  439. res = cursor.fetchone()
  440. if res: task_contents = [res['title']]
  441. elif internal_task_type == 'image':
  442. cursor.execute("SELECT image_url FROM t_image_info WHERE id = %s", (task_id,))
  443. res = cursor.fetchone()
  444. if res:
  445. img_url = res['image_url']
  446. if img_url and not img_url.startswith('http'):
  447. img_url = self.minio_manager.get_full_url(img_url)
  448. task_contents = [img_url]
  449. # 构建最终任务列表
  450. for idx, content in enumerate(task_contents):
  451. if not content: continue
  452. # 合并元数据:数据库数据 + 动态 ID
  453. task_metadata = {
  454. "original_id": task_id,
  455. "chunk_index": idx
  456. }
  457. if db_metadata:
  458. task_metadata.update(db_metadata)
  459. if doc_tags:
  460. task_metadata['tags'] = [{"tag": tag} for tag in doc_tags]
  461. task_item = {
  462. "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id,
  463. "content": content,
  464. "metadata": task_metadata
  465. }
  466. # 尝试从元数据中提取 annotation_result
  467. if db_metadata and 'annotation_result' in db_metadata:
  468. task_item['annotation_result'] = db_metadata['annotation_result']
  469. final_tasks.append(task_item)
  470. # 统一进行一次递归序列化处理
  471. return self._serialize_datetime({
  472. "name": project_name,
  473. "description": "",
  474. "task_type": external_task_type,
  475. "data": final_tasks,
  476. "external_id": remote_project_id,
  477. "tags": [{"tag": t} for t in sorted(list(all_project_tags))]
  478. })
  479. except Exception as e:
  480. logger.exception(f"导出项目数据异常: {e}")
  481. return {}
  482. finally:
  483. cursor.close()
  484. if should_close:
  485. conn.close()
  486. def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]:
  487. """
  488. 批量从 Milvus 获取文档分片内容
  489. """
  490. if not kb_info or not task_ids:
  491. return {}
  492. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  493. if not collections:
  494. return {}
  495. from app.services.milvus_service import milvus_service
  496. result_map = {tid: [] for tid in task_ids}
  497. if not hasattr(self, '_collection_schema_cache'):
  498. self._collection_schema_cache = {}
  499. for coll_name in collections:
  500. try:
  501. if not milvus_service.client.has_collection(coll_name):
  502. continue
  503. if coll_name not in self._collection_schema_cache:
  504. schema = milvus_service.client.describe_collection(coll_name)
  505. field_names = [f['name'] for f in schema.get('fields', [])]
  506. self._collection_schema_cache[coll_name] = {
  507. "id": "document_id",
  508. "content": "text" if "text" in field_names else "content"
  509. }
  510. fields = self._collection_schema_cache[coll_name]
  511. id_field = fields["id"]
  512. content_field = fields["content"]
  513. # 使用 in 表达式进行批量查询
  514. # 注意:如果 task_ids 非常多,可能需要分批(如每批 100 个)
  515. id_list_str = ", ".join([f'"{tid}"' for tid in task_ids])
  516. res = milvus_service.client.query(
  517. collection_name=coll_name,
  518. filter=f'{id_field} in [{id_list_str}]',
  519. output_fields=[id_field, content_field]
  520. )
  521. if res:
  522. for s in res:
  523. tid = s.get(id_field)
  524. val = s.get(content_field)
  525. if tid in result_map and val:
  526. result_map[tid].append(val)
  527. # 如果当前集合已经查到了内容,就不再查兜底集合(逻辑同单条查询)
  528. if any(result_map.values()):
  529. return result_map
  530. except Exception as e:
  531. logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}")
  532. continue
  533. return result_map
  534. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  535. """获取外部标注项目的进度"""
  536. conn = get_db_connection()
  537. if not conn:
  538. return {"error": "数据库连接失败"}
  539. try:
  540. # 1. 查询 remote_project_id (task_id)
  541. cursor = conn.cursor()
  542. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  543. row = cursor.fetchone()
  544. cursor.close()
  545. if not row or not row['task_id']:
  546. return {"error": "未找到已推送的外部项目ID"}
  547. remote_project_id = row['task_id']
  548. # 2. 获取配置
  549. from app.core.config import config_handler
  550. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  551. progress_url = f"{api_base_url}/{remote_project_id}/progress"
  552. token = config_handler.get('external_api', 'admin_token', '')
  553. # 3. 发送请求
  554. async with httpx.AsyncClient(timeout=10.0) as client:
  555. headers = {"Authorization": f"Bearer {token}"}
  556. response = await client.get(progress_url, headers=headers)
  557. if response.status_code == 200:
  558. data = response.json()
  559. # 同步更新本地缓存的完成数量和总数
  560. completed_count = data.get('completed_tasks', 0)
  561. total_count = data.get('total_tasks', 0)
  562. try:
  563. conn_update = get_db_connection()
  564. if conn_update:
  565. cursor_update = conn_update.cursor()
  566. cursor_update.execute(
  567. """
  568. UPDATE t_task_management
  569. SET external_completed_count = %s, external_total_count = %s
  570. WHERE project_id = %s
  571. """,
  572. (completed_count, total_count, project_id)
  573. )
  574. conn_update.commit()
  575. cursor_update.close()
  576. conn_update.close()
  577. except Exception as ex:
  578. logger.warning(f"更新本地进度缓存失败: {ex}")
  579. return data
  580. else:
  581. logger.error(f"查询进度失败: {response.status_code} - {response.text}")
  582. return {"error": f"外部平台返回错误 ({response.status_code})"}
  583. except Exception as e:
  584. logger.exception(f"查询进度异常: {e}")
  585. return {"error": str(e)}
  586. finally:
  587. conn.close()
  588. async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]:
  589. """触发外部标注项目的数据导出"""
  590. conn = get_db_connection()
  591. if not conn:
  592. return {"error": "数据库连接失败"}
  593. try:
  594. # 1. 查询 remote_project_id (task_id)
  595. cursor = conn.cursor()
  596. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  597. row = cursor.fetchone()
  598. cursor.close()
  599. if not row or not row['task_id']:
  600. return {"error": "未找到已推送的外部项目ID"}
  601. remote_project_id = row['task_id']
  602. # 2. 获取配置
  603. from app.core.config import config_handler
  604. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  605. export_url = f"{api_base_url}/{remote_project_id}/export"
  606. token = config_handler.get('external_api', 'admin_token', '')
  607. # 3. 发送请求
  608. async with httpx.AsyncClient(timeout=30.0) as client:
  609. headers = {
  610. "Authorization": f"Bearer {token}",
  611. "Content-Type": "application/json"
  612. }
  613. payload = {
  614. "format": export_format,
  615. "completed_only": completed_only
  616. }
  617. response = await client.post(export_url, json=payload, headers=headers)
  618. if response.status_code in (200, 201):
  619. res_data = response.json()
  620. logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}")
  621. # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url)
  622. download_url = res_data.get('download_url') or res_data.get('file_url')
  623. if isinstance(res_data, dict) and download_url:
  624. try:
  625. cursor_update = conn.cursor()
  626. affected = cursor_update.execute(
  627. "UPDATE t_task_management SET file_url = %s WHERE project_id = %s",
  628. (download_url, project_id)
  629. )
  630. conn.commit()
  631. cursor_update.close()
  632. logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}")
  633. except Exception as ex:
  634. logger.warning(f"导出同步回写 file_url 失败: {ex}")
  635. # 2. 同步回写 annotation_result 到 metadata
  636. # 情况 A: 接口直接返回了任务列表数据
  637. export_items = []
  638. if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list):
  639. export_items = res_data['data']
  640. # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果
  641. elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json':
  642. try:
  643. # 补全 URL 协议 (如果外部平台返回的是相对路径)
  644. full_download_url = download_url
  645. if not download_url.startswith('http'):
  646. from app.core.config import config_handler
  647. # 尝试从配置获取 base_url,如果没有则从 api_url 中提取
  648. base_url = config_handler.get('external_api', 'download_base_url', '')
  649. if not base_url:
  650. # 兜底:从 project_api_url 中提取域名部分
  651. api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003')
  652. from urllib.parse import urlparse
  653. parsed = urlparse(api_base)
  654. base_url = f"{parsed.scheme}://{parsed.netloc}"
  655. full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}"
  656. logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}")
  657. # 注意:这里需要带上 token
  658. file_res = await client.get(full_download_url, headers=headers)
  659. if file_res.status_code == 200:
  660. file_json = file_res.json()
  661. # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...]
  662. if isinstance(file_json, dict) and 'data' in file_json:
  663. export_items = file_json['data']
  664. elif isinstance(file_json, list):
  665. export_items = file_json
  666. if export_items:
  667. logger.info(f"成功获取导出项,共 {len(export_items)} 条。")
  668. else:
  669. logger.warning("获取到的导出列表为空")
  670. except Exception as ex:
  671. logger.warning(f"从导出文件同步标注数据失败: {ex}")
  672. if export_items:
  673. updated_count = 0
  674. try:
  675. cursor_meta = conn.cursor()
  676. for ext_item in export_items:
  677. # 根据实际 JSON 结构提取数据
  678. # 1. 提取 original_id
  679. original_data = ext_item.get('original_data', {})
  680. meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {})
  681. # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取
  682. original_id = meta.get('original_id')
  683. if not original_id:
  684. ext_id = ext_item.get('external_id', '')
  685. if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid
  686. original_id = ext_id.rsplit('_', 1)[0]
  687. else:
  688. original_id = ext_id
  689. # 2. 提取 annotation_result
  690. annotations = ext_item.get('annotations', [])
  691. if annotations and isinstance(annotations, list):
  692. annotation_res = annotations[0].get('result')
  693. else:
  694. annotation_res = ext_item.get('annotation_result')
  695. if original_id and annotation_res is not None:
  696. # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配
  697. # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识
  698. cursor_meta.execute(
  699. "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s",
  700. (original_id, original_id)
  701. )
  702. row = cursor_meta.fetchone()
  703. if row:
  704. db_id = row['id']
  705. current_meta = json.loads(row['metadata']) if row['metadata'] else {}
  706. current_meta['annotation_result'] = annotation_res
  707. cursor_meta.execute(
  708. "UPDATE t_task_management SET metadata = %s WHERE id = %s",
  709. (json.dumps(current_meta, ensure_ascii=False), db_id)
  710. )
  711. updated_count += 1
  712. else:
  713. logger.debug(f"未在数据库中找到对应的任务: {original_id}")
  714. conn.commit()
  715. cursor_meta.close()
  716. logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result")
  717. except Exception as ex:
  718. logger.warning(f"同步回写 annotation_result 异常: {ex}")
  719. return res_data
  720. else:
  721. logger.error(f"导出数据失败: {response.status_code} - {response.text}")
  722. return {"error": f"外部平台返回错误 ({response.status_code})"}
  723. except Exception as e:
  724. logger.exception(f"导出数据异常: {e}")
  725. return {"error": str(e)}
  726. finally:
  727. conn.close()
  728. async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]:
  729. """将项目数据推送至外部标注平台 (单表化)"""
  730. conn = get_db_connection()
  731. if not conn:
  732. return False, "数据库连接失败"
  733. try:
  734. # 1. 准备数据 (复用数据库连接)
  735. payload = await self.export_project_data(project_id=project_id, conn=conn)
  736. if not payload:
  737. return False, "项目导出失败,请检查项目ID是否正确"
  738. if not payload.get('data'):
  739. return False, f"项目数据为空 (查询到0条有效任务),无法推送"
  740. # 2. 获取配置
  741. from app.core.config import config_handler
  742. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  743. api_url = f"{api_base_url}/init"
  744. token = config_handler.get('external_api', 'admin_token', '')
  745. # 3. 发送请求
  746. async with httpx.AsyncClient(timeout=60.0) as client:
  747. headers = {
  748. "Authorization": f"Bearer {token}",
  749. "Content-Type": "application/json"
  750. }
  751. logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}")
  752. response = await client.post(api_url, json=payload, headers=headers)
  753. if response.status_code in (200, 201):
  754. res_data = response.json()
  755. logger.info(f"外部平台推送成功响应: {res_data}")
  756. remote_project_id = res_data.get('project_id')
  757. # 获取下载地址并回写 (兼容多种返回格式)
  758. download_url = res_data.get('download_url') or res_data.get('file_url')
  759. if download_url:
  760. # 4. 回写外部项目 ID 和下载地址 (复用当前连接)
  761. cursor = conn.cursor()
  762. # 先检查受影响行数
  763. affected = cursor.execute(
  764. "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s",
  765. (remote_project_id, download_url, project_id)
  766. )
  767. conn.commit()
  768. cursor.close()
  769. logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}")
  770. elif remote_project_id:
  771. # 仅回写外部项目 ID
  772. cursor = conn.cursor()
  773. affected = cursor.execute(
  774. "UPDATE t_task_management SET task_id = %s WHERE project_id = %s",
  775. (remote_project_id, project_id)
  776. )
  777. conn.commit()
  778. cursor.close()
  779. logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}")
  780. return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}"
  781. else:
  782. error_msg = response.text
  783. logger.error(f"推送失败: {response.status_code} - {error_msg}")
  784. return False, f"外部平台返回错误 ({response.status_code})"
  785. except Exception as e:
  786. logger.exception(f"推送至外部平台异常: {e}")
  787. return False, f"推送异常: {str(e)}"
  788. finally:
  789. conn.close()
  790. task_service = TaskService()