proxy.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import os
  2. from urllib.parse import urlparse
  3. import aiohttp
  4. from fastapi.responses import JSONResponse
  5. import logging
  6. from typing import Callable, Optional
  7. from functools import partial
  8. from fastapi import APIRouter, Request, Response
  9. from gpustack.api.exceptions import (
  10. BadRequestException,
  11. ForbiddenException,
  12. )
  13. from gpustack.config.config import get_global_config
  14. from gpustack.utils.network import use_proxy_env_for_url
  15. router = APIRouter()
  16. logger = logging.getLogger(__name__)
  17. ALLOWED_SITES = [
  18. "https://modelscope.cn",
  19. "https://www.modelscope.cn",
  20. "https://huggingface.co",
  21. ]
  22. HEADER_FORWARDED_PREFIX = "x-forwarded-"
  23. HEADER_SKIPPED = [
  24. "date",
  25. "set-cookie",
  26. "host",
  27. "port",
  28. "proto",
  29. "referer",
  30. "server",
  31. "content-length",
  32. "transfer-encoding",
  33. "content-encoding",
  34. "cookie",
  35. "x-forwarded-host",
  36. "x-forwarded-port",
  37. "x-forwarded-proto",
  38. "x-forwarded-server",
  39. ]
  40. HF_ENDPOINT = os.getenv("HF_ENDPOINT")
  41. timeout = aiohttp.ClientTimeout(
  42. connect=15.0,
  43. sock_read=60.0,
  44. sock_connect=10.0,
  45. )
  46. def hf_token_process(url: str, headers: dict) -> dict:
  47. global_config = get_global_config()
  48. if global_config.huggingface_token and (
  49. url.startswith("https://huggingface.co") or HF_ENDPOINT
  50. ):
  51. headers["Authorization"] = f"Bearer {global_config.huggingface_token}"
  52. return headers
  53. @router.api_route("", methods=["GET", "POST", "PUT", "DELETE"])
  54. async def proxy(request: Request, url: str):
  55. validate_http_method(request.method)
  56. validate_url(url)
  57. url = replace_hf_endpoint(url)
  58. return await proxy_to(request, url, header_func=partial(hf_token_process, url))
  59. async def proxy_to(
  60. request: Request, url: str, header_func: Optional[Callable[[dict], dict]] = None
  61. ):
  62. forwarded_headers = process_headers(request.headers)
  63. if header_func is not None:
  64. forwarded_headers = header_func(forwarded_headers)
  65. try:
  66. data = (
  67. await request.body()
  68. if request.method in ["POST", "PUT", "DELETE"]
  69. else None
  70. )
  71. use_proxy_env = use_proxy_env_for_url(url)
  72. async with aiohttp.ClientSession(
  73. timeout=timeout, trust_env=use_proxy_env
  74. ) as session:
  75. async with session.request(
  76. method=request.method,
  77. url=url,
  78. headers=forwarded_headers,
  79. data=data,
  80. ) as resp:
  81. content = await resp.read()
  82. headers = {
  83. k: v
  84. for k, v in resp.headers.items()
  85. if k.lower() not in HEADER_SKIPPED
  86. }
  87. return Response(
  88. status_code=resp.status,
  89. content=content,
  90. headers=headers,
  91. media_type=headers.get("Content-Type"),
  92. )
  93. except Exception as e:
  94. return JSONResponse(
  95. status_code=500,
  96. content={"detail": str(e)},
  97. media_type="application/json",
  98. )
  99. def validate_http_method(method: str):
  100. allowed_methods = ["GET", "POST", "PUT", "DELETE"]
  101. if method not in allowed_methods:
  102. raise BadRequestException(message=f"HTTP method '{method}' is not allowed")
  103. def validate_url(url: str):
  104. if not url:
  105. raise BadRequestException(message="Missing 'url' query parameter")
  106. try:
  107. parsed_url = urlparse(url)
  108. except Exception:
  109. raise BadRequestException(message="Invalid 'url' query parameter")
  110. if not parsed_url.netloc or not parsed_url.scheme:
  111. raise BadRequestException(message="Invalid 'url' query parameter")
  112. for allowed_site in ALLOWED_SITES:
  113. parsed_allowed_site_url = urlparse(allowed_site)
  114. if (
  115. parsed_url.netloc == parsed_allowed_site_url.netloc
  116. and parsed_url.scheme == parsed_allowed_site_url.scheme
  117. ):
  118. return
  119. raise ForbiddenException(message="This site is not allowed")
  120. def replace_hf_endpoint(url: str) -> str:
  121. """
  122. Replace the huggingface.co domain with the specified endpoint if set.
  123. """
  124. if HF_ENDPOINT and url.startswith("https://huggingface.co"):
  125. return url.replace("https://huggingface.co", HF_ENDPOINT, 1)
  126. return url
  127. def process_headers(headers: dict) -> dict:
  128. processed_headers = {}
  129. for key, value in headers.items():
  130. if key.lower() in HEADER_SKIPPED:
  131. continue
  132. elif key.lower().startswith(HEADER_FORWARDED_PREFIX):
  133. new_key = key[len(HEADER_FORWARDED_PREFIX) :]
  134. processed_headers[new_key] = value
  135. # set accept-encoding to identity to avoid decompression
  136. # httpx automatically decodes the content and we want to keep it raw
  137. # See https://www.python-httpx.org/quickstart/#binary-response-content
  138. elif key.lower() == "accept-encoding":
  139. processed_headers[key] = "identity"
  140. else:
  141. processed_headers[key] = value
  142. return processed_headers