Skip to content

Commit f3e6bb2

Browse files
Merge pull request #328 from supertokens/feat/usercontext-request-helper
feat: Add a helper function to read the original request from the user context inside overrides
2 parents 98efe74 + 6e55f1e commit f3e6bb2

File tree

8 files changed

+170
-17
lines changed

8 files changed

+170
-17
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.14.1] - 2023-05-23
12+
13+
### Changes
14+
15+
- Added a new `get_request_from_user_context` function that can be used to read the original network request from the user context in overridden APIs and recipe functions
16+
1117
## [0.14.0] - 2023-05-18
1218
- Adds missing `check_database` boolean in `verify_session`
1319

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.14.0",
73+
version="0.14.1",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
from typing import Any, Callable, Dict, List, Optional, Union
16+
1517
from typing_extensions import Literal
16-
from typing import Callable, List, Union
18+
19+
from supertokens_python.framework.request import BaseRequest
1720

1821
from . import supertokens
1922
from .recipe_module import RecipeModule
@@ -39,3 +42,9 @@ def init(
3942

4043
def get_all_cors_headers() -> List[str]:
4144
return supertokens.Supertokens.get_instance().get_all_cors_headers()
45+
46+
47+
def get_request_from_user_context(
48+
user_context: Optional[Dict[str, Any]],
49+
) -> Optional[BaseRequest]:
50+
return Supertokens.get_instance().get_request_from_user_context(user_context)

supertokens_python/asyncio/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import List, Union, Optional, Dict
14+
from typing import Dict, List, Optional, Union
1515

1616
from supertokens_python import Supertokens
1717
from supertokens_python.interfaces import (
1818
CreateUserIdMappingOkResult,
19+
DeleteUserIdMappingOkResult,
20+
GetUserIdMappingOkResult,
21+
UnknownMappingError,
1922
UnknownSupertokensUserIDError,
23+
UpdateOrDeleteUserIdMappingInfoOkResult,
2024
UserIdMappingAlreadyExistsError,
2125
UserIDTypes,
22-
UnknownMappingError,
23-
GetUserIdMappingOkResult,
24-
DeleteUserIdMappingOkResult,
25-
UpdateOrDeleteUserIdMappingInfoOkResult,
2626
)
2727
from supertokens_python.types import UsersResponse
2828

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["2.21"]
17-
VERSION = "0.14.0"
17+
VERSION = "0.14.1"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"

supertokens_python/supertokens.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,18 @@ async def handle_supertokens_error(
552552
)
553553
return await recipe.handle_error(request, err, response)
554554
raise err
555+
556+
def get_request_from_user_context( # pylint: disable=no-self-use
557+
self,
558+
user_context: Optional[Dict[str, Any]] = None,
559+
) -> Optional[BaseRequest]:
560+
if user_context is None:
561+
return None
562+
563+
if "_default" not in user_context:
564+
return None
565+
566+
if not isinstance(user_context["_default"], dict):
567+
return None
568+
569+
return user_context.get("_default", {}).get("request")

supertokens_python/syncio/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import List, Union, Optional, Dict
14+
from typing import Dict, List, Optional, Union
1515

1616
from supertokens_python import Supertokens
1717
from supertokens_python.async_to_sync_wrapper import sync
1818
from supertokens_python.interfaces import (
1919
CreateUserIdMappingOkResult,
20+
DeleteUserIdMappingOkResult,
21+
GetUserIdMappingOkResult,
22+
UnknownMappingError,
2023
UnknownSupertokensUserIDError,
24+
UpdateOrDeleteUserIdMappingInfoOkResult,
2125
UserIdMappingAlreadyExistsError,
2226
UserIDTypes,
23-
UnknownMappingError,
24-
GetUserIdMappingOkResult,
25-
DeleteUserIdMappingOkResult,
26-
UpdateOrDeleteUserIdMappingInfoOkResult,
2727
)
2828
from supertokens_python.types import UsersResponse
2929

tests/test_user_context.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313
# under the License.
1414
from typing import Any, Dict, List, Optional
1515

16+
from fastapi import FastAPI
17+
from fastapi.testclient import TestClient
1618
from pytest import fixture, mark
17-
from supertokens_python import InputAppInfo, SupertokensConfig, init
19+
20+
from supertokens_python import (
21+
InputAppInfo,
22+
SupertokensConfig,
23+
get_request_from_user_context,
24+
init,
25+
)
1826
from supertokens_python.framework.fastapi import get_middleware
1927
from supertokens_python.recipe import emailpassword, session
2028
from supertokens_python.recipe.emailpassword.asyncio import sign_up
@@ -28,9 +36,6 @@
2836
RecipeInterface as SRecipeInterface,
2937
)
3038

31-
from fastapi import FastAPI
32-
from fastapi.testclient import TestClient
33-
3439
from .utils import clean_st, reset, setup_st, sign_in_request, start_st
3540

3641
works = False
@@ -277,3 +282,121 @@ async def create_new_session(
277282
create_new_session_context_works,
278283
]
279284
)
285+
286+
287+
@mark.asyncio
288+
async def test_get_request_from_user_context(driver_config_client: TestClient):
289+
signin_api_context_works, signin_context_works, create_new_session_context_works = (
290+
False,
291+
False,
292+
False,
293+
)
294+
295+
def apis_override_email_password(param: APIInterface):
296+
og_sign_in_post = param.sign_in_post
297+
298+
async def sign_in_post(
299+
form_fields: List[FormField],
300+
api_options: APIOptions,
301+
user_context: Dict[str, Any],
302+
):
303+
req = get_request_from_user_context(user_context)
304+
if req:
305+
assert req.method() == "POST"
306+
assert req.get_path() == "/auth/signin"
307+
nonlocal signin_api_context_works
308+
signin_api_context_works = True
309+
310+
return await og_sign_in_post(form_fields, api_options, user_context)
311+
312+
param.sign_in_post = sign_in_post
313+
return param
314+
315+
def functions_override_email_password(param: RecipeInterface):
316+
og_sign_in = param.sign_in
317+
318+
async def sign_in(email: str, password: str, user_context: Dict[str, Any]):
319+
req = get_request_from_user_context(user_context)
320+
if req:
321+
assert req.method() == "POST"
322+
assert req.get_path() == "/auth/signin"
323+
nonlocal signin_context_works
324+
signin_context_works = True
325+
326+
orginal_request = req
327+
user_context["_default"]["request"] = None
328+
329+
newReq = get_request_from_user_context(user_context)
330+
assert newReq is None
331+
332+
user_context["_default"]["request"] = orginal_request
333+
334+
return await og_sign_in(email, password, user_context)
335+
336+
param.sign_in = sign_in
337+
return param
338+
339+
def functions_override_session(param: SRecipeInterface):
340+
og_create_new_session = param.create_new_session
341+
342+
async def create_new_session(
343+
user_id: str,
344+
access_token_payload: Optional[Dict[str, Any]],
345+
session_data_in_database: Optional[Dict[str, Any]],
346+
disable_anti_csrf: Optional[bool],
347+
user_context: Dict[str, Any],
348+
):
349+
req = get_request_from_user_context(user_context)
350+
if req:
351+
assert req.method() == "POST"
352+
assert req.get_path() == "/auth/signin"
353+
nonlocal create_new_session_context_works
354+
create_new_session_context_works = True
355+
356+
response = await og_create_new_session(
357+
user_id,
358+
access_token_payload,
359+
session_data_in_database,
360+
disable_anti_csrf,
361+
user_context,
362+
)
363+
return response
364+
365+
param.create_new_session = create_new_session
366+
return param
367+
368+
init(
369+
supertokens_config=SupertokensConfig("http://localhost:3567"),
370+
app_info=InputAppInfo(
371+
app_name="SuperTokens Demo",
372+
api_domain="http://api.supertokens.io",
373+
website_domain="http://supertokens.io",
374+
),
375+
framework="fastapi",
376+
recipe_list=[
377+
emailpassword.init(
378+
override=emailpassword.InputOverrideConfig(
379+
apis=apis_override_email_password,
380+
functions=functions_override_email_password,
381+
)
382+
),
383+
session.init(
384+
override=session.InputOverrideConfig(
385+
functions=functions_override_session
386+
)
387+
),
388+
],
389+
)
390+
start_st()
391+
392+
await sign_up("[email protected]", "validpass123", {"manualCall": True})
393+
res = sign_in_request(driver_config_client, "[email protected]", "validpass123")
394+
395+
assert res.status_code == 200
396+
assert all(
397+
[
398+
signin_api_context_works,
399+
signin_context_works,
400+
create_new_session_context_works,
401+
]
402+
)

0 commit comments

Comments
 (0)