Skip to content

Commit c155f0a

Browse files
authored
fastapi: start SMTP server for testing (#842)
1 parent b091610 commit c155f0a

File tree

4 files changed

+249
-7
lines changed

4 files changed

+249
-7
lines changed

gel/_internal/_integration/_fastapi/_cli/_patch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,24 @@ def _get_fastapi_cli_import_site() -> types.FrameType | None:
2020
return None
2121

2222

23-
def maybe_patch_fastapi_cli() -> None:
23+
def maybe_patch_fastapi_cli() -> bool:
2424
if importlib.util.find_spec("fastapi") is None:
2525
# No FastAPI here, move along.
26-
return
26+
return False
2727

2828
try:
2929
import uvicorn # noqa: PLC0415 # pyright: ignore [reportMissingImports]
3030
except ImportError:
31-
return
31+
return False
3232

3333
fastapi_cli_import_site = _get_fastapi_cli_import_site()
3434
if fastapi_cli_import_site is None:
3535
# Not being imported by fastapi.cli
36-
return
36+
return False
37+
38+
if fastapi_cli_import_site.f_locals.get("command") != "dev":
39+
# Don't patch in production mode.
40+
return False
3741

3842
def _patched_uvicorn_run(*args: Any, **kwargs: Any) -> None:
3943
from . import _lifespan # noqa: PLC0415
@@ -58,3 +62,4 @@ def __getattribute__(self, name: str) -> Any:
5862
uvicorn.__name__,
5963
doc=uvicorn.__doc__,
6064
)
65+
return True
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# SPDX-PackageName: gel-python
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors.
4+
5+
from __future__ import annotations
6+
from typing import cast, Optional, TYPE_CHECKING
7+
8+
import asyncio
9+
import email.message
10+
import email.parser
11+
import email.policy
12+
import signal
13+
14+
import gel
15+
16+
if TYPE_CHECKING:
17+
import rich_toolkit
18+
19+
20+
class SMTPServerProtocol(asyncio.Protocol):
21+
_transport: asyncio.Transport
22+
_mail_from: Optional[str]
23+
_rcpt_to: list[str]
24+
_parser: email.parser.BytesFeedParser
25+
_in_data: bool = False
26+
27+
def __init__(self, cli: rich_toolkit.RichToolkit):
28+
self._cli = cli
29+
self._buffer = bytearray()
30+
self._reset()
31+
32+
def connection_made(self, transport: asyncio.BaseTransport) -> None:
33+
trans = cast("asyncio.Transport", transport)
34+
self._transport = trans
35+
trans.write(b"220 localhost Simple SMTP server\r\n")
36+
37+
def connection_lost(self, exc: Optional[Exception]) -> None:
38+
del self._transport
39+
40+
def data_received(self, data: bytes) -> None:
41+
self._buffer.extend(data)
42+
43+
while True:
44+
newline_index = self._buffer.find(b"\r\n")
45+
if newline_index == -1:
46+
break
47+
48+
line = self._buffer[:newline_index]
49+
self._buffer = self._buffer[newline_index + 2 :]
50+
51+
self._handle_line(bytes(line))
52+
53+
def _handle_line(self, line: bytes) -> None:
54+
if self._in_data:
55+
if line == b".": # End of DATA mode
56+
message = self._parser.close()
57+
assert isinstance(message, email.message.EmailMessage)
58+
self._handle_message(message)
59+
self._reset()
60+
self._transport.write(b"250 OK\r\n")
61+
else:
62+
self._parser.feed(line + b"\r\n")
63+
return
64+
65+
# Handle SMTP commands
66+
upper = line.upper()
67+
if upper.startswith((b"HELO", b"EHLO")):
68+
self._transport.write(b"250 Hello\r\n")
69+
elif upper.startswith(b"MAIL FROM:"):
70+
self._mail_from = line[10:].strip().decode()
71+
self._transport.write(b"250 OK\r\n")
72+
elif upper.startswith(b"RCPT TO:"):
73+
self._rcpt_to.append(line[8:].strip().decode())
74+
self._transport.write(b"250 OK\r\n")
75+
elif upper == b"DATA":
76+
self._transport.write(b"354 End data with <CR><LF>.<CR><LF>\r\n")
77+
self._in_data = True
78+
elif upper == b"QUIT":
79+
self._transport.write(b"221 Bye\r\n")
80+
self._transport.close()
81+
else:
82+
self._transport.write(b"500 Unrecognized command\r\n")
83+
84+
def _handle_message(self, message: email.message.EmailMessage) -> None:
85+
self._cli.print("Received email:", tag="gel")
86+
self._cli.print(f" From: {self._mail_from}", tag="gel")
87+
self._cli.print(f" To: {', '.join(self._rcpt_to)}", tag="gel")
88+
self._cli.print(f" Subject: {message.get('Subject')}", tag="gel")
89+
has_gel_header = False
90+
for key in message:
91+
if key.lower().startswith("x-gel-"):
92+
self._cli.print(f" {key}: {message[key]}", tag="gel")
93+
has_gel_header = True
94+
if not has_gel_header:
95+
text_parts = []
96+
if message.is_multipart():
97+
for part in message.walk():
98+
content_type = part.get_content_type()
99+
content_disposition = part.get("Content-Disposition", "")
100+
if (
101+
content_type == "text/plain"
102+
and "attachment" not in content_disposition
103+
):
104+
charset = part.get_content_charset() or "utf-8"
105+
payload = part.get_payload(decode=True)
106+
if isinstance(payload, bytes):
107+
text = payload.decode(charset, errors="replace")
108+
text_parts.append(text)
109+
else:
110+
if message.get_content_type() == "text/plain":
111+
charset = message.get_content_charset() or "utf-8"
112+
payload = message.get_payload(decode=True)
113+
if isinstance(payload, bytes):
114+
text_parts.append(
115+
payload.decode(charset, errors="replace")
116+
)
117+
self._cli.print(
118+
"No X-Gel-* headers found, email content:", tag="gel"
119+
)
120+
if text_parts:
121+
for text in text_parts:
122+
self._cli.print(text, tag="gel")
123+
else:
124+
self._cli.print(
125+
repr(message.get_payload(decode=True)), tag="gel"
126+
)
127+
128+
def _reset(self) -> None:
129+
self._mail_from = None
130+
self._rcpt_to = []
131+
self._parser = email.parser.BytesFeedParser(policy=email.policy.SMTP)
132+
self._in_data = False
133+
self._buffer.clear()
134+
135+
136+
class SMTPServer:
137+
_server: asyncio.Server
138+
139+
async def maybe_start(
140+
self,
141+
client: gel.AsyncIOClient,
142+
) -> None:
143+
from fastapi_cli.utils.cli import get_rich_toolkit # noqa: PLC0415
144+
145+
# get_rich_toolkit() installs a SIGTERM handler underneath, which
146+
# causes unnecessary noise in the logs at CTRL + C shutdown.
147+
orig_handler = signal.getsignal(signal.SIGTERM)
148+
try:
149+
toolkit = get_rich_toolkit()
150+
finally:
151+
signal.signal(signal.SIGTERM, orig_handler)
152+
153+
try:
154+
config = await client.query_single("""
155+
select cfg::SMTPProviderConfig {
156+
host,
157+
port,
158+
security
159+
} filter .name =
160+
assert_single(cfg::Config).current_email_provider_name;
161+
""")
162+
except gel.QueryError as ex:
163+
toolkit.print(
164+
f"Skipping SMTP server startup due to "
165+
f"error reading configuration: {ex}",
166+
tag="gel",
167+
)
168+
return None
169+
170+
if config is None:
171+
toolkit.print(
172+
"No SMTP configuration found, skipping SMTP server startup",
173+
tag="gel",
174+
)
175+
return None
176+
if config.security not in {"PlainText", "STARTTLSOrPlainText"}:
177+
toolkit.print(
178+
"SMTP server only supports security=PlainText or "
179+
"STARTTLSOrPlainText, skipping SMTP server startup",
180+
tag="gel",
181+
)
182+
return None
183+
184+
try:
185+
self._server = await asyncio.get_running_loop().create_server(
186+
lambda: SMTPServerProtocol(toolkit),
187+
host=config.host,
188+
port=config.port,
189+
)
190+
except Exception as ex:
191+
toolkit.print(
192+
f"Skipping SMTP server startup due to error: {ex}",
193+
tag="gel",
194+
)
195+
else:
196+
toolkit.print(
197+
f"Started SMTP server on {config.host}:{config.port} "
198+
f"for testing purposes.",
199+
tag="gel",
200+
)
201+
202+
async def stop(self) -> None:
203+
if hasattr(self, "_server"):
204+
self._server.close()
205+
await self._server.wait_closed()

gel/_internal/_integration/_fastapi/_client.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import importlib.util
2020
import inspect
2121
import logging
22+
import os
2223
import sys
2324

2425
import fastapi
@@ -34,6 +35,7 @@
3435
import types
3536
from collections.abc import Callable, Iterator, Sequence
3637
from ._auth import GelAuth
38+
from ._cli._smtpd import SMTPServer
3739

3840

3941
_logger = logging.getLogger("gel.fastapi")
@@ -154,6 +156,7 @@ class GelLifespan:
154156
_bio_client_creator: Callable[..., gel.Client]
155157
_bio_client: gel.Client
156158
_client_accessed: bool = False
159+
_smtp_server: Optional[SMTPServer] = None
157160

158161
_auth: ExtensionShell[GelAuth]
159162

@@ -189,6 +192,8 @@ async def __aenter__(self) -> dict[str, Any]:
189192
if ext is not None:
190193
await ext.on_startup(self._app)
191194

195+
if self._smtp_server is not None:
196+
await self._smtp_server.maybe_start(self._client)
192197
self.installed = True
193198
return {
194199
self.state_name.value: self._client,
@@ -205,6 +210,8 @@ async def __aexit__(
205210
exc_val: Optional[BaseException],
206211
exc_tb: Optional[types.TracebackType],
207212
) -> None:
213+
if self._smtp_server is not None:
214+
await self._smtp_server.stop()
208215
for shell in self._shells:
209216
if shell.extension is not None:
210217
await shell.extension.on_shutdown(self._app)
@@ -354,7 +361,12 @@ def without_auth(self) -> Self:
354361
return self
355362

356363

357-
def gelify(app: fastapi.FastAPI, **kwargs: Any) -> GelLifespan:
364+
def gelify(
365+
app: fastapi.FastAPI,
366+
*,
367+
disable_testing_smtp_server: bool = False,
368+
**kwargs: Any,
369+
) -> GelLifespan:
358370
rv = GelLifespan(
359371
app,
360372
client_creator=gel.create_async_client,
@@ -365,7 +377,27 @@ def gelify(app: fastapi.FastAPI, **kwargs: Any) -> GelLifespan:
365377
getattr(rv, key)(value)
366378
else:
367379
raise ValueError(f"Unknown configuration option: {key}")
368-
_cli.maybe_patch_fastapi_cli()
380+
381+
# Patch FastAPI CLI when started with `fastapi dev`. Without `--no-reload`,
382+
# `fastapi dev` will import the FastAPI application object twice: first
383+
# in the reloader main process where we can do the patching successfully,
384+
# then in the subprocess that actually runs the application, where there
385+
# is no CLI to patch so `patched` would be `False`.
386+
patched = _cli.maybe_patch_fastapi_cli()
387+
if not disable_testing_smtp_server:
388+
# The SMTP server requires the Gel client to load the SMTP
389+
# configuration (in `__aenter__`), so we have to ensure that the
390+
# subprocess that runs the FastAPI application knows that we
391+
# are in `fastapi dev` mode through the `_GEL_FASTAPI_CLI_PATCH`.
392+
if patched:
393+
os.environ["_GEL_FASTAPI_CLI_PATCH"] = "patched"
394+
if os.environ.get("_GEL_FASTAPI_CLI_PATCH") == "patched":
395+
# Note: this is an `if` branch instead of an `elif` because we
396+
# also need an SMTP server under `fastapi dev --no-reload`.
397+
from ._cli._smtpd import SMTPServer # noqa: PLC0415
398+
399+
rv._smtp_server = SMTPServer()
400+
369401
return rv
370402

371403

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ test = [
4343

4444
'pytest>=3.6.0',
4545
'uvloop>=0.15.1; platform_system != "Windows"',
46-
'fastapi',
46+
'fastapi[standard]',
4747
'pyjwt',
4848
'httpx',
4949
]

0 commit comments

Comments
 (0)