task_service.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. import logging
  2. import json
  3. import httpx
  4. from datetime import datetime, date
  5. from typing import List, Dict, Any, Tuple, Optional
  6. from app.base.async_mysql_connection import get_db_connection
  7. from app.base.minio_connection import get_minio_manager
  8. logger = logging.getLogger(__name__)
  9. def generate_tag_colors(tags: List[str]) -> Dict[str, str]:
  10. """
  11. 为标签生成不重复的颜色
  12. Args:
  13. tags: 标签列表
  14. Returns:
  15. 标签到颜色值的映射字典
  16. """
  17. # 预定义的一组高对比度、视觉区分度高的颜色
  18. pre_defined_colors = [
  19. "#FF5733", "#33FF57", "#3357FF", "#FF33A1", "#33FFF5",
  20. "#F5FF33", "#FF8C33", "#8C33FF", "#FF3333", "#33FF8C",
  21. "#FF338C", "#8CFF33", "#338CFF", "#FF5733", "#57FF33",
  22. "#3357FF", "#FF33FF", "#33FFFF", "#FFFF33", "#FF8000",
  23. "#80FF00", "#0080FF", "#FF0080", "#00FF80", "#8000FF",
  24. "#FF0000", "#00FF00", "#0000FF", "#FFFF00", "#FF00FF",
  25. "#00FFFF", "#FFA500", "#A52A2A", "#800080", "#008080",
  26. "#000080", "#800000", "#008000", "#000000"
  27. ]
  28. color_map = {}
  29. for i, tag in enumerate(tags):
  30. # 使用预定义颜色,如果标签数量超过预定义颜色,则通过算法生成
  31. if i < len(pre_defined_colors):
  32. color_map[tag] = pre_defined_colors[i]
  33. else:
  34. # 使用 HSL 色相环生成颜色,确保颜色不重复
  35. hue = (i * 137.508) % 360 # 黄金角度,确保颜色分布均匀
  36. # 高饱和度和亮度,保证颜色鲜艳
  37. color_map[tag] = f"hsl({hue:.0f}, 70%, 50%)"
  38. return color_map
  39. class TaskService:
  40. """任务管理服务类"""
  41. _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次
  42. def __init__(self):
  43. self.minio_manager = get_minio_manager()
  44. async def _ensure_table_schema(self, cursor, conn):
  45. """确保表结构和索引正确 (DDL 操作)"""
  46. if TaskService._schema_verified:
  47. return
  48. try:
  49. # 1. 动态维护字段
  50. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'")
  51. if not cursor.fetchone():
  52. cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type")
  53. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'")
  54. if not cursor.fetchone():
  55. cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag")
  56. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'")
  57. if not cursor.fetchone():
  58. cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id")
  59. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'")
  60. if not cursor.fetchone():
  61. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status")
  62. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'")
  63. if not cursor.fetchone():
  64. cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count")
  65. conn.commit()
  66. # 2. 处理索引冲突
  67. cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'")
  68. if not cursor.fetchone():
  69. cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata")
  70. logger.info("Added column file_url to t_task_management")
  71. cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'")
  72. indexes = cursor.fetchall()
  73. for idx in indexes:
  74. if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1:
  75. cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'")
  76. if len(cursor.fetchall()) == 1:
  77. cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management")
  78. logger.info(f"Dropped old unique index: {idx['Key_name']}")
  79. cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'")
  80. if not cursor.fetchone():
  81. cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)")
  82. logger.info("Created new composite unique index: uk_business_project")
  83. cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'idx_project_id'")
  84. if not cursor.fetchone():
  85. cursor.execute("CREATE INDEX idx_project_id ON t_task_management (project_id)")
  86. logger.info("Created index: idx_project_id")
  87. conn.commit()
  88. TaskService._schema_verified = True
  89. except Exception as e:
  90. logger.warning(f"表结构维护失败: {e}")
  91. conn.rollback()
  92. async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]:
  93. """添加或更新任务记录 (适配单表结构)"""
  94. conn = get_db_connection()
  95. if not conn:
  96. return False, "数据库连接失败", None
  97. cursor = conn.cursor()
  98. try:
  99. # 确保表结构(仅在第一次调用时执行)
  100. await self._ensure_table_schema(cursor, conn)
  101. # 执行插入/更新
  102. sql = """
  103. INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata)
  104. VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  105. ON DUPLICATE KEY UPDATE
  106. task_id = IFNULL(VALUES(task_id), task_id),
  107. project_id = IFNULL(VALUES(project_id), project_id),
  108. project_name = IFNULL(VALUES(project_name), project_name),
  109. annotation_status = IFNULL(VALUES(annotation_status), annotation_status),
  110. tag = IFNULL(VALUES(tag), tag),
  111. metadata = IFNULL(VALUES(metadata), metadata)
  112. """
  113. cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata))
  114. record_id = cursor.lastrowid
  115. conn.commit()
  116. return True, "成功", record_id
  117. except Exception as e:
  118. conn.rollback()
  119. logger.exception(f"添加任务失败: {e}")
  120. return False, str(e), None
  121. finally:
  122. cursor.close()
  123. conn.close()
  124. async def delete_task_by_id(self, id: int) -> Tuple[bool, str]:
  125. """根据主键 id 删除任务记录"""
  126. conn = get_db_connection()
  127. if not conn:
  128. return False, "数据库连接失败"
  129. cursor = conn.cursor()
  130. try:
  131. sql = "DELETE FROM t_task_management WHERE id = %s"
  132. cursor.execute(sql, (id,))
  133. conn.commit()
  134. return True, "成功"
  135. except Exception as e:
  136. conn.rollback()
  137. logger.exception(f"删除任务失败: {e}")
  138. return False, str(e)
  139. finally:
  140. cursor.close()
  141. conn.close()
  142. def _serialize_datetime(self, obj: Any) -> Any:
  143. """递归遍历对象,将 datetime 转换为字符串"""
  144. if isinstance(obj, dict):
  145. return {k: self._serialize_datetime(v) for k, v in obj.items()}
  146. elif isinstance(obj, list):
  147. return [self._serialize_datetime(i) for i in obj]
  148. elif isinstance(obj, (datetime, date)):
  149. return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d')
  150. return obj
  151. async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]:
  152. """获取项目列表 (按 project_id 聚合)"""
  153. conn = get_db_connection()
  154. if not conn or task_type not in ['data', 'image']:
  155. return []
  156. cursor = conn.cursor()
  157. try:
  158. # 确保表结构
  159. await self._ensure_table_schema(cursor, conn)
  160. # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字)
  161. sql = """
  162. SELECT
  163. project_id,
  164. MAX(project_name) as project_name,
  165. MAX(task_id) as task_id,
  166. MAX(tag) as tag,
  167. MAX(id) as sort_id,
  168. COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count,
  169. COALESCE(
  170. MAX(external_completed_count),
  171. SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END),
  172. 0
  173. ) as completed_count
  174. FROM t_task_management
  175. WHERE type = %s
  176. GROUP BY project_id
  177. ORDER BY sort_id DESC
  178. """
  179. cursor.execute(sql, (task_type,))
  180. rows = cursor.fetchall()
  181. # 兼容旧数据:如果 project_name 为空,则用 project_id 代替
  182. for row in rows:
  183. if not row['project_name']:
  184. row['project_name'] = row['project_id']
  185. return self._serialize_datetime(rows)
  186. except Exception as e:
  187. logger.error(f"获取任务列表失败: {e}")
  188. return []
  189. finally:
  190. cursor.close()
  191. conn.close()
  192. async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]:
  193. """获取项目详情 (按 project_id 查询)"""
  194. conn = get_db_connection()
  195. if not conn:
  196. return []
  197. cursor = conn.cursor()
  198. try:
  199. # 确保表结构
  200. await self._ensure_table_schema(cursor, conn)
  201. # 修改查询逻辑:获取 project_name 并在结果中包含它
  202. sql = """
  203. SELECT
  204. id, business_id, task_id, project_id, project_name,
  205. type, annotation_status, tag, metadata
  206. FROM t_task_management
  207. WHERE project_id = %s AND type = %s
  208. """
  209. cursor.execute(sql, (project_id, task_type))
  210. rows = cursor.fetchall()
  211. # 处理 tag 和 metadata 的 JSON 解析
  212. for row in rows:
  213. if row.get('tag'):
  214. try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag']
  215. except: pass
  216. if row.get('metadata'):
  217. try:
  218. meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
  219. row['metadata'] = meta
  220. # 提取名称供前端显示
  221. if row['type'] == 'data':
  222. row['name'] = meta.get('title') or meta.get('filename') or row['business_id']
  223. elif row['type'] == 'image':
  224. row['name'] = meta.get('image_name') or row['business_id']
  225. else:
  226. row['name'] = row['business_id']
  227. except:
  228. row['name'] = row['business_id']
  229. else:
  230. row['name'] = row['business_id']
  231. return self._serialize_datetime(rows)
  232. except Exception as e:
  233. logger.error(f"获取项目详情失败: {e}")
  234. return []
  235. finally:
  236. cursor.close()
  237. conn.close()
  238. async def delete_task(self, business_id: str) -> Tuple[bool, str]:
  239. """删除任务记录"""
  240. conn = get_db_connection()
  241. if not conn:
  242. return False, "数据库连接失败"
  243. cursor = conn.cursor()
  244. try:
  245. sql = "DELETE FROM t_task_management WHERE business_id = %s"
  246. cursor.execute(sql, (business_id,))
  247. conn.commit()
  248. return True, "删除成功"
  249. except Exception as e:
  250. logger.exception(f"删除任务记录失败: {e}")
  251. conn.rollback()
  252. return False, f"删除失败: {str(e)}"
  253. finally:
  254. cursor.close()
  255. conn.close()
  256. async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]:
  257. """创建标注项目并同步任务数据"""
  258. project_name = data.get('name')
  259. if not project_name:
  260. return False, "项目名称不能为空"
  261. # 0. 统一使用 UUID 方案获取或生成 project_id
  262. conn = get_db_connection()
  263. if not conn:
  264. return False, "数据库连接失败"
  265. cursor = conn.cursor()
  266. try:
  267. project_id = None
  268. cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,))
  269. existing_project = cursor.fetchone()
  270. if existing_project:
  271. project_id = existing_project['project_id']
  272. else:
  273. import uuid
  274. project_id = str(uuid.uuid4())
  275. task_type = data.get('task_type', 'data')
  276. # 映射回内部类型
  277. internal_type_map = {
  278. 'text_classification': 'data',
  279. 'image_classification': 'image'
  280. }
  281. internal_task_type = internal_type_map.get(task_type, task_type)
  282. tasks_data = data.get('data', [])
  283. if not tasks_data:
  284. return False, "任务数据不能为空"
  285. # 提取全局标签名列表
  286. global_tags = []
  287. if data.get('tags'):
  288. global_tags = [t['tag'] for t in data['tags'] if 'tag' in t]
  289. # 批量写入任务
  290. import json
  291. tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None
  292. for item in tasks_data:
  293. business_id = item.get('id')
  294. if not business_id: continue
  295. # 检查是否已存在 (使用联合主键逻辑)
  296. cursor.execute(
  297. "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s",
  298. (business_id, project_id)
  299. )
  300. if cursor.fetchone():
  301. continue
  302. metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False)
  303. sql = """
  304. INSERT INTO t_task_management
  305. (business_id, type, project_id, project_name, tag, metadata, annotation_status)
  306. VALUES (%s, %s, %s, %s, %s, %s, 'pending')
  307. """
  308. cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str))
  309. conn.commit()
  310. # 自动推送至外部平台
  311. success, msg = await self.send_to_external_platform(project_id)
  312. if success:
  313. return True, project_id
  314. else:
  315. return False, f"任务已保存但推送失败: {msg}"
  316. except Exception as e:
  317. logger.exception(f"创建项目失败: {e}")
  318. conn.rollback()
  319. return False, str(e)
  320. finally:
  321. cursor.close()
  322. conn.close()
  323. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  324. """获取项目进度统计 (单表化)"""
  325. conn = get_db_connection()
  326. if not conn:
  327. return {}
  328. cursor = conn.cursor()
  329. try:
  330. # 统计各状态数量
  331. sql = """
  332. SELECT
  333. annotation_status,
  334. COUNT(*) as count
  335. FROM t_task_management
  336. WHERE project_id = %s
  337. GROUP BY annotation_status
  338. """
  339. cursor.execute(sql, (project_id,))
  340. stats = cursor.fetchall()
  341. if not stats:
  342. return {}
  343. total = sum(s['count'] for s in stats)
  344. completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed')
  345. return {
  346. "project_id": project_id,
  347. "total": total,
  348. "completed": completed,
  349. "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%",
  350. "details": stats
  351. }
  352. except Exception as e:
  353. logger.error(f"查询进度失败: {e}")
  354. return {}
  355. finally:
  356. cursor.close()
  357. conn.close()
  358. def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]:
  359. """
  360. 从 Milvus 获取文档分片内容
  361. """
  362. if not kb_info:
  363. return []
  364. # 优先使用数据库存储的子表名,父表名作为兜底
  365. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  366. if not collections:
  367. return []
  368. from app.services.milvus_service import milvus_service
  369. contents = []
  370. # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用
  371. if not hasattr(self, '_collection_schema_cache'):
  372. self._collection_schema_cache = {}
  373. for coll_name in collections:
  374. try:
  375. if not milvus_service.client.has_collection(coll_name):
  376. continue
  377. # 获取或更新缓存的字段名
  378. if coll_name not in self._collection_schema_cache:
  379. schema = milvus_service.client.describe_collection(coll_name)
  380. field_names = [f['name'] for f in schema.get('fields', [])]
  381. self._collection_schema_cache[coll_name] = {
  382. "id": "document_id", # 统一使用 document_id
  383. "content": "text" if "text" in field_names else "content"
  384. }
  385. fields = self._collection_schema_cache[coll_name]
  386. id_field = fields["id"]
  387. content_field = fields["content"]
  388. res = milvus_service.client.query(
  389. collection_name=coll_name,
  390. filter=f'{id_field} == "{task_id}"',
  391. output_fields=[content_field]
  392. )
  393. if res:
  394. for s in res:
  395. val = s.get(content_field)
  396. if val: contents.append(val)
  397. if contents:
  398. return contents
  399. except Exception as e:
  400. logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}")
  401. continue
  402. return []
  403. async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]:
  404. """导出项目数据为标注平台要求的格式 (单表化)"""
  405. should_close = False
  406. if not conn:
  407. conn = get_db_connection()
  408. should_close = True
  409. if not conn:
  410. return {}
  411. cursor = conn.cursor()
  412. try:
  413. # 1. 获取任务记录
  414. sql_tasks = """
  415. SELECT business_id as id, type, task_id, project_name, tag, metadata
  416. FROM t_task_management
  417. WHERE project_id = %s
  418. """
  419. cursor.execute(sql_tasks, (project_id,))
  420. rows = cursor.fetchall()
  421. if not rows:
  422. return {}
  423. # 2. 解析基本信息
  424. first_row = rows[0]
  425. internal_task_type = first_row['type']
  426. project_name = first_row.get('project_name') or project_id
  427. remote_project_id = first_row.get('task_id') or project_id
  428. # 映射任务类型
  429. type_map = {'data': 'text_classification', 'image': 'image_classification'}
  430. external_task_type = type_map.get(internal_task_type, internal_task_type)
  431. # 3. 处理数据
  432. final_tasks = []
  433. all_project_tags = set()
  434. # 针对 'data' 类型的批量 Milvus 查询优化
  435. milvus_data_map = {}
  436. missing_milvus_ids = []
  437. if internal_task_type == 'data':
  438. all_task_ids = [r['id'] for r in rows]
  439. sql_kb = """
  440. SELECT kb.collection_name_parent, kb.collection_name_children
  441. FROM t_samp_document_main d
  442. LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id
  443. WHERE d.id = %s
  444. """
  445. cursor.execute(sql_kb, (all_task_ids[0],))
  446. kb_info = cursor.fetchone()
  447. if kb_info:
  448. milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info)
  449. # 记录缺失 Milvus 数据的 ID
  450. missing_milvus_ids = [tid for tid in all_task_ids if not milvus_data_map.get(tid)]
  451. # 针对缺失数据的批量标题查询
  452. title_map = {}
  453. if missing_milvus_ids:
  454. placeholders = ', '.join(['%s'] * len(missing_milvus_ids))
  455. cursor.execute(f"SELECT id, title FROM t_samp_document_main WHERE id IN ({placeholders})", missing_milvus_ids)
  456. title_map = {r['id']: r['title'] for r in cursor.fetchall()}
  457. # 针对 'image' 类型的批量 MinIO 查询优化
  458. image_data_map = {}
  459. if internal_task_type == 'image':
  460. all_image_ids = [r['id'] for r in rows]
  461. placeholders = ', '.join(['%s'] * len(all_image_ids))
  462. cursor.execute(f"SELECT id, image_url FROM t_image_info WHERE id IN ({placeholders})", all_image_ids)
  463. for img_row in cursor.fetchall():
  464. img_url = img_row['image_url']
  465. if img_url and not img_url.startswith('http'):
  466. img_url = self.minio_manager.get_full_url(img_url)
  467. image_data_map[img_row['id']] = [img_url] if img_url else []
  468. # 第一阶段:收集所有任务数据(不含颜色)
  469. tasks_data = []
  470. for item in rows:
  471. task_id = item['id']
  472. # 提取并处理标签
  473. doc_tags = []
  474. if item.get('tag'):
  475. try:
  476. doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag']
  477. if doc_tags:
  478. for t in doc_tags: all_project_tags.add(t)
  479. except: pass
  480. # 解析数据库元数据 (提前序列化日期)
  481. db_metadata = {}
  482. if item.get('metadata'):
  483. try:
  484. db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata']
  485. if db_metadata:
  486. db_metadata = self._serialize_datetime(db_metadata)
  487. except: pass
  488. # 获取任务内容
  489. task_contents = []
  490. if internal_task_type == 'data':
  491. task_contents = milvus_data_map.get(task_id, [])
  492. if not task_contents:
  493. title = title_map.get(task_id)
  494. if title: task_contents = [title]
  495. elif internal_task_type == 'image':
  496. task_contents = image_data_map.get(task_id, [])
  497. # 保存任务数据(不含标签颜色)
  498. tasks_data.append({
  499. 'task_id': task_id,
  500. 'doc_tags': doc_tags,
  501. 'db_metadata': db_metadata,
  502. 'task_contents': task_contents
  503. })
  504. # 第二阶段:生成颜色并构建最终任务列表
  505. sorted_tags = sorted(list(all_project_tags))
  506. tag_color_map = generate_tag_colors(sorted_tags)
  507. for task_data in tasks_data:
  508. task_id = task_data['task_id']
  509. doc_tags = task_data['doc_tags']
  510. db_metadata = task_data['db_metadata']
  511. task_contents = task_data['task_contents']
  512. for idx, content in enumerate(task_contents):
  513. if not content: continue
  514. # 合并元数据:数据库数据 + 动态 ID
  515. task_metadata = {
  516. "original_id": task_id,
  517. "chunk_index": idx
  518. }
  519. if db_metadata:
  520. task_metadata.update(db_metadata)
  521. # 为每个标签添加颜色
  522. if doc_tags:
  523. task_metadata['tags'] = [
  524. {"tag": tag, "color": tag_color_map.get(tag, "#999999")}
  525. for tag in doc_tags
  526. ]
  527. task_item = {
  528. "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id,
  529. "content": content,
  530. "metadata": task_metadata
  531. }
  532. # 尝试从元数据中提取 annotation_result
  533. if db_metadata and 'annotation_result' in db_metadata:
  534. task_item['annotation_result'] = db_metadata['annotation_result']
  535. final_tasks.append(task_item)
  536. # 构建项目级别的标签列表(带颜色)
  537. tags_result = []
  538. for tag in sorted_tags:
  539. tag_obj = {"tag": tag}
  540. if tag in tag_color_map:
  541. tag_obj["color"] = tag_color_map[tag]
  542. tags_result.append(tag_obj)
  543. # 准备返回结果,不再进行全局递归序列化 (已在局部处理)
  544. return {
  545. "name": project_name,
  546. "description": "",
  547. "task_type": external_task_type,
  548. "data": final_tasks,
  549. "external_id": remote_project_id,
  550. "tags": tags_result
  551. }
  552. except Exception as e:
  553. logger.exception(f"导出项目数据异常: {e}")
  554. return {}
  555. finally:
  556. cursor.close()
  557. if should_close:
  558. conn.close()
  559. def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]:
  560. """
  561. 批量从 Milvus 获取文档分片内容
  562. """
  563. if not kb_info or not task_ids:
  564. return {}
  565. collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
  566. if not collections:
  567. return {}
  568. from app.services.milvus_service import milvus_service
  569. result_map = {tid: [] for tid in task_ids}
  570. if not hasattr(self, '_collection_schema_cache'):
  571. self._collection_schema_cache = {}
  572. for coll_name in collections:
  573. try:
  574. if not milvus_service.client.has_collection(coll_name):
  575. continue
  576. if coll_name not in self._collection_schema_cache:
  577. schema = milvus_service.client.describe_collection(coll_name)
  578. field_names = [f['name'] for f in schema.get('fields', [])]
  579. self._collection_schema_cache[coll_name] = {
  580. "id": "document_id",
  581. "content": "text" if "text" in field_names else "content"
  582. }
  583. fields = self._collection_schema_cache[coll_name]
  584. id_field = fields["id"]
  585. content_field = fields["content"]
  586. # 使用 in 表达式进行批量查询 (分批处理以防 ID 过多)
  587. CHUNK_SIZE = 100
  588. for i in range(0, len(task_ids), CHUNK_SIZE):
  589. chunk_ids = task_ids[i:i + CHUNK_SIZE]
  590. id_list_str = ", ".join([f'"{tid}"' for tid in chunk_ids])
  591. logger.info(f"正在从 Milvus 集合 {coll_name} 查询分片内容 ({i}/{len(task_ids)})...")
  592. res = milvus_service.client.query(
  593. collection_name=coll_name,
  594. filter=f'{id_field} in [{id_list_str}]',
  595. output_fields=[id_field, content_field]
  596. )
  597. if res:
  598. for s in res:
  599. tid = s.get(id_field)
  600. val = s.get(content_field)
  601. if tid in result_map and val:
  602. result_map[tid].append(val)
  603. # 如果当前集合已经查到了内容,就不再查兜底集合 (除非结果仍为空)
  604. if any(result_map.values()):
  605. logger.info(f"从 Milvus 集合 {coll_name} 查得 {sum(len(v) for v in result_map.values())} 条内容分片")
  606. return result_map
  607. except Exception as e:
  608. logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}")
  609. continue
  610. return result_map
  611. async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
  612. """获取外部标注项目的进度"""
  613. conn = get_db_connection()
  614. if not conn:
  615. return {"error": "数据库连接失败"}
  616. try:
  617. # 1. 查询 remote_project_id (task_id)
  618. cursor = conn.cursor()
  619. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  620. row = cursor.fetchone()
  621. cursor.close()
  622. if not row or not row['task_id']:
  623. return {"error": "未找到已推送的外部项目ID"}
  624. remote_project_id = row['task_id']
  625. # 2. 获取配置
  626. from app.core.config import config_handler
  627. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  628. progress_url = f"{api_base_url}/{remote_project_id}/progress"
  629. token = config_handler.get('external_api', 'admin_token', '')
  630. # 3. 发送请求
  631. async with httpx.AsyncClient(timeout=10.0) as client:
  632. headers = {"Authorization": f"Bearer {token}"}
  633. response = await client.get(progress_url, headers=headers)
  634. if response.status_code == 200:
  635. data = response.json()
  636. # 同步更新本地缓存的完成数量和总数
  637. completed_count = data.get('completed_tasks', 0)
  638. total_count = data.get('total_tasks', 0)
  639. try:
  640. conn_update = get_db_connection()
  641. if conn_update:
  642. cursor_update = conn_update.cursor()
  643. cursor_update.execute(
  644. """
  645. UPDATE t_task_management
  646. SET external_completed_count = %s, external_total_count = %s
  647. WHERE project_id = %s
  648. """,
  649. (completed_count, total_count, project_id)
  650. )
  651. conn_update.commit()
  652. cursor_update.close()
  653. conn_update.close()
  654. except Exception as ex:
  655. logger.warning(f"更新本地进度缓存失败: {ex}")
  656. return data
  657. else:
  658. logger.error(f"查询进度失败: {response.status_code} - {response.text}")
  659. return {"error": f"外部平台返回错误 ({response.status_code})"}
  660. except Exception as e:
  661. logger.exception(f"查询进度异常: {e}")
  662. return {"error": str(e)}
  663. finally:
  664. conn.close()
  665. async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]:
  666. """触发外部标注项目的数据导出"""
  667. conn = get_db_connection()
  668. if not conn:
  669. return {"error": "数据库连接失败"}
  670. try:
  671. # 1. 查询 remote_project_id (task_id)
  672. cursor = conn.cursor()
  673. cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
  674. row = cursor.fetchone()
  675. cursor.close()
  676. if not row or not row['task_id']:
  677. return {"error": "未找到已推送的外部项目ID"}
  678. remote_project_id = row['task_id']
  679. # 2. 获取配置
  680. from app.core.config import config_handler
  681. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  682. export_url = f"{api_base_url}/{remote_project_id}/export"
  683. token = config_handler.get('external_api', 'admin_token', '')
  684. # 3. 发送请求
  685. async with httpx.AsyncClient(timeout=30.0) as client:
  686. headers = {
  687. "Authorization": f"Bearer {token}",
  688. "Content-Type": "application/json"
  689. }
  690. payload = {
  691. "format": export_format,
  692. "completed_only": completed_only
  693. }
  694. response = await client.post(export_url, json=payload, headers=headers)
  695. if response.status_code in (200, 201):
  696. res_data = response.json()
  697. logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}")
  698. # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url)
  699. download_url = res_data.get('download_url') or res_data.get('file_url')
  700. if isinstance(res_data, dict) and download_url:
  701. try:
  702. cursor_update = conn.cursor()
  703. affected = cursor_update.execute(
  704. "UPDATE t_task_management SET file_url = %s WHERE project_id = %s",
  705. (download_url, project_id)
  706. )
  707. conn.commit()
  708. cursor_update.close()
  709. logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}")
  710. except Exception as ex:
  711. logger.warning(f"导出同步回写 file_url 失败: {ex}")
  712. # 2. 同步回写 annotation_result 到 metadata
  713. # 情况 A: 接口直接返回了任务列表数据
  714. export_items = []
  715. if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list):
  716. export_items = res_data['data']
  717. # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果
  718. elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json':
  719. try:
  720. # 补全 URL 协议 (如果外部平台返回的是相对路径)
  721. full_download_url = download_url
  722. if not download_url.startswith('http'):
  723. from app.core.config import config_handler
  724. # 尝试从配置获取 base_url,如果没有则从 api_url 中提取
  725. base_url = config_handler.get('external_api', 'download_base_url', '')
  726. if not base_url:
  727. # 兜底:从 project_api_url 中提取域名部分
  728. api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003')
  729. from urllib.parse import urlparse
  730. parsed = urlparse(api_base)
  731. base_url = f"{parsed.scheme}://{parsed.netloc}"
  732. full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}"
  733. logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}")
  734. # 注意:这里需要带上 token
  735. file_res = await client.get(full_download_url, headers=headers)
  736. if file_res.status_code == 200:
  737. file_json = file_res.json()
  738. # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...]
  739. if isinstance(file_json, dict) and 'data' in file_json:
  740. export_items = file_json['data']
  741. elif isinstance(file_json, list):
  742. export_items = file_json
  743. if export_items:
  744. logger.info(f"成功获取导出项,共 {len(export_items)} 条。")
  745. else:
  746. logger.warning("获取到的导出列表为空")
  747. except Exception as ex:
  748. logger.warning(f"从导出文件同步标注数据失败: {ex}")
  749. if export_items:
  750. updated_count = 0
  751. try:
  752. cursor_meta = conn.cursor()
  753. for ext_item in export_items:
  754. # 根据实际 JSON 结构提取数据
  755. # 1. 提取 original_id
  756. original_data = ext_item.get('original_data', {})
  757. meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {})
  758. # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取
  759. original_id = meta.get('original_id')
  760. if not original_id:
  761. ext_id = ext_item.get('external_id', '')
  762. if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid
  763. original_id = ext_id.rsplit('_', 1)[0]
  764. else:
  765. original_id = ext_id
  766. # 2. 提取 annotation_result
  767. annotations = ext_item.get('annotations', [])
  768. if annotations and isinstance(annotations, list):
  769. annotation_res = annotations[0].get('result')
  770. else:
  771. annotation_res = ext_item.get('annotation_result')
  772. if original_id and annotation_res is not None:
  773. # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配
  774. # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识
  775. cursor_meta.execute(
  776. "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s",
  777. (original_id, original_id)
  778. )
  779. row = cursor_meta.fetchone()
  780. if row:
  781. db_id = row['id']
  782. current_meta = json.loads(row['metadata']) if row['metadata'] else {}
  783. current_meta['annotation_result'] = annotation_res
  784. cursor_meta.execute(
  785. "UPDATE t_task_management SET metadata = %s WHERE id = %s",
  786. (json.dumps(current_meta, ensure_ascii=False), db_id)
  787. )
  788. updated_count += 1
  789. else:
  790. logger.debug(f"未在数据库中找到对应的任务: {original_id}")
  791. conn.commit()
  792. cursor_meta.close()
  793. logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result")
  794. except Exception as ex:
  795. logger.warning(f"同步回写 annotation_result 异常: {ex}")
  796. return res_data
  797. else:
  798. logger.error(f"导出数据失败: {response.status_code} - {response.text}")
  799. return {"error": f"外部平台返回错误 ({response.status_code})"}
  800. except Exception as e:
  801. logger.exception(f"导出数据异常: {e}")
  802. return {"error": str(e)}
  803. finally:
  804. conn.close()
  805. async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]:
  806. """将项目数据推送至外部标注平台 (单表化)"""
  807. # 1. 准备数据 (导出数据需要数据库连接)
  808. conn = get_db_connection()
  809. if not conn:
  810. return False, "数据库连接失败"
  811. try:
  812. logger.info(f"开始导出项目 {project_id} 数据...")
  813. payload = await self.export_project_data(project_id=project_id, conn=conn)
  814. if not payload:
  815. return False, "项目导出失败,请检查项目ID是否正确"
  816. if not payload.get('data'):
  817. return False, f"项目数据为空 (查询到0条有效任务),无法推送"
  818. except Exception as e:
  819. logger.exception(f"导出项目数据异常: {e}")
  820. return False, f"导出异常: {str(e)}"
  821. finally:
  822. # 及时释放连接,防止在 HTTP 请求期间占用
  823. conn.close()
  824. # 2. 获取配置
  825. try:
  826. from app.core.config import config_handler
  827. # 强制重新加载配置文件
  828. import os
  829. current_dir = os.path.dirname(os.path.abspath(__file__))
  830. config_file = os.path.join(current_dir, '..', 'config', 'config.ini')
  831. config_file = os.path.abspath(config_file)
  832. logger.info(f"配置文件路径: {config_file}")
  833. logger.info(f"配置文件是否存在: {os.path.exists(config_file)}")
  834. # 重新读取配置文件
  835. config_handler.config.read(config_file, encoding='utf-8')
  836. # 调试: 检查配置文件是否正确加载
  837. logger.info(f"配置文件sections: {config_handler.config.sections()}")
  838. # 调试: 检查external_api section的所有配置
  839. if config_handler.config.has_section('external_api'):
  840. logger.info(f"external_api section存在")
  841. logger.info(f"external_api所有选项: {config_handler.config.options('external_api')}")
  842. else:
  843. logger.error("external_api section不存在!")
  844. api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
  845. api_url = f"{api_base_url}/init"
  846. # 直接从ConfigParser读取,看看原始值
  847. raw_token = None
  848. if config_handler.config.has_option('external_api', 'admin_token'):
  849. raw_token = config_handler.config.get('external_api', 'admin_token')
  850. logger.info(f"原始token (直接从ConfigParser): 类型={type(raw_token)}, 长度={len(raw_token)}, 前50字符={raw_token[:50] if raw_token else 'None'}")
  851. else:
  852. logger.error("admin_token选项不存在!")
  853. token = config_handler.get('external_api', 'admin_token', '')
  854. logger.info(f"通过get方法获取的token: 类型={type(token)}, 长度={len(str(token)) if token else 0}, 前50字符={str(token)[:50] if token else 'None'}")
  855. # 确保token是字符串并去除首尾空格
  856. if token:
  857. token = str(token).strip()
  858. logger.info(f"清理后的token: 类型={type(token)}, 长度={len(token) if token else 0}, 前50字符={token[:50] if token else 'None'}")
  859. if not token:
  860. logger.error("外部平台Token未配置或为空")
  861. return False, "外部平台Token未配置"
  862. # 3. 发送请求 (不持有数据库连接)
  863. async with httpx.AsyncClient(timeout=120.0) as client: # 增加超时时间到 120s
  864. headers = {
  865. "Authorization": f"Bearer {token}",
  866. "Content-Type": "application/json"
  867. }
  868. logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}")
  869. logger.info(f"正在推送项目: {payload['data']}")
  870. response = await client.post(api_url, json=payload, headers=headers)
  871. if response.status_code in (200, 201):
  872. res_data = response.json()
  873. logger.info(f"外部平台推送成功响应: {res_data}")
  874. remote_project_id = res_data.get('project_id')
  875. download_url = res_data.get('download_url') or res_data.get('file_url')
  876. # 4. 回写外部项目 ID 和下载地址 (重新获取连接)
  877. if remote_project_id:
  878. conn = get_db_connection()
  879. if conn:
  880. try:
  881. cursor = conn.cursor()
  882. if download_url:
  883. affected = cursor.execute(
  884. "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s",
  885. (remote_project_id, download_url, project_id)
  886. )
  887. logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}")
  888. else:
  889. affected = cursor.execute(
  890. "UPDATE t_task_management SET task_id = %s WHERE project_id = %s",
  891. (remote_project_id, project_id)
  892. )
  893. logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}")
  894. conn.commit()
  895. finally:
  896. cursor.close()
  897. conn.close()
  898. return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}"
  899. else:
  900. error_msg = response.text
  901. logger.error(f"推送失败: {response.status_code} - {error_msg}")
  902. return False, f"外部平台返回错误 ({response.status_code})"
  903. except Exception as e:
  904. logger.exception(f"推送至外部平台异常: {e}")
  905. return False, f"推送异常: {str(e)}"
  906. task_service = TaskService()