phrase_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """
  2. 热词词表管理服务
  3. """
  4. import json
  5. import logging
  6. from typing import Dict, Any, Optional, List
  7. from datetime import datetime
  8. from sqlalchemy.orm import Session
  9. from sqlalchemy import desc
  10. from app.models.tingwu import TranscriptionPhrase
  11. from app.services.tingwu_client import TingwuClient, TingwuAPIError
  12. logger = logging.getLogger(__name__)
  13. class PhraseService:
  14. """热词词表管理服务"""
  15. def __init__(self, db: Session, user_id: str, api_key: str):
  16. """
  17. 初始化服务
  18. Args:
  19. db: 数据库会话
  20. user_id: 用户 ID
  21. api_key: DashScope API Key
  22. """
  23. self.db = db
  24. self.user_id = user_id
  25. self.client = TingwuClient(api_key)
  26. async def create_phrase(
  27. self,
  28. name: str,
  29. word_weights: Dict[str, int],
  30. description: Optional[str] = None
  31. ) -> TranscriptionPhrase:
  32. """
  33. 创建热词词表
  34. Args:
  35. name: 词表名称
  36. word_weights: 热词及权重字典
  37. description: 词表描述
  38. Returns:
  39. 创建的词表对象
  40. Raises:
  41. ValueError: 参数错误
  42. TingwuAPIError: API 调用失败
  43. """
  44. # 1. 参数验证
  45. if not name or len(name) > 100:
  46. raise ValueError("词表名称不能为空且长度不超过 100 字符")
  47. if not word_weights or not isinstance(word_weights, dict):
  48. raise ValueError("热词权重必须是非空字典")
  49. # 验证权重值
  50. for word, weight in word_weights.items():
  51. if not isinstance(weight, int) or weight < 1 or weight > 10:
  52. raise ValueError(f"热词 '{word}' 的权重必须是 1-10 之间的整数")
  53. # 2. 调用 DashScope API 创建词表
  54. result = await self.client.create_phrase(name, word_weights, description)
  55. phrase_id = result.get('PhraseId')
  56. if not phrase_id:
  57. raise ValueError("API 返回的词表 ID 为空")
  58. # 3. 保存到数据库
  59. phrase = TranscriptionPhrase(
  60. user_id=self.user_id,
  61. phrase_id=phrase_id,
  62. name=name,
  63. description=description,
  64. word_weights=json.dumps(word_weights, ensure_ascii=False),
  65. status='ACTIVE'
  66. )
  67. self.db.add(phrase)
  68. self.db.commit()
  69. self.db.refresh(phrase)
  70. logger.info(f"创建热词词表成功: phrase_id={phrase_id}, user_id={self.user_id}")
  71. return phrase
  72. async def update_phrase(
  73. self,
  74. phrase_id: str,
  75. name: Optional[str] = None,
  76. word_weights: Optional[Dict[str, int]] = None,
  77. description: Optional[str] = None
  78. ) -> TranscriptionPhrase:
  79. """
  80. 更新热词词表
  81. Args:
  82. phrase_id: 词表 ID
  83. name: 新名称
  84. word_weights: 新的热词权重
  85. description: 新描述
  86. Returns:
  87. 更新后的词表对象
  88. Raises:
  89. ValueError: 词表不存在或参数错误
  90. TingwuAPIError: API 调用失败
  91. """
  92. # 1. 查询词表
  93. phrase = self.db.query(TranscriptionPhrase).filter_by(
  94. phrase_id=phrase_id,
  95. user_id=self.user_id,
  96. status='ACTIVE'
  97. ).first()
  98. if not phrase:
  99. raise ValueError(f"词表不存在或已删除: {phrase_id}")
  100. # 2. 参数验证
  101. if name and len(name) > 100:
  102. raise ValueError("词表名称长度不超过 100 字符")
  103. if word_weights:
  104. if not isinstance(word_weights, dict):
  105. raise ValueError("热词权重必须是字典")
  106. for word, weight in word_weights.items():
  107. if not isinstance(weight, int) or weight < 1 or weight > 10:
  108. raise ValueError(f"热词 '{word}' 的权重必须是 1-10 之间的整数")
  109. # 3. 调用 DashScope API 更新
  110. await self.client.update_phrase(phrase_id, name, word_weights, description)
  111. # 4. 更新数据库
  112. if name:
  113. phrase.name = name
  114. if word_weights:
  115. phrase.word_weights = json.dumps(word_weights, ensure_ascii=False)
  116. if description is not None:
  117. phrase.description = description
  118. phrase.updated_at = datetime.now()
  119. self.db.commit()
  120. self.db.refresh(phrase)
  121. logger.info(f"更新热词词表成功: phrase_id={phrase_id}")
  122. return phrase
  123. async def delete_phrase(self, phrase_id: str) -> bool:
  124. """
  125. 删除热词词表
  126. Args:
  127. phrase_id: 词表 ID
  128. Returns:
  129. 是否删除成功
  130. Raises:
  131. ValueError: 词表不存在
  132. TingwuAPIError: API 调用失败
  133. """
  134. # 1. 查询词表
  135. phrase = self.db.query(TranscriptionPhrase).filter_by(
  136. phrase_id=phrase_id,
  137. user_id=self.user_id,
  138. status='ACTIVE'
  139. ).first()
  140. if not phrase:
  141. raise ValueError(f"词表不存在或已删除: {phrase_id}")
  142. # 2. 调用 DashScope API 删除
  143. await self.client.delete_phrase(phrase_id)
  144. # 3. 软删除(更新状态)
  145. phrase.status = 'DELETED'
  146. phrase.updated_at = datetime.now()
  147. self.db.commit()
  148. logger.info(f"删除热词词表成功: phrase_id={phrase_id}")
  149. return True
  150. def get_user_phrases(
  151. self,
  152. page: int = 1,
  153. page_size: int = 20,
  154. status: str = 'ACTIVE'
  155. ) -> Dict[str, Any]:
  156. """
  157. 获取用户热词词表列表
  158. Args:
  159. page: 页码
  160. page_size: 每页数量
  161. status: 状态筛选
  162. Returns:
  163. 包含词表列表和总数的字典
  164. """
  165. query = self.db.query(TranscriptionPhrase).filter_by(
  166. user_id=self.user_id,
  167. status=status
  168. )
  169. # 总数
  170. total = query.count()
  171. # 分页查询
  172. phrases = query.order_by(desc(TranscriptionPhrase.created_at)).offset(
  173. (page - 1) * page_size
  174. ).limit(page_size).all()
  175. return {
  176. 'total': total,
  177. 'page': page,
  178. 'page_size': page_size,
  179. 'items': [phrase.to_dict() for phrase in phrases]
  180. }
  181. def get_phrase_by_id(self, phrase_id: str) -> Optional[TranscriptionPhrase]:
  182. """
  183. 根据 ID 获取词表
  184. Args:
  185. phrase_id: 词表 ID
  186. Returns:
  187. 词表对象,不存在返回 None
  188. """
  189. return self.db.query(TranscriptionPhrase).filter_by(
  190. phrase_id=phrase_id,
  191. user_id=self.user_id,
  192. status='ACTIVE'
  193. ).first()