exceptions.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from typing import Optional
  2. from fastapi import FastAPI, Request, status
  3. from fastapi.exceptions import RequestValidationError
  4. from fastapi.responses import JSONResponse
  5. import httpx
  6. import logging
  7. from pydantic import BaseModel
  8. logger = logging.getLogger(__name__)
  9. class HTTPException(Exception):
  10. def __init__(self, status_code: int, reason: str, message: str):
  11. self.status_code = status_code
  12. self.reason = reason
  13. self.message = message
  14. class OpenAIAPIException(HTTPException):
  15. pass
  16. def http_exception_factory(
  17. status_code: int,
  18. reason: str,
  19. default_message: str,
  20. ):
  21. class_name = reason + "Exception"
  22. def init(self, message=default_message, is_openai_exception=False):
  23. if is_openai_exception:
  24. self.__class__.__bases__ = (OpenAIAPIException,)
  25. super(self.__class__, self).__init__(status_code, reason, message)
  26. return type(
  27. class_name,
  28. (HTTPException,),
  29. {"__init__": init},
  30. )
  31. AlreadyExistsException = http_exception_factory(
  32. status.HTTP_409_CONFLICT, "AlreadyExists", "Already exists"
  33. )
  34. ConflictException = http_exception_factory(
  35. status.HTTP_409_CONFLICT, "Conflict", "Conflict with existing resource"
  36. )
  37. NotFoundException = http_exception_factory(
  38. status.HTTP_404_NOT_FOUND, "NotFound", "Not found"
  39. )
  40. UnauthorizedException = http_exception_factory(
  41. status.HTTP_401_UNAUTHORIZED, "Unauthorized", "Unauthorized"
  42. )
  43. ForbiddenException = http_exception_factory(
  44. status.HTTP_403_FORBIDDEN, "Forbidden", "Forbidden"
  45. )
  46. InvalidException = http_exception_factory(
  47. status.HTTP_422_UNPROCESSABLE_ENTITY, "Invalid", "Invalid input"
  48. )
  49. BadRequestException = http_exception_factory(
  50. status.HTTP_400_BAD_REQUEST, "BadRequest", "Bad request"
  51. )
  52. InternalServerErrorException = http_exception_factory(
  53. status.HTTP_500_INTERNAL_SERVER_ERROR,
  54. "InternalServerError",
  55. "Internal server error",
  56. )
  57. ServiceUnavailableException = http_exception_factory(
  58. status.HTTP_503_SERVICE_UNAVAILABLE, "ServiceUnavailable", "Service unavailable"
  59. )
  60. GatewayTimeoutException = http_exception_factory(
  61. status.HTTP_504_GATEWAY_TIMEOUT, "GatewayTimeout", "Gateway timeout"
  62. )
  63. async def async_raise_if_response_error(response: httpx.Response): # noqa: C901
  64. if response.status_code < status.HTTP_400_BAD_REQUEST:
  65. return
  66. try:
  67. await response.aread()
  68. except httpx.ReadError as e:
  69. raise HTTPException(
  70. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  71. reason="Unknown",
  72. message=str(e),
  73. )
  74. raise_errors(response)
  75. def raise_if_response_error(response: httpx.Response): # noqa: C901
  76. if response.status_code < status.HTTP_400_BAD_REQUEST:
  77. return
  78. raise_errors(response)
  79. def raise_errors(response: httpx.Response):
  80. try:
  81. response_json = response.json()
  82. # Compatible with OpenAI API error format
  83. if "error" in response_json and isinstance(response_json["error"], dict):
  84. response_json = response_json["error"]
  85. if "type" in response_json and isinstance(response_json["type"], str):
  86. response_json["reason"] = response_json["type"]
  87. error = ErrorResponse.model_validate(response_json)
  88. except Exception:
  89. raise HTTPException(response.status_code, "Unknown", response.text)
  90. if response.status_code == status.HTTP_404_NOT_FOUND:
  91. raise NotFoundException(error.message)
  92. if (
  93. response.status_code == status.HTTP_409_CONFLICT
  94. and error.reason == "AlreadyExists"
  95. ):
  96. raise AlreadyExistsException(error.message)
  97. if response.status_code == status.HTTP_401_UNAUTHORIZED:
  98. raise UnauthorizedException(error.message)
  99. if response.status_code == status.HTTP_403_FORBIDDEN:
  100. raise ForbiddenException(error.message)
  101. if response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY:
  102. raise InvalidException(error.message)
  103. if response.status_code == status.HTTP_400_BAD_REQUEST:
  104. raise BadRequestException(error.message)
  105. if response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR:
  106. raise InternalServerErrorException(error.message)
  107. if response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE:
  108. raise ServiceUnavailableException(error.message)
  109. if response.status_code == status.HTTP_504_GATEWAY_TIMEOUT:
  110. raise GatewayTimeoutException(error.message)
  111. raise HTTPException(error.code, error.reason, error.message)
  112. class ErrorResponse(BaseModel):
  113. code: int
  114. reason: str
  115. message: str
  116. error_responses = {
  117. 404: {"model": ErrorResponse},
  118. 409: {"model": ErrorResponse},
  119. 401: {"model": ErrorResponse},
  120. 403: {"model": ErrorResponse},
  121. 422: {"model": ErrorResponse},
  122. 400: {"model": ErrorResponse},
  123. 500: {"model": ErrorResponse},
  124. 503: {"model": ErrorResponse},
  125. }
  126. class OpenAIAPIError(BaseModel):
  127. message: str
  128. type: Optional[str] = None
  129. code: Optional[int] = None
  130. param: Optional[str] = None
  131. class OpenAIAPIErrorResponse(BaseModel):
  132. error: OpenAIAPIError
  133. openai_api_error_responses = {
  134. 404: {"model": OpenAIAPIErrorResponse},
  135. 409: {"model": OpenAIAPIErrorResponse},
  136. 401: {"model": OpenAIAPIErrorResponse},
  137. 403: {"model": OpenAIAPIErrorResponse},
  138. 422: {"model": OpenAIAPIErrorResponse},
  139. 400: {"model": OpenAIAPIErrorResponse},
  140. 500: {"model": OpenAIAPIErrorResponse},
  141. 503: {"model": OpenAIAPIErrorResponse},
  142. }
  143. def register_handlers(app: FastAPI):
  144. @app.exception_handler(HTTPException)
  145. async def http_exception_handler(request: Request, exc: HTTPException):
  146. if exc.status_code >= 500:
  147. logger.error(
  148. "HTTP server error occurred: %s %s - %s (path=%s, method=%s)",
  149. exc.status_code,
  150. exc.reason,
  151. exc.message,
  152. request.url.path,
  153. request.method,
  154. )
  155. return JSONResponse(
  156. status_code=exc.status_code,
  157. content=ErrorResponse(
  158. code=exc.status_code,
  159. reason=exc.reason,
  160. message=exc.message,
  161. ).model_dump(),
  162. )
  163. @app.exception_handler(OpenAIAPIException)
  164. async def openai_api_exception_handler(request: Request, exc: OpenAIAPIException):
  165. """
  166. This handler is used to return error response in OpenAI API format.
  167. """
  168. return JSONResponse(
  169. status_code=exc.status_code,
  170. content={
  171. "error": {
  172. "message": exc.message,
  173. "code": exc.status_code,
  174. "type": exc.reason,
  175. }
  176. },
  177. )
  178. @app.exception_handler(RequestValidationError)
  179. async def validation_exception_handler(request, exc: RequestValidationError):
  180. message = f"{len(exc.errors())} validation errors:\n"
  181. for err in exc.errors():
  182. message += f" {err}\n"
  183. return JSONResponse(
  184. status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  185. content=ErrorResponse(
  186. code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  187. reason="Invalid",
  188. message=message,
  189. ).model_dump(),
  190. )