Skip to content

Commit a2f14e7

Browse files
authored
Avoid double JSON encode/decode for socket.io (#4449)
* Avoid double JSON encode/decode for socket.io socket.io (python and js) already has a built in mechanism for JSON encoding and decoding messages over the websocket. To use it, we pass a custom `json` namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the messages. This avoids sending a JSON-encoded string of JSON over the wire, and reduces the number of serialization/deserialization passes over the message data. The side benefit is that debugging websocket messages in browser tools displays the parsed JSON hierarchy and is much easier to work with. * JSON5.parse in on_upload_progress handler responses
1 parent 053cbe7 commit a2f14e7

File tree

4 files changed

+97
-63
lines changed

4 files changed

+97
-63
lines changed

reflex/.templates/web/utils/state.js

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
300300
if (socket) {
301301
socket.emit(
302302
"event",
303-
JSON.stringify(event, (k, v) => (v === undefined ? null : v))
303+
event,
304304
);
305305
return true;
306306
}
@@ -407,6 +407,8 @@ export const connect = async (
407407
transports: transports,
408408
autoUnref: false,
409409
});
410+
// Ensure undefined fields in events are sent as null instead of removed
411+
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
410412

411413
function checkVisibility() {
412414
if (document.visibilityState === "visible") {
@@ -443,8 +445,7 @@ export const connect = async (
443445
});
444446

445447
// On each received message, queue the updates and events.
446-
socket.current.on("event", async (message) => {
447-
const update = JSON5.parse(message);
448+
socket.current.on("event", async (update) => {
448449
for (const substate in update.delta) {
449450
dispatch[substate](update.delta[substate]);
450451
}
@@ -456,7 +457,7 @@ export const connect = async (
456457
});
457458
socket.current.on("reload", async (event) => {
458459
event_processing = false;
459-
queueEvents([...initialEvents(), JSON5.parse(event)], socket);
460+
queueEvents([...initialEvents(), event], socket);
460461
});
461462

462463
document.addEventListener("visibilitychange", checkVisibility);
@@ -497,23 +498,31 @@ export const uploadFiles = async (
497498
// Whenever called, responseText will contain the entire response so far.
498499
const chunks = progressEvent.event.target.responseText.trim().split("\n");
499500
// So only process _new_ chunks beyond resp_idx.
500-
chunks.slice(resp_idx).map((chunk) => {
501-
event_callbacks.map((f, ix) => {
502-
f(chunk)
503-
.then(() => {
504-
if (ix === event_callbacks.length - 1) {
505-
// Mark this chunk as processed.
506-
resp_idx += 1;
507-
}
508-
})
509-
.catch((e) => {
510-
if (progressEvent.progress === 1) {
511-
// Chunk may be incomplete, so only report errors when full response is available.
512-
console.log("Error parsing chunk", chunk, e);
513-
}
514-
return;
515-
});
516-
});
501+
chunks.slice(resp_idx).map((chunk_json) => {
502+
try {
503+
const chunk = JSON5.parse(chunk_json);
504+
event_callbacks.map((f, ix) => {
505+
f(chunk)
506+
.then(() => {
507+
if (ix === event_callbacks.length - 1) {
508+
// Mark this chunk as processed.
509+
resp_idx += 1;
510+
}
511+
})
512+
.catch((e) => {
513+
if (progressEvent.progress === 1) {
514+
// Chunk may be incomplete, so only report errors when full response is available.
515+
console.log("Error processing chunk", chunk, e);
516+
}
517+
return;
518+
});
519+
});
520+
} catch (e) {
521+
if (progressEvent.progress === 1) {
522+
console.log("Error parsing chunk", chunk_json, e);
523+
}
524+
return;
525+
}
517526
});
518527
};
519528

reflex/app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import traceback
1818
from datetime import datetime
1919
from pathlib import Path
20+
from types import SimpleNamespace
2021
from typing import (
2122
TYPE_CHECKING,
2223
Any,
@@ -363,6 +364,10 @@ def _setup_state(self) -> None:
363364
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
364365
ping_interval=constants.Ping.INTERVAL,
365366
ping_timeout=constants.Ping.TIMEOUT,
367+
json=SimpleNamespace(
368+
dumps=staticmethod(format.json_dumps),
369+
loads=staticmethod(json.loads),
370+
),
366371
transports=["websocket"],
367372
)
368373
elif getattr(self.sio, "async_mode", "") != "asgi":
@@ -1543,7 +1548,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
15431548
"""
15441549
# Creating a task prevents the update from being blocked behind other coroutines.
15451550
await asyncio.create_task(
1546-
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
1551+
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
15471552
)
15481553

15491554
async def on_event(self, sid, data):
@@ -1556,7 +1561,7 @@ async def on_event(self, sid, data):
15561561
sid: The Socket.IO session id.
15571562
data: The event data.
15581563
"""
1559-
fields = json.loads(data)
1564+
fields = data
15601565
# Get the event.
15611566
event = Event(
15621567
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

reflex/utils/format.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
664664
return lib
665665

666666

667-
def json_dumps(obj: Any) -> str:
667+
def json_dumps(obj: Any, **kwargs) -> str:
668668
"""Takes an object and returns a jsonified string.
669669
670670
Args:
671671
obj: The object to be serialized.
672+
kwargs: Additional keyword arguments to pass to json.dumps.
672673
673674
Returns:
674675
A string
675676
"""
676677
from reflex.utils import serializers
677678

678-
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize)
679+
kwargs.setdefault("ensure_ascii", False)
680+
kwargs.setdefault("default", serializers.serialize)
681+
682+
return json.dumps(obj, **kwargs)
679683

680684

681685
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

tests/units/test_state.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,24 @@ async def _coro_waiter():
18401840
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
18411841

18421842

1843+
class CopyingAsyncMock(AsyncMock):
1844+
"""An AsyncMock, but deepcopy the args and kwargs first."""
1845+
1846+
def __call__(self, *args, **kwargs):
1847+
"""Call the mock.
1848+
1849+
Args:
1850+
args: the arguments passed to the mock
1851+
kwargs: the keyword arguments passed to the mock
1852+
1853+
Returns:
1854+
The result of the mock call
1855+
"""
1856+
args = copy.deepcopy(args)
1857+
kwargs = copy.deepcopy(kwargs)
1858+
return super().__call__(*args, **kwargs)
1859+
1860+
18431861
@pytest.fixture(scope="function")
18441862
def mock_app_simple(monkeypatch) -> rx.App:
18451863
"""Simple Mock app fixture.
@@ -1856,7 +1874,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
18561874

18571875
setattr(app_module, CompileVars.APP, app)
18581876
app.state = TestState
1859-
app.event_namespace.emit = AsyncMock() # type: ignore
1877+
app.event_namespace.emit = CopyingAsyncMock() # type: ignore
18601878

18611879
def _mock_get_app(*args, **kwargs):
18621880
return app_module
@@ -1960,21 +1978,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
19601978
mock_app.event_namespace.emit.assert_called_once()
19611979
mcall = mock_app.event_namespace.emit.mock_calls[0]
19621980
assert mcall.args[0] == str(SocketEvent.EVENT)
1963-
assert json.loads(mcall.args[1]) == dataclasses.asdict(
1964-
StateUpdate(
1965-
delta={
1966-
parent_state.get_full_name(): {
1967-
"upper": "",
1968-
"sum": 3.14,
1969-
},
1970-
grandchild_state.get_full_name(): {
1971-
"value2": "42",
1972-
},
1973-
GrandchildState3.get_full_name(): {
1974-
"computed": "",
1975-
},
1976-
}
1977-
)
1981+
assert mcall.args[1] == StateUpdate(
1982+
delta={
1983+
parent_state.get_full_name(): {
1984+
"upper": "",
1985+
"sum": 3.14,
1986+
},
1987+
grandchild_state.get_full_name(): {
1988+
"value2": "42",
1989+
},
1990+
GrandchildState3.get_full_name(): {
1991+
"computed": "",
1992+
},
1993+
}
19781994
)
19791995
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
19801996

@@ -2156,51 +2172,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
21562172
assert mock_app.event_namespace is not None
21572173
emit_mock = mock_app.event_namespace.emit
21582174

2159-
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
2175+
first_ws_message = emit_mock.mock_calls[0].args[1]
21602176
assert (
2161-
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
2177+
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
21622178
is not None
21632179
)
2164-
assert first_ws_message == {
2165-
"delta": {
2180+
assert first_ws_message == StateUpdate(
2181+
delta={
21662182
BackgroundTaskState.get_full_name(): {
21672183
"order": ["background_task:start"],
21682184
"computed_order": ["background_task:start"],
21692185
}
21702186
},
2171-
"events": [],
2172-
"final": True,
2173-
}
2187+
events=[],
2188+
final=True,
2189+
)
21742190
for call in emit_mock.mock_calls[1:5]:
2175-
assert json.loads(call.args[1]) == {
2176-
"delta": {
2191+
assert call.args[1] == StateUpdate(
2192+
delta={
21772193
BackgroundTaskState.get_full_name(): {
21782194
"computed_order": ["background_task:start"],
21792195
}
21802196
},
2181-
"events": [],
2182-
"final": True,
2183-
}
2184-
assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
2185-
"delta": {
2197+
events=[],
2198+
final=True,
2199+
)
2200+
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
2201+
delta={
21862202
BackgroundTaskState.get_full_name(): {
21872203
"order": exp_order,
21882204
"computed_order": exp_order,
21892205
"dict_list": {},
21902206
}
21912207
},
2192-
"events": [],
2193-
"final": True,
2194-
}
2195-
assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
2196-
"delta": {
2208+
events=[],
2209+
final=True,
2210+
)
2211+
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
2212+
delta={
21972213
BackgroundTaskState.get_full_name(): {
21982214
"computed_order": exp_order,
21992215
},
22002216
},
2201-
"events": [],
2202-
"final": True,
2203-
}
2217+
events=[],
2218+
final=True,
2219+
)
22042220

22052221

22062222
@pytest.mark.asyncio

0 commit comments

Comments
 (0)