| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070 |
- import logging
- import json
- import httpx
- from datetime import datetime, date
- from typing import List, Dict, Any, Tuple, Optional
- from app.base.async_mysql_connection import get_db_connection
- from app.base.minio_connection import get_minio_manager
- logger = logging.getLogger(__name__)
- def generate_tag_colors(tags: List[str]) -> Dict[str, str]:
- """
- 为标签生成不重复的颜色
- Args:
- tags: 标签列表
- Returns:
- 标签到颜色值的映射字典
- """
- # 预定义的一组高对比度、视觉区分度高的颜色
- pre_defined_colors = [
- "#FF5733", "#33FF57", "#3357FF", "#FF33A1", "#33FFF5",
- "#F5FF33", "#FF8C33", "#8C33FF", "#FF3333", "#33FF8C",
- "#FF338C", "#8CFF33", "#338CFF", "#FF5733", "#57FF33",
- "#3357FF", "#FF33FF", "#33FFFF", "#FFFF33", "#FF8000",
- "#80FF00", "#0080FF", "#FF0080", "#00FF80", "#8000FF",
- "#FF0000", "#00FF00", "#0000FF", "#FFFF00", "#FF00FF",
- "#00FFFF", "#FFA500", "#A52A2A", "#800080", "#008080",
- "#000080", "#800000", "#008000", "#000000"
- ]
- color_map = {}
- for i, tag in enumerate(tags):
- # 使用预定义颜色,如果标签数量超过预定义颜色,则通过算法生成
- if i < len(pre_defined_colors):
- color_map[tag] = pre_defined_colors[i]
- else:
- # 使用 HSL 色相环生成颜色,确保颜色不重复
- hue = (i * 137.508) % 360 # 黄金角度,确保颜色分布均匀
- # 高饱和度和亮度,保证颜色鲜艳
- color_map[tag] = f"hsl({hue:.0f}, 70%, 50%)"
- return color_map
- class TaskService:
- """任务管理服务类"""
-
- _schema_verified = False # 类级别变量,确保 DDL 逻辑只运行一次
- def __init__(self):
- self.minio_manager = get_minio_manager()
- async def _ensure_table_schema(self, cursor, conn):
- """确保表结构和索引正确 (DDL 操作)"""
- if TaskService._schema_verified:
- return
-
- try:
- # 1. 动态维护字段
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'tag'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN tag json NULL COMMENT '标签' AFTER type")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'metadata'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN metadata json NULL COMMENT '业务元数据' AFTER tag")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'project_name'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN project_name varchar(255) NULL COMMENT '项目显示名称' AFTER project_id")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_completed_count'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_completed_count int NULL DEFAULT 0 COMMENT '外部平台完成数量' AFTER annotation_status")
-
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'external_total_count'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN external_total_count int NULL DEFAULT 0 COMMENT '外部平台总任务数量' AFTER external_completed_count")
-
- conn.commit()
- # 2. 处理索引冲突
- cursor.execute("SHOW COLUMNS FROM t_task_management LIKE 'file_url'")
- if not cursor.fetchone():
- cursor.execute("ALTER TABLE t_task_management ADD COLUMN file_url VARCHAR(512) DEFAULT NULL AFTER metadata")
- logger.info("Added column file_url to t_task_management")
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Column_name = 'business_id'")
- indexes = cursor.fetchall()
- for idx in indexes:
- if not idx['Non_unique'] and idx['Key_name'] != 'PRIMARY' and idx['Seq_in_index'] == 1:
- cursor.execute(f"SHOW INDEX FROM t_task_management WHERE Key_name = '{idx['Key_name']}'")
- if len(cursor.fetchall()) == 1:
- cursor.execute(f"DROP INDEX {idx['Key_name']} ON t_task_management")
- logger.info(f"Dropped old unique index: {idx['Key_name']}")
-
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'uk_business_project'")
- if not cursor.fetchone():
- cursor.execute("CREATE UNIQUE INDEX uk_business_project ON t_task_management (business_id, project_id)")
- logger.info("Created new composite unique index: uk_business_project")
-
- cursor.execute("SHOW INDEX FROM t_task_management WHERE Key_name = 'idx_project_id'")
- if not cursor.fetchone():
- cursor.execute("CREATE INDEX idx_project_id ON t_task_management (project_id)")
- logger.info("Created index: idx_project_id")
-
- conn.commit()
- TaskService._schema_verified = True
- except Exception as e:
- logger.warning(f"表结构维护失败: {e}")
- conn.rollback()
- async def add_task(self, business_id: str, task_type: str, task_id: str = None, project_id: str = None, project_name: str = None, tag: str = None, metadata: str = None) -> Tuple[bool, str, Optional[int]]:
- """添加或更新任务记录 (适配单表结构)"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败", None
-
- cursor = conn.cursor()
- try:
- # 确保表结构(仅在第一次调用时执行)
- await self._ensure_table_schema(cursor, conn)
- # 执行插入/更新
- sql = """
- INSERT INTO t_task_management (business_id, task_id, project_id, project_name, type, annotation_status, tag, metadata)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- ON DUPLICATE KEY UPDATE
- task_id = IFNULL(VALUES(task_id), task_id),
- project_id = IFNULL(VALUES(project_id), project_id),
- project_name = IFNULL(VALUES(project_name), project_name),
- annotation_status = IFNULL(VALUES(annotation_status), annotation_status),
- tag = IFNULL(VALUES(tag), tag),
- metadata = IFNULL(VALUES(metadata), metadata)
- """
- cursor.execute(sql, (business_id, task_id, project_id, project_name, task_type, 'pending', tag, metadata))
- record_id = cursor.lastrowid
-
- conn.commit()
- return True, "成功", record_id
- except Exception as e:
- conn.rollback()
- logger.exception(f"添加任务失败: {e}")
- return False, str(e), None
- finally:
- cursor.close()
- conn.close()
- async def delete_task_by_id(self, id: int) -> Tuple[bool, str]:
- """根据主键 id 删除任务记录"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- sql = "DELETE FROM t_task_management WHERE id = %s"
- cursor.execute(sql, (id,))
- conn.commit()
- return True, "成功"
- except Exception as e:
- conn.rollback()
- logger.exception(f"删除任务失败: {e}")
- return False, str(e)
- finally:
- cursor.close()
- conn.close()
- def _serialize_datetime(self, obj: Any) -> Any:
- """递归遍历对象,将 datetime 转换为字符串"""
- if isinstance(obj, dict):
- return {k: self._serialize_datetime(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [self._serialize_datetime(i) for i in obj]
- elif isinstance(obj, (datetime, date)):
- return obj.strftime('%Y-%m-%d %H:%M:%S') if isinstance(obj, datetime) else obj.strftime('%Y-%m-%d')
- return obj
- async def get_task_list(self, task_type: str) -> List[Dict[str, Any]]:
- """获取项目列表 (按 project_id 聚合)"""
- conn = get_db_connection()
- if not conn or task_type not in ['data', 'image']:
- return []
-
- cursor = conn.cursor()
- try:
- # 确保表结构
- await self._ensure_table_schema(cursor, conn)
-
- # 修改聚合逻辑:返回 project_id (UUID) 和 project_name (文字)
- sql = """
- SELECT
- project_id,
- MAX(project_name) as project_name,
- MAX(task_id) as task_id,
- MAX(tag) as tag,
- MAX(id) as sort_id,
- COALESCE(NULLIF(MAX(external_total_count), 0), COUNT(*)) as file_count,
- COALESCE(
- MAX(external_completed_count),
- SUM(CASE WHEN annotation_status = 'completed' THEN 1 ELSE 0 END),
- 0
- ) as completed_count
- FROM t_task_management
- WHERE type = %s
- GROUP BY project_id
- ORDER BY sort_id DESC
- """
- cursor.execute(sql, (task_type,))
- rows = cursor.fetchall()
-
- # 兼容旧数据:如果 project_name 为空,则用 project_id 代替
- for row in rows:
- if not row['project_name']:
- row['project_name'] = row['project_id']
-
- return self._serialize_datetime(rows)
- except Exception as e:
- logger.error(f"获取任务列表失败: {e}")
- return []
- finally:
- cursor.close()
- conn.close()
- async def get_project_details(self, project_id: str, task_type: str) -> List[Dict[str, Any]]:
- """获取项目详情 (按 project_id 查询)"""
- conn = get_db_connection()
- if not conn:
- return []
-
- cursor = conn.cursor()
- try:
- # 确保表结构
- await self._ensure_table_schema(cursor, conn)
-
- # 修改查询逻辑:获取 project_name 并在结果中包含它
- sql = """
- SELECT
- id, business_id, task_id, project_id, project_name,
- type, annotation_status, tag, metadata
- FROM t_task_management
- WHERE project_id = %s AND type = %s
- """
- cursor.execute(sql, (project_id, task_type))
- rows = cursor.fetchall()
-
- # 处理 tag 和 metadata 的 JSON 解析
- for row in rows:
- if row.get('tag'):
- try: row['tag'] = json.loads(row['tag']) if isinstance(row['tag'], str) else row['tag']
- except: pass
- if row.get('metadata'):
- try:
- meta = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
- row['metadata'] = meta
- # 提取名称供前端显示
- if row['type'] == 'data':
- row['name'] = meta.get('title') or meta.get('filename') or row['business_id']
- elif row['type'] == 'image':
- row['name'] = meta.get('image_name') or row['business_id']
- else:
- row['name'] = row['business_id']
- except:
- row['name'] = row['business_id']
- else:
- row['name'] = row['business_id']
-
- return self._serialize_datetime(rows)
- except Exception as e:
- logger.error(f"获取项目详情失败: {e}")
- return []
- finally:
- cursor.close()
- conn.close()
- async def delete_task(self, business_id: str) -> Tuple[bool, str]:
- """删除任务记录"""
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- sql = "DELETE FROM t_task_management WHERE business_id = %s"
- cursor.execute(sql, (business_id,))
- conn.commit()
- return True, "删除成功"
- except Exception as e:
- logger.exception(f"删除任务记录失败: {e}")
- conn.rollback()
- return False, f"删除失败: {str(e)}"
- finally:
- cursor.close()
- conn.close()
- async def create_anno_project(self, data: Dict[str, Any]) -> Tuple[bool, str]:
- """创建标注项目并同步任务数据"""
- project_name = data.get('name')
- if not project_name:
- return False, "项目名称不能为空"
-
- # 0. 统一使用 UUID 方案获取或生成 project_id
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- cursor = conn.cursor()
- try:
- project_id = None
- cursor.execute("SELECT project_id FROM t_task_management WHERE project_name = %s LIMIT 1", (project_name,))
- existing_project = cursor.fetchone()
- if existing_project:
- project_id = existing_project['project_id']
- else:
- import uuid
- project_id = str(uuid.uuid4())
- task_type = data.get('task_type', 'data')
- # 映射回内部类型
- internal_type_map = {
- 'text_classification': 'data',
- 'image_classification': 'image'
- }
- internal_task_type = internal_type_map.get(task_type, task_type)
-
- tasks_data = data.get('data', [])
- if not tasks_data:
- return False, "任务数据不能为空"
-
- # 提取全局标签名列表
- global_tags = []
- if data.get('tags'):
- global_tags = [t['tag'] for t in data['tags'] if 'tag' in t]
-
- # 批量写入任务
- import json
- tag_str = json.dumps(global_tags, ensure_ascii=False) if global_tags else None
-
- for item in tasks_data:
- business_id = item.get('id')
- if not business_id: continue
-
- # 检查是否已存在 (使用联合主键逻辑)
- cursor.execute(
- "SELECT id FROM t_task_management WHERE business_id = %s AND project_id = %s",
- (business_id, project_id)
- )
- if cursor.fetchone():
- continue
-
- metadata_str = json.dumps(item.get('metadata', {}), ensure_ascii=False)
-
- sql = """
- INSERT INTO t_task_management
- (business_id, type, project_id, project_name, tag, metadata, annotation_status)
- VALUES (%s, %s, %s, %s, %s, %s, 'pending')
- """
- cursor.execute(sql, (business_id, internal_task_type, project_id, project_name, tag_str, metadata_str))
-
- conn.commit()
-
- # 自动推送至外部平台
- success, msg = await self.send_to_external_platform(project_id)
- if success:
- return True, project_id
- else:
- return False, f"任务已保存但推送失败: {msg}"
-
- except Exception as e:
- logger.exception(f"创建项目失败: {e}")
- conn.rollback()
- return False, str(e)
- finally:
- cursor.close()
- conn.close()
- async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
- """获取项目进度统计 (单表化)"""
- conn = get_db_connection()
- if not conn:
- return {}
-
- cursor = conn.cursor()
- try:
- # 统计各状态数量
- sql = """
- SELECT
- annotation_status,
- COUNT(*) as count
- FROM t_task_management
- WHERE project_id = %s
- GROUP BY annotation_status
- """
- cursor.execute(sql, (project_id,))
- stats = cursor.fetchall()
-
- if not stats:
- return {}
- total = sum(s['count'] for s in stats)
- completed = sum(s['count'] for s in stats if s['annotation_status'] == 'completed')
-
- return {
- "project_id": project_id,
- "total": total,
- "completed": completed,
- "progress": f"{round(completed/total*100, 2)}%" if total > 0 else "0%",
- "details": stats
- }
- except Exception as e:
- logger.error(f"查询进度失败: {e}")
- return {}
- finally:
- cursor.close()
- conn.close()
- def _get_milvus_content(self, task_id: str, kb_info: Dict[str, Any]) -> List[str]:
- """
- 从 Milvus 获取文档分片内容
- """
- if not kb_info:
- return []
- # 优先使用数据库存储的子表名,父表名作为兜底
- collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
-
- if not collections:
- return []
- from app.services.milvus_service import milvus_service
- contents = []
-
- # 简单的全局缓存,用于存储集合的字段探测结果,减少 describe_collection 调用
- if not hasattr(self, '_collection_schema_cache'):
- self._collection_schema_cache = {}
- for coll_name in collections:
- try:
- if not milvus_service.client.has_collection(coll_name):
- continue
-
- # 获取或更新缓存的字段名
- if coll_name not in self._collection_schema_cache:
- schema = milvus_service.client.describe_collection(coll_name)
- field_names = [f['name'] for f in schema.get('fields', [])]
- self._collection_schema_cache[coll_name] = {
- "id": "document_id", # 统一使用 document_id
- "content": "text" if "text" in field_names else "content"
- }
-
- fields = self._collection_schema_cache[coll_name]
- id_field = fields["id"]
- content_field = fields["content"]
-
- res = milvus_service.client.query(
- collection_name=coll_name,
- filter=f'{id_field} == "{task_id}"',
- output_fields=[content_field]
- )
-
- if res:
- for s in res:
- val = s.get(content_field)
- if val: contents.append(val)
-
- if contents:
- return contents
- except Exception as e:
- logger.error(f"查询 Milvus 集合 {coll_name} 异常: {e}")
- continue
-
- return []
- async def export_project_data(self, project_id: str, conn=None) -> Dict[str, Any]:
- """导出项目数据为标注平台要求的格式 (单表化)"""
- should_close = False
- if not conn:
- conn = get_db_connection()
- should_close = True
-
- if not conn:
- return {}
-
- cursor = conn.cursor()
- try:
- # 1. 获取任务记录
- sql_tasks = """
- SELECT business_id as id, type, task_id, project_name, tag, metadata
- FROM t_task_management
- WHERE project_id = %s
- """
- cursor.execute(sql_tasks, (project_id,))
- rows = cursor.fetchall()
-
- if not rows:
- return {}
-
- # 2. 解析基本信息
- first_row = rows[0]
- internal_task_type = first_row['type']
- project_name = first_row.get('project_name') or project_id
- remote_project_id = first_row.get('task_id') or project_id
-
- # 映射任务类型
- type_map = {'data': 'text_classification', 'image': 'image_classification'}
- external_task_type = type_map.get(internal_task_type, internal_task_type)
- # 3. 处理数据
- final_tasks = []
- all_project_tags = set()
-
- # 针对 'data' 类型的批量 Milvus 查询优化
- milvus_data_map = {}
- missing_milvus_ids = []
- if internal_task_type == 'data':
- all_task_ids = [r['id'] for r in rows]
- sql_kb = """
- SELECT kb.collection_name_parent, kb.collection_name_children
- FROM t_samp_document_main d
- LEFT JOIN t_samp_knowledge_base kb ON d.kb_id = kb.id
- WHERE d.id = %s
- """
- cursor.execute(sql_kb, (all_task_ids[0],))
- kb_info = cursor.fetchone()
- if kb_info:
- milvus_data_map = self._get_milvus_content_batch(all_task_ids, kb_info)
-
- # 记录缺失 Milvus 数据的 ID
- missing_milvus_ids = [tid for tid in all_task_ids if not milvus_data_map.get(tid)]
-
- # 针对缺失数据的批量标题查询
- title_map = {}
- if missing_milvus_ids:
- placeholders = ', '.join(['%s'] * len(missing_milvus_ids))
- cursor.execute(f"SELECT id, title FROM t_samp_document_main WHERE id IN ({placeholders})", missing_milvus_ids)
- title_map = {r['id']: r['title'] for r in cursor.fetchall()}
- # 针对 'image' 类型的批量 MinIO 查询优化
- image_data_map = {}
- if internal_task_type == 'image':
- all_image_ids = [r['id'] for r in rows]
- placeholders = ', '.join(['%s'] * len(all_image_ids))
- cursor.execute(f"SELECT id, image_url FROM t_image_info WHERE id IN ({placeholders})", all_image_ids)
- for img_row in cursor.fetchall():
- img_url = img_row['image_url']
- if img_url and not img_url.startswith('http'):
- img_url = self.minio_manager.get_full_url(img_url)
- image_data_map[img_row['id']] = [img_url] if img_url else []
- # 第一阶段:收集所有任务数据(不含颜色)
- tasks_data = []
- for item in rows:
- task_id = item['id']
- # 提取并处理标签
- doc_tags = []
- if item.get('tag'):
- try:
- doc_tags = json.loads(item['tag']) if isinstance(item['tag'], str) else item['tag']
- if doc_tags:
- for t in doc_tags: all_project_tags.add(t)
- except: pass
- # 解析数据库元数据 (提前序列化日期)
- db_metadata = {}
- if item.get('metadata'):
- try:
- db_metadata = json.loads(item['metadata']) if isinstance(item['metadata'], str) else item['metadata']
- if db_metadata:
- db_metadata = self._serialize_datetime(db_metadata)
- except: pass
- # 获取任务内容
- task_contents = []
- if internal_task_type == 'data':
- task_contents = milvus_data_map.get(task_id, [])
- if not task_contents:
- title = title_map.get(task_id)
- if title: task_contents = [title]
- elif internal_task_type == 'image':
- task_contents = image_data_map.get(task_id, [])
- # 保存任务数据(不含标签颜色)
- tasks_data.append({
- 'task_id': task_id,
- 'doc_tags': doc_tags,
- 'db_metadata': db_metadata,
- 'task_contents': task_contents
- })
- # 第二阶段:生成颜色并构建最终任务列表
- sorted_tags = sorted(list(all_project_tags))
- tag_color_map = generate_tag_colors(sorted_tags)
- for task_data in tasks_data:
- task_id = task_data['task_id']
- doc_tags = task_data['doc_tags']
- db_metadata = task_data['db_metadata']
- task_contents = task_data['task_contents']
- for idx, content in enumerate(task_contents):
- if not content: continue
- # 合并元数据:数据库数据 + 动态 ID
- task_metadata = {
- "original_id": task_id,
- "chunk_index": idx
- }
- if db_metadata:
- task_metadata.update(db_metadata)
- # 为每个标签添加颜色
- if doc_tags:
- task_metadata['tags'] = [
- {"tag": tag, "color": tag_color_map.get(tag, "#999999")}
- for tag in doc_tags
- ]
- task_item = {
- "id": f"{task_id}_{idx}" if len(task_contents) > 1 else task_id,
- "content": content,
- "metadata": task_metadata
- }
- # 尝试从元数据中提取 annotation_result
- if db_metadata and 'annotation_result' in db_metadata:
- task_item['annotation_result'] = db_metadata['annotation_result']
- final_tasks.append(task_item)
- # 构建项目级别的标签列表(带颜色)
- tags_result = []
- for tag in sorted_tags:
- tag_obj = {"tag": tag}
- if tag in tag_color_map:
- tag_obj["color"] = tag_color_map[tag]
- tags_result.append(tag_obj)
- # 准备返回结果,不再进行全局递归序列化 (已在局部处理)
- return {
- "name": project_name,
- "description": "",
- "task_type": external_task_type,
- "data": final_tasks,
- "external_id": remote_project_id,
- "tags": tags_result
- }
- except Exception as e:
- logger.exception(f"导出项目数据异常: {e}")
- return {}
- finally:
- cursor.close()
- if should_close:
- conn.close()
- def _get_milvus_content_batch(self, task_ids: List[str], kb_info: Dict[str, Any]) -> Dict[str, List[str]]:
- """
- 批量从 Milvus 获取文档分片内容
- """
- if not kb_info or not task_ids:
- return {}
- collections = [c for c in [kb_info.get('collection_name_children'), kb_info.get('collection_name_parent')] if c]
- if not collections:
- return {}
- from app.services.milvus_service import milvus_service
- result_map = {tid: [] for tid in task_ids}
-
- if not hasattr(self, '_collection_schema_cache'):
- self._collection_schema_cache = {}
- for coll_name in collections:
- try:
- if not milvus_service.client.has_collection(coll_name):
- continue
-
- if coll_name not in self._collection_schema_cache:
- schema = milvus_service.client.describe_collection(coll_name)
- field_names = [f['name'] for f in schema.get('fields', [])]
- self._collection_schema_cache[coll_name] = {
- "id": "document_id",
- "content": "text" if "text" in field_names else "content"
- }
-
- fields = self._collection_schema_cache[coll_name]
- id_field = fields["id"]
- content_field = fields["content"]
-
- # 使用 in 表达式进行批量查询 (分批处理以防 ID 过多)
- CHUNK_SIZE = 100
- for i in range(0, len(task_ids), CHUNK_SIZE):
- chunk_ids = task_ids[i:i + CHUNK_SIZE]
- id_list_str = ", ".join([f'"{tid}"' for tid in chunk_ids])
-
- logger.info(f"正在从 Milvus 集合 {coll_name} 查询分片内容 ({i}/{len(task_ids)})...")
- res = milvus_service.client.query(
- collection_name=coll_name,
- filter=f'{id_field} in [{id_list_str}]',
- output_fields=[id_field, content_field]
- )
-
- if res:
- for s in res:
- tid = s.get(id_field)
- val = s.get(content_field)
- if tid in result_map and val:
- result_map[tid].append(val)
-
- # 如果当前集合已经查到了内容,就不再查兜底集合 (除非结果仍为空)
- if any(result_map.values()):
- logger.info(f"从 Milvus 集合 {coll_name} 查得 {sum(len(v) for v in result_map.values())} 条内容分片")
- return result_map
- except Exception as e:
- logger.error(f"批量查询 Milvus 集合 {coll_name} 异常: {e}")
- continue
-
- return result_map
- async def get_project_progress(self, project_id: str) -> Dict[str, Any]:
- """获取外部标注项目的进度"""
- conn = get_db_connection()
- if not conn:
- return {"error": "数据库连接失败"}
-
- try:
- # 1. 查询 remote_project_id (task_id)
- cursor = conn.cursor()
- cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
- row = cursor.fetchone()
- cursor.close()
-
- if not row or not row['task_id']:
- return {"error": "未找到已推送的外部项目ID"}
-
- remote_project_id = row['task_id']
-
- # 2. 获取配置
- from app.core.config import config_handler
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- progress_url = f"{api_base_url}/{remote_project_id}/progress"
- token = config_handler.get('external_api', 'admin_token', '')
- # 3. 发送请求
- async with httpx.AsyncClient(timeout=10.0) as client:
- headers = {"Authorization": f"Bearer {token}"}
- response = await client.get(progress_url, headers=headers)
-
- if response.status_code == 200:
- data = response.json()
- # 同步更新本地缓存的完成数量和总数
- completed_count = data.get('completed_tasks', 0)
- total_count = data.get('total_tasks', 0)
- try:
- conn_update = get_db_connection()
- if conn_update:
- cursor_update = conn_update.cursor()
- cursor_update.execute(
- """
- UPDATE t_task_management
- SET external_completed_count = %s, external_total_count = %s
- WHERE project_id = %s
- """,
- (completed_count, total_count, project_id)
- )
- conn_update.commit()
- cursor_update.close()
- conn_update.close()
- except Exception as ex:
- logger.warning(f"更新本地进度缓存失败: {ex}")
-
- return data
- else:
- logger.error(f"查询进度失败: {response.status_code} - {response.text}")
- return {"error": f"外部平台返回错误 ({response.status_code})"}
-
- except Exception as e:
- logger.exception(f"查询进度异常: {e}")
- return {"error": str(e)}
- finally:
- conn.close()
- async def export_labeled_data(self, project_id: str, export_format: str = 'json', completed_only: bool = True) -> Dict[str, Any]:
- """触发外部标注项目的数据导出"""
- conn = get_db_connection()
- if not conn:
- return {"error": "数据库连接失败"}
-
- try:
- # 1. 查询 remote_project_id (task_id)
- cursor = conn.cursor()
- cursor.execute("SELECT task_id FROM t_task_management WHERE project_id = %s AND task_id IS NOT NULL LIMIT 1", (project_id,))
- row = cursor.fetchone()
- cursor.close()
-
- if not row or not row['task_id']:
- return {"error": "未找到已推送的外部项目ID"}
-
- remote_project_id = row['task_id']
-
- # 2. 获取配置
- from app.core.config import config_handler
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- export_url = f"{api_base_url}/{remote_project_id}/export"
- token = config_handler.get('external_api', 'admin_token', '')
- # 3. 发送请求
- async with httpx.AsyncClient(timeout=30.0) as client:
- headers = {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json"
- }
- payload = {
- "format": export_format,
- "completed_only": completed_only
- }
- response = await client.post(export_url, json=payload, headers=headers)
-
- if response.status_code in (200, 201):
- res_data = response.json()
- logger.info(f"外部平台导出响应数据类型: {type(res_data)}, 键名: {list(res_data.keys()) if isinstance(res_data, dict) else 'None'}")
-
- # 1. 统一获取下载地址并回写 (兼容 download_url 和 file_url)
- download_url = res_data.get('download_url') or res_data.get('file_url')
- if isinstance(res_data, dict) and download_url:
- try:
- cursor_update = conn.cursor()
- affected = cursor_update.execute(
- "UPDATE t_task_management SET file_url = %s WHERE project_id = %s",
- (download_url, project_id)
- )
- conn.commit()
- cursor_update.close()
- logger.info(f"导出时同步更新 file_url: {download_url}, 受影响行数: {affected}")
- except Exception as ex:
- logger.warning(f"导出同步回写 file_url 失败: {ex}")
-
- # 2. 同步回写 annotation_result 到 metadata
- # 情况 A: 接口直接返回了任务列表数据
- export_items = []
- if isinstance(res_data, dict) and 'data' in res_data and isinstance(res_data['data'], list):
- export_items = res_data['data']
-
- # 情况 B: 接口返回了文件链接,且格式为 JSON,尝试下载并解析以获取标注结果
- elif isinstance(res_data, dict) and download_url and res_data.get('format') == 'json':
- try:
- # 补全 URL 协议 (如果外部平台返回的是相对路径)
- full_download_url = download_url
- if not download_url.startswith('http'):
- from app.core.config import config_handler
- # 尝试从配置获取 base_url,如果没有则从 api_url 中提取
- base_url = config_handler.get('external_api', 'download_base_url', '')
- if not base_url:
- # 兜底:从 project_api_url 中提取域名部分
- api_base = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003')
- from urllib.parse import urlparse
- parsed = urlparse(api_base)
- base_url = f"{parsed.scheme}://{parsed.netloc}"
-
- full_download_url = f"{base_url.rstrip('/')}/{download_url.lstrip('/')}"
-
- logger.info(f"正在从导出链接获取详细标注数据以同步数据库: {full_download_url}")
- # 注意:这里需要带上 token
- file_res = await client.get(full_download_url, headers=headers)
- if file_res.status_code == 200:
- file_json = file_res.json()
- # 外部平台导出的 JSON 结构通常是 { "data": [...] } 或直接是 [...]
- if isinstance(file_json, dict) and 'data' in file_json:
- export_items = file_json['data']
- elif isinstance(file_json, list):
- export_items = file_json
-
- if export_items:
- logger.info(f"成功获取导出项,共 {len(export_items)} 条。")
- else:
- logger.warning("获取到的导出列表为空")
- except Exception as ex:
- logger.warning(f"从导出文件同步标注数据失败: {ex}")
- if export_items:
- updated_count = 0
- try:
- cursor_meta = conn.cursor()
- for ext_item in export_items:
- # 根据实际 JSON 结构提取数据
- # 1. 提取 original_id
- original_data = ext_item.get('original_data', {})
- meta = original_data.get('metadata', {}) if original_data else ext_item.get('metadata', {})
- # 增加兼容性:如果 metadata 里没有,尝试直接从 ext_item 找,或者从 external_id 提取
- original_id = meta.get('original_id')
- if not original_id:
- ext_id = ext_item.get('external_id', '')
- if '_' in ext_id: # 比如 "uuid_4" 这种结构,提取前面的 uuid
- original_id = ext_id.rsplit('_', 1)[0]
- else:
- original_id = ext_id
-
- # 2. 提取 annotation_result
- annotations = ext_item.get('annotations', [])
- if annotations and isinstance(annotations, list):
- annotation_res = annotations[0].get('result')
- else:
- annotation_res = ext_item.get('annotation_result')
-
- if original_id and annotation_res is not None:
- # 注意:这里需要根据 business_id 或 metadata 里的 original_id 来匹配
- # 样本中心 t_task_management 表的 business_id 存的是原始数据的唯一标识
- cursor_meta.execute(
- "SELECT id, metadata FROM t_task_management WHERE business_id = %s OR id = %s",
- (original_id, original_id)
- )
- row = cursor_meta.fetchone()
- if row:
- db_id = row['id']
- current_meta = json.loads(row['metadata']) if row['metadata'] else {}
- current_meta['annotation_result'] = annotation_res
-
- cursor_meta.execute(
- "UPDATE t_task_management SET metadata = %s WHERE id = %s",
- (json.dumps(current_meta, ensure_ascii=False), db_id)
- )
- updated_count += 1
- else:
- logger.debug(f"未在数据库中找到对应的任务: {original_id}")
-
- conn.commit()
- cursor_meta.close()
- logger.info(f"已从导出数据同步回写 {updated_count} 条任务的 annotation_result")
- except Exception as ex:
- logger.warning(f"同步回写 annotation_result 异常: {ex}")
-
- return res_data
- else:
- logger.error(f"导出数据失败: {response.status_code} - {response.text}")
- return {"error": f"外部平台返回错误 ({response.status_code})"}
-
- except Exception as e:
- logger.exception(f"导出数据异常: {e}")
- return {"error": str(e)}
- finally:
- conn.close()
- async def send_to_external_platform(self, project_id: str) -> Tuple[bool, str]:
- """将项目数据推送至外部标注平台 (单表化)"""
- # 1. 准备数据 (导出数据需要数据库连接)
- conn = get_db_connection()
- if not conn:
- return False, "数据库连接失败"
-
- try:
- logger.info(f"开始导出项目 {project_id} 数据...")
- payload = await self.export_project_data(project_id=project_id, conn=conn)
-
- if not payload:
- return False, "项目导出失败,请检查项目ID是否正确"
-
- if not payload.get('data'):
- return False, f"项目数据为空 (查询到0条有效任务),无法推送"
- except Exception as e:
- logger.exception(f"导出项目数据异常: {e}")
- return False, f"导出异常: {str(e)}"
- finally:
- # 及时释放连接,防止在 HTTP 请求期间占用
- conn.close()
-
- # 2. 获取配置
- try:
- from app.core.config import config_handler
-
- # 强制重新加载配置文件
- import os
- current_dir = os.path.dirname(os.path.abspath(__file__))
- config_file = os.path.join(current_dir, '..', 'config', 'config.ini')
- config_file = os.path.abspath(config_file)
-
- logger.info(f"配置文件路径: {config_file}")
- logger.info(f"配置文件是否存在: {os.path.exists(config_file)}")
-
- # 重新读取配置文件
- config_handler.config.read(config_file, encoding='utf-8')
-
- # 调试: 检查配置文件是否正确加载
- logger.info(f"配置文件sections: {config_handler.config.sections()}")
-
- # 调试: 检查external_api section的所有配置
- if config_handler.config.has_section('external_api'):
- logger.info(f"external_api section存在")
- logger.info(f"external_api所有选项: {config_handler.config.options('external_api')}")
- else:
- logger.error("external_api section不存在!")
-
- api_base_url = config_handler.get('external_api', 'project_api_url', 'http://192.168.92.61:9003/api/external/projects').rstrip('/')
- api_url = f"{api_base_url}/init"
-
- # 直接从ConfigParser读取,看看原始值
- raw_token = None
- if config_handler.config.has_option('external_api', 'admin_token'):
- raw_token = config_handler.config.get('external_api', 'admin_token')
- logger.info(f"原始token (直接从ConfigParser): 类型={type(raw_token)}, 长度={len(raw_token)}, 前50字符={raw_token[:50] if raw_token else 'None'}")
- else:
- logger.error("admin_token选项不存在!")
-
- token = config_handler.get('external_api', 'admin_token', '')
- logger.info(f"通过get方法获取的token: 类型={type(token)}, 长度={len(str(token)) if token else 0}, 前50字符={str(token)[:50] if token else 'None'}")
-
- # 确保token是字符串并去除首尾空格
- if token:
- token = str(token).strip()
-
- logger.info(f"清理后的token: 类型={type(token)}, 长度={len(token) if token else 0}, 前50字符={token[:50] if token else 'None'}")
-
- if not token:
- logger.error("外部平台Token未配置或为空")
- return False, "外部平台Token未配置"
- # 3. 发送请求 (不持有数据库连接)
- async with httpx.AsyncClient(timeout=120.0) as client: # 增加超时时间到 120s
- headers = {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json"
- }
- logger.info(f"正在推送项目 {project_id} 至外部平台: {api_url}, 数据条数: {len(payload['data'])}")
- logger.info(f"正在推送项目: {payload['data']}")
- response = await client.post(api_url, json=payload, headers=headers)
-
- if response.status_code in (200, 201):
- res_data = response.json()
- logger.info(f"外部平台推送成功响应: {res_data}")
- remote_project_id = res_data.get('project_id')
- download_url = res_data.get('download_url') or res_data.get('file_url')
-
- # 4. 回写外部项目 ID 和下载地址 (重新获取连接)
- if remote_project_id:
- conn = get_db_connection()
- if conn:
- try:
- cursor = conn.cursor()
- if download_url:
- affected = cursor.execute(
- "UPDATE t_task_management SET task_id = %s, file_url = %s WHERE project_id = %s",
- (remote_project_id, download_url, project_id)
- )
- logger.info(f"已回写 task_id: {remote_project_id} 和 file_url: {download_url}, 受影响行数: {affected}")
- else:
- affected = cursor.execute(
- "UPDATE t_task_management SET task_id = %s WHERE project_id = %s",
- (remote_project_id, project_id)
- )
- logger.info(f"仅回写 task_id: {remote_project_id}, 受影响行数: {affected}")
- conn.commit()
- finally:
- cursor.close()
- conn.close()
-
- return True, f"推送成功!外部项目ID: {remote_project_id or '未知'}"
- else:
- error_msg = response.text
- logger.error(f"推送失败: {response.status_code} - {error_msg}")
- return False, f"外部平台返回错误 ({response.status_code})"
-
- except Exception as e:
- logger.exception(f"推送至外部平台异常: {e}")
- return False, f"推送异常: {str(e)}"
- task_service = TaskService()
|