Skip to content

Commit 1a92d45

Browse files
kepleremontnemery
andauthored
Fix playing TTS and local media source over DLNA (home-assistant#134903)
Co-authored-by: Erik Montnemery <[email protected]>
1 parent 7b80c1c commit 1a92d45

File tree

8 files changed

+134
-9
lines changed

8 files changed

+134
-9
lines changed

homeassistant/components/http/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ async def auth_middleware(
223223
# We first start with a string check to avoid parsing query params
224224
# for every request.
225225
elif (
226-
request.method == "GET"
226+
request.method in ["GET", "HEAD"]
227227
and SIGN_QUERY_PARAM in request.query_string
228228
and async_validate_signed_request(request)
229229
):

homeassistant/components/image/__init__.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,10 @@ def __init__(self, component: EntityComponent[ImageEntity]) -> None:
288288
"""Initialize an image view."""
289289
self.component = component
290290

291-
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
292-
"""Start a GET request."""
291+
async def _authenticate_request(
292+
self, request: web.Request, entity_id: str
293+
) -> ImageEntity:
294+
"""Authenticate request and return image entity."""
293295
if (image_entity := self.component.get_entity(entity_id)) is None:
294296
raise web.HTTPNotFound
295297

@@ -306,6 +308,31 @@ async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
306308
# Invalid sigAuth or image entity access token
307309
raise web.HTTPForbidden
308310

311+
return image_entity
312+
313+
async def head(self, request: web.Request, entity_id: str) -> web.Response:
314+
"""Start a HEAD request.
315+
316+
This is sent by some DLNA renderers, like Samsung ones, prior to sending
317+
the GET request.
318+
"""
319+
image_entity = await self._authenticate_request(request, entity_id)
320+
321+
# Don't use `handle` as we don't care about the stream case, we only want
322+
# to verify that the image exists.
323+
try:
324+
image = await _async_get_image(image_entity, IMAGE_TIMEOUT)
325+
except (HomeAssistantError, ValueError) as ex:
326+
raise web.HTTPInternalServerError from ex
327+
328+
return web.Response(
329+
content_type=image.content_type,
330+
headers={"Content-Length": str(len(image.content))},
331+
)
332+
333+
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
334+
"""Start a GET request."""
335+
image_entity = await self._authenticate_request(request, entity_id)
309336
return await self.handle(request, image_entity)
310337

311338
async def handle(
@@ -317,7 +344,11 @@ async def handle(
317344
except (HomeAssistantError, ValueError) as ex:
318345
raise web.HTTPInternalServerError from ex
319346

320-
return web.Response(body=image.content, content_type=image.content_type)
347+
return web.Response(
348+
body=image.content,
349+
content_type=image.content_type,
350+
headers={"Content-Length": str(len(image.content))},
351+
)
321352

322353

323354
async def async_get_still_stream(

homeassistant/components/media_source/local_source.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,8 @@ def __init__(self, hass: HomeAssistant, source: LocalSource) -> None:
210210
self.hass = hass
211211
self.source = source
212212

213-
async def get(
214-
self, request: web.Request, source_dir_id: str, location: str
215-
) -> web.FileResponse:
216-
"""Start a GET request."""
213+
async def _validate_media_path(self, source_dir_id: str, location: str) -> Path:
214+
"""Validate media path and return it if valid."""
217215
try:
218216
raise_if_invalid_path(location)
219217
except ValueError as err:
@@ -233,6 +231,25 @@ async def get(
233231
if not mime_type or mime_type.split("/")[0] not in MEDIA_MIME_TYPES:
234232
raise web.HTTPNotFound
235233

234+
return media_path
235+
236+
async def head(
237+
self, request: web.Request, source_dir_id: str, location: str
238+
) -> None:
239+
"""Handle a HEAD request.
240+
241+
This is sent by some DLNA renderers, like Samsung ones, prior to sending
242+
the GET request.
243+
244+
Check whether the location exists or not.
245+
"""
246+
await self._validate_media_path(source_dir_id, location)
247+
248+
async def get(
249+
self, request: web.Request, source_dir_id: str, location: str
250+
) -> web.FileResponse:
251+
"""Handle a GET request."""
252+
media_path = await self._validate_media_path(source_dir_id, location)
236253
return web.FileResponse(media_path)
237254

238255

homeassistant/components/tts/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,21 @@ def __init__(self, manager: SpeechManager) -> None:
11851185
"""Initialize a tts view."""
11861186
self.manager = manager
11871187

1188+
async def head(self, request: web.Request, token: str) -> web.StreamResponse:
1189+
"""Start a HEAD request.
1190+
1191+
This is sent by some DLNA renderers, like Samsung ones, prior to sending
1192+
the GET request.
1193+
1194+
Check whether the token (file) exists and return its content type.
1195+
"""
1196+
stream = self.manager.token_to_stream.get(token)
1197+
1198+
if stream is None:
1199+
return web.Response(status=HTTPStatus.NOT_FOUND)
1200+
1201+
return web.Response(content_type=stream.content_type)
1202+
11881203
async def get(self, request: web.Request, token: str) -> web.StreamResponse:
11891204
"""Start a get request."""
11901205
stream = self.manager.token_to_stream.get(token)

tests/components/http/test_auth.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,16 +305,22 @@ async def test_auth_access_signed_path_with_refresh_token(
305305
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
306306
)
307307

308+
req = await client.head(signed_path)
309+
assert req.status == HTTPStatus.OK
310+
308311
req = await client.get(signed_path)
309312
assert req.status == HTTPStatus.OK
310313
data = await req.json()
311314
assert data["user_id"] == refresh_token.user.id
312315

313316
# Use signature on other path
317+
req = await client.head(f"/another_path?{signed_path.split('?')[1]}")
318+
assert req.status == HTTPStatus.UNAUTHORIZED
319+
314320
req = await client.get(f"/another_path?{signed_path.split('?')[1]}")
315321
assert req.status == HTTPStatus.UNAUTHORIZED
316322

317-
# We only allow GET
323+
# We only allow GET and HEAD
318324
req = await client.post(signed_path)
319325
assert req.status == HTTPStatus.UNAUTHORIZED
320326

tests/components/image/test_init.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,22 @@ async def test_fetch_image_authenticated(
174174
"""Test fetching an image with an authenticated client."""
175175
client = await hass_client()
176176

177+
# Using HEAD
178+
resp = await client.head("/api/image_proxy/image.test")
179+
assert resp.status == HTTPStatus.OK
180+
assert resp.content_type == "image/jpeg"
181+
assert resp.content_length == 4
182+
183+
resp = await client.head("/api/image_proxy/image.unknown")
184+
assert resp.status == HTTPStatus.NOT_FOUND
185+
186+
# Using GET
177187
resp = await client.get("/api/image_proxy/image.test")
178188
assert resp.status == HTTPStatus.OK
179189
body = await resp.read()
180190
assert body == b"Test"
191+
assert resp.content_type == "image/jpeg"
192+
assert resp.content_length == 4
181193

182194
resp = await client.get("/api/image_proxy/image.unknown")
183195
assert resp.status == HTTPStatus.NOT_FOUND
@@ -260,10 +272,19 @@ async def test_fetch_image_url_success(
260272

261273
client = await hass_client()
262274

275+
# Using HEAD
276+
resp = await client.head("/api/image_proxy/image.test")
277+
assert resp.status == HTTPStatus.OK
278+
assert resp.content_type == "image/png"
279+
assert resp.content_length == 4
280+
281+
# Using GET
263282
resp = await client.get("/api/image_proxy/image.test")
264283
assert resp.status == HTTPStatus.OK
265284
body = await resp.read()
266285
assert body == b"Test"
286+
assert resp.content_type == "image/png"
287+
assert resp.content_length == 4
267288

268289

269290
@respx.mock

tests/components/media_source/test_local_source.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,33 @@ async def test_media_view(
105105
client = await hass_client()
106106

107107
# Protects against non-existent files
108+
resp = await client.head("/media/local/invalid.txt")
109+
assert resp.status == HTTPStatus.NOT_FOUND
110+
108111
resp = await client.get("/media/local/invalid.txt")
109112
assert resp.status == HTTPStatus.NOT_FOUND
110113

111114
resp = await client.get("/media/recordings/invalid.txt")
112115
assert resp.status == HTTPStatus.NOT_FOUND
113116

114117
# Protects against non-media files
118+
resp = await client.head("/media/local/not_media.txt")
119+
assert resp.status == HTTPStatus.NOT_FOUND
120+
115121
resp = await client.get("/media/local/not_media.txt")
116122
assert resp.status == HTTPStatus.NOT_FOUND
117123

118124
# Protects against unknown local media sources
125+
resp = await client.head("/media/unknown_source/not_media.txt")
126+
assert resp.status == HTTPStatus.NOT_FOUND
127+
119128
resp = await client.get("/media/unknown_source/not_media.txt")
120129
assert resp.status == HTTPStatus.NOT_FOUND
121130

122131
# Fetch available media
132+
resp = await client.head("/media/local/test.mp3")
133+
assert resp.status == HTTPStatus.OK
134+
123135
resp = await client.get("/media/local/test.mp3")
124136
assert resp.status == HTTPStatus.OK
125137

tests/components/tts/test_init.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,29 @@ async def test_web_view_wrong_file(
916916
assert req.status == HTTPStatus.NOT_FOUND
917917

918918

919+
@pytest.mark.parametrize(
920+
("setup", "expected_url_suffix"),
921+
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],
922+
indirect=["setup"],
923+
)
924+
async def test_web_view_wrong_file_with_head_request(
925+
hass: HomeAssistant,
926+
hass_client: ClientSessionGenerator,
927+
setup: str,
928+
expected_url_suffix: str,
929+
) -> None:
930+
"""Set up a TTS platform and receive wrong file from web."""
931+
client = await hass_client()
932+
933+
url = (
934+
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
935+
f"_en-us_-_{expected_url_suffix}.mp3"
936+
)
937+
938+
req = await client.head(url)
939+
assert req.status == HTTPStatus.NOT_FOUND
940+
941+
919942
@pytest.mark.parametrize(
920943
("setup", "expected_url_suffix"),
921944
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],

0 commit comments

Comments
 (0)