| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- import os
- from urllib.parse import urlparse
- import aiohttp
- from fastapi.responses import JSONResponse
- import logging
- from typing import Callable, Optional
- from functools import partial
- from fastapi import APIRouter, Request, Response
- from gpustack.api.exceptions import (
- BadRequestException,
- ForbiddenException,
- )
- from gpustack.config.config import get_global_config
- from gpustack.utils.network import use_proxy_env_for_url
- router = APIRouter()
- logger = logging.getLogger(__name__)
- ALLOWED_SITES = [
- "https://modelscope.cn",
- "https://www.modelscope.cn",
- "https://huggingface.co",
- ]
- HEADER_FORWARDED_PREFIX = "x-forwarded-"
- HEADER_SKIPPED = [
- "date",
- "set-cookie",
- "host",
- "port",
- "proto",
- "referer",
- "server",
- "content-length",
- "transfer-encoding",
- "content-encoding",
- "cookie",
- "x-forwarded-host",
- "x-forwarded-port",
- "x-forwarded-proto",
- "x-forwarded-server",
- ]
- HF_ENDPOINT = os.getenv("HF_ENDPOINT")
- timeout = aiohttp.ClientTimeout(
- connect=15.0,
- sock_read=60.0,
- sock_connect=10.0,
- )
- def hf_token_process(url: str, headers: dict) -> dict:
- global_config = get_global_config()
- if global_config.huggingface_token and (
- url.startswith("https://huggingface.co") or HF_ENDPOINT
- ):
- headers["Authorization"] = f"Bearer {global_config.huggingface_token}"
- return headers
- @router.api_route("", methods=["GET", "POST", "PUT", "DELETE"])
- async def proxy(request: Request, url: str):
- validate_http_method(request.method)
- validate_url(url)
- url = replace_hf_endpoint(url)
- return await proxy_to(request, url, header_func=partial(hf_token_process, url))
- async def proxy_to(
- request: Request, url: str, header_func: Optional[Callable[[dict], dict]] = None
- ):
- forwarded_headers = process_headers(request.headers)
- if header_func is not None:
- forwarded_headers = header_func(forwarded_headers)
- try:
- data = (
- await request.body()
- if request.method in ["POST", "PUT", "DELETE"]
- else None
- )
- use_proxy_env = use_proxy_env_for_url(url)
- async with aiohttp.ClientSession(
- timeout=timeout, trust_env=use_proxy_env
- ) as session:
- async with session.request(
- method=request.method,
- url=url,
- headers=forwarded_headers,
- data=data,
- ) as resp:
- content = await resp.read()
- headers = {
- k: v
- for k, v in resp.headers.items()
- if k.lower() not in HEADER_SKIPPED
- }
- return Response(
- status_code=resp.status,
- content=content,
- headers=headers,
- media_type=headers.get("Content-Type"),
- )
- except Exception as e:
- return JSONResponse(
- status_code=500,
- content={"detail": str(e)},
- media_type="application/json",
- )
- def validate_http_method(method: str):
- allowed_methods = ["GET", "POST", "PUT", "DELETE"]
- if method not in allowed_methods:
- raise BadRequestException(message=f"HTTP method '{method}' is not allowed")
- def validate_url(url: str):
- if not url:
- raise BadRequestException(message="Missing 'url' query parameter")
- try:
- parsed_url = urlparse(url)
- except Exception:
- raise BadRequestException(message="Invalid 'url' query parameter")
- if not parsed_url.netloc or not parsed_url.scheme:
- raise BadRequestException(message="Invalid 'url' query parameter")
- for allowed_site in ALLOWED_SITES:
- parsed_allowed_site_url = urlparse(allowed_site)
- if (
- parsed_url.netloc == parsed_allowed_site_url.netloc
- and parsed_url.scheme == parsed_allowed_site_url.scheme
- ):
- return
- raise ForbiddenException(message="This site is not allowed")
- def replace_hf_endpoint(url: str) -> str:
- """
- Replace the huggingface.co domain with the specified endpoint if set.
- """
- if HF_ENDPOINT and url.startswith("https://huggingface.co"):
- return url.replace("https://huggingface.co", HF_ENDPOINT, 1)
- return url
- def process_headers(headers: dict) -> dict:
- processed_headers = {}
- for key, value in headers.items():
- if key.lower() in HEADER_SKIPPED:
- continue
- elif key.lower().startswith(HEADER_FORWARDED_PREFIX):
- new_key = key[len(HEADER_FORWARDED_PREFIX) :]
- processed_headers[new_key] = value
- # set accept-encoding to identity to avoid decompression
- # httpx automatically decodes the content and we want to keep it raw
- # See https://www.python-httpx.org/quickstart/#binary-response-content
- elif key.lower() == "accept-encoding":
- processed_headers[key] = "identity"
- else:
- processed_headers[key] = value
- return processed_headers
|