Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ full = [
"python-multipart>=0.0.18",
"pyyaml",
"httpx>=0.27.0,<0.29.0",
"backports.zstd; python_version < '3.14'",
]

[dependency-groups]
Expand All @@ -56,6 +57,7 @@ dev = [
"types-PyYAML==6.0.12.20250516",
"pytest==8.4.1",
"trio==0.30.0",
"zstandard>=0.25.0",
# Check dist
"twine==6.1.0"
]
Expand Down
142 changes: 142 additions & 0 deletions starlette/middleware/zstd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sys
from typing import NoReturn

from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send

if sys.version_info >= (3, 14): # pragma: no cover
from compression.zstd import ZstdCompressor
else: # pragma: no cover
from backports.zstd import ZstdCompressor

DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)


class ZstdMiddleware:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 3) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return

headers = Headers(scope=scope)
responder: ASGIApp
if "zstd" in headers.get("Accept-Encoding", ""):
responder = ZstdResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
else:
responder = IdentityResponder(self.app, self.minimum_size)

await responder(scope, receive, send)


class IdentityResponder:
content_encoding: str

def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
self.content_type_is_excluded = False

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_compression)

async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body" and not self.started:
self.started = True
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard response.
body = self.apply_compression(body, more_body=False)

headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
headers["Content-Length"] = str(len(body))
message["body"] = body

await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming response.
body = self.apply_compression(body, more_body=True)

headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
del headers["Content-Length"]
message["body"] = body

await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
# Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)

message["body"] = self.apply_compression(body, more_body=more_body)

await self.send(message)
elif message_type == "http.response.pathsend": # pragma: no branch
# Don't apply Zstd to pathsend responses
await self.send(self.initial_message)
await self.send(message)

def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
"""Apply compression on the response body.

If more_body is False, any compression file should be closed. If it
isn't, it won't be closed automatically until all background tasks
complete.
"""
return body


class ZstdResponder(IdentityResponder):
content_encoding = "zstd"

def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 3) -> None:
super().__init__(app, minimum_size)

self.compressor = ZstdCompressor(level=compresslevel)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await super().__call__(scope, receive, send)

def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
if more_body:
return self.compressor.compress(body, mode=ZstdCompressor.CONTINUE)
else:
return self.compressor.compress(body, mode=ZstdCompressor.FLUSH_FRAME)


async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
202 changes: 202 additions & 0 deletions tests/middleware/test_zstd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from __future__ import annotations

from pathlib import Path

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.zstd import ZstdMiddleware
from starlette.requests import Request
from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.types import Message
from tests.types import TestClientFactory


def test_zstd_responses(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "zstd"
assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) < 4000


def test_zstd_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) == 4000


def test_zstd_ignored_for_small_responses(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK", status_code=200)

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "OK"
assert "Content-Encoding" not in response.headers
assert "Vary" not in response.headers
assert int(response.headers["Content-Length"]) == 2


def test_zstd_streaming_response(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "zstd"
assert response.headers["Vary"] == "Accept-Encoding"
assert "Content-Length" not in response.headers


def test_zstd_streaming_response_identity(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert response.headers["Vary"] == "Accept-Encoding"
assert "Content-Length" not in response.headers


def test_zstd_ignored_for_responses_with_encoding_set(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd, text"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
assert "Vary" not in response.headers
assert "Content-Length" not in response.headers


def test_zstd_ignored_on_server_sent_events(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for _ in range(count):
yield bytes

streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200, media_type="text/event-stream")

app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(ZstdMiddleware)],
)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert "Content-Length" not in response.headers


@pytest.mark.anyio
async def test_zstd_ignored_for_pathsend_responses(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("<file content>")

events: list[Message] = []

async def endpoint_with_pathsend(request: Request) -> FileResponse:
_ = await request.body()
return FileResponse(path)

app = Starlette(
routes=[Route("/", endpoint=endpoint_with_pathsend)],
middleware=[Middleware(ZstdMiddleware)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"headers": [(b"accept-encoding", b"zstd, text")],
"extensions": {"http.response.pathsend": {}},
}

async def receive() -> Message:
return {"type": "http.request", "body": b"", "more_body": False}

async def send(message: Message) -> None:
events.append(message)

await app(scope, receive, send)

assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"
Loading
Loading