Skip to content

Commit 63b8950

Browse files
committed
Support returning Not Modified responses in FileResponse
1 parent fa53554 commit 63b8950

File tree

2 files changed

+55
-64
lines changed

2 files changed

+55
-64
lines changed

starlette/responses.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence
1212
from datetime import datetime
13-
from email.utils import format_datetime, formatdate
13+
from email.utils import format_datetime, formatdate, parsedate
1414
from functools import partial
1515
from mimetypes import guess_type
1616
from secrets import token_hex
@@ -297,6 +297,15 @@ def __init__(self, max_size: int) -> None:
297297
class FileResponse(Response):
298298
chunk_size = 64 * 1024
299299

300+
NOT_MODIFIED_HEADERS = {
301+
b"cache-control",
302+
b"content-location",
303+
b"date",
304+
b"etag",
305+
b"expires",
306+
b"vary",
307+
}
308+
300309
def __init__(
301310
self,
302311
path: str | os.PathLike[str],
@@ -362,12 +371,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
362371
stat_result = self.stat_result
363372

364373
headers = Headers(scope=scope)
374+
http_if_none_match = headers.get("if-none-match")
375+
http_if_modified_since = headers.get("if-modified-since")
365376
http_range = headers.get("range")
366377
http_if_range = headers.get("if-range")
367378

368-
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
369-
await self._handle_simple(send, send_header_only, send_pathsend)
370-
else:
379+
if self.status_code == 200 and self._is_not_modified(http_if_none_match, http_if_modified_since):
380+
await self._handle_not_modified(send)
381+
elif self.status_code == 200 and http_range is not None and self._should_use_range(http_if_range):
371382
try:
372383
ranges = self._parse_range_header(http_range, stat_result.st_size)
373384
except MalformedRangeHeader as exc:
@@ -381,6 +392,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
381392
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
382393
else:
383394
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
395+
else:
396+
await self._handle_simple(send, send_header_only, send_pathsend)
384397

385398
if self.background is not None:
386399
await self.background()
@@ -399,6 +412,11 @@ async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend
399412
more_body = len(chunk) == self.chunk_size
400413
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
401414

415+
async def _handle_not_modified(self, send: Send) -> None:
416+
headers = [(k, v) for k, v in self.raw_headers if k in FileResponse.NOT_MODIFIED_HEADERS]
417+
await send({"type": "http.response.start", "status": 304, "headers": headers})
418+
await send({"type": "http.response.body", "body": b"", "more_body": False})
419+
402420
async def _handle_single_range(
403421
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
404422
) -> None:
@@ -452,8 +470,36 @@ async def _handle_multiple_ranges(
452470
}
453471
)
454472

455-
def _should_use_range(self, http_if_range: str) -> bool:
456-
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
473+
def _is_not_modified(self, http_if_none_match: str | None, http_if_modified_since: str | None) -> bool:
474+
"""
475+
Given the request and response headers, return `True` if an HTTP
476+
"Not Modified" response could be returned instead.
477+
"""
478+
if http_if_none_match is not None:
479+
try:
480+
match = [tag.strip(" W/") for tag in http_if_none_match.split(",")]
481+
etag = self.headers["etag"]
482+
return etag in match # Client already has the version with current tag
483+
except KeyError:
484+
pass
485+
486+
if http_if_modified_since:
487+
try:
488+
since = parsedate(http_if_modified_since)
489+
last_modified = parsedate(self.headers["last-modified"])
490+
if since is not None and last_modified is not None:
491+
return since >= last_modified
492+
except KeyError:
493+
pass
494+
495+
return False
496+
497+
def _should_use_range(self, http_if_range: str | None) -> bool:
498+
return http_if_range in (
499+
None,
500+
self.headers["last-modified"],
501+
self.headers["etag"],
502+
)
457503

458504
@staticmethod
459505
def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:

starlette/staticfiles.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,20 @@
44
import importlib.util
55
import os
66
import stat
7-
from email.utils import parsedate
87
from typing import Union
98

109
import anyio
1110
import anyio.to_thread
1211

1312
from starlette._utils import get_route_path
14-
from starlette.datastructures import URL, Headers
13+
from starlette.datastructures import URL
1514
from starlette.exceptions import HTTPException
1615
from starlette.responses import FileResponse, RedirectResponse, Response
1716
from starlette.types import Receive, Scope, Send
1817

1918
PathLike = Union[str, "os.PathLike[str]"]
2019

2120

22-
class NotModifiedResponse(Response):
23-
NOT_MODIFIED_HEADERS = (
24-
"cache-control",
25-
"content-location",
26-
"date",
27-
"etag",
28-
"expires",
29-
"vary",
30-
)
31-
32-
def __init__(self, headers: Headers):
33-
super().__init__(
34-
status_code=304,
35-
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
36-
)
37-
38-
3921
class StaticFiles:
4022
def __init__(
4123
self,
@@ -126,7 +108,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
126108

127109
if stat_result and stat.S_ISREG(stat_result.st_mode):
128110
# We have a static file to serve.
129-
return self.file_response(full_path, stat_result, scope)
111+
return FileResponse(full_path, stat_result=stat_result)
130112

131113
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
132114
# We're in HTML mode, and have got a directory URL.
@@ -139,7 +121,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
139121
url = URL(scope=scope)
140122
url = url.replace(path=url.path + "/")
141123
return RedirectResponse(url=url)
142-
return self.file_response(full_path, stat_result, scope)
124+
return FileResponse(full_path, stat_result=stat_result)
143125

144126
if self.html:
145127
# Check for '404.html' if we're in HTML mode.
@@ -166,20 +148,6 @@ def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
166148
continue
167149
return "", None
168150

169-
def file_response(
170-
self,
171-
full_path: PathLike,
172-
stat_result: os.stat_result,
173-
scope: Scope,
174-
status_code: int = 200,
175-
) -> Response:
176-
request_headers = Headers(scope=scope)
177-
178-
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
179-
if self.is_not_modified(response.headers, request_headers):
180-
return NotModifiedResponse(response.headers)
181-
return response
182-
183151
async def check_config(self) -> None:
184152
"""
185153
Perform a one-off configuration check that StaticFiles is actually
@@ -195,26 +163,3 @@ async def check_config(self) -> None:
195163
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
196164
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
197165
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
198-
199-
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
200-
"""
201-
Given the request and response headers, return `True` if an HTTP
202-
"Not Modified" response could be returned instead.
203-
"""
204-
try:
205-
if_none_match = request_headers["if-none-match"]
206-
etag = response_headers["etag"]
207-
if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
208-
return True
209-
except KeyError:
210-
pass
211-
212-
try:
213-
if_modified_since = parsedate(request_headers["if-modified-since"])
214-
last_modified = parsedate(response_headers["last-modified"])
215-
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
216-
return True
217-
except KeyError:
218-
pass
219-
220-
return False

0 commit comments

Comments
 (0)