task_service.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957
  1. import logging
  2. import json
  3. import httpx
  4. from datetime import datetime, date
  5. from typing import List, Dict, Any, Tuple, Optional
  6. from app.base.async_mysql_connection import get_db_connection
  7. from app.base.minio_connection import get_minio_manager
  8. logger = logging.getLogger(__name__)
  9. class TaskService:
  10. """任务管理服务类"""
  11. _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次
  12. def __init__(self):
  13. self.minio_manager = get_minio_manager()
  14. async def _ensure_table_schema(self, cursor, conn):
  15. """确保表结构和索引正确 (DDL 操作)"""
  16. if TaskService._schema_verified:
  17. return
  18. try:
  19. # 1. 动态维护字段
  20. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'")
  21. if not cursor.fetchone():
  22. cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type")
  23. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'")
  24. if not cursor.fetchone():
  25. cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag")
  26. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'")
  27. if not cursor.fetchone():
  28. cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id")
  29. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'")
  30. if not cursor.fetchone():
  31. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status")
  32. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'")
  33. if not cursor.fetchone():
  34. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count")
  35. conn.commit()
  36. # 2. 处理索引冲突
  37. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'")
  38. if not cursor.fetchone():
  39. cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata")
  40. logger.info("Added column file_url to t_task_management")
  41. cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'")
  42. indexes = cursor.fetchall()
  43. for idx in indexes:
  44. if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1:
  45. cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'")
  46. if len(cursor.fetchall()) == 1:
  47. cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management")
  48. logger.info(f"Dropped old unique index: {idx['Key_name']}")
  49. cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'")
  50. if not cursor.fetchone():
  51. cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)")
  52. logger.info("Created new composite unique index: uk_business_project")
  53. cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'idx_project_id'")
  54. if not cursor.fetchone():
  55. cursor.execute("CREATE INDEX idx_project_id ON t_task_management (project_id)")
  56. logger.info("Created index: idx_project_id")
  57. conn.commit()
  58. TaskService._schema_verified = True
  59. except Exception as e:
  60. logger.warning(f"表结构维护失败: {e}")
  61. conn.rollback()
  62. async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]:
  63. """添加或更新任务记录 (适配单表结构)"""
  64. conn = get_db_connection()
  65. if not conn:
  66. return False, "数据库连接失败", None
  67. cursor = conn.cursor()
  68. try:
  69. # 确保表结构(仅在第一次调用时执行)
  70. await self._ensure_table_schema(cursor, conn)
  71. # 执行插入/更新
  72. sql = """
  73. INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata)
  74. VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  75. ON DUPLICATE KEY UPDATE
  76. task_id = IFNULL(VALUES(task_id), task_id),
  77. project_id = IFNULL(VALUES(project_id), project_id),
  78. project_name = IFNULL(VALUES(project_name), project_name),
  79. annotation_status = IFNULL(VALUES(annotation_status), annotation_status),
  80. tag = IFNULL(VALUES(tag), tag),
  81. metadata = IFNULL(VALUES(metadata), metadata)
  82. """
  83. cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata))
  84. record_id = cursor.lastrowid
  85. conn.commit()
  86. return True, "成功", record_id
  87. except Exception as e:
  88. conn.rollback()
  89. logger.exception(f"添加任务失败: {e}")
  90. return False, str(e), None
  91. finally:
  92. cursor.close()
  93. conn.close()
  94. async def delete_task_by_id(self, id: int) -> Tuple[bool, str]:
  95. """根据主键 id 删除任务记录"""
  96. conn = get_db_connection()
  97. if not conn:
  98. return False, "数据库连接失败"
  99. cursor = conn.cursor()
  100. try:
  101. sql = "DELETE FROM t_task_management WHERE id = %s"
  102. cursor.execute(sql, (id,))
  103. conn.commit()
  104. return True, "成功"
  105. except Exception as e:
  106. conn.rollback()
  107. logger.exception(f"删除任务失败: {e}")
  108. return False, str(e)
  109. finally:
  110. cursor.close()
  111. conn.close()
  112. def _serialize_datetime(self, obj: Any) -> Any:
  113. """递归遍历对象,将 datetime 转换为字符串"""
  114. if isinstance(obj, dict):
  115. return {k: self._serialize_datetime(v) for k, v in obj.items()}
  116. elif isinstance(obj, list):
  117. return [self._serialize_datetime(i) for i in obj]
  118. elif isinstance(obj, (datetime, date)):
  119. return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d')
  120. return obj
  121. async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]:
  122. """获取项目列表 (按 project_id 聚合)"""
  123. conn = get_db_connection()
  124. if not conn or task_type not in ['data', 'image']:
  125. return []
  126. cursor = conn.cursor()
  127. try:
  128. # 确保表结构
  129. await self._ensure_table_schema(cursor, conn)
  130. # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字)
  131. sql = """
  132. SELECT
  133. project_id,
  134. MAX(project_name) as project_name,
  135. MAX(task_id) as task_id,
  136. MAX(tag) as tag,
  137. MAX(id) as sort_id,
  138. COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count,
  139. COALESCE(
  140. MAX(external_completed_count),
  141. SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END),
  142. 0
  143. ) as completed_count
  144. FROM t_task_management
  145. WHERE type = %s
  146. GROUP BY project_id
  147. ORDER BY sort_id DESC
  148. """
  149. cursor.execute(sql, (task_type,))
  150. rows = cursor.fetchall()
  151. # 兼容旧数据:如果 project_name 为空,则用 project_id 代替
  152. for row in rows:
  153. if not row['project_name']:
  154. row['project_name'] = row['project_id']
  155. return self._serialize_datetime(rows)
  156. except Exception as e:
  157. logger.error(f"获取任务列表失败: {e}")
  158. return []
  159. finally:
  160. cursor.close()
  161. conn.close()
  162. async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]:
  163. """获取项目详情 (按 project_id 查询)"""
  164. conn = get_db_connection()
  165. if not conn:
  166. return []
  167. cursor = conn.cursor()
  168. try:
  169. # 确保表结构
  170. await self._ensure_table_schema(cursor, conn)
  171. # 修改查询逻辑:获取 project_name 并在结果中包含它
  172. sql = """
  173. SELECT
  174. id, business_id, task_id, project_id, project_name,
  175. type, annotation_status, tag, metadata
  176. FROM t_task_management
  177. WHERE project_id = %s AND type = %s
  178. """
  179. cursor.execute(sql, (project_id, task_type))
  180. rows = cursor.fetchall()
  181. # 处理 tag 和 metadata 的 JSON 解析
  182. for row in rows:
  183. if row.get('tag'):
  184. try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag']
  185. except: pass
  186. if row.get('metadata'):
  187. try:
  188. meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
  189. row['metadata'] = meta
  190. # 提取名称供前端显示
  191. if row['type'] == 'data':
  192. row['name'] = meta.get('title') or meta.get('filename') or row['business_id']
  193. elif row['type'] == 'image':
  194. row['name'] = meta.get('image_name') or row['business_id']
  195. else:
  196. row['name'] = row['business_id']
  197. except:
  198. row['name'] = row['business_id']
  199. else:
  200. row['name'] = row['business_id']
  201. return self._serialize_datetime(rows)
  202. except Exception as e:
  203. logger.error(f"获取项目详情失败: {e}")
  204. return []
  205. finally:
  206. cursor.close()
  207. conn.close()
  208. async def delete_task(self, business_id: str) -> Tuple[bool, str]:
  209. """删除任务记录"""
  210. conn = get_db_connection()
  211. if not conn:
  212. return False, "数据库连接失败"
  213. cursor = conn.cursor()
  214. try:
  215. sql = "DELETE FROM t_task_management WHERE business_id = %s"
  216. cursor.execute(sql, (business_id,))
  217. conn.commit()
  218. return True, "删除成功"
  219. except Exception as e:
  220. logger.exception(f"删除任务记录失败: {e}")
  221. conn.rollback()
  222. return False, f"删除失败: {str(e)}"
  223. finally:
  224. cursor.close()
  225. conn.close()
  226. async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]:
  227. """创建标注项目并同步任务数据"""
  228. project_name = data.get('name')
  229. if not project_name:
  230. return False, "项目名称不能为空"
  231. # 0. 统一使用 UUID 方案获取或生成 project_id
  232. conn = get_db_connection()
  233. if not conn:
  234. return False, "数据库连接失败"
  235. cursor = conn.cursor()
  236. try:
  237. project_id = None
  238. cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,))
  239. existing_project = cursor.fetchone()
  240. if existing_project:
  241. project_id = existing_project['project_id']
  242. else:
  243. import uuid
  244. project_id = str(uuid.uuid4())
  245. task_type = data.get('task_type', 'data')
  246. # 映射回内部类型
  247. internal_type_map = {
  248. 'text_classification': 'data',
  249. 'image_classification': 'image'
  250. }
  251. internal_task_type = internal_type_map.get(task_type, task_type)
  252. tasks_data = data.get('data', [])
  253. if not tasks_data:
  254. return False, "任务数据不能为空"
  255. # 提取全局标签名列表
  256. global_tags = []
  257. if data.get('tags'):
  258. global_tags = [t['tag'] for t in data['tags'] if 'tag' in t]
  259. # 批量写入任务
  260. import json
  261. tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None
  262. for item in tasks_data:
  263. business_id = item.get('id')
  264. if not business_id: continue
  265. # 检查是否已存在 (使用联合主键逻辑)
  266. cursor.execute(
  267. "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s",
  268. (business_id, project_id)
  269. )
  270. if cursor.fetchone():
  271. continue
  272. metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False)
  273. sql = """
  274. INSERT INTO t_task_management
  275. (business_id, type, project_id, project_name, tag, metadata, annotation_status)
  276. VALUES (%s, %s, %s, %s, %s, %s, 'pending')
  277. """
  278. cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str))
  279. conn.commit()
  280. # 自动推送至外部平台
  281. success, msg = await self.send_to_external_platform(project_id)
  282. if success:
  283. return True, project_id
  284. else:
  285. return False, f"任务已保存但推送失败: {msg}"
  286. except Exception as e:
  287. logger.exception(f"创建项目失败: {e}")
  288. conn.rollback()
  289. return False, str(e)
  290. finally:
  291. cursor.close()
  292. conn.close()
  293. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  294. """获取项目进度统计 (单表化)"""
  295. conn = get_db_connection()
  296. if not conn:
  297. return {}
  298. cursor = conn.cursor()
  299. try:
  300. # 统计各状态数量
  301. sql = """
  302. SELECT
  303. annotation_status,
  304. COUNT(*) as count
  305. FROM t_task_management
  306. WHERE project_id = %s
  307. GROUP BY annotation_status
  308. """
  309. cursor.execute(sql, (project_id,))
  310. stats = cursor.fetchall()
  311. if not stats:
  312. return {}
  313. total = sum(s['count'] for s in stats)
  314. completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed')
  315. return {
  316. "project_id": project_id,
  317. "total": total,
  318. "completed": completed,
  319. "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%",
  320. "details": stats
  321. }
  322. except Exception as e:
  323. logger.error(f"查询进度失败: {e}")
  324. return {}
  325. finally:
  326. cursor.close()
  327. conn.close()
  328. def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]:
  329. """
  330. 从 Milvus 获取文档分片内容
  331. """
  332. if not kb_info:
  333. return []
  334. # 优先使用数据库存储的子表名,父表名作为兜底
  335. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  336. if not collections:
  337. return []
  338. from app.services.milvus_service import milvus_service
  339. contents = []
  340. # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用
  341. if not hasattr(self, '_collection_schema_cache'):
  342. self._collection_schema_cache = {}
  343. for coll_name in collections:
  344. try:
  345. if not milvus_service.client.has_collection(coll_name):
  346. continue
  347. # 获取或更新缓存的字段名
  348. if coll_name not in self._collection_schema_cache:
  349. schema = milvus_service.client.describe_collection(coll_name)
  350. field_names = [f['name'] for f in schema.get('fields', [])]
  351. self._collection_schema_cache[coll_name] = {
  352. "id": "document_id", # 统一使用 document_id
  353. "content": "text" if "text" in field_names else "content"
  354. }
  355. fields = self._collection_schema_cache[coll_name]
  356. id_field = fields["id"]
  357. content_field = fields["content"]
  358. res = milvus_service.client.query(
  359. collection_name=coll_name,
  360. filter=f'{id_field} == "{task_id}"',
  361. output_fields=[content_field]
  362. )
  363. if res:
  364. for s in res:
  365. val = s.get(content_field)
  366. if val: contents.append(val)
  367. if contents:
  368. return contents
  369. except Exception as e:
  370. logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}")
  371. continue
  372. return []
  373. async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]:
  374. """导出项目数据为标注平台要求的格式 (单表化)"""
  375. should_close = False
  376. if not conn:
  377. conn = get_db_connection()
  378. should_close = True
  379. if not conn:
  380. return {}
  381. cursor = conn.cursor()
  382. try:
  383. # 1. 获取任务记录
  384. sql_tasks = """
  385. SELECT business_id as id, type, task_id, project_name, tag, metadata
  386. FROM t_task_management
  387. WHERE project_id = %s
  388. """
  389. cursor.execute(sql_tasks, (project_id,))
  390. rows = cursor.fetchall()
  391. if not rows:
  392. return {}
  393. # 2. 解析基本信息
  394. first_row = rows[0]
  395. internal_task_type = first_row['type']
  396. project_name = first_row.get('project_name') or project_id
  397. remote_project_id = first_row.get('task_id') or project_id
  398. # 映射任务类型
  399. type_map = {'data': 'text_classification', 'image': 'image_classification'}
  400. external_task_type = type_map.get(internal_task_type, internal_task_type)
  401. # 3. 处理数据
  402. final_tasks = []
  403. all_project_tags = set()
  404. # 针对 'data' 类型的批量 Milvus 查询优化
  405. milvus_data_map = {}
  406. missing_milvus_ids = []
  407. if internal_task_type == 'data':
  408. all_task_ids = [r['id'] for r in rows]
  409. sql_kb = """
  410. SELECT kb.collection_name_parent, kb.collection_name_children
  411. FROM t_samp_document_main d
  412. LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id
  413. WHERE d.id = %s
  414. """
  415. cursor.execute(sql_kb, (all_task_ids[0],))
  416. kb_info = cursor.fetchone()
  417. if kb_info:
  418. milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info)
  419. # 记录缺失 Milvus 数据的 ID
  420. missing_milvus_ids = [tid for tid in all_task_ids if not milvus_data_map.get(tid)]
  421. # 针对缺失数据的批量标题查询
  422. title_map = {}
  423. if missing_milvus_ids:
  424. placeholders = ', '.join(['%s'] * len(missing_milvus_ids))
  425. cursor.execute(f"SELECT id, title FROM t_samp_document_main WHERE id IN ({placeholders})", missing_milvus_ids)
  426. title_map = {r['id']: r['title'] for r in cursor.fetchall()}
  427. # 针对 'image' 类型的批量 MinIO 查询优化
  428. image_data_map = {}
  429. if internal_task_type == 'image':
  430. all_image_ids = [r['id'] for r in rows]
  431. placeholders = ', '.join(['%s'] * len(all_image_ids))
  432. cursor.execute(f"SELECT id, image_url FROM t_image_info WHERE id IN ({placeholders})", all_image_ids)
  433. for img_row in cursor.fetchall():
  434. img_url = img_row['image_url']
  435. if img_url and not img_url.startswith('http'):
  436. img_url = self.minio_manager.get_full_url(img_url)
  437. image_data_map[img_row['id']] = [img_url] if img_url else []
  438. for item in rows:
  439. task_id = item['id']
  440. # 提取并处理标签
  441. doc_tags = []
  442. if item.get('tag'):
  443. try:
  444. doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag']
  445. if doc_tags:
  446. for t in doc_tags: all_project_tags.add(t)
  447. except: pass
  448. # 解析数据库元数据 (提前序列化日期)
  449. db_metadata = {}
  450. if item.get('metadata'):
  451. try:
  452. db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata']
  453. if db_metadata:
  454. db_metadata = self._serialize_datetime(db_metadata)
  455. except: pass
  456. # 获取任务内容
  457. task_contents = []
  458. if internal_task_type == 'data':
  459. task_contents = milvus_data_map.get(task_id, [])
  460. if not task_contents:
  461. title = title_map.get(task_id)
  462. if title: task_contents = [title]
  463. elif internal_task_type == 'image':
  464. task_contents = image_data_map.get(task_id, [])
  465. # 构建最终任务列表
  466. for idx, content in enumerate(task_contents):
  467. if not content: continue
  468. # 合并元数据:数据库数据 + 动态 ID
  469. task_metadata = {
  470. "original_id": task_id,
  471. "chunk_index": idx
  472. }
  473. if db_metadata:
  474. task_metadata.update(db_metadata)
  475. if doc_tags:
  476. task_metadata['tags'] = [{"tag": tag} for tag in doc_tags]
  477. task_item = {
  478. "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id,
  479. "content": content,
  480. "metadata": task_metadata
  481. }
  482. # 尝试从元数据中提取 annotation_result
  483. if db_metadata and 'annotation_result' in db_metadata:
  484. task_item['annotation_result'] = db_metadata['annotation_result']
  485. final_tasks.append(task_item)
  486. # 准备返回结果,不再进行全局递归序列化 (已在局部处理)
  487. return {
  488. "name": project_name,
  489. "description": "",
  490. "task_type": external_task_type,
  491. "data": final_tasks,
  492. "external_id": remote_project_id,
  493. "tags": [{"tag": t} for t in sorted(list(all_project_tags))]
  494. }
  495. except Exception as e:
  496. logger.exception(f"导出项目数据异常: {e}")
  497. return {}
  498. finally:
  499. cursor.close()
  500. if should_close:
  501. conn.close()
  502. def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]:
  503. """
  504. 批量从 Milvus 获取文档分片内容
  505. """
  506. if not kb_info or not task_ids:
  507. return {}
  508. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  509. if not collections:
  510. return {}
  511. from app.services.milvus_service import milvus_service
  512. result_map = {tid: [] for tid in task_ids}
  513. if not hasattr(self, '_collection_schema_cache'):
  514. self._collection_schema_cache = {}
  515. for coll_name in collections:
  516. try:
  517. if not milvus_service.client.has_collection(coll_name):
  518. continue
  519. if coll_name not in self._collection_schema_cache:
  520. schema = milvus_service.client.describe_collection(coll_name)
  521. field_names = [f['name'] for f in schema.get('fields', [])]
  522. self._collection_schema_cache[coll_name] = {
  523. "id": "document_id",
  524. "content": "text" if "text" in field_names else "content"
  525. }
  526. fields = self._collection_schema_cache[coll_name]
  527. id_field = fields["id"]
  528. content_field = fields["content"]
  529. # 使用 in 表达式进行批量查询 (分批处理以防 ID 过多)
  530. CHUNK_SIZE = 100
  531. for i in range(0, len(task_ids), CHUNK_SIZE):
  532. chunk_ids = task_ids[i:i + CHUNK_SIZE]
  533. id_list_str = ", ".join([f'"{tid}"' for tid in chunk_ids])
  534. logger.info(f"正在从 Milvus 集合 {coll_name} 查询分片内容 ({i}/{len(task_ids)})...")
  535. res = milvus_service.client.query(
  536. collection_name=coll_name,
  537. filter=f'{id_field} in [{id_list_str}]',
  538. output_fields=[id_field, content_field]
  539. )
  540. if res:
  541. for s in res:
  542. tid = s.get(id_field)
  543. val = s.get(content_field)
  544. if tid in result_map and val:
  545. result_map[tid].append(val)
  546. # 如果当前集合已经查到了内容,就不再查兜底集合 (除非结果仍为空)
  547. if any(result_map.values()):
  548. logger.info(f"从 Milvus 集合 {coll_name} 查得 {sum(len(v) for v in result_map.values())} 条内容分片")
  549. return result_map
  550. except Exception as e:
  551. logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}")
  552. continue
  553. return result_map
  554. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  555. """获取外部标注项目的进度"""
  556. conn = get_db_connection()
  557. if not conn:
  558. return {"error": "数据库连接失败"}
  559. try:
  560. # 1. 查询 remote_project_id (task_id)
  561. cursor = conn.cursor()
  562. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  563. row = cursor.fetchone()
  564. cursor.close()
  565. if not row or not row['task_id']:
  566. return {"error": "未找到已推送的外部项目ID"}
  567. remote_project_id = row['task_id']
  568. # 2. 获取配置
  569. from app.core.config import config_handler
  570. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  571. progress_url = f"{api_base_url}/{remote_project_id}/progress"
  572. token = config_handler.get('external_api', 'admin_token', '')
  573. # 3. 发送请求
  574. async with httpx.AsyncClient(timeout=10.0) as client:
  575. headers = {"Authorization": f"Bearer {token}"}
  576. response = await client.get(progress_url, headers=headers)
  577. if response.status_code == 200:
  578. data = response.json()
  579. # 同步更新本地缓存的完成数量和总数
  580. completed_count = data.get('completed_tasks', 0)
  581. total_count = data.get('total_tasks', 0)
  582. try:
  583. conn_update = get_db_connection()
  584. if conn_update:
  585. cursor_update = conn_update.cursor()
  586. cursor_update.execute(
  587. """
  588. UPDATE t_task_management
  589. SET external_completed_count = %s, external_total_count = %s
  590. WHERE project_id = %s
  591. """,
  592. (completed_count, total_count, project_id)
  593. )
  594. conn_update.commit()
  595. cursor_update.close()
  596. conn_update.close()
  597. except Exception as ex:
  598. logger.warning(f"更新本地进度缓存失败: {ex}")
  599. return data
  600. else:
  601. logger.error(f"查询进度失败: {response.status_code} - {response.text}")
  602. return {"error": f"外部平台返回错误 ({response.status_code})"}
  603. except Exception as e:
  604. logger.exception(f"查询进度异常: {e}")
  605. return {"error": str(e)}
  606. finally:
  607. conn.close()
  608. async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]:
  609. """触发外部标注项目的数据导出"""
  610. conn = get_db_connection()
  611. if not conn:
  612. return {"error": "数据库连接失败"}
  613. try:
  614. # 1. 查询 remote_project_id (task_id)
  615. cursor = conn.cursor()
  616. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  617. row = cursor.fetchone()
  618. cursor.close()
  619. if not row or not row['task_id']:
  620. return {"error": "未找到已推送的外部项目ID"}
  621. remote_project_id = row['task_id']
  622. # 2. 获取配置
  623. from app.core.config import config_handler
  624. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  625. export_url = f"{api_base_url}/{remote_project_id}/export"
  626. token = config_handler.get('external_api', 'admin_token', '')
  627. # 3. 发送请求
  628. async with httpx.AsyncClient(timeout=30.0) as client:
  629. headers = {
  630. "Authorization": f"Bearer {token}",
  631. "Content-Type": "application/json"
  632. }
  633. payload = {
  634. "format": export_format,
  635. "completed_only": completed_only
  636. }
  637. response = await client.post(export_url, json=payload, headers=headers)
  638. if response.status_code in (200, 201):
  639. res_data = response.json()
  640. logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}")
  641. # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url)
  642. download_url = res_data.get('download_url') or res_data.get('file_url')
  643. if isinstance(res_data, dict) and download_url:
  644. try:
  645. cursor_update = conn.cursor()
  646. affected = cursor_update.execute(
  647. "UPDATE t_task_management SET file_url = %s WHERE project_id = %s",
  648. (download_url, project_id)
  649. )
  650. conn.commit()
  651. cursor_update.close()
  652. logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}")
  653. except Exception as ex:
  654. logger.warning(f"导出同步回写 file_url 失败: {ex}")
  655. # 2. 同步回写 annotation_result 到 metadata
  656. # 情况 A: 接口直接返回了任务列表数据
  657. export_items = []
  658. if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list):
  659. export_items = res_data['data']
  660. # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果
  661. elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json':
  662. try:
  663. # 补全 URL 协议 (如果外部平台返回的是相对路径)
  664. full_download_url = download_url
  665. if not download_url.startswith('http'):
  666. from app.core.config import config_handler
  667. # 尝试从配置获取 base_url,如果没有则从 api_url 中提取
  668. base_url = config_handler.get('external_api', 'download_base_url', '')
  669. if not base_url:
  670. # 兜底:从 project_api_url 中提取域名部分
  671. api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003')
  672. from urllib.parse import urlparse
  673. parsed = urlparse(api_base)
  674. base_url = f"{parsed.scheme}://{parsed.netloc}"
  675. full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}"
  676. logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}")
  677. # 注意:这里需要带上 token
  678. file_res = await client.get(full_download_url, headers=headers)
  679. if file_res.status_code == 200:
  680. file_json = file_res.json()
  681. # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...]
  682. if isinstance(file_json, dict) and 'data' in file_json:
  683. export_items = file_json['data']
  684. elif isinstance(file_json, list):
  685. export_items = file_json
  686. if export_items:
  687. logger.info(f"成功获取导出项,共 {len(export_items)} 条。")
  688. else:
  689. logger.warning("获取到的导出列表为空")
  690. except Exception as ex:
  691. logger.warning(f"从导出文件同步标注数据失败: {ex}")
  692. if export_items:
  693. updated_count = 0
  694. try:
  695. cursor_meta = conn.cursor()
  696. for ext_item in export_items:
  697. # 根据实际 JSON 结构提取数据
  698. # 1. 提取 original_id
  699. original_data = ext_item.get('original_data', {})
  700. meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {})
  701. # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取
  702. original_id = meta.get('original_id')
  703. if not original_id:
  704. ext_id = ext_item.get('external_id', '')
  705. if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid
  706. original_id = ext_id.rsplit('_', 1)[0]
  707. else:
  708. original_id = ext_id
  709. # 2. 提取 annotation_result
  710. annotations = ext_item.get('annotations', [])
  711. if annotations and isinstance(annotations, list):
  712. annotation_res = annotations[0].get('result')
  713. else:
  714. annotation_res = ext_item.get('annotation_result')
  715. if original_id and annotation_res is not None:
  716. # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配
  717. # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识
  718. cursor_meta.execute(
  719. "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s",
  720. (original_id, original_id)
  721. )
  722. row = cursor_meta.fetchone()
  723. if row:
  724. db_id = row['id']
  725. current_meta = json.loads(row['metadata']) if row['metadata'] else {}
  726. current_meta['annotation_result'] = annotation_res
  727. cursor_meta.execute(
  728. "UPDATE t_task_management SET metadata = %s WHERE id = %s",
  729. (json.dumps(current_meta, ensure_ascii=False), db_id)
  730. )
  731. updated_count += 1
  732. else:
  733. logger.debug(f"未在数据库中找到对应的任务: {original_id}")
  734. conn.commit()
  735. cursor_meta.close()
  736. logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result")
  737. except Exception as ex:
  738. logger.warning(f"同步回写 annotation_result 异常: {ex}")
  739. return res_data
  740. else:
  741. logger.error(f"导出数据失败: {response.status_code} - {response.text}")
  742. return {"error": f"外部平台返回错误 ({response.status_code})"}
  743. except Exception as e:
  744. logger.exception(f"导出数据异常: {e}")
  745. return {"error": str(e)}
  746. finally:
  747. conn.close()
  748. async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]:
  749. """将项目数据推送至外部标注平台 (单表化)"""
  750. # 1. 准备数据 (导出数据需要数据库连接)
  751. conn = get_db_connection()
  752. if not conn:
  753. return False, "数据库连接失败"
  754. try:
  755. logger.info(f"开始导出项目 {project_id} 数据...")
  756. payload = await self.export_project_data(project_id=project_id, conn=conn)
  757. if not payload:
  758. return False, "项目导出失败,请检查项目ID是否正确"
  759. if not payload.get('data'):
  760. return False, f"项目数据为空 (查询到0条有效任务),无法推送"
  761. except Exception as e:
  762. logger.exception(f"导出项目数据异常: {e}")
  763. return False, f"导出异常: {str(e)}"
  764. finally:
  765. # 及时释放连接,防止在 HTTP 请求期间占用
  766. conn.close()
  767. # 2. 获取配置
  768. try:
  769. from app.core.config import config_handler
  770. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  771. api_url = f"{api_base_url}/init"
  772. token = config_handler.get('external_api', 'admin_token', '')
  773. # 3. 发送请求 (不持有数据库连接)
  774. async with httpx.AsyncClient(timeout=120.0) as client: # 增加超时时间到 120s
  775. headers = {
  776. "Authorization": f"Bearer {token}",
  777. "Content-Type": "application/json"
  778. }
  779. logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}")
  780. response = await client.post(api_url, json=payload, headers=headers)
  781. if response.status_code in (200, 201):
  782. res_data = response.json()
  783. logger.info(f"外部平台推送成功响应: {res_data}")
  784. remote_project_id = res_data.get('project_id')
  785. download_url = res_data.get('download_url') or res_data.get('file_url')
  786. # 4. 回写外部项目 ID 和下载地址 (重新获取连接)
  787. if remote_project_id:
  788. conn = get_db_connection()
  789. if conn:
  790. try:
  791. cursor = conn.cursor()
  792. if download_url:
  793. affected = cursor.execute(
  794. "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s",
  795. (remote_project_id, download_url, project_id)
  796. )
  797. logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}")
  798. else:
  799. affected = cursor.execute(
  800. "UPDATE t_task_management SET task_id = %s WHERE project_id = %s",
  801. (remote_project_id, project_id)
  802. )
  803. logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}")
  804. conn.commit()
  805. finally:
  806. cursor.close()
  807. conn.close()
  808. return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}"
  809. else:
  810. error_msg = response.text
  811. logger.error(f"推送失败: {response.status_code} - {error_msg}")
  812. return False, f"外部平台返回错误 ({response.status_code})"
  813. except Exception as e:
  814. logger.exception(f"推送至外部平台异常: {e}")
  815. return False, f"推送异常: {str(e)}"
  816. task_service = TaskService()