Skip to content

Commit 9f1e708

Browse files
committed
Update folder tree
1 parent 34f6078 commit 9f1e708

File tree

5 files changed

+92
-81
lines changed

5 files changed

+92
-81
lines changed
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import config
2+
import typing
3+
import aiohttp
4+
5+
from urllib.parse import quote_plus, urlparse
6+
7+
DISCORD_ENDPOINT = "https://discord.com/api"
8+
SCOPES = ["identify"]
9+
10+
11+
async def exchange_code(
12+
*, code: str, scope: str, redirect_uri: str, grant_type: str = "authorization_code"
13+
) -> typing.Tuple[dict, int]:
14+
"""Exchange discord oauth code for access and refresh tokens."""
15+
async with aiohttp.ClientSession() as session:
16+
async with session.post(
17+
"%s/v6/oauth2/token" % DISCORD_ENDPOINT,
18+
data=dict(
19+
code=code,
20+
scope=scope,
21+
grant_type=grant_type,
22+
redirect_uri=redirect_uri,
23+
client_id=config.discord_client_id(),
24+
client_secret=config.discord_client_secret(),
25+
),
26+
headers={"Content-Type": "application/x-www-form-urlencoded"},
27+
) as response:
28+
return await response.json(), response.status
29+
30+
31+
async def get_user(access_token: str) -> dict:
32+
"""Coroutine to fetch User data from discord using the users `access_token`"""
33+
async with aiohttp.ClientSession() as session:
34+
async with session.get(
35+
"%s/v6/users/@me" % DISCORD_ENDPOINT,
36+
headers={"Authorization": "Bearer %s" % access_token},
37+
) as response:
38+
return await response.json()
39+
40+
41+
def format_scopes(scopes: typing.List[str]) -> str:
42+
"""Format a list of scopes."""
43+
return " ".join(scopes)
44+
45+
46+
def get_redirect(callback: str, scopes: typing.List[str]) -> str:
47+
"""Generates the correct oauth link depending on our provided arguments."""
48+
return (
49+
"{BASE}/oauth2/authorize?response_type=code"
50+
"&client_id={client_id}"
51+
"&scope={scopes}"
52+
"&redirect_uri={redirect_uri}"
53+
"&prompt=consent"
54+
).format(
55+
BASE=DISCORD_ENDPOINT,
56+
scopes=format_scopes(scopes),
57+
redirect_uri=quote_plus(callback),
58+
client_id=config.discord_client_id(),
59+
)
60+
61+
62+
def is_valid_url(string: str) -> bool:
63+
"""Returns boolean describing if the provided string is a url"""
64+
result = urlparse(string)
65+
return all((result.scheme, result.netloc))
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from datetime import datetime
2+
from pydantic import BaseModel
3+
4+
5+
class CallbackResponse(BaseModel):
6+
token: str
7+
exp: datetime
8+
9+
10+
class CallbackBody(BaseModel):
11+
code: str
12+
callback: str

api/versions/v1/routers/auth.py renamed to api/versions/v1/routers/auth/routes.py

+9-75
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,26 @@
11
import jwt
22
import utils
3-
import typing
43
import config
5-
import aiohttp
64

75
from api.models import User, Token
86

9-
from pydantic import BaseModel
107
from fastapi.params import Param
118
from fastapi import APIRouter, Request
129
from datetime import datetime, timedelta
13-
from urllib.parse import quote_plus, urlparse
1410
from fastapi.responses import RedirectResponse
11+
from .models import CallbackBody, CallbackResponse
12+
from .helpers import (
13+
SCOPES,
14+
get_user,
15+
get_redirect,
16+
is_valid_url,
17+
exchange_code,
18+
format_scopes,
19+
)
1520

1621
router = APIRouter(prefix="/auth")
1722

1823

19-
DISCORD_ENDPOINT = "https://discord.com/api"
20-
SCOPES = ["identify"]
21-
22-
23-
class CallbackResponse(BaseModel):
24-
token: str
25-
exp: datetime
26-
27-
28-
class CallbackBody(BaseModel):
29-
code: str
30-
callback: str
31-
32-
33-
async def exchange_code(
34-
*, code: str, scope: str, redirect_uri: str, grant_type: str = "authorization_code"
35-
) -> typing.Tuple[dict, int]:
36-
"""Exchange discord oauth code for access and refresh tokens."""
37-
async with aiohttp.ClientSession() as session:
38-
async with session.post(
39-
"%s/v6/oauth2/token" % DISCORD_ENDPOINT,
40-
data=dict(
41-
code=code,
42-
scope=scope,
43-
grant_type=grant_type,
44-
redirect_uri=redirect_uri,
45-
client_id=config.discord_client_id(),
46-
client_secret=config.discord_client_secret(),
47-
),
48-
headers={"Content-Type": "application/x-www-form-urlencoded"},
49-
) as response:
50-
return await response.json(), response.status
51-
52-
53-
async def get_user(access_token: str) -> dict:
54-
"""Coroutine to fetch User data from discord using the users `access_token`"""
55-
async with aiohttp.ClientSession() as session:
56-
async with session.get(
57-
"%s/v6/users/@me" % DISCORD_ENDPOINT,
58-
headers={"Authorization": "Bearer %s" % access_token},
59-
) as response:
60-
return await response.json()
61-
62-
63-
def format_scopes(scopes: typing.List[str]) -> str:
64-
"""Format a list of scopes."""
65-
return " ".join(scopes)
66-
67-
68-
def get_redirect(callback: str, scopes: typing.List[str]) -> str:
69-
"""Generates the correct oauth link depending on our provided arguments."""
70-
return (
71-
"{BASE}/oauth2/authorize?response_type=code"
72-
"&client_id={client_id}"
73-
"&scope={scopes}"
74-
"&redirect_uri={redirect_uri}"
75-
"&prompt=consent"
76-
).format(
77-
BASE=DISCORD_ENDPOINT,
78-
scopes=format_scopes(scopes),
79-
redirect_uri=quote_plus(callback),
80-
client_id=config.discord_client_id(),
81-
)
82-
83-
84-
def is_valid_url(string: str) -> bool:
85-
"""Returns boolean describing if the provided string is a url"""
86-
result = urlparse(string)
87-
return all((result.scheme, result.netloc))
88-
89-
9024
@router.get(
9125
"/discord/redirect",
9226
tags=["auth"],

api/versions/v1/routers/router.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fastapi import APIRouter
2-
from .auth import router as auth_router
2+
from .auth.routes import router as auth_router
33

44
router = APIRouter(prefix="/v1")
55

tests/test_auth.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from httpx import AsyncClient
33
from pytest_mock import MockerFixture
4-
from api.versions.v1.routers.auth import get_redirect, SCOPES
4+
from api.versions.v1.routers.auth.helpers import get_redirect, SCOPES
55

66

77
@pytest.mark.asyncio
@@ -42,7 +42,7 @@ async def test_callback_discord_error(app: AsyncClient, mocker: MockerFixture):
4242
async def exchange_code(**kwargs):
4343
return {"error": "internal server error"}, 500
4444

45-
mocker.patch("api.versions.v1.routers.auth.exchange_code", new=exchange_code)
45+
mocker.patch("api.versions.v1.routers.auth.routes.exchange_code", new=exchange_code)
4646

4747
res = await app.post(
4848
"/v1/auth/discord/callback",
@@ -57,7 +57,7 @@ async def test_callback_invalid_code(app: AsyncClient, mocker: MockerFixture):
5757
async def exchange_code(**kwargs):
5858
return {"error": 'invalid "code" in request'}, 400
5959

60-
mocker.patch("api.versions.v1.routers.auth.exchange_code", new=exchange_code)
60+
mocker.patch("api.versions.v1.routers.auth.routes.exchange_code", new=exchange_code)
6161
res = await app.post(
6262
"/v1/auth/discord/callback",
6363
json={"code": "invalid", "callback": "https://twt.gg"},
@@ -88,8 +88,8 @@ async def get_user(**kwargs):
8888
"avatar": "135fa48ba8f26417c4b9818ae2e37aa0",
8989
}
9090

91-
mocker.patch("api.versions.v1.routers.auth.get_user", new=get_user)
92-
mocker.patch("api.versions.v1.routers.auth.exchange_code", new=exchange_code)
91+
mocker.patch("api.versions.v1.routers.auth.routes.get_user", new=get_user)
92+
mocker.patch("api.versions.v1.routers.auth.routes.exchange_code", new=exchange_code)
9393

9494
res = await app.post(
9595
"/v1/auth/discord/callback",

0 commit comments

Comments
 (0)