Skip to content

Commit 4b9caad

Browse files
committedFeb 4, 2025·
Add a router based on werkzeug.routing.
Fix #311.
1 parent 602d719 commit 4b9caad

File tree

19 files changed

+879
-18
lines changed

19 files changed

+879
-18
lines changed
 

‎docs/conf.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@
8282
assert PythonDomain.object_types["data"].roles == ("data", "obj")
8383
PythonDomain.object_types["data"].roles = ("data", "class", "obj")
8484

85-
intersphinx_mapping = {"python": ("https://docs.python.org/3", None)}
85+
intersphinx_mapping = {
86+
"python": ("https://docs.python.org/3", None),
87+
"werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None),
88+
}
8689

8790
spelling_show_suggestions = True
8891

‎docs/faq/client.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The connection is closed when exiting the context manager.
8181
How do I reconnect when the connection drops?
8282
---------------------------------------------
8383

84-
Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator::
84+
Use :func:`connect` as an asynchronous iterator::
8585

8686
from websockets.asyncio.client import connect
8787
from websockets.exceptions import ConnectionClosed

‎docs/faq/server.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Record all connections in a global variable::
116116
finally:
117117
CONNECTIONS.remove(websocket)
118118

119-
Then, call :func:`~websockets.asyncio.server.broadcast`::
119+
Then, call :func:`broadcast`::
120120

121121
from websockets.asyncio.server import broadcast
122122

@@ -219,6 +219,8 @@ You may route a connection to different handlers depending on the request path::
219219
# No handler for this path; close the connection.
220220
return
221221

222+
For more complex routing, you may use :func:`~websockets.asyncio.router.route`.
223+
222224
You may also route the connection based on the first message received from the
223225
client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to
224226
authenticate the connection before routing it, this is usually more convenient.

‎docs/project/changelog.rst

+6
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ Backwards-incompatible changes
5656

5757
See :doc:`keepalive and latency <../topics/keepalive>` for details.
5858

59+
New features
60+
............
61+
62+
* Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to
63+
dispatch connections to different handlers depending on the URL.
64+
5965
Improvements
6066
............
6167

‎docs/reference/asyncio/server.rst

+17-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ Creating a server
1212
.. autofunction:: unix_serve
1313
:async:
1414

15+
Routing connections
16+
-------------------
17+
18+
.. automodule:: websockets.asyncio.router
19+
20+
.. autofunction:: route
21+
:async:
22+
23+
.. autofunction:: unix_route
24+
:async:
25+
26+
.. autoclass:: Router
27+
28+
.. currentmodule:: websockets.asyncio.server
29+
1530
Running a server
1631
----------------
1732

@@ -89,12 +104,12 @@ Using a connection
89104
Broadcast
90105
---------
91106

92-
.. autofunction:: websockets.asyncio.server.broadcast
107+
.. autofunction:: broadcast
93108

94109
HTTP Basic Authentication
95110
-------------------------
96111

97112
websockets supports HTTP Basic Authentication according to
98113
:rfc:`7235` and :rfc:`7617`.
99114

100-
.. autofunction:: websockets.asyncio.server.basic_auth
115+
.. autofunction:: basic_auth

‎docs/reference/features.rst

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ Server
127127
+------------------------------------+--------+--------+--------+--------+
128128
| Perform HTTP Digest Authentication |||||
129129
+------------------------------------+--------+--------+--------+--------+
130+
| Dispatch connections to handlers |||||
131+
+------------------------------------+--------+--------+--------+--------+
130132

131133
Client
132134
------

‎docs/reference/sync/server.rst

+26-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,31 @@ Creating a server
1010

1111
.. autofunction:: unix_serve
1212

13+
Routing connections
14+
-------------------
15+
16+
.. automodule:: websockets.sync.router
17+
18+
.. autofunction:: route
19+
20+
.. autofunction:: unix_route
21+
22+
.. autoclass:: Router
23+
24+
.. currentmodule:: websockets.sync.server
25+
26+
Routing connections
27+
-------------------
28+
29+
.. autofunction:: route
30+
:async:
31+
32+
.. autofunction:: unix_route
33+
:async:
34+
35+
.. autoclass:: Server
36+
37+
1338
Running a server
1439
----------------
1540

@@ -78,4 +103,4 @@ HTTP Basic Authentication
78103
websockets supports HTTP Basic Authentication according to
79104
:rfc:`7235` and :rfc:`7617`.
80105

81-
.. autofunction:: websockets.sync.server.basic_auth
106+
.. autofunction:: basic_auth

‎docs/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ sphinx-inline-tabs
66
sphinxcontrib-spelling
77
sphinxcontrib-trio
88
sphinxext-opengraph
9+
werkzeug

‎src/websockets/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
"connect",
1313
"unix_connect",
1414
"ClientConnection",
15+
# .asyncio.router
16+
"route",
17+
"unix_route",
18+
"Router",
1519
# .asyncio.server
1620
"basic_auth",
1721
"broadcast",
@@ -79,6 +83,7 @@
7983
# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
8084
if TYPE_CHECKING:
8185
from .asyncio.client import ClientConnection, connect, unix_connect
86+
from .asyncio.router import Router, route, unix_route
8287
from .asyncio.server import (
8388
Server,
8489
ServerConnection,
@@ -138,6 +143,10 @@
138143
"connect": ".asyncio.client",
139144
"unix_connect": ".asyncio.client",
140145
"ClientConnection": ".asyncio.client",
146+
# .asyncio.router
147+
"route": ".asyncio.router",
148+
"unix_route": ".asyncio.router",
149+
"Router": ".asyncio.router",
141150
# .asyncio.server
142151
"basic_auth": ".asyncio.server",
143152
"broadcast": ".asyncio.server",

‎src/websockets/asyncio/router.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from __future__ import annotations
2+
3+
import http
4+
import ssl as ssl_module
5+
import urllib.parse
6+
from typing import Any, Awaitable, Callable, Literal
7+
8+
from werkzeug.exceptions import NotFound
9+
from werkzeug.routing import Map, RequestRedirect
10+
11+
from ..http11 import Request, Response
12+
from .server import Server, ServerConnection, serve
13+
14+
15+
__all__ = ["route", "unix_route", "Router"]
16+
17+
18+
class Router:
19+
"""WebSocket router supporting :func:`route`."""
20+
21+
def __init__(
22+
self,
23+
url_map: Map,
24+
server_name: str | None = None,
25+
url_scheme: str = "ws",
26+
) -> None:
27+
self.url_map = url_map
28+
self.server_name = server_name
29+
self.url_scheme = url_scheme
30+
for rule in self.url_map.iter_rules():
31+
rule.websocket = True
32+
33+
def get_server_name(self, connection: ServerConnection, request: Request) -> str:
34+
if self.server_name is None:
35+
return request.headers["Host"]
36+
else:
37+
return self.server_name
38+
39+
def redirect(self, connection: ServerConnection, url: str) -> Response:
40+
response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}")
41+
response.headers["Location"] = url
42+
return response
43+
44+
def not_found(self, connection: ServerConnection) -> Response:
45+
return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found")
46+
47+
def route_request(
48+
self, connection: ServerConnection, request: Request
49+
) -> Response | None:
50+
"""Route incoming request."""
51+
url_map_adapter = self.url_map.bind(
52+
server_name=self.get_server_name(connection, request),
53+
url_scheme=self.url_scheme,
54+
)
55+
try:
56+
parsed = urllib.parse.urlparse(request.path)
57+
handler, kwargs = url_map_adapter.match(
58+
path_info=parsed.path,
59+
query_args=parsed.query,
60+
)
61+
except RequestRedirect as redirect:
62+
return self.redirect(connection, redirect.new_url)
63+
except NotFound:
64+
return self.not_found(connection)
65+
connection.handler, connection.handler_kwargs = handler, kwargs
66+
return None
67+
68+
async def handler(self, connection: ServerConnection) -> None:
69+
"""Handle a connection."""
70+
return await connection.handler(connection, **connection.handler_kwargs)
71+
72+
73+
def route(
74+
url_map: Map,
75+
*args: Any,
76+
server_name: str | None = None,
77+
ssl: ssl_module.SSLContext | Literal[True] | None = None,
78+
create_router: type[Router] | None = None,
79+
**kwargs: Any,
80+
) -> Awaitable[Server]:
81+
"""
82+
Create a WebSocket server dispatching connections to different handlers.
83+
84+
This feature requires the third-party library `werkzeug`_::
85+
86+
$ pip install werkzeug
87+
88+
.. _werkzeug: https://werkzeug.palletsprojects.com/
89+
90+
:func:`route` accepts the same arguments as
91+
:func:`~websockets.sync.server.serve`, except as described below.
92+
93+
The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns
94+
to connection handlers. In addition to the connection, handlers receive
95+
parameters captured in the URL as keyword arguments.
96+
97+
Here's an example::
98+
99+
100+
from websockets.asyncio.router import route
101+
from werkzeug.routing import Map, Rule
102+
103+
async def channel_handler(websocket, channel_id):
104+
...
105+
106+
url_map = Map([
107+
Rule("/channel/<uuid:channel_id>", endpoint=channel_handler),
108+
...
109+
])
110+
111+
# set this future to exit the server
112+
stop = asyncio.get_running_loop().create_future()
113+
114+
async with route(url_map, ...) as server:
115+
await stop
116+
117+
118+
Refer to the documentation of :mod:`werkzeug.routing` for details.
119+
120+
If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map,
121+
when the server runs behind a reverse proxy that modifies the ``Host``
122+
header or terminates TLS, you need additional configuration:
123+
124+
* Set ``server_name`` to the name of the server as seen by clients. When not
125+
provided, websockets uses the value of the ``Host`` header.
126+
127+
* Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling
128+
TLS. Under the hood, this bind the URL map with a ``url_scheme`` of
129+
``wss://`` instead of ``ws://``.
130+
131+
There is no need to specify ``websocket=True`` in each rule. It is added
132+
automatically.
133+
134+
Args:
135+
url_map: Mapping of URL patterns to connection handlers.
136+
server_name: Name of the server as seen by clients. If :obj:`None`,
137+
websockets uses the value of the ``Host`` header.
138+
ssl: Configuration for enabling TLS on the connection. Set it to
139+
:obj:`True` if a reverse proxy terminates TLS connections.
140+
create_router: Factory for the :class:`Router` dispatching requests to
141+
handlers. Set it to a wrapper or a subclass to customize routing.
142+
143+
"""
144+
url_scheme = "ws" if ssl is None else "wss"
145+
if ssl is not True and ssl is not None:
146+
kwargs["ssl"] = ssl
147+
148+
if create_router is None:
149+
create_router = Router
150+
151+
router = create_router(url_map, server_name, url_scheme)
152+
153+
_process_request: (
154+
Callable[
155+
[ServerConnection, Request],
156+
Awaitable[Response | None] | Response | None,
157+
]
158+
| None
159+
) = kwargs.pop("process_request", None)
160+
if _process_request is None:
161+
process_request: Callable[
162+
[ServerConnection, Request],
163+
Awaitable[Response | None] | Response | None,
164+
] = router.route_request
165+
else:
166+
167+
async def process_request(
168+
connection: ServerConnection, request: Request
169+
) -> Response | None:
170+
response = _process_request(connection, request)
171+
if isinstance(response, Awaitable):
172+
response = await response
173+
if response is not None:
174+
return response
175+
return router.route_request(connection, request)
176+
177+
return serve(router.handler, *args, process_request=process_request, **kwargs)
178+
179+
180+
def unix_route(
181+
url_map: Map,
182+
path: str | None = None,
183+
**kwargs: Any,
184+
) -> Awaitable[Server]:
185+
"""
186+
Create a WebSocket Unix server dispatching connections to different handlers.
187+
188+
:func:`unix_route` combines the behaviors of :func:`route` and
189+
:func:`~websockets.asyncio.server.unix_serve`.
190+
191+
Args:
192+
url_map: Mapping of URL patterns to connection handlers.
193+
path: File system path to the Unix socket.
194+
195+
"""
196+
return route(url_map, unix=True, path=path, **kwargs)

‎src/websockets/asyncio/server.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
from collections.abc import Awaitable, Generator, Iterable, Sequence
1111
from types import TracebackType
12-
from typing import Any, Callable, cast
12+
from typing import Any, Callable, Mapping, cast
1313

1414
from ..exceptions import InvalidHeader
1515
from ..extensions.base import ServerExtensionFactory
@@ -87,6 +87,8 @@ def __init__(
8787
self.server = server
8888
self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
8989
self.username: str # see basic_auth()
90+
self.handler: Callable[[ServerConnection], Awaitable[None]] # see route()
91+
self.handler_kwargs: Mapping[str, Any] # see route()
9092

9193
def respond(self, status: StatusLike, text: str) -> Response:
9294
"""

‎src/websockets/sync/router.py

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from __future__ import annotations
2+
3+
import http
4+
import ssl as ssl_module
5+
import urllib.parse
6+
from typing import Any, Callable, Literal
7+
8+
from werkzeug.exceptions import NotFound
9+
from werkzeug.routing import Map, RequestRedirect
10+
11+
from ..http11 import Request, Response
12+
from .server import Server, ServerConnection, serve
13+
14+
15+
__all__ = ["route", "unix_route", "Router"]
16+
17+
18+
class Router:
19+
"""WebSocket router supporting :func:`route`."""
20+
21+
def __init__(
22+
self,
23+
url_map: Map,
24+
server_name: str | None = None,
25+
url_scheme: str = "ws",
26+
) -> None:
27+
self.url_map = url_map
28+
self.server_name = server_name
29+
self.url_scheme = url_scheme
30+
for rule in self.url_map.iter_rules():
31+
rule.websocket = True
32+
33+
def get_server_name(self, connection: ServerConnection, request: Request) -> str:
34+
if self.server_name is None:
35+
return request.headers["Host"]
36+
else:
37+
return self.server_name
38+
39+
def redirect(self, connection: ServerConnection, url: str) -> Response:
40+
response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}")
41+
response.headers["Location"] = url
42+
return response
43+
44+
def not_found(self, connection: ServerConnection) -> Response:
45+
return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found")
46+
47+
def route_request(
48+
self, connection: ServerConnection, request: Request
49+
) -> Response | None:
50+
"""Route incoming request."""
51+
url_map_adapter = self.url_map.bind(
52+
server_name=self.get_server_name(connection, request),
53+
url_scheme=self.url_scheme,
54+
)
55+
try:
56+
parsed = urllib.parse.urlparse(request.path)
57+
handler, kwargs = url_map_adapter.match(
58+
path_info=parsed.path,
59+
query_args=parsed.query,
60+
)
61+
except RequestRedirect as redirect:
62+
return self.redirect(connection, redirect.new_url)
63+
except NotFound:
64+
return self.not_found(connection)
65+
connection.handler, connection.handler_kwargs = handler, kwargs
66+
return None
67+
68+
def handler(self, connection: ServerConnection) -> None:
69+
"""Handle a connection."""
70+
return connection.handler(connection, **connection.handler_kwargs)
71+
72+
73+
def route(
74+
url_map: Map,
75+
*args: Any,
76+
server_name: str | None = None,
77+
ssl: ssl_module.SSLContext | Literal[True] | None = None,
78+
create_router: type[Router] | None = None,
79+
**kwargs: Any,
80+
) -> Server:
81+
"""
82+
Create a WebSocket server dispatching connections to different handlers.
83+
84+
This feature requires the third-party library `werkzeug`_::
85+
86+
$ pip install werkzeug
87+
88+
.. _werkzeug: https://werkzeug.palletsprojects.com/
89+
90+
:func:`route` accepts the same arguments as
91+
:func:`~websockets.sync.server.serve`, except as described below.
92+
93+
The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns
94+
to connection handlers. In addition to the connection, handlers receive
95+
parameters captured in the URL as keyword arguments.
96+
97+
Here's an example::
98+
99+
100+
from websockets.sync.router import route
101+
from werkzeug.routing import Map, Rule
102+
103+
def channel_handler(websocket, channel_id):
104+
...
105+
106+
url_map = Map([
107+
Rule("/channel/<uuid:channel_id>", endpoint=channel_handler),
108+
...
109+
])
110+
111+
with route(url_map, ...) as server:
112+
server.serve_forever()
113+
114+
Refer to the documentation of :mod:`werkzeug.routing` for details.
115+
116+
If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map,
117+
when the server runs behind a reverse proxy that modifies the ``Host``
118+
header or terminates TLS, you need additional configuration:
119+
120+
* Set ``server_name`` to the name of the server as seen by clients. When not
121+
provided, websockets uses the value of the ``Host`` header.
122+
123+
* Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling
124+
TLS. Under the hood, this bind the URL map with a ``url_scheme`` of
125+
``wss://`` instead of ``ws://``.
126+
127+
There is no need to specify ``websocket=True`` in each rule. It is added
128+
automatically.
129+
130+
Args:
131+
url_map: Mapping of URL patterns to connection handlers.
132+
server_name: Name of the server as seen by clients. If :obj:`None`,
133+
websockets uses the value of the ``Host`` header.
134+
ssl: Configuration for enabling TLS on the connection. Set it to
135+
:obj:`True` if a reverse proxy terminates TLS connections.
136+
create_router: Factory for the :class:`Router` dispatching requests to
137+
handlers. Set it to a wrapper or a subclass to customize routing.
138+
139+
"""
140+
url_scheme = "ws" if ssl is None else "wss"
141+
if ssl is not True and ssl is not None:
142+
kwargs["ssl"] = ssl
143+
144+
if create_router is None:
145+
create_router = Router
146+
147+
router = create_router(url_map, server_name, url_scheme)
148+
149+
_process_request: (
150+
Callable[
151+
[ServerConnection, Request],
152+
Response | None,
153+
]
154+
| None
155+
) = kwargs.pop("process_request", None)
156+
if _process_request is None:
157+
process_request: Callable[
158+
[ServerConnection, Request],
159+
Response | None,
160+
] = router.route_request
161+
else:
162+
163+
def process_request(
164+
connection: ServerConnection, request: Request
165+
) -> Response | None:
166+
response = _process_request(connection, request)
167+
if response is not None:
168+
return response
169+
return router.route_request(connection, request)
170+
171+
return serve(router.handler, *args, process_request=process_request, **kwargs)
172+
173+
174+
def unix_route(
175+
url_map: Map,
176+
path: str | None = None,
177+
**kwargs: Any,
178+
) -> Server:
179+
"""
180+
Create a WebSocket Unix server dispatching connections to different handlers.
181+
182+
:func:`unix_route` combines the behaviors of :func:`route` and
183+
:func:`~websockets.sync.server.unix_serve`.
184+
185+
Args:
186+
url_map: Mapping of URL patterns to connection handlers.
187+
path: File system path to the Unix socket.
188+
189+
"""
190+
return route(url_map, unix=True, path=path, **kwargs)

‎src/websockets/sync/server.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414
from collections.abc import Iterable, Sequence
1515
from types import TracebackType
16-
from typing import Any, Callable, cast
16+
from typing import Any, Callable, Mapping, cast
1717

1818
from ..exceptions import InvalidHeader
1919
from ..extensions.base import ServerExtensionFactory
@@ -82,6 +82,8 @@ def __init__(
8282
max_queue=max_queue,
8383
)
8484
self.username: str # see basic_auth()
85+
self.handler: Callable[[ServerConnection], None] # see route()
86+
self.handler_kwargs: Mapping[str, Any] # see route()
8587

8688
def respond(self, status: StatusLike, text: str) -> Response:
8789
"""

‎tests/asyncio/server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import socket
3+
import urllib.parse
34

45

56
def get_host_port(server):
@@ -9,15 +10,16 @@ def get_host_port(server):
910
raise AssertionError("expected at least one IPv4 socket")
1011

1112

12-
def get_uri(server):
13-
secure = server.server._ssl_context is not None # hack
13+
def get_uri(server, secure=None):
14+
if secure is None:
15+
secure = server.server._ssl_context is not None # hack
1416
protocol = "wss" if secure else "ws"
1517
host, port = get_host_port(server)
1618
return f"{protocol}://{host}:{port}"
1719

1820

1921
async def handler(ws):
20-
path = ws.request.path
22+
path = urllib.parse.urlparse(ws.request.path).path
2123
if path == "/":
2224
# The default path is an eval shell.
2325
async for expr in ws:

‎tests/asyncio/test_router.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import http
2+
import socket
3+
import sys
4+
import unittest
5+
from unittest.mock import patch
6+
7+
from websockets.asyncio.client import connect, unix_connect
8+
from websockets.asyncio.router import *
9+
from websockets.exceptions import InvalidStatus
10+
11+
from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path
12+
from .server import EvalShellMixin, get_uri, handler
13+
from .utils import alist
14+
15+
16+
try:
17+
from werkzeug.routing import Map, Rule
18+
except ImportError:
19+
pass
20+
21+
22+
async def echo(websocket, count):
23+
message = await websocket.recv()
24+
for _ in range(count):
25+
await websocket.send(message)
26+
27+
28+
@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed")
29+
class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase):
30+
# This is a small realistic example of werkzeug's basic URL routing
31+
# features: path matching, parameter extraction, and default values.
32+
33+
async def test_router_matches_paths_and_extracts_parameters(self):
34+
"""Router matches paths and extracts parameters."""
35+
url_map = Map(
36+
[
37+
Rule("/echo", defaults={"count": 1}, endpoint=echo),
38+
Rule("/echo/<int:count>", endpoint=echo),
39+
]
40+
)
41+
async with route(url_map, "localhost", 0) as server:
42+
async with connect(get_uri(server) + "/echo") as client:
43+
await client.send("hello")
44+
messages = await alist(client)
45+
self.assertEqual(messages, ["hello"])
46+
47+
async with connect(get_uri(server) + "/echo/3") as client:
48+
await client.send("hello")
49+
messages = await alist(client)
50+
self.assertEqual(messages, ["hello", "hello", "hello"])
51+
52+
@property # avoids an import-time dependency on werkzeug
53+
def url_map(self):
54+
return Map(
55+
[
56+
Rule("/", endpoint=handler),
57+
Rule("/r", redirect_to="/"),
58+
]
59+
)
60+
61+
async def test_route_with_query_string(self):
62+
"""Router ignores query strings when matching paths."""
63+
async with route(self.url_map, "localhost", 0) as server:
64+
async with connect(get_uri(server) + "/?a=b") as client:
65+
await self.assertEval(client, "ws.request.path", "/?a=b")
66+
67+
async def test_redirect(self):
68+
"""Router redirects connections according to redirect_to."""
69+
async with route(self.url_map, "localhost", 0) as server:
70+
async with connect(get_uri(server) + "/r") as client:
71+
await self.assertEval(client, "ws.request.path", "/")
72+
73+
async def test_secure_redirect(self):
74+
"""Router redirects connections to a wss:// URI when TLS is enabled."""
75+
async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server:
76+
async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client:
77+
await self.assertEval(client, "ws.request.path", "/")
78+
79+
@patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc)
80+
async def test_force_secure_redirect(self):
81+
"""Router redirects ws:// connections to a wss:// URI when ssl=True."""
82+
async with route(self.url_map, "localhost", 0, ssl=True) as server:
83+
redirect_uri = get_uri(server, secure=True)
84+
with self.assertRaises(InvalidStatus) as raised:
85+
async with connect(get_uri(server) + "/r"):
86+
self.fail("did not raise")
87+
self.assertEqual(
88+
raised.exception.response.headers["Location"],
89+
redirect_uri + "/",
90+
)
91+
92+
@patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc)
93+
async def test_force_redirect_server_name(self):
94+
"""Router redirects connections to the host declared in server_name."""
95+
async with route(self.url_map, "localhost", 0, server_name="other") as server:
96+
with self.assertRaises(InvalidStatus) as raised:
97+
async with connect(get_uri(server) + "/r"):
98+
self.fail("did not raise")
99+
self.assertEqual(
100+
raised.exception.response.headers["Location"],
101+
"ws://other/",
102+
)
103+
104+
async def test_not_found(self):
105+
"""Router rejects requests to unknown paths with an HTTP 404 error."""
106+
async with route(self.url_map, "localhost", 0) as server:
107+
with self.assertRaises(InvalidStatus) as raised:
108+
async with connect(get_uri(server) + "/n"):
109+
self.fail("did not raise")
110+
self.assertEqual(
111+
str(raised.exception),
112+
"server rejected WebSocket connection: HTTP 404",
113+
)
114+
115+
async def test_process_request_function_returning_none(self):
116+
"""Router supports a process_request function returning None."""
117+
118+
def process_request(ws, request):
119+
ws.process_request_ran = True
120+
121+
async with route(
122+
self.url_map, "localhost", 0, process_request=process_request
123+
) as server:
124+
async with connect(get_uri(server) + "/") as client:
125+
await self.assertEval(client, "ws.process_request_ran", "True")
126+
127+
async def test_process_request_coroutine_returning_none(self):
128+
"""Router supports a process_request coroutine returning None."""
129+
130+
async def process_request(ws, request):
131+
ws.process_request_ran = True
132+
133+
async with route(
134+
self.url_map, "localhost", 0, process_request=process_request
135+
) as server:
136+
async with connect(get_uri(server) + "/") as client:
137+
await self.assertEval(client, "ws.process_request_ran", "True")
138+
139+
async def test_process_request_function_returning_response(self):
140+
"""Router supports a process_request function returning a response."""
141+
142+
def process_request(ws, request):
143+
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
144+
145+
async with route(
146+
self.url_map, "localhost", 0, process_request=process_request
147+
) as server:
148+
with self.assertRaises(InvalidStatus) as raised:
149+
async with connect(get_uri(server) + "/"):
150+
self.fail("did not raise")
151+
self.assertEqual(
152+
str(raised.exception),
153+
"server rejected WebSocket connection: HTTP 403",
154+
)
155+
156+
async def test_process_request_coroutine_returning_response(self):
157+
"""Router supports a process_request coroutine returning a response."""
158+
159+
async def process_request(ws, request):
160+
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
161+
162+
async with route(
163+
self.url_map, "localhost", 0, process_request=process_request
164+
) as server:
165+
with self.assertRaises(InvalidStatus) as raised:
166+
async with connect(get_uri(server) + "/"):
167+
self.fail("did not raise")
168+
self.assertEqual(
169+
str(raised.exception),
170+
"server rejected WebSocket connection: HTTP 403",
171+
)
172+
173+
async def test_custom_router_factory(self):
174+
"""Router supports a custom router factory."""
175+
176+
class MyRouter(Router):
177+
async def handler(self, connection):
178+
connection.my_router_ran = True
179+
return await super().handler(connection)
180+
181+
async with route(
182+
self.url_map, "localhost", 0, create_router=MyRouter
183+
) as server:
184+
async with connect(get_uri(server)) as client:
185+
await self.assertEval(client, "ws.my_router_ran", "True")
186+
187+
188+
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
189+
class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase):
190+
async def test_router_supports_unix_sockets(self):
191+
"""Router supports Unix sockets."""
192+
url_map = Map([Rule("/echo/<int:count>", endpoint=echo)])
193+
with temp_unix_socket_path() as path:
194+
async with unix_route(url_map, path):
195+
async with unix_connect(path, "ws://localhost/echo/3") as client:
196+
await client.send("hello")
197+
messages = await alist(client)
198+
self.assertEqual(messages, ["hello", "hello", "hello"])

‎tests/sync/server.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import contextlib
22
import ssl
33
import threading
4+
import urllib.parse
45

6+
from websockets.sync.router import *
57
from websockets.sync.server import *
68

79

8-
def get_uri(server):
9-
secure = isinstance(server.socket, ssl.SSLSocket) # hack
10+
def get_uri(server, secure=None):
11+
if secure is None:
12+
secure = isinstance(server.socket, ssl.SSLSocket) # hack
1013
protocol = "wss" if secure else "ws"
1114
host, port = server.socket.getsockname()
1215
return f"{protocol}://{host}:{port}"
1316

1417

1518
def handler(ws):
16-
path = ws.request.path
19+
path = urllib.parse.urlparse(ws.request.path).path
1720
if path == "/":
1821
# The default path is an eval shell.
1922
for expr in ws:
@@ -34,8 +37,14 @@ def assertEval(self, client, expr, value):
3437

3538

3639
@contextlib.contextmanager
37-
def run_server(handler=handler, host="localhost", port=0, **kwargs):
38-
with serve(handler, host, port, **kwargs) as server:
40+
def run_server_or_router(
41+
serve_or_route,
42+
handler_or_url_map,
43+
host="localhost",
44+
port=0,
45+
**kwargs,
46+
):
47+
with serve_or_route(handler_or_url_map, host, port, **kwargs) as server:
3948
thread = threading.Thread(target=server.serve_forever)
4049
thread.start()
4150

@@ -63,13 +72,34 @@ def handler(sock, addr):
6372
handler_thread.join()
6473

6574

75+
def run_server(handler=handler, **kwargs):
76+
return run_server_or_router(serve, handler, **kwargs)
77+
78+
79+
def run_router(url_map, **kwargs):
80+
return run_server_or_router(route, url_map, **kwargs)
81+
82+
6683
@contextlib.contextmanager
67-
def run_unix_server(path, handler=handler, **kwargs):
68-
with unix_serve(handler, path, **kwargs) as server:
84+
def run_unix_server_or_router(
85+
path,
86+
unix_serve_or_route,
87+
handler_or_url_map,
88+
**kwargs,
89+
):
90+
with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server:
6991
thread = threading.Thread(target=server.serve_forever)
7092
thread.start()
7193
try:
7294
yield server
7395
finally:
7496
server.shutdown()
7597
thread.join()
98+
99+
100+
def run_unix_server(path, handler=handler, **kwargs):
101+
return run_unix_server_or_router(path, unix_serve, handler, **kwargs)
102+
103+
104+
def run_unix_router(path, url_map, **kwargs):
105+
return run_unix_server_or_router(path, unix_route, url_map, **kwargs)

‎tests/sync/test_router.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import http
2+
import socket
3+
import sys
4+
import unittest
5+
from unittest.mock import patch
6+
7+
from websockets.exceptions import InvalidStatus
8+
from websockets.sync.client import connect, unix_connect
9+
from websockets.sync.router import *
10+
11+
from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path
12+
from .server import EvalShellMixin, get_uri, handler, run_router, run_unix_router
13+
14+
15+
try:
16+
from werkzeug.routing import Map, Rule
17+
except ImportError:
18+
pass
19+
20+
21+
def echo(websocket, count):
22+
message = websocket.recv()
23+
for _ in range(count):
24+
websocket.send(message)
25+
26+
27+
@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed")
28+
class RouterTests(EvalShellMixin, unittest.TestCase):
29+
# This is a small realistic example of werkzeug's basic URL routing
30+
# features: path matching, parameter extraction, and default values.
31+
32+
def test_router_matches_paths_and_extracts_parameters(self):
33+
"""Router matches paths and extracts parameters."""
34+
url_map = Map(
35+
[
36+
Rule("/echo", defaults={"count": 1}, endpoint=echo),
37+
Rule("/echo/<int:count>", endpoint=echo),
38+
]
39+
)
40+
with run_router(url_map) as server:
41+
with connect(get_uri(server) + "/echo") as client:
42+
client.send("hello")
43+
messages = list(client)
44+
self.assertEqual(messages, ["hello"])
45+
46+
with connect(get_uri(server) + "/echo/3") as client:
47+
client.send("hello")
48+
messages = list(client)
49+
self.assertEqual(messages, ["hello", "hello", "hello"])
50+
51+
@property # avoids an import-time dependency on werkzeug
52+
def url_map(self):
53+
return Map(
54+
[
55+
Rule("/", endpoint=handler),
56+
Rule("/r", redirect_to="/"),
57+
]
58+
)
59+
60+
def test_route_with_query_string(self):
61+
"""Router ignores query strings when matching paths."""
62+
with run_router(self.url_map) as server:
63+
with connect(get_uri(server) + "/?a=b") as client:
64+
self.assertEval(client, "ws.request.path", "/?a=b")
65+
66+
def test_redirect(self):
67+
"""Router redirects connections according to redirect_to."""
68+
with run_router(self.url_map, server_name="localhost") as server:
69+
with self.assertRaises(InvalidStatus) as raised:
70+
with connect(get_uri(server) + "/r"):
71+
self.fail("did not raise")
72+
self.assertEqual(
73+
raised.exception.response.headers["Location"],
74+
"ws://localhost/",
75+
)
76+
77+
def test_secure_redirect(self):
78+
"""Router redirects connections to a wss:// URI when TLS is enabled."""
79+
with run_router(
80+
self.url_map, server_name="localhost", ssl=SERVER_CONTEXT
81+
) as server:
82+
with self.assertRaises(InvalidStatus) as raised:
83+
with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT):
84+
self.fail("did not raise")
85+
self.assertEqual(
86+
raised.exception.response.headers["Location"],
87+
"wss://localhost/",
88+
)
89+
90+
@patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc)
91+
def test_force_secure_redirect(self):
92+
"""Router redirects ws:// connections to a wss:// URI when ssl=True."""
93+
with run_router(self.url_map, ssl=True) as server:
94+
redirect_uri = get_uri(server, secure=True)
95+
with self.assertRaises(InvalidStatus) as raised:
96+
with connect(get_uri(server) + "/r"):
97+
self.fail("did not raise")
98+
self.assertEqual(
99+
raised.exception.response.headers["Location"],
100+
redirect_uri + "/",
101+
)
102+
103+
@patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc)
104+
def test_force_redirect_server_name(self):
105+
"""Router redirects connections to the host declared in server_name."""
106+
with run_router(self.url_map, server_name="other") as server:
107+
with self.assertRaises(InvalidStatus) as raised:
108+
with connect(get_uri(server) + "/r"):
109+
self.fail("did not raise")
110+
self.assertEqual(
111+
raised.exception.response.headers["Location"],
112+
"ws://other/",
113+
)
114+
115+
def test_not_found(self):
116+
"""Router rejects requests to unknown paths with an HTTP 404 error."""
117+
with run_router(self.url_map) as server:
118+
with self.assertRaises(InvalidStatus) as raised:
119+
with connect(get_uri(server) + "/n"):
120+
self.fail("did not raise")
121+
self.assertEqual(
122+
str(raised.exception),
123+
"server rejected WebSocket connection: HTTP 404",
124+
)
125+
126+
def test_process_request_returning_none(self):
127+
"""Router supports a process_request returning None."""
128+
129+
def process_request(ws, request):
130+
ws.process_request_ran = True
131+
132+
with run_router(self.url_map, process_request=process_request) as server:
133+
with connect(get_uri(server) + "/") as client:
134+
self.assertEval(client, "ws.process_request_ran", "True")
135+
136+
def test_process_request_returning_response(self):
137+
"""Router supports a process_request returning a response."""
138+
139+
def process_request(ws, request):
140+
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
141+
142+
with run_router(self.url_map, process_request=process_request) as server:
143+
with self.assertRaises(InvalidStatus) as raised:
144+
with connect(get_uri(server) + "/"):
145+
self.fail("did not raise")
146+
self.assertEqual(
147+
str(raised.exception),
148+
"server rejected WebSocket connection: HTTP 403",
149+
)
150+
151+
def test_custom_router_factory(self):
152+
"""Router supports a custom router factory."""
153+
154+
class MyRouter(Router):
155+
def handler(self, connection):
156+
connection.my_router_ran = True
157+
return super().handler(connection)
158+
159+
with run_router(self.url_map, create_router=MyRouter) as server:
160+
with connect(get_uri(server)) as client:
161+
self.assertEval(client, "ws.my_router_ran", "True")
162+
163+
164+
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
165+
class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase):
166+
def test_router_supports_unix_sockets(self):
167+
"""Router supports Unix sockets."""
168+
url_map = Map([Rule("/echo/<int:count>", endpoint=echo)])
169+
with temp_unix_socket_path() as path:
170+
with run_unix_router(path, url_map):
171+
with unix_connect(path, "ws://localhost/echo/3") as client:
172+
client.send("hello")
173+
messages = list(client)
174+
self.assertEqual(messages, ["hello", "hello", "hello"])

‎tests/test_exports.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import websockets
44
import websockets.asyncio.client
5+
import websockets.asyncio.router
56
import websockets.asyncio.server
67
import websockets.client
78
import websockets.datastructures
@@ -16,6 +17,7 @@
1617
for name in (
1718
[]
1819
+ websockets.asyncio.client.__all__
20+
+ websockets.asyncio.router.__all__
1921
+ websockets.asyncio.server.__all__
2022
+ websockets.client.__all__
2123
+ websockets.datastructures.__all__

‎tox.ini

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pass_env =
1717
deps =
1818
py311,py312,py313,coverage,maxi_cov: mitmproxy
1919
py311,py312,py313,coverage,maxi_cov: python-socks[asyncio]
20+
werkzeug
2021

2122
[testenv:coverage]
2223
commands =
@@ -47,3 +48,4 @@ commands =
4748
deps =
4849
mypy
4950
python-socks
51+
werkzeug

0 commit comments

Comments
 (0)
Please sign in to comment.