|
| 1 | +import http |
| 2 | +import socket |
| 3 | +import sys |
| 4 | +import unittest |
| 5 | +from unittest.mock import patch |
| 6 | + |
| 7 | +from websockets.asyncio.client import connect, unix_connect |
| 8 | +from websockets.asyncio.router import * |
| 9 | +from websockets.exceptions import InvalidStatus |
| 10 | + |
| 11 | +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path |
| 12 | +from .server import EvalShellMixin, get_uri, handler |
| 13 | +from .utils import alist |
| 14 | + |
| 15 | + |
| 16 | +try: |
| 17 | + from werkzeug.routing import Map, Rule |
| 18 | +except ImportError: |
| 19 | + pass |
| 20 | + |
| 21 | + |
| 22 | +async def echo(websocket, count): |
| 23 | + message = await websocket.recv() |
| 24 | + for _ in range(count): |
| 25 | + await websocket.send(message) |
| 26 | + |
| 27 | + |
| 28 | +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") |
| 29 | +class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): |
| 30 | + # This is a small realistic example of werkzeug's basic URL routing |
| 31 | + # features: path matching, parameter extraction, and default values. |
| 32 | + |
| 33 | + async def test_router_matches_paths_and_extracts_parameters(self): |
| 34 | + """Router matches paths and extracts parameters.""" |
| 35 | + url_map = Map( |
| 36 | + [ |
| 37 | + Rule("/echo", defaults={"count": 1}, endpoint=echo), |
| 38 | + Rule("/echo/<int:count>", endpoint=echo), |
| 39 | + ] |
| 40 | + ) |
| 41 | + async with route(url_map, "localhost", 0) as server: |
| 42 | + async with connect(get_uri(server) + "/echo") as client: |
| 43 | + await client.send("hello") |
| 44 | + messages = await alist(client) |
| 45 | + self.assertEqual(messages, ["hello"]) |
| 46 | + |
| 47 | + async with connect(get_uri(server) + "/echo/3") as client: |
| 48 | + await client.send("hello") |
| 49 | + messages = await alist(client) |
| 50 | + self.assertEqual(messages, ["hello", "hello", "hello"]) |
| 51 | + |
| 52 | + @property # avoids an import-time dependency on werkzeug |
| 53 | + def url_map(self): |
| 54 | + return Map( |
| 55 | + [ |
| 56 | + Rule("/", endpoint=handler), |
| 57 | + Rule("/r", redirect_to="/"), |
| 58 | + ] |
| 59 | + ) |
| 60 | + |
| 61 | + async def test_route_with_query_string(self): |
| 62 | + """Router ignores query strings when matching paths.""" |
| 63 | + async with route(self.url_map, "localhost", 0) as server: |
| 64 | + async with connect(get_uri(server) + "/?a=b") as client: |
| 65 | + await self.assertEval(client, "ws.request.path", "/?a=b") |
| 66 | + |
| 67 | + async def test_redirect(self): |
| 68 | + """Router redirects connections according to redirect_to.""" |
| 69 | + async with route(self.url_map, "localhost", 0) as server: |
| 70 | + async with connect(get_uri(server) + "/r") as client: |
| 71 | + await self.assertEval(client, "ws.request.path", "/") |
| 72 | + |
| 73 | + async def test_secure_redirect(self): |
| 74 | + """Router redirects connections to a wss:// URI when TLS is enabled.""" |
| 75 | + async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server: |
| 76 | + async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client: |
| 77 | + await self.assertEval(client, "ws.request.path", "/") |
| 78 | + |
| 79 | + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) |
| 80 | + async def test_force_secure_redirect(self): |
| 81 | + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" |
| 82 | + async with route(self.url_map, "localhost", 0, ssl=True) as server: |
| 83 | + redirect_uri = get_uri(server, secure=True) |
| 84 | + with self.assertRaises(InvalidStatus) as raised: |
| 85 | + async with connect(get_uri(server) + "/r"): |
| 86 | + self.fail("did not raise") |
| 87 | + self.assertEqual( |
| 88 | + raised.exception.response.headers["Location"], |
| 89 | + redirect_uri + "/", |
| 90 | + ) |
| 91 | + |
| 92 | + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) |
| 93 | + async def test_force_redirect_server_name(self): |
| 94 | + """Router redirects connections to the host declared in server_name.""" |
| 95 | + async with route(self.url_map, "localhost", 0, server_name="other") as server: |
| 96 | + with self.assertRaises(InvalidStatus) as raised: |
| 97 | + async with connect(get_uri(server) + "/r"): |
| 98 | + self.fail("did not raise") |
| 99 | + self.assertEqual( |
| 100 | + raised.exception.response.headers["Location"], |
| 101 | + "ws://other/", |
| 102 | + ) |
| 103 | + |
| 104 | + async def test_not_found(self): |
| 105 | + """Router rejects requests to unknown paths with an HTTP 404 error.""" |
| 106 | + async with route(self.url_map, "localhost", 0) as server: |
| 107 | + with self.assertRaises(InvalidStatus) as raised: |
| 108 | + async with connect(get_uri(server) + "/n"): |
| 109 | + self.fail("did not raise") |
| 110 | + self.assertEqual( |
| 111 | + str(raised.exception), |
| 112 | + "server rejected WebSocket connection: HTTP 404", |
| 113 | + ) |
| 114 | + |
| 115 | + async def test_process_request_function_returning_none(self): |
| 116 | + """Router supports a process_request function returning None.""" |
| 117 | + |
| 118 | + def process_request(ws, request): |
| 119 | + ws.process_request_ran = True |
| 120 | + |
| 121 | + async with route( |
| 122 | + self.url_map, "localhost", 0, process_request=process_request |
| 123 | + ) as server: |
| 124 | + async with connect(get_uri(server) + "/") as client: |
| 125 | + await self.assertEval(client, "ws.process_request_ran", "True") |
| 126 | + |
| 127 | + async def test_process_request_coroutine_returning_none(self): |
| 128 | + """Router supports a process_request coroutine returning None.""" |
| 129 | + |
| 130 | + async def process_request(ws, request): |
| 131 | + ws.process_request_ran = True |
| 132 | + |
| 133 | + async with route( |
| 134 | + self.url_map, "localhost", 0, process_request=process_request |
| 135 | + ) as server: |
| 136 | + async with connect(get_uri(server) + "/") as client: |
| 137 | + await self.assertEval(client, "ws.process_request_ran", "True") |
| 138 | + |
| 139 | + async def test_process_request_function_returning_response(self): |
| 140 | + """Router supports a process_request function returning a response.""" |
| 141 | + |
| 142 | + def process_request(ws, request): |
| 143 | + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") |
| 144 | + |
| 145 | + async with route( |
| 146 | + self.url_map, "localhost", 0, process_request=process_request |
| 147 | + ) as server: |
| 148 | + with self.assertRaises(InvalidStatus) as raised: |
| 149 | + async with connect(get_uri(server) + "/"): |
| 150 | + self.fail("did not raise") |
| 151 | + self.assertEqual( |
| 152 | + str(raised.exception), |
| 153 | + "server rejected WebSocket connection: HTTP 403", |
| 154 | + ) |
| 155 | + |
| 156 | + async def test_process_request_coroutine_returning_response(self): |
| 157 | + """Router supports a process_request coroutine returning a response.""" |
| 158 | + |
| 159 | + async def process_request(ws, request): |
| 160 | + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") |
| 161 | + |
| 162 | + async with route( |
| 163 | + self.url_map, "localhost", 0, process_request=process_request |
| 164 | + ) as server: |
| 165 | + with self.assertRaises(InvalidStatus) as raised: |
| 166 | + async with connect(get_uri(server) + "/"): |
| 167 | + self.fail("did not raise") |
| 168 | + self.assertEqual( |
| 169 | + str(raised.exception), |
| 170 | + "server rejected WebSocket connection: HTTP 403", |
| 171 | + ) |
| 172 | + |
| 173 | + async def test_custom_router_factory(self): |
| 174 | + """Router supports a custom router factory.""" |
| 175 | + |
| 176 | + class MyRouter(Router): |
| 177 | + async def handler(self, connection): |
| 178 | + connection.my_router_ran = True |
| 179 | + return await super().handler(connection) |
| 180 | + |
| 181 | + async with route( |
| 182 | + self.url_map, "localhost", 0, create_router=MyRouter |
| 183 | + ) as server: |
| 184 | + async with connect(get_uri(server)) as client: |
| 185 | + await self.assertEval(client, "ws.my_router_ran", "True") |
| 186 | + |
| 187 | + |
| 188 | +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") |
| 189 | +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): |
| 190 | + async def test_router_supports_unix_sockets(self): |
| 191 | + """Router supports Unix sockets.""" |
| 192 | + url_map = Map([Rule("/echo/<int:count>", endpoint=echo)]) |
| 193 | + with temp_unix_socket_path() as path: |
| 194 | + async with unix_route(url_map, path): |
| 195 | + async with unix_connect(path, "ws://localhost/echo/3") as client: |
| 196 | + await client.send("hello") |
| 197 | + messages = await alist(client) |
| 198 | + self.assertEqual(messages, ["hello", "hello", "hello"]) |
0 commit comments