""" This middleware can be used when a known proxy is fronting the application, and is trusted to be properly setting the `X-Forwarded-Proto` and `X-Forwarded-For` headers with the connecting client information. Modifies the `client` and `scheme` information so that they reference the connecting client, rather that the connecting proxy. https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies """ from __future__ import annotations from typing import Union, cast from uvicorn._types import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope, WebSocketScope class ProxyHeadersMiddleware: def __init__( self, app: ASGI3Application, trusted_hosts: list[str] | str = "127.0.0.1", ) -> None: self.app = app if isinstance(trusted_hosts, str): self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")} else: self.trusted_hosts = set(trusted_hosts) self.always_trust = "*" in self.trusted_hosts def get_trusted_client_host(self, x_forwarded_for_hosts: list[str]) -> str | None: if self.always_trust: return x_forwarded_for_hosts[0] for host in reversed(x_forwarded_for_hosts): if host not in self.trusted_hosts: return host return None async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: if scope["type"] in ("http", "websocket"): scope = cast(Union["HTTPScope", "WebSocketScope"], scope) client_addr: tuple[str, int] | None = scope.get("client") client_host = client_addr[0] if client_addr else None if self.always_trust or client_host in self.trusted_hosts: headers = dict(scope["headers"]) if b"x-forwarded-proto" in headers: # Determine if the incoming request was http or https based on # the X-Forwarded-Proto header. x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip() if scope["type"] == "websocket": scope["scheme"] = x_forwarded_proto.replace("http", "ws") else: scope["scheme"] = x_forwarded_proto if b"x-forwarded-for" in headers: # Determine the client address from the last trusted IP in the # X-Forwarded-For header. We've lost the connecting client's port # information by now, so only include the host. x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") x_forwarded_for_hosts = [item.strip() for item in x_forwarded_for.split(",")] host = self.get_trusted_client_host(x_forwarded_for_hosts) port = 0 scope["client"] = (host, port) # type: ignore[arg-type] return await self.app(scope, receive, send)