Skip to content
This repository was archived by the owner on Jun 13, 2023. It is now read-only.

Commit 5fb0295

Browse files
authored
feat(fastapi): support async endpoint handlers (#416)
1 parent 58a6168 commit 5fb0295

File tree

6 files changed

+358
-18
lines changed

6 files changed

+358
-18
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ Advanced options can be configured as a parameter to the init() method or as env
482482
|- |EPSAGON_LAMBDA_TIMEOUT_THRESHOLD_MS |Integer|`200` |The threshold in millieseconds to send the trace before a Lambda timeout occurs |
483483
|- |EPSAGON_PAYLOADS_TO_IGNORE |List |- |Array of dictionaries to not instrument. Example: `'[{"source": "serverless-plugin-warmup"}]'` |
484484
|- |EPSAGON_REMOVE_EXCEPTION_FRAMES|Boolean|`False` |Disable the automatic capture of exception frames data (Python 3) |
485+
|- |EPSAGON_FASTAPI_ASYNC_MODE|Boolean|`False` |Enable capturing of Fast API async endpoint handlers calls(Python 3) |
486+
485487

486488

487489

epsagon/trace.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from epsagon.trace_transports import NoneTransport, HTTPTransport, LogTransport
2323
from .constants import (
2424
TIMEOUT_GRACE_TIME_MS,
25+
EPSAGON_MARKER,
2526
MAX_LABEL_SIZE,
2627
DEFAULT_SAMPLE_RATE,
2728
TRACE_URL_PREFIX,
@@ -35,6 +36,8 @@
3536
DEFAULT_MAX_TRACE_SIZE_BYTES = 64 * (2 ** 10)
3637
MAX_METADATA_FIELD_SIZE_LIMIT = 1024 * 3
3738
FAILED_TO_SERIALIZE_MESSAGE = 'Failed to serialize returned object to JSON'
39+
# check if python version is 3.7 and above
40+
IS_PY_VERSION_ABOVE_3_6 = sys.version_info[0] == 3 and sys.version_info[1] > 6
3841

3942

4043
# pylint: disable=invalid-name
@@ -95,6 +98,7 @@ def __init__(self):
9598
self.keys_to_ignore = None
9699
self.keys_to_allow = None
97100
self.use_single_trace = True
101+
self.use_async_tracer = False
98102
self.singleton_trace = None
99103
self.local_thread_to_unique_id = {}
100104
self.transport = NoneTransport()
@@ -200,11 +204,25 @@ def update_tracers(self):
200204
tracer.step_dict_output_path = self.step_dict_output_path
201205
tracer.sample_rate = self.sample_rate
202206

207+
def switch_to_async_tracer(self):
208+
"""
209+
Set the use_async_tracer flag to True.
210+
:return: None
211+
"""
212+
self.use_async_tracer = True
213+
214+
def is_async_tracer(self):
215+
"""
216+
Returns whether using an async tracer
217+
"""
218+
return self.use_async_tracer
219+
203220
def switch_to_multiple_traces(self):
204221
"""
205222
Set the use_single_trace flag to False.
206223
:return: None
207224
"""
225+
self.use_async_tracer = False
208226
self.use_single_trace = False
209227

210228
def _create_new_trace(self, unique_id=None):
@@ -233,6 +251,58 @@ def _create_new_trace(self, unique_id=None):
233251
unique_id=unique_id,
234252
)
235253

254+
@staticmethod
255+
def _get_current_task():
256+
"""
257+
Gets the current asyncio task safely
258+
:return: The task.
259+
"""
260+
# Dynamic import since this is only valid in Python3+
261+
asyncio = __import__('asyncio')
262+
263+
#check if python version 3.7 and above
264+
if IS_PY_VERSION_ABOVE_3_6:
265+
get_event_loop = asyncio.get_event_loop
266+
get_current_task = asyncio.current_task
267+
else:
268+
get_event_loop = asyncio.events._get_running_loop # pylint: disable=W0212
269+
get_current_task = asyncio.events._get_running_loop # pylint: disable=W0212
270+
try:
271+
if not get_event_loop():
272+
return None
273+
return get_current_task()
274+
except Exception: # pylint: disable=broad-except
275+
return None
276+
277+
def _get_tracer_async_mode(self, should_create):
278+
"""
279+
Get trace assuming async tracer.
280+
:return: The trace.
281+
"""
282+
task = type(self)._get_current_task()
283+
if not task:
284+
return None
285+
286+
trace = getattr(task, EPSAGON_MARKER, None)
287+
if not trace and should_create:
288+
trace = self._create_new_trace()
289+
setattr(task, EPSAGON_MARKER, trace)
290+
return trace
291+
292+
def _pop_trace_async_mode(self):
293+
"""
294+
Pops the trace from the current task, assuming async tracer
295+
:return: The trace.
296+
"""
297+
task = type(self)._get_current_task()
298+
if not task:
299+
return None
300+
301+
trace = getattr(task, EPSAGON_MARKER, None)
302+
if trace: # can safely remove tracer from async task
303+
delattr(task, EPSAGON_MARKER)
304+
return trace
305+
236306
def get_or_create_trace(self, unique_id=None):
237307
"""
238308
Gets or create a trace - thread-safe
@@ -267,6 +337,9 @@ def _get_trace(self, unique_id=None, should_create=False):
267337
:return: The trace.
268338
"""
269339
with TraceFactory.LOCK:
340+
if self.use_async_tracer:
341+
return self._get_tracer_async_mode(should_create=should_create)
342+
270343
unique_id = self.get_thread_local_unique_id(unique_id)
271344
if unique_id:
272345
trace = (
@@ -321,6 +394,8 @@ def pop_trace(self, trace=None):
321394
:return: unique id
322395
"""
323396
with self.LOCK:
397+
if self.use_async_tracer:
398+
return self._pop_trace_async_mode()
324399
if self.traces:
325400
trace = self.traces.pop(self.get_trace_identifier(trace), None)
326401
if not self.traces:
@@ -338,6 +413,11 @@ def get_thread_local_unique_id(self, unique_id=None):
338413
:param unique_id: input unique id
339414
:return: active id if there's an active unique id or given one
340415
"""
416+
if self.is_async_tracer():
417+
return self.local_thread_to_unique_id.get(
418+
type(self)._get_current_task(), unique_id
419+
)
420+
341421
return self.local_thread_to_unique_id.get(
342422
get_thread_id(), unique_id
343423
)
@@ -353,7 +433,12 @@ def set_thread_local_unique_id(self, unique_id=None):
353433
self.singleton_trace.unique_id if self.singleton_trace else None
354434
)
355435
)
356-
self.local_thread_to_unique_id[get_thread_id()] = unique_id
436+
437+
if self.is_async_tracer():
438+
self.local_thread_to_unique_id[
439+
type(self)._get_current_task()] = unique_id
440+
else:
441+
self.local_thread_to_unique_id[get_thread_id()] = unique_id
357442
return unique_id
358443

359444
def unset_thread_local_unique_id(self):

epsagon/wrappers/fastapi.py

+98-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import json.decoder
88
import asyncio
9+
import os
910

1011
import warnings
1112
from fastapi import Request, Response
@@ -29,9 +30,17 @@
2930
SCOPE_UNIQUE_ID = 'trace_unique_id'
3031
SCOPE_CONTAINER_METADATA_COLLECTED = 'container_metadata'
3132
SCOPE_IGNORE_REQUEST = 'ignore_request'
33+
IS_ASYNC_MODE = False
34+
35+
def _initialize_async_mode(mode):
36+
global IS_ASYNC_MODE # pylint: disable=global-statement
37+
IS_ASYNC_MODE = mode
38+
39+
_initialize_async_mode(os.getenv(
40+
'EPSAGON_FASTAPI_ASYNC_MODE', 'FALSE') == 'TRUE')
3241

3342
def _handle_wrapper_params(_args, kwargs, original_request_param_name):
34-
"""
43+
"""f
3544
Handles the sync/async given parameters - gets the request object
3645
If original handler is set to get the Request object, then getting the
3746
request using this param. Otherwise, trying to get the Request object using
@@ -222,6 +231,71 @@ def _fastapi_handler(
222231
raised_err
223232
)
224233

234+
async def _async_fastapi_handler(
235+
original_handler,
236+
request,
237+
status_code,
238+
args,
239+
kwargs
240+
):
241+
"""
242+
FastAPI generic handler - for callbacks executed by a threadpool
243+
:param original_handler: the wrapped original handler
244+
:param request: the given handler request
245+
:param status_code: the default configured response status code.
246+
Can be None when called by exception handlers wrapper, as there's
247+
no status code configuration for exception handlers.
248+
"""
249+
has_setup_succeeded = False
250+
should_ignore_request = False
251+
252+
try:
253+
epsagon_scope, trace = _setup_handler(request)
254+
if epsagon_scope and trace:
255+
has_setup_succeeded = True
256+
if (
257+
ignore_request('', request.url.path.lower())
258+
or
259+
is_ignored_endpoint(request.url.path.lower())
260+
):
261+
should_ignore_request = True
262+
epsagon_scope[SCOPE_IGNORE_REQUEST] = True
263+
264+
except Exception: # pylint: disable=broad-except
265+
has_setup_succeeded = False
266+
267+
if not has_setup_succeeded or should_ignore_request:
268+
return await original_handler(*args, **kwargs)
269+
270+
created_runner = False
271+
response = None
272+
if not trace.runner:
273+
if not _setup_trace_runner(epsagon_scope, trace, request):
274+
return await original_handler(*args, **kwargs)
275+
276+
raised_err = None
277+
try:
278+
response = await original_handler(*args, **kwargs)
279+
except Exception as exception: # pylint: disable=W0703
280+
raised_err = exception
281+
finally:
282+
try:
283+
epsagon.trace.trace_factory.unset_thread_local_unique_id()
284+
except Exception: # pylint: disable=broad-except
285+
pass
286+
# no need to update request body if runner already created before
287+
if created_runner:
288+
_extract_request_body(trace, request)
289+
290+
return _handle_response(
291+
epsagon_scope,
292+
response,
293+
status_code,
294+
trace,
295+
raised_err
296+
)
297+
298+
225299

226300
# pylint: disable=too-many-statements
227301
def _wrap_handler(dependant, status_code):
@@ -230,9 +304,6 @@ def _wrap_handler(dependant, status_code):
230304
"""
231305
original_handler = dependant.call
232306
is_async = asyncio.iscoroutinefunction(original_handler)
233-
if is_async:
234-
# async endpoints are not supported
235-
return
236307

237308
original_request_param_name = dependant.request_param_name
238309
if not original_request_param_name:
@@ -249,7 +320,23 @@ def wrapped_handler(*args, **kwargs):
249320
original_handler, request, status_code, args, kwargs
250321
)
251322

252-
dependant.call = wrapped_handler
323+
async def async_wrapped_handler(*args, **kwargs):
324+
"""
325+
Asynchronous wrapper handler
326+
"""
327+
request: Request = _handle_wrapper_params(
328+
args, kwargs, original_request_param_name
329+
)
330+
return await _async_fastapi_handler(
331+
original_handler, request, status_code, args, kwargs
332+
)
333+
334+
if is_async and IS_ASYNC_MODE:
335+
# async endpoints
336+
dependant.call = async_wrapped_handler
337+
338+
elif not is_async and not IS_ASYNC_MODE:
339+
dependant.call = wrapped_handler
253340

254341

255342
def route_class_wrapper(wrapped, instance, args, kwargs):
@@ -280,7 +367,7 @@ def exception_handler_wrapper(original_handler):
280367
Wraps an exception handler
281368
"""
282369
is_async = asyncio.iscoroutinefunction(original_handler)
283-
if is_async:
370+
if is_async or IS_ASYNC_MODE:
284371
# async handlers are not supported
285372
return original_handler
286373

@@ -323,12 +410,16 @@ async def server_call_wrapper(wrapped, _instance, args, kwargs):
323410

324411
trace = None
325412
try:
326-
epsagon.trace.trace_factory.switch_to_multiple_traces()
413+
if IS_ASYNC_MODE:
414+
epsagon.trace.trace_factory.switch_to_async_tracer()
415+
else:
416+
epsagon.trace.trace_factory.switch_to_multiple_traces()
327417
unique_id = str(uuid.uuid4())
328418
trace = epsagon.trace.trace_factory.get_or_create_trace(
329419
unique_id=unique_id
330420
)
331421
trace.prepare()
422+
332423
scope[EPSAGON_MARKER] = {
333424
SCOPE_UNIQUE_ID: unique_id,
334425
}

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pytest-asyncio; python_version >= '3.5'
2020
pytest-aiohttp; python_version >= '3.5'
2121
httpx; python_version >= '3.5'
2222
asynctest; python_version >= '3.5'
23+
pytest-lazy-fixture; python_version >= '3.5'
2324
moto; python_version >= '3.5'
2425
moto==2.1.0; python_version < '3.5'
2526
tornado

tests/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ def reset_tracer_mode():
8181
Resets trace factory tracer mode to a single trace.
8282
"""
8383
epsagon.trace_factory.use_single_trace = True
84+
epsagon.use_async_tracer = False

0 commit comments

Comments
 (0)