Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions authentik/enterprise/providers/scim/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,72 @@
from datetime import datetime

from django.urls import reverse
from django.utils.translation import gettext as _
from rest_framework.exceptions import ValidationError

from authentik.enterprise.license import LicenseKey
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMProvider
from authentik.sources.oauth.models import UserOAuthSourceConnection


class SCIMProviderSerializerMixin:

def _get_token(self, instance: SCIMProvider) -> UserOAuthSourceConnection | None:
user = instance.auth_oauth_user
conn = UserOAuthSourceConnection.objects.filter(
user=user, source=instance.auth_oauth
).first()
return conn

def get_auth_oauth_token_last_updated(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.last_updated if conn else None

def get_auth_oauth_token_expires(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.expires if conn else None

def get_auth_oauth_url_callback(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:callback",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)

def get_auth_oauth_url_start(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:start",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)

def validate_auth_mode(self, auth_mode: SCIMAuthenticationMode) -> SCIMAuthenticationMode:
if auth_mode == SCIMAuthenticationMode.OAUTH:
if auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
if not LicenseKey.cached_summary().status.is_valid:
raise ValidationError(_("Enterprise is required to use the OAuth mode."))
return auth_mode
1 change: 1 addition & 0 deletions authentik/enterprise/providers/scim/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class AuthentikEnterpriseProviderSCIMConfig(EnterpriseConfig):
label = "authentik_enterprise_providers_scim"
verbose_name = "authentik Enterprise.Providers.SCIM"
default = True
mountpoint = "application/scim/"
34 changes: 21 additions & 13 deletions authentik/enterprise/providers/scim/auth_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from django.utils.timezone import now
from requests import Request, RequestException
from structlog.stdlib import get_logger

from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN
from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection

if TYPE_CHECKING:
Expand All @@ -18,23 +20,26 @@ class SCIMOAuthException(SCIMRequestException):


class SCIMOAuthAuth:

def __init__(self, provider: SCIMProvider):
self.provider = provider
self.user = provider.auth_oauth_user
self.logger = get_logger().bind()
self.connection = self.get_connection()

def retrieve_token(self):
if not self.provider.auth_oauth:
return None
def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
source: OAuthSource = self.provider.auth_oauth
client = OAuth2Client(source, None)
client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
access_token_url = source.source_type.access_token_url or ""
if source.source_type.urls_customizable and source.access_token_url:
access_token_url = source.access_token_url
data = client.get_access_token_args(None, None)
data["grant_type"] = "password"
if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
data["grant_type"] = GRANT_TYPE_PASSWORD
elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
if not conn:
raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
data["refresh_token"] = conn.refresh_token
data.update(self.provider.auth_oauth_params)
try:
response = client.do_request(
Expand All @@ -54,19 +59,22 @@ def retrieve_token(self):
raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc

def get_connection(self):
token = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user, expires__gt=now()
if not self.provider.auth_oauth:
return None
conn = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user
).first()
if token and token.access_token:
return token
token = self.retrieve_token()
if conn and conn.access_token and conn.expires > now():
return conn
token = self.retrieve_token(conn)
access_token = token["access_token"]
expires_in = int(token.get("expires_in", 0))
token, _ = UserOAuthSourceConnection.objects.update_or_create(
source=self.provider.auth_oauth,
user=self.user,
defaults={
"access_token": access_token,
"refresh_token": token.get("refresh_token"),
"expires": now() + timedelta(seconds=expires_in),
},
)
Expand Down
5 changes: 4 additions & 1 deletion authentik/enterprise/providers/scim/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def scim_provider_post_save(sender: type[Model], instance: SCIMProvider, created
"""Create service account before provider is saved"""
identifier = f"ak-providers-scim-{instance.pk}"
with audit_ignore():
if instance.auth_mode == SCIMAuthenticationMode.OAUTH:
if instance.auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
user, user_created = User.objects.update_or_create(
username=identifier,
defaults={
Expand Down
46 changes: 42 additions & 4 deletions authentik/enterprise/providers/scim/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self) -> None:
self.provider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
auth_mode=SCIMAuthenticationMode.OAUTH,
auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
auth_oauth=self.source,
auth_oauth_params={
"foo": "bar",
Expand All @@ -61,7 +61,7 @@ def setUp(self) -> None:
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)

def test_retrieve_token(self):
def test_retrieve_token_silent(self):
"""Test token retrieval"""
with Mocker() as mocker:
token = generate_id()
Expand All @@ -86,6 +86,44 @@ def test_retrieve_token(self):
)
self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")

def test_retrieve_token_interactive(self):
"""Test token retrieval"""
self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE
self.provider.save()
refresh_token = generate_id()
access_token = generate_id()
UserOAuthSourceConnection.objects.create(
user=self.provider.auth_oauth_user,
source=self.source,
refresh_token=refresh_token,
access_token=access_token,
)
with Mocker() as mocker:
token = generate_id()
mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
self.provider.scim_auth()
conn = UserOAuthSourceConnection.objects.filter(
source=self.source,
user=self.provider.auth_oauth_user,
).first()
self.assertIsNotNone(conn)
self.assertTrue(conn.is_valid)
auth = (
b64encode(
b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
)
.strip()
.decode()
)
self.assertEqual(
mocker.request_history[0].headers["Authorization"],
f"Basic {auth}",
)
self.assertEqual(
mocker.request_history[0].body,
f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar",
)

def test_existing_token(self):
"""Test existing token"""
UserOAuthSourceConnection.objects.create(
Expand Down Expand Up @@ -166,7 +204,7 @@ def test_api_create(self):
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_mode": "oauth_silent",
"auth_oauth": str(self.source.pk),
},
)
Expand All @@ -183,7 +221,7 @@ def test_api_create_no_license(self):
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_mode": "oauth_silent",
"auth_oauth": str(self.source.pk),
},
)
Expand Down
10 changes: 10 additions & 0 deletions authentik/enterprise/providers/scim/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.urls import path

from authentik.enterprise.providers.scim.views import SCIMOAuthStart, SCIMRedirectCallback

urlpatterns = [
path("<slug:application_slug>/oauth2/start/", SCIMOAuthStart.as_view(), name="start"),
path(
"<slug:application_slug>/oauth2/callback/", SCIMRedirectCallback.as_view(), name="callback"
),
]
61 changes: 61 additions & 0 deletions authentik/enterprise/providers/scim/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from datetime import timedelta

from django.http import HttpRequest, HttpResponseBadRequest
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.timezone import now

from authentik.core.models import Application
from authentik.providers.scim.models import SCIMProvider
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.registry import RequestKind, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect


class SCIMOAuthViewMixin:
provider: SCIMProvider

def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
source: OAuthSource = self.provider.auth_oauth
source_cls = registry.find(source.provider_type, kind=RequestKind.CALLBACK)
if not source_cls.client_class:
return super().get_client(source, **kwargs)
return source_cls.client_class(source, self.request, **kwargs)

def _get_scim_provider(self, app_slug: str):
app = Application.objects.filter(slug=app_slug).first()
if not app:
return None
provider = SCIMProvider.objects.filter(backchannel_application=app)
return provider.first()

def dispatch(self, request: HttpRequest, application_slug: str):
provider = self._get_scim_provider(application_slug)
if not provider or not provider.auth_oauth:
return HttpResponseBadRequest()
self.provider = provider
return super().dispatch(request, source_slug=provider.auth_oauth.slug)


class SCIMOAuthStart(SCIMOAuthViewMixin, OAuthRedirect):

def get_callback_url(self, source: OAuthSource):
return reverse("authentik_enterprise_providers_scim:callback", kwargs=self.kwargs)


class SCIMRedirectCallback(SCIMOAuthViewMixin, OAuthCallback):

def redirect_flow_manager(self, client: BaseOAuthClient):
expires_in = int(self.token.get("expires_in", 0))
UserOAuthSourceConnection.objects.update_or_create(
source=self.provider.auth_oauth,
user=self.provider.auth_oauth_user,
defaults={
"access_token": self.token.get("access_token"),
"refresh_token": self.token.get("refresh_token"),
"expires": now() + timedelta(seconds=expires_in),
},
)
return redirect("authentik_core:if-admin")
10 changes: 10 additions & 0 deletions authentik/providers/scim/api/providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SCIM Provider API Views"""

from rest_framework.fields import SerializerMethodField
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.providers import ProviderSerializer
Expand All @@ -16,6 +17,11 @@ class SCIMProviderSerializer(
):
"""SCIMProvider Serializer"""

auth_oauth_token_last_updated = SerializerMethodField()
auth_oauth_token_expires = SerializerMethodField()
auth_oauth_url_callback = SerializerMethodField()
auth_oauth_url_start = SerializerMethodField()

class Meta:
model = SCIMProvider
fields = [
Expand All @@ -35,6 +41,10 @@ class Meta:
"auth_mode",
"auth_oauth",
"auth_oauth_params",
"auth_oauth_token_last_updated",
"auth_oauth_token_expires",
"auth_oauth_url_callback",
"auth_oauth_url_start",
"compatibility_mode",
"service_provider_config_cache_timeout",
"exclude_users_service_account",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 5.2.14 on 2026-05-05 22:11

from django.db import migrations, models
from django.apps.registry import Apps

from django.db.backends.base.schema import BaseDatabaseSchemaEditor


def update_oauth(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias

SCIMProvider = apps.get("authentik_providers_scim", "scimprovider")

SCIMProvider.objects.using(db_alias).filter(auth_mode="oauth").update(auth_mode="oauth_silent")


class Migration(migrations.Migration):

dependencies = [
("authentik_providers_scim", "0019_scimprovider_group_filters_and_more"),
]

operations = [
migrations.AlterField(
model_name="scimprovider",
name="auth_mode",
field=models.TextField(
choices=[
("token", "Token"),
("oauth_silent", "OAuth (Silent)"),
("oauth_interactive", "OAuth (interactive)"),
],
default="token",
),
),
]
Loading
Loading