cus_streamer.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :cus_streamer.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/14 12:04
  9. '''
  10. from langchain_core.messages import HumanMessage
  11. from typing import AsyncGenerator
  12. import asyncio
  13. class AdaptiveStreamer:
  14. def __init__(self, min_chunk: int = 256, max_chunk: int = 4096, initial_chunk: int = 4):
  15. self.min_chunk = min_chunk
  16. self.max_chunk = max_chunk
  17. self.chunk_size = initial_chunk
  18. self.buffer = ""
  19. self.last_latency = 0.0
  20. async def astream(self, model, prompt: str, config, stream_mode="values") -> AsyncGenerator[bytes, None]:
  21. """
  22. 自适应流式输出
  23. """
  24. async for langchain_chunk in model.astream({"messages": [HumanMessage(content=prompt)]}, config=config, stream_mode=stream_mode):
  25. # 检查是否有内容属性
  26. if not hasattr(langchain_chunk['messages'][-1], 'content') or not langchain_chunk['messages'][-1].content:
  27. continue # 跳过空内容块
  28. if isinstance(langchain_chunk['messages'][-1], HumanMessage):
  29. continue
  30. try:
  31. # 添加到缓冲区
  32. chunk_bytes = langchain_chunk['messages'][-1].content
  33. self.buffer += chunk_bytes
  34. # 处理缓冲区
  35. while len(self.buffer) >= self.chunk_size:
  36. # 提取块
  37. output_chunk = self.buffer[:self.chunk_size]
  38. self.buffer = self.buffer[self.chunk_size:]
  39. # 记录发送时间
  40. start_time = asyncio.get_event_loop().time()
  41. yield output_chunk
  42. send_duration = asyncio.get_event_loop().time() - start_time
  43. # 基于发送时间调整块大小
  44. self.adjust_chunk_size(send_duration)
  45. except Exception as e:
  46. # 处理编码或其他错误
  47. error_msg = f"[错误] {str(e)}".encode('utf-8')
  48. yield error_msg
  49. continue # 继续处理后续块
  50. # 发送剩余内容
  51. if self.buffer:
  52. yield self.buffer
  53. self.buffer = ""
  54. def adjust_chunk_size(self, send_duration: float):
  55. """
  56. 基于发送时间调整块大小
  57. """
  58. # 计算发送速率(字节/秒)
  59. if send_duration > 0:
  60. send_rate = self.chunk_size / send_duration
  61. else:
  62. send_rate = float('inf')
  63. # 调整策略
  64. if send_rate < 10000: # 低速网络(<10KB/s)
  65. new_size = max(self.min_chunk, int(self.chunk_size * 0.8))
  66. elif send_rate > 100000: # 高速网络(>100KB/s)
  67. new_size = min(self.max_chunk, int(self.chunk_size * 1.2))
  68. else:
  69. new_size = self.chunk_size
  70. # 应用平滑过渡
  71. self.chunk_size = int(0.7 * self.chunk_size + 0.3 * new_size)