edututor_client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. """
  2. 阿里云百炼 EduTutor 客户端
  3. 提供拍照解题功能的API调用
  4. """
  5. import os
  6. import logging
  7. from typing import Optional, AsyncGenerator
  8. from alibabacloud_edututor20250707.client import Client as EduTutorClient
  9. from alibabacloud_edututor20250707 import models as edututor_models
  10. from alibabacloud_tea_openapi import models as open_api_models
  11. from alibabacloud_credentials.models import Config as CredConfig
  12. from alibabacloud_credentials.client import Client as CredClient
  13. logger = logging.getLogger(__name__)
  14. class BailianEduTutorClient:
  15. """百炼 EduTutor 客户端"""
  16. def __init__(self, api_key: str, workspace_id: Optional[str] = None):
  17. """
  18. 初始化客户端
  19. Args:
  20. api_key: 用户的 DASHSCOPE API Key(预留参数,当前未使用)
  21. workspace_id: 百炼工作空间ID,默认从环境变量读取
  22. """
  23. self.api_key = api_key
  24. self.workspace_id = workspace_id or os.getenv('BAILIAN_WORKSPACE_ID', 'llm-uflun9q7q59osmbb')
  25. # 从环境变量获取 AccessKey ID 和 Secret
  26. access_key_id = os.getenv('ALIBABA_CLOUD_ACCESS_KEY_ID')
  27. access_key_secret = os.getenv('ALIBABA_CLOUD_ACCESS_KEY_SECRET')
  28. if not access_key_id or not access_key_secret:
  29. raise ValueError("环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID 或 ALIBABA_CLOUD_ACCESS_KEY_SECRET 未配置")
  30. # 使用环境变量的 AccessKey 进行认证
  31. cred_config = CredConfig(
  32. type='access_key',
  33. access_key_id=access_key_id,
  34. access_key_secret=access_key_secret
  35. )
  36. cred = CredClient(cred_config)
  37. # 创建配置(使用 hangzhou endpoint)
  38. config = open_api_models.Config(
  39. credential=cred,
  40. endpoint='edututor.cn-hangzhou.aliyuncs.com'
  41. )
  42. # 创建客户端
  43. self.client = EduTutorClient(config)
  44. logger.info(f"EduTutor client initialized with workspace: {self.workspace_id}")
  45. async def answer_sse_async(
  46. self,
  47. image_url: str,
  48. grade: int = 0,
  49. stage: str = 'other',
  50. subject: str = 'other'
  51. ) -> AsyncGenerator[dict, None]:
  52. """
  53. 流式解答题目(SSE)
  54. Args:
  55. image_url: 题目图片URL
  56. grade: 年级(0-17, 99=其他)
  57. stage: 学段
  58. subject: 学科
  59. Yields:
  60. dict: SSE事件数据
  61. - type: 'start' | 'chunk' | 'finish'
  62. - content: 内容文本(仅chunk类型)
  63. - finish_reason: 完成原因(仅finish类型)
  64. - tokens: Token统计(仅finish类型)
  65. """
  66. try:
  67. # 构建消息对象(content 是字典列表)
  68. message = edututor_models.AnswerSSERequestMessages(
  69. role='user',
  70. content=[{'image': image_url}]
  71. )
  72. # 构建参数对象
  73. parameters = edututor_models.AnswerSSERequestParameters(
  74. grade=grade,
  75. stage=stage,
  76. subject=subject
  77. )
  78. # 构建请求
  79. request = edututor_models.AnswerSSERequest(
  80. workspace_id=self.workspace_id,
  81. messages=[message],
  82. parameters=parameters
  83. )
  84. logger.info(f"Calling EduTutor Answer API: image_url={image_url}, grade={grade}, stage={stage}, subject={subject}")
  85. try:
  86. # 设置运行时选项
  87. from alibabacloud_tea_util import models as util_models
  88. runtime = util_models.RuntimeOptions(read_timeout=1000 * 100)
  89. headers = {}
  90. # 调用流式API(使用 answer_ssewith_sse 方法)
  91. logger.info("About to call self.client.answer_ssewith_sse...")
  92. sse_receiver = self.client.answer_ssewith_sse(request, headers, runtime)
  93. logger.info(f"Got SSE receiver: {type(sse_receiver)}")
  94. except Exception as e:
  95. logger.error(f"Failed to call answer_ssewith_sse: {type(e).__name__}: {e}", exc_info=True)
  96. raise Exception(f"调用百炼 API 失败: {str(e)}")
  97. logger.info("Starting to parse SSE stream...")
  98. # 解析SSE流
  99. import json
  100. event_count = 0
  101. try:
  102. for response in sse_receiver:
  103. event_count += 1
  104. try:
  105. body = response.body
  106. # 将响应体转换为字典
  107. if hasattr(body, 'to_map'):
  108. body_dict = body.to_map()
  109. elif isinstance(body, dict):
  110. body_dict = body
  111. else:
  112. # 尝试通过属性访问
  113. body_dict = {}
  114. for attr in ['code', 'data', 'message', 'request_id', 'finish_reason', 'input_tokens', 'output_tokens']:
  115. if hasattr(body, attr):
  116. body_dict[attr] = getattr(body, attr)
  117. # 检查是否成功
  118. if body_dict.get('code') != 'SUCCESS':
  119. logger.error(f"API error: {body_dict}")
  120. continue
  121. # 解析 data 字段(JSON 字符串)
  122. if 'data' in body_dict and body_dict['data']:
  123. try:
  124. inner_data = json.loads(body_dict['data'])
  125. except json.JSONDecodeError as e:
  126. logger.warning(f"Event {event_count}: Failed to parse data field: {e}")
  127. # 跳过这个事件,继续处理下一个
  128. continue
  129. # 提取message内容
  130. if 'message' in inner_data:
  131. message = inner_data['message']
  132. if 'content' in message and len(message['content']) > 0:
  133. text = message['content'][0].get('text', '')
  134. if text:
  135. yield {
  136. 'type': 'chunk',
  137. 'content': text
  138. }
  139. # 检查是否完成
  140. if 'finish_reason' in body_dict and body_dict['finish_reason'] and body_dict['finish_reason'] != 'null':
  141. logger.info(f"Stream finished: {body_dict['finish_reason']}, processed {event_count} events")
  142. yield {
  143. 'type': 'finish',
  144. 'finish_reason': body_dict['finish_reason'],
  145. 'tokens': {
  146. 'input': body_dict.get('input_tokens', 0),
  147. 'output': body_dict.get('output_tokens', 0)
  148. }
  149. }
  150. return
  151. except Exception as e:
  152. logger.warning(f"Event {event_count}: Error processing SSE response: {e}")
  153. # 继续处理下一个事件
  154. continue
  155. except json.JSONDecodeError as e:
  156. # SDK 内部的 JSON 解析错误
  157. logger.error(f"SDK JSON decode error after {event_count} events: {e}")
  158. # 如果已经处理了一些事件,发送 finish 事件
  159. if event_count > 0:
  160. logger.info(f"Ending stream early due to JSON error, processed {event_count} events")
  161. yield {
  162. 'type': 'finish',
  163. 'finish_reason': 'error',
  164. 'tokens': {'input': 0, 'output': 0}
  165. }
  166. return
  167. else:
  168. # 如果一个事件都没处理,抛出异常
  169. raise Exception(f"解题失败: {str(e)}")
  170. logger.info(f"SSE stream ended, processed {event_count} events total")
  171. logger.info("EduTutor Answer API call completed")
  172. except Exception as e:
  173. logger.error(f"EduTutor Answer API error: {e}")
  174. raise Exception(f"解题失败: {str(e)}")
  175. async def cut_questions_async(
  176. self,
  177. image_url: str,
  178. struct: bool = True,
  179. extract_images: bool = True
  180. ) -> dict:
  181. """
  182. 异步切题接口
  183. Args:
  184. image_url: 试卷图片URL
  185. struct: 是否输出题目结构化(OCR)信息
  186. extract_images: 是否返回题目图片链接
  187. Returns:
  188. dict: 切题结果
  189. - questions: 题目列表
  190. - count: 题目数量
  191. """
  192. try:
  193. # 构建参数对象
  194. parameters = edututor_models.CutQuestionsRequestParameters(
  195. struct=struct,
  196. extract_images=extract_images
  197. )
  198. # 构建请求
  199. request = edututor_models.CutQuestionsRequest(
  200. image=image_url,
  201. parameters=parameters,
  202. workspace_id=self.workspace_id
  203. )
  204. logger.info(f"Calling EduTutor CutQuestions API: image_url={image_url}")
  205. # 调用API
  206. response = await self.client.cut_questions_async(request)
  207. logger.info(f"CutQuestions API raw response type: {type(response)}")
  208. # 解析响应
  209. if hasattr(response, 'body'):
  210. body = response.body
  211. logger.info(f"CutQuestions API response body type: {type(body)}")
  212. # 将响应体转换为字典
  213. import json
  214. if hasattr(body, 'to_map'):
  215. body_dict = body.to_map()
  216. logger.info(f"Converted body to dict, keys: {body_dict.keys()}")
  217. elif isinstance(body, dict):
  218. body_dict = body
  219. logger.info(f"Body is already dict, keys: {body_dict.keys()}")
  220. else:
  221. # 尝试通过属性访问
  222. logger.info(f"Body attributes: {dir(body)}")
  223. body_dict = {}
  224. for attr in ['code', 'data', 'message', 'request_id']:
  225. if hasattr(body, attr):
  226. body_dict[attr] = getattr(body, attr)
  227. logger.info(f"Extracted attributes: {body_dict.keys()}")
  228. # 检查是否有 code 和 data 字段
  229. if body_dict.get('code') == 'SUCCESS' and 'data' in body_dict:
  230. logger.info(f"Found SUCCESS code, parsing data field...")
  231. # data 是 JSON 字符串,需要解析
  232. data_str = body_dict['data']
  233. logger.info(f"Data field type: {type(data_str)}")
  234. data = json.loads(data_str)
  235. logger.info(f"Parsed data keys: {data.keys() if isinstance(data, dict) else 'not a dict'}")
  236. questions = data.get('questions', [])
  237. logger.info(f"Found {len(questions)} questions in parsed data")
  238. # 转换为前端需要的格式
  239. formatted_questions = []
  240. for idx, q in enumerate(questions):
  241. merged_image = q.get('merged_image', '')
  242. stem_text = ''
  243. # 安全地提取 stem text
  244. if 'info' in q and isinstance(q['info'], dict):
  245. if 'stem' in q['info'] and isinstance(q['info']['stem'], dict):
  246. stem_text = q['info']['stem'].get('text', '')
  247. formatted_q = {
  248. 'question_id': str(idx + 1),
  249. 'image_url': merged_image,
  250. 'text': stem_text
  251. }
  252. logger.info(f"Formatted question {idx + 1}: id={formatted_q['question_id']}, has_image={bool(merged_image)}, text_length={len(stem_text)}")
  253. formatted_questions.append(formatted_q)
  254. result = {
  255. 'questions': formatted_questions,
  256. 'count': len(formatted_questions)
  257. }
  258. logger.info(f"Returning result with {result['count']} questions")
  259. return result
  260. elif 'questions' in body_dict:
  261. # 直接包含 questions 字段
  262. logger.info(f"Found questions field directly in body")
  263. return {
  264. 'questions': body_dict.get('questions', []),
  265. 'count': len(body_dict.get('questions', []))
  266. }
  267. else:
  268. logger.warning(f"Unexpected response format, available keys: {body_dict.keys()}")
  269. logger.warning("No valid response body found, returning empty result")
  270. return {'questions': [], 'count': 0}
  271. except Exception as e:
  272. logger.error(f"EduTutor CutQuestions API error: {e}", exc_info=True)
  273. raise Exception(f"切题失败: {str(e)}")
  274. def answer_sync(
  275. self,
  276. image_url: str,
  277. grade: int = 0,
  278. stage: str = 'other',
  279. subject: str = 'other'
  280. ) -> dict:
  281. """
  282. 同步解答题目(非流式)
  283. Args:
  284. image_url: 题目图片URL
  285. grade: 年级
  286. stage: 学段
  287. subject: 学科
  288. Returns:
  289. dict: 解答结果
  290. - answer: 解答内容
  291. - input_tokens: 输入Token数
  292. - output_tokens: 输出Token数
  293. """
  294. try:
  295. # 构建消息对象(content 是字典列表)
  296. message = edututor_models.AnswerSSERequestMessages(
  297. role='user',
  298. content=[{'image': image_url}]
  299. )
  300. # 构建参数对象
  301. parameters = edututor_models.AnswerSSERequestParameters(
  302. grade=grade,
  303. stage=stage,
  304. subject=subject
  305. )
  306. # 构建请求
  307. request = edututor_models.AnswerSSERequest(
  308. workspace_id=self.workspace_id,
  309. messages=[message],
  310. parameters=parameters
  311. )
  312. logger.info(f"Calling EduTutor Answer API (sync): image_url={image_url}")
  313. # 设置运行时选项
  314. from alibabacloud_tea_util import models as util_models
  315. runtime = util_models.RuntimeOptions(read_timeout=1000 * 100)
  316. headers = {}
  317. # 调用API
  318. sse_receiver = self.client.answer_ssewith_sse(request, headers, runtime)
  319. # 收集所有内容
  320. full_answer = ''
  321. input_tokens = 0
  322. output_tokens = 0
  323. import json
  324. for response in sse_receiver:
  325. try:
  326. body = response.body
  327. # 将响应体转换为字典
  328. if hasattr(body, 'to_map'):
  329. body_dict = body.to_map()
  330. elif isinstance(body, dict):
  331. body_dict = body
  332. else:
  333. body_dict = {}
  334. for attr in ['code', 'data', 'finish_reason', 'input_tokens', 'output_tokens']:
  335. if hasattr(body, attr):
  336. body_dict[attr] = getattr(body, attr)
  337. if body_dict.get('code') != 'SUCCESS':
  338. continue
  339. # 解析 data 字段
  340. if 'data' in body_dict and body_dict['data']:
  341. inner_data = json.loads(body_dict['data'])
  342. if 'message' in inner_data:
  343. message = inner_data['message']
  344. if 'content' in message and len(message['content']) > 0:
  345. text = message['content'][0].get('text', '')
  346. full_answer += text
  347. if 'finish_reason' in body_dict and body_dict['finish_reason'] and body_dict['finish_reason'] != 'null':
  348. input_tokens = body_dict.get('input_tokens', 0)
  349. output_tokens = body_dict.get('output_tokens', 0)
  350. break
  351. except json.JSONDecodeError:
  352. continue
  353. except Exception:
  354. continue
  355. return {
  356. 'answer': full_answer,
  357. 'input_tokens': input_tokens,
  358. 'output_tokens': output_tokens
  359. }
  360. except Exception as e:
  361. logger.error(f"EduTutor Answer API error: {e}")
  362. raise Exception(f"解题失败: {str(e)}")