responses.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from typing import List, Tuple
  2. from fastapi.responses import StreamingResponse
  3. from fastapi import status
  4. from starlette.types import Send
  5. from gpustack.api.exceptions import (
  6. OpenAIAPIError,
  7. OpenAIAPIErrorResponse,
  8. )
  9. class StreamingResponseWithStatusCode(StreamingResponse):
  10. '''
  11. Variation of StreamingResponse that can dynamically decide the HTTP status code, based on the returns from the content iterator (parameter 'content').
  12. Expects the content to yield tuples of (content: str, status_code: int), instead of just content as it was in the original StreamingResponse.
  13. The parameter status_code in the constructor is ignored, but kept for compatibility with StreamingResponse.
  14. '''
  15. async def stream_response(self, send: Send) -> None:
  16. try:
  17. first_chunk_content, headers, self.status_code = (
  18. await self.body_iterator.__anext__()
  19. )
  20. if not isinstance(first_chunk_content, bytes):
  21. first_chunk_content = first_chunk_content.encode(self.charset)
  22. asgi_headers: List[Tuple[bytes, bytes]] = [
  23. (key.encode("latin-1"), value.encode("latin-1"))
  24. for key, value in headers.items()
  25. ]
  26. await send(
  27. {
  28. "type": "http.response.start",
  29. "status": self.status_code,
  30. "headers": asgi_headers,
  31. }
  32. )
  33. await send(
  34. {
  35. "type": "http.response.body",
  36. "body": first_chunk_content,
  37. "more_body": True,
  38. }
  39. )
  40. async for chunk_content, _, _ in self.body_iterator:
  41. if not isinstance(chunk_content, bytes):
  42. chunk_content = chunk_content.encode(self.charset)
  43. await send(
  44. {
  45. "type": "http.response.body",
  46. "body": chunk_content,
  47. "more_body": True,
  48. }
  49. )
  50. await send({"type": "http.response.body", "body": b"", "more_body": False})
  51. except StopAsyncIteration:
  52. self.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
  53. await send(
  54. {
  55. "type": "http.response.start",
  56. "status": self.status_code,
  57. "headers": self.raw_headers,
  58. }
  59. )
  60. error_response = OpenAIAPIErrorResponse(
  61. error=OpenAIAPIError(
  62. message="Service unavailable. Please retry your requests after a brief wait.",
  63. code=status.HTTP_503_SERVICE_UNAVAILABLE,
  64. type="ServiceUnavailable",
  65. ),
  66. )
  67. await send(
  68. {
  69. "type": "http.response.body",
  70. "body": error_response.model_dump_json().encode(),
  71. "more_body": False,
  72. }
  73. )
  74. except Exception as e:
  75. self.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
  76. await send(
  77. {
  78. "type": "http.response.start",
  79. "status": self.status_code,
  80. "headers": self.raw_headers,
  81. }
  82. )
  83. error_response = OpenAIAPIErrorResponse(
  84. error=OpenAIAPIError(
  85. message=str(e),
  86. code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  87. type="InternalServerError",
  88. ),
  89. )
  90. await send(
  91. {
  92. "type": "http.response.body",
  93. "body": error_response.model_dump_json().encode(),
  94. "more_body": False,
  95. }
  96. )